]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/special_phrases/sp_importer.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / tools / special_phrases / sp_importer.py
index 9eefaa1972600cd6a257ff4ed91a555014b74958..06b59fd003d5e3022b3e32cd458936152aaeb669 100644 (file)
     The phrases already present in the database which are not
     valids anymore are removed.
 """
     The phrases already present in the database which are not
     valids anymore are removed.
 """
+from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set
 import logging
 import re
 
 from psycopg2.sql import Identifier, SQL
 import logging
 import re
 
 from psycopg2.sql import Identifier, SQL
+
+from nominatim.config import Configuration
+from nominatim.db.connection import Connection
 from nominatim.tools.special_phrases.importer_statistics import SpecialPhrasesImporterStatistics
 from nominatim.tools.special_phrases.importer_statistics import SpecialPhrasesImporterStatistics
+from nominatim.tools.special_phrases.special_phrase import SpecialPhrase
+from nominatim.tokenizer.base import AbstractTokenizer
+from nominatim.typing import Protocol
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-def _classtype_table(phrase_class, phrase_type):
+def _classtype_table(phrase_class: str, phrase_type: str) -> str:
     """ Return the name of the table for the given class and type.
     """
     return f'place_classtype_{phrase_class}_{phrase_type}'
 
     """ Return the name of the table for the given class and type.
     """
     return f'place_classtype_{phrase_class}_{phrase_type}'
 
+
+class SpecialPhraseLoader(Protocol):
+    """ Protocol for classes implementing a loader for special phrases.
+    """
+
+    def generate_phrases(self) -> Iterable[SpecialPhrase]:
+        """ Generates all special phrase terms this loader can produce.
+        """
+
+
 class SPImporter():
     # pylint: disable-msg=too-many-instance-attributes
     """
 class SPImporter():
     # pylint: disable-msg=too-many-instance-attributes
     """
@@ -33,21 +50,22 @@ class SPImporter():
 
         Take a sp loader which load the phrases from an external source.
     """
 
         Take a sp loader which load the phrases from an external source.
     """
-    def __init__(self, config, db_connection, sp_loader) -> None:
+    def __init__(self, config: Configuration, conn: Connection,
+                 sp_loader: SpecialPhraseLoader) -> None:
         self.config = config
         self.config = config
-        self.db_connection = db_connection
+        self.db_connection = conn
         self.sp_loader = sp_loader
         self.statistics_handler = SpecialPhrasesImporterStatistics()
         self.black_list, self.white_list = self._load_white_and_black_lists()
         self.sanity_check_pattern = re.compile(r'^\w+$')
         # This set will contain all existing phrases to be added.
         self.sp_loader = sp_loader
         self.statistics_handler = SpecialPhrasesImporterStatistics()
         self.black_list, self.white_list = self._load_white_and_black_lists()
         self.sanity_check_pattern = re.compile(r'^\w+$')
         # This set will contain all existing phrases to be added.
-        # It contains tuples with the following format: (lable, class, type, operator)
-        self.word_phrases = set()
+        # It contains tuples with the following format: (label, class, type, operator)
+        self.word_phrases: Set[Tuple[str, str, str, str]] = set()
         # This set will contain all existing place_classtype tables which doesn't match any
         # special phrases class/type on the wiki.
         # This set will contain all existing place_classtype tables which doesn't match any
         # special phrases class/type on the wiki.
-        self.table_phrases_to_delete = set()
+        self.table_phrases_to_delete: Set[str] = set()
 
 
-    def import_phrases(self, tokenizer, should_replace):
+    def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None:
         """
             Iterate through all SpecialPhrases extracted from the
             loader and import them into the database.
         """
             Iterate through all SpecialPhrases extracted from the
             loader and import them into the database.
@@ -62,13 +80,12 @@ class SPImporter():
         # Store pairs of class/type for further processing
         class_type_pairs = set()
 
         # Store pairs of class/type for further processing
         class_type_pairs = set()
 
-        for loaded_phrases in self.sp_loader:
-            for phrase in loaded_phrases:
-                result = self._process_phrase(phrase)
-                if result:
-                    class_type_pairs.add(result)
+        for phrase in self.sp_loader.generate_phrases():
+            result = self._process_phrase(phrase)
+            if result:
+                class_type_pairs.add(result)
 
 
-        self._create_place_classtype_table_and_indexes(class_type_pairs)
+        self._create_classtype_table_and_indexes(class_type_pairs)
         if should_replace:
             self._remove_non_existent_tables_from_db()
         self.db_connection.commit()
         if should_replace:
             self._remove_non_existent_tables_from_db()
         self.db_connection.commit()
@@ -80,7 +97,7 @@ class SPImporter():
         self.statistics_handler.notify_import_done()
 
 
         self.statistics_handler.notify_import_done()
 
 
-    def _fetch_existing_place_classtype_tables(self):
+    def _fetch_existing_place_classtype_tables(self) -> None:
         """
             Fetch existing place_classtype tables.
             Fill the table_phrases_to_delete set of the class.
         """
             Fetch existing place_classtype tables.
             Fill the table_phrases_to_delete set of the class.
@@ -96,7 +113,8 @@ class SPImporter():
             for row in db_cursor:
                 self.table_phrases_to_delete.add(row[0])
 
             for row in db_cursor:
                 self.table_phrases_to_delete.add(row[0])
 
-    def _load_white_and_black_lists(self):
+    def _load_white_and_black_lists(self) \
+          -> Tuple[Mapping[str, Sequence[str]], Mapping[str, Sequence[str]]]:
         """
             Load white and black lists from phrases-settings.json.
         """
         """
             Load white and black lists from phrases-settings.json.
         """
