]> git.openstreetmap.org Git - nominatim.git/commitdiff
add type annotations to tool functions
authorSarah Hoffmann <lonvia@denofr.de>
Sat, 16 Jul 2022 21:28:02 +0000 (23:28 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 18 Jul 2022 07:54:27 +0000 (09:54 +0200)
nominatim/db/connection.py
nominatim/tools/add_osm_data.py
nominatim/tools/admin.py
nominatim/tools/postcodes.py
nominatim/tools/refresh.py
nominatim/tools/replication.py

index 685ac6cbfe112e2f6814a8e074573a4909259f37..4f32dfceb8b56868f6c83c6bc6e8200875349773 100644 (file)
@@ -37,7 +37,7 @@ class Cursor(psycopg2.extras.DictCursor):
 
 
     def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
 
 
     def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
-                       template: Optional[str] = None) -> None:
+                       template: Optional[Query] = None) -> None:
         """ Wrapper for the psycopg2 convenience function to execute
             SQL for a list of values.
         """
         """ Wrapper for the psycopg2 convenience function to execute
             SQL for a list of values.
         """
index b4e77b218e3752ad486e3864095072df0673d642..d5d01754dc2c82e65d3ed55c52e56ec1e01c41bf 100644 (file)
@@ -7,6 +7,7 @@
 """
 Function to add additional OSM data from a file or the API into the database.
 """
 """
 Function to add additional OSM data from a file or the API into the database.
 """
+from typing import Any, MutableMapping
 from pathlib import Path
 import logging
 import urllib
 from pathlib import Path
 import logging
 import urllib
@@ -15,7 +16,7 @@ from nominatim.tools.exec_utils import run_osm2pgsql, get_url
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-def add_data_from_file(fname, options):
+def add_data_from_file(fname: str, options: MutableMapping[str, Any]) -> int:
     """ Adds data from a OSM file to the database. The file may be a normal
         OSM file or a diff file in all formats supported by libosmium.
     """
     """ Adds data from a OSM file to the database. The file may be a normal
         OSM file or a diff file in all formats supported by libosmium.
     """
@@ -27,7 +28,8 @@ def add_data_from_file(fname, options):
     return 0
 
 
     return 0
 
 
-def add_osm_object(osm_type, osm_id, use_main_api, options):
+def add_osm_object(osm_type: str, osm_id: int, use_main_api: bool,
+                   options: MutableMapping[str, Any]) -> None:
     """ Add or update a single OSM object from the latest version of the
         API.
     """
     """ Add or update a single OSM object from the latest version of the
         API.
     """
index 1bf217e2d81e77fd2526f3f3b3b98ec03a236a05..49ba75261b92d212f09a3681969b152da1b84cd7 100644 (file)
@@ -7,22 +7,27 @@
 """
 Functions for database analysis and maintenance.
 """
 """
 Functions for database analysis and maintenance.
 """
+from typing import Optional, Tuple, Any
 import logging
 
 from psycopg2.extras import Json, register_hstore
 
 import logging
 
 from psycopg2.extras import Json, register_hstore
 
-from nominatim.db.connection import connect
+from nominatim.config import Configuration
+from nominatim.db.connection import connect, Cursor
 from nominatim.tokenizer import factory as tokenizer_factory
 from nominatim.errors import UsageError
 from nominatim.data.place_info import PlaceInfo
 from nominatim.tokenizer import factory as tokenizer_factory
 from nominatim.errors import UsageError
 from nominatim.data.place_info import PlaceInfo
+from nominatim.typing import DictCursorResult
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-def _get_place_info(cursor, osm_id, place_id):
+def _get_place_info(cursor: Cursor, osm_id: Optional[str],
+                    place_id: Optional[int]) -> DictCursorResult:
     sql = """SELECT place_id, extra.*
              FROM placex, LATERAL placex_indexing_prepare(placex) as extra
           """
 
     sql = """SELECT place_id, extra.*
              FROM placex, LATERAL placex_indexing_prepare(placex) as extra
           """
 
+    values: Tuple[Any, ...]
     if osm_id:
         osm_type = osm_id[0].upper()
         if osm_type not in 'NWR' or not osm_id[1:].isdigit():
     if osm_id:
         osm_type = osm_id[0].upper()
         if osm_type not in 'NWR' or not osm_id[1:].isdigit():
