from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
import logging
import os
-import selectors
import subprocess
+import asyncio
from pathlib import Path
import psutil
-from psycopg2 import sql as pysql
+import psycopg
+from psycopg import sql as pysql
from ..errors import UsageError
from ..config import Configuration
-from ..db.connection import connect, get_pg_env, Connection
-from ..db.async_connection import DBConnection
+from ..db.connection import connect, get_pg_env, Connection, server_version_tuple, \
+ postgis_version_tuple, drop_tables, table_exists, execute_scalar
from ..db.sql_preprocessor import SQLPreprocessor
+from ..db.query_pool import QueryPool
from .exec_utils import run_osm2pgsql
from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
LOG = logging.getLogger()
+
def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None:
""" Compares the version for the given module and raises an exception
if the actual version is too old.
def _require_loaded(extension_name: str, conn: Connection) -> None:
""" Check that the given extension is loaded. """
- if not conn.extension_loaded(extension_name):
- LOG.fatal('Required module %s is not loaded.', extension_name)
- raise UsageError(f'{extension_name} is not loaded.')
+ with conn.cursor() as cur:
+ cur.execute('SELECT * FROM pg_extension WHERE extname = %s', (extension_name, ))
+ if cur.rowcount <= 0:
+ LOG.fatal('Required module %s is not loaded.', extension_name)
+ raise UsageError(f'{extension_name} is not loaded.')
def check_existing_database_plugins(dsn: str) -> None:
""" Check that the database has the required plugins installed."""
with connect(dsn) as conn:
_require_version('PostgreSQL server',
- conn.server_version_tuple(),
+ server_version_tuple(conn),
POSTGRESQL_REQUIRED_VERSION)
_require_version('PostGIS',
- conn.postgis_version_tuple(),
+ postgis_version_tuple(conn),
POSTGIS_REQUIRED_VERSION)
_require_loaded('hstore', conn)
with connect(dsn) as conn:
_require_version('PostgreSQL server',
- conn.server_version_tuple(),
+ server_version_tuple(conn),
POSTGRESQL_REQUIRED_VERSION)
if rouser is not None:
- with conn.cursor() as cur:
- cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
+ cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s',
(rouser, ))
- if cnt == 0:
- LOG.fatal("Web user '%s' does not exist. Create it with:\n"
- "\n createuser %s", rouser, rouser)
- raise UsageError('Missing read-only user.')
+ if cnt == 0:
+ LOG.fatal("Web user '%s' does not exist. Create it with:\n"
+ "\n createuser %s", rouser, rouser)
+ raise UsageError('Missing read-only user.')
# Create extensions.
with conn.cursor() as cur:
cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
-
- postgis_version = conn.postgis_version_tuple()
- if postgis_version[0] >= 3:
- cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
+ cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
conn.commit()
_require_version('PostGIS',
- conn.postgis_version_tuple(),
+ postgis_version_tuple(conn),
POSTGIS_REQUIRED_VERSION)
with connect(options['dsn']) as conn:
if not ignore_errors:
with conn.cursor() as cur:
- cur.execute('SELECT * FROM place LIMIT 1')
+ cur.execute('SELECT true FROM place LIMIT 1')
if cur.rowcount == 0:
raise UsageError('No data imported by osm2pgsql.')
if drop:
- conn.drop_table('planet_osm_nodes')
+ drop_tables(conn, 'planet_osm_nodes')
+ conn.commit()
if drop and options['flatnode_file']:
Path(options['flatnode_file']).unlink()
cur.execute('TRUNCATE location_property_tiger')
cur.execute('TRUNCATE location_property_osmline')
cur.execute('TRUNCATE location_postcode')
- if conn.table_exists('search_name'):
+ if table_exists(conn, 'search_name'):
cur.execute('TRUNCATE search_name')
cur.execute('DROP SEQUENCE IF EXISTS seq_place')
cur.execute('CREATE SEQUENCE seq_place start 100000')
'extratags', 'geometry')))
-def load_data(dsn: str, threads: int) -> None:
+async def load_data(dsn: str, threads: int) -> None:
""" Copy data into the word and placex table.
"""
- sel = selectors.DefaultSelector()
- # Then copy data from place to placex in <threads - 1> chunks.
- place_threads = max(1, threads - 1)
- for imod in range(place_threads):
- conn = DBConnection(dsn)
- conn.connect()
- conn.perform(
- pysql.SQL("""INSERT INTO placex ({columns})
- SELECT {columns} FROM place
- WHERE osm_id % {total} = {mod}
- AND NOT (class='place' and (type='houses' or type='postcode'))
- AND ST_IsValid(geometry)
- """).format(columns=_COPY_COLUMNS,
- total=pysql.Literal(place_threads),
- mod=pysql.Literal(imod)))
- sel.register(conn, selectors.EVENT_READ, conn)
-
- # Address interpolations go into another table.
- conn = DBConnection(dsn)
- conn.connect()
- conn.perform("""INSERT INTO location_property_osmline (osm_id, address, linegeo)
- SELECT osm_id, address, geometry FROM place
- WHERE class='place' and type='houses' and osm_type='W'
- and ST_GeometryType(geometry) = 'ST_LineString'
- """)
- sel.register(conn, selectors.EVENT_READ, conn)
-
- # Now wait for all of them to finish.
- todo = place_threads + 1
- while todo > 0:
- for key, _ in sel.select(1):
- conn = key.data
- sel.unregister(conn)
- conn.wait()
- conn.close()
- todo -= 1
+ placex_threads = max(1, threads - 1)
+
+ progress = asyncio.create_task(_progress_print())
+
+ async with QueryPool(dsn, placex_threads + 1) as pool:
+ # Copy data from place to placex in <threads - 1> chunks.
+ for imod in range(placex_threads):
+ await pool.put_query(
+ pysql.SQL("""INSERT INTO placex ({columns})
+ SELECT {columns} FROM place
+ WHERE osm_id % {total} = {mod}
+ AND NOT (class='place'
+ and (type='houses' or type='postcode'))
+ AND ST_IsValid(geometry)
+ """).format(columns=_COPY_COLUMNS,
+ total=pysql.Literal(placex_threads),
+ mod=pysql.Literal(imod)), None)
+
+ # Interpolations need to be copied seperately
+ await pool.put_query("""
+ INSERT INTO location_property_osmline (osm_id, address, linegeo)
+ SELECT osm_id, address, geometry FROM place
+ WHERE class='place' and type='houses' and osm_type='W'
+ and ST_GeometryType(geometry) = 'ST_LineString' """, None)
+
+ progress.cancel()
+
+ async with await psycopg.AsyncConnection.connect(dsn) as aconn:
+ await aconn.execute('ANALYSE')
+
+
+async def _progress_print() -> None:
+ while True:
+ try:
+ await asyncio.sleep(1)
+ except asyncio.CancelledError:
+ print('', flush=True)
+ break
print('.', end='', flush=True)
- print('\n')
-
- with connect(dsn) as syn_conn:
- with syn_conn.cursor() as cur:
- cur.execute('ANALYSE')
-def create_search_indices(conn: Connection, config: Configuration,
- drop: bool = False, threads: int = 1) -> None:
+async def create_search_indices(conn: Connection, config: Configuration,
+ drop: bool = False, threads: int = 1) -> None:
""" Create tables that have explicit partitioning.
"""
sql = SQLPreprocessor(conn, config)
- sql.run_parallel_sql_file(config.get_libpq_dsn(),
- 'indices.sql', min(8, threads), drop=drop)
+ await sql.run_parallel_sql_file(config.get_libpq_dsn(),
+ 'indices.sql', min(8, threads), drop=drop)