]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/database_import.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / tools / database_import.py
index 6195b44a2fc60593513dacabaa6644d6d3338f24..de7e6a4aa2018c06e7284b4120973351b8a04ea5 100644 (file)
@@ -7,6 +7,7 @@
 """
 Functions for setting up and importing a new Nominatim database.
 """
 """
 Functions for setting up and importing a new Nominatim database.
 """
+from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
 import logging
 import os
 import selectors
 import logging
 import os
 import selectors
@@ -16,16 +17,18 @@ from pathlib import Path
 import psutil
 from psycopg2 import sql as pysql
 
 import psutil
 from psycopg2 import sql as pysql
 
-from nominatim.db.connection import connect, get_pg_env
+from nominatim.config import Configuration
+from nominatim.db.connection import connect, get_pg_env, Connection
 from nominatim.db.async_connection import DBConnection
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.tools.exec_utils import run_osm2pgsql
 from nominatim.errors import UsageError
 from nominatim.db.async_connection import DBConnection
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.tools.exec_utils import run_osm2pgsql
 from nominatim.errors import UsageError
-from nominatim.version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
+from nominatim.version import POSTGRESQL_REQUIRED_VERSION, \
+                              POSTGIS_REQUIRED_VERSION
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-def _require_version(module, actual, expected):
+def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None:
     """ Compares the version for the given module and raises an exception
         if the actual version is too old.
     """
     """ Compares the version for the given module and raises an exception
         if the actual version is too old.
     """
@@ -36,7 +39,26 @@ def _require_version(module, actual, expected):
         raise UsageError(f'{module} is too old.')
 
 
         raise UsageError(f'{module} is too old.')
 
 
-def setup_database_skeleton(dsn, rouser=None):
+def _require_loaded(extension_name: str, conn: Connection) -> None:
+    """ Check that the given extension is loaded. """
+    if not conn.extension_loaded(extension_name):
+        LOG.fatal('Required module %s is not loaded.', extension_name)
+        raise UsageError(f'{extension_name} is not loaded.')
+
+
+def check_existing_database_plugins(dsn: str) -> None:
+    """ Check that the database has the required plugins installed."""
+    with connect(dsn) as conn:
+        _require_version('PostgreSQL server',
+                         conn.server_version_tuple(),
+                         POSTGRESQL_REQUIRED_VERSION)
+        _require_version('PostGIS',
+                         conn.postgis_version_tuple(),
+                         POSTGIS_REQUIRED_VERSION)
+        _require_loaded('hstore', conn)
+
+
+def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
     """ Create a new database for Nominatim and populate it with the
         essential extensions.
 
     """ Create a new database for Nominatim and populate it with the
         essential extensions.
 
@@ -73,6 +95,11 @@ def setup_database_skeleton(dsn, rouser=None):
         with conn.cursor() as cur:
             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
         with conn.cursor() as cur:
             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
+
+            postgis_version = conn.postgis_version_tuple()
+            if postgis_version[0] >= 3:
+                cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
+
         conn.commit()
 
         _require_version('PostGIS',
         conn.commit()
 
         _require_version('PostGIS',
@@ -80,7 +107,9 @@ def setup_database_skeleton(dsn, rouser=None):
                          POSTGIS_REQUIRED_VERSION)
 
 
                          POSTGIS_REQUIRED_VERSION)
 
 
-def import_osm_data(osm_files, options, drop=False, ignore_errors=False):
+def import_osm_data(osm_files: Union[Path, Sequence[Path]],
+                    options: MutableMapping[str, Any],
+                    drop: bool = False, ignore_errors: bool = False) -> None:
     """ Import the given OSM files. 'options' contains the list of
         default settings for osm2pgsql.
     """
     """ Import the given OSM files. 'options' contains the list of
         default settings for osm2pgsql.
     """
@@ -117,7 +146,7 @@ def import_osm_data(osm_files, options, drop=False, ignore_errors=False):
         Path(options['flatnode_file']).unlink()
 
 
         Path(options['flatnode_file']).unlink()
 
 
-def create_tables(conn, config, reverse_only=False):
+def create_tables(conn: Connection, config: Configuration, reverse_only: bool = False) -> None:
     """ Create the set of basic tables.
         When `reverse_only` is True, then the main table for searching will
         be skipped and only reverse search is possible.
     """ Create the set of basic tables.
         When `reverse_only` is True, then the main table for searching will
         be skipped and only reverse search is possible.
@@ -128,7 +157,7 @@ def create_tables(conn, config, reverse_only=False):
     sql.run_sql_file(conn, 'tables.sql')
 
 
     sql.run_sql_file(conn, 'tables.sql')
 
 
-def create_table_triggers(conn, config):
+def create_table_triggers(conn: Connection, config: Configuration) -> None:
     """ Create the triggers for the tables. The trigger functions must already
         have been imported with refresh.create_functions().
     """
     """ Create the triggers for the tables. The trigger functions must already
         have been imported with refresh.create_functions().
     """
@@ -136,14 +165,14 @@ def create_table_triggers(conn, config):
     sql.run_sql_file(conn, 'table-triggers.sql')
 
 
     sql.run_sql_file(conn, 'table-triggers.sql')
 
 
-def create_partition_tables(conn, config):
+def create_partition_tables(conn: Connection, config: Configuration) -> None:
     """ Create tables that have explicit partitioning.
     """
     sql = SQLPreprocessor(conn, config)
     sql.run_sql_file(conn, 'partition-tables.src.sql')
 
 
     """ Create tables that have explicit partitioning.
     """
     sql = SQLPreprocessor(conn, config)
     sql.run_sql_file(conn, 'partition-tables.src.sql')
 
 
-def truncate_data_tables(conn):
+def truncate_data_tables(conn: Connection) -> None:
     """ Truncate all data tables to prepare for a fresh load.
     """
     with conn.cursor() as cur:
     """ Truncate all data tables to prepare for a fresh load.
     """
     with conn.cursor() as cur:
@@ -174,7 +203,7 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
                                          'extratags', 'geometry')))
 
 
                                          'extratags', 'geometry')))
 
 
-def load_data(dsn, threads):
+def load_data(dsn: str, threads: int) -> None:
     """ Copy data into the word and placex table.
     """
     sel = selectors.DefaultSelector()
     """ Copy data into the word and placex table.
     """
     sel = selectors.DefaultSelector()
@@ -216,12 +245,13 @@ def load_data(dsn, threads):
         print('.', end='', flush=True)
     print('\n')
 
         print('.', end='', flush=True)
     print('\n')
 
-    with connect(dsn) as conn:
-        with conn.cursor() as cur:
+    with connect(dsn) as syn_conn:
+        with syn_conn.cursor() as cur:
             cur.execute('ANALYSE')
 
 
             cur.execute('ANALYSE')
 
 
-def create_search_indices(conn, config, drop=False):
+def create_search_indices(conn: Connection, config: Configuration,
+                          drop: bool = False, threads: int = 1) -> None:
     """ Create tables that have explicit partitioning.
     """
 
     """ Create tables that have explicit partitioning.
     """
 
@@ -239,4 +269,5 @@ def create_search_indices(conn, config, drop=False):
 
     sql = SQLPreprocessor(conn, config)
 
 
     sql = SQLPreprocessor(conn, config)
 
-    sql.run_sql_file(conn, 'indices.sql', drop=drop)
+    sql.run_parallel_sql_file(config.get_libpq_dsn(),
+                              'indices.sql', min(8, threads), drop=drop)