]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/sql_preprocessor.py
bdd: reorganise field comparisons
[nominatim.git] / nominatim / db / sql_preprocessor.py
index 9e0b291298d7739993c6447db4ca159ae517fe40..31b4a8c0f9a0042877cb652f6578fe2e896a899a 100644 (file)
@@ -1,10 +1,20 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Preprocessing of SQL files.
 """
 """
 Preprocessing of SQL files.
 """
+from typing import Set, Dict, Any
 import jinja2
 
 import jinja2
 
+from nominatim.db.connection import Connection
+from nominatim.db.async_connection import WorkerPool
+from nominatim.config import Configuration
 
 
-def _get_partitions(conn):
+def _get_partitions(conn: Connection) -> Set[int]:
     """ Get the set of partitions currently in use.
     """
     with conn.cursor() as cur:
     """ Get the set of partitions currently in use.
     """
     with conn.cursor() as cur:
@@ -16,7 +26,7 @@ def _get_partitions(conn):
     return partitions
 
 
     return partitions
 
 
-def _get_tables(conn):
+def _get_tables(conn: Connection) -> Set[str]:
     """ Return the set of tables currently in use.
         Only includes non-partitioned
     """
     """ Return the set of tables currently in use.
         Only includes non-partitioned
     """
@@ -26,42 +36,30 @@ def _get_tables(conn):
         return set((row[0] for row in list(cur)))
 
 
         return set((row[0] for row in list(cur)))
 
 
-def _setup_tablespace_sql(config):
+def _setup_tablespace_sql(config: Configuration) -> Dict[str, str]:
     """ Returns a dict with tablespace expressions for the different tablespace
         kinds depending on whether a tablespace is configured or not.
     """
     out = {}
     for subset in ('ADDRESS', 'SEARCH', 'AUX'):
         for kind in ('DATA', 'INDEX'):
     """ Returns a dict with tablespace expressions for the different tablespace
         kinds depending on whether a tablespace is configured or not.
     """
     out = {}
     for subset in ('ADDRESS', 'SEARCH', 'AUX'):
         for kind in ('DATA', 'INDEX'):
-            tspace = getattr(config, 'TABLESPACE_{}_{}'.format(subset, kind))
+            tspace = getattr(config, f'TABLESPACE_{subset}_{kind}')
             if tspace:
             if tspace:
-                tspace = 'TABLESPACE "{}"'.format(tspace)
-            out['{}_{}'.format(subset.lower, kind.lower())] = tspace
+                tspace = f'TABLESPACE "{tspace}"'
+            out[f'{subset.lower()}_{kind.lower()}'] = tspace
 
     return out
 
 
 
     return out
 
 
-def _setup_postgres_sql(conn):
-    """ Set up a dictionary with various Postgresql/Postgis SQL terms which
-        are dependent on the database version in use.
-    """
-    out = {}
-    pg_version = conn.server_version_tuple()
-    # CREATE INDEX IF NOT EXISTS was introduced in PG9.5.
-    # Note that you need to ignore failures on older versions when
-    # using this construct.
-    out['if_index_not_exists'] = ' IF NOT EXISTS ' if pg_version >= (9, 5, 0) else ''
-
-    return out
-
-
-def _setup_postgresql_features(conn):
+def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
     """ Set up a dictionary with various optional Postgresql/Postgis features that
         depend on the database version.
     """
     pg_version = conn.server_version_tuple()
     """ Set up a dictionary with various optional Postgresql/Postgis features that
         depend on the database version.
     """
     pg_version = conn.server_version_tuple()
+    postgis_version = conn.postgis_version_tuple()
     return {
     return {
-        'has_index_non_key_column' : pg_version >= (11, 0, 0)
+        'has_index_non_key_column': pg_version >= (11, 0, 0),
+        'spgist_geom' : 'SPGIST' if postgis_version >= (3, 0) else 'GIST'
     }
 
 class SQLPreprocessor:
     }
 
 class SQLPreprocessor:
@@ -75,11 +73,11 @@ class SQLPreprocessor:
         and follows its syntax.
     """
 
         and follows its syntax.
     """
 
-    def __init__(self, conn, config):
+    def __init__(self, conn: Connection, config: Configuration) -> None:
         self.env = jinja2.Environment(autoescape=False,
                                       loader=jinja2.FileSystemLoader(str(config.lib_dir.sql)))
 
         self.env = jinja2.Environment(autoescape=False,
                                       loader=jinja2.FileSystemLoader(str(config.lib_dir.sql)))
 
-        db_info = {}
+        db_info: Dict[str, Any] = {}
         db_info['partitions'] = _get_partitions(conn)
         db_info['tables'] = _get_tables(conn)
         db_info['reverse_only'] = 'search_name' not in db_info['tables']
         db_info['partitions'] = _get_partitions(conn)
         db_info['tables'] = _get_tables(conn)
         db_info['reverse_only'] = 'search_name' not in db_info['tables']
@@ -87,13 +85,10 @@ class SQLPreprocessor:
 
         self.env.globals['config'] = config
         self.env.globals['db'] = db_info
 
         self.env.globals['config'] = config
         self.env.globals['db'] = db_info
-        self.env.globals['sql'] = _setup_postgres_sql(conn)
         self.env.globals['postgres'] = _setup_postgresql_features(conn)
         self.env.globals['postgres'] = _setup_postgresql_features(conn)
-        self.env.globals['modulepath'] = config.DATABASE_MODULE_PATH or \
-                                         str((config.project_dir / 'module').resolve())
 
 
 
 
-    def run_sql_file(self, conn, name, **kwargs):
+    def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None:
         """ Execute the given SQL file on the connection. The keyword arguments
             may supply additional parameters for preprocessing.
         """
         """ Execute the given SQL file on the connection. The keyword arguments
             may supply additional parameters for preprocessing.
         """
@@ -102,3 +97,21 @@ class SQLPreprocessor:
         with conn.cursor() as cur:
             cur.execute(sql)
         conn.commit()
         with conn.cursor() as cur:
             cur.execute(sql)
         conn.commit()
+
+
+    def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
+                              **kwargs: Any) -> None:
+        """ Execure the given SQL files using parallel asynchronous connections.
+            The keyword arguments may supply additional parameters for
+            preprocessing.
+
+            After preprocessing the SQL code is cut at lines containing only
+            '---'. Each chunk is sent to one of the `num_threads` workers.
+        """
+        sql = self.env.get_template(name).render(**kwargs)
+
+        parts = sql.split('\n---\n')
+
+        with WorkerPool(dsn, num_threads) as pool:
+            for part in parts:
+                pool.next_free_worker().perform(part)