]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/special_phrases/sp_importer.py
add type annotations to special phrase importer
[nominatim.git] / nominatim / tools / special_phrases / sp_importer.py
index 195f8387d3d98d33a3339ac5920203d3955e1262..6ca6a1e17b8ef7db34f5015ca65ce61d7e7c0f52 100644 (file)
     The phrases already present in the database which are not
     valids anymore are removed.
 """
+from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set
 import logging
 import re
 
-from psycopg2.sql import Identifier, Literal, SQL
+from typing_extensions import Protocol
+
+from psycopg2.sql import Identifier, SQL
+
+from nominatim.config import Configuration
+from nominatim.db.connection import Connection
 from nominatim.tools.special_phrases.importer_statistics import SpecialPhrasesImporterStatistics
+from nominatim.tools.special_phrases.special_phrase import SpecialPhrase
+from nominatim.tokenizer.base import AbstractTokenizer
 
 LOG = logging.getLogger()
 
-def _classtype_table(phrase_class, phrase_type):
+def _classtype_table(phrase_class: str, phrase_type: str) -> str:
     """ Return the name of the table for the given class and type.
     """
     return f'place_classtype_{phrase_class}_{phrase_type}'
 
+
+class SpecialPhraseLoader(Protocol):
+    """ Protocol for classes implementing a loader for special phrases.
+    """
+
+    def generate_phrases(self) -> Iterable[SpecialPhrase]:
+        """ Generates all special phrase terms this loader can produce.
+        """
+
+
 class SPImporter():
     # pylint: disable-msg=too-many-instance-attributes
     """
@@ -33,21 +51,22 @@ class SPImporter():
 
         Take a sp loader which load the phrases from an external source.
     """
-    def __init__(self, config, db_connection, sp_loader) -> None:
+    def __init__(self, config: Configuration, conn: Connection,
+                 sp_loader: SpecialPhraseLoader) -> None:
         self.config = config
-        self.db_connection = db_connection
+        self.db_connection = conn
         self.sp_loader = sp_loader
         self.statistics_handler = SpecialPhrasesImporterStatistics()
         self.black_list, self.white_list = self._load_white_and_black_lists()
         self.sanity_check_pattern = re.compile(r'^\w+$')
         # This set will contain all existing phrases to be added.
         # It contains tuples with the following format: (lable, class, type, operator)
-        self.word_phrases = set()
+        self.word_phrases: Set[Tuple[str, str, str, str]] = set()
         # This set will contain all existing place_classtype tables which doesn't match any
         # special phrases class/type on the wiki.
-        self.table_phrases_to_delete = set()
+        self.table_phrases_to_delete: Set[str] = set()
 
-    def import_phrases(self, tokenizer, should_replace):
+    def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None:
         """
             Iterate through all SpecialPhrases extracted from the
             loader and import them into the database.
@@ -62,13 +81,12 @@ class SPImporter():
         # Store pairs of class/type for further processing
         class_type_pairs = set()
 
-        for loaded_phrases in self.sp_loader:
-            for phrase in loaded_phrases:
-                result = self._process_phrase(phrase)
-                if result:
-                    class_type_pairs.add(result)
+        for phrase in self.sp_loader.generate_phrases():
+            result = self._process_phrase(phrase)
+            if result:
+                class_type_pairs.add(result)
 
-        self._create_place_classtype_table_and_indexes(class_type_pairs)
+        self._create_classtype_table_and_indexes(class_type_pairs)
         if should_replace:
             self._remove_non_existent_tables_from_db()
         self.db_connection.commit()
@@ -80,7 +98,7 @@ class SPImporter():
         self.statistics_handler.notify_import_done()
 
 
-    def _fetch_existing_place_classtype_tables(self):
+    def _fetch_existing_place_classtype_tables(self) -> None:
         """
             Fetch existing place_classtype tables.
             Fill the table_phrases_to_delete set of the class.
@@ -96,7 +114,8 @@ class SPImporter():
             for row in db_cursor:
                 self.table_phrases_to_delete.add(row[0])
 
-    def _load_white_and_black_lists(self):
+    def _load_white_and_black_lists(self) \
+          -> Tuple[Mapping[str, Sequence[str]], Mapping[str, Sequence[str]]]:
         """
             Load white and black lists from phrases-settings.json.
         """
@@ -104,7 +123,7 @@ class SPImporter():
 
         return settings['blackList'], settings['whiteList']
 
-    def _check_sanity(self, phrase):
+    def _check_sanity(self, phrase: SpecialPhrase) -> bool:
         """
             Check sanity of given inputs in case somebody added garbage in the wiki.
             If a bad class/type is detected the system will exit with an error.
