]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/database_import.py
add factory for query analyzer
[nominatim.git] / nominatim / tools / database_import.py
index caec9035a6d46882a66709db0ff1a69c7d67b8db..cb620d41fb8f31126fe69a622bf14130e38494d1 100644 (file)
@@ -7,6 +7,7 @@
 """
 Functions for setting up and importing a new Nominatim database.
 """
 """
 Functions for setting up and importing a new Nominatim database.
 """
+from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
 import logging
 import os
 import selectors
 import logging
 import os
 import selectors
@@ -16,7 +17,8 @@ from pathlib import Path
 import psutil
 from psycopg2 import sql as pysql
 
 import psutil
 from psycopg2 import sql as pysql
 
-from nominatim.db.connection import connect, get_pg_env
+from nominatim.config import Configuration
+from nominatim.db.connection import connect, get_pg_env, Connection
 from nominatim.db.async_connection import DBConnection
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.tools.exec_utils import run_osm2pgsql
 from nominatim.db.async_connection import DBConnection
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.tools.exec_utils import run_osm2pgsql
@@ -25,7 +27,7 @@ from nominatim.version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERS
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-def _require_version(module, actual, expected):
+def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None:
     """ Compares the version for the given module and raises an exception
         if the actual version is too old.
     """
     """ Compares the version for the given module and raises an exception
         if the actual version is too old.
     """
@@ -36,7 +38,7 @@ def _require_version(module, actual, expected):
         raise UsageError(f'{module} is too old.')
 
 
         raise UsageError(f'{module} is too old.')
 
 
-def setup_database_skeleton(dsn, rouser=None):
+def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
     """ Create a new database for Nominatim and populate it with the
         essential extensions.
 
     """ Create a new database for Nominatim and populate it with the
         essential extensions.
 
@@ -65,7 +67,7 @@ def setup_database_skeleton(dsn, rouser=None):
                 cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
                                  (rouser, ))
                 if cnt == 0:
                 cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
                                  (rouser, ))
                 if cnt == 0:
-                    LOG.fatal("Web user '%s' does not exists. Create it with:\n"
+                    LOG.fatal("Web user '%s' does not exist. Create it with:\n"
                               "\n      createuser %s", rouser, rouser)
                     raise UsageError('Missing read-only user.')
 
                               "\n      createuser %s", rouser, rouser)
                     raise UsageError('Missing read-only user.')
 
@@ -73,6 +75,11 @@ def setup_database_skeleton(dsn, rouser=None):
         with conn.cursor() as cur:
             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
         with conn.cursor() as cur:
             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
+
+            postgis_version = conn.postgis_version_tuple()
+            if postgis_version[0] >= 3:
+                cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
+
         conn.commit()
 
         _require_version('PostGIS',
         conn.commit()
 
         _require_version('PostGIS',
@@ -80,7 +87,9 @@ def setup_database_skeleton(dsn, rouser=None):
                          POSTGIS_REQUIRED_VERSION)
 
 
                          POSTGIS_REQUIRED_VERSION)
 
 
-def import_osm_data(osm_files, options, drop=False, ignore_errors=False):
+def import_osm_data(osm_files: Union[Path, Sequence[Path]],
+                    options: MutableMapping[str, Any],
+                    drop: bool = False, ignore_errors: bool = False) -> None:
     """ Import the given OSM files. 'options' contains the list of
         default settings for osm2pgsql.
     """
     """ Import the given OSM files. 'options' contains the list of
         default settings for osm2pgsql.
     """
@@ -117,7 +126,7 @@ def import_osm_data(osm_files, options, drop=False, ignore_errors=False):
         Path(options['flatnode_file']).unlink()
 
 
         Path(options['flatnode_file']).unlink()
 
 
