X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/3742fa2929619a4c54a50d3e79e0eeadb4d6ca6f..2735ea768aa812998a9498cf411563f118bd6ad6:/src/nominatim_db/tools/database_import.py diff --git a/src/nominatim_db/tools/database_import.py b/src/nominatim_db/tools/database_import.py index 2398d404..415e9d24 100644 --- a/src/nominatim_db/tools/database_import.py +++ b/src/nominatim_db/tools/database_import.py @@ -10,24 +10,26 @@ Functions for setting up and importing a new Nominatim database. from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any import logging import os -import selectors import subprocess +import asyncio from pathlib import Path import psutil -from psycopg2 import sql as pysql +import psycopg +from psycopg import sql as pysql from ..errors import UsageError from ..config import Configuration -from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\ +from ..db.connection import connect, get_pg_env, Connection, server_version_tuple, \ postgis_version_tuple, drop_tables, table_exists, execute_scalar -from ..db.async_connection import DBConnection from ..db.sql_preprocessor import SQLPreprocessor +from ..db.query_pool import QueryPool from .exec_utils import run_osm2pgsql from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION LOG = logging.getLogger() + def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None: """ Compares the version for the given module and raises an exception if the actual version is too old. @@ -136,7 +138,7 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]], with connect(options['dsn']) as conn: if not ignore_errors: with conn.cursor() as cur: - cur.execute('SELECT * FROM place LIMIT 1') + cur.execute('SELECT true FROM place LIMIT 1') if cur.rowcount == 0: raise UsageError('No data imported by osm2pgsql.') @@ -205,55 +207,52 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier, 'extratags', 'geometry'))) -def load_data(dsn: str, threads: int) -> None: +async def load_data(dsn: str, threads: int) -> None: """ Copy data into the word and placex table. """ - sel = selectors.DefaultSelector() - # Then copy data from place to placex in chunks. - place_threads = max(1, threads - 1) - for imod in range(place_threads): - conn = DBConnection(dsn) - conn.connect() - conn.perform( - pysql.SQL("""INSERT INTO placex ({columns}) - SELECT {columns} FROM place - WHERE osm_id % {total} = {mod} - AND NOT (class='place' and (type='houses' or type='postcode')) - AND ST_IsValid(geometry) - """).format(columns=_COPY_COLUMNS, - total=pysql.Literal(place_threads), - mod=pysql.Literal(imod))) - sel.register(conn, selectors.EVENT_READ, conn) - - # Address interpolations go into another table. - conn = DBConnection(dsn) - conn.connect() - conn.perform("""INSERT INTO location_property_osmline (osm_id, address, linegeo) - SELECT osm_id, address, geometry FROM place - WHERE class='place' and type='houses' and osm_type='W' - and ST_GeometryType(geometry) = 'ST_LineString' - """) - sel.register(conn, selectors.EVENT_READ, conn) - - # Now wait for all of them to finish. - todo = place_threads + 1 - while todo > 0: - for key, _ in sel.select(1): - conn = key.data - sel.unregister(conn) - conn.wait() - conn.close() - todo -= 1 + placex_threads = max(1, threads - 1) + + progress = asyncio.create_task(_progress_print()) + + async with QueryPool(dsn, placex_threads + 1) as pool: + # Copy data from place to placex in chunks. + for imod in range(placex_threads): + await pool.put_query( + pysql.SQL("""INSERT INTO placex ({columns}) + SELECT {columns} FROM place + WHERE osm_id % {total} = {mod} + AND NOT (class='place' + and (type='houses' or type='postcode')) + AND ST_IsValid(geometry) + """).format(columns=_COPY_COLUMNS, + total=pysql.Literal(placex_threads), + mod=pysql.Literal(imod)), None) + + # Interpolations need to be copied seperately + await pool.put_query(""" + INSERT INTO location_property_osmline (osm_id, address, linegeo) + SELECT osm_id, address, geometry FROM place + WHERE class='place' and type='houses' and osm_type='W' + and ST_GeometryType(geometry) = 'ST_LineString' """, None) + + progress.cancel() + + async with await psycopg.AsyncConnection.connect(dsn) as aconn: + await aconn.execute('ANALYSE') + + +async def _progress_print() -> None: + while True: + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + print('', flush=True) + break print('.', end='', flush=True) - print('\n') - - with connect(dsn) as syn_conn: - with syn_conn.cursor() as cur: - cur.execute('ANALYSE') -def create_search_indices(conn: Connection, config: Configuration, - drop: bool = False, threads: int = 1) -> None: +async def create_search_indices(conn: Connection, config: Configuration, + drop: bool = False, threads: int = 1) -> None: """ Create tables that have explicit partitioning. """ @@ -271,5 +270,5 @@ def create_search_indices(conn: Connection, config: Configuration, sql = SQLPreprocessor(conn, config) - sql.run_parallel_sql_file(config.get_libpq_dsn(), - 'indices.sql', min(8, threads), drop=drop) + await sql.run_parallel_sql_file(config.get_libpq_dsn(), + 'indices.sql', min(8, threads), drop=drop)