]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_db/tools/database_import.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / src / nominatim_db / tools / database_import.py
index d07febc8a3da97fc5e41b8b243e3e8ba0f703de1..415e9d249f3277ce95bc048c9c8858b684cfd1c3 100644 (file)
@@ -10,23 +10,26 @@ Functions for setting up and importing a new Nominatim database.
 from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
 import logging
 import os
-import selectors
 import subprocess
+import asyncio
 from pathlib import Path
 
 import psutil
-from psycopg2 import sql as pysql
+import psycopg
+from psycopg import sql as pysql
 
 from ..errors import UsageError
 from ..config import Configuration
-from ..db.connection import connect, get_pg_env, Connection
-from ..db.async_connection import DBConnection
+from ..db.connection import connect, get_pg_env, Connection, server_version_tuple, \
+                            postgis_version_tuple, drop_tables, table_exists, execute_scalar
 from ..db.sql_preprocessor import SQLPreprocessor
+from ..db.query_pool import QueryPool
 from .exec_utils import run_osm2pgsql
 from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
 
 LOG = logging.getLogger()
 
+
 def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None:
     """ Compares the version for the given module and raises an exception
         if the actual version is too old.
@@ -40,19 +43,21 @@ def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int,
 
 def _require_loaded(extension_name: str, conn: Connection) -> None:
     """ Check that the given extension is loaded. """
-    if not conn.extension_loaded(extension_name):
-        LOG.fatal('Required module %s is not loaded.', extension_name)
-        raise UsageError(f'{extension_name} is not loaded.')
+    with conn.cursor() as cur:
+        cur.execute('SELECT * FROM pg_extension WHERE extname = %s', (extension_name, ))
+        if cur.rowcount <= 0:
+            LOG.fatal('Required module %s is not loaded.', extension_name)
+            raise UsageError(f'{extension_name} is not loaded.')
 
 
 def check_existing_database_plugins(dsn: str) -> None:
     """ Check that the database has the required plugins installed."""
     with connect(dsn) as conn:
         _require_version('PostgreSQL server',
-                         conn.server_version_tuple(),
+                         server_version_tuple(conn),
                          POSTGRESQL_REQUIRED_VERSION)
         _require_version('PostGIS',
-                         conn.postgis_version_tuple(),
+                         postgis_version_tuple(conn),
                          POSTGIS_REQUIRED_VERSION)
         _require_loaded('hstore', conn)
 
@@ -78,31 +83,30 @@ def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
 
     with connect(dsn) as conn:
         _require_version('PostgreSQL server',
-                         conn.server_version_tuple(),
+                         server_version_tuple(conn),
                          POSTGRESQL_REQUIRED_VERSION)
 
         if rouser is not None:
-            with conn.cursor() as cur:
-                cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
+            cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s',
                                  (rouser, ))
-                if cnt == 0:
-                    LOG.fatal("Web user '%s' does not exist. Create it with:\n"
-                              "\n      createuser %s", rouser, rouser)
-                    raise UsageError('Missing read-only user.')
+            if cnt == 0:
+                LOG.fatal("Web user '%s' does not exist. Create it with:\n"
+                          "\n      createuser %s", rouser, rouser)
+                raise UsageError('Missing read-only user.')
 
         # Create extensions.
         with conn.cursor() as cur:
             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
 
-            postgis_version = conn.postgis_version_tuple()
+            postgis_version = postgis_version_tuple(conn)
             if postgis_version[0] >= 3:
                 cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
 
         conn.commit()
 
         _require_version('PostGIS',
-                         conn.postgis_version_tuple(),
+                         postgis_version_tuple(conn),
                          POSTGIS_REQUIRED_VERSION)
 
 
@@ -134,12 +138,13 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]],
     with connect(options['dsn']) as conn:
         if not ignore_errors:
             with conn.cursor() as cur:
-                cur.execute('SELECT * FROM place LIMIT 1')
+                cur.execute('SELECT true FROM place LIMIT 1')
                 if cur.rowcount == 0:
                     raise UsageError('No data imported by osm2pgsql.')
 
         if drop:
-            conn.drop_table('planet_osm_nodes')
+            drop_tables(conn, 'planet_osm_nodes')
+            conn.commit()
 
     if drop and options['flatnode_file']:
         Path(options['flatnode_file']).unlink()
@@ -182,7 +187,7 @@ def truncate_data_tables(conn: Connection) -> None:
         cur.execute('TRUNCATE location_property_tiger')
         cur.execute('TRUNCATE location_property_osmline')
         cur.execute('TRUNCATE location_postcode')
-        if conn.table_exists('search_name'):
+        if table_exists(conn, 'search_name'):
             cur.execute('TRUNCATE search_name')
         cur.execute('DROP SEQUENCE IF EXISTS seq_place')
         cur.execute('CREATE SEQUENCE seq_place start 100000')
@@ -202,55 +207,52 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
                                          'extratags', 'geometry')))
 
 
-def load_data(dsn: str, threads: int) -> None:
+async def load_data(dsn: str, threads: int) -> None:
     """ Copy data into the word and placex table.
     """
-    sel = selectors.DefaultSelector()
-    # Then copy data from place to placex in <threads - 1> chunks.
-    place_threads = max(1, threads - 1)
-    for imod in range(place_threads):
-        conn = DBConnection(dsn)
-        conn.connect()
-        conn.perform(
-            pysql.SQL("""INSERT INTO placex ({columns})
-                           SELECT {columns} FROM place
-                           WHERE osm_id % {total} = {mod}
-                             AND NOT (class='place' and (type='houses' or type='postcode'))
-                             AND ST_IsValid(geometry)
-                      """).format(columns=_COPY_COLUMNS,
-                                  total=pysql.Literal(place_threads),
-                                  mod=pysql.Literal(imod)))
-        sel.register(conn, selectors.EVENT_READ, conn)
-
-    # Address interpolations go into another table.
-    conn = DBConnection(dsn)
-    conn.connect()
-    conn.perform("""INSERT INTO location_property_osmline (osm_id, address, linegeo)
-                      SELECT osm_id, address, geometry FROM place
-                      WHERE class='place' and type='houses' and osm_type='W'
-                            and ST_GeometryType(geometry) = 'ST_LineString'
-                 """)
-    sel.register(conn, selectors.EVENT_READ, conn)
-
-    # Now wait for all of them to finish.
-    todo = place_threads + 1
-    while todo > 0:
-        for key, _ in sel.select(1):
-            conn = key.data
-            sel.unregister(conn)
-            conn.wait()
-            conn.close()
-            todo -= 1
+    placex_threads = max(1, threads - 1)
+
+    progress = asyncio.create_task(_progress_print())
+
+    async with QueryPool(dsn, placex_threads + 1) as pool:
+        # Copy data from place to placex in <threads - 1> chunks.
+        for imod in range(placex_threads):
+            await pool.put_query(
+                pysql.SQL("""INSERT INTO placex ({columns})
+                               SELECT {columns} FROM place
+                                WHERE osm_id % {total} = {mod}
+                                  AND NOT (class='place'
+                                           and (type='houses' or type='postcode'))
+                                  AND ST_IsValid(geometry)
+                          """).format(columns=_COPY_COLUMNS,
+                                      total=pysql.Literal(placex_threads),
+                                      mod=pysql.Literal(imod)), None)
+
+        # Interpolations need to be copied seperately
+        await pool.put_query("""
+                INSERT INTO location_property_osmline (osm_id, address, linegeo)
+                  SELECT osm_id, address, geometry FROM place
+                  WHERE class='place' and type='houses' and osm_type='W'
+                        and ST_GeometryType(geometry) = 'ST_LineString' """, None)
+
+    progress.cancel()
+
+    async with await psycopg.AsyncConnection.connect(dsn) as aconn:
+        await aconn.execute('ANALYSE')
+
+
+async def _progress_print() -> None:
+    while True:
+        try:
+            await asyncio.sleep(1)
+        except asyncio.CancelledError:
+            print('', flush=True)
+            break
         print('.', end='', flush=True)
-    print('\n')
-
-    with connect(dsn) as syn_conn:
-        with syn_conn.cursor() as cur:
-            cur.execute('ANALYSE')
 
 
-def create_search_indices(conn: Connection, config: Configuration,
-                          drop: bool = False, threads: int = 1) -> None:
+async def create_search_indices(conn: Connection, config: Configuration,
+                                drop: bool = False, threads: int = 1) -> None:
     """ Create tables that have explicit partitioning.
     """
 
@@ -268,5 +270,5 @@ def create_search_indices(conn: Connection, config: Configuration,
 
     sql = SQLPreprocessor(conn, config)
 
-    sql.run_parallel_sql_file(config.get_libpq_dsn(),
-                              'indices.sql', min(8, threads), drop=drop)
+    await sql.run_parallel_sql_file(config.get_libpq_dsn(),
+                                    'indices.sql', min(8, threads), drop=drop)