]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/data/country_info.py
Improved logic.
[nominatim.git] / nominatim / data / country_info.py
index d754b4ddb029365b22d2cc7a77ccaeefc49a2719..eb0190b54e7b3928b61016604c7b1fc049e255b2 100644 (file)
@@ -7,13 +7,17 @@
 """
 Functions for importing and managing static country information.
 """
+from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload
+from pathlib import Path
 import psycopg2.extras
 
 from nominatim.db import utils as db_utils
-from nominatim.db.connection import connect
+from nominatim.db.connection import connect, Connection
 from nominatim.errors import UsageError
+from nominatim.config import Configuration
+from nominatim.tokenizer.base import AbstractTokenizer
 
-def _flatten_name_list(names):
+def _flatten_name_list(names: Any) -> Dict[str, str]:
     if names is None:
         return {}
 
@@ -41,11 +45,11 @@ class _CountryInfo:
     """ Caches country-specific properties from the configuration file.
     """
 
-    def __init__(self):
-        self._info = {}
+    def __init__(self) -> None:
+        self._info: Dict[str, Dict[str, Any]] = {}
 
 
-    def load(self, config):
+    def load(self, config: Configuration) -> None:
         """ Load the country properties from the configuration files,
             if they are not loaded yet.
         """
@@ -61,12 +65,12 @@ class _CountryInfo:
                 prop['names'] = _flatten_name_list(prop.get('names'))
 
 
-    def items(self):
+    def items(self) -> Iterable[Tuple[str, Dict[str, Any]]]:
         """ Return tuples of (country_code, property dict) as iterable.
         """
         return self._info.items()
 
-    def get(self, country_code):
+    def get(self, country_code: str) -> Dict[str, Any]:
         """ Get country information for the country with the given country code.
         """
         return self._info.get(country_code, {})
@@ -76,15 +80,22 @@ class _CountryInfo:
 _COUNTRY_INFO = _CountryInfo()
 
 
-def setup_country_config(config):
+def setup_country_config(config: Configuration) -> None:
     """ Load country properties from the configuration file.
         Needs to be called before using any other functions in this
         file.
     """
     _COUNTRY_INFO.load(config)
 
+@overload
+def iterate() -> Iterable[Tuple[str, Dict[str, Any]]]:
+    ...
 
-def iterate(prop=None):
+@overload
+def iterate(prop: str) -> Iterable[Tuple[str, Any]]:
+    ...
+
+def iterate(prop: Optional[str] = None) -> Iterable[Tuple[str, Dict[str, Any]]]:
     """ Iterate over country code and properties.
 
         When `prop` is None, all countries are returned with their complete
@@ -100,7 +111,7 @@ def iterate(prop=None):
     return ((c, p[prop]) for c, p in _COUNTRY_INFO.items() if prop in p)
 
 
-def setup_country_tables(dsn, sql_dir, ignore_partitions=False):
+def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = False) -> None:
     """ Create and populate the tables with basic static data that provides
         the background for geocoding. Data is assumed to not yet exist.
     """
@@ -112,7 +123,7 @@ def setup_country_tables(dsn, sql_dir, ignore_partitions=False):
             if ignore_partitions:
                 partition = 0
             else:
-                partition = props.get('partition')
+                partition = props.get('partition', 0)
             lang = props['languages'][0] if len(
                 props['languages']) == 1 else None
 
@@ -135,13 +146,14 @@ def setup_country_tables(dsn, sql_dir, ignore_partitions=False):
         conn.commit()
 
 
-def create_country_names(conn, tokenizer, languages=None):
+def create_country_names(conn: Connection, tokenizer: AbstractTokenizer,
+                         languages: Optional[Container[str]] = None) -> None:
     """ Add default country names to search index. `languages` is a comma-
         separated list of language codes as used in OSM. If `languages` is not
         empty then only name translations for the given languages are added
         to the index.
     """
-    def _include_key(key):
+    def _include_key(key: str) -> bool:
         return ':' not in key or not languages or \
                key[key.index(':') + 1:] in languages