From: Sarah Hoffmann Date: Fri, 5 Jul 2024 08:43:10 +0000 (+0200) Subject: port code to psycopg3 X-Git-Tag: deploy~4^2~8^2~2 X-Git-Url: https://git.openstreetmap.org./nominatim.git/commitdiff_plain/9659afbade47d1ec6d5359b2b21e1e874516ed80?ds=inline;hp=3742fa2929619a4c54a50d3e79e0eeadb4d6ca6f port code to psycopg3 --- diff --git a/lib-sql/indices.sql b/lib-sql/indices.sql index 8c176fdf..4d92452d 100644 --- a/lib-sql/indices.sql +++ b/lib-sql/indices.sql @@ -31,6 +31,7 @@ CREATE INDEX IF NOT EXISTS idx_placex_geometry ON placex -- Index is needed during import but can be dropped as soon as a full -- geometry index is in place. The partial index is almost as big as the full -- index. +--- DROP INDEX IF EXISTS idx_placex_geometry_lower_rank_ways; --- CREATE INDEX IF NOT EXISTS idx_placex_geometry_reverse_lookupPolygon @@ -60,7 +61,6 @@ CREATE INDEX IF NOT EXISTS idx_postcode_postcode --- DROP INDEX IF EXISTS idx_placex_geometry_address_area_candidates; DROP INDEX IF EXISTS idx_placex_geometry_buildings; - DROP INDEX IF EXISTS idx_placex_geometry_lower_rank_ways; DROP INDEX IF EXISTS idx_placex_wikidata; DROP INDEX IF EXISTS idx_placex_rank_address_sector; DROP INDEX IF EXISTS idx_placex_rank_boundaries_sector; diff --git a/packaging/nominatim-db/pyproject.toml b/packaging/nominatim-db/pyproject.toml index 652f683f..112f5a29 100644 --- a/packaging/nominatim-db/pyproject.toml +++ b/packaging/nominatim-db/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "psycopg2-binary", + "psycopg[pool]", "python-dotenv", "jinja2", "pyYAML>=5.1", diff --git a/src/nominatim_api/core.py b/src/nominatim_api/core.py index 632c97a7..c460d98c 100644 --- a/src/nominatim_api/core.py +++ b/src/nominatim_api/core.py @@ -7,7 +7,7 @@ """ Implementation of classes for API access via libraries. """ -from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, Tuple +from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, Tuple, cast import asyncio import sys import contextlib @@ -107,16 +107,16 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes raise UsageError(f"SQlite database '{params.get('dbname')}' does not exist.") else: dsn = self.config.get_database_params() - query = {k: v for k, v in dsn.items() + query = {k: str(v) for k, v in dsn.items() if k not in ('user', 'password', 'dbname', 'host', 'port')} dburl = sa.engine.URL.create( f'postgresql+{PGCORE_LIB}', - database=dsn.get('dbname'), - username=dsn.get('user'), - password=dsn.get('password'), - host=dsn.get('host'), - port=int(dsn['port']) if 'port' in dsn else None, + database=cast(str, dsn.get('dbname')), + username=cast(str, dsn.get('user')), + password=cast(str, dsn.get('password')), + host=cast(str, dsn.get('host')), + port=int(cast(str, dsn['port'])) if 'port' in dsn else None, query=query) engine = sa_asyncio.create_async_engine(dburl, **extra_args) diff --git a/src/nominatim_db/cli.py b/src/nominatim_db/cli.py index 41684fa1..93278668 100644 --- a/src/nominatim_db/cli.py +++ b/src/nominatim_db/cli.py @@ -14,6 +14,7 @@ import logging import os import sys import argparse +import asyncio from pathlib import Path from .config import Configuration @@ -170,22 +171,30 @@ class AdminServe: raise UsageError("PHP frontend not configured.") run_php_server(args.server, args.project_dir / 'website') else: - import uvicorn # pylint: disable=import-outside-toplevel - server_info = args.server.split(':', 1) - host = server_info[0] - if len(server_info) > 1: - if not server_info[1].isdigit(): - raise UsageError('Invalid format for --server parameter. Use :') - port = int(server_info[1]) - else: - port = 8088 + asyncio.run(self.run_uvicorn(args)) - server_module = importlib.import_module(f'nominatim_api.server.{args.engine}.server') + return 0 - app = server_module.get_application(args.project_dir) - uvicorn.run(app, host=host, port=port) - return 0 + async def run_uvicorn(self, args: NominatimArgs) -> None: + import uvicorn # pylint: disable=import-outside-toplevel + + server_info = args.server.split(':', 1) + host = server_info[0] + if len(server_info) > 1: + if not server_info[1].isdigit(): + raise UsageError('Invalid format for --server parameter. Use :') + port = int(server_info[1]) + else: + port = 8088 + + server_module = importlib.import_module(f'nominatim_api.server.{args.engine}.server') + + app = server_module.get_application(args.project_dir) + + config = uvicorn.Config(app, host=host, port=port) + server = uvicorn.Server(config) + await server.serve() def get_set_parser() -> CommandlineParser: diff --git a/src/nominatim_db/clicmd/add_data.py b/src/nominatim_db/clicmd/add_data.py index eced9907..a690435c 100644 --- a/src/nominatim_db/clicmd/add_data.py +++ b/src/nominatim_db/clicmd/add_data.py @@ -10,6 +10,7 @@ Implementation of the 'add-data' subcommand. from typing import cast import argparse import logging +import asyncio import psutil @@ -64,15 +65,10 @@ class UpdateAddData: def run(self, args: NominatimArgs) -> int: - from ..tokenizer import factory as tokenizer_factory - from ..tools import tiger_data, add_osm_data + from ..tools import add_osm_data if args.tiger_data: - tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config) - return tiger_data.add_tiger_data(args.tiger_data, - args.config, - args.threads or psutil.cpu_count() or 1, - tokenizer) + return asyncio.run(self._add_tiger_data(args)) osm2pgsql_params = args.osm2pgsql_options(default_cache=1000, default_threads=1) if args.file or args.diff: @@ -99,3 +95,16 @@ class UpdateAddData: osm2pgsql_params) return 0 + + + async def _add_tiger_data(self, args: NominatimArgs) -> int: + from ..tokenizer import factory as tokenizer_factory + from ..tools import tiger_data + + assert args.tiger_data + + tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config) + return await tiger_data.add_tiger_data(args.tiger_data, + args.config, + args.threads or psutil.cpu_count() or 1, + tokenizer) diff --git a/src/nominatim_db/clicmd/index.py b/src/nominatim_db/clicmd/index.py index 87e0fc03..c0619f34 100644 --- a/src/nominatim_db/clicmd/index.py +++ b/src/nominatim_db/clicmd/index.py @@ -8,6 +8,7 @@ Implementation of the 'index' subcommand. """ import argparse +import asyncio import psutil @@ -44,19 +45,7 @@ class UpdateIndex: def run(self, args: NominatimArgs) -> int: - from ..indexer.indexer import Indexer - from ..tokenizer import factory as tokenizer_factory - - tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config) - - indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, - args.threads or psutil.cpu_count() or 1) - - if not args.no_boundaries: - indexer.index_boundaries(args.minrank, args.maxrank) - if not args.boundaries_only: - indexer.index_by_rank(args.minrank, args.maxrank) - indexer.index_postcodes() + asyncio.run(self._do_index(args)) if not args.no_boundaries and not args.boundaries_only \ and args.minrank == 0 and args.maxrank == 30: @@ -64,3 +53,22 @@ class UpdateIndex: status.set_indexed(conn, True) return 0 + + + async def _do_index(self, args: NominatimArgs) -> None: + from ..tokenizer import factory as tokenizer_factory + + tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config) + from ..indexer.indexer import Indexer + + indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, + args.threads or psutil.cpu_count() or 1) + + has_pending = True # run at least once + while has_pending: + if not args.no_boundaries: + await indexer.index_boundaries(args.minrank, args.maxrank) + if not args.boundaries_only: + await indexer.index_by_rank(args.minrank, args.maxrank) + await indexer.index_postcodes() + has_pending = indexer.has_pending() diff --git a/src/nominatim_db/clicmd/refresh.py b/src/nominatim_db/clicmd/refresh.py index 363bad78..adc7ee65 100644 --- a/src/nominatim_db/clicmd/refresh.py +++ b/src/nominatim_db/clicmd/refresh.py @@ -11,6 +11,7 @@ from typing import Tuple, Optional import argparse import logging from pathlib import Path +import asyncio from ..config import Configuration from ..db.connection import connect, table_exists @@ -99,7 +100,7 @@ class UpdateRefresh: args.project_dir, tokenizer) indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, args.threads or 1) - indexer.index_postcodes() + asyncio.run(indexer.index_postcodes()) else: LOG.error("The place table doesn't exist. " "Postcode updates on a frozen database is not possible.") diff --git a/src/nominatim_db/clicmd/replication.py b/src/nominatim_db/clicmd/replication.py index 581c731e..ba4c7730 100644 --- a/src/nominatim_db/clicmd/replication.py +++ b/src/nominatim_db/clicmd/replication.py @@ -13,6 +13,7 @@ import datetime as dt import logging import socket import time +import asyncio from ..db import status from ..db.connection import connect @@ -123,7 +124,7 @@ class UpdateReplication: return update_interval - def _update(self, args: NominatimArgs) -> None: + async def _update(self, args: NominatimArgs) -> None: # pylint: disable=too-many-locals from ..tools import replication from ..indexer.indexer import Indexer @@ -161,7 +162,7 @@ class UpdateReplication: if state is not replication.UpdateState.NO_CHANGES and args.do_index: index_start = dt.datetime.now(dt.timezone.utc) - indexer.index_full(analyse=False) + await indexer.index_full(analyse=False) with connect(dsn) as conn: status.set_indexed(conn, True) @@ -172,8 +173,7 @@ class UpdateReplication: if state is replication.UpdateState.NO_CHANGES and \ args.catch_up or update_interval > 40*60: - while indexer.has_pending(): - indexer.index_full(analyse=False) + await indexer.index_full(analyse=False) if LOG.isEnabledFor(logging.WARNING): assert batchdate is not None @@ -196,5 +196,5 @@ class UpdateReplication: if args.check_for_updates: return self._check_for_updates(args) - self._update(args) + asyncio.run(self._update(args)) return 0 diff --git a/src/nominatim_db/clicmd/setup.py b/src/nominatim_db/clicmd/setup.py index f516ba0c..07a76f59 100644 --- a/src/nominatim_db/clicmd/setup.py +++ b/src/nominatim_db/clicmd/setup.py @@ -11,6 +11,7 @@ from typing import Optional import argparse import logging from pathlib import Path +import asyncio import psutil @@ -71,14 +72,6 @@ class SetupAll: def run(self, args: NominatimArgs) -> int: # pylint: disable=too-many-statements, too-many-branches - from ..data import country_info - from ..tools import database_import, refresh, postcodes, freeze - from ..indexer.indexer import Indexer - - num_threads = args.threads or psutil.cpu_count() or 1 - - country_info.setup_country_config(args.config) - if args.osm_file is None and args.continue_at is None and not args.prepare_database: raise UsageError("No input files (use --osm-file).") @@ -90,6 +83,16 @@ class SetupAll: "Cannot use --continue and --prepare-database together." ) + return asyncio.run(self.async_run(args)) + + + async def async_run(self, args: NominatimArgs) -> int: + from ..data import country_info + from ..tools import database_import, refresh, postcodes, freeze + from ..indexer.indexer import Indexer + + num_threads = args.threads or psutil.cpu_count() or 1 + country_info.setup_country_config(args.config) if args.prepare_database or args.continue_at is None: LOG.warning('Creating database') @@ -99,39 +102,7 @@ class SetupAll: return 0 if args.continue_at in (None, 'import-from-file'): - files = args.get_osm_file_list() - if not files: - raise UsageError("No input files (use --osm-file).") - - if args.continue_at in ('import-from-file', None): - # Check if the correct plugins are installed - database_import.check_existing_database_plugins(args.config.get_libpq_dsn()) - LOG.warning('Setting up country tables') - country_info.setup_country_tables(args.config.get_libpq_dsn(), - args.config.lib_dir.data, - args.no_partitions) - - LOG.warning('Importing OSM data file') - database_import.import_osm_data(files, - args.osm2pgsql_options(0, 1), - drop=args.no_updates, - ignore_errors=args.ignore_errors) - - LOG.warning('Importing wikipedia importance data') - data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir) - if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(), - data_path) > 0: - LOG.error('Wikipedia importance dump file not found. ' - 'Calculating importance values of locations will not ' - 'use Wikipedia importance data.') - - LOG.warning('Importing secondary importance raster data') - if refresh.import_secondary_importance(args.config.get_libpq_dsn(), - args.project_dir) != 0: - LOG.error('Secondary importance file not imported. ' - 'Falling back to default ranking.') - - self._setup_tables(args.config, args.reverse_only) + self._base_import(args) if args.continue_at in ('import-from-file', 'load-data', None): LOG.warning('Initialise tables') @@ -139,7 +110,7 @@ class SetupAll: database_import.truncate_data_tables(conn) LOG.warning('Load data into placex table') - database_import.load_data(args.config.get_libpq_dsn(), num_threads) + await database_import.load_data(args.config.get_libpq_dsn(), num_threads) LOG.warning("Setting up tokenizer") tokenizer = self._get_tokenizer(args.continue_at, args.config) @@ -153,13 +124,13 @@ class SetupAll: ('import-from-file', 'load-data', 'indexing', None): LOG.warning('Indexing places') indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, num_threads) - indexer.index_full(analyse=not args.index_noanalyse) + await indexer.index_full(analyse=not args.index_noanalyse) LOG.warning('Post-process tables') with connect(args.config.get_libpq_dsn()) as conn: - database_import.create_search_indices(conn, args.config, - drop=args.no_updates, - threads=num_threads) + await database_import.create_search_indices(conn, args.config, + drop=args.no_updates, + threads=num_threads) LOG.warning('Create search index for default country names.') country_info.create_country_names(conn, tokenizer, args.config.get_str_list('LANGUAGES')) @@ -180,6 +151,45 @@ class SetupAll: return 0 + def _base_import(self, args: NominatimArgs) -> None: + from ..tools import database_import, refresh + from ..data import country_info + + files = args.get_osm_file_list() + if not files: + raise UsageError("No input files (use --osm-file).") + + if args.continue_at in ('import-from-file', None): + # Check if the correct plugins are installed + database_import.check_existing_database_plugins(args.config.get_libpq_dsn()) + LOG.warning('Setting up country tables') + country_info.setup_country_tables(args.config.get_libpq_dsn(), + args.config.lib_dir.data, + args.no_partitions) + + LOG.warning('Importing OSM data file') + database_import.import_osm_data(files, + args.osm2pgsql_options(0, 1), + drop=args.no_updates, + ignore_errors=args.ignore_errors) + + LOG.warning('Importing wikipedia importance data') + data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir) + if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(), + data_path) > 0: + LOG.error('Wikipedia importance dump file not found. ' + 'Calculating importance values of locations will not ' + 'use Wikipedia importance data.') + + LOG.warning('Importing secondary importance raster data') + if refresh.import_secondary_importance(args.config.get_libpq_dsn(), + args.project_dir) != 0: + LOG.error('Secondary importance file not imported. ' + 'Falling back to default ranking.') + + self._setup_tables(args.config, args.reverse_only) + + def _setup_tables(self, config: Configuration, reverse_only: bool) -> None: """ Set up the basic database layout: tables, indexes and functions. """ diff --git a/src/nominatim_db/config.py b/src/nominatim_db/config.py index c4264f0d..5ae3dea3 100644 --- a/src/nominatim_db/config.py +++ b/src/nominatim_db/config.py @@ -7,7 +7,7 @@ """ Nominatim configuration accessor. """ -from typing import Dict, Any, List, Mapping, Optional +from typing import Union, Dict, Any, List, Mapping, Optional import importlib.util import logging import os @@ -18,10 +18,7 @@ import yaml from dotenv import dotenv_values -try: - from psycopg2.extensions import parse_dsn -except ModuleNotFoundError: - from psycopg.conninfo import conninfo_to_dict as parse_dsn # type: ignore[assignment] +from psycopg.conninfo import conninfo_to_dict from .typing import StrPath from .errors import UsageError @@ -198,7 +195,7 @@ class Configuration: return dsn - def get_database_params(self) -> Mapping[str, str]: + def get_database_params(self) -> Mapping[str, Union[str, int, None]]: """ Get the configured parameters for the database connection as a mapping. """ @@ -207,7 +204,7 @@ class Configuration: if dsn.startswith('pgsql:'): return dict((p.split('=', 1) for p in dsn[6:].split(';'))) - return parse_dsn(dsn) + return conninfo_to_dict(dsn) def get_import_style_file(self) -> Path: diff --git a/src/nominatim_db/data/country_info.py b/src/nominatim_db/data/country_info.py index e2bf5133..9b714059 100644 --- a/src/nominatim_db/data/country_info.py +++ b/src/nominatim_db/data/country_info.py @@ -138,9 +138,10 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals country_default_language_code text, partition integer ); """) - cur.execute_values( + cur.executemany( """ INSERT INTO public.country_name - (country_code, name, country_default_language_code, partition) VALUES %s + (country_code, name, country_default_language_code, partition) + VALUES (%s, %s, %s, %s) """, params) conn.commit() diff --git a/src/nominatim_db/db/async_connection.py b/src/nominatim_db/db/async_connection.py deleted file mode 100644 index 83e4c865..00000000 --- a/src/nominatim_db/db/async_connection.py +++ /dev/null @@ -1,236 +0,0 @@ -# SPDX-License-Identifier: GPL-3.0-or-later -# -# This file is part of Nominatim. (https://nominatim.org) -# -# Copyright (C) 2024 by the Nominatim developer community. -# For a full list of authors see the git log. -""" Non-blocking database connections. -""" -from typing import Callable, Any, Optional, Iterator, Sequence -import logging -import select -import time - -import psycopg2 -from psycopg2.extras import wait_select - -# psycopg2 emits different exceptions pre and post 2.8. Detect if the new error -# module is available and adapt the error handling accordingly. -try: - import psycopg2.errors # pylint: disable=no-name-in-module,import-error - __has_psycopg2_errors__ = True -except ImportError: - __has_psycopg2_errors__ = False - -from ..typing import T_cursor, Query - -LOG = logging.getLogger() - -class DeadlockHandler: - """ Context manager that catches deadlock exceptions and calls - the given handler function. All other exceptions are passed on - normally. - """ - - def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None: - self.handler = handler - self.ignore_sql_errors = ignore_sql_errors - - def __enter__(self) -> 'DeadlockHandler': - return self - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool: - if __has_psycopg2_errors__: - if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101 - self.handler() - return True - elif exc_type == psycopg2.extensions.TransactionRollbackError \ - and exc_value.pgcode == '40P01': - self.handler() - return True - - if self.ignore_sql_errors and isinstance(exc_value, psycopg2.Error): - LOG.info("SQL error ignored: %s", exc_value) - return True - - return False - - -class DBConnection: - """ A single non-blocking database connection. - """ - - def __init__(self, dsn: str, - cursor_factory: Optional[Callable[..., T_cursor]] = None, - ignore_sql_errors: bool = False) -> None: - self.dsn = dsn - - self.current_query: Optional[Query] = None - self.current_params: Optional[Sequence[Any]] = None - self.ignore_sql_errors = ignore_sql_errors - - self.conn: Optional['psycopg2._psycopg.connection'] = None - self.cursor: Optional['psycopg2._psycopg.cursor'] = None - self.connect(cursor_factory=cursor_factory) - - def close(self) -> None: - """ Close all open connections. Does not wait for pending requests. - """ - if self.conn is not None: - if self.cursor is not None: - self.cursor.close() - self.cursor = None - self.conn.close() - - self.conn = None - - def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None: - """ (Re)connect to the database. Creates an asynchronous connection - with JIT and parallel processing disabled. If a connection was - already open, it is closed and a new connection established. - The caller must ensure that no query is pending before reconnecting. - """ - self.close() - - # Use a dict to hand in the parameters because async is a reserved - # word in Python3. - self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True}) # type: ignore - assert self.conn - self.wait() - - if cursor_factory is not None: - self.cursor = self.conn.cursor(cursor_factory=cursor_factory) - else: - self.cursor = self.conn.cursor() - # Disable JIT and parallel workers as they are known to cause problems. - # Update pg_settings instead of using SET because it does not yield - # errors on older versions of Postgres where the settings are not - # implemented. - self.perform( - """ UPDATE pg_settings SET setting = -1 WHERE name = 'jit_above_cost'; - UPDATE pg_settings SET setting = 0 - WHERE name = 'max_parallel_workers_per_gather';""") - self.wait() - - def _deadlock_handler(self) -> None: - LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params)) - assert self.cursor is not None - assert self.current_query is not None - assert self.current_params is not None - - self.cursor.execute(self.current_query, self.current_params) - - def wait(self) -> None: - """ Block until any pending operation is done. - """ - while True: - with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors): - wait_select(self.conn) - self.current_query = None - return - - def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None: - """ Send SQL query to the server. Returns immediately without - blocking. - """ - assert self.cursor is not None - self.current_query = sql - self.current_params = args - self.cursor.execute(sql, args) - - def fileno(self) -> int: - """ File descriptor to wait for. (Makes this class select()able.) - """ - assert self.conn is not None - return self.conn.fileno() - - def is_done(self) -> bool: - """ Check if the connection is available for a new query. - - Also checks if the previous query has run into a deadlock. - If so, then the previous query is repeated. - """ - assert self.conn is not None - - if self.current_query is None: - return True - - with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors): - if self.conn.poll() == psycopg2.extensions.POLL_OK: - self.current_query = None - return True - - return False - - -class WorkerPool: - """ A pool of asynchronous database connections. - - The pool may be used as a context manager. - """ - REOPEN_CONNECTIONS_AFTER = 100000 - - def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None: - self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors) - for _ in range(pool_size)] - self.free_workers = self._yield_free_worker() - self.wait_time = 0.0 - - - def finish_all(self) -> None: - """ Wait for all connection to finish. - """ - for thread in self.threads: - while not thread.is_done(): - thread.wait() - - self.free_workers = self._yield_free_worker() - - def close(self) -> None: - """ Close all connections and clear the pool. - """ - for thread in self.threads: - thread.close() - self.threads = [] - self.free_workers = iter([]) - - - def next_free_worker(self) -> DBConnection: - """ Get the next free connection. - """ - return next(self.free_workers) - - - def _yield_free_worker(self) -> Iterator[DBConnection]: - ready = self.threads - command_stat = 0 - while True: - for thread in ready: - if thread.is_done(): - command_stat += 1 - yield thread - - if command_stat > self.REOPEN_CONNECTIONS_AFTER: - self._reconnect_threads() - ready = self.threads - command_stat = 0 - else: - tstart = time.time() - _, ready, _ = select.select([], self.threads, []) - self.wait_time += time.time() - tstart - - - def _reconnect_threads(self) -> None: - for thread in self.threads: - while not thread.is_done(): - thread.wait() - thread.connect() - - - def __enter__(self) -> 'WorkerPool': - return self - - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - self.finish_all() - self.close() diff --git a/src/nominatim_db/db/connection.py b/src/nominatim_db/db/connection.py index 629fad6a..6c7e843f 100644 --- a/src/nominatim_db/db/connection.py +++ b/src/nominatim_db/db/connection.py @@ -7,73 +7,27 @@ """ Specialised connection and cursor functions. """ -from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\ - Tuple, Iterable -import contextlib +from typing import Optional, Any, Dict, Tuple import logging import os -import psycopg2 -import psycopg2.extensions -import psycopg2.extras -from psycopg2 import sql as pysql +import psycopg +import psycopg.types.hstore +from psycopg import sql as pysql -from ..typing import SysEnv, Query, T_cursor +from ..typing import SysEnv from ..errors import UsageError LOG = logging.getLogger() -class Cursor(psycopg2.extras.DictCursor): - """ A cursor returning dict-like objects and providing specialised - execution functions. - """ - # pylint: disable=arguments-renamed,arguments-differ - def execute(self, query: Query, args: Any = None) -> None: - """ Query execution that logs the SQL query when debugging is enabled. - """ - if LOG.isEnabledFor(logging.DEBUG): - LOG.debug(self.mogrify(query, args).decode('utf-8')) - - super().execute(query, args) - - - def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]], - template: Optional[Query] = None) -> None: - """ Wrapper for the psycopg2 convenience function to execute - SQL for a list of values. - """ - LOG.debug("SQL execute_values(%s, %s)", sql, argslist) - - psycopg2.extras.execute_values(self, sql, argslist, template=template) - - -class Connection(psycopg2.extensions.connection): - """ A connection that provides the specialised cursor by default and - adds convenience functions for administrating the database. - """ - @overload # type: ignore[override] - def cursor(self) -> Cursor: - ... - - @overload - def cursor(self, name: str) -> Cursor: - ... - - @overload - def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor: - ... +Cursor = psycopg.Cursor[Any] +Connection = psycopg.Connection[Any] - def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore - """ Return a new cursor. By default the specialised cursor is returned. - """ - return super().cursor(cursor_factory=cursor_factory, **kwargs) - - -def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any: +def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any: """ Execute query that returns a single value. The value is returned. If the query yields more than one row, a ValueError is raised. """ - with conn.cursor() as cur: + with conn.cursor(row_factory=psycopg.rows.tuple_row) as cur: cur.execute(sql, args) if cur.rowcount != 1: @@ -144,7 +98,7 @@ def server_version_tuple(conn: Connection) -> Tuple[int, int]: """ Return the server version as a tuple of (major, minor). Converts correctly for pre-10 and post-10 PostgreSQL versions. """ - version = conn.server_version + version = conn.info.server_version if version < 100000: return (int(version / 10000), int((version % 10000) / 100)) @@ -164,31 +118,25 @@ def postgis_version_tuple(conn: Connection) -> Tuple[int, int]: return (int(version_parts[0]), int(version_parts[1])) + def register_hstore(conn: Connection) -> None: """ Register the hstore type with psycopg for the connection. """ - psycopg2.extras.register_hstore(conn) - - -class ConnectionContext(ContextManager[Connection]): - """ Context manager of the connection that also provides direct access - to the underlying connection. - """ - connection: Connection + info = psycopg.types.TypeInfo.fetch(conn, "hstore") + if info is None: + raise RuntimeError('Hstore extension is requested but not installed.') + psycopg.types.hstore.register_hstore(info, conn) -def connect(dsn: str) -> ConnectionContext: +def connect(dsn: str, **kwargs: Any) -> Connection: """ Open a connection to the database using the specialised connection factory. The returned object may be used in conjunction with 'with'. When used outside a context manager, use the `connection` attribute to get the connection. """ try: - conn = psycopg2.connect(dsn, connection_factory=Connection) - ctxmgr = cast(ConnectionContext, contextlib.closing(conn)) - ctxmgr.connection = conn - return ctxmgr - except psycopg2.OperationalError as err: + return psycopg.connect(dsn, row_factory=psycopg.rows.namedtuple_row, **kwargs) + except psycopg.OperationalError as err: raise UsageError(f"Cannot connect to database: {err}") from err @@ -233,10 +181,18 @@ def get_pg_env(dsn: str, """ env = dict(base_env if base_env is not None else os.environ) - for param, value in psycopg2.extensions.parse_dsn(dsn).items(): + for param, value in psycopg.conninfo.conninfo_to_dict(dsn).items(): if param in _PG_CONNECTION_STRINGS: - env[_PG_CONNECTION_STRINGS[param]] = value + env[_PG_CONNECTION_STRINGS[param]] = str(value) else: LOG.error("Unknown connection parameter '%s' ignored.", param) return env + + +async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None: + """ Open a connection to the database and run a single query + asynchronously. + """ + async with await psycopg.AsyncConnection.connect(dsn) as aconn: + await aconn.execute(query) diff --git a/src/nominatim_db/db/query_pool.py b/src/nominatim_db/db/query_pool.py new file mode 100644 index 00000000..2828937f --- /dev/null +++ b/src/nominatim_db/db/query_pool.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2024 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +A connection pool that executes incoming queries in parallel. +""" +from typing import Any, Tuple, Optional +import asyncio +import logging +import time + +import psycopg + +LOG = logging.getLogger() + +QueueItem = Optional[Tuple[psycopg.abc.Query, Any]] + +class QueryPool: + """ Pool to run SQL queries in parallel asynchronous execution. + + All queries are run in autocommit mode. If parallel execution leads + to a deadlock, then the query is repeated. + The results of the queries is discarded. + """ + def __init__(self, dsn: str, pool_size: int = 1, **conn_args: Any) -> None: + self.wait_time = 0.0 + self.query_queue: 'asyncio.Queue[QueueItem]' = asyncio.Queue(maxsize=2 * pool_size) + + self.pool = [asyncio.create_task(self._worker_loop(dsn, **conn_args)) + for _ in range(pool_size)] + + + async def put_query(self, query: psycopg.abc.Query, params: Any) -> None: + """ Schedule a query for execution. + """ + tstart = time.time() + await self.query_queue.put((query, params)) + self.wait_time += time.time() - tstart + await asyncio.sleep(0) + + + async def finish(self) -> None: + """ Wait for all queries to finish and close the pool. + """ + for _ in self.pool: + await self.query_queue.put(None) + + tstart = time.time() + await asyncio.wait(self.pool) + self.wait_time += time.time() - tstart + + for task in self.pool: + excp = task.exception() + if excp is not None: + raise excp + + + async def _worker_loop(self, dsn: str, **conn_args: Any) -> None: + conn_args['autocommit'] = True + aconn = await psycopg.AsyncConnection.connect(dsn, **conn_args) + async with aconn: + async with aconn.cursor() as cur: + item = await self.query_queue.get() + while item is not None: + try: + if item[1] is None: + await cur.execute(item[0]) + else: + await cur.execute(item[0], item[1]) + + item = await self.query_queue.get() + except psycopg.errors.DeadlockDetected: + assert item is not None + LOG.info("Deadlock detected (sql = %s, params = %s), retry.", + str(item[0]), str(item[1])) + # item is still valid here, causing a retry + + + async def __aenter__(self) -> 'QueryPool': + return self + + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + await self.finish() diff --git a/src/nominatim_db/db/sql_preprocessor.py b/src/nominatim_db/db/sql_preprocessor.py index 691ab6c5..25faead4 100644 --- a/src/nominatim_db/db/sql_preprocessor.py +++ b/src/nominatim_db/db/sql_preprocessor.py @@ -8,11 +8,12 @@ Preprocessing of SQL files. """ from typing import Set, Dict, Any, cast + import jinja2 from .connection import Connection, server_version_tuple, postgis_version_tuple -from .async_connection import WorkerPool from ..config import Configuration +from ..db.query_pool import QueryPool def _get_partitions(conn: Connection) -> Set[int]: """ Get the set of partitions currently in use. @@ -125,8 +126,8 @@ class SQLPreprocessor: conn.commit() - def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1, - **kwargs: Any) -> None: + async def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1, + **kwargs: Any) -> None: """ Execute the given SQL files using parallel asynchronous connections. The keyword arguments may supply additional parameters for preprocessing. @@ -138,6 +139,6 @@ class SQLPreprocessor: parts = sql.split('\n---\n') - with WorkerPool(dsn, num_threads) as pool: + async with QueryPool(dsn, num_threads) as pool: for part in parts: - pool.next_free_worker().perform(part) + await pool.put_query(part, None) diff --git a/src/nominatim_db/db/status.py b/src/nominatim_db/db/status.py index 1d2b3bec..4fe9f444 100644 --- a/src/nominatim_db/db/status.py +++ b/src/nominatim_db/db/status.py @@ -7,7 +7,7 @@ """ Access and helper functions for the status and status log table. """ -from typing import Optional, Tuple, cast +from typing import Optional, Tuple import datetime as dt import logging import re @@ -15,20 +15,11 @@ import re from .connection import Connection, table_exists, execute_scalar from ..utils.url_utils import get_url from ..errors import UsageError -from ..typing import TypedDict LOG = logging.getLogger() ISODATE_FORMAT = '%Y-%m-%dT%H:%M:%S' -class StatusRow(TypedDict): - """ Dictionary of columns of the import_status table. - """ - lastimportdate: dt.datetime - sequence_id: Optional[int] - indexed: Optional[bool] - - def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetime: """ Determine the date of the database from the newest object in the data base. @@ -102,8 +93,9 @@ def get_status(conn: Connection) -> Tuple[Optional[dt.datetime], Optional[int], if cur.rowcount < 1: return None, None, None - row = cast(StatusRow, cur.fetchone()) - return row['lastimportdate'], row['sequence_id'], row['indexed'] + row = cur.fetchone() + assert row + return row.lastimportdate, row.sequence_id, row.indexed def set_indexed(conn: Connection, state: bool) -> None: diff --git a/src/nominatim_db/db/utils.py b/src/nominatim_db/db/utils.py index 32bf79ac..02e5bd2d 100644 --- a/src/nominatim_db/db/utils.py +++ b/src/nominatim_db/db/utils.py @@ -7,14 +7,13 @@ """ Helper functions for handling DB accesses. """ -from typing import IO, Optional, Union, Any, Iterable +from typing import IO, Optional, Union import subprocess import logging import gzip -import io from pathlib import Path -from .connection import get_pg_env, Cursor +from .connection import get_pg_env from ..errors import UsageError LOG = logging.getLogger() @@ -72,58 +71,3 @@ def execute_file(dsn: str, fname: Path, if ret != 0 or remain > 0: raise UsageError("Failed to execute SQL file.") - - -# List of characters that need to be quoted for the copy command. -_SQL_TRANSLATION = {ord('\\'): '\\\\', - ord('\t'): '\\t', - ord('\n'): '\\n'} - - -class CopyBuffer: - """ Data collector for the copy_from command. - """ - - def __init__(self) -> None: - self.buffer = io.StringIO() - - - def __enter__(self) -> 'CopyBuffer': - return self - - - def size(self) -> int: - """ Return the number of bytes the buffer currently contains. - """ - return self.buffer.tell() - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - if self.buffer is not None: - self.buffer.close() - - - def add(self, *data: Any) -> None: - """ Add another row of data to the copy buffer. - """ - first = True - for column in data: - if first: - first = False - else: - self.buffer.write('\t') - if column is None: - self.buffer.write('\\N') - else: - self.buffer.write(str(column).translate(_SQL_TRANSLATION)) - self.buffer.write('\n') - - - def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None: - """ Copy all collected data into the given table. - - The buffer is empty and reusable after this operation. - """ - if self.buffer.tell() > 0: - self.buffer.seek(0) - cur.copy_from(self.buffer, table, columns=columns) - self.buffer = io.StringIO() diff --git a/src/nominatim_db/indexer/indexer.py b/src/nominatim_db/indexer/indexer.py index b4c9732c..9680e5a9 100644 --- a/src/nominatim_db/indexer/indexer.py +++ b/src/nominatim_db/indexer/indexer.py @@ -7,92 +7,20 @@ """ Main work horse for indexing (computing addresses) the database. """ -from typing import Optional, Any, cast +from typing import cast, List, Any import logging import time -import psycopg2.extras +import psycopg -from ..typing import DictCursorResults -from ..db.async_connection import DBConnection, WorkerPool -from ..db.connection import connect, Connection, Cursor, execute_scalar, register_hstore +from ..db.connection import connect, execute_scalar +from ..db.query_pool import QueryPool from ..tokenizer.base import AbstractTokenizer from .progress import ProgressLogger from . import runners LOG = logging.getLogger() - -class PlaceFetcher: - """ Asynchronous connection that fetches place details for processing. - """ - def __init__(self, dsn: str, setup_conn: Connection) -> None: - self.wait_time = 0.0 - self.current_ids: Optional[DictCursorResults] = None - self.conn: Optional[DBConnection] = DBConnection(dsn, - cursor_factory=psycopg2.extras.DictCursor) - - # need to fetch those manually because register_hstore cannot - # fetch them on an asynchronous connection below. - hstore_oid = execute_scalar(setup_conn, "SELECT 'hstore'::regtype::oid") - hstore_array_oid = execute_scalar(setup_conn, "SELECT 'hstore[]'::regtype::oid") - - psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid, - array_oid=hstore_array_oid) - - - def close(self) -> None: - """ Close the underlying asynchronous connection. - """ - if self.conn: - self.conn.close() - self.conn = None - - - def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool: - """ Send a request for the next batch of places. - If details for the places are required, they will be fetched - asynchronously. - - Returns true if there is still data available. - """ - ids = cast(Optional[DictCursorResults], cur.fetchmany(100)) - - if not ids: - self.current_ids = None - return False - - assert self.conn is not None - self.current_ids = runner.get_place_details(self.conn, ids) - - return True - - def get_batch(self) -> DictCursorResults: - """ Get the next batch of data, previously requested with - `fetch_next_batch`. - """ - assert self.conn is not None - assert self.conn.cursor is not None - - if self.current_ids is not None and not self.current_ids: - tstart = time.time() - self.conn.wait() - self.wait_time += time.time() - tstart - self.current_ids = cast(Optional[DictCursorResults], - self.conn.cursor.fetchall()) - - return self.current_ids if self.current_ids is not None else [] - - def __enter__(self) -> 'PlaceFetcher': - return self - - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - assert self.conn is not None - self.conn.wait() - self.close() - - class Indexer: """ Main indexing routine. """ @@ -114,7 +42,7 @@ class Indexer: return cur.rowcount > 0 - def index_full(self, analyse: bool = True) -> None: + async def index_full(self, analyse: bool = True) -> None: """ Index the complete database. This will first index boundaries followed by all other objects. When `analyse` is True, then the database will be analysed at the appropriate places to @@ -128,23 +56,27 @@ class Indexer: with conn.cursor() as cur: cur.execute('ANALYZE') - if self.index_by_rank(0, 4) > 0: - _analyze() + while True: + if await self.index_by_rank(0, 4) > 0: + _analyze() - if self.index_boundaries(0, 30) > 100: - _analyze() + if await self.index_boundaries(0, 30) > 100: + _analyze() - if self.index_by_rank(5, 25) > 100: - _analyze() + if await self.index_by_rank(5, 25) > 100: + _analyze() - if self.index_by_rank(26, 30) > 1000: - _analyze() + if await self.index_by_rank(26, 30) > 1000: + _analyze() - if self.index_postcodes() > 100: - _analyze() + if await self.index_postcodes() > 100: + _analyze() + if not self.has_pending(): + break - def index_boundaries(self, minrank: int, maxrank: int) -> int: + + async def index_boundaries(self, minrank: int, maxrank: int) -> int: """ Index only administrative boundaries within the given rank range. """ total = 0 @@ -153,11 +85,11 @@ class Indexer: with self.tokenizer.name_analyzer() as analyzer: for rank in range(max(minrank, 4), min(maxrank, 26)): - total += self._index(runners.BoundaryRunner(rank, analyzer)) + total += await self._index(runners.BoundaryRunner(rank, analyzer)) return total - def index_by_rank(self, minrank: int, maxrank: int) -> int: + async def index_by_rank(self, minrank: int, maxrank: int) -> int: """ Index all entries of placex in the given rank range (inclusive) in order of their address rank. @@ -171,21 +103,27 @@ class Indexer: with self.tokenizer.name_analyzer() as analyzer: for rank in range(max(1, minrank), maxrank + 1): - total += self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1) + if rank >= 30: + batch = 20 + elif rank >= 26: + batch = 5 + else: + batch = 1 + total += await self._index(runners.RankRunner(rank, analyzer), batch) if maxrank == 30: - total += self._index(runners.RankRunner(0, analyzer)) - total += self._index(runners.InterpolationRunner(analyzer), 20) + total += await self._index(runners.RankRunner(0, analyzer)) + total += await self._index(runners.InterpolationRunner(analyzer), 20) return total - def index_postcodes(self) -> int: + async def index_postcodes(self) -> int: """Index the entries of the location_postcode table. """ LOG.warning("Starting indexing postcodes using %s threads", self.num_threads) - return self._index(runners.PostcodeRunner(), 20) + return await self._index(runners.PostcodeRunner(), 20) def update_status_table(self) -> None: @@ -197,45 +135,58 @@ class Indexer: conn.commit() - def _index(self, runner: runners.Runner, batch: int = 1) -> int: + async def _index(self, runner: runners.Runner, batch: int = 1) -> int: """ Index a single rank or table. `runner` describes the SQL to use for indexing. `batch` describes the number of objects that should be processed with a single SQL statement """ LOG.warning("Starting %s (using batch size %s)", runner.name(), batch) - with connect(self.dsn) as conn: - register_hstore(conn) - total_tuples = execute_scalar(conn, runner.sql_count_objects()) - LOG.debug("Total number of rows: %i", total_tuples) + total_tuples = self._prepare_indexing(runner) - conn.commit() - - progress = ProgressLogger(runner.name(), total_tuples) + progress = ProgressLogger(runner.name(), total_tuples) - if total_tuples > 0: - with conn.cursor(name='places') as cur: - cur.execute(runner.sql_get_objects()) + if total_tuples > 0: + async with await psycopg.AsyncConnection.connect( + self.dsn, row_factory=psycopg.rows.dict_row) as aconn,\ + QueryPool(self.dsn, self.num_threads, autocommit=True) as pool: + fetcher_time = 0.0 + tstart = time.time() + async with aconn.cursor(name='places') as cur: + query = runner.index_places_query(batch) + params: List[Any] = [] + num_places = 0 + async for place in cur.stream(runner.sql_get_objects()): + fetcher_time += time.time() - tstart - with PlaceFetcher(self.dsn, conn) as fetcher: - with WorkerPool(self.dsn, self.num_threads) as pool: - has_more = fetcher.fetch_next_batch(cur, runner) - while has_more: - places = fetcher.get_batch() + params.extend(runner.index_places_params(place)) + num_places += 1 - # asynchronously get the next batch - has_more = fetcher.fetch_next_batch(cur, runner) + if num_places >= batch: + LOG.debug("Processing places: %s", str(params)) + await pool.put_query(query, params) + progress.add(num_places) + params = [] + num_places = 0 - # And insert the current batch - for idx in range(0, len(places), batch): - part = places[idx:idx + batch] - LOG.debug("Processing places: %s", str(part)) - runner.index_places(pool.next_free_worker(), part) - progress.add(len(part)) + tstart = time.time() - LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs", - fetcher.wait_time, pool.wait_time) + if num_places > 0: + await pool.put_query(runner.index_places_query(num_places), params) - conn.commit() + LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs", + fetcher_time, pool.wait_time) return progress.done() + + + def _prepare_indexing(self, runner: runners.Runner) -> int: + with connect(self.dsn) as conn: + hstore_info = psycopg.types.TypeInfo.fetch(conn, "hstore") + if hstore_info is None: + raise RuntimeError('Hstore extension is requested but not installed.') + psycopg.types.hstore.register_hstore(hstore_info) + + total_tuples = execute_scalar(conn, runner.sql_count_objects()) + LOG.debug("Total number of rows: %i", total_tuples) + return cast(int, total_tuples) diff --git a/src/nominatim_db/indexer/runners.py b/src/nominatim_db/indexer/runners.py index 7b98e240..d7737c07 100644 --- a/src/nominatim_db/indexer/runners.py +++ b/src/nominatim_db/indexer/runners.py @@ -8,14 +8,14 @@ Mix-ins that provide the actual commands for the indexer for various indexing tasks. """ -from typing import Any, List -import functools +from typing import Any, Sequence -from psycopg2 import sql as pysql -import psycopg2.extras +from psycopg import sql as pysql +from psycopg.abc import Query +from psycopg.rows import DictRow +from psycopg.types.json import Json -from ..typing import Query, DictCursorResult, DictCursorResults, Protocol -from ..db.async_connection import DBConnection +from ..typing import Protocol from ..data.place_info import PlaceInfo from ..tokenizer.base import AbstractAnalyzer @@ -24,58 +24,48 @@ from ..tokenizer.base import AbstractAnalyzer def _mk_valuelist(template: str, num: int) -> pysql.Composed: return pysql.SQL(',').join([pysql.SQL(template)] * num) -def _analyze_place(place: DictCursorResult, analyzer: AbstractAnalyzer) -> psycopg2.extras.Json: - return psycopg2.extras.Json(analyzer.process_place(PlaceInfo(place))) +def _analyze_place(place: DictRow, analyzer: AbstractAnalyzer) -> Json: + return Json(analyzer.process_place(PlaceInfo(place))) class Runner(Protocol): def name(self) -> str: ... def sql_count_objects(self) -> Query: ... def sql_get_objects(self) -> Query: ... - def get_place_details(self, worker: DBConnection, - ids: DictCursorResults) -> DictCursorResults: ... - def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: ... + def index_places_query(self, batch_size: int) -> Query: ... + def index_places_params(self, place: DictRow) -> Sequence[Any]: ... +SELECT_SQL = pysql.SQL("""SELECT place_id, extra.* + FROM (SELECT * FROM placex {}) as px, + LATERAL placex_indexing_prepare(px) as extra """) +UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)" + class AbstractPlacexRunner: """ Returns SQL commands for indexing of the placex table. """ - SELECT_SQL = pysql.SQL('SELECT place_id FROM placex ') - UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)" def __init__(self, rank: int, analyzer: AbstractAnalyzer) -> None: self.rank = rank self.analyzer = analyzer - @functools.lru_cache(maxsize=1) - def _index_sql(self, num_places: int) -> pysql.Composed: + def index_places_query(self, batch_size: int) -> Query: return pysql.SQL( """ UPDATE placex SET indexed_status = 0, address = v.addr, token_info = v.ti, name = v.name, linked_place_id = v.linked_place_id FROM (VALUES {}) as v(id, name, addr, linked_place_id, ti) WHERE place_id = v.id - """).format(_mk_valuelist(AbstractPlacexRunner.UPDATE_LINE, num_places)) - - - def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults: - worker.perform("""SELECT place_id, extra.* - FROM placex, LATERAL placex_indexing_prepare(placex) as extra - WHERE place_id IN %s""", - (tuple((p[0] for p in ids)), )) + """).format(_mk_valuelist(UPDATE_LINE, batch_size)) - return [] - - def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: - values: List[Any] = [] - for place in places: - for field in ('place_id', 'name', 'address', 'linked_place_id'): - values.append(place[field]) - values.append(_analyze_place(place, self.analyzer)) - - worker.perform(self._index_sql(len(places)), values) + def index_places_params(self, place: DictRow) -> Sequence[Any]: + return (place['place_id'], + place['name'], + place['address'], + place['linked_place_id'], + _analyze_place(place, self.analyzer)) class RankRunner(AbstractPlacexRunner): @@ -91,10 +81,10 @@ class RankRunner(AbstractPlacexRunner): """).format(pysql.Literal(self.rank)) def sql_get_objects(self) -> pysql.Composed: - return self.SELECT_SQL + pysql.SQL( - """WHERE indexed_status > 0 and rank_address = {} - ORDER BY geometry_sector - """).format(pysql.Literal(self.rank)) + return SELECT_SQL.format(pysql.SQL( + """WHERE placex.indexed_status > 0 and placex.rank_address = {} + ORDER BY placex.geometry_sector + """).format(pysql.Literal(self.rank))) class BoundaryRunner(AbstractPlacexRunner): @@ -105,19 +95,19 @@ class BoundaryRunner(AbstractPlacexRunner): def name(self) -> str: return f"boundaries rank {self.rank}" - def sql_count_objects(self) -> pysql.Composed: + def sql_count_objects(self) -> Query: return pysql.SQL("""SELECT count(*) FROM placex WHERE indexed_status > 0 AND rank_search = {} AND class = 'boundary' and type = 'administrative' """).format(pysql.Literal(self.rank)) - def sql_get_objects(self) -> pysql.Composed: - return self.SELECT_SQL + pysql.SQL( - """WHERE indexed_status > 0 and rank_search = {} - and class = 'boundary' and type = 'administrative' - ORDER BY partition, admin_level - """).format(pysql.Literal(self.rank)) + def sql_get_objects(self) -> Query: + return SELECT_SQL.format(pysql.SQL( + """WHERE placex.indexed_status > 0 and placex.rank_search = {} + and placex.class = 'boundary' and placex.type = 'administrative' + ORDER BY placex.partition, placex.admin_level + """).format(pysql.Literal(self.rank))) class InterpolationRunner: @@ -132,40 +122,29 @@ class InterpolationRunner: def name(self) -> str: return "interpolation lines (location_property_osmline)" - def sql_count_objects(self) -> str: + def sql_count_objects(self) -> Query: return """SELECT count(*) FROM location_property_osmline WHERE indexed_status > 0""" - def sql_get_objects(self) -> str: - return """SELECT place_id + + def sql_get_objects(self) -> Query: + return """SELECT place_id, get_interpolation_address(address, osm_id) as address FROM location_property_osmline WHERE indexed_status > 0 ORDER BY geometry_sector""" - def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults: - worker.perform("""SELECT place_id, get_interpolation_address(address, osm_id) as address - FROM location_property_osmline WHERE place_id IN %s""", - (tuple((p[0] for p in ids)), )) - return [] - - - @functools.lru_cache(maxsize=1) - def _index_sql(self, num_places: int) -> pysql.Composed: + def index_places_query(self, batch_size: int) -> Query: return pysql.SQL("""UPDATE location_property_osmline SET indexed_status = 0, address = v.addr, token_info = v.ti FROM (VALUES {}) as v(id, addr, ti) WHERE place_id = v.id - """).format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", num_places)) - + """).format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", batch_size)) - def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: - values: List[Any] = [] - for place in places: - values.extend((place[x] for x in ('place_id', 'address'))) - values.append(_analyze_place(place, self.analyzer)) - worker.perform(self._index_sql(len(places)), values) + def index_places_params(self, place: DictRow) -> Sequence[Any]: + return (place['place_id'], place['address'], + _analyze_place(place, self.analyzer)) @@ -177,20 +156,21 @@ class PostcodeRunner(Runner): return "postcodes (location_postcode)" - def sql_count_objects(self) -> str: + def sql_count_objects(self) -> Query: return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0' - def sql_get_objects(self) -> str: + def sql_get_objects(self) -> Query: return """SELECT place_id FROM location_postcode WHERE indexed_status > 0 ORDER BY country_code, postcode""" - def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults: - return ids + def index_places_query(self, batch_size: int) -> Query: + return pysql.SQL("""UPDATE location_postcode SET indexed_status = 0 + WHERE place_id IN ({})""")\ + .format(pysql.SQL(',').join((pysql.Placeholder() for _ in range(batch_size)))) + - def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: - worker.perform(pysql.SQL("""UPDATE location_postcode SET indexed_status = 0 - WHERE place_id IN ({})""") - .format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in places)))) + def index_places_params(self, place: DictRow) -> Sequence[Any]: + return (place['place_id'], ) diff --git a/src/nominatim_db/tokenizer/icu_tokenizer.py b/src/nominatim_db/tokenizer/icu_tokenizer.py index 70c5c27a..7cd96d59 100644 --- a/src/nominatim_db/tokenizer/icu_tokenizer.py +++ b/src/nominatim_db/tokenizer/icu_tokenizer.py @@ -11,15 +11,16 @@ libICU instead of the PostgreSQL module. from typing import Optional, Sequence, List, Tuple, Mapping, Any, cast, \ Dict, Set, Iterable import itertools -import json import logging from pathlib import Path from textwrap import dedent +from psycopg.types.json import Jsonb +from psycopg import sql as pysql + from ..db.connection import connect, Connection, Cursor, server_version_tuple,\ drop_tables, table_exists, execute_scalar from ..config import Configuration -from ..db.utils import CopyBuffer from ..db.sql_preprocessor import SQLPreprocessor from ..data.place_info import PlaceInfo from ..data.place_name import PlaceName @@ -115,8 +116,8 @@ class ICUTokenizer(AbstractTokenizer): with conn.cursor() as cur: cur.execute('ANALYSE search_name') if threads > 1: - cur.execute('SET max_parallel_workers_per_gather TO %s', - (min(threads, 6),)) + cur.execute(pysql.SQL('SET max_parallel_workers_per_gather TO {}') + .format(pysql.Literal(min(threads, 6),))) if server_version_tuple(conn) < (12, 0): LOG.info('Computing word frequencies') @@ -391,7 +392,7 @@ class ICUNameAnalyzer(AbstractAnalyzer): def __init__(self, dsn: str, sanitizer: PlaceSanitizer, token_analysis: ICUTokenAnalysis) -> None: - self.conn: Optional[Connection] = connect(dsn).connection + self.conn: Optional[Connection] = connect(dsn) self.conn.autocommit = True self.sanitizer = sanitizer self.token_analysis = token_analysis @@ -533,9 +534,7 @@ class ICUNameAnalyzer(AbstractAnalyzer): if terms: with self.conn.cursor() as cur: - cur.execute_values("""SELECT create_postcode_word(pc, var) - FROM (VALUES %s) AS v(pc, var)""", - terms) + cur.executemany("""SELECT create_postcode_word(%s, %s)""", terms) @@ -578,18 +577,15 @@ class ICUNameAnalyzer(AbstractAnalyzer): to_add = new_phrases - existing_phrases added = 0 - with CopyBuffer() as copystr: + with cursor.copy('COPY word(word_token, type, word, info) FROM STDIN') as copy: for word, cls, typ, oper in to_add: term = self._search_normalized(word) if term: - copystr.add(term, 'S', word, - json.dumps({'class': cls, 'type': typ, - 'op': oper if oper in ('in', 'near') else None})) + copy.write_row((term, 'S', word, + Jsonb({'class': cls, 'type': typ, + 'op': oper if oper in ('in', 'near') else None}))) added += 1 - copystr.copy_out(cursor, 'word', - columns=['word_token', 'type', 'word', 'info']) - return added @@ -602,11 +598,11 @@ class ICUNameAnalyzer(AbstractAnalyzer): to_delete = existing_phrases - new_phrases if to_delete: - cursor.execute_values( - """ DELETE FROM word USING (VALUES %s) as v(name, in_class, in_type, op) - WHERE type = 'S' and word = name - and info->>'class' = in_class and info->>'type' = in_type - and ((op = '-' and info->>'op' is null) or op = info->>'op') + cursor.executemany( + """ DELETE FROM word + WHERE type = 'S' and word = %s + and info->>'class' = %s and info->>'type' = %s + and %s = coalesce(info->>'op', '-') """, to_delete) return len(to_delete) @@ -653,7 +649,7 @@ class ICUNameAnalyzer(AbstractAnalyzer): gone_tokens.update(existing_tokens[False] & word_tokens) if gone_tokens: cur.execute("""DELETE FROM word - USING unnest(%s) as token + USING unnest(%s::text[]) as token WHERE type = 'C' and word = %s and word_token = token""", (list(gone_tokens), country_code)) @@ -666,12 +662,12 @@ class ICUNameAnalyzer(AbstractAnalyzer): if internal: sql = """INSERT INTO word (word_token, type, word, info) (SELECT token, 'C', %s, '{"internal": "yes"}' - FROM unnest(%s) as token) + FROM unnest(%s::text[]) as token) """ else: sql = """INSERT INTO word (word_token, type, word) (SELECT token, 'C', %s - FROM unnest(%s) as token) + FROM unnest(%s::text[]) as token) """ cur.execute(sql, (country_code, list(new_tokens))) diff --git a/src/nominatim_db/tokenizer/legacy_tokenizer.py b/src/nominatim_db/tokenizer/legacy_tokenizer.py index 0e8dfcf9..fa4b3b99 100644 --- a/src/nominatim_db/tokenizer/legacy_tokenizer.py +++ b/src/nominatim_db/tokenizer/legacy_tokenizer.py @@ -17,7 +17,8 @@ import shutil from textwrap import dedent from icu import Transliterator -import psycopg2 +import psycopg +from psycopg import sql as pysql from ..errors import UsageError from ..db.connection import connect, Connection, drop_tables, table_exists,\ @@ -78,12 +79,12 @@ def _check_module(module_dir: str, conn: Connection) -> None: """ with conn.cursor() as cur: try: - cur.execute("""CREATE FUNCTION nominatim_test_import_func(text) - RETURNS text AS %s, 'transliteration' - LANGUAGE c IMMUTABLE STRICT; - DROP FUNCTION nominatim_test_import_func(text) - """, (f'{module_dir}/nominatim.so', )) - except psycopg2.DatabaseError as err: + cur.execute(pysql.SQL("""CREATE FUNCTION nominatim_test_import_func(text) + RETURNS text AS {}, 'transliteration' + LANGUAGE c IMMUTABLE STRICT; + DROP FUNCTION nominatim_test_import_func(text) + """).format(pysql.Literal(f'{module_dir}/nominatim.so'))) + except psycopg.DatabaseError as err: LOG.fatal("Error accessing database module: %s", err) raise UsageError("Database module cannot be accessed.") from err @@ -181,7 +182,7 @@ class LegacyTokenizer(AbstractTokenizer): with connect(self.dsn) as conn: try: out = execute_scalar(conn, "SELECT make_standard_name('a')") - except psycopg2.Error as err: + except psycopg.Error as err: return hint.format(error=str(err)) if out != 'a': @@ -312,7 +313,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer): """ def __init__(self, dsn: str, normalizer: Any): - self.conn: Optional[Connection] = connect(dsn).connection + self.conn: Optional[Connection] = connect(dsn) self.conn.autocommit = True self.normalizer = normalizer register_hstore(self.conn) @@ -405,7 +406,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer): """, (to_delete, )) if to_add: cur.execute("""SELECT count(create_postcode_id(pc)) - FROM unnest(%s) as pc + FROM unnest(%s::text[]) as pc """, (to_add, )) @@ -422,7 +423,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer): with self.conn.cursor() as cur: # Get the old phrases. existing_phrases = set() - cur.execute("""SELECT word, class, type, operator FROM word + cur.execute("""SELECT word, class as cls, type, operator FROM word WHERE class != 'place' OR (type != 'house' AND type != 'postcode')""") for label, cls, typ, oper in cur: @@ -432,18 +433,19 @@ class LegacyNameAnalyzer(AbstractAnalyzer): to_delete = existing_phrases - norm_phrases if to_add: - cur.execute_values( + cur.executemany( """ INSERT INTO word (word_id, word_token, word, class, type, search_name_count, operator) (SELECT nextval('seq_word'), ' ' || make_standard_name(name), name, class, type, 0, CASE WHEN op in ('in', 'near') THEN op ELSE null END - FROM (VALUES %s) as v(name, class, type, op))""", + FROM (VALUES (%s, %s, %s, %s)) as v(name, class, type, op))""", to_add) if to_delete and should_replace: - cur.execute_values( - """ DELETE FROM word USING (VALUES %s) as v(name, in_class, in_type, op) + cur.executemany( + """ DELETE FROM word + USING (VALUES (%s, %s, %s, %s)) as v(name, in_class, in_type, op) WHERE word = name and class = in_class and type = in_type and ((op = '-' and operator is null) or op = operator)""", to_delete) @@ -462,7 +464,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer): """INSERT INTO word (word_id, word_token, country_code) (SELECT nextval('seq_word'), lookup_token, %s FROM (SELECT DISTINCT ' ' || make_standard_name(n) as lookup_token - FROM unnest(%s)n) y + FROM unnest(%s::TEXT[])n) y WHERE NOT EXISTS(SELECT * FROM word WHERE word_token = lookup_token and country_code = %s)) """, (country_code, list(names.values()), country_code)) diff --git a/src/nominatim_db/tools/admin.py b/src/nominatim_db/tools/admin.py index 3e502199..e70a7e50 100644 --- a/src/nominatim_db/tools/admin.py +++ b/src/nominatim_db/tools/admin.py @@ -10,8 +10,8 @@ Functions for database analysis and maintenance. from typing import Optional, Tuple, Any, cast import logging -from psycopg2.extras import Json -from psycopg2 import DataError +import psycopg +from psycopg.types.json import Json from ..typing import DictCursorResult from ..config import Configuration @@ -59,7 +59,7 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None, """ with connect(config.get_libpq_dsn()) as conn: register_hstore(conn) - with conn.cursor() as cur: + with conn.cursor(row_factory=psycopg.rows.dict_row) as cur: place = _get_place_info(cur, osm_id, place_id) cur.execute("update placex set indexed_status = 2 where place_id = %s", @@ -74,6 +74,9 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None, tokenizer = tokenizer_factory.get_tokenizer_for_db(config) + # Enable printing of messages. + conn.add_notice_handler(lambda diag: print(diag.message_primary)) + with tokenizer.name_analyzer() as analyzer: cur.execute("""UPDATE placex SET indexed_status = 0, address = %s, token_info = %s, @@ -86,9 +89,6 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None, # we do not want to keep the results conn.rollback() - for msg in conn.notices: - print(msg) - def clean_deleted_relations(config: Configuration, age: str) -> None: """ Clean deleted relations older than a given age @@ -101,6 +101,6 @@ def clean_deleted_relations(config: Configuration, age: str) -> None: WHERE p.osm_type = d.osm_type AND p.osm_id = d.osm_id AND age(p.indexed_date) > %s::interval""", (age, )) - except DataError as exc: + except psycopg.DataError as exc: raise UsageError('Invalid PostgreSQL time interval format') from exc conn.commit() diff --git a/src/nominatim_db/tools/check_database.py b/src/nominatim_db/tools/check_database.py index 946f9291..7389c9a2 100644 --- a/src/nominatim_db/tools/check_database.py +++ b/src/nominatim_db/tools/check_database.py @@ -81,7 +81,7 @@ def check_database(config: Configuration) -> int: """ Run a number of checks on the database and return the status. """ try: - conn = connect(config.get_libpq_dsn()).connection + conn = connect(config.get_libpq_dsn()) except UsageError as err: conn = _BadConnection(str(err)) # type: ignore[assignment] diff --git a/src/nominatim_db/tools/collect_os_info.py b/src/nominatim_db/tools/collect_os_info.py index db3e773d..d054ef00 100644 --- a/src/nominatim_db/tools/collect_os_info.py +++ b/src/nominatim_db/tools/collect_os_info.py @@ -15,7 +15,6 @@ from pathlib import Path from typing import List, Optional, Union import psutil -from psycopg2.extensions import make_dsn from ..config import Configuration from ..db.connection import connect, server_version_tuple, execute_scalar @@ -97,7 +96,7 @@ def report_system_information(config: Configuration) -> None: """Generate a report about the host system including software versions, memory, storage, and database configuration.""" - with connect(make_dsn(config.get_libpq_dsn(), dbname='postgres')) as conn: + with connect(config.get_libpq_dsn(), dbname='postgres') as conn: postgresql_ver: str = '.'.join(map(str, server_version_tuple(conn))) with conn.cursor() as cur: diff --git a/src/nominatim_db/tools/database_import.py b/src/nominatim_db/tools/database_import.py index 2398d404..e96954dd 100644 --- a/src/nominatim_db/tools/database_import.py +++ b/src/nominatim_db/tools/database_import.py @@ -10,19 +10,20 @@ Functions for setting up and importing a new Nominatim database. from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any import logging import os -import selectors import subprocess +import asyncio from pathlib import Path import psutil -from psycopg2 import sql as pysql +import psycopg +from psycopg import sql as pysql from ..errors import UsageError from ..config import Configuration from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\ postgis_version_tuple, drop_tables, table_exists, execute_scalar -from ..db.async_connection import DBConnection from ..db.sql_preprocessor import SQLPreprocessor +from ..db.query_pool import QueryPool from .exec_utils import run_osm2pgsql from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION @@ -136,7 +137,7 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]], with connect(options['dsn']) as conn: if not ignore_errors: with conn.cursor() as cur: - cur.execute('SELECT * FROM place LIMIT 1') + cur.execute('SELECT true FROM place LIMIT 1') if cur.rowcount == 0: raise UsageError('No data imported by osm2pgsql.') @@ -205,54 +206,51 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier, 'extratags', 'geometry'))) -def load_data(dsn: str, threads: int) -> None: +async def load_data(dsn: str, threads: int) -> None: """ Copy data into the word and placex table. """ - sel = selectors.DefaultSelector() - # Then copy data from place to placex in chunks. - place_threads = max(1, threads - 1) - for imod in range(place_threads): - conn = DBConnection(dsn) - conn.connect() - conn.perform( - pysql.SQL("""INSERT INTO placex ({columns}) - SELECT {columns} FROM place - WHERE osm_id % {total} = {mod} - AND NOT (class='place' and (type='houses' or type='postcode')) - AND ST_IsValid(geometry) - """).format(columns=_COPY_COLUMNS, - total=pysql.Literal(place_threads), - mod=pysql.Literal(imod))) - sel.register(conn, selectors.EVENT_READ, conn) - - # Address interpolations go into another table. - conn = DBConnection(dsn) - conn.connect() - conn.perform("""INSERT INTO location_property_osmline (osm_id, address, linegeo) - SELECT osm_id, address, geometry FROM place - WHERE class='place' and type='houses' and osm_type='W' - and ST_GeometryType(geometry) = 'ST_LineString' - """) - sel.register(conn, selectors.EVENT_READ, conn) - - # Now wait for all of them to finish. - todo = place_threads + 1 - while todo > 0: - for key, _ in sel.select(1): - conn = key.data - sel.unregister(conn) - conn.wait() - conn.close() - todo -= 1 + placex_threads = max(1, threads - 1) + + progress = asyncio.create_task(_progress_print()) + + async with QueryPool(dsn, placex_threads + 1) as pool: + # Copy data from place to placex in chunks. + for imod in range(placex_threads): + await pool.put_query( + pysql.SQL("""INSERT INTO placex ({columns}) + SELECT {columns} FROM place + WHERE osm_id % {total} = {mod} + AND NOT (class='place' + and (type='houses' or type='postcode')) + AND ST_IsValid(geometry) + """).format(columns=_COPY_COLUMNS, + total=pysql.Literal(placex_threads), + mod=pysql.Literal(imod)), None) + + # Interpolations need to be copied seperately + await pool.put_query(""" + INSERT INTO location_property_osmline (osm_id, address, linegeo) + SELECT osm_id, address, geometry FROM place + WHERE class='place' and type='houses' and osm_type='W' + and ST_GeometryType(geometry) = 'ST_LineString' """, None) + + progress.cancel() + + async with await psycopg.AsyncConnection.connect(dsn) as aconn: + await aconn.execute('ANALYSE') + + +async def _progress_print() -> None: + while True: + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + print('', flush=True) + break print('.', end='', flush=True) - print('\n') - with connect(dsn) as syn_conn: - with syn_conn.cursor() as cur: - cur.execute('ANALYSE') - -def create_search_indices(conn: Connection, config: Configuration, +async def create_search_indices(conn: Connection, config: Configuration, drop: bool = False, threads: int = 1) -> None: """ Create tables that have explicit partitioning. """ @@ -271,5 +269,5 @@ def create_search_indices(conn: Connection, config: Configuration, sql = SQLPreprocessor(conn, config) - sql.run_parallel_sql_file(config.get_libpq_dsn(), - 'indices.sql', min(8, threads), drop=drop) + await sql.run_parallel_sql_file(config.get_libpq_dsn(), + 'indices.sql', min(8, threads), drop=drop) diff --git a/src/nominatim_db/tools/freeze.py b/src/nominatim_db/tools/freeze.py index e6d80e1e..c4eedb43 100644 --- a/src/nominatim_db/tools/freeze.py +++ b/src/nominatim_db/tools/freeze.py @@ -10,7 +10,7 @@ Functions for removing unnecessary data from the database. from typing import Optional from pathlib import Path -from psycopg2 import sql as pysql +from psycopg import sql as pysql from ..db.connection import Connection, drop_tables, table_exists diff --git a/src/nominatim_db/tools/migration.py b/src/nominatim_db/tools/migration.py index 46ba0125..54836532 100644 --- a/src/nominatim_db/tools/migration.py +++ b/src/nominatim_db/tools/migration.py @@ -10,7 +10,7 @@ Functions for database migration to newer software versions. from typing import List, Tuple, Callable, Any import logging -from psycopg2 import sql as pysql +from psycopg import sql as pysql from ..errors import UsageError from ..config import Configuration diff --git a/src/nominatim_db/tools/postcodes.py b/src/nominatim_db/tools/postcodes.py index a5d8ef8b..357b2bae 100644 --- a/src/nominatim_db/tools/postcodes.py +++ b/src/nominatim_db/tools/postcodes.py @@ -16,7 +16,7 @@ import gzip import logging from math import isfinite -from psycopg2 import sql as pysql +from psycopg import sql as pysql from ..db.connection import connect, Connection, table_exists from ..utils.centroid import PointsCentroid @@ -76,30 +76,30 @@ class _PostcodeCollector: with conn.cursor() as cur: if to_add: - cur.execute_values( + cur.executemany(pysql.SQL( """INSERT INTO location_postcode (place_id, indexed_status, country_code, - postcode, geometry) VALUES %s""", - to_add, - template=pysql.SQL("""(nextval('seq_place'), 1, {}, - %s, 'SRID=4326;POINT(%s %s)') - """).format(pysql.Literal(self.country))) + postcode, geometry) + VALUES (nextval('seq_place'), 1, {}, %s, + ST_SetSRID(ST_MakePoint(%s, %s), 4326)) + """).format(pysql.Literal(self.country)), + to_add) if to_delete: cur.execute("""DELETE FROM location_postcode WHERE country_code = %s and postcode = any(%s) """, (self.country, to_delete)) if to_update: - cur.execute_values( + cur.executemany( pysql.SQL("""UPDATE location_postcode SET indexed_status = 2, - geometry = ST_SetSRID(ST_Point(v.x, v.y), 4326) - FROM (VALUES %s) AS v (pc, x, y) - WHERE country_code = {} and postcode = pc - """).format(pysql.Literal(self.country)), to_update) + geometry = ST_SetSRID(ST_Point(%s, %s), 4326) + WHERE country_code = {} and postcode = %s + """).format(pysql.Literal(self.country)), + to_update) def _compute_changes(self, conn: Connection) \ - -> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[str, float, float]]]: + -> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[float, float, str]]]: """ Compute which postcodes from the collected postcodes have to be added or modified and which from the location_postcode table have to be deleted. @@ -116,7 +116,7 @@ class _PostcodeCollector: if pcobj: newx, newy = pcobj.centroid() if (x - newx) > 0.0000001 or (y - newy) > 0.0000001: - to_update.append((postcode, newx, newy)) + to_update.append((newx, newy, postcode)) else: to_delete.append(postcode) diff --git a/src/nominatim_db/tools/refresh.py b/src/nominatim_db/tools/refresh.py index 2e2ffabd..d48c4e45 100644 --- a/src/nominatim_db/tools/refresh.py +++ b/src/nominatim_db/tools/refresh.py @@ -14,12 +14,12 @@ import logging from textwrap import dedent from pathlib import Path -from psycopg2 import sql as pysql +from psycopg import sql as pysql from ..config import Configuration from ..db.connection import Connection, connect, postgis_version_tuple,\ drop_tables, table_exists -from ..db.utils import execute_file, CopyBuffer +from ..db.utils import execute_file from ..db.sql_preprocessor import SQLPreprocessor from ..version import NOMINATIM_VERSION @@ -68,8 +68,8 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s rank_address SMALLINT) """).format(pysql.Identifier(table))) - cur.execute_values(pysql.SQL("INSERT INTO {} VALUES %s") - .format(pysql.Identifier(table)), rows) + cur.executemany(pysql.SQL("INSERT INTO {} VALUES (%s, %s, %s, %s, %s)") + .format(pysql.Identifier(table)), rows) cur.execute(pysql.SQL('CREATE UNIQUE INDEX ON {} (country_code, class, type)') .format(pysql.Identifier(table))) @@ -155,7 +155,7 @@ def import_importance_csv(dsn: str, data_file: Path) -> int: if not data_file.exists(): return 1 - # Only import the first occurence of a wikidata ID. + # Only import the first occurrence of a wikidata ID. # This keeps indexes and table small. wd_done = set() @@ -169,24 +169,17 @@ def import_importance_csv(dsn: str, data_file: Path) -> int: wikidata TEXT ) """) - with gzip.open(str(data_file), 'rt') as fd, CopyBuffer() as buf: - for row in csv.DictReader(fd, delimiter='\t', quotechar='|'): - wd_id = int(row['wikidata_id'][1:]) - buf.add(row['language'], row['title'], row['importance'], - None if wd_id in wd_done else row['wikidata_id']) - wd_done.add(wd_id) + copy_cmd = """COPY wikimedia_importance(language, title, importance, wikidata) + FROM STDIN""" + with gzip.open(str(data_file), 'rt') as fd, cur.copy(copy_cmd) as copy: + for row in csv.DictReader(fd, delimiter='\t', quotechar='|'): + wd_id = int(row['wikidata_id'][1:]) + copy.write_row((row['language'], + row['title'], + row['importance'], + None if wd_id in wd_done else row['wikidata_id'])) + wd_done.add(wd_id) - if buf.size() > 10000000: - with conn.cursor() as cur: - buf.copy_out(cur, 'wikimedia_importance', - columns=['language', 'title', 'importance', - 'wikidata']) - - with conn.cursor() as cur: - buf.copy_out(cur, 'wikimedia_importance', - columns=['language', 'title', 'importance', 'wikidata']) - - with conn.cursor() as cur: cur.execute("""CREATE INDEX IF NOT EXISTS idx_wikimedia_importance_title ON wikimedia_importance (title)""") cur.execute("""CREATE INDEX IF NOT EXISTS idx_wikimedia_importance_wikidata diff --git a/src/nominatim_db/tools/special_phrases/sp_importer.py b/src/nominatim_db/tools/special_phrases/sp_importer.py index 4a63ff14..311e37e2 100644 --- a/src/nominatim_db/tools/special_phrases/sp_importer.py +++ b/src/nominatim_db/tools/special_phrases/sp_importer.py @@ -17,7 +17,7 @@ from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set import logging import re -from psycopg2.sql import Identifier, SQL +from psycopg.sql import Identifier, SQL from ...typing import Protocol from ...config import Configuration diff --git a/src/nominatim_db/tools/tiger_data.py b/src/nominatim_db/tools/tiger_data.py index 7c52b710..f4a7eba7 100644 --- a/src/nominatim_db/tools/tiger_data.py +++ b/src/nominatim_db/tools/tiger_data.py @@ -7,22 +7,22 @@ """ Functions for importing tiger data and handling tarbar and directory files """ -from typing import Any, TextIO, List, Union, cast +from typing import Any, TextIO, List, Union, cast, Iterator, Dict import csv import io import logging import os import tarfile -from psycopg2.extras import Json +from psycopg.types.json import Json from ..config import Configuration from ..db.connection import connect -from ..db.async_connection import WorkerPool from ..db.sql_preprocessor import SQLPreprocessor from ..errors import UsageError +from ..db.query_pool import QueryPool from ..data.place_info import PlaceInfo -from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer +from ..tokenizer.base import AbstractTokenizer from . import freeze LOG = logging.getLogger() @@ -63,13 +63,13 @@ class TigerInput: self.tar_handle.close() self.tar_handle = None + def __bool__(self) -> bool: + return bool(self.files) - def next_file(self) -> TextIO: + def get_file(self, fname: Union[str, tarfile.TarInfo]) -> TextIO: """ Return a file handle to the next file to be processed. Raises an IndexError if there is no file left. """ - fname = self.files.pop(0) - if self.tar_handle is not None: extracted = self.tar_handle.extractfile(fname) assert extracted is not None @@ -78,47 +78,22 @@ class TigerInput: return open(cast(str, fname), encoding='utf-8') - def __len__(self) -> int: - return len(self.files) - - -def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO, - analyzer: AbstractAnalyzer) -> None: - """ Handles sql statement with multiplexing - """ - lines = 0 - # Using pool of database connections to execute sql statements - - sql = "SELECT tiger_line_import(%s, %s, %s, %s, %s, %s)" - - for row in csv.DictReader(fd, delimiter=';'): - try: - address = dict(street=row['street'], postcode=row['postcode']) - args = ('SRID=4326;' + row['geometry'], - int(row['from']), int(row['to']), row['interpolation'], - Json(analyzer.process_place(PlaceInfo({'address': address}))), - analyzer.normalize_postcode(row['postcode'])) - except ValueError: - continue - pool.next_free_worker().perform(sql, args=args) - - lines += 1 - if lines == 1000: - print('.', end='', flush=True) - lines = 0 + def __iter__(self) -> Iterator[Dict[str, Any]]: + """ Iterate over the lines in each file. + """ + for fname in self.files: + fd = self.get_file(fname) + yield from csv.DictReader(fd, delimiter=';') -def add_tiger_data(data_dir: str, config: Configuration, threads: int, +async def add_tiger_data(data_dir: str, config: Configuration, threads: int, tokenizer: AbstractTokenizer) -> int: """ Import tiger data from directory or tar file `data dir`. """ dsn = config.get_libpq_dsn() with connect(dsn) as conn: - is_frozen = freeze.is_frozen(conn) - conn.close() - - if is_frozen: + if freeze.is_frozen(conn): raise UsageError("Tiger cannot be imported when database frozen (Github issue #3048)") with TigerInput(data_dir) as tar: @@ -133,13 +108,30 @@ def add_tiger_data(data_dir: str, config: Configuration, threads: int, # sql_query in chunks. place_threads = max(1, threads - 1) - with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool: + async with QueryPool(dsn, place_threads, autocommit=True) as pool: with tokenizer.name_analyzer() as analyzer: - while tar: - with tar.next_file() as fd: - handle_threaded_sql_statements(pool, fd, analyzer) - - print('\n') + lines = 0 + for row in tar: + try: + address = dict(street=row['street'], postcode=row['postcode']) + args = ('SRID=4326;' + row['geometry'], + int(row['from']), int(row['to']), row['interpolation'], + Json(analyzer.process_place(PlaceInfo({'address': address}))), + analyzer.normalize_postcode(row['postcode'])) + except ValueError: + continue + + await pool.put_query( + """SELECT tiger_line_import(%s::GEOMETRY, %s::INT, + %s::INT, %s::TEXT, %s::JSONB, %s::TEXT)""", + args) + + lines += 1 + if lines == 1000: + print('.', end='', flush=True) + lines = 0 + + print('', flush=True) LOG.warning("Creating indexes on Tiger data") with connect(dsn) as conn: diff --git a/src/nominatim_db/typing.py b/src/nominatim_db/typing.py index f1abee82..6f0145c3 100644 --- a/src/nominatim_db/typing.py +++ b/src/nominatim_db/typing.py @@ -16,18 +16,13 @@ from typing import Any, Union, Mapping, TypeVar, Sequence, TYPE_CHECKING # pylint: disable=missing-class-docstring,useless-import-alias if TYPE_CHECKING: - import psycopg2.sql - import psycopg2.extensions - import psycopg2.extras import os StrPath = Union[str, 'os.PathLike[str]'] SysEnv = Mapping[str, str] -# psycopg2-related types - -Query = Union[str, bytes, 'psycopg2.sql.Composable'] +# psycopg-related types T_ResultKey = TypeVar('T_ResultKey', int, str) @@ -36,8 +31,6 @@ class DictCursorResult(Mapping[str, Any]): DictCursorResults = Sequence[DictCursorResult] -T_cursor = TypeVar('T_cursor', bound='psycopg2.extensions.cursor') - # The following typing features require typing_extensions to work # on all supported Python versions. # Only require this for type checking but not for normal operations. diff --git a/src/nominatim_db/version.py b/src/nominatim_db/version.py index 70e1ac14..fceee5d0 100644 --- a/src/nominatim_db/version.py +++ b/src/nominatim_db/version.py @@ -31,7 +31,7 @@ class NominatimVersion(NamedTuple): major: int minor: int patch_level: int - db_patch_level: Optional[int] + db_patch_level: int def __str__(self) -> str: if self.db_patch_level is None: @@ -47,6 +47,7 @@ class NominatimVersion(NamedTuple): return f"{self.major}.{self.minor}.{self.patch_level}" + def parse_version(version: str) -> NominatimVersion: """ Parse a version string into a version consisting of a tuple of four ints: major, minor, patch level, database patch level diff --git a/test/bdd/steps/nominatim_environment.py b/test/bdd/steps/nominatim_environment.py index dfbbee28..17a76745 100644 --- a/test/bdd/steps/nominatim_environment.py +++ b/test/bdd/steps/nominatim_environment.py @@ -9,14 +9,14 @@ import importlib import sys import tempfile -import psycopg2 -import psycopg2.extras +import psycopg +from psycopg import sql as pysql sys.path.insert(1, str((Path(__file__) / '..' / '..' / '..' / '..'/ 'src').resolve())) from nominatim_db import cli from nominatim_db.config import Configuration -from nominatim_db.db.connection import Connection +from nominatim_db.db.connection import Connection, register_hstore, execute_scalar from nominatim_db.tools import refresh from nominatim_db.tokenizer import factory as tokenizer_factory from steps.utils import run_script @@ -60,7 +60,7 @@ class NominatimEnvironment: """ Return a connection to the database with the given name. Uses configured host, user and port. """ - dbargs = {'database': dbname} + dbargs = {'dbname': dbname, 'row_factory': psycopg.rows.dict_row} if self.db_host: dbargs['host'] = self.db_host if self.db_port: @@ -69,8 +69,7 @@ class NominatimEnvironment: dbargs['user'] = self.db_user if self.db_pass: dbargs['password'] = self.db_pass - conn = psycopg2.connect(connection_factory=Connection, **dbargs) - return conn + return psycopg.connect(**dbargs) def next_code_coverage_file(self): """ Generate the next name for a coverage file. @@ -132,6 +131,8 @@ class NominatimEnvironment: conn = False refresh.setup_website(Path(self.website_dir.name) / 'website', self.get_test_config(), conn) + if conn: + conn.close() def get_test_config(self): @@ -160,11 +161,10 @@ class NominatimEnvironment: def db_drop_database(self, name): """ Drop the database with the given name. """ - conn = self.connect_database('postgres') - conn.set_isolation_level(0) - cur = conn.cursor() - cur.execute('DROP DATABASE IF EXISTS {}'.format(name)) - conn.close() + with self.connect_database('postgres') as conn: + conn.autocommit = True + conn.execute(pysql.SQL('DROP DATABASE IF EXISTS') + + pysql.Identifier(name)) def setup_template_db(self): """ Setup a template database that already contains common test data. @@ -249,16 +249,18 @@ class NominatimEnvironment: """ Setup a test against a fresh, empty test database. """ self.setup_template_db() - conn = self.connect_database(self.template_db) - conn.set_isolation_level(0) - cur = conn.cursor() - cur.execute('DROP DATABASE IF EXISTS {}'.format(self.test_db)) - cur.execute('CREATE DATABASE {} TEMPLATE = {}'.format(self.test_db, self.template_db)) - conn.close() + with self.connect_database(self.template_db) as conn: + conn.autocommit = True + conn.execute(pysql.SQL('DROP DATABASE IF EXISTS') + + pysql.Identifier(self.test_db)) + conn.execute(pysql.SQL('CREATE DATABASE {} TEMPLATE = {}').format( + pysql.Identifier(self.test_db), + pysql.Identifier(self.template_db))) + self.write_nominatim_config(self.test_db) context.db = self.connect_database(self.test_db) context.db.autocommit = True - psycopg2.extras.register_hstore(context.db, globally=False) + register_hstore(context.db) def teardown_db(self, context, force_drop=False): """ Remove the test database, if it exists. @@ -276,31 +278,26 @@ class NominatimEnvironment: dropped and always false returned. """ if self.reuse_template: - conn = self.connect_database('postgres') - with conn.cursor() as cur: - cur.execute('select count(*) from pg_database where datname = %s', - (name,)) - if cur.fetchone()[0] == 1: + with self.connect_database('postgres') as conn: + num = execute_scalar(conn, + 'select count(*) from pg_database where datname = %s', + (name,)) + if num == 1: return True - conn.close() else: self.db_drop_database(name) return False + def reindex_placex(self, db): """ Run the indexing step until all data in the placex has been processed. Indexing during updates can produce more data to index under some circumstances. That is why indexing may have to be run multiple times. """ - with db.cursor() as cur: - while True: - self.run_nominatim('index') + self.run_nominatim('index') - cur.execute("SELECT 'a' FROM placex WHERE indexed_status != 0 LIMIT 1") - if cur.rowcount == 0: - return def run_nominatim(self, *cmdline): """ Run the nominatim command-line tool via the library. diff --git a/test/bdd/steps/steps_db_ops.py b/test/bdd/steps/steps_db_ops.py index 441198fd..a0dd9b34 100644 --- a/test/bdd/steps/steps_db_ops.py +++ b/test/bdd/steps/steps_db_ops.py @@ -7,7 +7,8 @@ import logging from itertools import chain -import psycopg2.extras +import psycopg +from psycopg import sql as pysql from place_inserter import PlaceColumn from table_compare import NominatimID, DBRow @@ -18,7 +19,7 @@ from nominatim_db.tokenizer import factory as tokenizer_factory def check_database_integrity(context): """ Check some generic constraints on the tables. """ - with context.db.cursor() as cur: + with context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur: # place_addressline should not have duplicate (place_id, address_place_id) cur.execute("""SELECT count(*) FROM (SELECT place_id, address_place_id, count(*) as c @@ -54,7 +55,7 @@ def add_data_to_planet_relations(context): with context.db.cursor() as cur: cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'") row = cur.fetchone() - if row is None or row[0] == '1': + if row is None or row['value'] == '1': for r in context.table: last_node = 0 last_way = 0 @@ -96,8 +97,8 @@ def add_data_to_planet_relations(context): cur.execute("""INSERT INTO planet_osm_rels (id, tags, members) VALUES (%s, %s, %s)""", - (r['id'], psycopg2.extras.Json(tags), - psycopg2.extras.Json(members))) + (r['id'], psycopg.types.json.Json(tags), + psycopg.types.json.Json(members))) @given("the ways") def add_data_to_planet_ways(context): @@ -107,10 +108,10 @@ def add_data_to_planet_ways(context): with context.db.cursor() as cur: cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'") row = cur.fetchone() - json_tags = row is not None and row[0] != '1' + json_tags = row is not None and row['value'] != '1' for r in context.table: if json_tags: - tags = psycopg2.extras.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")}) + tags = psycopg.types.json.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")}) else: tags = list(chain.from_iterable([(h[5:], r[h]) for h in r.headings if h.startswith("tags+")])) @@ -197,7 +198,7 @@ def check_place_contents(context, table, exact): expected rows are expected to be present with at least one database row. When 'exactly' is given, there must not be additional rows in the database. """ - with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + with context.db.cursor() as cur: expected_content = set() for row in context.table: nid = NominatimID(row['object']) @@ -215,8 +216,9 @@ def check_place_contents(context, table, exact): DBRow(nid, res, context).assert_row(row, ['object']) if exact: - cur.execute('SELECT osm_type, osm_id, class from {}'.format(table)) - actual = set([(r[0], r[1], r[2]) for r in cur]) + cur.execute(pysql.SQL('SELECT osm_type, osm_id, class from') + + pysql.Identifier(table)) + actual = set([(r['osm_type'], r['osm_id'], r['class']) for r in cur]) assert expected_content == actual, \ f"Missing entries: {expected_content - actual}\n" \ f"Not expected in table: {actual - expected_content}" @@ -227,7 +229,7 @@ def check_place_has_entry(context, table, oid): """ Ensure that no database row for the given object exists. The ID must be of the form '[:]'. """ - with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + with context.db.cursor() as cur: NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table) assert cur.rowcount == 0, \ "Found {} entries for ID {}".format(cur.rowcount, oid) @@ -244,7 +246,7 @@ def check_search_name_contents(context, exclude): tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config()) with tokenizer.name_analyzer() as analyzer: - with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + with context.db.cursor() as cur: for row in context.table: nid = NominatimID(row['object']) nid.row_by_place_id(cur, 'search_name', @@ -276,7 +278,7 @@ def check_search_name_has_entry(context, oid): """ Check that there is noentry in the search_name table for the given objects. IDs are in format '[:]'. """ - with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + with context.db.cursor() as cur: NominatimID(oid).row_by_place_id(cur, 'search_name') assert cur.rowcount == 0, \ @@ -290,7 +292,7 @@ def check_location_postcode(context): All rows must be present as excepted and there must not be additional rows. """ - with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + with context.db.cursor() as cur: cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode") assert cur.rowcount == len(list(context.table)), \ "Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table))) @@ -321,7 +323,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes): plist.sort() - with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + with context.db.cursor() as cur: if nctx.tokenizer != 'legacy': cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)", (plist,)) @@ -330,7 +332,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes): and class = 'place' and type = 'postcode'""", (plist,)) - found = [row[0] for row in cur] + found = [row['word'] for row in cur] assert len(found) == len(set(found)), f"Duplicate rows for postcodes: {found}" if exclude: @@ -347,7 +349,7 @@ def check_place_addressline(context): representing the addressee and the 'address' column, representing the address item. """ - with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + with context.db.cursor() as cur: for row in context.table: nid = NominatimID(row['object']) pid = nid.get_place_id(cur) @@ -366,7 +368,7 @@ def check_place_addressline_exclude(context): """ Check that the place_addressline doesn't contain any entries for the given addressee/address item pairs. """ - with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + with context.db.cursor() as cur: for row in context.table: pid = NominatimID(row['object']).get_place_id(cur) apid = NominatimID(row['address']).get_place_id(cur, allow_empty=True) @@ -381,7 +383,7 @@ def check_place_addressline_exclude(context): def check_location_property_osmline(context, oid, neg): """ Check that the given way is present in the interpolation table. """ - with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + with context.db.cursor() as cur: cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt FROM location_property_osmline WHERE osm_id = %s AND startnumber IS NOT NULL""", @@ -417,7 +419,7 @@ def check_place_contents(context, exact): expected rows are expected to be present with at least one database row. When 'exactly' is given, there must not be additional rows in the database. """ - with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + with context.db.cursor() as cur: expected_content = set() for row in context.table: if ':' in row['object']: @@ -447,7 +449,7 @@ def check_place_contents(context, exact): if exact: cur.execute('SELECT osm_id, startnumber from location_property_osmline') - actual = set([(r[0], r[1]) for r in cur]) + actual = set([(r['osm_id'], r['startnumber']) for r in cur]) assert expected_content == actual, \ f"Missing entries: {expected_content - actual}\n" \ f"Not expected in table: {actual - expected_content}" diff --git a/test/bdd/steps/table_compare.py b/test/bdd/steps/table_compare.py index cf2e12f1..4284fad9 100644 --- a/test/bdd/steps/table_compare.py +++ b/test/bdd/steps/table_compare.py @@ -10,6 +10,9 @@ Functions to facilitate accessing and comparing the content of DB tables. import re import json +import psycopg +from psycopg import sql as pysql + from steps.check_functions import Almost ID_REGEX = re.compile(r"(?P[NRW])(?P\d+)(:(?P\w+))?") @@ -73,7 +76,7 @@ class NominatimID: assert cur.rowcount == 1, \ "Place ID {!s} not unique. Found {} entries.".format(self, cur.rowcount) - return cur.fetchone()[0] + return cur.fetchone()['place_id'] class DBRow: @@ -152,9 +155,10 @@ class DBRow: def _has_centroid(self, expected): if expected == 'in geometry': - with self.context.db.cursor() as cur: - cur.execute("""SELECT ST_Within(ST_SetSRID(ST_Point({cx}, {cy}), 4326), - ST_SetSRID('{geomtxt}'::geometry, 4326))""".format(**self.db_row)) + with self.context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur: + cur.execute("""SELECT ST_Within(ST_SetSRID(ST_Point(%(cx)s, %(cy)s), 4326), + ST_SetSRID(%(geomtxt)s::geometry, 4326))""", + (self.db_row)) return cur.fetchone()[0] if ' ' in expected: @@ -166,10 +170,11 @@ class DBRow: def _has_geometry(self, expected): geom = self.context.osm.parse_geometry(expected) - with self.context.db.cursor() as cur: - cur.execute("""SELECT ST_Equals(ST_SnapToGrid({}, 0.00001, 0.00001), - ST_SnapToGrid(ST_SetSRID('{}'::geometry, 4326), 0.00001, 0.00001))""".format( - geom, self.db_row['geomtxt'])) + with self.context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur: + cur.execute(pysql.SQL("""SELECT ST_Equals(ST_SnapToGrid({}, 0.00001, 0.00001), + ST_SnapToGrid(ST_SetSRID({}::geometry, 4326), 0.00001, 0.00001))""") + .format(pysql.SQL(geom), + pysql.Literal(self.db_row['geomtxt']))) return cur.fetchone()[0] def assert_msg(self, name, value): @@ -209,7 +214,7 @@ class DBRow: if actual == 0: return "place ID 0" - with self.context.db.cursor() as cur: + with self.context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur: cur.execute("""SELECT osm_type, osm_id, class FROM placex WHERE place_id = %s""", (actual, )) diff --git a/test/python/api/test_api_deletable_v1.py b/test/python/api/test_api_deletable_v1.py index 1b8dc34d..649dd8fc 100644 --- a/test/python/api/test_api_deletable_v1.py +++ b/test/python/api/test_api_deletable_v1.py @@ -13,8 +13,6 @@ from pathlib import Path import pytest import pytest_asyncio -import psycopg2.extras - from fake_adaptor import FakeAdaptor, FakeError, FakeResponse import nominatim_api.v1.server_glue as glue @@ -31,7 +29,6 @@ class TestDeletableEndPoint: @pytest.fixture(autouse=True) def setup_deletable_table(self, temp_db_cursor, table_factory, temp_db_with_extensions): - psycopg2.extras.register_hstore(temp_db_cursor) table_factory('import_polygon_delete', definition='osm_id bigint, osm_type char(1), class text, type text', content=[(345, 'N', 'boundary', 'administrative'), diff --git a/test/python/api/test_api_polygons_v1.py b/test/python/api/test_api_polygons_v1.py index bf51cd17..558be813 100644 --- a/test/python/api/test_api_polygons_v1.py +++ b/test/python/api/test_api_polygons_v1.py @@ -14,8 +14,6 @@ from pathlib import Path import pytest import pytest_asyncio -import psycopg2.extras - from fake_adaptor import FakeAdaptor, FakeError, FakeResponse import nominatim_api.v1.server_glue as glue @@ -32,8 +30,6 @@ class TestPolygonsEndPoint: @pytest.fixture(autouse=True) def setup_deletable_table(self, temp_db_cursor, table_factory, temp_db_with_extensions): - psycopg2.extras.register_hstore(temp_db_cursor) - self.now = dt.datetime.now() self.recent = dt.datetime.now() - dt.timedelta(days=3) diff --git a/test/python/cli/conftest.py b/test/python/cli/conftest.py index 1e3ca8ab..d5ade223 100644 --- a/test/python/cli/conftest.py +++ b/test/python/cli/conftest.py @@ -25,6 +25,23 @@ class MockParamCapture: return self.return_value +class AsyncMockParamCapture: + """ Mock that records the parameters with which a function was called + as well as the number of calls. + """ + def __init__(self, retval=0): + self.called = 0 + self.return_value = retval + self.last_args = None + self.last_kwargs = None + + async def __call__(self, *args, **kwargs): + self.called += 1 + self.last_args = args + self.last_kwargs = kwargs + return self.return_value + + class DummyTokenizer: def __init__(self, *args, **kwargs): self.update_sql_functions_called = False @@ -69,6 +86,17 @@ def mock_func_factory(monkeypatch): return get_mock +@pytest.fixture +def async_mock_func_factory(monkeypatch): + def get_mock(module, func): + mock = AsyncMockParamCapture() + mock.func_name = func + monkeypatch.setattr(module, func, mock) + return mock + + return get_mock + + @pytest.fixture def cli_tokenizer_mock(monkeypatch): tok = DummyTokenizer() diff --git a/test/python/cli/test_cli.py b/test/python/cli/test_cli.py index 688afb7c..6586c5ec 100644 --- a/test/python/cli/test_cli.py +++ b/test/python/cli/test_cli.py @@ -17,6 +17,7 @@ import pytest import nominatim_db.indexer.indexer import nominatim_db.tools.add_osm_data import nominatim_db.tools.freeze +import nominatim_db.tools.tiger_data def test_cli_help(cli_call, capsys): @@ -52,8 +53,8 @@ def test_cli_add_data_object_command(cli_call, mock_func_factory, name, oid): -def test_cli_add_data_tiger_data(cli_call, cli_tokenizer_mock, mock_func_factory): - mock = mock_func_factory(nominatim_db.tools.tiger_data, 'add_tiger_data') +def test_cli_add_data_tiger_data(cli_call, cli_tokenizer_mock, async_mock_func_factory): + mock = async_mock_func_factory(nominatim_db.tools.tiger_data, 'add_tiger_data') assert cli_call('add-data', '--tiger-data', 'somewhere') == 0 @@ -68,38 +69,6 @@ def test_cli_serve_php(cli_call, mock_func_factory): assert func.called == 1 -def test_cli_serve_starlette_custom_server(cli_call, mock_func_factory): - pytest.importorskip("starlette") - mod = pytest.importorskip("uvicorn") - func = mock_func_factory(mod, "run") - - cli_call('serve', '--engine', 'starlette', '--server', 'foobar:4545') == 0 - - assert func.called == 1 - assert func.last_kwargs['host'] == 'foobar' - assert func.last_kwargs['port'] == 4545 - - -def test_cli_serve_starlette_custom_server_bad_port(cli_call, mock_func_factory): - pytest.importorskip("starlette") - mod = pytest.importorskip("uvicorn") - func = mock_func_factory(mod, "run") - - cli_call('serve', '--engine', 'starlette', '--server', 'foobar:45:45') == 1 - - -@pytest.mark.parametrize("engine", ['falcon', 'starlette']) -def test_cli_serve_uvicorn_based(cli_call, engine, mock_func_factory): - pytest.importorskip(engine) - mod = pytest.importorskip("uvicorn") - func = mock_func_factory(mod, "run") - - cli_call('serve', '--engine', engine) == 0 - - assert func.called == 1 - assert func.last_kwargs['host'] == '127.0.0.1' - assert func.last_kwargs['port'] == 8088 - class TestCliWithDb: @@ -120,16 +89,19 @@ class TestCliWithDb: @pytest.mark.parametrize("params,do_bnds,do_ranks", [ - ([], 1, 1), - (['--boundaries-only'], 1, 0), - (['--no-boundaries'], 0, 1), + ([], 2, 2), + (['--boundaries-only'], 2, 0), + (['--no-boundaries'], 0, 2), (['--boundaries-only', '--no-boundaries'], 0, 0)]) - def test_index_command(self, mock_func_factory, table_factory, + def test_index_command(self, monkeypatch, async_mock_func_factory, table_factory, params, do_bnds, do_ranks): table_factory('import_status', 'indexed bool') - bnd_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_boundaries') - rank_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_by_rank') - postcode_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes') + bnd_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_boundaries') + rank_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_by_rank') + postcode_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes') + + monkeypatch.setattr(nominatim_db.indexer.indexer.Indexer, 'has_pending', + [False, True].pop) assert self.call_nominatim('index', *params) == 0 diff --git a/test/python/cli/test_cmd_import.py b/test/python/cli/test_cmd_import.py index 85235e1e..e47d713c 100644 --- a/test/python/cli/test_cmd_import.py +++ b/test/python/cli/test_cmd_import.py @@ -34,7 +34,8 @@ class TestCliImportWithDb: @pytest.mark.parametrize('with_updates', [True, False]) - def test_import_full(self, mock_func_factory, with_updates, place_table, property_table): + def test_import_full(self, mock_func_factory, async_mock_func_factory, + with_updates, place_table, property_table): mocks = [ mock_func_factory(nominatim_db.tools.database_import, 'setup_database_skeleton'), mock_func_factory(nominatim_db.data.country_info, 'setup_country_tables'), @@ -42,15 +43,15 @@ class TestCliImportWithDb: mock_func_factory(nominatim_db.tools.refresh, 'import_wikipedia_articles'), mock_func_factory(nominatim_db.tools.refresh, 'import_secondary_importance'), mock_func_factory(nominatim_db.tools.database_import, 'truncate_data_tables'), - mock_func_factory(nominatim_db.tools.database_import, 'load_data'), + async_mock_func_factory(nominatim_db.tools.database_import, 'load_data'), mock_func_factory(nominatim_db.tools.database_import, 'create_tables'), mock_func_factory(nominatim_db.tools.database_import, 'create_table_triggers'), mock_func_factory(nominatim_db.tools.database_import, 'create_partition_tables'), - mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'), + async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'), mock_func_factory(nominatim_db.data.country_info, 'create_country_names'), mock_func_factory(nominatim_db.tools.refresh, 'load_address_levels_from_config'), mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes'), - mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'), + async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'), mock_func_factory(nominatim_db.tools.refresh, 'setup_website'), ] @@ -72,14 +73,14 @@ class TestCliImportWithDb: assert mock.called == 1, "Mock '{}' not called".format(mock.func_name) - def test_import_continue_load_data(self, mock_func_factory): + def test_import_continue_load_data(self, mock_func_factory, async_mock_func_factory): mocks = [ mock_func_factory(nominatim_db.tools.database_import, 'truncate_data_tables'), - mock_func_factory(nominatim_db.tools.database_import, 'load_data'), - mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'), + async_mock_func_factory(nominatim_db.tools.database_import, 'load_data'), + async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'), mock_func_factory(nominatim_db.data.country_info, 'create_country_names'), mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes'), - mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'), + async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'), mock_func_factory(nominatim_db.tools.refresh, 'setup_website'), mock_func_factory(nominatim_db.db.properties, 'set_property') ] @@ -91,12 +92,12 @@ class TestCliImportWithDb: assert mock.called == 1, "Mock '{}' not called".format(mock.func_name) - def test_import_continue_indexing(self, mock_func_factory, placex_table, - temp_db_conn): + def test_import_continue_indexing(self, mock_func_factory, async_mock_func_factory, + placex_table, temp_db_conn): mocks = [ - mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'), + async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'), mock_func_factory(nominatim_db.data.country_info, 'create_country_names'), - mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'), + async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'), mock_func_factory(nominatim_db.tools.refresh, 'setup_website'), mock_func_factory(nominatim_db.db.properties, 'set_property') ] @@ -110,9 +111,9 @@ class TestCliImportWithDb: assert self.call_nominatim('import', '--continue', 'indexing') == 0 - def test_import_continue_postprocess(self, mock_func_factory): + def test_import_continue_postprocess(self, mock_func_factory, async_mock_func_factory): mocks = [ - mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'), + async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'), mock_func_factory(nominatim_db.data.country_info, 'create_country_names'), mock_func_factory(nominatim_db.tools.refresh, 'setup_website'), mock_func_factory(nominatim_db.db.properties, 'set_property') diff --git a/test/python/cli/test_cmd_refresh.py b/test/python/cli/test_cmd_refresh.py index 39396817..9074b2cc 100644 --- a/test/python/cli/test_cmd_refresh.py +++ b/test/python/cli/test_cmd_refresh.py @@ -45,9 +45,9 @@ class TestRefresh: assert self.tokenizer_mock.update_word_tokens_called - def test_refresh_postcodes(self, mock_func_factory, place_table): + def test_refresh_postcodes(self, async_mock_func_factory, mock_func_factory, place_table): func_mock = mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes') - idx_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes') + idx_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes') assert self.call_nominatim('refresh', '--postcodes') == 0 assert func_mock.called == 1 diff --git a/test/python/cli/test_cmd_replication.py b/test/python/cli/test_cmd_replication.py index 8c1e8ea6..21c6350d 100644 --- a/test/python/cli/test_cmd_replication.py +++ b/test/python/cli/test_cmd_replication.py @@ -47,8 +47,8 @@ def init_status(temp_db_conn, status_table): @pytest.fixture -def index_mock(mock_func_factory, tokenizer_mock, init_status): - return mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full') +def index_mock(async_mock_func_factory, tokenizer_mock, init_status): + return async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full') @pytest.fixture diff --git a/test/python/conftest.py b/test/python/conftest.py index d301c973..3ced3205 100644 --- a/test/python/conftest.py +++ b/test/python/conftest.py @@ -8,7 +8,8 @@ import itertools import sys from pathlib import Path -import psycopg2 +import psycopg +from psycopg import sql as pysql import pytest # always test against the source @@ -36,26 +37,23 @@ def temp_db(monkeypatch): exported into NOMINATIM_DATABASE_DSN. """ name = 'test_nominatim_python_unittest' - conn = psycopg2.connect(database='postgres') - conn.set_isolation_level(0) - with conn.cursor() as cur: - cur.execute('DROP DATABASE IF EXISTS {}'.format(name)) - cur.execute('CREATE DATABASE {}'.format(name)) - - conn.close() + with psycopg.connect(dbname='postgres', autocommit=True) as conn: + with conn.cursor() as cur: + cur.execute(pysql.SQL('DROP DATABASE IF EXISTS') + pysql.Identifier(name)) + cur.execute(pysql.SQL('CREATE DATABASE') + pysql.Identifier(name)) monkeypatch.setenv('NOMINATIM_DATABASE_DSN', 'dbname=' + name) - yield name - - conn = psycopg2.connect(database='postgres') + with psycopg.connect(dbname=name) as conn: + with conn.cursor() as cur: + cur.execute('CREATE EXTENSION hstore') - conn.set_isolation_level(0) - with conn.cursor() as cur: - cur.execute('DROP DATABASE IF EXISTS {}'.format(name)) + yield name - conn.close() + with psycopg.connect(dbname='postgres', autocommit=True) as conn: + with conn.cursor() as cur: + cur.execute('DROP DATABASE IF EXISTS {}'.format(name)) @pytest.fixture @@ -65,11 +63,9 @@ def dsn(temp_db): @pytest.fixture def temp_db_with_extensions(temp_db): - conn = psycopg2.connect(database=temp_db) - with conn.cursor() as cur: - cur.execute('CREATE EXTENSION hstore; CREATE EXTENSION postgis;') - conn.commit() - conn.close() + with psycopg.connect(dbname=temp_db) as conn: + with conn.cursor() as cur: + cur.execute('CREATE EXTENSION postgis') return temp_db @@ -77,7 +73,8 @@ def temp_db_with_extensions(temp_db): def temp_db_conn(temp_db): """ Connection to the test database. """ - with connection.connect('dbname=' + temp_db) as conn: + with connection.connect('', autocommit=True, dbname=temp_db) as conn: + connection.register_hstore(conn) yield conn @@ -86,22 +83,25 @@ def temp_db_cursor(temp_db): """ Connection and cursor towards the test database. The connection will be in auto-commit mode. """ - conn = psycopg2.connect('dbname=' + temp_db) - conn.set_isolation_level(0) - with conn.cursor(cursor_factory=CursorForTesting) as cur: - yield cur - conn.close() + with psycopg.connect(dbname=temp_db, autocommit=True, cursor_factory=CursorForTesting) as conn: + connection.register_hstore(conn) + with conn.cursor() as cur: + yield cur @pytest.fixture -def table_factory(temp_db_cursor): +def table_factory(temp_db_conn): """ A fixture that creates new SQL tables, potentially filled with content. """ def mk_table(name, definition='id INT', content=None): - temp_db_cursor.execute('CREATE TABLE {} ({})'.format(name, definition)) - if content is not None: - temp_db_cursor.execute_values("INSERT INTO {} VALUES %s".format(name), content) + with psycopg.ClientCursor(temp_db_conn) as cur: + cur.execute('CREATE TABLE {} ({})'.format(name, definition)) + if content: + sql = pysql.SQL("INSERT INTO {} VALUES ({})")\ + .format(pysql.Identifier(name), + pysql.SQL(',').join([pysql.Placeholder() for _ in range(len(content[0]))])) + cur.executemany(sql , content) return mk_table @@ -168,7 +168,6 @@ def place_row(place_table, temp_db_cursor): """ A factory for rows in the place table. The table is created as a prerequisite to the fixture. """ - psycopg2.extras.register_hstore(temp_db_cursor) idseq = itertools.count(1001) def _insert(osm_type='N', osm_id=None, cls='amenity', typ='cafe', names=None, admin_level=None, address=None, extratags=None, geom=None): diff --git a/test/python/cursor.py b/test/python/cursor.py index 7d586b3c..b3fc260a 100644 --- a/test/python/cursor.py +++ b/test/python/cursor.py @@ -5,11 +5,11 @@ # Copyright (C) 2024 by the Nominatim developer community. # For a full list of authors see the git log. """ -Specialised psycopg2 cursor with shortcut functions useful for testing. +Specialised psycopg cursor with shortcut functions useful for testing. """ -import psycopg2.extras +import psycopg -class CursorForTesting(psycopg2.extras.DictCursor): +class CursorForTesting(psycopg.Cursor): """ Extension to the DictCursor class that provides execution short-cuts that simplify writing assertions. """ @@ -59,9 +59,3 @@ class CursorForTesting(psycopg2.extras.DictCursor): return self.scalar('SELECT count(*) FROM ' + table) return self.scalar('SELECT count(*) FROM {} WHERE {}'.format(table, where)) - - - def execute_values(self, *args, **kwargs): - """ Execute the execute_values() function on the cursor. - """ - psycopg2.extras.execute_values(self, *args, **kwargs) diff --git a/test/python/db/test_async_connection.py b/test/python/db/test_async_connection.py deleted file mode 100644 index 9647bedc..00000000 --- a/test/python/db/test_async_connection.py +++ /dev/null @@ -1,113 +0,0 @@ -# SPDX-License-Identifier: GPL-3.0-or-later -# -# This file is part of Nominatim. (https://nominatim.org) -# -# Copyright (C) 2024 by the Nominatim developer community. -# For a full list of authors see the git log. -""" -Tests for function providing a non-blocking query interface towards PostgreSQL. -""" -from contextlib import closing -import concurrent.futures - -import pytest -import psycopg2 - -from nominatim_db.db.async_connection import DBConnection, DeadlockHandler - - -@pytest.fixture -def conn(temp_db): - with closing(DBConnection('dbname=' + temp_db)) as connection: - yield connection - - -@pytest.fixture -def simple_conns(temp_db): - conn1 = psycopg2.connect('dbname=' + temp_db) - conn2 = psycopg2.connect('dbname=' + temp_db) - - yield conn1.cursor(), conn2.cursor() - - conn1.close() - conn2.close() - - -def test_simple_query(conn, temp_db_cursor): - conn.connect() - - conn.perform('CREATE TABLE foo (id INT)') - conn.wait() - - assert temp_db_cursor.table_exists('foo') - - -def test_wait_for_query(conn): - conn.connect() - - conn.perform('SELECT pg_sleep(1)') - - assert not conn.is_done() - - conn.wait() - - -def test_bad_query(conn): - conn.connect() - - conn.perform('SELECT efasfjsea') - - with pytest.raises(psycopg2.ProgrammingError): - conn.wait() - - -def test_bad_query_ignore(temp_db): - with closing(DBConnection('dbname=' + temp_db, ignore_sql_errors=True)) as conn: - conn.connect() - - conn.perform('SELECT efasfjsea') - - conn.wait() - - -def exec_with_deadlock(cur, sql, detector): - with DeadlockHandler(lambda *args: detector.append(1)): - cur.execute(sql) - - -def test_deadlock(simple_conns): - cur1, cur2 = simple_conns - - cur1.execute("""CREATE TABLE t1 (id INT PRIMARY KEY, t TEXT); - INSERT into t1 VALUES (1, 'a'), (2, 'b')""") - cur1.connection.commit() - - cur1.execute("UPDATE t1 SET t = 'x' WHERE id = 1") - cur2.execute("UPDATE t1 SET t = 'x' WHERE id = 2") - - # This is the tricky part of the test. The first SQL command runs into - # a lock and blocks, so we have to run it in a separate thread. When the - # second deadlocking SQL statement is issued, Postgresql will abort one of - # the two transactions that cause the deadlock. There is no way to tell - # which one of the two. Therefore wrap both in a DeadlockHandler and - # expect that exactly one of the two triggers. - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - deadlock_check = [] - try: - future = executor.submit(exec_with_deadlock, cur2, - "UPDATE t1 SET t = 'y' WHERE id = 1", - deadlock_check) - - while not future.running(): - pass - - - exec_with_deadlock(cur1, "UPDATE t1 SET t = 'y' WHERE id = 2", - deadlock_check) - finally: - # Whatever happens, make sure the deadlock gets resolved. - cur1.connection.rollback() - - future.result() - - assert len(deadlock_check) == 1 diff --git a/test/python/db/test_connection.py b/test/python/db/test_connection.py index 9f1442f3..a8b5d677 100644 --- a/test/python/db/test_connection.py +++ b/test/python/db/test_connection.py @@ -8,7 +8,7 @@ Tests for specialised connection and cursor classes. """ import pytest -import psycopg2 +import psycopg import nominatim_db.db.connection as nc @@ -73,7 +73,7 @@ def test_drop_many_tables(db, table_factory): def test_drop_table_non_existing_force(db): - with pytest.raises(psycopg2.ProgrammingError, match='.*does not exist.*'): + with pytest.raises(psycopg.ProgrammingError, match='.*does not exist.*'): nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre', if_exists=False) def test_connection_server_version_tuple(db): diff --git a/test/python/db/test_sql_preprocessor.py b/test/python/db/test_sql_preprocessor.py index 114c5244..45109c70 100644 --- a/test/python/db/test_sql_preprocessor.py +++ b/test/python/db/test_sql_preprocessor.py @@ -8,6 +8,7 @@ Tests for SQL preprocessing. """ import pytest +import pytest_asyncio from nominatim_db.db.sql_preprocessor import SQLPreprocessor @@ -54,3 +55,17 @@ def test_load_file_with_params(sql_preprocessor, sql_factory, temp_db_conn, temp sql_preprocessor.run_sql_file(temp_db_conn, sqlfile, bar='XX', foo='ZZ') assert temp_db_cursor.scalar('SELECT test()') == 'ZZ XX' + + +@pytest.mark.asyncio +async def test_load_parallel_file(dsn, sql_preprocessor, tmp_path, temp_db_cursor): + (tmp_path / 'test.sql').write_text(""" + CREATE TABLE foo (a TEXT); + CREATE TABLE foo2(a TEXT);""" + + "\n---\nCREATE TABLE bar (b INT);") + + await sql_preprocessor.run_parallel_sql_file(dsn, 'test.sql', num_threads=4) + + assert temp_db_cursor.table_exists('foo') + assert temp_db_cursor.table_exists('foo2') + assert temp_db_cursor.table_exists('bar') diff --git a/test/python/db/test_utils.py b/test/python/db/test_utils.py index b4335ab0..7c46846d 100644 --- a/test/python/db/test_utils.py +++ b/test/python/db/test_utils.py @@ -58,103 +58,3 @@ def test_execute_file_with_post_code(dsn, tmp_path, temp_db_cursor): db_utils.execute_file(dsn, tmpfile, post_code='INSERT INTO test VALUES(23)') assert temp_db_cursor.row_set('SELECT * FROM test') == {(23, )} - - -class TestCopyBuffer: - TABLE_NAME = 'copytable' - - @pytest.fixture(autouse=True) - def setup_test_table(self, table_factory): - table_factory(self.TABLE_NAME, 'col_a INT, col_b TEXT') - - - def table_rows(self, cursor): - return cursor.row_set('SELECT * FROM ' + self.TABLE_NAME) - - - def test_copybuffer_empty(self): - with db_utils.CopyBuffer() as buf: - buf.copy_out(None, "dummy") - - - def test_all_columns(self, temp_db_cursor): - with db_utils.CopyBuffer() as buf: - buf.add(3, 'hum') - buf.add(None, 'f\\t') - - buf.copy_out(temp_db_cursor, self.TABLE_NAME) - - assert self.table_rows(temp_db_cursor) == {(3, 'hum'), (None, 'f\\t')} - - - def test_selected_columns(self, temp_db_cursor): - with db_utils.CopyBuffer() as buf: - buf.add('foo') - - buf.copy_out(temp_db_cursor, self.TABLE_NAME, - columns=['col_b']) - - assert self.table_rows(temp_db_cursor) == {(None, 'foo')} - - - def test_reordered_columns(self, temp_db_cursor): - with db_utils.CopyBuffer() as buf: - buf.add('one', 1) - buf.add(' two ', 2) - - buf.copy_out(temp_db_cursor, self.TABLE_NAME, - columns=['col_b', 'col_a']) - - assert self.table_rows(temp_db_cursor) == {(1, 'one'), (2, ' two ')} - - - def test_special_characters(self, temp_db_cursor): - with db_utils.CopyBuffer() as buf: - buf.add('foo\tbar') - buf.add('sun\nson') - buf.add('\\N') - - buf.copy_out(temp_db_cursor, self.TABLE_NAME, - columns=['col_b']) - - assert self.table_rows(temp_db_cursor) == {(None, 'foo\tbar'), - (None, 'sun\nson'), - (None, '\\N')} - - - -class TestCopyBufferJson: - TABLE_NAME = 'copytable' - - @pytest.fixture(autouse=True) - def setup_test_table(self, table_factory): - table_factory(self.TABLE_NAME, 'col_a INT, col_b JSONB') - - - def table_rows(self, cursor): - cursor.execute('SELECT * FROM ' + self.TABLE_NAME) - results = {k: v for k,v in cursor} - - assert len(results) == cursor.rowcount - - return results - - - def test_json_object(self, temp_db_cursor): - with db_utils.CopyBuffer() as buf: - buf.add(1, json.dumps({'test': 'value', 'number': 1})) - - buf.copy_out(temp_db_cursor, self.TABLE_NAME) - - assert self.table_rows(temp_db_cursor) == \ - {1: {'test': 'value', 'number': 1}} - - - def test_json_object_special_chras(self, temp_db_cursor): - with db_utils.CopyBuffer() as buf: - buf.add(1, json.dumps({'te\tst': 'va\nlue', 'nu"mber': None})) - - buf.copy_out(temp_db_cursor, self.TABLE_NAME) - - assert self.table_rows(temp_db_cursor) == \ - {1: {'te\tst': 'va\nlue', 'nu"mber': None}} diff --git a/test/python/indexer/test_indexing.py b/test/python/indexer/test_indexing.py index 93481e8a..fe65b69c 100644 --- a/test/python/indexer/test_indexing.py +++ b/test/python/indexer/test_indexing.py @@ -9,6 +9,7 @@ Tests for running the indexing. """ import itertools import pytest +import pytest_asyncio from nominatim_db.indexer import indexer from nominatim_db.tokenizer import factory @@ -21,9 +22,8 @@ class IndexerTestDB: self.postcode_id = itertools.count(700000) self.conn = conn - self.conn.set_isolation_level(0) + self.conn.autocimmit = True with self.conn.cursor() as cur: - cur.execute('CREATE EXTENSION hstore') cur.execute("""CREATE TABLE placex (place_id BIGINT, name HSTORE, class TEXT, @@ -156,7 +156,8 @@ def test_tokenizer(tokenizer_mock, project_env): @pytest.mark.parametrize("threads", [1, 15]) -def test_index_all_by_rank(test_db, threads, test_tokenizer): +@pytest.mark.asyncio +async def test_index_all_by_rank(test_db, threads, test_tokenizer): for rank in range(31): test_db.add_place(rank_address=rank, rank_search=rank) test_db.add_osmline() @@ -165,7 +166,7 @@ def test_index_all_by_rank(test_db, threads, test_tokenizer): assert test_db.osmline_unindexed() == 1 idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads) - idx.index_by_rank(0, 30) + await idx.index_by_rank(0, 30) assert test_db.placex_unindexed() == 0 assert test_db.osmline_unindexed() == 0 @@ -190,7 +191,8 @@ def test_index_all_by_rank(test_db, threads, test_tokenizer): @pytest.mark.parametrize("threads", [1, 15]) -def test_index_partial_without_30(test_db, threads, test_tokenizer): +@pytest.mark.asyncio +async def test_index_partial_without_30(test_db, threads, test_tokenizer): for rank in range(31): test_db.add_place(rank_address=rank, rank_search=rank) test_db.add_osmline() @@ -200,7 +202,7 @@ def test_index_partial_without_30(test_db, threads, test_tokenizer): idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads) - idx.index_by_rank(4, 15) + await idx.index_by_rank(4, 15) assert test_db.placex_unindexed() == 19 assert test_db.osmline_unindexed() == 1 @@ -211,7 +213,8 @@ def test_index_partial_without_30(test_db, threads, test_tokenizer): @pytest.mark.parametrize("threads", [1, 15]) -def test_index_partial_with_30(test_db, threads, test_tokenizer): +@pytest.mark.asyncio +async def test_index_partial_with_30(test_db, threads, test_tokenizer): for rank in range(31): test_db.add_place(rank_address=rank, rank_search=rank) test_db.add_osmline() @@ -220,7 +223,7 @@ def test_index_partial_with_30(test_db, threads, test_tokenizer): assert test_db.osmline_unindexed() == 1 idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads) - idx.index_by_rank(28, 30) + await idx.index_by_rank(28, 30) assert test_db.placex_unindexed() == 27 assert test_db.osmline_unindexed() == 0 @@ -230,7 +233,8 @@ def test_index_partial_with_30(test_db, threads, test_tokenizer): WHERE indexed_status = 0 AND rank_address between 1 and 27""") == 0 @pytest.mark.parametrize("threads", [1, 15]) -def test_index_boundaries(test_db, threads, test_tokenizer): +@pytest.mark.asyncio +async def test_index_boundaries(test_db, threads, test_tokenizer): for rank in range(4, 10): test_db.add_admin(rank_address=rank, rank_search=rank) for rank in range(31): @@ -241,7 +245,7 @@ def test_index_boundaries(test_db, threads, test_tokenizer): assert test_db.osmline_unindexed() == 1 idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads) - idx.index_boundaries(0, 30) + await idx.index_boundaries(0, 30) assert test_db.placex_unindexed() == 31 assert test_db.osmline_unindexed() == 1 @@ -252,21 +256,23 @@ def test_index_boundaries(test_db, threads, test_tokenizer): @pytest.mark.parametrize("threads", [1, 15]) -def test_index_postcodes(test_db, threads, test_tokenizer): +@pytest.mark.asyncio +async def test_index_postcodes(test_db, threads, test_tokenizer): for postcode in range(1000): test_db.add_postcode('de', postcode) for postcode in range(32000, 33000): test_db.add_postcode('us', postcode) idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads) - idx.index_postcodes() + await idx.index_postcodes() assert test_db.scalar("""SELECT count(*) FROM location_postcode WHERE indexed_status != 0""") == 0 @pytest.mark.parametrize("analyse", [True, False]) -def test_index_full(test_db, analyse, test_tokenizer): +@pytest.mark.asyncio +async def test_index_full(test_db, analyse, test_tokenizer): for rank in range(4, 10): test_db.add_admin(rank_address=rank, rank_search=rank) for rank in range(31): @@ -276,22 +282,9 @@ def test_index_full(test_db, analyse, test_tokenizer): test_db.add_postcode('de', postcode) idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, 4) - idx.index_full(analyse=analyse) + await idx.index_full(analyse=analyse) assert test_db.placex_unindexed() == 0 assert test_db.osmline_unindexed() == 0 assert test_db.scalar("""SELECT count(*) FROM location_postcode WHERE indexed_status != 0""") == 0 - - -@pytest.mark.parametrize("threads", [1, 15]) -def test_index_reopen_connection(test_db, threads, monkeypatch, test_tokenizer): - monkeypatch.setattr(indexer.WorkerPool, "REOPEN_CONNECTIONS_AFTER", 15) - - for _ in range(1000): - test_db.add_place(rank_address=30, rank_search=30) - - idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads) - idx.index_by_rank(28, 30) - - assert test_db.placex_unindexed() == 0 diff --git a/test/python/mock_icu_word_table.py b/test/python/mock_icu_word_table.py index 67be1892..e8b4390f 100644 --- a/test/python/mock_icu_word_table.py +++ b/test/python/mock_icu_word_table.py @@ -36,9 +36,9 @@ class MockIcuWordTable: with self.conn.cursor() as cur: cur.execute("""INSERT INTO word (word_token, type, word, info) VALUES (%s, 'S', %s, - json_build_object('class', %s, - 'type', %s, - 'op', %s)) + json_build_object('class', %s::text, + 'type', %s::text, + 'op', %s::text)) """, (word_token, word, cls, typ, oper)) self.conn.commit() @@ -71,7 +71,7 @@ class MockIcuWordTable: word = word_tokens[0] for token in word_tokens: cur.execute("""INSERT INTO word (word_id, word_token, type, word, info) - VALUES (%s, %s, 'H', %s, jsonb_build_object('lookup', %s)) + VALUES (%s, %s, 'H', %s, jsonb_build_object('lookup', %s::text)) """, (word_id, token, word, word_tokens[0])) self.conn.commit() diff --git a/test/python/mock_legacy_word_table.py b/test/python/mock_legacy_word_table.py index d1c523eb..d3f81a4d 100644 --- a/test/python/mock_legacy_word_table.py +++ b/test/python/mock_legacy_word_table.py @@ -68,7 +68,7 @@ class MockLegacyWordTable: def get_special(self): with self.conn.cursor() as cur: - cur.execute("""SELECT word_token, word, class, type, operator + cur.execute("""SELECT word_token, word, class as cls, type, operator FROM word WHERE class != 'place'""") result = set((tuple(row) for row in cur)) assert len(result) == cur.rowcount, "Word table has duplicates." diff --git a/test/python/mocks.py b/test/python/mocks.py index 82e700bd..cde0b7bb 100644 --- a/test/python/mocks.py +++ b/test/python/mocks.py @@ -9,8 +9,6 @@ Custom mocks for testing. """ import itertools -import psycopg2.extras - from nominatim_db.db import properties # This must always point to the mock word table for the default tokenizer. @@ -56,7 +54,6 @@ class MockPlacexTable: admin_level=None, address=None, extratags=None, geom='POINT(10 4)', country=None, housenumber=None, rank_search=30): with self.conn.cursor() as cur: - psycopg2.extras.register_hstore(cur) cur.execute("""INSERT INTO placex (place_id, osm_type, osm_id, class, type, name, admin_level, address, housenumber, rank_search, diff --git a/test/python/tools/test_database_import.py b/test/python/tools/test_database_import.py index 9d56efa0..df204298 100644 --- a/test/python/tools/test_database_import.py +++ b/test/python/tools/test_database_import.py @@ -8,10 +8,11 @@ Tests for functions to import a new database. """ from pathlib import Path -from contextlib import closing import pytest -import psycopg2 +import pytest_asyncio +import psycopg +from psycopg import sql as pysql from nominatim_db.tools import database_import from nominatim_db.errors import UsageError @@ -21,10 +22,7 @@ class TestDatabaseSetup: @pytest.fixture(autouse=True) def setup_nonexistant_db(self): - conn = psycopg2.connect(database='postgres') - - try: - conn.set_isolation_level(0) + with psycopg.connect(dbname='postgres', autocommit=True) as conn: with conn.cursor() as cur: cur.execute(f'DROP DATABASE IF EXISTS {self.DBNAME}') @@ -32,22 +30,17 @@ class TestDatabaseSetup: with conn.cursor() as cur: cur.execute(f'DROP DATABASE IF EXISTS {self.DBNAME}') - finally: - conn.close() + @pytest.fixture def cursor(self): - conn = psycopg2.connect(database=self.DBNAME) - - try: + with psycopg.connect(dbname=self.DBNAME) as conn: with conn.cursor() as cur: yield cur - finally: - conn.close() def conn(self): - return closing(psycopg2.connect(database=self.DBNAME)) + return psycopg.connect(dbname=self.DBNAME) def test_setup_skeleton(self): @@ -178,18 +171,19 @@ def test_truncate_database_tables(temp_db_conn, temp_db_cursor, table_factory, w @pytest.mark.parametrize("threads", (1, 5)) -def test_load_data(dsn, place_row, placex_table, osmline_table, +@pytest.mark.asyncio +async def test_load_data(dsn, place_row, placex_table, osmline_table, temp_db_cursor, threads): for func in ('precompute_words', 'getorcreate_housenumber_id', 'make_standard_name'): - temp_db_cursor.execute(f"""CREATE FUNCTION {func} (src TEXT) - RETURNS TEXT AS $$ SELECT 'a'::TEXT $$ LANGUAGE SQL - """) + temp_db_cursor.execute(pysql.SQL("""CREATE FUNCTION {} (src TEXT) + RETURNS TEXT AS $$ SELECT 'a'::TEXT $$ LANGUAGE SQL + """).format(pysql.Identifier(func))) for oid in range(100, 130): place_row(osm_id=oid) place_row(osm_type='W', osm_id=342, cls='place', typ='houses', geom='SRID=4326;LINESTRING(0 0, 10 10)') - database_import.load_data(dsn, threads) + await database_import.load_data(dsn, threads) assert temp_db_cursor.table_rows('placex') == 30 assert temp_db_cursor.table_rows('location_property_osmline') == 1 @@ -241,11 +235,12 @@ class TestSetupSQL: @pytest.mark.parametrize("drop", [True, False]) - def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop): + @pytest.mark.asyncio + async def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop): self.write_sql('indices.sql', """CREATE FUNCTION test() RETURNS bool AS $$ SELECT {{drop}} $$ LANGUAGE SQL""") - database_import.create_search_indices(temp_db_conn, self.config, drop) + await database_import.create_search_indices(temp_db_conn, self.config, drop) temp_db_cursor.scalar('SELECT test()') == drop diff --git a/test/python/tools/test_migration.py b/test/python/tools/test_migration.py index 8821f694..2c7b2d56 100644 --- a/test/python/tools/test_migration.py +++ b/test/python/tools/test_migration.py @@ -8,7 +8,6 @@ Tests for migration functions """ import pytest -import psycopg2.extras from nominatim_db.tools import migration from nominatim_db.errors import UsageError @@ -44,7 +43,6 @@ def test_no_migration_old_versions(temp_db_with_extensions, table_factory, def_c def test_set_up_migration_for_36(temp_db_with_extensions, temp_db_cursor, table_factory, def_config, monkeypatch, postprocess_mock): - psycopg2.extras.register_hstore(temp_db_cursor) # don't actually run any migration, except the property table creation monkeypatch.setattr(migration, '_MIGRATION_FUNCTIONS', [((3, 5, 0, 99), migration.add_nominatim_property_table)]) diff --git a/test/python/tools/test_postcodes.py b/test/python/tools/test_postcodes.py index febb2271..f035bb19 100644 --- a/test/python/tools/test_postcodes.py +++ b/test/python/tools/test_postcodes.py @@ -47,7 +47,7 @@ class MockPostcodeTable: country_code, postcode, geometry) VALUES (nextval('seq_place'), 1, %s, %s, - 'SRID=4326;POINT(%s %s)')""", + ST_SetSRID(ST_MakePoint(%s, %s), 4326))""", (country, postcode, x, y)) self.conn.commit() diff --git a/test/python/tools/test_tiger_data.py b/test/python/tools/test_tiger_data.py index fc01f22f..5d65fafb 100644 --- a/test/python/tools/test_tiger_data.py +++ b/test/python/tools/test_tiger_data.py @@ -11,6 +11,7 @@ import tarfile from textwrap import dedent import pytest +import pytest_asyncio from nominatim_db.db.connection import execute_scalar from nominatim_db.tools import tiger_data, freeze @@ -76,82 +77,91 @@ def csv_factory(tmp_path): @pytest.mark.parametrize("threads", (1, 5)) -def test_add_tiger_data(def_config, src_dir, tiger_table, tokenizer_mock, threads): - tiger_data.add_tiger_data(str(src_dir / 'test' / 'testdb' / 'tiger'), - def_config, threads, tokenizer_mock()) +@pytest.mark.asyncio +async def test_add_tiger_data(def_config, src_dir, tiger_table, tokenizer_mock, threads): + await tiger_data.add_tiger_data(str(src_dir / 'test' / 'testdb' / 'tiger'), + def_config, threads, tokenizer_mock()) assert tiger_table.count() == 6213 -def test_add_tiger_data_database_frozen(def_config, temp_db_conn, tiger_table, tokenizer_mock, +@pytest.mark.asyncio +async def test_add_tiger_data_database_frozen(def_config, temp_db_conn, tiger_table, tokenizer_mock, tmp_path): freeze.drop_update_tables(temp_db_conn) with pytest.raises(UsageError) as excinfo: - tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock()) + await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock()) assert "database frozen" in str(excinfo.value) assert tiger_table.count() == 0 -def test_add_tiger_data_no_files(def_config, tiger_table, tokenizer_mock, + +@pytest.mark.asyncio +async def test_add_tiger_data_no_files(def_config, tiger_table, tokenizer_mock, tmp_path): - tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock()) + await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock()) assert tiger_table.count() == 0 -def test_add_tiger_data_bad_file(def_config, tiger_table, tokenizer_mock, +@pytest.mark.asyncio +async def test_add_tiger_data_bad_file(def_config, tiger_table, tokenizer_mock, tmp_path): sqlfile = tmp_path / '1010.csv' sqlfile.write_text("""Random text""") - tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock()) + await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock()) assert tiger_table.count() == 0 -def test_add_tiger_data_hnr_nan(def_config, tiger_table, tokenizer_mock, +@pytest.mark.asyncio +async def test_add_tiger_data_hnr_nan(def_config, tiger_table, tokenizer_mock, csv_factory, tmp_path): csv_factory('file1', hnr_from=99) csv_factory('file2', hnr_from='L12') csv_factory('file3', hnr_to='12.4') - tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock()) + await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock()) assert tiger_table.count() == 1 - assert tiger_table.row()['start'] == 99 + assert tiger_table.row().start == 99 @pytest.mark.parametrize("threads", (1, 5)) -def test_add_tiger_data_tarfile(def_config, tiger_table, tokenizer_mock, +@pytest.mark.asyncio +async def test_add_tiger_data_tarfile(def_config, tiger_table, tokenizer_mock, tmp_path, src_dir, threads): tar = tarfile.open(str(tmp_path / 'sample.tar.gz'), "w:gz") tar.add(str(src_dir / 'test' / 'testdb' / 'tiger' / '01001.csv')) tar.close() - tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, threads, - tokenizer_mock()) + await tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, threads, + tokenizer_mock()) assert tiger_table.count() == 6213 -def test_add_tiger_data_bad_tarfile(def_config, tiger_table, tokenizer_mock, +@pytest.mark.asyncio +async def test_add_tiger_data_bad_tarfile(def_config, tiger_table, tokenizer_mock, tmp_path): tarfile = tmp_path / 'sample.tar.gz' tarfile.write_text("""Random text""") with pytest.raises(UsageError): - tiger_data.add_tiger_data(str(tarfile), def_config, 1, tokenizer_mock()) + await tiger_data.add_tiger_data(str(tarfile), def_config, 1, tokenizer_mock()) -def test_add_tiger_data_empty_tarfile(def_config, tiger_table, tokenizer_mock, +@pytest.mark.asyncio +async def test_add_tiger_data_empty_tarfile(def_config, tiger_table, tokenizer_mock, tmp_path): tar = tarfile.open(str(tmp_path / 'sample.tar.gz'), "w:gz") tar.add(__file__) tar.close() - tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, 1, - tokenizer_mock()) + await tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, 1, + tokenizer_mock()) assert tiger_table.count() == 0