@@ -44,10 +49,11 @@ def _get_place_info(cursor, osm_id, place_id):
         LOG.fatal("OSM object %s not found in database.", osm_id)
         raise UsageError("OSM object not found")
 
         LOG.fatal("OSM object %s not found in database.", osm_id)
         raise UsageError("OSM object not found")
 
-    return cursor.fetchone()
+    return cursor.fetchone() # type: ignore[no-untyped-call]
 
 
 
 
-def analyse_indexing(config, osm_id=None, place_id=None):
+def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
+                     place_id: Optional[int] = None) -> None:
     """ Analyse indexing of a single Nominatim object.
     """
     with connect(config.get_libpq_dsn()) as conn:
     """ Analyse indexing of a single Nominatim object.
     """
     with connect(config.get_libpq_dsn()) as conn:
index 9c66719b5fe1ce55573985f8a653876c093102c6..7171e25d169d0af7b625430fafefdfba183c4df4 100644 (file)
@@ -8,7 +8,9 @@
 Functions for importing, updating and otherwise maintaining the table
 of artificial postcode centroids.
 """
 Functions for importing, updating and otherwise maintaining the table
 of artificial postcode centroids.
 """
+from typing import Optional, Tuple, Dict, List, TextIO
 from collections import defaultdict
 from collections import defaultdict
+from pathlib import Path
 import csv
 import gzip
 import logging
 import csv
 import gzip
 import logging
@@ -16,18 +18,19 @@ from math import isfinite
 
 from psycopg2 import sql as pysql
 
 
 from psycopg2 import sql as pysql
 
-from nominatim.db.connection import connect
+from nominatim.db.connection import connect, Connection
 from nominatim.utils.centroid import PointsCentroid
 from nominatim.utils.centroid import PointsCentroid
-from nominatim.data.postcode_format import PostcodeFormatter
+from nominatim.data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher
+from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-def _to_float(num, max_value):
+def _to_float(numstr: str, max_value: float) -> float:
     """ Convert the number in string into a float. The number is expected
         to be in the range of [-max_value, max_value]. Otherwise rises a
         ValueError.
     """
     """ Convert the number in string into a float. The number is expected
         to be in the range of [-max_value, max_value]. Otherwise rises a
         ValueError.
     """
-    num = float(num)
+    num = float(numstr)
     if not isfinite(num) or num <= -max_value or num >= max_value:
         raise ValueError()
 
     if not isfinite(num) or num <= -max_value or num >= max_value:
         raise ValueError()
 
@@ -37,18 +40,19 @@ class _PostcodeCollector:
     """ Collector for postcodes of a single country.
     """
 
     """ Collector for postcodes of a single country.
     """
 
-    def __init__(self, country, matcher):
+    def __init__(self, country: str, matcher: Optional[CountryPostcodeMatcher]):
         self.country = country
         self.matcher = matcher
         self.country = country
         self.matcher = matcher
-        self.collected = defaultdict(PointsCentroid)
-        self.normalization_cache = None
+        self.collected: Dict[str, PointsCentroid] = defaultdict(PointsCentroid)
+        self.normalization_cache: Optional[Tuple[str, Optional[str]]] = None
 
 
 
 
-    def add(self, postcode, x, y):
+    def add(self, postcode: str, x: float, y: float) -> None:
         """ Add the given postcode to the collection cache. If the postcode
             already existed, it is overwritten with the new centroid.
         """
         if self.matcher is not None:
         """ Add the given postcode to the collection cache. If the postcode
             already existed, it is overwritten with the new centroid.
         """
         if self.matcher is not None:
+            normalized: Optional[str]
             if self.normalization_cache and self.normalization_cache[0] == postcode:
                 normalized = self.normalization_cache[1]
             else:
             if self.normalization_cache and self.normalization_cache[0] == postcode:
                 normalized = self.normalization_cache[1]
             else:
@@ -60,7 +64,7 @@ class _PostcodeCollector:
                 self.collected[normalized] += (x, y)
 
 
                 self.collected[normalized] += (x, y)
 
 
-    def commit(self, conn, analyzer, project_dir):
+    def commit(self, conn: Connection, analyzer: AbstractAnalyzer, project_dir: Path) -> None:
         """ Update postcodes for the country from the postcodes selected so far
             as well as any externally supplied postcodes.
         """
         """ Update postcodes for the country from the postcodes selected so far
             as well as any externally supplied postcodes.
         """
@@ -94,7 +98,8 @@ class _PostcodeCollector:
                               """).format(pysql.Literal(self.country)), to_update)
 
 
                               """).format(pysql.Literal(self.country)), to_update)
 
 