-def create_tables(conn, config, reverse_only=False):
+def create_tables(conn: Connection, config: Configuration, reverse_only: bool = False) -> None:
     """ Create the set of basic tables.
         When `reverse_only` is True, then the main table for searching will
         be skipped and only reverse search is possible.
     """ Create the set of basic tables.
         When `reverse_only` is True, then the main table for searching will
         be skipped and only reverse search is possible.
@@ -128,7 +137,7 @@ def create_tables(conn, config, reverse_only=False):
     sql.run_sql_file(conn, 'tables.sql')
 
 
     sql.run_sql_file(conn, 'tables.sql')
 
 
-def create_table_triggers(conn, config):
+def create_table_triggers(conn: Connection, config: Configuration) -> None:
     """ Create the triggers for the tables. The trigger functions must already
         have been imported with refresh.create_functions().
     """
     """ Create the triggers for the tables. The trigger functions must already
         have been imported with refresh.create_functions().
     """
@@ -136,14 +145,14 @@ def create_table_triggers(conn, config):
     sql.run_sql_file(conn, 'table-triggers.sql')
 
 
     sql.run_sql_file(conn, 'table-triggers.sql')
 
 
-def create_partition_tables(conn, config):
+def create_partition_tables(conn: Connection, config: Configuration) -> None:
     """ Create tables that have explicit partitioning.
     """
     sql = SQLPreprocessor(conn, config)
     sql.run_sql_file(conn, 'partition-tables.src.sql')
 
 
     """ Create tables that have explicit partitioning.
     """
     sql = SQLPreprocessor(conn, config)
     sql.run_sql_file(conn, 'partition-tables.src.sql')
 
 
-def truncate_data_tables(conn):
+def truncate_data_tables(conn: Connection) -> None:
     """ Truncate all data tables to prepare for a fresh load.
     """
     with conn.cursor() as cur:
     """ Truncate all data tables to prepare for a fresh load.
     """
     with conn.cursor() as cur:
@@ -174,7 +183,7 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
                                          'extratags', 'geometry')))
 
 
                                          'extratags', 'geometry')))
 
 
-def load_data(dsn, threads):
+def load_data(dsn: str, threads: int) -> None:
     """ Copy data into the word and placex table.
     """
     sel = selectors.DefaultSelector()
     """ Copy data into the word and placex table.
     """
     sel = selectors.DefaultSelector()
@@ -216,12 +225,13 @@ def load_data(dsn, threads):
         print('.', end='', flush=True)
     print('\n')
 
         print('.', end='', flush=True)
     print('\n')
 
-    with connect(dsn) as conn:
-        with conn.cursor() as cur:
+    with connect(dsn) as syn_conn:
+        with syn_conn.cursor() as cur:
             cur.execute('ANALYSE')
 
 
             cur.execute('ANALYSE')
 
 
-def create_search_indices(conn, config, drop=False):
+def create_search_indices(conn: Connection, config: Configuration,
+                          drop: bool = False, threads: int = 1) -> None:
     """ Create tables that have explicit partitioning.
     """
 
     """ Create tables that have explicit partitioning.
     """
 
@@ -234,9 +244,10 @@ def create_search_indices(conn, config, drop=False):
         bad_indices = [row[0] for row in list(cur)]
         for idx in bad_indices:
             LOG.info("Drop invalid index %s.", idx)
         bad_indices = [row[0] for row in list(cur)]
         for idx in bad_indices:
             LOG.info("Drop invalid index %s.", idx)
-            cur.execute('DROP INDEX "{}"'.format(idx))
+            cur.execute(pysql.SQL('DROP INDEX {}').format(pysql.Identifier(idx)))
     conn.commit()
 
     sql = SQLPreprocessor(conn, config)
 
     conn.commit()
 
     sql = SQLPreprocessor(conn, config)
 
-    sql.run_sql_file(conn, 'indices.sql', drop=drop)
+    sql.run_parallel_sql_file(config.get_libpq_dsn(),
+                              'indices.sql', min(8, threads), drop=drop)