X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/3ea87169ac7b11919ad1261927619bba2e984da6..87c91ec5c42089797da1525bd270269b2cd9d3ad:/nominatim/tools/database_import.py diff --git a/nominatim/tools/database_import.py b/nominatim/tools/database_import.py index 6195b44a..cb620d41 100644 --- a/nominatim/tools/database_import.py +++ b/nominatim/tools/database_import.py @@ -7,6 +7,7 @@ """ Functions for setting up and importing a new Nominatim database. """ +from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any import logging import os import selectors @@ -16,7 +17,8 @@ from pathlib import Path import psutil from psycopg2 import sql as pysql -from nominatim.db.connection import connect, get_pg_env +from nominatim.config import Configuration +from nominatim.db.connection import connect, get_pg_env, Connection from nominatim.db.async_connection import DBConnection from nominatim.db.sql_preprocessor import SQLPreprocessor from nominatim.tools.exec_utils import run_osm2pgsql @@ -25,7 +27,7 @@ from nominatim.version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERS LOG = logging.getLogger() -def _require_version(module, actual, expected): +def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None: """ Compares the version for the given module and raises an exception if the actual version is too old. """ @@ -36,7 +38,7 @@ def _require_version(module, actual, expected): raise UsageError(f'{module} is too old.') -def setup_database_skeleton(dsn, rouser=None): +def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None: """ Create a new database for Nominatim and populate it with the essential extensions. @@ -73,6 +75,11 @@ def setup_database_skeleton(dsn, rouser=None): with conn.cursor() as cur: cur.execute('CREATE EXTENSION IF NOT EXISTS hstore') cur.execute('CREATE EXTENSION IF NOT EXISTS postgis') + + postgis_version = conn.postgis_version_tuple() + if postgis_version[0] >= 3: + cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster') + conn.commit() _require_version('PostGIS', @@ -80,7 +87,9 @@ def setup_database_skeleton(dsn, rouser=None): POSTGIS_REQUIRED_VERSION) -def import_osm_data(osm_files, options, drop=False, ignore_errors=False): +def import_osm_data(osm_files: Union[Path, Sequence[Path]], + options: MutableMapping[str, Any], + drop: bool = False, ignore_errors: bool = False) -> None: """ Import the given OSM files. 'options' contains the list of default settings for osm2pgsql. """ @@ -117,7 +126,7 @@ def import_osm_data(osm_files, options, drop=False, ignore_errors=False): Path(options['flatnode_file']).unlink() -def create_tables(conn, config, reverse_only=False): +def create_tables(conn: Connection, config: Configuration, reverse_only: bool = False) -> None: """ Create the set of basic tables. When `reverse_only` is True, then the main table for searching will be skipped and only reverse search is possible. @@ -128,7 +137,7 @@ def create_tables(conn, config, reverse_only=False): sql.run_sql_file(conn, 'tables.sql') -def create_table_triggers(conn, config): +def create_table_triggers(conn: Connection, config: Configuration) -> None: """ Create the triggers for the tables. The trigger functions must already have been imported with refresh.create_functions(). """ @@ -136,14 +145,14 @@ def create_table_triggers(conn, config): sql.run_sql_file(conn, 'table-triggers.sql') -def create_partition_tables(conn, config): +def create_partition_tables(conn: Connection, config: Configuration) -> None: """ Create tables that have explicit partitioning. """ sql = SQLPreprocessor(conn, config) sql.run_sql_file(conn, 'partition-tables.src.sql') -def truncate_data_tables(conn): +def truncate_data_tables(conn: Connection) -> None: """ Truncate all data tables to prepare for a fresh load. """ with conn.cursor() as cur: @@ -174,7 +183,7 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier, 'extratags', 'geometry'))) -def load_data(dsn, threads): +def load_data(dsn: str, threads: int) -> None: """ Copy data into the word and placex table. """ sel = selectors.DefaultSelector() @@ -216,12 +225,13 @@ def load_data(dsn, threads): print('.', end='', flush=True) print('\n') - with connect(dsn) as conn: - with conn.cursor() as cur: + with connect(dsn) as syn_conn: + with syn_conn.cursor() as cur: cur.execute('ANALYSE') -def create_search_indices(conn, config, drop=False): +def create_search_indices(conn: Connection, config: Configuration, + drop: bool = False, threads: int = 1) -> None: """ Create tables that have explicit partitioning. """ @@ -239,4 +249,5 @@ def create_search_indices(conn, config, drop=False): sql = SQLPreprocessor(conn, config) - sql.run_sql_file(conn, 'indices.sql', drop=drop) + sql.run_parallel_sql_file(config.get_libpq_dsn(), + 'indices.sql', min(8, threads), drop=drop)