]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/database_import.py
Merge pull request #3122 from miku0/sanitizer-final
[nominatim.git] / nominatim / tools / database_import.py
index 0dd93490d7b4c8c316bdd5635a02a3b1c9a6c6da..cb620d41fb8f31126fe69a622bf14130e38494d1 100644 (file)
@@ -1,6 +1,13 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Functions for setting up and importing a new Nominatim database.
 """
 """
 Functions for setting up and importing a new Nominatim database.
 """
+from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
 import logging
 import os
 import selectors
 import logging
 import os
 import selectors
@@ -8,11 +15,10 @@ import subprocess
 from pathlib import Path
 
 import psutil
 from pathlib import Path
 
 import psutil
-import psycopg2.extras
 from psycopg2 import sql as pysql
 
 from psycopg2 import sql as pysql
 
-from nominatim.db.connection import connect, get_pg_env
-from nominatim.db import utils as db_utils
+from nominatim.config import Configuration
+from nominatim.db.connection import connect, get_pg_env, Connection
 from nominatim.db.async_connection import DBConnection
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.tools.exec_utils import run_osm2pgsql
 from nominatim.db.async_connection import DBConnection
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.tools.exec_utils import run_osm2pgsql
@@ -21,24 +27,24 @@ from nominatim.version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERS
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-def setup_database_skeleton(dsn, data_dir, no_partitions, rouser=None):
-    """ Create a new database for Nominatim and populate it with the
-        essential extensions and data.
+def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None:
+    """ Compares the version for the given module and raises an exception
+        if the actual version is too old.
     """
     """
-    LOG.warning('Creating database')
-    create_db(dsn, rouser)
+    if actual < expected:
+        LOG.fatal('Minimum supported version of %s is %d.%d. '
+                  'Found version %d.%d.',
+                  module, expected[0], expected[1], actual[0], actual[1])
+        raise UsageError(f'{module} is too old.')
 
 
-    LOG.warning('Setting up database')
-    with connect(dsn) as conn:
-        setup_extensions(conn)
 
 
-    LOG.warning('Loading basic data')
-    import_base_data(dsn, data_dir, no_partitions)
+def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
+    """ Create a new database for Nominatim and populate it with the
+        essential extensions.
 
 
+        The function fails when the database already exists or Postgresql or
+        PostGIS versions are too old.
 
 
-def create_db(dsn, rouser=None):
-    """ Create a new database for the given DSN. Fails when the database
-        already exists or the PostgreSQL version is too old.
         Uses `createdb` to create the database.
 
         If 'rouser' is given, then the function also checks that the user
         Uses `createdb` to create the database.
 
         If 'rouser' is given, then the function also checks that the user
@@ -52,58 +58,38 @@ def create_db(dsn, rouser=None):
         raise UsageError('Creating new database failed.')
 
     with connect(dsn) as conn:
         raise UsageError('Creating new database failed.')
 
     with connect(dsn) as conn:
-        postgres_version = conn.server_version_tuple()
-        if postgres_version < POSTGRESQL_REQUIRED_VERSION:
-            LOG.fatal('Minimum supported version of Postgresql is %d.%d. '
-                      'Found version %d.%d.',
-                      POSTGRESQL_REQUIRED_VERSION[0], POSTGRESQL_REQUIRED_VERSION[1],
-                      postgres_version[0], postgres_version[1])
-            raise UsageError('PostgreSQL server is too old.')
+        _require_version('PostgreSQL server',
+                         conn.server_version_tuple(),
+                         POSTGRESQL_REQUIRED_VERSION)
 
         if rouser is not None:
             with conn.cursor() as cur:
                 cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
                                  (rouser, ))
                 if cnt == 0:
 
         if rouser is not None:
             with conn.cursor() as cur:
                 cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
                                  (rouser, ))
                 if cnt == 0:
-                    LOG.fatal("Web user '%s' does not exists. Create it with:\n"
+                    LOG.fatal("Web user '%s' does not exist. Create it with:\n"
                               "\n      createuser %s", rouser, rouser)
                     raise UsageError('Missing read-only user.')
 
                               "\n      createuser %s", rouser, rouser)
                     raise UsageError('Missing read-only user.')
 
+        # Create extensions.
+        with conn.cursor() as cur:
+            cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
+            cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
 
 
+            postgis_version = conn.postgis_version_tuple()
+            if postgis_version[0] >= 3:
+                cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
 
 
-def setup_extensions(conn):
-    """ Set up all extensions needed for Nominatim. Also checks that the
-        versions of the extensions are sufficient.
-    """
-    with conn.cursor() as cur:
-        cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
-        cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
-    conn.commit()
-
-    postgis_version = conn.postgis_version_tuple()
-    if postgis_version < POSTGIS_REQUIRED_VERSION:
-        LOG.fatal('Minimum supported version of PostGIS is %d.%d. '
-                  'Found version %d.%d.',
-                  POSTGIS_REQUIRED_VERSION[0], POSTGIS_REQUIRED_VERSION[1],
-                  postgis_version[0], postgis_version[1])
-        raise UsageError('PostGIS version is too old.')
-
-
-def import_base_data(dsn, sql_dir, ignore_partitions=False):
-    """ Create and populate the tables with basic static data that provides
-        the background for geocoding. Data is assumed to not yet exist.
-    """
-    db_utils.execute_file(dsn, sql_dir / 'country_name.sql')
-    db_utils.execute_file(dsn, sql_dir / 'country_osm_grid.sql.gz')
+        conn.commit()
 
 
-    if ignore_partitions:
-        with connect(dsn) as conn:
-            with conn.cursor() as cur:
-                cur.execute('UPDATE country_name SET partition = 0')
-            conn.commit()
+        _require_version('PostGIS',
+                         conn.postgis_version_tuple(),
+                         POSTGIS_REQUIRED_VERSION)
 
 
 
 
-def import_osm_data(osm_files, options, drop=False, ignore_errors=False):
+def import_osm_data(osm_files: Union[Path, Sequence[Path]],
+                    options: MutableMapping[str, Any],
+                    drop: bool = False, ignore_errors: bool = False) -> None:
     """ Import the given OSM files. 'options' contains the list of
         default settings for osm2pgsql.
     """
     """ Import the given OSM files. 'options' contains the list of
         default settings for osm2pgsql.
     """
@@ -140,7 +126,7 @@ def import_osm_data(osm_files, options, drop=False, ignore_errors=False):
         Path(options['flatnode_file']).unlink()
 
 
         Path(options['flatnode_file']).unlink()
 
 
-def create_tables(conn, config, reverse_only=False):
+def create_tables(conn: Connection, config: Configuration, reverse_only: bool = False) -> None:
     """ Create the set of basic tables.
         When `reverse_only` is True, then the main table for searching will
         be skipped and only reverse search is possible.
     """ Create the set of basic tables.
         When `reverse_only` is True, then the main table for searching will
         be skipped and only reverse search is possible.
@@ -151,7 +137,7 @@ def create_tables(conn, config, reverse_only=False):
     sql.run_sql_file(conn, 'tables.sql')
 
 
     sql.run_sql_file(conn, 'tables.sql')
 
 
-def create_table_triggers(conn, config):
+def create_table_triggers(conn: Connection, config: Configuration) -> None:
     """ Create the triggers for the tables. The trigger functions must already
         have been imported with refresh.create_functions().
     """
     """ Create the triggers for the tables. The trigger functions must already
         have been imported with refresh.create_functions().
     """
@@ -159,14 +145,14 @@ def create_table_triggers(conn, config):
     sql.run_sql_file(conn, 'table-triggers.sql')
 
 
     sql.run_sql_file(conn, 'table-triggers.sql')
 
 
-def create_partition_tables(conn, config):
+def create_partition_tables(conn: Connection, config: Configuration) -> None:
     """ Create tables that have explicit partitioning.
     """
     sql = SQLPreprocessor(conn, config)
     sql.run_sql_file(conn, 'partition-tables.src.sql')
 
 
     """ Create tables that have explicit partitioning.
     """
     sql = SQLPreprocessor(conn, config)
     sql.run_sql_file(conn, 'partition-tables.src.sql')
 
 
-def truncate_data_tables(conn):
+def truncate_data_tables(conn: Connection) -> None:
     """ Truncate all data tables to prepare for a fresh load.
     """
     with conn.cursor() as cur:
     """ Truncate all data tables to prepare for a fresh load.
     """
     with conn.cursor() as cur:
@@ -197,7 +183,7 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
                                          'extratags', 'geometry')))
 
 
                                          'extratags', 'geometry')))
 
 
-def load_data(dsn, threads):
+def load_data(dsn: str, threads: int) -> None:
     """ Copy data into the word and placex table.
     """
     sel = selectors.DefaultSelector()
     """ Copy data into the word and placex table.
     """
     sel = selectors.DefaultSelector()
@@ -239,12 +225,13 @@ def load_data(dsn, threads):
         print('.', end='', flush=True)
     print('\n')
 
         print('.', end='', flush=True)
     print('\n')
 
-    with connect(dsn) as conn:
-        with conn.cursor() as cur:
+    with connect(dsn) as syn_conn:
+        with syn_conn.cursor() as cur:
             cur.execute('ANALYSE')
 
 
             cur.execute('ANALYSE')
 
 
-def create_search_indices(conn, config, drop=False):
+def create_search_indices(conn: Connection, config: Configuration,
+                          drop: bool = False, threads: int = 1) -> None:
     """ Create tables that have explicit partitioning.
     """
 
     """ Create tables that have explicit partitioning.
     """
 
@@ -257,44 +244,10 @@ def create_search_indices(conn, config, drop=False):
         bad_indices = [row[0] for row in list(cur)]
         for idx in bad_indices:
             LOG.info("Drop invalid index %s.", idx)
         bad_indices = [row[0] for row in list(cur)]
         for idx in bad_indices:
             LOG.info("Drop invalid index %s.", idx)
-            cur.execute('DROP INDEX "{}"'.format(idx))
+            cur.execute(pysql.SQL('DROP INDEX {}').format(pysql.Identifier(idx)))
     conn.commit()
 
     sql = SQLPreprocessor(conn, config)
 
     conn.commit()
 
     sql = SQLPreprocessor(conn, config)
 
-    sql.run_sql_file(conn, 'indices.sql', drop=drop)
-
-
-def create_country_names(conn, tokenizer, languages=None):
-    """ Add default country names to search index. `languages` is a comma-
-        separated list of language codes as used in OSM. If `languages` is not
-        empty then only name translations for the given languages are added
-        to the index.
-    """
-    if languages:
-        languages = languages.split(',')
-
-    def _include_key(key):
-        return key == 'name' or \
-               (key.startswith('name:') and (not languages or key[5:] in languages))
-
-    with conn.cursor() as cur:
-        psycopg2.extras.register_hstore(cur)
-        cur.execute("""SELECT country_code, name FROM country_name
-                       WHERE country_code is not null""")
-
-        with tokenizer.name_analyzer() as analyzer:
-            for code, name in cur:
-                names = {'countrycode': code}
-                if code == 'gb':
-                    names['short_name'] = 'UK'
-                if code == 'us':
-                    names['short_name'] = 'United States'
-
-                # country names (only in languages as provided)
-                if name:
-                    names.update(((k, v) for k, v in name.items() if _include_key(k)))
-
-                analyzer.add_country_names(code, names)
-
-    conn.commit()
+    sql.run_parallel_sql_file(config.get_libpq_dsn(),
+                              'indices.sql', min(8, threads), drop=drop)