]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/sql_preprocessor.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / db / sql_preprocessor.py
index 7ffe88818ba94a9538a9f8a61e7e8bb658e55e00..af5bc3357959abf52b9518e83144403e2150564b 100644 (file)
@@ -1,10 +1,20 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Preprocessing of SQL files.
 """
 """
 Preprocessing of SQL files.
 """
+from typing import Set, Dict, Any, cast
 import jinja2
 
 import jinja2
 
+from nominatim.db.connection import Connection
+from nominatim.db.async_connection import WorkerPool
+from nominatim.config import Configuration
 
 
-def _get_partitions(conn):
+def _get_partitions(conn: Connection) -> Set[int]:
     """ Get the set of partitions currently in use.
     """
     with conn.cursor() as cur:
     """ Get the set of partitions currently in use.
     """
     with conn.cursor() as cur:
@@ -16,55 +26,56 @@ def _get_partitions(conn):
     return partitions
 
 
     return partitions
 
 
-def _get_tables(conn):
+def _get_tables(conn: Connection) -> Set[str]:
     """ Return the set of tables currently in use.
     """ Return the set of tables currently in use.
-        Only includes non-partitioned
     """
     with conn.cursor() as cur:
         cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
 
         return set((row[0] for row in list(cur)))
 
     """
     with conn.cursor() as cur:
         cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
 
         return set((row[0] for row in list(cur)))
 
+def _get_middle_db_format(conn: Connection, tables: Set[str]) -> str:
+    """ Returns the version of the slim middle tables.
+    """
+    if 'osm2pgsql_properties' not in tables:
+        return '1'
+
+    with conn.cursor() as cur:
+        cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
+        row = cur.fetchone()
+
+        return cast(str, row[0]) if row is not None else '1'
 
 
-def _setup_tablespace_sql(config):
+
+def _setup_tablespace_sql(config: Configuration) -> Dict[str, str]:
     """ Returns a dict with tablespace expressions for the different tablespace
         kinds depending on whether a tablespace is configured or not.
     """
     out = {}
     for subset in ('ADDRESS', 'SEARCH', 'AUX'):
         for kind in ('DATA', 'INDEX'):
     """ Returns a dict with tablespace expressions for the different tablespace
         kinds depending on whether a tablespace is configured or not.
     """
     out = {}
     for subset in ('ADDRESS', 'SEARCH', 'AUX'):
         for kind in ('DATA', 'INDEX'):
-            tspace = getattr(config, 'TABLESPACE_{}_{}'.format(subset, kind))
+            tspace = getattr(config, f'TABLESPACE_{subset}_{kind}')
             if tspace:
             if tspace:
-                tspace = 'TABLESPACE "{}"'.format(tspace)
-            out['{}_{}'.format(subset.lower, kind.lower())] = tspace
+                tspace = f'TABLESPACE "{tspace}"'
+            out[f'{subset.lower()}_{kind.lower()}'] = tspace
 
     return out
 
 
 
     return out
 
 
-def _setup_postgres_sql(conn):
-    """ Set up a dictionary with various Postgresql/Postgis SQL terms which
-        are dependent on the database version in use.
-    """
-    out = {}
-    pg_version = conn.server_version_tuple()
-    # CREATE INDEX IF NOT EXISTS was introduced in PG9.5.
-    # Note that you need to ignore failures on older versions when
-    # using this construct.
-    out['if_index_not_exists'] = ' IF NOT EXISTS ' if pg_version >= (9, 5, 0) else ''
-
-    return out
-
-
-def _setup_postgresql_features(conn):
+def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
     """ Set up a dictionary with various optional Postgresql/Postgis features that
         depend on the database version.
     """
     pg_version = conn.server_version_tuple()
     """ Set up a dictionary with various optional Postgresql/Postgis features that
         depend on the database version.
     """
     pg_version = conn.server_version_tuple()
+    postgis_version = conn.postgis_version_tuple()
+    pg11plus = pg_version >= (11, 0, 0)
+    ps3 = postgis_version >= (3, 0)
     return {
     return {
-        'has_index_non_key_column' : pg_version >= (11, 0, 0)
+        'has_index_non_key_column': pg11plus,
+        'spgist_geom' : 'SPGIST' if pg11plus and ps3 else 'GIST'
     }
 
     }
 
-class SQLPreprocessor: # pylint: disable=too-few-public-methods
+class SQLPreprocessor:
     """ A environment for preprocessing SQL files from the
         lib-sql directory.
 
     """ A environment for preprocessing SQL files from the
         lib-sql directory.
 
@@ -75,25 +86,35 @@ class SQLPreprocessor: # pylint: disable=too-few-public-methods
         and follows its syntax.
     """
 
         and follows its syntax.
     """
 
-    def __init__(self, conn, config, sqllib_dir):
+    def __init__(self, conn: Connection, config: Configuration) -> None:
         self.env = jinja2.Environment(autoescape=False,
         self.env = jinja2.Environment(autoescape=False,
-                                      loader=jinja2.FileSystemLoader(str(sqllib_dir)))
+                                      loader=jinja2.FileSystemLoader(str(config.lib_dir.sql)))
 
 
-        db_info = {}
+        db_info: Dict[str, Any] = {}
         db_info['partitions'] = _get_partitions(conn)
         db_info['tables'] = _get_tables(conn)
         db_info['reverse_only'] = 'search_name' not in db_info['tables']
         db_info['tablespace'] = _setup_tablespace_sql(config)
         db_info['partitions'] = _get_partitions(conn)
         db_info['tables'] = _get_tables(conn)
         db_info['reverse_only'] = 'search_name' not in db_info['tables']
         db_info['tablespace'] = _setup_tablespace_sql(config)
+        db_info['middle_db_format'] = _get_middle_db_format(conn, db_info['tables'])
 
         self.env.globals['config'] = config
         self.env.globals['db'] = db_info
 
         self.env.globals['config'] = config
         self.env.globals['db'] = db_info
-        self.env.globals['sql'] = _setup_postgres_sql(conn)
         self.env.globals['postgres'] = _setup_postgresql_features(conn)
         self.env.globals['postgres'] = _setup_postgresql_features(conn)
-        self.env.globals['modulepath'] = config.DATABASE_MODULE_PATH or \
-                                         str((config.project_dir / 'module').resolve())
 
 
 
 
-    def run_sql_file(self, conn, name, **kwargs):
+    def run_string(self, conn: Connection, template: str, **kwargs: Any) -> None:
+        """ Execute the given SQL template string on the connection.
+            The keyword arguments may supply additional parameters
+            for preprocessing.
+        """
+        sql = self.env.from_string(template).render(**kwargs)
+
+        with conn.cursor() as cur:
+            cur.execute(sql)
+        conn.commit()
+
+
+    def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None:
         """ Execute the given SQL file on the connection. The keyword arguments
             may supply additional parameters for preprocessing.
         """
         """ Execute the given SQL file on the connection. The keyword arguments
             may supply additional parameters for preprocessing.
         """
@@ -102,3 +123,21 @@ class SQLPreprocessor: # pylint: disable=too-few-public-methods
         with conn.cursor() as cur:
             cur.execute(sql)
         conn.commit()
         with conn.cursor() as cur:
             cur.execute(sql)
         conn.commit()
+
+
+    def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
+                              **kwargs: Any) -> None:
+        """ Execute the given SQL files using parallel asynchronous connections.
+            The keyword arguments may supply additional parameters for
+            preprocessing.
+
+            After preprocessing the SQL code is cut at lines containing only
+            '---'. Each chunk is sent to one of the `num_threads` workers.
+        """
+        sql = self.env.get_template(name).render(**kwargs)
+
+        parts = sql.split('\n---\n')
+
+        with WorkerPool(dsn, num_threads) as pool:
+            for part in parts:
+                pool.next_free_worker().perform(part)