X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/6e89310a9285f1ad15d8002bf68f578eada367a0..2735ea768aa812998a9498cf411563f118bd6ad6:/src/nominatim_db/tools/database_import.py?ds=sidebyside diff --git a/src/nominatim_db/tools/database_import.py b/src/nominatim_db/tools/database_import.py index 84f2f325..415e9d24 100644 --- a/src/nominatim_db/tools/database_import.py +++ b/src/nominatim_db/tools/database_import.py @@ -10,23 +10,26 @@ Functions for setting up and importing a new Nominatim database. from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any import logging import os -import selectors import subprocess +import asyncio from pathlib import Path import psutil -from psycopg2 import sql as pysql - -from nominatim_core.errors import UsageError -from nominatim_core.config import Configuration -from nominatim_core.db.connection import connect, get_pg_env, Connection -from nominatim_core.db.async_connection import DBConnection -from nominatim_core.db.sql_preprocessor import SQLPreprocessor +import psycopg +from psycopg import sql as pysql + +from ..errors import UsageError +from ..config import Configuration +from ..db.connection import connect, get_pg_env, Connection, server_version_tuple, \ + postgis_version_tuple, drop_tables, table_exists, execute_scalar +from ..db.sql_preprocessor import SQLPreprocessor +from ..db.query_pool import QueryPool from .exec_utils import run_osm2pgsql from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION LOG = logging.getLogger() + def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None: """ Compares the version for the given module and raises an exception if the actual version is too old. @@ -40,19 +43,21 @@ def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, def _require_loaded(extension_name: str, conn: Connection) -> None: """ Check that the given extension is loaded. """ - if not conn.extension_loaded(extension_name): - LOG.fatal('Required module %s is not loaded.', extension_name) - raise UsageError(f'{extension_name} is not loaded.') + with conn.cursor() as cur: + cur.execute('SELECT * FROM pg_extension WHERE extname = %s', (extension_name, )) + if cur.rowcount <= 0: + LOG.fatal('Required module %s is not loaded.', extension_name) + raise UsageError(f'{extension_name} is not loaded.') def check_existing_database_plugins(dsn: str) -> None: """ Check that the database has the required plugins installed.""" with connect(dsn) as conn: _require_version('PostgreSQL server', - conn.server_version_tuple(), + server_version_tuple(conn), POSTGRESQL_REQUIRED_VERSION) _require_version('PostGIS', - conn.postgis_version_tuple(), + postgis_version_tuple(conn), POSTGIS_REQUIRED_VERSION) _require_loaded('hstore', conn) @@ -78,31 +83,30 @@ def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None: with connect(dsn) as conn: _require_version('PostgreSQL server', - conn.server_version_tuple(), + server_version_tuple(conn), POSTGRESQL_REQUIRED_VERSION) if rouser is not None: - with conn.cursor() as cur: - cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s', + cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s', (rouser, )) - if cnt == 0: - LOG.fatal("Web user '%s' does not exist. Create it with:\n" - "\n createuser %s", rouser, rouser) - raise UsageError('Missing read-only user.') + if cnt == 0: + LOG.fatal("Web user '%s' does not exist. Create it with:\n" + "\n createuser %s", rouser, rouser) + raise UsageError('Missing read-only user.') # Create extensions. with conn.cursor() as cur: cur.execute('CREATE EXTENSION IF NOT EXISTS hstore') cur.execute('CREATE EXTENSION IF NOT EXISTS postgis') - postgis_version = conn.postgis_version_tuple() + postgis_version = postgis_version_tuple(conn) if postgis_version[0] >= 3: cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster') conn.commit() _require_version('PostGIS', - conn.postgis_version_tuple(), + postgis_version_tuple(conn), POSTGIS_REQUIRED_VERSION) @@ -134,12 +138,13 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]], with connect(options['dsn']) as conn: if not ignore_errors: with conn.cursor() as cur: - cur.execute('SELECT * FROM place LIMIT 1') + cur.execute('SELECT true FROM place LIMIT 1') if cur.rowcount == 0: raise UsageError('No data imported by osm2pgsql.') if drop: - conn.drop_table('planet_osm_nodes') + drop_tables(conn, 'planet_osm_nodes') + conn.commit() if drop and options['flatnode_file']: Path(options['flatnode_file']).unlink() @@ -182,7 +187,7 @@ def truncate_data_tables(conn: Connection) -> None: cur.execute('TRUNCATE location_property_tiger') cur.execute('TRUNCATE location_property_osmline') cur.execute('TRUNCATE location_postcode') - if conn.table_exists('search_name'): + if table_exists(conn, 'search_name'): cur.execute('TRUNCATE search_name') cur.execute('DROP SEQUENCE IF EXISTS seq_place') cur.execute('CREATE SEQUENCE seq_place start 100000') @@ -202,55 +207,52 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier, 'extratags', 'geometry'))) -def load_data(dsn: str, threads: int) -> None: +async def load_data(dsn: str, threads: int) -> None: """ Copy data into the word and placex table. """ - sel = selectors.DefaultSelector() - # Then copy data from place to placex in chunks. - place_threads = max(1, threads - 1) - for imod in range(place_threads): - conn = DBConnection(dsn) - conn.connect() - conn.perform( - pysql.SQL("""INSERT INTO placex ({columns}) - SELECT {columns} FROM place - WHERE osm_id % {total} = {mod} - AND NOT (class='place' and (type='houses' or type='postcode')) - AND ST_IsValid(geometry) - """).format(columns=_COPY_COLUMNS, - total=pysql.Literal(place_threads), - mod=pysql.Literal(imod))) - sel.register(conn, selectors.EVENT_READ, conn) - - # Address interpolations go into another table. - conn = DBConnection(dsn) - conn.connect() - conn.perform("""INSERT INTO location_property_osmline (osm_id, address, linegeo) - SELECT osm_id, address, geometry FROM place - WHERE class='place' and type='houses' and osm_type='W' - and ST_GeometryType(geometry) = 'ST_LineString' - """) - sel.register(conn, selectors.EVENT_READ, conn) - - # Now wait for all of them to finish. - todo = place_threads + 1 - while todo > 0: - for key, _ in sel.select(1): - conn = key.data - sel.unregister(conn) - conn.wait() - conn.close() - todo -= 1 + placex_threads = max(1, threads - 1) + + progress = asyncio.create_task(_progress_print()) + + async with QueryPool(dsn, placex_threads + 1) as pool: + # Copy data from place to placex in chunks. + for imod in range(placex_threads): + await pool.put_query( + pysql.SQL("""INSERT INTO placex ({columns}) + SELECT {columns} FROM place + WHERE osm_id % {total} = {mod} + AND NOT (class='place' + and (type='houses' or type='postcode')) + AND ST_IsValid(geometry) + """).format(columns=_COPY_COLUMNS, + total=pysql.Literal(placex_threads), + mod=pysql.Literal(imod)), None) + + # Interpolations need to be copied seperately + await pool.put_query(""" + INSERT INTO location_property_osmline (osm_id, address, linegeo) + SELECT osm_id, address, geometry FROM place + WHERE class='place' and type='houses' and osm_type='W' + and ST_GeometryType(geometry) = 'ST_LineString' """, None) + + progress.cancel() + + async with await psycopg.AsyncConnection.connect(dsn) as aconn: + await aconn.execute('ANALYSE') + + +async def _progress_print() -> None: + while True: + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + print('', flush=True) + break print('.', end='', flush=True) - print('\n') - - with connect(dsn) as syn_conn: - with syn_conn.cursor() as cur: - cur.execute('ANALYSE') -def create_search_indices(conn: Connection, config: Configuration, - drop: bool = False, threads: int = 1) -> None: +async def create_search_indices(conn: Connection, config: Configuration, + drop: bool = False, threads: int = 1) -> None: """ Create tables that have explicit partitioning. """ @@ -268,5 +270,5 @@ def create_search_indices(conn: Connection, config: Configuration, sql = SQLPreprocessor(conn, config) - sql.run_parallel_sql_file(config.get_libpq_dsn(), - 'indices.sql', min(8, threads), drop=drop) + await sql.run_parallel_sql_file(config.get_libpq_dsn(), + 'indices.sql', min(8, threads), drop=drop)