@@ -104,7 +122,7 @@ class SPImporter():
 
         return settings['blackList'], settings['whiteList']
 
 
         return settings['blackList'], settings['whiteList']
 
-    def _check_sanity(self, phrase):
+    def _check_sanity(self, phrase: SpecialPhrase) -> bool:
         """
             Check sanity of given inputs in case somebody added garbage in the wiki.
             If a bad class/type is detected the system will exit with an error.
         """
             Check sanity of given inputs in case somebody added garbage in the wiki.
             If a bad class/type is detected the system will exit with an error.
@@ -118,7 +136,7 @@ class SPImporter():
             return False
         return True
 
             return False
         return True
 
-    def _process_phrase(self, phrase):
+    def _process_phrase(self, phrase: SpecialPhrase) -> Optional[Tuple[str, str]]:
         """
             Processes the given phrase by checking black and white list
             and sanity.
         """
             Processes the given phrase by checking black and white list
             and sanity.
@@ -146,7 +164,8 @@ class SPImporter():
         return (phrase.p_class, phrase.p_type)
 
 
         return (phrase.p_class, phrase.p_type)
 
 
-    def _create_place_classtype_table_and_indexes(self, class_type_pairs):
+    def _create_classtype_table_and_indexes(self,
+                                            class_type_pairs: Iterable[Tuple[str, str]]) -> None:
         """
             Create table place_classtype for each given pair.
             Also create indexes on place_id and centroid.
         """
             Create table place_classtype for each given pair.
             Also create indexes on place_id and centroid.
@@ -189,22 +208,25 @@ class SPImporter():
             db_cursor.execute("DROP INDEX idx_placex_classtype")
 
 
             db_cursor.execute("DROP INDEX idx_placex_classtype")
 
 
-    def _create_place_classtype_table(self, sql_tablespace, phrase_class, phrase_type):
+    def _create_place_classtype_table(self, sql_tablespace: str,
+                                      phrase_class: str, phrase_type: str) -> None:
         """
         """
-            Create table place_classtype of the given phrase_class/phrase_type if doesn't exit.
+            Create table place_classtype of the given phrase_class/phrase_type
+            if doesn't exit.
         """
         table_name = _classtype_table(phrase_class, phrase_type)
         """
         table_name = _classtype_table(phrase_class, phrase_type)
-        with self.db_connection.cursor() as db_cursor:
-            db_cursor.execute(SQL("""CREATE TABLE IF NOT EXISTS {} {} AS
-                                       SELECT place_id AS place_id,
-                                              st_centroid(geometry) AS centroid
-                                       FROM placex
-                                       WHERE class = %s AND type = %s""")
-                                 .format(Identifier(table_name), SQL(sql_tablespace)),
-                              (phrase_class, phrase_type))
-
-
-    def _create_place_classtype_indexes(self, sql_tablespace, phrase_class, phrase_type):
+        with self.db_connection.cursor() as cur:
+            cur.execute(SQL("""CREATE TABLE IF NOT EXISTS {} {} AS
+                                 SELECT place_id AS place_id,
+                                        st_centroid(geometry) AS centroid
+                                 FROM placex
+                                 WHERE class = %s AND type = %s
+                             """).format(Identifier(table_name), SQL(sql_tablespace)),
+                        (phrase_class, phrase_type))
+
+
+    def _create_place_classtype_indexes(self, sql_tablespace: str,
+                                        phrase_class: str, phrase_type: str) -> None:
         """
             Create indexes on centroid and place_id for the place_classtype table.
         """
         """
             Create indexes on centroid and place_id for the place_classtype table.
         """
@@ -214,9 +236,9 @@ class SPImporter():
         if not self.db_connection.index_exists(index_prefix + 'centroid'):
             with self.db_connection.cursor() as db_cursor:
                 db_cursor.execute(SQL("CREATE INDEX {} ON {} USING GIST (centroid) {}")
         if not self.db_connection.index_exists(index_prefix + 'centroid'):
             with self.db_connection.cursor() as db_cursor:
                 db_cursor.execute(SQL("CREATE INDEX {} ON {} USING GIST (centroid) {}")
-                                     .format(Identifier(index_prefix + 'centroid'),
-                                             Identifier(base_table),
-                                             SQL(sql_tablespace)))
+                                  .format(Identifier(index_prefix + 'centroid'),
+                                          Identifier(base_table),
+                                          SQL(sql_tablespace)))
 
         # Index on place_id
         if not self.db_connection.index_exists(index_prefix + 'place_id'):
 
         # Index on place_id
         if not self.db_connection.index_exists(index_prefix + 'place_id'):
@@ -227,7 +249,7 @@ class SPImporter():
                                           SQL(sql_tablespace)))
 
 
                                           SQL(sql_tablespace)))
 
 
-    def _grant_access_to_webuser(self, phrase_class, phrase_type):
+    def _grant_access_to_webuser(self, phrase_class: str, phrase_type: str) -> None:
         """
             Grant access on read to the table place_classtype for the webuser.
         """
         """
             Grant access on read to the table place_classtype for the webuser.
         """
@@ -237,7 +259,7 @@ class SPImporter():
                               .format(Identifier(table_name),
                                       Identifier(self.config.DATABASE_WEBUSER)))
 
                               .format(Identifier(table_name),
                                       Identifier(self.config.DATABASE_WEBUSER)))
 
-    def _remove_non_existent_tables_from_db(self):
+    def _remove_non_existent_tables_from_db(self) -> None:
         """
             Remove special phrases which doesn't exist on the wiki anymore.
             Delete the place_classtype tables.
         """
             Remove special phrases which doesn't exist on the wiki anymore.
             Delete the place_classtype tables.