]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_db/data/country_info.py
use autocommit when creating tables and indexes
[nominatim.git] / src / nominatim_db / data / country_info.py
index 35943a50d3ed08cf0a46a7658fafeffac0db2674..bc3f20f598b68155f787a7feb5bdcd3a006fa3f4 100644 (file)
@@ -9,14 +9,14 @@ Functions for importing and managing static country information.
 """
 from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload
 from pathlib import Path
 """
 from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload
 from pathlib import Path
-import psycopg2.extras
 
 
-from nominatim_core.db import utils as db_utils
-from nominatim_core.db.connection import connect, Connection
-from nominatim_core.errors import UsageError
-from nominatim_core.config import Configuration
+from ..db import utils as db_utils
+from ..db.connection import connect, Connection, register_hstore
+from ..errors import UsageError
+from ..config import Configuration
 from ..tokenizer.base import AbstractTokenizer
 
 from ..tokenizer.base import AbstractTokenizer
 
+
 def _flatten_name_list(names: Any) -> Dict[str, str]:
     if names is None:
         return {}
 def _flatten_name_list(names: Any) -> Dict[str, str]:
     if names is None:
         return {}
@@ -40,7 +40,6 @@ def _flatten_name_list(names: Any) -> Dict[str, str]:
     return flat
 
 
     return flat
 
 
-
 class _CountryInfo:
     """ Caches country-specific properties from the configuration file.
     """
 class _CountryInfo:
     """ Caches country-specific properties from the configuration file.
     """
@@ -48,7 +47,6 @@ class _CountryInfo:
     def __init__(self) -> None:
         self._info: Dict[str, Dict[str, Any]] = {}
 
     def __init__(self) -> None:
         self._info: Dict[str, Dict[str, Any]] = {}
 
-
     def load(self, config: Configuration) -> None:
         """ Load the country properties from the configuration files,
             if they are not loaded yet.
     def load(self, config: Configuration) -> None:
         """ Load the country properties from the configuration files,
             if they are not loaded yet.
@@ -64,7 +62,6 @@ class _CountryInfo:
                                          for x in prop['languages'].split(',')]
                 prop['names'] = _flatten_name_list(prop.get('names'))
 
                                          for x in prop['languages'].split(',')]
                 prop['names'] = _flatten_name_list(prop.get('names'))
 
-
     def items(self) -> Iterable[Tuple[str, Dict[str, Any]]]:
         """ Return tuples of (country_code, property dict) as iterable.
         """
     def items(self) -> Iterable[Tuple[str, Dict[str, Any]]]:
         """ Return tuples of (country_code, property dict) as iterable.
         """
@@ -76,7 +73,6 @@ class _CountryInfo:
         return self._info.get(country_code, {})
 
 
         return self._info.get(country_code, {})
 
 
-
 _COUNTRY_INFO = _CountryInfo()
 
 
 _COUNTRY_INFO = _CountryInfo()
 
 
@@ -87,14 +83,17 @@ def setup_country_config(config: Configuration) -> None:
     """
     _COUNTRY_INFO.load(config)
 
     """
     _COUNTRY_INFO.load(config)
 
+
 @overload
 def iterate() -> Iterable[Tuple[str, Dict[str, Any]]]:
     ...
 
 @overload
 def iterate() -> Iterable[Tuple[str, Dict[str, Any]]]:
     ...
 
+
 @overload
 def iterate(prop: str) -> Iterable[Tuple[str, Any]]:
     ...
 
 @overload
 def iterate(prop: str) -> Iterable[Tuple[str, Any]]:
     ...
 
+
 def iterate(prop: Optional[str] = None) -> Iterable[Tuple[str, Dict[str, Any]]]:
     """ Iterate over country code and properties.
 
 def iterate(prop: Optional[str] = None) -> Iterable[Tuple[str, Dict[str, Any]]]:
     """ Iterate over country code and properties.
 
@@ -129,8 +128,8 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals
 
             params.append((ccode, props['names'], lang, partition))
     with connect(dsn) as conn:
 
             params.append((ccode, props['names'], lang, partition))
     with connect(dsn) as conn:
+        register_hstore(conn)
         with conn.cursor() as cur:
         with conn.cursor() as cur:
-            psycopg2.extras.register_hstore(cur)
             cur.execute(
                 """ CREATE TABLE public.country_name (
                         country_code character varying(2),
             cur.execute(
                 """ CREATE TABLE public.country_name (
                         country_code character varying(2),
@@ -139,9 +138,10 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals
                         country_default_language_code text,
                         partition integer
                     ); """)
                         country_default_language_code text,
                         partition integer
                     ); """)
-            cur.execute_values(
+            cur.executemany(
                 """ INSERT INTO public.country_name
                 """ INSERT INTO public.country_name
-                    (country_code, name, country_default_language_code, partition) VALUES %s
+                    (country_code, name, country_default_language_code, partition)
+                    VALUES (%s, %s, %s, %s)
                 """, params)
         conn.commit()
 
                 """, params)
         conn.commit()
 
@@ -157,8 +157,8 @@ def create_country_names(conn: Connection, tokenizer: AbstractTokenizer,
         return ':' not in key or not languages or \
                key[key.index(':') + 1:] in languages
 
         return ':' not in key or not languages or \
                key[key.index(':') + 1:] in languages
 
+    register_hstore(conn)
     with conn.cursor() as cur:
     with conn.cursor() as cur:
-        psycopg2.extras.register_hstore(cur)
         cur.execute("""SELECT country_code, name FROM country_name
                        WHERE country_code is not null""")
 
         cur.execute("""SELECT country_code, name FROM country_name
                        WHERE country_code is not null""")
 
@@ -168,7 +168,7 @@ def create_country_names(conn: Connection, tokenizer: AbstractTokenizer,
 
                 # country names (only in languages as provided)
                 if name:
 
                 # country names (only in languages as provided)
                 if name:
-                    names.update({k : v for k, v in name.items() if _include_key(k)})
+                    names.update({k: v for k, v in name.items() if _include_key(k)})
 
                 analyzer.add_country_names(code, names)
 
 
                 analyzer.add_country_names(code, names)