-    def _compute_changes(self, conn):
+    def _compute_changes(self, conn: Connection) \
+          -> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[str, float, float]]]:
         """ Compute which postcodes from the collected postcodes have to be
             added or modified and which from the location_postcode table
             have to be deleted.
         """ Compute which postcodes from the collected postcodes have to be
             added or modified and which from the location_postcode table
             have to be deleted.
@@ -116,12 +121,12 @@ class _PostcodeCollector:
                     to_delete.append(postcode)
 
         to_add = [(k, *v.centroid()) for k, v in self.collected.items()]
                     to_delete.append(postcode)
 
         to_add = [(k, *v.centroid()) for k, v in self.collected.items()]
-        self.collected = None
+        self.collected = defaultdict(PointsCentroid)
 
         return to_add, to_delete, to_update
 
 
 
         return to_add, to_delete, to_update
 
 
-    def _update_from_external(self, analyzer, project_dir):
+    def _update_from_external(self, analyzer: AbstractAnalyzer, project_dir: Path) -> None:
         """ Look for an external postcode file for the active country in
             the project directory and add missing postcodes when found.
         """
         """ Look for an external postcode file for the active country in
             the project directory and add missing postcodes when found.
         """
@@ -151,7 +156,7 @@ class _PostcodeCollector:
             csvfile.close()
 
 
             csvfile.close()
 
 
-    def _open_external(self, project_dir):
+    def _open_external(self, project_dir: Path) -> Optional[TextIO]:
         fname = project_dir / f'{self.country}_postcodes.csv'
 
         if fname.is_file():
         fname = project_dir / f'{self.country}_postcodes.csv'
 
         if fname.is_file():
@@ -167,7 +172,7 @@ class _PostcodeCollector:
         return None
 
 
         return None
 
 
-def update_postcodes(dsn, project_dir, tokenizer):
+def update_postcodes(dsn: str, project_dir: Path, tokenizer: AbstractTokenizer) -> None:
     """ Update the table of artificial postcodes.
 
         Computes artificial postcode centroids from the placex table,
     """ Update the table of artificial postcodes.
 
         Computes artificial postcode centroids from the placex table,
@@ -220,7 +225,7 @@ def update_postcodes(dsn, project_dir, tokenizer):
 
         analyzer.update_postcodes_from_db()
 
 
         analyzer.update_postcodes_from_db()
 
-def can_compute(dsn):
+def can_compute(dsn: str) -> bool:
     """
         Check that the place table exists so that
         postcodes can be computed.
     """
         Check that the place table exists so that
         postcodes can be computed.
index 257d587e117f41de52571958826e75ae3202ba28..9c5b7b085e50582202a117528d87bc0ca7ff117a 100644 (file)
@@ -7,12 +7,15 @@
 """
 Functions for bringing auxiliary data in the database up-to-date.
 """
 """
 Functions for bringing auxiliary data in the database up-to-date.
 """
+from typing import MutableSequence, Tuple, Any, Type, Mapping, Sequence, List, cast
 import logging
 from textwrap import dedent
 from pathlib import Path
 
 from psycopg2 import sql as pysql
 
 import logging
 from textwrap import dedent
 from pathlib import Path
 
 from psycopg2 import sql as pysql
 
+from nominatim.config import Configuration
+from nominatim.db.connection import Connection
 from nominatim.db.utils import execute_file
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.version import version_str
 from nominatim.db.utils import execute_file
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.version import version_str
@@ -21,7 +24,8 @@ LOG = logging.getLogger()
 
 OSM_TYPE = {'N': 'node', 'W': 'way', 'R': 'relation'}
 
 
 OSM_TYPE = {'N': 'node', 'W': 'way', 'R': 'relation'}
 
-def _add_address_level_rows_from_entry(rows, entry):
+def _add_address_level_rows_from_entry(rows: MutableSequence[Tuple[Any, ...]],
+                                       entry: Mapping[str, Any]) -> None:
     """ Converts a single entry from the JSON format for address rank
         descriptions into a flat format suitable for inserting into a
         PostgreSQL table and adds these lines to `rows`.
     """ Converts a single entry from the JSON format for address rank
         descriptions into a flat format suitable for inserting into a
         PostgreSQL table and adds these lines to `rows`.
@@ -38,14 +42,15 @@ def _add_address_level_rows_from_entry(rows, entry):
             for country in countries:
                 rows.append((country, key, value, rank_search, rank_address))
 
             for country in countries:
                 rows.append((country, key, value, rank_search, rank_address))
 
