]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_db/tools/database_import.py
fix style issue found by flake8
[nominatim.git] / src / nominatim_db / tools / database_import.py
index 2398d4043a734ceb6b3cf45971e998133fe6456a..415e9d249f3277ce95bc048c9c8858b684cfd1c3 100644 (file)
@@ -10,24 +10,26 @@ Functions for setting up and importing a new Nominatim database.
 from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
 import logging
 import os
-import selectors
 import subprocess
+import asyncio
 from pathlib import Path
 
 import psutil
-from psycopg2 import sql as pysql
+import psycopg
+from psycopg import sql as pysql
 
 from ..errors import UsageError
 from ..config import Configuration
-from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\
+from ..db.connection import connect, get_pg_env, Connection, server_version_tuple, \
                             postgis_version_tuple, drop_tables, table_exists, execute_scalar
-from ..db.async_connection import DBConnection
 from ..db.sql_preprocessor import SQLPreprocessor
+from ..db.query_pool import QueryPool
 from .exec_utils import run_osm2pgsql
 from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
 
 LOG = logging.getLogger()
 
+
 def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None:
     """ Compares the version for the given module and raises an exception
         if the actual version is too old.
@@ -136,7 +138,7 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]],
     with connect(options['dsn']) as conn:
         if not ignore_errors:
             with conn.cursor() as cur:
-                cur.execute('SELECT * FROM place LIMIT 1')
+                cur.execute('SELECT true FROM place LIMIT 1')
                 if cur.rowcount == 0:
                     raise UsageError('No data imported by osm2pgsql.')
 
@@ -205,55 +207,52 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
                                          'extratags', 'geometry')))
 
 
-def load_data(dsn: str, threads: int) -> None:
+async def load_data(dsn: str, threads: int) -> None:
     """ Copy data into the word and placex table.
     """
-    sel = selectors.DefaultSelector()
-    # Then copy data from place to placex in <threads - 1> chunks.
-    place_threads = max(1, threads - 1)
-    for imod in range(place_threads):
-        conn = DBConnection(dsn)
-        conn.connect()
-        conn.perform(
-            pysql.SQL("""INSERT INTO placex ({columns})
-                           SELECT {columns} FROM place
-                           WHERE osm_id % {total} = {mod}
-                             AND NOT (class='place' and (type='houses' or type='postcode'))
-                             AND ST_IsValid(geometry)
-                      """).format(columns=_COPY_COLUMNS,
-                                  total=pysql.Literal(place_threads),
-                                  mod=pysql.Literal(imod)))
-        sel.register(conn, selectors.EVENT_READ, conn)
-
-    # Address interpolations go into another table.
-    conn = DBConnection(dsn)
-    conn.connect()
-    conn.perform("""INSERT INTO location_property_osmline (osm_id, address, linegeo)
-                      SELECT osm_id, address, geometry FROM place
-                      WHERE class='place' and type='houses' and osm_type='W'
-                            and ST_GeometryType(geometry) = 'ST_LineString'
-                 """)
-    sel.register(conn, selectors.EVENT_READ, conn)
-
-    # Now wait for all of them to finish.
-    todo = place_threads + 1
-    while todo > 0:
-        for key, _ in sel.select(1):
-            conn = key.data
-            sel.unregister(conn)
-            conn.wait()
-            conn.close()
-            todo -= 1
+    placex_threads = max(1, threads - 1)
+
+    progress = asyncio.create_task(_progress_print())
+
+    async with QueryPool(dsn, placex_threads + 1) as pool:
+        # Copy data from place to placex in <threads - 1> chunks.
+        for imod in range(placex_threads):
+            await pool.put_query(
+                pysql.SQL("""INSERT INTO placex ({columns})
+                               SELECT {columns} FROM place
+                                WHERE osm_id % {total} = {mod}
+                                  AND NOT (class='place'
+                                           and (type='houses' or type='postcode'))
+                                  AND ST_IsValid(geometry)
+                          """).format(columns=_COPY_COLUMNS,
+                                      total=pysql.Literal(placex_threads),
+                                      mod=pysql.Literal(imod)), None)
+
+        # Interpolations need to be copied seperately
+        await pool.put_query("""
+                INSERT INTO location_property_osmline (osm_id, address, linegeo)
+                  SELECT osm_id, address, geometry FROM place
+                  WHERE class='place' and type='houses' and osm_type='W'
+                        and ST_GeometryType(geometry) = 'ST_LineString' """, None)
+
+    progress.cancel()
+
+    async with await psycopg.AsyncConnection.connect(dsn) as aconn:
+        await aconn.execute('ANALYSE')
+
+
+async def _progress_print() -> None:
+    while True:
+        try:
+            await asyncio.sleep(1)
+        except asyncio.CancelledError:
+            print('', flush=True)
+            break
         print('.', end='', flush=True)
-    print('\n')
-
-    with connect(dsn) as syn_conn:
-        with syn_conn.cursor() as cur:
-            cur.execute('ANALYSE')
 
 
-def create_search_indices(conn: Connection, config: Configuration,
-                          drop: bool = False, threads: int = 1) -> None:
+async def create_search_indices(conn: Connection, config: Configuration,
+                                drop: bool = False, threads: int = 1) -> None:
     """ Create tables that have explicit partitioning.
     """
 
@@ -271,5 +270,5 @@ def create_search_indices(conn: Connection, config: Configuration,
 
     sql = SQLPreprocessor(conn, config)
 
-    sql.run_parallel_sql_file(config.get_libpq_dsn(),
-                              'indices.sql', min(8, threads), drop=drop)
+    await sql.run_parallel_sql_file(config.get_libpq_dsn(),
+                                    'indices.sql', min(8, threads), drop=drop)