from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
import logging
import os
-import selectors
import subprocess
+import asyncio
from pathlib import Path
import psutil
-from psycopg2 import sql as pysql
+import psycopg
+from psycopg import sql as pysql
from ..errors import UsageError
from ..config import Configuration
-from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\
+from ..db.connection import connect, get_pg_env, Connection, server_version_tuple, \
postgis_version_tuple, drop_tables, table_exists, execute_scalar
-from ..db.async_connection import DBConnection
from ..db.sql_preprocessor import SQLPreprocessor
+from ..db.query_pool import QueryPool
from .exec_utils import run_osm2pgsql
from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
LOG = logging.getLogger()
+
def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None:
""" Compares the version for the given module and raises an exception
if the actual version is too old.
with conn.cursor() as cur:
cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
-
- postgis_version = postgis_version_tuple(conn)
- if postgis_version[0] >= 3:
- cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
+ cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
conn.commit()
with connect(options['dsn']) as conn:
if not ignore_errors:
with conn.cursor() as cur:
- cur.execute('SELECT * FROM place LIMIT 1')
+ cur.execute('SELECT true FROM place LIMIT 1')
if cur.rowcount == 0:
raise UsageError('No data imported by osm2pgsql.')
'extratags', 'geometry')))
-def load_data(dsn: str, threads: int) -> None:
+async def load_data(dsn: str, threads: int) -> None:
""" Copy data into the word and placex table.
"""
- sel = selectors.DefaultSelector()
- # Then copy data from place to placex in <threads - 1> chunks.
- place_threads = max(1, threads - 1)
- for imod in range(place_threads):
- conn = DBConnection(dsn)
- conn.connect()
- conn.perform(
- pysql.SQL("""INSERT INTO placex ({columns})
- SELECT {columns} FROM place
- WHERE osm_id % {total} = {mod}
- AND NOT (class='place' and (type='houses' or type='postcode'))
- AND ST_IsValid(geometry)
- """).format(columns=_COPY_COLUMNS,
- total=pysql.Literal(place_threads),
- mod=pysql.Literal(imod)))
- sel.register(conn, selectors.EVENT_READ, conn)
-
- # Address interpolations go into another table.
- conn = DBConnection(dsn)
- conn.connect()
- conn.perform("""INSERT INTO location_property_osmline (osm_id, address, linegeo)
- SELECT osm_id, address, geometry FROM place
- WHERE class='place' and type='houses' and osm_type='W'
- and ST_GeometryType(geometry) = 'ST_LineString'
- """)
- sel.register(conn, selectors.EVENT_READ, conn)
-
- # Now wait for all of them to finish.
- todo = place_threads + 1
- while todo > 0:
- for key, _ in sel.select(1):
- conn = key.data
- sel.unregister(conn)
- conn.wait()
- conn.close()
- todo -= 1
+ placex_threads = max(1, threads - 1)
+
+ progress = asyncio.create_task(_progress_print())
+
+ async with QueryPool(dsn, placex_threads + 1) as pool:
+ # Copy data from place to placex in <threads - 1> chunks.
+ for imod in range(placex_threads):
+ await pool.put_query(
+ pysql.SQL("""INSERT INTO placex ({columns})
+ SELECT {columns} FROM place
+ WHERE osm_id % {total} = {mod}
+ AND NOT (class='place'
+ and (type='houses' or type='postcode'))
+ AND ST_IsValid(geometry)
+ """).format(columns=_COPY_COLUMNS,
+ total=pysql.Literal(placex_threads),
+ mod=pysql.Literal(imod)), None)
+
+ # Interpolations need to be copied seperately
+ await pool.put_query("""
+ INSERT INTO location_property_osmline (osm_id, address, linegeo)
+ SELECT osm_id, address, geometry FROM place
+ WHERE class='place' and type='houses' and osm_type='W'
+ and ST_GeometryType(geometry) = 'ST_LineString' """, None)
+
+ progress.cancel()
+
+ async with await psycopg.AsyncConnection.connect(dsn) as aconn:
+ await aconn.execute('ANALYSE')
+
+
+async def _progress_print() -> None:
+ while True:
+ try:
+ await asyncio.sleep(1)
+ except asyncio.CancelledError:
+ print('', flush=True)
+ break
print('.', end='', flush=True)
- print('\n')
-
- with connect(dsn) as syn_conn:
- with syn_conn.cursor() as cur:
- cur.execute('ANALYSE')
-def create_search_indices(conn: Connection, config: Configuration,
- drop: bool = False, threads: int = 1) -> None:
+async def create_search_indices(conn: Connection, config: Configuration,
+ drop: bool = False, threads: int = 1) -> None:
""" Create tables that have explicit partitioning.
"""
sql = SQLPreprocessor(conn, config)
- sql.run_parallel_sql_file(config.get_libpq_dsn(),
- 'indices.sql', min(8, threads), drop=drop)
+ await sql.run_parallel_sql_file(config.get_libpq_dsn(),
+ 'indices.sql', min(8, threads), drop=drop)