-def load_address_levels(conn, table, levels):
+
+def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[str, Any]]) -> None:
     """ Replace the `address_levels` table with the contents of `levels'.
 
         A new table is created any previously existing table is dropped.
         The table has the following columns:
             country, class, type, rank_search, rank_address
     """
     """ Replace the `address_levels` table with the contents of `levels'.
 
         A new table is created any previously existing table is dropped.
         The table has the following columns:
             country, class, type, rank_search, rank_address
     """
-    rows = []
+    rows: List[Tuple[Any, ...]]  = []
     for entry in levels:
         _add_address_level_rows_from_entry(rows, entry)
 
     for entry in levels:
         _add_address_level_rows_from_entry(rows, entry)
 
@@ -69,7 +74,7 @@ def load_address_levels(conn, table, levels):
     conn.commit()
 
 
     conn.commit()
 
 
-def load_address_levels_from_config(conn, config):
+def load_address_levels_from_config(conn: Connection, config: Configuration) -> None:
     """ Replace the `address_levels` table with the content as
         defined in the given configuration. Uses the parameter
         NOMINATIM_ADDRESS_LEVEL_CONFIG to determine the location of the
     """ Replace the `address_levels` table with the content as
         defined in the given configuration. Uses the parameter
         NOMINATIM_ADDRESS_LEVEL_CONFIG to determine the location of the
@@ -79,7 +84,9 @@ def load_address_levels_from_config(conn, config):
     load_address_levels(conn, 'address_levels', cfg)
 
 
     load_address_levels(conn, 'address_levels', cfg)
 
 
-def create_functions(conn, config, enable_diff_updates=True, enable_debug=False):
+def create_functions(conn: Connection, config: Configuration,
+                     enable_diff_updates: bool = True,
+                     enable_debug: bool = False) -> None:
     """ (Re)create the PL/pgSQL functions.
     """
     sql = SQLPreprocessor(conn, config)
     """ (Re)create the PL/pgSQL functions.
     """
     sql = SQLPreprocessor(conn, config)
@@ -116,7 +123,7 @@ PHP_CONST_DEFS = (
 )
 
 
 )
 
 
-def import_wikipedia_articles(dsn, data_path, ignore_errors=False):
+def import_wikipedia_articles(dsn: str, data_path: Path, ignore_errors: bool = False) -> int:
     """ Replaces the wikipedia importance tables with new data.
         The import is run in a single transaction so that the new data
         is replace seemlessly.
     """ Replaces the wikipedia importance tables with new data.
         The import is run in a single transaction so that the new data
         is replace seemlessly.
@@ -140,7 +147,7 @@ def import_wikipedia_articles(dsn, data_path, ignore_errors=False):
     return 0
 
 
     return 0
 
 
-def recompute_importance(conn):
+def recompute_importance(conn: Connection) -> None:
     """ Recompute wikipedia links and importance for all entries in placex.
         This is a long-running operations that must not be executed in
         parallel with updates.
     """ Recompute wikipedia links and importance for all entries in placex.
         This is a long-running operations that must not be executed in
         parallel with updates.
@@ -163,12 +170,13 @@ def recompute_importance(conn):
     conn.commit()
 
 
     conn.commit()
 
 
-def _quote_php_variable(var_type, config, conf_name):
+def _quote_php_variable(var_type: Type[Any], config: Configuration,
+                        conf_name: str) -> str:
     if var_type == bool:
         return 'true' if config.get_bool(conf_name) else 'false'
 
     if var_type == int:
     if var_type == bool:
         return 'true' if config.get_bool(conf_name) else 'false'
 
     if var_type == int:
-        return getattr(config, conf_name)
+        return cast(str, getattr(config, conf_name))
 
     if not getattr(config, conf_name):
         return 'false'
 
     if not getattr(config, conf_name):
         return 'false'
@@ -182,7 +190,7 @@ def _quote_php_variable(var_type, config, conf_name):
     return f"'{quoted}'"
 
 
     return f"'{quoted}'"
 
 
-def setup_website(basedir, config, conn):
+def setup_website(basedir: Path, config: Configuration, conn: Connection) -> None:
     """ Create the website script stubs.
     """
     if not basedir.exists():
     """ Create the website script stubs.
     """
     if not basedir.exists():
@@ -215,7 +223,8 @@ def setup_website(basedir, config, conn):
             (basedir / script).write_text(template.format(script), 'utf-8')
 
 
             (basedir / script).write_text(template.format(script), 'utf-8')
 
 