@@ -118,7 +137,7 @@ class SPImporter():
             return False
         return True
 
-    def _process_phrase(self, phrase):
+    def _process_phrase(self, phrase: SpecialPhrase) -> Optional[Tuple[str, str]]:
         """
             Processes the given phrase by checking black and white list
             and sanity.
@@ -146,7 +165,8 @@ class SPImporter():
         return (phrase.p_class, phrase.p_type)
 
 
-    def _create_place_classtype_table_and_indexes(self, class_type_pairs):
+    def _create_classtype_table_and_indexes(self,
+                                            class_type_pairs: Iterable[Tuple[str, str]]) -> None:
         """
             Create table place_classtype for each given pair.
             Also create indexes on place_id and centroid.
@@ -189,44 +209,48 @@ class SPImporter():
             db_cursor.execute("DROP INDEX idx_placex_classtype")
 
 
-    def _create_place_classtype_table(self, sql_tablespace, phrase_class, phrase_type):
+    def _create_place_classtype_table(self, sql_tablespace: str,
+                                      phrase_class: str, phrase_type: str) -> None:
         """
-            Create table place_classtype of the given phrase_class/phrase_type if doesn't exit.
+            Create table place_classtype of the given phrase_class/phrase_type
+            if doesn't exit.
         """
         table_name = _classtype_table(phrase_class, phrase_type)
-        with self.db_connection.cursor() as db_cursor:
-            db_cursor.execute(SQL("""
-                    CREATE TABLE IF NOT EXISTS {{}} {}
-                    AS SELECT place_id AS place_id,st_centroid(geometry) AS centroid FROM placex
-                    WHERE class = {{}} AND type = {{}}""".format(sql_tablespace))
-                              .format(Identifier(table_name), Literal(phrase_class),
-                                      Literal(phrase_type)))
-
-
-    def _create_place_classtype_indexes(self, sql_tablespace, phrase_class, phrase_type):
+        with self.db_connection.cursor() as cur:
+            cur.execute(SQL("""CREATE TABLE IF NOT EXISTS {} {} AS
+                                 SELECT place_id AS place_id,
+                                        st_centroid(geometry) AS centroid
+                                 FROM placex
+                                 WHERE class = %s AND type = %s
+                             """).format(Identifier(table_name), SQL(sql_tablespace)),
+                        (phrase_class, phrase_type))
+
+
+    def _create_place_classtype_indexes(self, sql_tablespace: str,
+                                        phrase_class: str, phrase_type: str) -> None:
         """
             Create indexes on centroid and place_id for the place_classtype table.
         """
-        index_prefix = 'idx_place_classtype_{}_{}_'.format(phrase_class, phrase_type)
+        index_prefix = f'idx_place_classtype_{phrase_class}_{phrase_type}_'
         base_table = _classtype_table(phrase_class, phrase_type)
         # Index on centroid
         if not self.db_connection.index_exists(index_prefix + 'centroid'):
             with self.db_connection.cursor() as db_cursor:
-                db_cursor.execute(SQL("""
-                    CREATE INDEX {{}} ON {{}} USING GIST (centroid) {}""".format(sql_tablespace))
+                db_cursor.execute(SQL("CREATE INDEX {} ON {} USING GIST (centroid) {}")
                                   .format(Identifier(index_prefix + 'centroid'),
-                                          Identifier(base_table)), sql_tablespace)
+                                          Identifier(base_table),
+                                          SQL(sql_tablespace)))
 
         # Index on place_id
         if not self.db_connection.index_exists(index_prefix + 'place_id'):
             with self.db_connection.cursor() as db_cursor:
-                db_cursor.execute(SQL(
-                    """CREATE INDEX {{}} ON {{}} USING btree(place_id) {}""".format(sql_tablespace))
+                db_cursor.execute(SQL("CREATE INDEX {} ON {} USING btree(place_id) {}")
                                   .format(Identifier(index_prefix + 'place_id'),
-                                          Identifier(base_table)))
+                                          Identifier(base_table),
+                                          SQL(sql_tablespace)))
 
 
-    def _grant_access_to_webuser(self, phrase_class, phrase_type):
+    def _grant_access_to_webuser(self, phrase_class: str, phrase_type: str) -> None:
         """
             Grant access on read to the table place_classtype for the webuser.
         """
@@ -236,7 +260,7 @@ class SPImporter():
                               .format(Identifier(table_name),
                                       Identifier(self.config.DATABASE_WEBUSER)))
 
-    def _remove_non_existent_tables_from_db(self):
+    def _remove_non_existent_tables_from_db(self) -> None:
         """
             Remove special phrases which doesn't exist on the wiki anymore.
             Delete the place_classtype tables.