X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/4da4cbfe27a576ae011430b2de205c74435e241b..122ecd46269d48a8b2aacba2474311c0400d2a9d:/src/nominatim_db/data/country_info.py diff --git a/src/nominatim_db/data/country_info.py b/src/nominatim_db/data/country_info.py index c8002ee7..bc3f20f5 100644 --- a/src/nominatim_db/data/country_info.py +++ b/src/nominatim_db/data/country_info.py @@ -9,14 +9,14 @@ Functions for importing and managing static country information. """ from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload from pathlib import Path -import psycopg2.extras from ..db import utils as db_utils -from ..db.connection import connect, Connection +from ..db.connection import connect, Connection, register_hstore from ..errors import UsageError from ..config import Configuration from ..tokenizer.base import AbstractTokenizer + def _flatten_name_list(names: Any) -> Dict[str, str]: if names is None: return {} @@ -40,7 +40,6 @@ def _flatten_name_list(names: Any) -> Dict[str, str]: return flat - class _CountryInfo: """ Caches country-specific properties from the configuration file. """ @@ -48,7 +47,6 @@ class _CountryInfo: def __init__(self) -> None: self._info: Dict[str, Dict[str, Any]] = {} - def load(self, config: Configuration) -> None: """ Load the country properties from the configuration files, if they are not loaded yet. @@ -64,7 +62,6 @@ class _CountryInfo: for x in prop['languages'].split(',')] prop['names'] = _flatten_name_list(prop.get('names')) - def items(self) -> Iterable[Tuple[str, Dict[str, Any]]]: """ Return tuples of (country_code, property dict) as iterable. """ @@ -76,7 +73,6 @@ class _CountryInfo: return self._info.get(country_code, {}) - _COUNTRY_INFO = _CountryInfo() @@ -87,14 +83,17 @@ def setup_country_config(config: Configuration) -> None: """ _COUNTRY_INFO.load(config) + @overload def iterate() -> Iterable[Tuple[str, Dict[str, Any]]]: ... + @overload def iterate(prop: str) -> Iterable[Tuple[str, Any]]: ... + def iterate(prop: Optional[str] = None) -> Iterable[Tuple[str, Dict[str, Any]]]: """ Iterate over country code and properties. @@ -129,8 +128,8 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals params.append((ccode, props['names'], lang, partition)) with connect(dsn) as conn: + register_hstore(conn) with conn.cursor() as cur: - psycopg2.extras.register_hstore(cur) cur.execute( """ CREATE TABLE public.country_name ( country_code character varying(2), @@ -139,9 +138,10 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals country_default_language_code text, partition integer ); """) - cur.execute_values( + cur.executemany( """ INSERT INTO public.country_name - (country_code, name, country_default_language_code, partition) VALUES %s + (country_code, name, country_default_language_code, partition) + VALUES (%s, %s, %s, %s) """, params) conn.commit() @@ -157,8 +157,8 @@ def create_country_names(conn: Connection, tokenizer: AbstractTokenizer, return ':' not in key or not languages or \ key[key.index(':') + 1:] in languages + register_hstore(conn) with conn.cursor() as cur: - psycopg2.extras.register_hstore(cur) cur.execute("""SELECT country_code, name FROM country_name WHERE country_code is not null""") @@ -168,7 +168,7 @@ def create_country_names(conn: Connection, tokenizer: AbstractTokenizer, # country names (only in languages as provided) if name: - names.update({k : v for k, v in name.items() if _include_key(k)}) + names.update({k: v for k, v in name.items() if _include_key(k)}) analyzer.add_country_names(code, names)