-def invalidate_osm_object(osm_type, osm_id, conn, recursive=True):
+def invalidate_osm_object(osm_type: str, osm_id: int, conn: Connection,
+                          recursive: bool = True) -> None:
     """ Mark the given OSM object for reindexing. When 'recursive' is set
         to True (the default), then all dependent objects are marked for
         reindexing as well.
     """ Mark the given OSM object for reindexing. When 'recursive' is set
         to True (the default), then all dependent objects are marked for
         reindexing as well.
index fab3d2db57dc54822d4178261718a17ca132024e..db706bf67563d795b863da9b0d025412b8554c0f 100644 (file)
@@ -7,6 +7,7 @@
 """
 Functions for updating a database from a replication source.
 """
 """
 Functions for updating a database from a replication source.
 """
+from typing import ContextManager, MutableMapping, Any, Generator, cast
 from contextlib import contextmanager
 import datetime as dt
 from enum import Enum
 from contextlib import contextmanager
 import datetime as dt
 from enum import Enum
@@ -14,6 +15,7 @@ import logging
 import time
 
 from nominatim.db import status
 import time
 
 from nominatim.db import status
+from nominatim.db.connection import Connection
 from nominatim.tools.exec_utils import run_osm2pgsql
 from nominatim.errors import UsageError
 
 from nominatim.tools.exec_utils import run_osm2pgsql
 from nominatim.errors import UsageError
 
@@ -27,7 +29,7 @@ except ImportError as exc:
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-def init_replication(conn, base_url):
+def init_replication(conn: Connection, base_url: str) -> None:
     """ Set up replication for the server at the given base URL.
     """
     LOG.info("Using replication source: %s", base_url)
     """ Set up replication for the server at the given base URL.
     """
     LOG.info("Using replication source: %s", base_url)
@@ -51,7 +53,7 @@ def init_replication(conn, base_url):
     LOG.warning("Updates initialised at sequence %s (%s)", seq, date)
 
 
     LOG.warning("Updates initialised at sequence %s (%s)", seq, date)
 
 
-def check_for_updates(conn, base_url):
+def check_for_updates(conn: Connection, base_url: str) -> int:
     """ Check if new data is available from the replication service at the
         given base URL.
     """
     """ Check if new data is available from the replication service at the
         given base URL.
     """
@@ -84,7 +86,7 @@ class UpdateState(Enum):
     NO_CHANGES = 3
 
 
     NO_CHANGES = 3
 
 
-def update(conn, options):
+def update(conn: Connection, options: MutableMapping[str, Any]) -> UpdateState:
     """ Update database from the next batch of data. Returns the state of
         updates according to `UpdateState`.
     """
     """ Update database from the next batch of data. Returns the state of
         updates according to `UpdateState`.
     """
@@ -95,6 +97,8 @@ def update(conn, options):
                   "Please run 'nominatim replication --init' first.")
         raise UsageError("Replication not set up.")
 
                   "Please run 'nominatim replication --init' first.")
         raise UsageError("Replication not set up.")
 
+    assert startdate is not None
+
     if not indexed and options['indexed_only']:
         LOG.info("Skipping update. There is data that needs indexing.")
         return UpdateState.MORE_PENDING
     if not indexed and options['indexed_only']:
         LOG.info("Skipping update. There is data that needs indexing.")
         return UpdateState.MORE_PENDING
@@ -132,17 +136,17 @@ def update(conn, options):
     return UpdateState.UP_TO_DATE
 
 
     return UpdateState.UP_TO_DATE
 
 
-def _make_replication_server(url):
+def _make_replication_server(url: str) -> ContextManager[ReplicationServer]:
     """ Returns a ReplicationServer in form of a context manager.
 
         Creates a light wrapper around older versions of pyosmium that did
         not support the context manager interface.
     """
     if hasattr(ReplicationServer, '__enter__'):
     """ Returns a ReplicationServer in form of a context manager.
 
         Creates a light wrapper around older versions of pyosmium that did
         not support the context manager interface.
     """
     if hasattr(ReplicationServer, '__enter__'):
-        return ReplicationServer(url)
+        return cast(ContextManager[ReplicationServer], ReplicationServer(url))
 
     @contextmanager
 
     @contextmanager
-    def get_cm():
+    def get_cm() -> Generator[ReplicationServer, None, None]:
         yield ReplicationServer(url)
 
     return get_cm()
         yield ReplicationServer(url)
 
     return get_cm()