]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/data/country_info.py
Improved logic.
[nominatim.git] / nominatim / data / country_info.py
index d754b4ddb029365b22d2cc7a77ccaeefc49a2719..eb0190b54e7b3928b61016604c7b1fc049e255b2 100644 (file)
@@ -7,13 +7,17 @@
 """
 Functions for importing and managing static country information.
 """
 """
 Functions for importing and managing static country information.
 """
+from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload
+from pathlib import Path
 import psycopg2.extras
 
 from nominatim.db import utils as db_utils
 import psycopg2.extras
 
 from nominatim.db import utils as db_utils
-from nominatim.db.connection import connect
+from nominatim.db.connection import connect, Connection
 from nominatim.errors import UsageError
 from nominatim.errors import UsageError
+from nominatim.config import Configuration
+from nominatim.tokenizer.base import AbstractTokenizer
 
 
-def _flatten_name_list(names):
+def _flatten_name_list(names: Any) -> Dict[str, str]:
     if names is None:
         return {}
 
     if names is None:
         return {}
 
@@ -41,11 +45,11 @@ class _CountryInfo:
     """ Caches country-specific properties from the configuration file.
     """
 
     """ Caches country-specific properties from the configuration file.
     """
 
-    def __init__(self):
-        self._info = {}
+    def __init__(self) -> None:
+        self._info: Dict[str, Dict[str, Any]] = {}
 
 
 
 
-    def load(self, config):
+    def load(self, config: Configuration) -> None:
         """ Load the country properties from the configuration files,
             if they are not loaded yet.
         """
         """ Load the country properties from the configuration files,
             if they are not loaded yet.
         """
@@ -61,12 +65,12 @@ class _CountryInfo:
                 prop['names'] = _flatten_name_list(prop.get('names'))
 
 
                 prop['names'] = _flatten_name_list(prop.get('names'))
 
 
-    def items(self):
+    def items(self) -> Iterable[Tuple[str, Dict[str, Any]]]:
         """ Return tuples of (country_code, property dict) as iterable.
         """
         return self._info.items()
 
         """ Return tuples of (country_code, property dict) as iterable.
         """
         return self._info.items()
 
-    def get(self, country_code):
+    def get(self, country_code: str) -> Dict[str, Any]:
         """ Get country information for the country with the given country code.
         """
         return self._info.get(country_code, {})
         """ Get country information for the country with the given country code.
         """
         return self._info.get(country_code, {})
@@ -76,15 +80,22 @@ class _CountryInfo:
 _COUNTRY_INFO = _CountryInfo()
 
 
 _COUNTRY_INFO = _CountryInfo()
 
 
-def setup_country_config(config):
+def setup_country_config(config: Configuration) -> None:
     """ Load country properties from the configuration file.
         Needs to be called before using any other functions in this
         file.
     """
     _COUNTRY_INFO.load(config)
 
     """ Load country properties from the configuration file.
         Needs to be called before using any other functions in this
         file.
     """
     _COUNTRY_INFO.load(config)
 
+@overload
+def iterate() -> Iterable[Tuple[str, Dict[str, Any]]]:
+    ...
 
 
-def iterate(prop=None):
+@overload
+def iterate(prop: str) -> Iterable[Tuple[str, Any]]:
+    ...
+
+def iterate(prop: Optional[str] = None) -> Iterable[Tuple[str, Dict[str, Any]]]:
     """ Iterate over country code and properties.
 
         When `prop` is None, all countries are returned with their complete
     """ Iterate over country code and properties.
 
         When `prop` is None, all countries are returned with their complete
@@ -100,7 +111,7 @@ def iterate(prop=None):
     return ((c, p[prop]) for c, p in _COUNTRY_INFO.items() if prop in p)
 
 
     return ((c, p[prop]) for c, p in _COUNTRY_INFO.items() if prop in p)
 
 
-def setup_country_tables(dsn, sql_dir, ignore_partitions=False):
+def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = False) -> None:
     """ Create and populate the tables with basic static data that provides
         the background for geocoding. Data is assumed to not yet exist.
     """
     """ Create and populate the tables with basic static data that provides
         the background for geocoding. Data is assumed to not yet exist.
     """
@@ -112,7 +123,7 @@ def setup_country_tables(dsn, sql_dir, ignore_partitions=False):
             if ignore_partitions:
                 partition = 0
             else:
             if ignore_partitions:
                 partition = 0
             else:
-                partition = props.get('partition')
+                partition = props.get('partition', 0)
             lang = props['languages'][0] if len(
                 props['languages']) == 1 else None
 
             lang = props['languages'][0] if len(
                 props['languages']) == 1 else None
 
@@ -135,13 +146,14 @@ def setup_country_tables(dsn, sql_dir, ignore_partitions=False):
         conn.commit()
 
 
         conn.commit()
 
 
-def create_country_names(conn, tokenizer, languages=None):
+def create_country_names(conn: Connection, tokenizer: AbstractTokenizer,
+                         languages: Optional[Container[str]] = None) -> None:
     """ Add default country names to search index. `languages` is a comma-
         separated list of language codes as used in OSM. If `languages` is not
         empty then only name translations for the given languages are added
         to the index.
     """
     """ Add default country names to search index. `languages` is a comma-
         separated list of language codes as used in OSM. If `languages` is not
         empty then only name translations for the given languages are added
         to the index.
     """
-    def _include_key(key):
+    def _include_key(key: str) -> bool:
         return ':' not in key or not languages or \
                key[key.index(':') + 1:] in languages
 
         return ':' not in key or not languages or \
                key[key.index(':') + 1:] in languages