]> git.openstreetmap.org Git - nominatim.git/commitdiff
Merge pull request #3487 from lonvia/port-to-psycopg3
authorSarah Hoffmann <lonvia@denofr.de>
Mon, 29 Jul 2024 14:52:07 +0000 (16:52 +0200)
committerGitHub <noreply@github.com>
Mon, 29 Jul 2024 14:52:07 +0000 (16:52 +0200)
Move importer code to psycopg3

66 files changed:
.github/actions/build-nominatim/action.yml
docs/admin/Installation.md
lib-sql/indices.sql
packaging/nominatim-db/pyproject.toml
src/nominatim_api/core.py
src/nominatim_db/cli.py
src/nominatim_db/clicmd/add_data.py
src/nominatim_db/clicmd/admin.py
src/nominatim_db/clicmd/index.py
src/nominatim_db/clicmd/refresh.py
src/nominatim_db/clicmd/replication.py
src/nominatim_db/clicmd/setup.py
src/nominatim_db/config.py
src/nominatim_db/data/country_info.py
src/nominatim_db/db/async_connection.py [deleted file]
src/nominatim_db/db/connection.py
src/nominatim_db/db/properties.py
src/nominatim_db/db/query_pool.py [new file with mode: 0644]
src/nominatim_db/db/sql_preprocessor.py
src/nominatim_db/db/status.py
src/nominatim_db/db/utils.py
src/nominatim_db/indexer/indexer.py
src/nominatim_db/indexer/runners.py
src/nominatim_db/tokenizer/icu_tokenizer.py
src/nominatim_db/tokenizer/legacy_tokenizer.py
src/nominatim_db/tools/admin.py
src/nominatim_db/tools/check_database.py
src/nominatim_db/tools/collect_os_info.py
src/nominatim_db/tools/database_import.py
src/nominatim_db/tools/freeze.py
src/nominatim_db/tools/migration.py
src/nominatim_db/tools/postcodes.py
src/nominatim_db/tools/refresh.py
src/nominatim_db/tools/replication.py
src/nominatim_db/tools/special_phrases/sp_importer.py
src/nominatim_db/tools/tiger_data.py
src/nominatim_db/typing.py
src/nominatim_db/version.py
test/bdd/steps/nominatim_environment.py
test/bdd/steps/steps_db_ops.py
test/bdd/steps/table_compare.py
test/python/api/test_api_deletable_v1.py
test/python/api/test_api_polygons_v1.py
test/python/cli/conftest.py
test/python/cli/test_cli.py
test/python/cli/test_cmd_import.py
test/python/cli/test_cmd_refresh.py
test/python/cli/test_cmd_replication.py
test/python/conftest.py
test/python/cursor.py
test/python/db/test_async_connection.py [deleted file]
test/python/db/test_connection.py
test/python/db/test_sql_preprocessor.py
test/python/db/test_utils.py
test/python/indexer/test_indexing.py
test/python/mock_icu_word_table.py
test/python/mock_legacy_word_table.py
test/python/mocks.py
test/python/tokenizer/test_icu.py
test/python/tools/test_database_import.py
test/python/tools/test_import_special_phrases.py
test/python/tools/test_migration.py
test/python/tools/test_postcodes.py
test/python/tools/test_refresh.py
test/python/tools/test_tiger_data.py
vagrant/Install-on-Ubuntu-22.sh

index 17ff0ccfc14d391958309bae4020e598523b69e5..d601fc7b7880534eec5934e4bb0e7ab57a2c03c1 100644 (file)
@@ -27,9 +27,9 @@ runs:
           run: |
             sudo apt-get install -y -qq libboost-system-dev libboost-filesystem-dev libexpat1-dev zlib1g-dev libbz2-dev libpq-dev libproj-dev libicu-dev liblua${LUA_VERSION}-dev lua${LUA_VERSION} lua-dkjson nlohmann-json3-dev libspatialite7 libsqlite3-mod-spatialite
             if [ "$FLAVOUR" == "oldstuff" ]; then
-                pip3 install MarkupSafe==2.0.1 python-dotenv psycopg2==2.7.7 jinja2==2.8 psutil==5.4.2 pyicu==2.9 osmium PyYAML==5.1 sqlalchemy==1.4.31 datrie asyncpg aiosqlite
+                pip3 install MarkupSafe==2.0.1 python-dotenv jinja2==2.8 psutil==5.4.2 pyicu==2.9 osmium PyYAML==5.1 sqlalchemy==1.4.31 psycopg==3.1.7 datrie asyncpg aiosqlite
             else
-                sudo apt-get install -y -qq python3-icu python3-datrie python3-pyosmium python3-jinja2 python3-psutil python3-psycopg2 python3-dotenv python3-yaml
+                sudo apt-get install -y -qq python3-icu python3-datrie python3-pyosmium python3-jinja2 python3-psutil python3-dotenv python3-yaml
                 pip3 install sqlalchemy psycopg aiosqlite
             fi
           shell: bash
index a2f1a084c5678a0d997dc02e4abf2a3298f364d4..9159ac62666287b81f70b754a8b2caf985917065 100644 (file)
@@ -36,19 +36,15 @@ For running Nominatim:
 
 Furthermore the following Python libraries are required:
 
-  * [Psycopg2](https://www.psycopg.org) (2.7+)
+  * [Psycopg3](https://www.psycopg.org)
   * [Python Dotenv](https://github.com/theskumar/python-dotenv)
   * [psutil](https://github.com/giampaolo/psutil)
   * [Jinja2](https://palletsprojects.com/p/jinja/)
-  * [SQLAlchemy](https://www.sqlalchemy.org/) (1.4.31+ with greenlet support)
-  * one of
-    * [psycopg3](https://www.psycopg.org)
-    * [asyncpg](https://magicstack.github.io/asyncpg) (0.8+)
   * [PyICU](https://pypi.org/project/PyICU/)
   * [PyYaml](https://pyyaml.org/) (5.1+)
   * [datrie](https://github.com/pytries/datrie)
 
-These will be installed automatically, when using pip installation.
+These will be installed automatically when using pip installation.
 
 When using legacy CMake-based installation:
 
@@ -69,6 +65,8 @@ For running continuous updates:
 
 For running the Python frontend:
 
+  * [SQLAlchemy](https://www.sqlalchemy.org/) (1.4.31+ with greenlet support)
+  * [asyncpg](https://magicstack.github.io/asyncpg) (0.8+, only when using SQLAlchemy < 2.0)
   * one of the following web frameworks:
     * [falcon](https://falconframework.org/) (3.0+)
     * [starlette](https://www.starlette.io/)
index 8c176fdf8199b55503a9775f5cf8e78e973ce969..4d92452d1f588bdc8a0be63cd6d522e96bdab25d 100644 (file)
@@ -31,6 +31,7 @@ CREATE INDEX IF NOT EXISTS idx_placex_geometry ON placex
 -- Index is needed during import but can be dropped as soon as a full
 -- geometry index is in place. The partial index is almost as big as the full
 -- index.
+---
 DROP INDEX IF EXISTS idx_placex_geometry_lower_rank_ways;
 ---
 CREATE INDEX IF NOT EXISTS idx_placex_geometry_reverse_lookupPolygon
@@ -60,7 +61,6 @@ CREATE INDEX IF NOT EXISTS idx_postcode_postcode
 ---
   DROP INDEX IF EXISTS idx_placex_geometry_address_area_candidates;
   DROP INDEX IF EXISTS idx_placex_geometry_buildings;
-  DROP INDEX IF EXISTS idx_placex_geometry_lower_rank_ways;
   DROP INDEX IF EXISTS idx_placex_wikidata;
   DROP INDEX IF EXISTS idx_placex_rank_address_sector;
   DROP INDEX IF EXISTS idx_placex_rank_boundaries_sector;
index 652f683f095e6bbeec5fcd718ab8c267538b21f9..841845f036f20e3ad5db19ba9f8c62fcfe44b4c4 100644 (file)
@@ -15,7 +15,7 @@ classifiers = [
     "Operating System :: OS Independent",
 ]
 dependencies = [
-    "psycopg2-binary",
+    "psycopg",
     "python-dotenv",
     "jinja2",
     "pyYAML>=5.1",
index 632c97a7a6af93d387a28bc070f22d41f67d3f33..c460d98c3ce099817138e30d3dfc646ff9235905 100644 (file)
@@ -7,7 +7,7 @@
 """
 Implementation of classes for API access via libraries.
 """
-from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, Tuple
+from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, Tuple, cast
 import asyncio
 import sys
 import contextlib
@@ -107,16 +107,16 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
                     raise UsageError(f"SQlite database '{params.get('dbname')}' does not exist.")
             else:
                 dsn = self.config.get_database_params()
-                query = {k: v for k, v in dsn.items()
+                query = {k: str(v) for k, v in dsn.items()
                          if k not in ('user', 'password', 'dbname', 'host', 'port')}
 
                 dburl = sa.engine.URL.create(
                            f'postgresql+{PGCORE_LIB}',
-                           database=dsn.get('dbname'),
-                           username=dsn.get('user'),
-                           password=dsn.get('password'),
-                           host=dsn.get('host'),
-                           port=int(dsn['port']) if 'port' in dsn else None,
+                           database=cast(str, dsn.get('dbname')),
+                           username=cast(str, dsn.get('user')),
+                           password=cast(str, dsn.get('password')),
+                           host=cast(str, dsn.get('host')),
+                           port=int(cast(str, dsn['port'])) if 'port' in dsn else None,
                            query=query)
 
             engine = sa_asyncio.create_async_engine(dburl, **extra_args)
index 41684fa1dbe9e37bf9c1d3e1a2c9ded9e7767715..932786688243f10ff1185221ffe55812ae020fd6 100644 (file)
@@ -14,6 +14,7 @@ import logging
 import os
 import sys
 import argparse
+import asyncio
 from pathlib import Path
 
 from .config import Configuration
@@ -170,22 +171,30 @@ class AdminServe:
                 raise UsageError("PHP frontend not configured.")
             run_php_server(args.server, args.project_dir / 'website')
         else:
-            import uvicorn # pylint: disable=import-outside-toplevel
-            server_info = args.server.split(':', 1)
-            host = server_info[0]
-            if len(server_info) > 1:
-                if not server_info[1].isdigit():
-                    raise UsageError('Invalid format for --server parameter. Use <host>:<port>')
-                port = int(server_info[1])
-            else:
-                port = 8088
+            asyncio.run(self.run_uvicorn(args))
 
-            server_module = importlib.import_module(f'nominatim_api.server.{args.engine}.server')
+        return 0
 
-            app = server_module.get_application(args.project_dir)
-            uvicorn.run(app, host=host, port=port)
 
-        return 0
+    async def run_uvicorn(self, args: NominatimArgs) -> None:
+        import uvicorn # pylint: disable=import-outside-toplevel
+
+        server_info = args.server.split(':', 1)
+        host = server_info[0]
+        if len(server_info) > 1:
+            if not server_info[1].isdigit():
+                raise UsageError('Invalid format for --server parameter. Use <host>:<port>')
+            port = int(server_info[1])
+        else:
+            port = 8088
+
+        server_module = importlib.import_module(f'nominatim_api.server.{args.engine}.server')
+
+        app = server_module.get_application(args.project_dir)
+
+        config = uvicorn.Config(app, host=host, port=port)
+        server = uvicorn.Server(config)
+        await server.serve()
 
 
 def get_set_parser() -> CommandlineParser:
index eced99070ba0c42aa2322111d310b103a3876f8e..a690435c52ccf1ff8b51f9fc30d3b2e41ca0cb5a 100644 (file)
@@ -10,6 +10,7 @@ Implementation of the 'add-data' subcommand.
 from typing import cast
 import argparse
 import logging
+import asyncio
 
 import psutil
 
@@ -64,15 +65,10 @@ class UpdateAddData:
 
 
     def run(self, args: NominatimArgs) -> int:
-        from ..tokenizer import factory as tokenizer_factory
-        from ..tools import tiger_data, add_osm_data
+        from ..tools import add_osm_data
 
         if args.tiger_data:
-            tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
-            return tiger_data.add_tiger_data(args.tiger_data,
-                                             args.config,
-                                             args.threads or psutil.cpu_count()  or 1,
-                                             tokenizer)
+            return asyncio.run(self._add_tiger_data(args))
 
         osm2pgsql_params = args.osm2pgsql_options(default_cache=1000, default_threads=1)
         if args.file or args.diff:
@@ -99,3 +95,16 @@ class UpdateAddData:
                                                osm2pgsql_params)
 
         return 0
+
+
+    async def _add_tiger_data(self, args: NominatimArgs) -> int:
+        from ..tokenizer import factory as tokenizer_factory
+        from ..tools import tiger_data
+
+        assert args.tiger_data
+
+        tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
+        return await tiger_data.add_tiger_data(args.tiger_data,
+                                               args.config,
+                                               args.threads or psutil.cpu_count()  or 1,
+                                               tokenizer)
index 7744595bbd3f99b449e0ea09a1e4ba2840edf710..1edff174dc37ad251de9d9d048736098d717d2f0 100644 (file)
@@ -12,7 +12,7 @@ import argparse
 import random
 
 from ..errors import UsageError
-from ..db.connection import connect
+from ..db.connection import connect, table_exists
 from .args import NominatimArgs
 
 # Do not repeat documentation of subcommand classes.
@@ -115,7 +115,7 @@ class AdminFuncs:
 
                 tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
                 with connect(args.config.get_libpq_dsn()) as conn:
-                    if conn.table_exists('search_name'):
+                    if table_exists(conn, 'search_name'):
                         words = tokenizer.most_frequent_words(conn, 1000)
                     else:
                         words = []
index 87e0fc0333270d586140b9e255ee27fe737cfb86..c0619f34db410e7ffc618a06dea12130782ff684 100644 (file)
@@ -8,6 +8,7 @@
 Implementation of the 'index' subcommand.
 """
 import argparse
+import asyncio
 
 import psutil
 
@@ -44,19 +45,7 @@ class UpdateIndex:
 
 
     def run(self, args: NominatimArgs) -> int:
-        from ..indexer.indexer import Indexer
-        from ..tokenizer import factory as tokenizer_factory
-
-        tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
-
-        indexer = Indexer(args.config.get_libpq_dsn(), tokenizer,
-                          args.threads or psutil.cpu_count() or 1)
-
-        if not args.no_boundaries:
-            indexer.index_boundaries(args.minrank, args.maxrank)
-        if not args.boundaries_only:
-            indexer.index_by_rank(args.minrank, args.maxrank)
-            indexer.index_postcodes()
+        asyncio.run(self._do_index(args))
 
         if not args.no_boundaries and not args.boundaries_only \
            and args.minrank == 0 and args.maxrank == 30:
@@ -64,3 +53,22 @@ class UpdateIndex:
                 status.set_indexed(conn, True)
 
         return 0
+
+
+    async def _do_index(self, args: NominatimArgs) -> None:
+        from ..tokenizer import factory as tokenizer_factory
+
+        tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
+        from ..indexer.indexer import Indexer
+
+        indexer = Indexer(args.config.get_libpq_dsn(), tokenizer,
+                          args.threads or psutil.cpu_count() or 1)
+
+        has_pending = True # run at least once
+        while has_pending:
+            if not args.no_boundaries:
+                await indexer.index_boundaries(args.minrank, args.maxrank)
+            if not args.boundaries_only:
+                await indexer.index_by_rank(args.minrank, args.maxrank)
+                await indexer.index_postcodes()
+            has_pending = indexer.has_pending()
index d5acf54b3a36a0e8b0832081a436ed4ce62c586f..adc7ee656caa67004a402aadf66e467f477d4041 100644 (file)
@@ -11,9 +11,10 @@ from typing import Tuple, Optional
 import argparse
 import logging
 from pathlib import Path
+import asyncio
 
 from ..config import Configuration
-from ..db.connection import connect
+from ..db.connection import connect, table_exists
 from ..tokenizer.base import AbstractTokenizer
 from .args import NominatimArgs
 
@@ -99,7 +100,7 @@ class UpdateRefresh:
                                            args.project_dir, tokenizer)
                 indexer = Indexer(args.config.get_libpq_dsn(), tokenizer,
                                   args.threads or 1)
-                indexer.index_postcodes()
+                asyncio.run(indexer.index_postcodes())
             else:
                 LOG.error("The place table doesn't exist. "
                           "Postcode updates on a frozen database is not possible.")
@@ -124,7 +125,7 @@ class UpdateRefresh:
             with connect(args.config.get_libpq_dsn()) as conn:
                 # If the table did not exist before, then the importance code
                 # needs to be enabled.
-                if not conn.table_exists('secondary_importance'):
+                if not table_exists(conn, 'secondary_importance'):
                     args.functions = True
 
             LOG.warning('Import secondary importance raster data from %s', args.project_dir)
index 581c731ea6556eec95897f9c30d22fdc38ad2ea6..ba4c7730b44a6ef6782bd651a25b53abb56748fa 100644 (file)
@@ -13,6 +13,7 @@ import datetime as dt
 import logging
 import socket
 import time
+import asyncio
 
 from ..db import status
 from ..db.connection import connect
@@ -123,7 +124,7 @@ class UpdateReplication:
         return update_interval
 
 
-    def _update(self, args: NominatimArgs) -> None:
+    async def _update(self, args: NominatimArgs) -> None:
         # pylint: disable=too-many-locals
         from ..tools import replication
         from ..indexer.indexer import Indexer
@@ -161,7 +162,7 @@ class UpdateReplication:
 
             if state is not replication.UpdateState.NO_CHANGES and args.do_index:
                 index_start = dt.datetime.now(dt.timezone.utc)
-                indexer.index_full(analyse=False)
+                await indexer.index_full(analyse=False)
 
                 with connect(dsn) as conn:
                     status.set_indexed(conn, True)
@@ -172,8 +173,7 @@ class UpdateReplication:
 
             if state is replication.UpdateState.NO_CHANGES and \
                args.catch_up or update_interval > 40*60:
-                while indexer.has_pending():
-                    indexer.index_full(analyse=False)
+                await indexer.index_full(analyse=False)
 
             if LOG.isEnabledFor(logging.WARNING):
                 assert batchdate is not None
@@ -196,5 +196,5 @@ class UpdateReplication:
         if args.check_for_updates:
             return self._check_for_updates(args)
 
-        self._update(args)
+        asyncio.run(self._update(args))
         return 0
index f516ba0c0db7248bdc09d0c2a014ed7081c9b050..07a76f59036d35a4595adbafcd07f357aa2baa87 100644 (file)
@@ -11,6 +11,7 @@ from typing import Optional
 import argparse
 import logging
 from pathlib import Path
+import asyncio
 
 import psutil
 
@@ -71,14 +72,6 @@ class SetupAll:
 
 
     def run(self, args: NominatimArgs) -> int: # pylint: disable=too-many-statements, too-many-branches
-        from ..data import country_info
-        from ..tools import database_import, refresh, postcodes, freeze
-        from ..indexer.indexer import Indexer
-
-        num_threads = args.threads or psutil.cpu_count() or 1
-
-        country_info.setup_country_config(args.config)
-
         if args.osm_file is None and args.continue_at is None and not args.prepare_database:
             raise UsageError("No input files (use --osm-file).")
 
@@ -90,6 +83,16 @@ class SetupAll:
                 "Cannot use --continue and --prepare-database together."
             )
 
+        return asyncio.run(self.async_run(args))
+
+
+    async def async_run(self, args: NominatimArgs) -> int:
+        from ..data import country_info
+        from ..tools import database_import, refresh, postcodes, freeze
+        from ..indexer.indexer import Indexer
+
+        num_threads = args.threads or psutil.cpu_count() or 1
+        country_info.setup_country_config(args.config)
 
         if args.prepare_database or args.continue_at is None:
             LOG.warning('Creating database')
@@ -99,39 +102,7 @@ class SetupAll:
                 return 0
 
         if args.continue_at in (None, 'import-from-file'):
-            files = args.get_osm_file_list()
-            if not files:
-                raise UsageError("No input files (use --osm-file).")
-
-            if args.continue_at in ('import-from-file', None):
-                # Check if the correct plugins are installed
-                database_import.check_existing_database_plugins(args.config.get_libpq_dsn())
-                LOG.warning('Setting up country tables')
-                country_info.setup_country_tables(args.config.get_libpq_dsn(),
-                                                args.config.lib_dir.data,
-                                                args.no_partitions)
-
-                LOG.warning('Importing OSM data file')
-                database_import.import_osm_data(files,
-                                                args.osm2pgsql_options(0, 1),
-                                                drop=args.no_updates,
-                                                ignore_errors=args.ignore_errors)
-
-                LOG.warning('Importing wikipedia importance data')
-                data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir)
-                if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(),
-                                                    data_path) > 0:
-                    LOG.error('Wikipedia importance dump file not found. '
-                            'Calculating importance values of locations will not '
-                            'use Wikipedia importance data.')
-
-                LOG.warning('Importing secondary importance raster data')
-                if refresh.import_secondary_importance(args.config.get_libpq_dsn(),
-                                                    args.project_dir) != 0:
-                    LOG.error('Secondary importance file not imported. '
-                            'Falling back to default ranking.')
-
-                self._setup_tables(args.config, args.reverse_only)
+            self._base_import(args)
 
         if args.continue_at in ('import-from-file', 'load-data', None):
             LOG.warning('Initialise tables')
@@ -139,7 +110,7 @@ class SetupAll:
                 database_import.truncate_data_tables(conn)
 
             LOG.warning('Load data into placex table')
-            database_import.load_data(args.config.get_libpq_dsn(), num_threads)
+            await database_import.load_data(args.config.get_libpq_dsn(), num_threads)
 
         LOG.warning("Setting up tokenizer")
         tokenizer = self._get_tokenizer(args.continue_at, args.config)
@@ -153,13 +124,13 @@ class SetupAll:
             ('import-from-file', 'load-data', 'indexing', None):
             LOG.warning('Indexing places')
             indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, num_threads)
-            indexer.index_full(analyse=not args.index_noanalyse)
+            await indexer.index_full(analyse=not args.index_noanalyse)
 
         LOG.warning('Post-process tables')
         with connect(args.config.get_libpq_dsn()) as conn:
-            database_import.create_search_indices(conn, args.config,
-                                                  drop=args.no_updates,
-                                                  threads=num_threads)
+            await database_import.create_search_indices(conn, args.config,
+                                                        drop=args.no_updates,
+                                                        threads=num_threads)
             LOG.warning('Create search index for default country names.')
             country_info.create_country_names(conn, tokenizer,
                                               args.config.get_str_list('LANGUAGES'))
@@ -180,6 +151,45 @@ class SetupAll:
         return 0
 
 
+    def _base_import(self, args: NominatimArgs) -> None:
+        from ..tools import database_import, refresh
+        from ..data import country_info
+
+        files = args.get_osm_file_list()
+        if not files:
+            raise UsageError("No input files (use --osm-file).")
+
+        if args.continue_at in ('import-from-file', None):
+            # Check if the correct plugins are installed
+            database_import.check_existing_database_plugins(args.config.get_libpq_dsn())
+            LOG.warning('Setting up country tables')
+            country_info.setup_country_tables(args.config.get_libpq_dsn(),
+                                            args.config.lib_dir.data,
+                                            args.no_partitions)
+
+            LOG.warning('Importing OSM data file')
+            database_import.import_osm_data(files,
+                                            args.osm2pgsql_options(0, 1),
+                                            drop=args.no_updates,
+                                            ignore_errors=args.ignore_errors)
+
+            LOG.warning('Importing wikipedia importance data')
+            data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir)
+            if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(),
+                                                data_path) > 0:
+                LOG.error('Wikipedia importance dump file not found. '
+                        'Calculating importance values of locations will not '
+                        'use Wikipedia importance data.')
+
+            LOG.warning('Importing secondary importance raster data')
+            if refresh.import_secondary_importance(args.config.get_libpq_dsn(),
+                                                args.project_dir) != 0:
+                LOG.error('Secondary importance file not imported. '
+                        'Falling back to default ranking.')
+
+            self._setup_tables(args.config, args.reverse_only)
+
+
     def _setup_tables(self, config: Configuration, reverse_only: bool) -> None:
         """ Set up the basic database layout: tables, indexes and functions.
         """
index c4264f0d68ef17c0a182a1063c0a970862cf5aec..5ae3dea3b3ea8b14dfef655c0edf45e31aab6bdf 100644 (file)
@@ -7,7 +7,7 @@
 """
 Nominatim configuration accessor.
 """
-from typing import Dict, Any, List, Mapping, Optional
+from typing import Union, Dict, Any, List, Mapping, Optional
 import importlib.util
 import logging
 import os
@@ -18,10 +18,7 @@ import yaml
 
 from dotenv import dotenv_values
 
-try:
-    from psycopg2.extensions import parse_dsn
-except ModuleNotFoundError:
-    from psycopg.conninfo import conninfo_to_dict as parse_dsn # type: ignore[assignment]
+from psycopg.conninfo import conninfo_to_dict
 
 from .typing import StrPath
 from .errors import UsageError
@@ -198,7 +195,7 @@ class Configuration:
         return dsn
 
 
-    def get_database_params(self) -> Mapping[str, str]:
+    def get_database_params(self) -> Mapping[str, Union[str, int, None]]:
         """ Get the configured parameters for the database connection
             as a mapping.
         """
@@ -207,7 +204,7 @@ class Configuration:
         if dsn.startswith('pgsql:'):
             return dict((p.split('=', 1) for p in dsn[6:].split(';')))
 
-        return parse_dsn(dsn)
+        return conninfo_to_dict(dsn)
 
 
     def get_import_style_file(self) -> Path:
index c8002ee7023ca27a9e680003e92a66e3d8f0b428..9b71405993f35f855e57465564a2b1eebd29ee9e 100644 (file)
@@ -9,10 +9,9 @@ Functions for importing and managing static country information.
 """
 from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload
 from pathlib import Path
-import psycopg2.extras
 
 from ..db import utils as db_utils
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, register_hstore
 from ..errors import UsageError
 from ..config import Configuration
 from ..tokenizer.base import AbstractTokenizer
@@ -129,8 +128,8 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals
 
             params.append((ccode, props['names'], lang, partition))
     with connect(dsn) as conn:
+        register_hstore(conn)
         with conn.cursor() as cur:
-            psycopg2.extras.register_hstore(cur)
             cur.execute(
                 """ CREATE TABLE public.country_name (
                         country_code character varying(2),
@@ -139,9 +138,10 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals
                         country_default_language_code text,
                         partition integer
                     ); """)
-            cur.execute_values(
+            cur.executemany(
                 """ INSERT INTO public.country_name
-                    (country_code, name, country_default_language_code, partition) VALUES %s
+                    (country_code, name, country_default_language_code, partition)
+                    VALUES (%s, %s, %s, %s)
                 """, params)
         conn.commit()
 
@@ -157,8 +157,8 @@ def create_country_names(conn: Connection, tokenizer: AbstractTokenizer,
         return ':' not in key or not languages or \
                key[key.index(':') + 1:] in languages
 
+    register_hstore(conn)
     with conn.cursor() as cur:
-        psycopg2.extras.register_hstore(cur)
         cur.execute("""SELECT country_code, name FROM country_name
                        WHERE country_code is not null""")
 
diff --git a/src/nominatim_db/db/async_connection.py b/src/nominatim_db/db/async_connection.py
deleted file mode 100644 (file)
index 83e4c86..0000000
+++ /dev/null
@@ -1,236 +0,0 @@
-# SPDX-License-Identifier: GPL-3.0-or-later
-#
-# This file is part of Nominatim. (https://nominatim.org)
-#
-# Copyright (C) 2024 by the Nominatim developer community.
-# For a full list of authors see the git log.
-""" Non-blocking database connections.
-"""
-from typing import Callable, Any, Optional, Iterator, Sequence
-import logging
-import select
-import time
-
-import psycopg2
-from psycopg2.extras import wait_select
-
-# psycopg2 emits different exceptions pre and post 2.8. Detect if the new error
-# module is available and adapt the error handling accordingly.
-try:
-    import psycopg2.errors # pylint: disable=no-name-in-module,import-error
-    __has_psycopg2_errors__ = True
-except ImportError:
-    __has_psycopg2_errors__ = False
-
-from ..typing import T_cursor, Query
-
-LOG = logging.getLogger()
-
-class DeadlockHandler:
-    """ Context manager that catches deadlock exceptions and calls
-        the given handler function. All other exceptions are passed on
-        normally.
-    """
-
-    def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None:
-        self.handler = handler
-        self.ignore_sql_errors = ignore_sql_errors
-
-    def __enter__(self) -> 'DeadlockHandler':
-        return self
-
-    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
-        if __has_psycopg2_errors__:
-            if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101
-                self.handler()
-                return True
-        elif exc_type == psycopg2.extensions.TransactionRollbackError \
-             and exc_value.pgcode == '40P01':
-            self.handler()
-            return True
-
-        if self.ignore_sql_errors and isinstance(exc_value, psycopg2.Error):
-            LOG.info("SQL error ignored: %s", exc_value)
-            return True
-
-        return False
-
-
-class DBConnection:
-    """ A single non-blocking database connection.
-    """
-
-    def __init__(self, dsn: str,
-                 cursor_factory: Optional[Callable[..., T_cursor]] = None,
-                 ignore_sql_errors: bool = False) -> None:
-        self.dsn = dsn
-
-        self.current_query: Optional[Query] = None
-        self.current_params: Optional[Sequence[Any]] = None
-        self.ignore_sql_errors = ignore_sql_errors
-
-        self.conn: Optional['psycopg2._psycopg.connection'] = None
-        self.cursor: Optional['psycopg2._psycopg.cursor'] = None
-        self.connect(cursor_factory=cursor_factory)
-
-    def close(self) -> None:
-        """ Close all open connections. Does not wait for pending requests.
-        """
-        if self.conn is not None:
-            if self.cursor is not None:
-                self.cursor.close()
-                self.cursor = None
-            self.conn.close()
-
-        self.conn = None
-
-    def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None:
-        """ (Re)connect to the database. Creates an asynchronous connection
-            with JIT and parallel processing disabled. If a connection was
-            already open, it is closed and a new connection established.
-            The caller must ensure that no query is pending before reconnecting.
-        """
-        self.close()
-
-        # Use a dict to hand in the parameters because async is a reserved
-        # word in Python3.
-        self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True}) # type: ignore
-        assert self.conn
-        self.wait()
-
-        if cursor_factory is not None:
-            self.cursor = self.conn.cursor(cursor_factory=cursor_factory)
-        else:
-            self.cursor = self.conn.cursor()
-        # Disable JIT and parallel workers as they are known to cause problems.
-        # Update pg_settings instead of using SET because it does not yield
-        # errors on older versions of Postgres where the settings are not
-        # implemented.
-        self.perform(
-            """ UPDATE pg_settings SET setting = -1 WHERE name = 'jit_above_cost';
-                UPDATE pg_settings SET setting = 0
-                   WHERE name = 'max_parallel_workers_per_gather';""")
-        self.wait()
-
-    def _deadlock_handler(self) -> None:
-        LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params))
-        assert self.cursor is not None
-        assert self.current_query is not None
-        assert self.current_params is not None
-
-        self.cursor.execute(self.current_query, self.current_params)
-
-    def wait(self) -> None:
-        """ Block until any pending operation is done.
-        """
-        while True:
-            with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
-                wait_select(self.conn)
-                self.current_query = None
-                return
-
-    def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None:
-        """ Send SQL query to the server. Returns immediately without
-            blocking.
-        """
-        assert self.cursor is not None
-        self.current_query = sql
-        self.current_params = args
-        self.cursor.execute(sql, args)
-
-    def fileno(self) -> int:
-        """ File descriptor to wait for. (Makes this class select()able.)
-        """
-        assert self.conn is not None
-        return self.conn.fileno()
-
-    def is_done(self) -> bool:
-        """ Check if the connection is available for a new query.
-
-            Also checks if the previous query has run into a deadlock.
-            If so, then the previous query is repeated.
-        """
-        assert self.conn is not None
-
-        if self.current_query is None:
-            return True
-
-        with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
-            if self.conn.poll() == psycopg2.extensions.POLL_OK:
-                self.current_query = None
-                return True
-
-        return False
-
-
-class WorkerPool:
-    """ A pool of asynchronous database connections.
-
-        The pool may be used as a context manager.
-    """
-    REOPEN_CONNECTIONS_AFTER = 100000
-
-    def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None:
-        self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors)
-                        for _ in range(pool_size)]
-        self.free_workers = self._yield_free_worker()
-        self.wait_time = 0.0
-
-
-    def finish_all(self) -> None:
-        """ Wait for all connection to finish.
-        """
-        for thread in self.threads:
-            while not thread.is_done():
-                thread.wait()
-
-        self.free_workers = self._yield_free_worker()
-
-    def close(self) -> None:
-        """ Close all connections and clear the pool.
-        """
-        for thread in self.threads:
-            thread.close()
-        self.threads = []
-        self.free_workers = iter([])
-
-
-    def next_free_worker(self) -> DBConnection:
-        """ Get the next free connection.
-        """
-        return next(self.free_workers)
-
-
-    def _yield_free_worker(self) -> Iterator[DBConnection]:
-        ready = self.threads
-        command_stat = 0
-        while True:
-            for thread in ready:
-                if thread.is_done():
-                    command_stat += 1
-                    yield thread
-
-            if command_stat > self.REOPEN_CONNECTIONS_AFTER:
-                self._reconnect_threads()
-                ready = self.threads
-                command_stat = 0
-            else:
-                tstart = time.time()
-                _, ready, _ = select.select([], self.threads, [])
-                self.wait_time += time.time() - tstart
-
-
-    def _reconnect_threads(self) -> None:
-        for thread in self.threads:
-            while not thread.is_done():
-                thread.wait()
-            thread.connect()
-
-
-    def __enter__(self) -> 'WorkerPool':
-        return self
-
-
-    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
-        self.finish_all()
-        self.close()
index 19fcddd44c7bd790c2b77224a7dbad58a48f3063..6c7e843fdd34dddfdb6fdf4db9decbd00ec8b145 100644 (file)
 """
 Specialised connection and cursor functions.
 """
-from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
-import contextlib
+from typing import Optional, Any, Dict, Tuple
 import logging
 import os
 
-import psycopg2
-import psycopg2.extensions
-import psycopg2.extras
-from psycopg2 import sql as pysql
+import psycopg
+import psycopg.types.hstore
+from psycopg import sql as pysql
 
-from ..typing import SysEnv, Query, T_cursor
+from ..typing import SysEnv
 from ..errors import UsageError
 
 LOG = logging.getLogger()
 
-class Cursor(psycopg2.extras.DictCursor):
-    """ A cursor returning dict-like objects and providing specialised
-        execution functions.
-    """
-    # pylint: disable=arguments-renamed,arguments-differ
-    def execute(self, query: Query, args: Any = None) -> None:
-        """ Query execution that logs the SQL query when debugging is enabled.
-        """
-        if LOG.isEnabledFor(logging.DEBUG):
-            LOG.debug(self.mogrify(query, args).decode('utf-8'))
-
-        super().execute(query, args)
-
-
-    def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
-                       template: Optional[Query] = None) -> None:
-        """ Wrapper for the psycopg2 convenience function to execute
-            SQL for a list of values.
-        """
-        LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
-
-        psycopg2.extras.execute_values(self, sql, argslist, template=template)
+Cursor = psycopg.Cursor[Any]
+Connection = psycopg.Connection[Any]
 
+def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any:
+    """ Execute query that returns a single value. The value is returned.
+        If the query yields more than one row, a ValueError is raised.
+    """
+    with conn.cursor(row_factory=psycopg.rows.tuple_row) as cur:
+        cur.execute(sql, args)
 
-    def scalar(self, sql: Query, args: Any = None) -> Any:
-        """ Execute query that returns a single value. The value is returned.
-            If the query yields more than one row, a ValueError is raised.
-        """
-        self.execute(sql, args)
-
-        if self.rowcount != 1:
+        if cur.rowcount != 1:
             raise RuntimeError("Query did not return a single row.")
 
-        result = self.fetchone()
-        assert result is not None
-
-        return result[0]
+        result = cur.fetchone()
 
+    assert result is not None
+    return result[0]
 
-    def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
-        """ Drop the table with the given name.
-            Set `if_exists` to False if a non-existent table should raise
-            an exception instead of just being ignored. If 'cascade' is set
-            to True then all dependent tables are deleted as well.
-        """
-        sql = 'DROP TABLE '
-        if if_exists:
-            sql += 'IF EXISTS '
-        sql += '{}'
-        if cascade:
-            sql += ' CASCADE'
 
-        self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
+def table_exists(conn: Connection, table: str) -> bool:
+    """ Check that a table with the given name exists in the database.
+    """
+    num = execute_scalar(conn,
+            """SELECT count(*) FROM pg_tables
+               WHERE tablename = %s and schemaname = 'public'""", (table, ))
+    return num == 1 if isinstance(num, int) else False
 
 
-class Connection(psycopg2.extensions.connection):
-    """ A connection that provides the specialised cursor by default and
-        adds convenience functions for administrating the database.
+def table_has_column(conn: Connection, table: str, column: str) -> bool:
+    """ Check if the table 'table' exists and has a column with name 'column'.
     """
-    @overload # type: ignore[override]
-    def cursor(self) -> Cursor:
-        ...
-
-    @overload
-    def cursor(self, name: str) -> Cursor:
-        ...
-
-    @overload
-    def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
-        ...
-
-    def cursor(self, cursor_factory  = Cursor, **kwargs): # type: ignore
-        """ Return a new cursor. By default the specialised cursor is returned.
-        """
-        return super().cursor(cursor_factory=cursor_factory, **kwargs)
-
-
-    def table_exists(self, table: str) -> bool:
-        """ Check that a table with the given name exists in the database.
-        """
-        with self.cursor() as cur:
-            num = cur.scalar("""SELECT count(*) FROM pg_tables
-                                WHERE tablename = %s and schemaname = 'public'""", (table, ))
-            return num == 1 if isinstance(num, int) else False
-
-
-    def table_has_column(self, table: str, column: str) -> bool:
-        """ Check if the table 'table' exists and has a column with name 'column'.
-        """
-        with self.cursor() as cur:
-            has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
-                                       WHERE table_name = %s
-                                             and column_name = %s""",
-                                    (table, column))
-            return has_column > 0 if isinstance(has_column, int) else False
-
-
-    def index_exists(self, index: str, table: Optional[str] = None) -> bool:
-        """ Check that an index with the given name exists in the database.
-            If table is not None then the index must relate to the given
-            table.
-        """
-        with self.cursor() as cur:
-            cur.execute("""SELECT tablename FROM pg_indexes
-                           WHERE indexname = %s and schemaname = 'public'""", (index, ))
-            if cur.rowcount == 0:
-                return False
+    has_column = execute_scalar(conn,
+                    """SELECT count(*) FROM information_schema.columns
+                       WHERE table_name = %s and column_name = %s""",
+                    (table, column))
+    return has_column > 0 if isinstance(has_column, int) else False
 
-            if table is not None:
-                row = cur.fetchone()
-                if row is None or not isinstance(row[0], str):
-                    return False
-                return row[0] == table
-
-        return True
 
+def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
+    """ Check that an index with the given name exists in the database.
+        If table is not None then the index must relate to the given
+        table.
+    """
+    with conn.cursor() as cur:
+        cur.execute("""SELECT tablename FROM pg_indexes
+                       WHERE indexname = %s and schemaname = 'public'""", (index, ))
+        if cur.rowcount == 0:
+            return False
+
+        if table is not None:
+            row = cur.fetchone()
+            if row is None or not isinstance(row[0], str):
+                return False
+            return row[0] == table
 
-    def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
-        """ Drop the table with the given name.
-            Set `if_exists` to False if a non-existent table should raise
-            an exception instead of just being ignored.
-        """
-        with self.cursor() as cur:
-            cur.drop_table(name, if_exists, cascade)
-        self.commit()
+    return True
 
+def drop_tables(conn: Connection, *names: str,
+               if_exists: bool = True, cascade: bool = False) -> None:
+    """ Drop one or more tables with the given names.
+        Set `if_exists` to False if a non-existent table should raise
+        an exception instead of just being ignored. `cascade` will cause
+        depended objects to be dropped as well.
+        The caller needs to take care of committing the change.
+    """
+    sql = pysql.SQL('DROP TABLE%s{}%s' % (
+                        ' IF EXISTS ' if if_exists else ' ',
+                        ' CASCADE' if cascade else ''))
 
-    def server_version_tuple(self) -> Tuple[int, int]:
-        """ Return the server version as a tuple of (major, minor).
-            Converts correctly for pre-10 and post-10 PostgreSQL versions.
-        """
-        version = self.server_version
-        if version < 100000:
-            return (int(version / 10000), int((version % 10000) / 100))
+    with conn.cursor() as cur:
+        for name in names:
+            cur.execute(sql.format(pysql.Identifier(name)))
 
-        return (int(version / 10000), version % 10000)
 
+def server_version_tuple(conn: Connection) -> Tuple[int, int]:
+    """ Return the server version as a tuple of (major, minor).
+        Converts correctly for pre-10 and post-10 PostgreSQL versions.
+    """
+    version = conn.info.server_version
+    if version < 100000:
+        return (int(version / 10000), int((version % 10000) / 100))
 
-    def postgis_version_tuple(self) -> Tuple[int, int]:
-        """ Return the postgis version installed in the database as a
-            tuple of (major, minor). Assumes that the PostGIS extension
-            has been installed already.
-        """
-        with self.cursor() as cur:
-            version = cur.scalar('SELECT postgis_lib_version()')
+    return (int(version / 10000), version % 10000)
 
-        version_parts = version.split('.')
-        if len(version_parts) < 2:
-            raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
 
-        return (int(version_parts[0]), int(version_parts[1]))
+def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
+    """ Return the postgis version installed in the database as a
+        tuple of (major, minor). Assumes that the PostGIS extension
+        has been installed already.
+    """
+    version = execute_scalar(conn, 'SELECT postgis_lib_version()')
 
+    version_parts = version.split('.')
+    if len(version_parts) < 2:
+        raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
 
-    def extension_loaded(self, extension_name: str) -> bool:
-        """ Return True if the hstore extension is loaded in the database.
-        """
-        with self.cursor() as cur:
-            cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
-            return cur.rowcount > 0
+    return (int(version_parts[0]), int(version_parts[1]))
 
 
-class ConnectionContext(ContextManager[Connection]):
-    """ Context manager of the connection that also provides direct access
-        to the underlying connection.
+def register_hstore(conn: Connection) -> None:
+    """ Register the hstore type with psycopg for the connection.
     """
-    connection: Connection
+    info = psycopg.types.TypeInfo.fetch(conn, "hstore")
+    if info is None:
+        raise RuntimeError('Hstore extension is requested but not installed.')
+    psycopg.types.hstore.register_hstore(info, conn)
+
 
-def connect(dsn: str) -> ConnectionContext:
+def connect(dsn: str, **kwargs: Any) -> Connection:
     """ Open a connection to the database using the specialised connection
         factory. The returned object may be used in conjunction with 'with'.
         When used outside a context manager, use the `connection` attribute
         to get the connection.
     """
     try:
-        conn = psycopg2.connect(dsn, connection_factory=Connection)
-        ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
-        ctxmgr.connection = conn
-        return ctxmgr
-    except psycopg2.OperationalError as err:
+        return psycopg.connect(dsn, row_factory=psycopg.rows.namedtuple_row, **kwargs)
+    except psycopg.OperationalError as err:
         raise UsageError(f"Cannot connect to database: {err}") from err
 
 
@@ -245,10 +181,18 @@ def get_pg_env(dsn: str,
     """
     env = dict(base_env if base_env is not None else os.environ)
 
-    for param, value in psycopg2.extensions.parse_dsn(dsn).items():
+    for param, value in psycopg.conninfo.conninfo_to_dict(dsn).items():
         if param in _PG_CONNECTION_STRINGS:
-            env[_PG_CONNECTION_STRINGS[param]] = value
+            env[_PG_CONNECTION_STRINGS[param]] = str(value)
         else:
             LOG.error("Unknown connection parameter '%s' ignored.", param)
 
     return env
+
+
+async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None:
+    """ Open a connection to the database and run a single query
+        asynchronously.
+    """
+    async with await psycopg.AsyncConnection.connect(dsn) as aconn:
+        await aconn.execute(query)
index 3549382f944ad9e7f998b08190d2190b9c124d97..0e017ead900c0c2f6fa236e65d3ae927498e059b 100644 (file)
@@ -9,7 +9,7 @@ Query and access functions for the in-database property table.
 """
 from typing import Optional, cast
 
-from .connection import Connection
+from .connection import Connection, table_exists
 
 def set_property(conn: Connection, name: str, value: str) -> None:
     """ Add or replace the property with the given name.
@@ -31,7 +31,7 @@ def get_property(conn: Connection, name: str) -> Optional[str]:
     """ Return the current value of the given property or None if the property
         is not set.
     """
-    if not conn.table_exists('nominatim_properties'):
+    if not table_exists(conn, 'nominatim_properties'):
         return None
 
     with conn.cursor() as cur:
diff --git a/src/nominatim_db/db/query_pool.py b/src/nominatim_db/db/query_pool.py
new file mode 100644 (file)
index 0000000..2828937
--- /dev/null
@@ -0,0 +1,87 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2024 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+A connection pool that executes incoming queries in parallel.
+"""
+from typing import Any, Tuple, Optional
+import asyncio
+import logging
+import time
+
+import psycopg
+
+LOG = logging.getLogger()
+
+QueueItem = Optional[Tuple[psycopg.abc.Query, Any]]
+
+class QueryPool:
+    """ Pool to run SQL queries in parallel asynchronous execution.
+
+        All queries are run in autocommit mode. If parallel execution leads
+        to a deadlock, then the query is repeated.
+        The results of the queries is discarded.
+    """
+    def __init__(self, dsn: str, pool_size: int = 1, **conn_args: Any) -> None:
+        self.wait_time = 0.0
+        self.query_queue: 'asyncio.Queue[QueueItem]' = asyncio.Queue(maxsize=2 * pool_size)
+
+        self.pool = [asyncio.create_task(self._worker_loop(dsn, **conn_args))
+                     for _ in range(pool_size)]
+
+
+    async def put_query(self, query: psycopg.abc.Query, params: Any) -> None:
+        """ Schedule a query for execution.
+        """
+        tstart = time.time()
+        await self.query_queue.put((query, params))
+        self.wait_time += time.time() - tstart
+        await asyncio.sleep(0)
+
+
+    async def finish(self) -> None:
+        """ Wait for all queries to finish and close the pool.
+        """
+        for _ in self.pool:
+            await self.query_queue.put(None)
+
+        tstart = time.time()
+        await asyncio.wait(self.pool)
+        self.wait_time += time.time() - tstart
+
+        for task in self.pool:
+            excp = task.exception()
+            if excp is not None:
+                raise excp
+
+
+    async def _worker_loop(self, dsn: str, **conn_args: Any) -> None:
+        conn_args['autocommit'] = True
+        aconn = await psycopg.AsyncConnection.connect(dsn, **conn_args)
+        async with aconn:
+            async with aconn.cursor() as cur:
+                item = await self.query_queue.get()
+                while item is not None:
+                    try:
+                        if item[1] is None:
+                            await cur.execute(item[0])
+                        else:
+                            await cur.execute(item[0], item[1])
+
+                        item = await self.query_queue.get()
+                    except psycopg.errors.DeadlockDetected:
+                        assert item is not None
+                        LOG.info("Deadlock detected (sql = %s, params = %s), retry.",
+                                 str(item[0]), str(item[1]))
+                        # item is still valid here, causing a retry
+
+
+    async def __aenter__(self) -> 'QueryPool':
+        return self
+
+
+    async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        await self.finish()
index 468f35107b8f5188a3cbc9b93c9fb3353ef5e322..25faead4d7fcb75a1bd12fd866f96a82e88db76e 100644 (file)
@@ -8,11 +8,12 @@
 Preprocessing of SQL files.
 """
 from typing import Set, Dict, Any, cast
+
 import jinja2
 
-from .connection import Connection
-from .async_connection import WorkerPool
+from .connection import Connection, server_version_tuple, postgis_version_tuple
 from ..config import Configuration
+from ..db.query_pool import QueryPool
 
 def _get_partitions(conn: Connection) -> Set[int]:
     """ Get the set of partitions currently in use.
@@ -66,8 +67,8 @@ def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
     """ Set up a dictionary with various optional Postgresql/Postgis features that
         depend on the database version.
     """
-    pg_version = conn.server_version_tuple()
-    postgis_version = conn.postgis_version_tuple()
+    pg_version = server_version_tuple(conn)
+    postgis_version = postgis_version_tuple(conn)
     pg11plus = pg_version >= (11, 0, 0)
     ps3 = postgis_version >= (3, 0)
     return {
@@ -125,8 +126,8 @@ class SQLPreprocessor:
         conn.commit()
 
 
-    def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
-                              **kwargs: Any) -> None:
+    async def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
+                                    **kwargs: Any) -> None:
         """ Execute the given SQL files using parallel asynchronous connections.
             The keyword arguments may supply additional parameters for
             preprocessing.
@@ -138,6 +139,6 @@ class SQLPreprocessor:
 
         parts = sql.split('\n---\n')
 
-        with WorkerPool(dsn, num_threads) as pool:
+        async with QueryPool(dsn, num_threads) as pool:
             for part in parts:
-                pool.next_free_worker().perform(part)
+                await pool.put_query(part, None)
index 1278359cc02ebae2ac1513162a0b49a0ef37d211..4fe9f4449c1ce24e6af45feb34e8fad1bde9d9f8 100644 (file)
@@ -7,34 +7,25 @@
 """
 Access and helper functions for the status and status log table.
 """
-from typing import Optional, Tuple, cast
+from typing import Optional, Tuple
 import datetime as dt
 import logging
 import re
 
-from .connection import Connection
+from .connection import Connection, table_exists, execute_scalar
 from ..utils.url_utils import get_url
 from ..errors import UsageError
-from ..typing import TypedDict
 
 LOG = logging.getLogger()
 ISODATE_FORMAT = '%Y-%m-%dT%H:%M:%S'
 
 
-class StatusRow(TypedDict):
-    """ Dictionary of columns of the import_status table.
-    """
-    lastimportdate: dt.datetime
-    sequence_id: Optional[int]
-    indexed: Optional[bool]
-
-
 def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetime:
     """ Determine the date of the database from the newest object in the
         data base.
     """
     # If there is a date from osm2pgsql available, use that.
-    if conn.table_exists('osm2pgsql_properties'):
+    if table_exists(conn, 'osm2pgsql_properties'):
         with conn.cursor() as cur:
             cur.execute(""" SELECT value FROM osm2pgsql_properties
                             WHERE property = 'current_timestamp' """)
@@ -47,15 +38,14 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
         raise UsageError("Cannot determine database date from data in offline mode.")
 
     # Else, find the node with the highest ID in the database
-    with conn.cursor() as cur:
-        if conn.table_exists('place'):
-            osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'")
-        else:
-            osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'")
+    if table_exists(conn, 'place'):
+        osmid = execute_scalar(conn, "SELECT max(osm_id) FROM place WHERE osm_type='N'")
+    else:
+        osmid = execute_scalar(conn, "SELECT max(osm_id) FROM placex WHERE osm_type='N'")
 
-        if osmid is None:
-            LOG.fatal("No data found in the database.")
-            raise UsageError("No data found in the database.")
+    if osmid is None:
+        LOG.fatal("No data found in the database.")
+        raise UsageError("No data found in the database.")
 
     LOG.info("Using node id %d for timestamp lookup", osmid)
     # Get the node from the API to find the timestamp when it was created.
@@ -103,8 +93,9 @@ def get_status(conn: Connection) -> Tuple[Optional[dt.datetime], Optional[int],
         if cur.rowcount < 1:
             return None, None, None
 
-        row = cast(StatusRow, cur.fetchone())
-        return row['lastimportdate'], row['sequence_id'], row['indexed']
+        row = cur.fetchone()
+        assert row
+        return row.lastimportdate, row.sequence_id, row.indexed
 
 
 def set_indexed(conn: Connection, state: bool) -> None:
index 32bf79acd93f04e951e56fa0861eea1445f24090..02e5bd2d35ed825c322d07c203e64795de308f5b 100644 (file)
@@ -7,14 +7,13 @@
 """
 Helper functions for handling DB accesses.
 """
-from typing import IO, Optional, Union, Any, Iterable
+from typing import IO, Optional, Union
 import subprocess
 import logging
 import gzip
-import io
 from pathlib import Path
 
-from .connection import get_pg_env, Cursor
+from .connection import get_pg_env
 from ..errors import UsageError
 
 LOG = logging.getLogger()
@@ -72,58 +71,3 @@ def execute_file(dsn: str, fname: Path,
 
     if ret != 0 or remain > 0:
         raise UsageError("Failed to execute SQL file.")
-
-
-# List of characters that need to be quoted for the copy command.
-_SQL_TRANSLATION = {ord('\\'): '\\\\',
-                    ord('\t'): '\\t',
-                    ord('\n'): '\\n'}
-
-
-class CopyBuffer:
-    """ Data collector for the copy_from command.
-    """
-
-    def __init__(self) -> None:
-        self.buffer = io.StringIO()
-
-
-    def __enter__(self) -> 'CopyBuffer':
-        return self
-
-
-    def size(self) -> int:
-        """ Return the number of bytes the buffer currently contains.
-        """
-        return self.buffer.tell()
-
-    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
-        if self.buffer is not None:
-            self.buffer.close()
-
-
-    def add(self, *data: Any) -> None:
-        """ Add another row of data to the copy buffer.
-        """
-        first = True
-        for column in data:
-            if first:
-                first = False
-            else:
-                self.buffer.write('\t')
-            if column is None:
-                self.buffer.write('\\N')
-            else:
-                self.buffer.write(str(column).translate(_SQL_TRANSLATION))
-        self.buffer.write('\n')
-
-
-    def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
-        """ Copy all collected data into the given table.
-
-            The buffer is empty and reusable after this operation.
-        """
-        if self.buffer.tell() > 0:
-            self.buffer.seek(0)
-            cur.copy_from(self.buffer, table, columns=columns)
-            self.buffer = io.StringIO()
index 5a219f6b2ecc3904b1129c9b4f9cff92c12132a5..9680e5a9db14d382e84f385c7ffeaf6cb1c3dd13 100644 (file)
@@ -7,92 +7,20 @@
 """
 Main work horse for indexing (computing addresses) the database.
 """
-from typing import Optional, Any, cast
+from typing import cast, List, Any
 import logging
 import time
 
-import psycopg2.extras
+import psycopg
 
-from ..typing import DictCursorResults
-from ..db.async_connection import DBConnection, WorkerPool
-from ..db.connection import connect, Connection, Cursor
+from ..db.connection import connect, execute_scalar
+from ..db.query_pool import QueryPool
 from ..tokenizer.base import AbstractTokenizer
 from .progress import ProgressLogger
 from . import runners
 
 LOG = logging.getLogger()
 
-
-class PlaceFetcher:
-    """ Asynchronous connection that fetches place details for processing.
-    """
-    def __init__(self, dsn: str, setup_conn: Connection) -> None:
-        self.wait_time = 0.0
-        self.current_ids: Optional[DictCursorResults] = None
-        self.conn: Optional[DBConnection] = DBConnection(dsn,
-                                               cursor_factory=psycopg2.extras.DictCursor)
-
-        with setup_conn.cursor() as cur:
-            # need to fetch those manually because register_hstore cannot
-            # fetch them on an asynchronous connection below.
-            hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
-            hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
-
-        psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
-                                        array_oid=hstore_array_oid)
-
-    def close(self) -> None:
-        """ Close the underlying asynchronous connection.
-        """
-        if self.conn:
-            self.conn.close()
-            self.conn = None
-
-
-    def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
-        """ Send a request for the next batch of places.
-            If details for the places are required, they will be fetched
-            asynchronously.
-
-            Returns true if there is still data available.
-        """
-        ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
-
-        if not ids:
-            self.current_ids = None
-            return False
-
-        assert self.conn is not None
-        self.current_ids = runner.get_place_details(self.conn, ids)
-
-        return True
-
-    def get_batch(self) -> DictCursorResults:
-        """ Get the next batch of data, previously requested with
-            `fetch_next_batch`.
-        """
-        assert self.conn is not None
-        assert self.conn.cursor is not None
-
-        if self.current_ids is not None and not self.current_ids:
-            tstart = time.time()
-            self.conn.wait()
-            self.wait_time += time.time() - tstart
-            self.current_ids = cast(Optional[DictCursorResults],
-                                    self.conn.cursor.fetchall())
-
-        return self.current_ids if self.current_ids is not None else []
-
-    def __enter__(self) -> 'PlaceFetcher':
-        return self
-
-
-    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
-        assert self.conn is not None
-        self.conn.wait()
-        self.close()
-
-
 class Indexer:
     """ Main indexing routine.
     """
@@ -114,7 +42,7 @@ class Indexer:
                 return cur.rowcount > 0
 
 
-    def index_full(self, analyse: bool = True) -> None:
+    async def index_full(self, analyse: bool = True) -> None:
         """ Index the complete database. This will first index boundaries
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
@@ -128,23 +56,27 @@ class Indexer:
                     with conn.cursor() as cur:
                         cur.execute('ANALYZE')
 
-            if self.index_by_rank(0, 4) > 0:
-                _analyze()
+            while True:
+                if await self.index_by_rank(0, 4) > 0:
+                    _analyze()
+
+                if await self.index_boundaries(0, 30) > 100:
+                    _analyze()
 
-            if self.index_boundaries(0, 30) > 100:
-                _analyze()
+                if await self.index_by_rank(5, 25) > 100:
+                    _analyze()
 
-            if self.index_by_rank(5, 25) > 100:
-                _analyze()
+                if await self.index_by_rank(26, 30) > 1000:
+                    _analyze()
 
-            if self.index_by_rank(26, 30) > 1000:
-                _analyze()
+                if await self.index_postcodes() > 100:
+                    _analyze()
 
-            if self.index_postcodes() > 100:
-                _analyze()
+                if not self.has_pending():
+                    break
 
 
-    def index_boundaries(self, minrank: int, maxrank: int) -> int:
+    async def index_boundaries(self, minrank: int, maxrank: int) -> int:
         """ Index only administrative boundaries within the given rank range.
         """
         total = 0
@@ -153,11 +85,11 @@ class Indexer:
 
         with self.tokenizer.name_analyzer() as analyzer:
             for rank in range(max(minrank, 4), min(maxrank, 26)):
-                total += self._index(runners.BoundaryRunner(rank, analyzer))
+                total += await self._index(runners.BoundaryRunner(rank, analyzer))
 
         return total
 
-    def index_by_rank(self, minrank: int, maxrank: int) -> int:
+    async def index_by_rank(self, minrank: int, maxrank: int) -> int:
         """ Index all entries of placex in the given rank range (inclusive)
             in order of their address rank.
 
@@ -171,21 +103,27 @@ class Indexer:
 
         with self.tokenizer.name_analyzer() as analyzer:
             for rank in range(max(1, minrank), maxrank + 1):
-                total += self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1)
+                if rank >= 30:
+                    batch = 20
+                elif rank >= 26:
+                    batch = 5
+                else:
+                    batch = 1
+                total += await self._index(runners.RankRunner(rank, analyzer), batch)
 
             if maxrank == 30:
-                total += self._index(runners.RankRunner(0, analyzer))
-                total += self._index(runners.InterpolationRunner(analyzer), 20)
+                total += await self._index(runners.RankRunner(0, analyzer))
+                total += await self._index(runners.InterpolationRunner(analyzer), 20)
 
         return total
 
 
-    def index_postcodes(self) -> int:
+    async def index_postcodes(self) -> int:
         """Index the entries of the location_postcode table.
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
 
-        return self._index(runners.PostcodeRunner(), 20)
+        return await self._index(runners.PostcodeRunner(), 20)
 
 
     def update_status_table(self) -> None:
@@ -197,46 +135,58 @@ class Indexer:
 
             conn.commit()
 
-    def _index(self, runner: runners.Runner, batch: int = 1) -> int:
+    async def _index(self, runner: runners.Runner, batch: int = 1) -> int:
         """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """
         LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
-        with connect(self.dsn) as conn:
-            psycopg2.extras.register_hstore(conn)
-            with conn.cursor() as cur:
-                total_tuples = cur.scalar(runner.sql_count_objects())
-                LOG.debug("Total number of rows: %i", total_tuples)
-
-            conn.commit()
+        total_tuples = self._prepare_indexing(runner)
 
-            progress = ProgressLogger(runner.name(), total_tuples)
+        progress = ProgressLogger(runner.name(), total_tuples)
 
-            if total_tuples > 0:
-                with conn.cursor(name='places') as cur:
-                    cur.execute(runner.sql_get_objects())
+        if total_tuples > 0:
+            async with await psycopg.AsyncConnection.connect(
+                                 self.dsn, row_factory=psycopg.rows.dict_row) as aconn,\
+                       QueryPool(self.dsn, self.num_threads, autocommit=True) as pool:
+                fetcher_time = 0.0
+                tstart = time.time()
+                async with aconn.cursor(name='places') as cur:
+                    query = runner.index_places_query(batch)
+                    params: List[Any] = []
+                    num_places = 0
+                    async for place in cur.stream(runner.sql_get_objects()):
+                        fetcher_time += time.time() - tstart
 
-                    with PlaceFetcher(self.dsn, conn) as fetcher:
-                        with WorkerPool(self.dsn, self.num_threads) as pool:
-                            has_more = fetcher.fetch_next_batch(cur, runner)
-                            while has_more:
-                                places = fetcher.get_batch()
+                        params.extend(runner.index_places_params(place))
+                        num_places += 1
 
-                                # asynchronously get the next batch
-                                has_more = fetcher.fetch_next_batch(cur, runner)
+                        if num_places >= batch:
+                            LOG.debug("Processing places: %s", str(params))
+                            await pool.put_query(query, params)
+                            progress.add(num_places)
+                            params = []
+                            num_places = 0
 
-                                # And insert the current batch
-                                for idx in range(0, len(places), batch):
-                                    part = places[idx:idx + batch]
-                                    LOG.debug("Processing places: %s", str(part))
-                                    runner.index_places(pool.next_free_worker(), part)
-                                    progress.add(len(part))
+                        tstart = time.time()
 
-                            LOG.info("Wait time: fetcher: %.2fs,  pool: %.2fs",
-                                     fetcher.wait_time, pool.wait_time)
+                if num_places > 0:
+                    await pool.put_query(runner.index_places_query(num_places), params)
 
-                conn.commit()
+            LOG.info("Wait time: fetcher: %.2fs,  pool: %.2fs",
+                     fetcher_time, pool.wait_time)
 
         return progress.done()
+
+
+    def _prepare_indexing(self, runner: runners.Runner) -> int:
+        with connect(self.dsn) as conn:
+            hstore_info = psycopg.types.TypeInfo.fetch(conn, "hstore")
+            if hstore_info is None:
+                raise RuntimeError('Hstore extension is requested but not installed.')
+            psycopg.types.hstore.register_hstore(hstore_info)
+
+            total_tuples = execute_scalar(conn, runner.sql_count_objects())
+            LOG.debug("Total number of rows: %i", total_tuples)
+        return cast(int, total_tuples)
index 7b98e2401beab3b820941d76218b1137f8afb97f..d7737c07230c3d7df4b1579c986ba68d046333df 100644 (file)
@@ -8,14 +8,14 @@
 Mix-ins that provide the actual commands for the indexer for various indexing
 tasks.
 """
-from typing import Any, List
-import functools
+from typing import Any, Sequence
 
-from psycopg2 import sql as pysql
-import psycopg2.extras
+from psycopg import sql as pysql
+from psycopg.abc import Query
+from psycopg.rows import DictRow
+from psycopg.types.json import Json
 
-from ..typing import Query, DictCursorResult, DictCursorResults, Protocol
-from ..db.async_connection import DBConnection
+from ..typing import Protocol
 from ..data.place_info import PlaceInfo
 from ..tokenizer.base import AbstractAnalyzer
 
@@ -24,58 +24,48 @@ from ..tokenizer.base import AbstractAnalyzer
 def _mk_valuelist(template: str, num: int) -> pysql.Composed:
     return pysql.SQL(',').join([pysql.SQL(template)] * num)
 
-def _analyze_place(place: DictCursorResult, analyzer: AbstractAnalyzer) -> psycopg2.extras.Json:
-    return psycopg2.extras.Json(analyzer.process_place(PlaceInfo(place)))
+def _analyze_place(place: DictRow, analyzer: AbstractAnalyzer) -> Json:
+    return Json(analyzer.process_place(PlaceInfo(place)))
 
 
 class Runner(Protocol):
     def name(self) -> str: ...
     def sql_count_objects(self) -> Query: ...
     def sql_get_objects(self) -> Query: ...
-    def get_place_details(self, worker: DBConnection,
-                          ids: DictCursorResults) -> DictCursorResults: ...
-    def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: ...
+    def index_places_query(self, batch_size: int) -> Query: ...
+    def index_places_params(self, place: DictRow) -> Sequence[Any]: ...
 
 
+SELECT_SQL = pysql.SQL("""SELECT place_id, extra.*
+                          FROM (SELECT * FROM placex {}) as px,
+                          LATERAL placex_indexing_prepare(px) as extra """)
+UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
+
 class AbstractPlacexRunner:
     """ Returns SQL commands for indexing of the placex table.
     """
-    SELECT_SQL = pysql.SQL('SELECT place_id FROM placex ')
-    UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
 
     def __init__(self, rank: int, analyzer: AbstractAnalyzer) -> None:
         self.rank = rank
         self.analyzer = analyzer
 
 
-    @functools.lru_cache(maxsize=1)
-    def _index_sql(self, num_places: int) -> pysql.Composed:
+    def index_places_query(self, batch_size: int) -> Query:
         return pysql.SQL(
             """ UPDATE placex
                 SET indexed_status = 0, address = v.addr, token_info = v.ti,
                     name = v.name, linked_place_id = v.linked_place_id
                 FROM (VALUES {}) as v(id, name, addr, linked_place_id, ti)
                 WHERE place_id = v.id
-            """).format(_mk_valuelist(AbstractPlacexRunner.UPDATE_LINE, num_places))
-
-
-    def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
-        worker.perform("""SELECT place_id, extra.*
-                          FROM placex, LATERAL placex_indexing_prepare(placex) as extra
-                          WHERE place_id IN %s""",
-                       (tuple((p[0] for p in ids)), ))
+            """).format(_mk_valuelist(UPDATE_LINE, batch_size))
 
-        return []
 
-
-    def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
-        values: List[Any] = []
-        for place in places:
-            for field in ('place_id', 'name', 'address', 'linked_place_id'):
-                values.append(place[field])
-            values.append(_analyze_place(place, self.analyzer))
-
-        worker.perform(self._index_sql(len(places)), values)
+    def index_places_params(self, place: DictRow) -> Sequence[Any]:
+        return (place['place_id'],
+                place['name'],
+                place['address'],
+                place['linked_place_id'],
+                _analyze_place(place, self.analyzer))
 
 
 class RankRunner(AbstractPlacexRunner):
@@ -91,10 +81,10 @@ class RankRunner(AbstractPlacexRunner):
                          """).format(pysql.Literal(self.rank))
 
     def sql_get_objects(self) -> pysql.Composed:
-        return self.SELECT_SQL + pysql.SQL(
-            """WHERE indexed_status > 0 and rank_address = {}
-               ORDER BY geometry_sector
-            """).format(pysql.Literal(self.rank))
+        return SELECT_SQL.format(pysql.SQL(
+                """WHERE placex.indexed_status > 0 and placex.rank_address = {}
+                   ORDER BY placex.geometry_sector
+                """).format(pysql.Literal(self.rank)))
 
 
 class BoundaryRunner(AbstractPlacexRunner):
@@ -105,19 +95,19 @@ class BoundaryRunner(AbstractPlacexRunner):
     def name(self) -> str:
         return f"boundaries rank {self.rank}"
 
-    def sql_count_objects(self) -> pysql.Composed:
+    def sql_count_objects(self) -> Query:
         return pysql.SQL("""SELECT count(*) FROM placex
                             WHERE indexed_status > 0
                               AND rank_search = {}
                               AND class = 'boundary' and type = 'administrative'
                          """).format(pysql.Literal(self.rank))
 
-    def sql_get_objects(self) -> pysql.Composed:
-        return self.SELECT_SQL + pysql.SQL(
-            """WHERE indexed_status > 0 and rank_search = {}
-                     and class = 'boundary' and type = 'administrative'
-               ORDER BY partition, admin_level
-            """).format(pysql.Literal(self.rank))
+    def sql_get_objects(self) -> Query:
+        return SELECT_SQL.format(pysql.SQL(
+                """WHERE placex.indexed_status > 0 and placex.rank_search = {}
+                         and placex.class = 'boundary' and placex.type = 'administrative'
+                   ORDER BY placex.partition, placex.admin_level
+                """).format(pysql.Literal(self.rank)))
 
 
 class InterpolationRunner:
@@ -132,40 +122,29 @@ class InterpolationRunner:
     def name(self) -> str:
         return "interpolation lines (location_property_osmline)"
 
-    def sql_count_objects(self) -> str:
+    def sql_count_objects(self) -> Query:
         return """SELECT count(*) FROM location_property_osmline
                   WHERE indexed_status > 0"""
 
-    def sql_get_objects(self) -> str:
-        return """SELECT place_id
+
+    def sql_get_objects(self) -> Query:
+        return """SELECT place_id, get_interpolation_address(address, osm_id) as address
                   FROM location_property_osmline
                   WHERE indexed_status > 0
                   ORDER BY geometry_sector"""
 
 
-    def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
-        worker.perform("""SELECT place_id, get_interpolation_address(address, osm_id) as address
-                          FROM location_property_osmline WHERE place_id IN %s""",
-                       (tuple((p[0] for p in ids)), ))
-        return []
-
-
-    @functools.lru_cache(maxsize=1)
-    def _index_sql(self, num_places: int) -> pysql.Composed:
+    def index_places_query(self, batch_size: int) -> Query:
         return pysql.SQL("""UPDATE location_property_osmline
                             SET indexed_status = 0, address = v.addr, token_info = v.ti
                             FROM (VALUES {}) as v(id, addr, ti)
                             WHERE place_id = v.id
-                         """).format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", num_places))
-
+                         """).format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", batch_size))
 
-    def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
-        values: List[Any] = []
-        for place in places:
-            values.extend((place[x] for x in ('place_id', 'address')))
-            values.append(_analyze_place(place, self.analyzer))
 
-        worker.perform(self._index_sql(len(places)), values)
+    def index_places_params(self, place: DictRow) -> Sequence[Any]:
+        return (place['place_id'], place['address'],
+                _analyze_place(place, self.analyzer))
 
 
 
@@ -177,20 +156,21 @@ class PostcodeRunner(Runner):
         return "postcodes (location_postcode)"
 
 
-    def sql_count_objects(self) -> str:
+    def sql_count_objects(self) -> Query:
         return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0'
 
 
-    def sql_get_objects(self) -> str:
+    def sql_get_objects(self) -> Query:
         return """SELECT place_id FROM location_postcode
                   WHERE indexed_status > 0
                   ORDER BY country_code, postcode"""
 
 
-    def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
-        return ids
+    def index_places_query(self, batch_size: int) -> Query:
+        return pysql.SQL("""UPDATE location_postcode SET indexed_status = 0
+                                    WHERE place_id IN ({})""")\
+                    .format(pysql.SQL(',').join((pysql.Placeholder() for _ in range(batch_size))))
+
 
-    def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
-        worker.perform(pysql.SQL("""UPDATE location_postcode SET indexed_status = 0
-                                    WHERE place_id IN ({})""")
-                       .format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in places))))
+    def index_places_params(self, place: DictRow) -> Sequence[Any]:
+        return (place['place_id'], )
index 22e2d048291bcaa6d566e0f5fd3cf92fe9fcf870..7cd96d591fcec97483a721542829f4da90e52430 100644 (file)
@@ -11,14 +11,16 @@ libICU instead of the PostgreSQL module.
 from typing import Optional, Sequence, List, Tuple, Mapping, Any, cast, \
                    Dict, Set, Iterable
 import itertools
-import json
 import logging
 from pathlib import Path
 from textwrap import dedent
 
-from ..db.connection import connect, Connection, Cursor
+from psycopg.types.json import Jsonb
+from psycopg import sql as pysql
+
+from ..db.connection import connect, Connection, Cursor, server_version_tuple,\
+                            drop_tables, table_exists, execute_scalar
 from ..config import Configuration
-from ..db.utils import CopyBuffer
 from ..db.sql_preprocessor import SQLPreprocessor
 from ..data.place_info import PlaceInfo
 from ..data.place_name import PlaceName
@@ -108,19 +110,18 @@ class ICUTokenizer(AbstractTokenizer):
         """ Recompute frequencies for all name words.
         """
         with connect(self.dsn) as conn:
-            if not conn.table_exists('search_name'):
+            if not table_exists(conn, 'search_name'):
                 return
 
             with conn.cursor() as cur:
                 cur.execute('ANALYSE search_name')
                 if threads > 1:
-                    cur.execute('SET max_parallel_workers_per_gather TO %s',
-                                (min(threads, 6),))
+                    cur.execute(pysql.SQL('SET max_parallel_workers_per_gather TO {}')
+                                     .format(pysql.Literal(min(threads, 6),)))
 
-                if conn.server_version_tuple() < (12, 0):
+                if server_version_tuple(conn) < (12, 0):
                     LOG.info('Computing word frequencies')
-                    cur.drop_table('word_frequencies')
-                    cur.drop_table('addressword_frequencies')
+                    drop_tables(conn, 'word_frequencies', 'addressword_frequencies')
                     cur.execute("""CREATE TEMP TABLE word_frequencies AS
                                      SELECT unnest(name_vector) as id, count(*)
                                      FROM search_name GROUP BY id""")
@@ -152,17 +153,16 @@ class ICUTokenizer(AbstractTokenizer):
                                    $$ LANGUAGE plpgsql IMMUTABLE;
                                 """)
                     LOG.info('Update word table with recomputed frequencies')
-                    cur.drop_table('tmp_word')
+                    drop_tables(conn, 'tmp_word')
                     cur.execute("""CREATE TABLE tmp_word AS
                                     SELECT word_id, word_token, type, word,
                                            word_freq_update(word_id, info) as info
                                     FROM word
                                 """)
-                    cur.drop_table('word_frequencies')
-                    cur.drop_table('addressword_frequencies')
+                    drop_tables(conn, 'word_frequencies', 'addressword_frequencies')
                 else:
                     LOG.info('Computing word frequencies')
-                    cur.drop_table('word_frequencies')
+                    drop_tables(conn, 'word_frequencies')
                     cur.execute("""
                       CREATE TEMP TABLE word_frequencies AS
                       WITH word_freq AS MATERIALIZED (
@@ -182,7 +182,7 @@ class ICUTokenizer(AbstractTokenizer):
                     cur.execute('CREATE UNIQUE INDEX ON word_frequencies(id) INCLUDE(info)')
                     cur.execute('ANALYSE word_frequencies')
                     LOG.info('Update word table with recomputed frequencies')
-                    cur.drop_table('tmp_word')
+                    drop_tables(conn, 'tmp_word')
                     cur.execute("""CREATE TABLE tmp_word AS
                                     SELECT word_id, word_token, type, word,
                                            (CASE WHEN wf.info is null THEN word.info
@@ -191,7 +191,7 @@ class ICUTokenizer(AbstractTokenizer):
                                     FROM word LEFT JOIN word_frequencies wf
                                          ON word.word_id = wf.id
                                 """)
-                    cur.drop_table('word_frequencies')
+                    drop_tables(conn, 'word_frequencies')
 
             with conn.cursor() as cur:
                 cur.execute('SET max_parallel_workers_per_gather TO 0')
@@ -210,7 +210,7 @@ class ICUTokenizer(AbstractTokenizer):
         """ Remove unused house numbers.
         """
         with connect(self.dsn) as conn:
-            if not conn.table_exists('search_name'):
+            if not table_exists(conn, 'search_name'):
                 return
             with conn.cursor(name="hnr_counter") as cur:
                 cur.execute("""SELECT DISTINCT word_id, coalesce(info->>'lookup', word_token)
@@ -311,8 +311,7 @@ class ICUTokenizer(AbstractTokenizer):
             frequencies.
         """
         with connect(self.dsn) as conn:
-            with conn.cursor() as cur:
-                cur.drop_table('word')
+            drop_tables(conn, 'word')
             sqlp = SQLPreprocessor(conn, config)
             sqlp.run_string(conn, """
                 CREATE TABLE word (
@@ -370,8 +369,8 @@ class ICUTokenizer(AbstractTokenizer):
         """ Rename all tables and indexes used by the tokenizer.
         """
         with connect(self.dsn) as conn:
+            drop_tables(conn, 'word')
             with conn.cursor() as cur:
-                cur.drop_table('word')
                 cur.execute(f"ALTER TABLE {old} RENAME TO word")
                 for idx in ('word_token', 'word_id'):
                     cur.execute(f"""ALTER INDEX idx_{old}_{idx}
@@ -393,7 +392,7 @@ class ICUNameAnalyzer(AbstractAnalyzer):
 
     def __init__(self, dsn: str, sanitizer: PlaceSanitizer,
                  token_analysis: ICUTokenAnalysis) -> None:
-        self.conn: Optional[Connection] = connect(dsn).connection
+        self.conn: Optional[Connection] = connect(dsn)
         self.conn.autocommit = True
         self.sanitizer = sanitizer
         self.token_analysis = token_analysis
@@ -535,9 +534,7 @@ class ICUNameAnalyzer(AbstractAnalyzer):
 
         if terms:
             with self.conn.cursor() as cur:
-                cur.execute_values("""SELECT create_postcode_word(pc, var)
-                                      FROM (VALUES %s) AS v(pc, var)""",
-                                   terms)
+                cur.executemany("""SELECT create_postcode_word(%s, %s)""", terms)
 
 
 
@@ -580,18 +577,15 @@ class ICUNameAnalyzer(AbstractAnalyzer):
         to_add = new_phrases - existing_phrases
 
         added = 0
-        with CopyBuffer() as copystr:
+        with cursor.copy('COPY word(word_token, type, word, info) FROM STDIN') as copy:
             for word, cls, typ, oper in to_add:
                 term = self._search_normalized(word)
                 if term:
-                    copystr.add(term, 'S', word,
-                                json.dumps({'class': cls, 'type': typ,
-                                            'op': oper if oper in ('in', 'near') else None}))
+                    copy.write_row((term, 'S', word,
+                                    Jsonb({'class': cls, 'type': typ,
+                                           'op': oper if oper in ('in', 'near') else None})))
                     added += 1
 
-            copystr.copy_out(cursor, 'word',
-                             columns=['word_token', 'type', 'word', 'info'])
-
         return added
 
 
@@ -604,11 +598,11 @@ class ICUNameAnalyzer(AbstractAnalyzer):
         to_delete = existing_phrases - new_phrases
 
         if to_delete:
-            cursor.execute_values(
-                """ DELETE FROM word USING (VALUES %s) as v(name, in_class, in_type, op)
-                    WHERE type = 'S' and word = name
-                          and info->>'class' = in_class and info->>'type' = in_type
-                          and ((op = '-' and info->>'op' is null) or op = info->>'op')
+            cursor.executemany(
+                """ DELETE FROM word
+                      WHERE type = 'S' and word = %s
+                            and info->>'class' = %s and info->>'type' = %s
+                            and %s = coalesce(info->>'op', '-')
                 """, to_delete)
 
         return len(to_delete)
@@ -655,7 +649,7 @@ class ICUNameAnalyzer(AbstractAnalyzer):
                 gone_tokens.update(existing_tokens[False] & word_tokens)
             if gone_tokens:
                 cur.execute("""DELETE FROM word
-                               USING unnest(%s) as token
+                               USING unnest(%s::text[]) as token
                                WHERE type = 'C' and word = %s
                                      and word_token = token""",
                             (list(gone_tokens), country_code))
@@ -668,12 +662,12 @@ class ICUNameAnalyzer(AbstractAnalyzer):
                 if internal:
                     sql = """INSERT INTO word (word_token, type, word, info)
                                (SELECT token, 'C', %s, '{"internal": "yes"}'
-                                  FROM unnest(%s) as token)
+                                  FROM unnest(%s::text[]) as token)
                            """
                 else:
                     sql = """INSERT INTO word (word_token, type, word)
                                    (SELECT token, 'C', %s
-                                    FROM unnest(%s) as token)
+                                    FROM unnest(%s::text[]) as token)
                           """
                 cur.execute(sql, (country_code, list(new_tokens)))
 
@@ -733,11 +727,10 @@ class ICUNameAnalyzer(AbstractAnalyzer):
             if norm_name:
                 result = self._cache.housenumbers.get(norm_name, result)
                 if result[0] is None:
-                    with self.conn.cursor() as cur:
-                        hid = cur.scalar("SELECT getorcreate_hnr_id(%s)", (norm_name, ))
+                    hid = execute_scalar(self.conn, "SELECT getorcreate_hnr_id(%s)", (norm_name, ))
 
-                        result = hid, norm_name
-                        self._cache.housenumbers[norm_name] = result
+                    result = hid, norm_name
+                    self._cache.housenumbers[norm_name] = result
         else:
             # Otherwise use the analyzer to determine the canonical name.
             # Per convention we use the first variant as the 'lookup name', the
@@ -748,11 +741,10 @@ class ICUNameAnalyzer(AbstractAnalyzer):
                 if result[0] is None:
                     variants = analyzer.compute_variants(word_id)
                     if variants:
-                        with self.conn.cursor() as cur:
-                            hid = cur.scalar("SELECT create_analyzed_hnr_id(%s, %s)",
+                        hid = execute_scalar(self.conn, "SELECT create_analyzed_hnr_id(%s, %s)",
                                              (word_id, list(variants)))
-                            result = hid, variants[0]
-                            self._cache.housenumbers[word_id] = result
+                        result = hid, variants[0]
+                        self._cache.housenumbers[word_id] = result
 
         return result
 
index 136a733159cd3180e2e0558e48b97141887165c1..fa4b3b99ca61f7e1bb9fed8539d95464ae17e9c9 100644 (file)
@@ -17,11 +17,12 @@ import shutil
 from textwrap import dedent
 
 from icu import Transliterator
-import psycopg2
-import psycopg2.extras
+import psycopg
+from psycopg import sql as pysql
 
 from ..errors import UsageError
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, drop_tables, table_exists,\
+                            execute_scalar, register_hstore
 from ..config import Configuration
 from ..db import properties
 from ..db import utils as db_utils
@@ -78,12 +79,12 @@ def _check_module(module_dir: str, conn: Connection) -> None:
     """
     with conn.cursor() as cur:
         try:
-            cur.execute("""CREATE FUNCTION nominatim_test_import_func(text)
-                           RETURNS text AS %s, 'transliteration'
-                           LANGUAGE c IMMUTABLE STRICT;
-                           DROP FUNCTION nominatim_test_import_func(text)
-                        """, (f'{module_dir}/nominatim.so', ))
-        except psycopg2.DatabaseError as err:
+            cur.execute(pysql.SQL("""CREATE FUNCTION nominatim_test_import_func(text)
+                                     RETURNS text AS {}, 'transliteration'
+                                     LANGUAGE c IMMUTABLE STRICT;
+                                     DROP FUNCTION nominatim_test_import_func(text)
+                                 """).format(pysql.Literal(f'{module_dir}/nominatim.so')))
+        except psycopg.DatabaseError as err:
             LOG.fatal("Error accessing database module: %s", err)
             raise UsageError("Database module cannot be accessed.") from err
 
@@ -179,11 +180,10 @@ class LegacyTokenizer(AbstractTokenizer):
              * Can nominatim.so be accessed by the database user?
              """
         with connect(self.dsn) as conn:
-            with conn.cursor() as cur:
-                try:
-                    out = cur.scalar("SELECT make_standard_name('a')")
-                except psycopg2.Error as err:
-                    return hint.format(error=str(err))
+            try:
+                out = execute_scalar(conn, "SELECT make_standard_name('a')")
+            except psycopg.Error as err:
+                return hint.format(error=str(err))
 
         if out != 'a':
             return hint.format(error='Unexpected result for make_standard_name()')
@@ -214,9 +214,9 @@ class LegacyTokenizer(AbstractTokenizer):
         """ Recompute the frequency of full words.
         """
         with connect(self.dsn) as conn:
-            if conn.table_exists('search_name'):
+            if table_exists(conn, 'search_name'):
+                drop_tables(conn, "word_frequencies")
                 with conn.cursor() as cur:
-                    cur.drop_table("word_frequencies")
                     LOG.info("Computing word frequencies")
                     cur.execute("""CREATE TEMP TABLE word_frequencies AS
                                      SELECT unnest(name_vector) as id, count(*)
@@ -226,7 +226,7 @@ class LegacyTokenizer(AbstractTokenizer):
                     cur.execute("""UPDATE word SET search_name_count = count
                                    FROM word_frequencies
                                    WHERE word_token like ' %' and word_id = id""")
-                    cur.drop_table("word_frequencies")
+                drop_tables(conn, "word_frequencies")
             conn.commit()
 
 
@@ -313,10 +313,10 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
     """
 
     def __init__(self, dsn: str, normalizer: Any):
-        self.conn: Optional[Connection] = connect(dsn).connection
+        self.conn: Optional[Connection] = connect(dsn)
         self.conn.autocommit = True
         self.normalizer = normalizer
-        psycopg2.extras.register_hstore(self.conn)
+        register_hstore(self.conn)
 
         self._cache = _TokenCache(self.conn)
 
@@ -406,7 +406,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
                             """, (to_delete, ))
             if to_add:
                 cur.execute("""SELECT count(create_postcode_id(pc))
-                               FROM unnest(%s) as pc
+                               FROM unnest(%s::text[]) as pc
                             """, (to_add, ))
 
 
@@ -423,7 +423,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
         with self.conn.cursor() as cur:
             # Get the old phrases.
             existing_phrases = set()
-            cur.execute("""SELECT word, class, type, operator FROM word
+            cur.execute("""SELECT word, class as cls, type, operator FROM word
                            WHERE class != 'place'
                                  OR (type != 'house' AND type != 'postcode')""")
             for label, cls, typ, oper in cur:
@@ -433,18 +433,19 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
             to_delete = existing_phrases - norm_phrases
 
             if to_add:
-                cur.execute_values(
+                cur.executemany(
                     """ INSERT INTO word (word_id, word_token, word, class, type,
                                           search_name_count, operator)
                         (SELECT nextval('seq_word'), ' ' || make_standard_name(name), name,
                                 class, type, 0,
                                 CASE WHEN op in ('in', 'near') THEN op ELSE null END
-                           FROM (VALUES %s) as v(name, class, type, op))""",
+                           FROM (VALUES (%s, %s, %s, %s)) as v(name, class, type, op))""",
                     to_add)
 
             if to_delete and should_replace:
-                cur.execute_values(
-                    """ DELETE FROM word USING (VALUES %s) as v(name, in_class, in_type, op)
+                cur.executemany(
+                    """ DELETE FROM word
+                          USING (VALUES (%s, %s, %s, %s)) as v(name, in_class, in_type, op)
                         WHERE word = name and class = in_class and type = in_type
                               and ((op = '-' and operator is null) or op = operator)""",
                     to_delete)
@@ -463,7 +464,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
                 """INSERT INTO word (word_id, word_token, country_code)
                    (SELECT nextval('seq_word'), lookup_token, %s
                       FROM (SELECT DISTINCT ' ' || make_standard_name(n) as lookup_token
-                            FROM unnest(%s)n) y
+                            FROM unnest(%s::TEXT[])n) y
                       WHERE NOT EXISTS(SELECT * FROM word
                                        WHERE word_token = lookup_token and country_code = %s))
                 """, (country_code, list(names.values()), country_code))
@@ -536,9 +537,8 @@ class _TokenInfo:
     def add_names(self, conn: Connection, names: Mapping[str, str]) -> None:
         """ Add token information for the names of the place.
         """
-        with conn.cursor() as cur:
-            # Create the token IDs for all names.
-            self.data['names'] = cur.scalar("SELECT make_keywords(%s)::text",
+        # Create the token IDs for all names.
+        self.data['names'] = execute_scalar(conn, "SELECT make_keywords(%s)::text",
                                             (names, ))
 
 
@@ -576,9 +576,8 @@ class _TokenInfo:
         """ Add addr:street match terms.
         """
         def _get_street(name: str) -> Optional[str]:
-            with conn.cursor() as cur:
-                return cast(Optional[str],
-                            cur.scalar("SELECT word_ids_from_name(%s)::text", (name, )))
+            return cast(Optional[str],
+                        execute_scalar(conn, "SELECT word_ids_from_name(%s)::text", (name, )))
 
         tokens = self.cache.streets.get(street, _get_street)
         self.data['street'] = tokens or '{}'
index cea2ad664a7e339c2c66dd23790ab517dcd7b338..e70a7e50508abb8b5fcba78e5b147338e1205e51 100644 (file)
@@ -10,12 +10,12 @@ Functions for database analysis and maintenance.
 from typing import Optional, Tuple, Any, cast
 import logging
 
-from psycopg2.extras import Json, register_hstore
-from psycopg2 import DataError
+import psycopg
+from psycopg.types.json import Json
 
 from ..typing import DictCursorResult
 from ..config import Configuration
-from ..db.connection import connect, Cursor
+from ..db.connection import connect, Cursor, register_hstore
 from ..errors import UsageError
 from ..tokenizer import factory as tokenizer_factory
 from ..data.place_info import PlaceInfo
@@ -59,7 +59,7 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
     """
     with connect(config.get_libpq_dsn()) as conn:
         register_hstore(conn)
-        with conn.cursor() as cur:
+        with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
             place = _get_place_info(cur, osm_id, place_id)
 
             cur.execute("update placex set indexed_status = 2 where place_id = %s",
@@ -74,6 +74,9 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
 
             tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
 
+            # Enable printing of messages.
+            conn.add_notice_handler(lambda diag: print(diag.message_primary))
+
             with tokenizer.name_analyzer() as analyzer:
                 cur.execute("""UPDATE placex
                                SET indexed_status = 0, address = %s, token_info = %s,
@@ -86,9 +89,6 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
         # we do not want to keep the results
         conn.rollback()
 
-        for msg in conn.notices:
-            print(msg)
-
 
 def clean_deleted_relations(config: Configuration, age: str) -> None:
     """ Clean deleted relations older than a given age
@@ -101,6 +101,6 @@ def clean_deleted_relations(config: Configuration, age: str) -> None:
                             WHERE p.osm_type = d.osm_type AND p.osm_id = d.osm_id
                             AND age(p.indexed_date) > %s::interval""",
                             (age, ))
-            except DataError as exc:
+            except psycopg.DataError as exc:
                 raise UsageError('Invalid PostgreSQL time interval format') from exc
         conn.commit()
index ef28a0e5a61653b6a2f2c250137050130caff084..7389c9a2c8be49bdf7a43f9984427e5f52f00430 100644 (file)
@@ -12,7 +12,8 @@ from enum import Enum
 from textwrap import dedent
 
 from ..config import Configuration
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, server_version_tuple,\
+                            index_exists, table_exists, execute_scalar
 from ..db import properties
 from ..errors import UsageError
 from ..tokenizer import factory as tokenizer_factory
@@ -80,7 +81,7 @@ def check_database(config: Configuration) -> int:
     """ Run a number of checks on the database and return the status.
     """
     try:
-        conn = connect(config.get_libpq_dsn()).connection
+        conn = connect(config.get_libpq_dsn())
     except UsageError as err:
         conn = _BadConnection(str(err)) # type: ignore[assignment]
 
@@ -109,14 +110,14 @@ def _get_indexes(conn: Connection) -> List[str]:
                'idx_postcode_id',
                'idx_postcode_postcode'
               ]
-    if conn.table_exists('search_name'):
+    if table_exists(conn, 'search_name'):
         indexes.extend(('idx_search_name_nameaddress_vector',
                         'idx_search_name_name_vector',
                         'idx_search_name_centroid'))
-        if conn.server_version_tuple() >= (11, 0, 0):
+        if server_version_tuple(conn) >= (11, 0, 0):
             indexes.extend(('idx_placex_housenumber',
                             'idx_osmline_parent_osm_id_with_hnr'))
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         indexes.extend(('idx_location_area_country_place_id',
                         'idx_place_osm_unique',
                         'idx_placex_rank_address_sector',
@@ -153,7 +154,7 @@ def check_connection(conn: Any, config: Configuration) -> CheckResult:
 
              Hints:
              * Are you connecting to the correct database?
-             
+
              {instruction}
 
              Check the Migration chapter of the Administration Guide.
@@ -165,7 +166,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
     """ Checking database_version matches Nominatim software version
     """
 
-    if conn.table_exists('nominatim_properties'):
+    if table_exists(conn, 'nominatim_properties'):
         db_version_str = properties.get_property(conn, 'database_version')
     else:
         db_version_str = None
@@ -202,7 +203,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
 def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
     """ Checking for placex table
     """
-    if conn.table_exists('placex'):
+    if table_exists(conn, 'placex'):
         return CheckState.OK
 
     return CheckState.FATAL, dict(config=config)
@@ -212,8 +213,7 @@ def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
 def check_placex_size(conn: Connection, _: Configuration) -> CheckResult:
     """ Checking for placex content
     """
-    with conn.cursor() as cur:
-        cnt = cur.scalar('SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x')
+    cnt = execute_scalar(conn, 'SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x')
 
     return CheckState.OK if cnt > 0 else CheckState.FATAL
 
@@ -244,16 +244,15 @@ def check_tokenizer(_: Connection, config: Configuration) -> CheckResult:
 def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult:
     """ Checking for wikipedia/wikidata data
     """
-    if not conn.table_exists('search_name') or not conn.table_exists('place'):
+    if not table_exists(conn, 'search_name') or not table_exists(conn, 'place'):
         return CheckState.NOT_APPLICABLE
 
-    with conn.cursor() as cur:
-        if conn.table_exists('wikimedia_importance'):
-            cnt = cur.scalar('SELECT count(*) FROM wikimedia_importance')
-        else:
-            cnt = cur.scalar('SELECT count(*) FROM wikipedia_article')
+    if table_exists(conn, 'wikimedia_importance'):
+        cnt = execute_scalar(conn, 'SELECT count(*) FROM wikimedia_importance')
+    else:
+        cnt = execute_scalar(conn, 'SELECT count(*) FROM wikipedia_article')
 
-        return CheckState.WARN if cnt == 0 else CheckState.OK
+    return CheckState.WARN if cnt == 0 else CheckState.OK
 
 
 @_check(hint="""\
@@ -264,8 +263,7 @@ def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult
 def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
     """ Checking indexing status
     """
-    with conn.cursor() as cur:
-        cnt = cur.scalar('SELECT count(*) FROM placex WHERE indexed_status > 0')
+    cnt = execute_scalar(conn, 'SELECT count(*) FROM placex WHERE indexed_status > 0')
 
     if cnt == 0:
         return CheckState.OK
@@ -276,7 +274,7 @@ def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
             Low counts of unindexed places are fine."""
         return CheckState.WARN, dict(count=cnt, index_cmd=index_cmd)
 
-    if conn.index_exists('idx_placex_rank_search'):
+    if index_exists(conn, 'idx_placex_rank_search'):
         # Likely just an interrupted update.
         index_cmd = 'nominatim index'
     else:
@@ -297,7 +295,7 @@ def check_database_indexes(conn: Connection, _: Configuration) -> CheckResult:
     """
     missing = []
     for index in _get_indexes(conn):
-        if not conn.index_exists(index):
+        if not index_exists(conn, index):
             missing.append(index)
 
     if missing:
@@ -340,11 +338,10 @@ def check_tiger_table(conn: Connection, config: Configuration) -> CheckResult:
     if not config.get_bool('USE_US_TIGER_DATA'):
         return CheckState.NOT_APPLICABLE
 
-    if not conn.table_exists('location_property_tiger'):
+    if not table_exists(conn, 'location_property_tiger'):
         return CheckState.FAIL, dict(error='TIGER data table not found.')
 
-    with conn.cursor() as cur:
-        if cur.scalar('SELECT count(*) FROM location_property_tiger') == 0:
-            return CheckState.FAIL, dict(error='TIGER data table is empty.')
+    if execute_scalar(conn, 'SELECT count(*) FROM location_property_tiger') == 0:
+        return CheckState.FAIL, dict(error='TIGER data table is empty.')
 
     return CheckState.OK
index e1f8b16637a4cd02de87fb582f8932b8695cd1b1..d054ef006741729e0bf5f541b977e083f695bad5 100644 (file)
@@ -12,21 +12,15 @@ import os
 import subprocess
 import sys
 from pathlib import Path
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Union
 
 import psutil
-from psycopg2.extensions import make_dsn, parse_dsn
 
 from ..config import Configuration
-from ..db.connection import connect
+from ..db.connection import connect, server_version_tuple, execute_scalar
 from ..version import NOMINATIM_VERSION
 
 
-def convert_version(ver_tup: Tuple[int, int]) -> str:
-    """converts tuple version (ver_tup) to a string representation"""
-    return ".".join(map(str, ver_tup))
-
-
 def friendly_memory_string(mem: float) -> str:
     """Create a user friendly string for the amount of memory specified as mem"""
     mem_magnitude = ("bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
@@ -102,17 +96,17 @@ def report_system_information(config: Configuration) -> None:
     """Generate a report about the host system including software versions, memory,
     storage, and database configuration."""
 
-    with connect(make_dsn(config.get_libpq_dsn(), dbname='postgres')) as conn:
-        postgresql_ver: str = convert_version(conn.server_version_tuple())
+    with connect(config.get_libpq_dsn(), dbname='postgres') as conn:
+        postgresql_ver: str = '.'.join(map(str, server_version_tuple(conn)))
 
         with conn.cursor() as cur:
-            num = cur.scalar("SELECT count(*) FROM pg_catalog.pg_database WHERE datname=%s",
-                             (parse_dsn(config.get_libpq_dsn())['dbname'], ))
-            nominatim_db_exists = num == 1 if isinstance(num, int) else False
+            cur.execute("SELECT datname FROM pg_catalog.pg_database WHERE datname=%s",
+                        (config.get_database_params()['dbname'], ))
+            nominatim_db_exists = cur.rowcount > 0
 
     if nominatim_db_exists:
         with connect(config.get_libpq_dsn()) as conn:
-            postgis_ver: str = convert_version(conn.postgis_version_tuple())
+            postgis_ver: str = execute_scalar(conn, 'SELECT postgis_lib_version()')
     else:
         postgis_ver = "Unable to connect to database"
 
index d07febc8a3da97fc5e41b8b243e3e8ba0f703de1..e96954ddd2df605e59618fd23e204c955649ef66 100644 (file)
@@ -10,18 +10,20 @@ Functions for setting up and importing a new Nominatim database.
 from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
 import logging
 import os
-import selectors
 import subprocess
+import asyncio
 from pathlib import Path
 
 import psutil
-from psycopg2 import sql as pysql
+import psycopg
+from psycopg import sql as pysql
 
 from ..errors import UsageError
 from ..config import Configuration
-from ..db.connection import connect, get_pg_env, Connection
-from ..db.async_connection import DBConnection
+from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\
+                            postgis_version_tuple, drop_tables, table_exists, execute_scalar
 from ..db.sql_preprocessor import SQLPreprocessor
+from ..db.query_pool import QueryPool
 from .exec_utils import run_osm2pgsql
 from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
 
@@ -40,19 +42,21 @@ def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int,
 
 def _require_loaded(extension_name: str, conn: Connection) -> None:
     """ Check that the given extension is loaded. """
-    if not conn.extension_loaded(extension_name):
-        LOG.fatal('Required module %s is not loaded.', extension_name)
-        raise UsageError(f'{extension_name} is not loaded.')
+    with conn.cursor() as cur:
+        cur.execute('SELECT * FROM pg_extension WHERE extname = %s', (extension_name, ))
+        if cur.rowcount <= 0:
+            LOG.fatal('Required module %s is not loaded.', extension_name)
+            raise UsageError(f'{extension_name} is not loaded.')
 
 
 def check_existing_database_plugins(dsn: str) -> None:
     """ Check that the database has the required plugins installed."""
     with connect(dsn) as conn:
         _require_version('PostgreSQL server',
-                         conn.server_version_tuple(),
+                         server_version_tuple(conn),
                          POSTGRESQL_REQUIRED_VERSION)
         _require_version('PostGIS',
-                         conn.postgis_version_tuple(),
+                         postgis_version_tuple(conn),
                          POSTGIS_REQUIRED_VERSION)
         _require_loaded('hstore', conn)
 
@@ -78,31 +82,30 @@ def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
 
     with connect(dsn) as conn:
         _require_version('PostgreSQL server',
-                         conn.server_version_tuple(),
+                         server_version_tuple(conn),
                          POSTGRESQL_REQUIRED_VERSION)
 
         if rouser is not None:
-            with conn.cursor() as cur:
-                cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
+            cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s',
                                  (rouser, ))
-                if cnt == 0:
-                    LOG.fatal("Web user '%s' does not exist. Create it with:\n"
-                              "\n      createuser %s", rouser, rouser)
-                    raise UsageError('Missing read-only user.')
+            if cnt == 0:
+                LOG.fatal("Web user '%s' does not exist. Create it with:\n"
+                          "\n      createuser %s", rouser, rouser)
+                raise UsageError('Missing read-only user.')
 
         # Create extensions.
         with conn.cursor() as cur:
             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
 
-            postgis_version = conn.postgis_version_tuple()
+            postgis_version = postgis_version_tuple(conn)
             if postgis_version[0] >= 3:
                 cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
 
         conn.commit()
 
         _require_version('PostGIS',
-                         conn.postgis_version_tuple(),
+                         postgis_version_tuple(conn),
                          POSTGIS_REQUIRED_VERSION)
 
 
@@ -134,12 +137,13 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]],
     with connect(options['dsn']) as conn:
         if not ignore_errors:
             with conn.cursor() as cur:
-                cur.execute('SELECT * FROM place LIMIT 1')
+                cur.execute('SELECT true FROM place LIMIT 1')
                 if cur.rowcount == 0:
                     raise UsageError('No data imported by osm2pgsql.')
 
         if drop:
-            conn.drop_table('planet_osm_nodes')
+            drop_tables(conn, 'planet_osm_nodes')
+            conn.commit()
 
     if drop and options['flatnode_file']:
         Path(options['flatnode_file']).unlink()
@@ -182,7 +186,7 @@ def truncate_data_tables(conn: Connection) -> None:
         cur.execute('TRUNCATE location_property_tiger')
         cur.execute('TRUNCATE location_property_osmline')
         cur.execute('TRUNCATE location_postcode')
-        if conn.table_exists('search_name'):
+        if table_exists(conn, 'search_name'):
             cur.execute('TRUNCATE search_name')
         cur.execute('DROP SEQUENCE IF EXISTS seq_place')
         cur.execute('CREATE SEQUENCE seq_place start 100000')
@@ -202,54 +206,51 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
                                          'extratags', 'geometry')))
 
 
-def load_data(dsn: str, threads: int) -> None:
+async def load_data(dsn: str, threads: int) -> None:
     """ Copy data into the word and placex table.
     """
-    sel = selectors.DefaultSelector()
-    # Then copy data from place to placex in <threads - 1> chunks.
-    place_threads = max(1, threads - 1)
-    for imod in range(place_threads):
-        conn = DBConnection(dsn)
-        conn.connect()
-        conn.perform(
-            pysql.SQL("""INSERT INTO placex ({columns})
-                           SELECT {columns} FROM place
-                           WHERE osm_id % {total} = {mod}
-                             AND NOT (class='place' and (type='houses' or type='postcode'))
-                             AND ST_IsValid(geometry)
-                      """).format(columns=_COPY_COLUMNS,
-                                  total=pysql.Literal(place_threads),
-                                  mod=pysql.Literal(imod)))
-        sel.register(conn, selectors.EVENT_READ, conn)
-
-    # Address interpolations go into another table.
-    conn = DBConnection(dsn)
-    conn.connect()
-    conn.perform("""INSERT INTO location_property_osmline (osm_id, address, linegeo)
-                      SELECT osm_id, address, geometry FROM place
-                      WHERE class='place' and type='houses' and osm_type='W'
-                            and ST_GeometryType(geometry) = 'ST_LineString'
-                 """)
-    sel.register(conn, selectors.EVENT_READ, conn)
-
-    # Now wait for all of them to finish.
-    todo = place_threads + 1
-    while todo > 0:
-        for key, _ in sel.select(1):
-            conn = key.data
-            sel.unregister(conn)
-            conn.wait()
-            conn.close()
-            todo -= 1
+    placex_threads = max(1, threads - 1)
+
+    progress = asyncio.create_task(_progress_print())
+
+    async with QueryPool(dsn, placex_threads + 1) as pool:
+        # Copy data from place to placex in <threads - 1> chunks.
+        for imod in range(placex_threads):
+            await pool.put_query(
+                pysql.SQL("""INSERT INTO placex ({columns})
+                               SELECT {columns} FROM place
+                                WHERE osm_id % {total} = {mod}
+                                  AND NOT (class='place'
+                                           and (type='houses' or type='postcode'))
+                                  AND ST_IsValid(geometry)
+                          """).format(columns=_COPY_COLUMNS,
+                                      total=pysql.Literal(placex_threads),
+                                      mod=pysql.Literal(imod)), None)
+
+        # Interpolations need to be copied seperately
+        await pool.put_query("""
+                INSERT INTO location_property_osmline (osm_id, address, linegeo)
+                  SELECT osm_id, address, geometry FROM place
+                  WHERE class='place' and type='houses' and osm_type='W'
+                        and ST_GeometryType(geometry) = 'ST_LineString' """, None)
+
+    progress.cancel()
+
+    async with await psycopg.AsyncConnection.connect(dsn) as aconn:
+        await aconn.execute('ANALYSE')
+
+
+async def _progress_print() -> None:
+    while True:
+        try:
+            await asyncio.sleep(1)
+        except asyncio.CancelledError:
+            print('', flush=True)
+            break
         print('.', end='', flush=True)
-    print('\n')
-
-    with connect(dsn) as syn_conn:
-        with syn_conn.cursor() as cur:
-            cur.execute('ANALYSE')
 
 
-def create_search_indices(conn: Connection, config: Configuration,
+async def create_search_indices(conn: Connection, config: Configuration,
                           drop: bool = False, threads: int = 1) -> None:
     """ Create tables that have explicit partitioning.
     """
@@ -268,5 +269,5 @@ def create_search_indices(conn: Connection, config: Configuration,
 
     sql = SQLPreprocessor(conn, config)
 
-    sql.run_parallel_sql_file(config.get_libpq_dsn(),
-                              'indices.sql', min(8, threads), drop=drop)
+    await sql.run_parallel_sql_file(config.get_libpq_dsn(),
+                                    'indices.sql', min(8, threads), drop=drop)
index bd52ba9a5a7ad194c353a721d1529ea4cefa20cb..c4eedb43262029424b120072576de0e061b32d86 100644 (file)
@@ -10,9 +10,9 @@ Functions for removing unnecessary data from the database.
 from typing import Optional
 from pathlib import Path
 
-from psycopg2 import sql as pysql
+from psycopg import sql as pysql
 
-from ..db.connection import Connection
+from ..db.connection import Connection, drop_tables, table_exists
 
 UPDATE_TABLES = [
     'address_levels',
@@ -39,9 +39,7 @@ def drop_update_tables(conn: Connection) -> None:
                     + pysql.SQL(' or ').join(parts))
         tables = [r[0] for r in cur]
 
-        for table in tables:
-            cur.drop_table(table, cascade=True)
-
+    drop_tables(conn, *tables, cascade=True)
     conn.commit()
 
 
@@ -55,4 +53,4 @@ def is_frozen(conn: Connection) -> bool:
     """ Returns true if database is in a frozen state
     """
 
-    return conn.table_exists('place') is False
+    return table_exists(conn, 'place') is False
index e6803c7dea23521a805aa939584d028f66b4a7b5..5483653230bb6c2eff94109a0c87e6539727c0d9 100644 (file)
@@ -10,12 +10,13 @@ Functions for database migration to newer software versions.
 from typing import List, Tuple, Callable, Any
 import logging
 
-from psycopg2 import sql as pysql
+from psycopg import sql as pysql
 
 from ..errors import UsageError
 from ..config import Configuration
 from ..db import properties
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, server_version_tuple,\
+                            table_has_column, table_exists, execute_scalar, register_hstore
 from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version
 from ..tokenizer import factory as tokenizer_factory
 from . import refresh
@@ -29,7 +30,8 @@ def migrate(config: Configuration, paths: Any) -> int:
         if necesssary.
     """
     with connect(config.get_libpq_dsn()) as conn:
-        if conn.table_exists('nominatim_properties'):
+        register_hstore(conn)
+        if table_exists(conn, 'nominatim_properties'):
             db_version_str = properties.get_property(conn, 'database_version')
         else:
             db_version_str = None
@@ -72,16 +74,15 @@ def _guess_version(conn: Connection) -> NominatimVersion:
         Only migrations for 3.6 and later are supported, so bail out
         when the version seems older.
     """
-    with conn.cursor() as cur:
-        # In version 3.6, the country_name table was updated. Check for that.
-        cnt = cur.scalar("""SELECT count(*) FROM
-                            (SELECT svals(name) FROM  country_name
-                             WHERE country_code = 'gb')x;
-                         """)
-        if cnt < 100:
-            LOG.fatal('It looks like your database was imported with a version '
-                      'prior to 3.6.0. Automatic migration not possible.')
-            raise UsageError('Migration not possible.')
+    # In version 3.6, the country_name table was updated. Check for that.
+    cnt = execute_scalar(conn, """SELECT count(*) FROM
+                                  (SELECT svals(name) FROM  country_name
+                                  WHERE country_code = 'gb')x;
+                               """)
+    if cnt < 100:
+        LOG.fatal('It looks like your database was imported with a version '
+                  'prior to 3.6.0. Automatic migration not possible.')
+        raise UsageError('Migration not possible.')
 
     return NominatimVersion(3, 5, 0, 99)
 
@@ -125,7 +126,7 @@ def import_status_timestamp_change(conn: Connection, **_: Any) -> None:
 def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None:
     """ Add nominatim_property table.
     """
-    if not conn.table_exists('nominatim_properties'):
+    if not table_exists(conn, 'nominatim_properties'):
         with conn.cursor() as cur:
             cur.execute(pysql.SQL("""CREATE TABLE nominatim_properties (
                                         property TEXT,
@@ -189,13 +190,9 @@ def install_legacy_tokenizer(conn: Connection, config: Configuration, **_: Any)
         configuration for the backwards-compatible legacy tokenizer
     """
     if properties.get_property(conn, 'tokenizer') is None:
-        with conn.cursor() as cur:
-            for table in ('placex', 'location_property_osmline'):
-                has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
-                                           WHERE table_name = %s
-                                           and column_name = 'token_info'""",
-                                        (table, ))
-                if has_column == 0:
+        for table in ('placex', 'location_property_osmline'):
+            if not table_has_column(conn, table, 'token_info'):
+                with conn.cursor() as cur:
                     cur.execute(pysql.SQL('ALTER TABLE {} ADD COLUMN token_info JSONB')
                                 .format(pysql.Identifier(table)))
         tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False,
@@ -212,7 +209,7 @@ def create_tiger_housenumber_index(conn: Connection, **_: Any) -> None:
         The inclusion is needed for efficient lookup of housenumbers in
         full address searches.
     """
-    if conn.server_version_tuple() >= (11, 0, 0):
+    if server_version_tuple(conn) >= (11, 0, 0):
         with conn.cursor() as cur:
             cur.execute(""" CREATE INDEX IF NOT EXISTS
                                 idx_location_property_tiger_housenumber_migrated
@@ -239,7 +236,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
         Also converts the data into the stricter format which requires that
         startnumbers comply with the odd/even requirements.
     """
-    if conn.table_has_column('location_property_osmline', 'step'):
+    if table_has_column(conn, 'location_property_osmline', 'step'):
         return
 
     with conn.cursor() as cur:
@@ -271,7 +268,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
 def add_step_column_for_tiger(conn: Connection, **_: Any) -> None:
     """ Add a new column 'step' to the tiger data table.
     """
-    if conn.table_has_column('location_property_tiger', 'step'):
+    if table_has_column(conn, 'location_property_tiger', 'step'):
         return
 
     with conn.cursor() as cur:
@@ -287,7 +284,7 @@ def add_derived_name_column_for_country_names(conn: Connection, **_: Any) -> Non
     """ Add a new column 'derived_name' which in the future takes the
         country names as imported from OSM data.
     """
-    if not conn.table_has_column('country_name', 'derived_name'):
+    if not table_has_column(conn, 'country_name', 'derived_name'):
         with conn.cursor() as cur:
             cur.execute("ALTER TABLE country_name ADD COLUMN derived_name public.HSTORE")
 
@@ -297,12 +294,9 @@ def mark_internal_country_names(conn: Connection, config: Configuration, **_: An
     """ Names from the country table should be marked as internal to prevent
         them from being deleted. Only necessary for ICU tokenizer.
     """
-    import psycopg2.extras # pylint: disable=import-outside-toplevel
-
     tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
     with tokenizer.name_analyzer() as analyzer:
         with conn.cursor() as cur:
-            psycopg2.extras.register_hstore(cur)
             cur.execute("SELECT country_code, name FROM country_name")
 
             for country_code, names in cur:
@@ -319,7 +313,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
         The table is only necessary when updates are possible, i.e.
         the database is not in freeze mode.
     """
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         with conn.cursor() as cur:
             cur.execute("""CREATE TABLE IF NOT EXISTS place_to_be_deleted (
                              osm_type CHAR(1),
@@ -333,7 +327,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
 def split_pending_index(conn: Connection, **_: Any) -> None:
     """ Reorganise indexes for pending updates.
     """
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS idx_placex_rank_address_sector
                            ON placex USING BTREE (rank_address, geometry_sector)
@@ -349,7 +343,7 @@ def split_pending_index(conn: Connection, **_: Any) -> None:
 def enable_forward_dependencies(conn: Connection, **_: Any) -> None:
     """ Create indexes for updates with forward dependency tracking (long-running).
     """
-    if conn.table_exists('planet_osm_ways'):
+    if table_exists(conn, 'planet_osm_ways'):
         with conn.cursor() as cur:
             cur.execute("""SELECT * FROM pg_indexes
                            WHERE tablename = 'planet_osm_ways'
@@ -398,7 +392,7 @@ def create_postcode_area_lookup_index(conn: Connection, **_: Any) -> None:
 def create_postcode_parent_index(conn: Connection, **_: Any) -> None:
     """ Create index needed for updating postcodes when a parent changes.
     """
-    if conn.table_exists('planet_osm_ways'):
+    if table_exists(conn, 'planet_osm_ways'):
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS
                              idx_location_postcode_parent_place_id
index 8dc5bdbdb43cacb46fec70a86d553734a4ef92d7..357b2bae027bef4be2917fe5952e03ca01b2ab75 100644 (file)
@@ -16,9 +16,9 @@ import gzip
 import logging
 from math import isfinite
 
-from psycopg2 import sql as pysql
+from psycopg import sql as pysql
 
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, table_exists
 from ..utils.centroid import PointsCentroid
 from ..data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher
 from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
@@ -76,30 +76,30 @@ class _PostcodeCollector:
 
         with conn.cursor() as cur:
             if to_add:
-                cur.execute_values(
+                cur.executemany(pysql.SQL(
                     """INSERT INTO location_postcode
                          (place_id, indexed_status, country_code,
-                          postcode, geometry) VALUES %s""",
-                    to_add,
-                    template=pysql.SQL("""(nextval('seq_place'), 1, {},
-                                          %s, 'SRID=4326;POINT(%s %s)')
-                                       """).format(pysql.Literal(self.country)))
+                          postcode, geometry)
+                       VALUES (nextval('seq_place'), 1, {}, %s,
+                               ST_SetSRID(ST_MakePoint(%s, %s), 4326))
+                    """).format(pysql.Literal(self.country)),
+                    to_add)
             if to_delete:
                 cur.execute("""DELETE FROM location_postcode
                                WHERE country_code = %s and postcode = any(%s)
                             """, (self.country, to_delete))
             if to_update:
-                cur.execute_values(
+                cur.executemany(
                     pysql.SQL("""UPDATE location_postcode
                                  SET indexed_status = 2,
-                                     geometry = ST_SetSRID(ST_Point(v.x, v.y), 4326)
-                                 FROM (VALUES %s) AS v (pc, x, y)
-                                 WHERE country_code = {} and postcode = pc
-                              """).format(pysql.Literal(self.country)), to_update)
+                                     geometry = ST_SetSRID(ST_Point(%s, %s), 4326)
+                                 WHERE country_code = {} and postcode = %s
+                              """).format(pysql.Literal(self.country)),
+                    to_update)
 
 
     def _compute_changes(self, conn: Connection) \
-          -> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[str, float, float]]]:
+          -> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[float, float, str]]]:
         """ Compute which postcodes from the collected postcodes have to be
             added or modified and which from the location_postcode table
             have to be deleted.
@@ -116,7 +116,7 @@ class _PostcodeCollector:
                 if pcobj:
                     newx, newy = pcobj.centroid()
                     if (x - newx) > 0.0000001 or (y - newy) > 0.0000001:
-                        to_update.append((postcode, newx, newy))
+                        to_update.append((newx, newy, postcode))
                 else:
                     to_delete.append(postcode)
 
@@ -231,4 +231,4 @@ def can_compute(dsn: str) -> bool:
         postcodes can be computed.
     """
     with connect(dsn) as conn:
-        return conn.table_exists('place')
+        return table_exists(conn, 'place')
index 6a40c0a73a4c6427d11fe89b0db0552c19a1e7cf..d48c4e45a01dd68ee71fd338b4f06373d53ffe0c 100644 (file)
@@ -14,11 +14,12 @@ import logging
 from textwrap import dedent
 from pathlib import Path
 
-from psycopg2 import sql as pysql
+from psycopg import sql as pysql
 
 from ..config import Configuration
-from ..db.connection import Connection, connect
-from ..db.utils import execute_file, CopyBuffer
+from ..db.connection import Connection, connect, postgis_version_tuple,\
+                            drop_tables, table_exists
+from ..db.utils import execute_file
 from ..db.sql_preprocessor import SQLPreprocessor
 from ..version import NOMINATIM_VERSION
 
@@ -56,9 +57,9 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s
     for entry in levels:
         _add_address_level_rows_from_entry(rows, entry)
 
-    with conn.cursor() as cur:
-        cur.drop_table(table)
+    drop_tables(conn, table)
 
+    with conn.cursor() as cur:
         cur.execute(pysql.SQL("""CREATE TABLE {} (
                                         country_code varchar(2),
                                         class TEXT,
@@ -67,8 +68,8 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s
                                         rank_address SMALLINT)
                               """).format(pysql.Identifier(table)))
 
-        cur.execute_values(pysql.SQL("INSERT INTO {} VALUES %s")
-                           .format(pysql.Identifier(table)), rows)
+        cur.executemany(pysql.SQL("INSERT INTO {} VALUES (%s, %s, %s, %s, %s)")
+                             .format(pysql.Identifier(table)), rows)
 
         cur.execute(pysql.SQL('CREATE UNIQUE INDEX ON {} (country_code, class, type)')
                     .format(pysql.Identifier(table)))
@@ -154,15 +155,13 @@ def import_importance_csv(dsn: str, data_file: Path) -> int:
     if not data_file.exists():
         return 1
 
-    # Only import the first occurence of a wikidata ID.
+    # Only import the first occurrence of a wikidata ID.
     # This keeps indexes and table small.
     wd_done = set()
 
     with connect(dsn) as conn:
+        drop_tables(conn, 'wikipedia_article', 'wikipedia_redirect', 'wikimedia_importance')
         with conn.cursor() as cur:
-            cur.drop_table('wikipedia_article')
-            cur.drop_table('wikipedia_redirect')
-            cur.drop_table('wikimedia_importance')
             cur.execute("""CREATE TABLE wikimedia_importance (
                              language TEXT NOT NULL,
                              title TEXT NOT NULL,
@@ -170,24 +169,17 @@ def import_importance_csv(dsn: str, data_file: Path) -> int:
                              wikidata TEXT
                            ) """)
 
-        with gzip.open(str(data_file), 'rt') as fd, CopyBuffer() as buf:
-            for row in csv.DictReader(fd, delimiter='\t', quotechar='|'):
-                wd_id = int(row['wikidata_id'][1:])
-                buf.add(row['language'], row['title'], row['importance'],
-                        None if wd_id in wd_done else row['wikidata_id'])
-                wd_done.add(wd_id)
-
-                if buf.size() > 10000000:
-                    with conn.cursor() as cur:
-                        buf.copy_out(cur, 'wikimedia_importance',
-                                     columns=['language', 'title', 'importance',
-                                              'wikidata'])
+            copy_cmd = """COPY wikimedia_importance(language, title, importance, wikidata)
+                          FROM STDIN"""
+            with gzip.open(str(data_file), 'rt') as fd, cur.copy(copy_cmd) as copy:
+                for row in csv.DictReader(fd, delimiter='\t', quotechar='|'):
+                    wd_id = int(row['wikidata_id'][1:])
+                    copy.write_row((row['language'],
+                                    row['title'],
+                                    row['importance'],
+                                    None if wd_id in wd_done else row['wikidata_id']))
+                    wd_done.add(wd_id)
 
-            with conn.cursor() as cur:
-                buf.copy_out(cur, 'wikimedia_importance',
-                             columns=['language', 'title', 'importance', 'wikidata'])
-
-        with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS idx_wikimedia_importance_title
                            ON wikimedia_importance (title)""")
             cur.execute("""CREATE INDEX IF NOT EXISTS idx_wikimedia_importance_wikidata
@@ -228,7 +220,7 @@ def import_secondary_importance(dsn: str, data_path: Path, ignore_errors: bool =
         return 1
 
     with connect(dsn) as conn:
-        postgis_version = conn.postgis_version_tuple()
+        postgis_version = postgis_version_tuple(conn)
         if postgis_version[0] < 3:
             LOG.error('PostGIS version is too old for using OSM raster data.')
             return 2
@@ -309,7 +301,7 @@ def setup_website(basedir: Path, config: Configuration, conn: Connection) -> Non
 
     template = "\nrequire_once(CONST_LibDir.'/website/{}');\n"
 
-    search_name_table_exists = bool(conn and conn.table_exists('search_name'))
+    search_name_table_exists = bool(conn and table_exists(conn, 'search_name'))
 
     for script in WEBSITE_SCRIPTS:
         if not search_name_table_exists and script == 'search.php':
index bf1189df032a85a40a985c20e783edfd74410e29..2b1d444f0cb1e3cd08b6540273ac9b64d136e43f 100644 (file)
@@ -20,7 +20,7 @@ import requests
 
 from ..errors import UsageError
 from ..db import status
-from ..db.connection import Connection, connect
+from ..db.connection import Connection, connect, server_version_tuple
 from .exec_utils import run_osm2pgsql
 
 try:
@@ -155,7 +155,7 @@ def run_osm2pgsql_updates(conn: Connection, options: MutableMapping[str, Any]) -
 
     # Consume updates with osm2pgsql.
     options['append'] = True
-    options['disable_jit'] = conn.server_version_tuple() >= (11, 0)
+    options['disable_jit'] = server_version_tuple(conn) >= (11, 0)
     run_osm2pgsql(options)
 
     # Handle deletions
index 1bdcdaf133d5d3db10fdb18e23c2b3f6ea24abce..311e37e2010125571449dd4143adb0dd8c16a8a5 100644 (file)
@@ -17,11 +17,11 @@ from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set
 import logging
 import re
 
-from psycopg2.sql import Identifier, SQL
+from psycopg.sql import Identifier, SQL
 
 from ...typing import Protocol
 from ...config import Configuration
-from ...db.connection import Connection
+from ...db.connection import Connection, drop_tables, index_exists
 from .importer_statistics import SpecialPhrasesImporterStatistics
 from .special_phrase import SpecialPhrase
 from ...tokenizer.base import AbstractTokenizer
@@ -233,7 +233,7 @@ class SPImporter():
         index_prefix = f'idx_place_classtype_{phrase_class}_{phrase_type}_'
         base_table = _classtype_table(phrase_class, phrase_type)
         # Index on centroid
-        if not self.db_connection.index_exists(index_prefix + 'centroid'):
+        if not index_exists(self.db_connection, index_prefix + 'centroid'):
             with self.db_connection.cursor() as db_cursor:
                 db_cursor.execute(SQL("CREATE INDEX {} ON {} USING GIST (centroid) {}")
                                   .format(Identifier(index_prefix + 'centroid'),
@@ -241,7 +241,7 @@ class SPImporter():
                                           SQL(sql_tablespace)))
 
         # Index on place_id
-        if not self.db_connection.index_exists(index_prefix + 'place_id'):
+        if not index_exists(self.db_connection, index_prefix + 'place_id'):
             with self.db_connection.cursor() as db_cursor:
                 db_cursor.execute(SQL("CREATE INDEX {} ON {} USING btree(place_id) {}")
                                   .format(Identifier(index_prefix + 'place_id'),
@@ -259,6 +259,7 @@ class SPImporter():
                               .format(Identifier(table_name),
                                       Identifier(self.config.DATABASE_WEBUSER)))
 
+
     def _remove_non_existent_tables_from_db(self) -> None:
         """
             Remove special phrases which doesn't exist on the wiki anymore.
@@ -268,7 +269,6 @@ class SPImporter():
 
         # Delete place_classtype tables corresponding to class/type which
         # are not on the wiki anymore.
-        with self.db_connection.cursor() as db_cursor:
-            for table in self.table_phrases_to_delete:
-                self.statistics_handler.notify_one_table_deleted()
-                db_cursor.drop_table(table)
+        drop_tables(self.db_connection, *self.table_phrases_to_delete)
+        for _ in self.table_phrases_to_delete:
+            self.statistics_handler.notify_one_table_deleted()
index 7c52b7102a68707f4764a129957e84b71aab2f69..f4a7eba770b618df52b91f33946a4b761110b52e 100644 (file)
@@ -7,22 +7,22 @@
 """
 Functions for importing tiger data and handling tarbar and directory files
 """
-from typing import Any, TextIO, List, Union, cast
+from typing import Any, TextIO, List, Union, cast, Iterator, Dict
 import csv
 import io
 import logging
 import os
 import tarfile
 
-from psycopg2.extras import Json
+from psycopg.types.json import Json
 
 from ..config import Configuration
 from ..db.connection import connect
-from ..db.async_connection import WorkerPool
 from ..db.sql_preprocessor import SQLPreprocessor
 from ..errors import UsageError
+from ..db.query_pool import QueryPool
 from ..data.place_info import PlaceInfo
-from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
+from ..tokenizer.base import AbstractTokenizer
 from . import freeze
 
 LOG = logging.getLogger()
@@ -63,13 +63,13 @@ class TigerInput:
             self.tar_handle.close()
             self.tar_handle = None
 
+    def __bool__(self) -> bool:
+        return bool(self.files)
 
-    def next_file(self) -> TextIO:
+    def get_file(self, fname: Union[str, tarfile.TarInfo]) -> TextIO:
         """ Return a file handle to the next file to be processed.
             Raises an IndexError if there is no file left.
         """
-        fname = self.files.pop(0)
-
         if self.tar_handle is not None:
             extracted = self.tar_handle.extractfile(fname)
             assert extracted is not None
@@ -78,47 +78,22 @@ class TigerInput:
         return open(cast(str, fname), encoding='utf-8')
 
 
-    def __len__(self) -> int:
-        return len(self.files)
-
-
-def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO,
-                                   analyzer: AbstractAnalyzer) -> None:
-    """ Handles sql statement with multiplexing
-    """
-    lines = 0
-    # Using pool of database connections to execute sql statements
-
-    sql = "SELECT tiger_line_import(%s, %s, %s, %s, %s, %s)"
-
-    for row in csv.DictReader(fd, delimiter=';'):
-        try:
-            address = dict(street=row['street'], postcode=row['postcode'])
-            args = ('SRID=4326;' + row['geometry'],
-                    int(row['from']), int(row['to']), row['interpolation'],
-                    Json(analyzer.process_place(PlaceInfo({'address': address}))),
-                    analyzer.normalize_postcode(row['postcode']))
-        except ValueError:
-            continue
-        pool.next_free_worker().perform(sql, args=args)
-
-        lines += 1
-        if lines == 1000:
-            print('.', end='', flush=True)
-            lines = 0
+    def __iter__(self) -> Iterator[Dict[str, Any]]:
+        """ Iterate over the lines in each file.
+        """
+        for fname in self.files:
+            fd = self.get_file(fname)
+            yield from csv.DictReader(fd, delimiter=';')
 
 
-def add_tiger_data(data_dir: str, config: Configuration, threads: int,
+async def add_tiger_data(data_dir: str, config: Configuration, threads: int,
                    tokenizer: AbstractTokenizer) -> int:
     """ Import tiger data from directory or tar file `data dir`.
     """
     dsn = config.get_libpq_dsn()
 
     with connect(dsn) as conn:
-        is_frozen = freeze.is_frozen(conn)
-        conn.close()
-
-        if is_frozen:
+        if freeze.is_frozen(conn):
             raise UsageError("Tiger cannot be imported when database frozen (Github issue #3048)")
 
     with TigerInput(data_dir) as tar:
@@ -133,13 +108,30 @@ def add_tiger_data(data_dir: str, config: Configuration, threads: int,
         # sql_query in <threads - 1> chunks.
         place_threads = max(1, threads - 1)
 
-        with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
+        async with QueryPool(dsn, place_threads, autocommit=True) as pool:
             with tokenizer.name_analyzer() as analyzer:
-                while tar:
-                    with tar.next_file() as fd:
-                        handle_threaded_sql_statements(pool, fd, analyzer)
-
-        print('\n')
+                lines = 0
+                for row in tar:
+                    try:
+                        address = dict(street=row['street'], postcode=row['postcode'])
+                        args = ('SRID=4326;' + row['geometry'],
+                                int(row['from']), int(row['to']), row['interpolation'],
+                                Json(analyzer.process_place(PlaceInfo({'address': address}))),
+                                analyzer.normalize_postcode(row['postcode']))
+                    except ValueError:
+                        continue
+
+                    await pool.put_query(
+                        """SELECT tiger_line_import(%s::GEOMETRY, %s::INT,
+                                                    %s::INT, %s::TEXT, %s::JSONB, %s::TEXT)""",
+                        args)
+
+                    lines += 1
+                    if lines == 1000:
+                        print('.', end='', flush=True)
+                    lines = 0
+
+        print('', flush=True)
 
     LOG.warning("Creating indexes on Tiger data")
     with connect(dsn) as conn:
index f1abee8280691ea9e5737a9a46553b81e2edc01b..6f0145c36e4d8753561476dcf61d3b55ad42f220 100644 (file)
@@ -16,18 +16,13 @@ from typing import Any, Union, Mapping, TypeVar, Sequence, TYPE_CHECKING
 # pylint: disable=missing-class-docstring,useless-import-alias
 
 if TYPE_CHECKING:
-    import psycopg2.sql
-    import psycopg2.extensions
-    import psycopg2.extras
     import os
 
 StrPath = Union[str, 'os.PathLike[str]']
 
 SysEnv = Mapping[str, str]
 
-# psycopg2-related types
-
-Query = Union[str, bytes, 'psycopg2.sql.Composable']
+# psycopg-related types
 
 T_ResultKey = TypeVar('T_ResultKey', int, str)
 
@@ -36,8 +31,6 @@ class DictCursorResult(Mapping[str, Any]):
 
 DictCursorResults = Sequence[DictCursorResult]
 
-T_cursor = TypeVar('T_cursor', bound='psycopg2.extensions.cursor')
-
 # The following typing features require typing_extensions to work
 # on all supported Python versions.
 # Only require this for type checking but not for normal operations.
index 70e1ac14ac9a6bc13969ea0e0f059ceb0b078f6c..fceee5d04f6d961daa07ada58629cc928c643679 100644 (file)
@@ -31,7 +31,7 @@ class NominatimVersion(NamedTuple):
     major: int
     minor: int
     patch_level: int
-    db_patch_level: Optional[int]
+    db_patch_level: int
 
     def __str__(self) -> str:
         if self.db_patch_level is None:
@@ -47,6 +47,7 @@ class NominatimVersion(NamedTuple):
         return f"{self.major}.{self.minor}.{self.patch_level}"
 
 
+
 def parse_version(version: str) -> NominatimVersion:
     """ Parse a version string into a version consisting of a tuple of
         four ints: major, minor, patch level, database patch level
index dfbbee2874f56cf3918982aa45201b0ede3c1557..17a7674507b844ae1983524aa04e8d772cf849fb 100644 (file)
@@ -9,14 +9,14 @@ import importlib
 import sys
 import tempfile
 
-import psycopg2
-import psycopg2.extras
+import psycopg
+from psycopg import sql as pysql
 
 sys.path.insert(1, str((Path(__file__) / '..' / '..' / '..' / '..'/ 'src').resolve()))
 
 from nominatim_db import cli
 from nominatim_db.config import Configuration
-from nominatim_db.db.connection import Connection
+from nominatim_db.db.connection import Connection, register_hstore, execute_scalar
 from nominatim_db.tools import refresh
 from nominatim_db.tokenizer import factory as tokenizer_factory
 from steps.utils import run_script
@@ -60,7 +60,7 @@ class NominatimEnvironment:
         """ Return a connection to the database with the given name.
             Uses configured host, user and port.
         """
-        dbargs = {'database': dbname}
+        dbargs = {'dbname': dbname, 'row_factory': psycopg.rows.dict_row}
         if self.db_host:
             dbargs['host'] = self.db_host
         if self.db_port:
@@ -69,8 +69,7 @@ class NominatimEnvironment:
             dbargs['user'] = self.db_user
         if self.db_pass:
             dbargs['password'] = self.db_pass
-        conn = psycopg2.connect(connection_factory=Connection, **dbargs)
-        return conn
+        return psycopg.connect(**dbargs)
 
     def next_code_coverage_file(self):
         """ Generate the next name for a coverage file.
@@ -132,6 +131,8 @@ class NominatimEnvironment:
             conn = False
         refresh.setup_website(Path(self.website_dir.name) / 'website',
                               self.get_test_config(), conn)
+        if conn:
+            conn.close()
 
 
     def get_test_config(self):
@@ -160,11 +161,10 @@ class NominatimEnvironment:
     def db_drop_database(self, name):
         """ Drop the database with the given name.
         """
-        conn = self.connect_database('postgres')
-        conn.set_isolation_level(0)
-        cur = conn.cursor()
-        cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
-        conn.close()
+        with self.connect_database('postgres') as conn:
+            conn.autocommit = True
+            conn.execute(pysql.SQL('DROP DATABASE IF EXISTS')
+                         +  pysql.Identifier(name))
 
     def setup_template_db(self):
         """ Setup a template database that already contains common test data.
@@ -249,16 +249,18 @@ class NominatimEnvironment:
         """ Setup a test against a fresh, empty test database.
         """
         self.setup_template_db()
-        conn = self.connect_database(self.template_db)
-        conn.set_isolation_level(0)
-        cur = conn.cursor()
-        cur.execute('DROP DATABASE IF EXISTS {}'.format(self.test_db))
-        cur.execute('CREATE DATABASE {} TEMPLATE = {}'.format(self.test_db, self.template_db))
-        conn.close()
+        with self.connect_database(self.template_db) as conn:
+            conn.autocommit = True
+            conn.execute(pysql.SQL('DROP DATABASE IF EXISTS')
+                                   + pysql.Identifier(self.test_db))
+            conn.execute(pysql.SQL('CREATE DATABASE {} TEMPLATE = {}').format(
+                           pysql.Identifier(self.test_db),
+                           pysql.Identifier(self.template_db)))
+
         self.write_nominatim_config(self.test_db)
         context.db = self.connect_database(self.test_db)
         context.db.autocommit = True
-        psycopg2.extras.register_hstore(context.db, globally=False)
+        register_hstore(context.db)
 
     def teardown_db(self, context, force_drop=False):
         """ Remove the test database, if it exists.
@@ -276,31 +278,26 @@ class NominatimEnvironment:
             dropped and always false returned.
         """
         if self.reuse_template:
-            conn = self.connect_database('postgres')
-            with conn.cursor() as cur:
-                cur.execute('select count(*) from pg_database where datname = %s',
-                            (name,))
-                if cur.fetchone()[0] == 1:
+            with self.connect_database('postgres') as conn:
+                num = execute_scalar(conn,
+                                     'select count(*) from pg_database where datname = %s',
+                                     (name,))
+                if num == 1:
                     return True
-            conn.close()
         else:
             self.db_drop_database(name)
 
         return False
 
+
     def reindex_placex(self, db):
         """ Run the indexing step until all data in the placex has
             been processed. Indexing during updates can produce more data
             to index under some circumstances. That is why indexing may have
             to be run multiple times.
         """
-        with db.cursor() as cur:
-            while True:
-                self.run_nominatim('index')
+        self.run_nominatim('index')
 
-                cur.execute("SELECT 'a' FROM placex WHERE indexed_status != 0 LIMIT 1")
-                if cur.rowcount == 0:
-                    return
 
     def run_nominatim(self, *cmdline):
         """ Run the nominatim command-line tool via the library.
index 441198fdd4dfe46391d8fdcffad199d39f74efb5..a0dd9b348e7f60681ad56ed9618f45972d0da945 100644 (file)
@@ -7,7 +7,8 @@
 import logging
 from itertools import chain
 
-import psycopg2.extras
+import psycopg
+from psycopg import sql as pysql
 
 from place_inserter import PlaceColumn
 from table_compare import NominatimID, DBRow
@@ -18,7 +19,7 @@ from nominatim_db.tokenizer import factory as tokenizer_factory
 def check_database_integrity(context):
     """ Check some generic constraints on the tables.
     """
-    with context.db.cursor() as cur:
+    with context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
         # place_addressline should not have duplicate (place_id, address_place_id)
         cur.execute("""SELECT count(*) FROM
                         (SELECT place_id, address_place_id, count(*) as c
@@ -54,7 +55,7 @@ def add_data_to_planet_relations(context):
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
-        if row is None or row[0] == '1':
+        if row is None or row['value'] == '1':
             for r in context.table:
                 last_node = 0
                 last_way = 0
@@ -96,8 +97,8 @@ def add_data_to_planet_relations(context):
 
                 cur.execute("""INSERT INTO planet_osm_rels (id, tags, members)
                                VALUES (%s, %s, %s)""",
-                            (r['id'], psycopg2.extras.Json(tags),
-                             psycopg2.extras.Json(members)))
+                            (r['id'], psycopg.types.json.Json(tags),
+                             psycopg.types.json.Json(members)))
 
 @given("the ways")
 def add_data_to_planet_ways(context):
@@ -107,10 +108,10 @@ def add_data_to_planet_ways(context):
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
-        json_tags = row is not None and row[0] != '1'
+        json_tags = row is not None and row['value'] != '1'
         for r in context.table:
             if json_tags:
-                tags = psycopg2.extras.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
+                tags = psycopg.types.json.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
             else:
                 tags = list(chain.from_iterable([(h[5:], r[h])
                                                  for h in r.headings if h.startswith("tags+")]))
@@ -197,7 +198,7 @@ def check_place_contents(context, table, exact):
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         expected_content = set()
         for row in context.table:
             nid = NominatimID(row['object'])
@@ -215,8 +216,9 @@ def check_place_contents(context, table, exact):
                 DBRow(nid, res, context).assert_row(row, ['object'])
 
         if exact:
-            cur.execute('SELECT osm_type, osm_id, class from {}'.format(table))
-            actual = set([(r[0], r[1], r[2]) for r in cur])
+            cur.execute(pysql.SQL('SELECT osm_type, osm_id, class from')
+                        + pysql.Identifier(table))
+            actual = set([(r['osm_type'], r['osm_id'], r['class']) for r in cur])
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"
@@ -227,7 +229,7 @@ def check_place_has_entry(context, table, oid):
     """ Ensure that no database row for the given object exists. The ID
         must be of the form '<NRW><osm id>[:<class>]'.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
         assert cur.rowcount == 0, \
                "Found {} entries for ID {}".format(cur.rowcount, oid)
@@ -244,7 +246,7 @@ def check_search_name_contents(context, exclude):
     tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config())
 
     with tokenizer.name_analyzer() as analyzer:
-        with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        with context.db.cursor() as cur:
             for row in context.table:
                 nid = NominatimID(row['object'])
                 nid.row_by_place_id(cur, 'search_name',
@@ -276,7 +278,7 @@ def check_search_name_has_entry(context, oid):
     """ Check that there is noentry in the search_name table for the given
         objects. IDs are in format '<NRW><osm id>[:<class>]'.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         NominatimID(oid).row_by_place_id(cur, 'search_name')
 
         assert cur.rowcount == 0, \
@@ -290,7 +292,7 @@ def check_location_postcode(context):
         All rows must be present as excepted and there must not be additional
         rows.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
         assert cur.rowcount == len(list(context.table)), \
             "Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table)))
@@ -321,7 +323,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
 
     plist.sort()
 
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         if nctx.tokenizer != 'legacy':
             cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)",
                         (plist,))
@@ -330,7 +332,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
                              and class = 'place' and type = 'postcode'""",
                         (plist,))
 
-        found = [row[0] for row in cur]
+        found = [row['word'] for row in cur]
         assert len(found) == len(set(found)), f"Duplicate rows for postcodes: {found}"
 
     if exclude:
@@ -347,7 +349,7 @@ def check_place_addressline(context):
         representing the addressee and the 'address' column, representing the
         address item.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         for row in context.table:
             nid = NominatimID(row['object'])
             pid = nid.get_place_id(cur)
@@ -366,7 +368,7 @@ def check_place_addressline_exclude(context):
     """ Check that the place_addressline doesn't contain any entries for the
         given addressee/address item pairs.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         for row in context.table:
             pid = NominatimID(row['object']).get_place_id(cur)
             apid = NominatimID(row['address']).get_place_id(cur, allow_empty=True)
@@ -381,7 +383,7 @@ def check_place_addressline_exclude(context):
 def check_location_property_osmline(context, oid, neg):
     """ Check that the given way is present in the interpolation table.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt
                        FROM location_property_osmline
                        WHERE osm_id = %s AND startnumber IS NOT NULL""",
@@ -417,7 +419,7 @@ def check_place_contents(context, exact):
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         expected_content = set()
         for row in context.table:
             if ':' in row['object']:
@@ -447,7 +449,7 @@ def check_place_contents(context, exact):
 
         if exact:
             cur.execute('SELECT osm_id, startnumber from location_property_osmline')
-            actual = set([(r[0], r[1]) for r in cur])
+            actual = set([(r['osm_id'], r['startnumber']) for r in cur])
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"
index cf2e12f127871390126379fb348f2eb2cfeceec1..4284fad962607796c59560dbda2bceb687375a1e 100644 (file)
@@ -10,6 +10,9 @@ Functions to facilitate accessing and comparing the content of DB tables.
 import re
 import json
 
+import psycopg
+from psycopg import sql as pysql
+
 from steps.check_functions import Almost
 
 ID_REGEX = re.compile(r"(?P<typ>[NRW])(?P<oid>\d+)(:(?P<cls>\w+))?")
@@ -73,7 +76,7 @@ class NominatimID:
         assert cur.rowcount == 1, \
                "Place ID {!s} not unique. Found {} entries.".format(self, cur.rowcount)
 
-        return cur.fetchone()[0]
+        return cur.fetchone()['place_id']
 
 
 class DBRow:
@@ -152,9 +155,10 @@ class DBRow:
 
     def _has_centroid(self, expected):
         if expected == 'in geometry':
-            with self.context.db.cursor() as cur:
-                cur.execute("""SELECT ST_Within(ST_SetSRID(ST_Point({cx}, {cy}), 4326),
-                                        ST_SetSRID('{geomtxt}'::geometry, 4326))""".format(**self.db_row))
+            with self.context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
+                cur.execute("""SELECT ST_Within(ST_SetSRID(ST_Point(%(cx)s, %(cy)s), 4326),
+                                        ST_SetSRID(%(geomtxt)s::geometry, 4326))""",
+                            (self.db_row))
                 return cur.fetchone()[0]
 
         if ' ' in expected:
@@ -166,10 +170,11 @@ class DBRow:
 
     def _has_geometry(self, expected):
         geom = self.context.osm.parse_geometry(expected)
-        with self.context.db.cursor() as cur:
-            cur.execute("""SELECT ST_Equals(ST_SnapToGrid({}, 0.00001, 0.00001),
-                                   ST_SnapToGrid(ST_SetSRID('{}'::geometry, 4326), 0.00001, 0.00001))""".format(
-                            geom, self.db_row['geomtxt']))
+        with self.context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
+            cur.execute(pysql.SQL("""SELECT ST_Equals(ST_SnapToGrid({}, 0.00001, 0.00001),
+                                   ST_SnapToGrid(ST_SetSRID({}::geometry, 4326), 0.00001, 0.00001))""")
+                             .format(pysql.SQL(geom),
+                                     pysql.Literal(self.db_row['geomtxt'])))
             return cur.fetchone()[0]
 
     def assert_msg(self, name, value):
@@ -209,7 +214,7 @@ class DBRow:
             if actual == 0:
                 return "place ID 0"
 
-            with self.context.db.cursor() as cur:
+            with self.context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
                 cur.execute("""SELECT osm_type, osm_id, class
                                FROM placex WHERE place_id = %s""",
                             (actual, ))
index 1b8dc34d3b1b962f13c30f930531529e21bdd6d4..649dd8fc44f4619838ee7730da2e163698c0f462 100644 (file)
@@ -13,8 +13,6 @@ from pathlib import Path
 import pytest
 import pytest_asyncio
 
-import psycopg2.extras
-
 from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
 
 import nominatim_api.v1.server_glue as glue
@@ -31,7 +29,6 @@ class TestDeletableEndPoint:
 
     @pytest.fixture(autouse=True)
     def setup_deletable_table(self, temp_db_cursor, table_factory, temp_db_with_extensions):
-        psycopg2.extras.register_hstore(temp_db_cursor)
         table_factory('import_polygon_delete',
                       definition='osm_id bigint, osm_type char(1), class text, type text',
                       content=[(345, 'N', 'boundary', 'administrative'),
index bf51cd17ff2864c7e37eb60b6058286a10e82cfb..558be813e4d1b8c8183cd7ae5ae1826bb4d39149 100644 (file)
@@ -14,8 +14,6 @@ from pathlib import Path
 import pytest
 import pytest_asyncio
 
-import psycopg2.extras
-
 from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
 
 import nominatim_api.v1.server_glue as glue
@@ -32,8 +30,6 @@ class TestPolygonsEndPoint:
 
     @pytest.fixture(autouse=True)
     def setup_deletable_table(self, temp_db_cursor, table_factory, temp_db_with_extensions):
-        psycopg2.extras.register_hstore(temp_db_cursor)
-
         self.now = dt.datetime.now()
         self.recent = dt.datetime.now() - dt.timedelta(days=3)
 
index 1e3ca8abbe8b02a66361b26e7e640345cdb0be9b..d5ade22350a9a7764bfd26678fd1b97d559f22bf 100644 (file)
@@ -25,6 +25,23 @@ class MockParamCapture:
         return self.return_value
 
 
+class AsyncMockParamCapture:
+    """ Mock that records the parameters with which a function was called
+        as well as the number of calls.
+    """
+    def __init__(self, retval=0):
+        self.called = 0
+        self.return_value = retval
+        self.last_args = None
+        self.last_kwargs = None
+
+    async def __call__(self, *args, **kwargs):
+        self.called += 1
+        self.last_args = args
+        self.last_kwargs = kwargs
+        return self.return_value
+
+
 class DummyTokenizer:
     def __init__(self, *args, **kwargs):
         self.update_sql_functions_called = False
@@ -69,6 +86,17 @@ def mock_func_factory(monkeypatch):
     return get_mock
 
 
+@pytest.fixture
+def async_mock_func_factory(monkeypatch):
+    def get_mock(module, func):
+        mock = AsyncMockParamCapture()
+        mock.func_name = func
+        monkeypatch.setattr(module, func, mock)
+        return mock
+
+    return get_mock
+
+
 @pytest.fixture
 def cli_tokenizer_mock(monkeypatch):
     tok = DummyTokenizer()
index 688afb7c463addea7c42c6259c2b4c0deefed273..6586c5ec745457ddf2a0482faa4a3a3d205d45d5 100644 (file)
@@ -17,6 +17,7 @@ import pytest
 import nominatim_db.indexer.indexer
 import nominatim_db.tools.add_osm_data
 import nominatim_db.tools.freeze
+import nominatim_db.tools.tiger_data
 
 
 def test_cli_help(cli_call, capsys):
@@ -52,8 +53,8 @@ def test_cli_add_data_object_command(cli_call, mock_func_factory, name, oid):
 
 
 
-def test_cli_add_data_tiger_data(cli_call, cli_tokenizer_mock, mock_func_factory):
-    mock = mock_func_factory(nominatim_db.tools.tiger_data, 'add_tiger_data')
+def test_cli_add_data_tiger_data(cli_call, cli_tokenizer_mock, async_mock_func_factory):
+    mock = async_mock_func_factory(nominatim_db.tools.tiger_data, 'add_tiger_data')
 
     assert cli_call('add-data', '--tiger-data', 'somewhere') == 0
 
@@ -68,38 +69,6 @@ def test_cli_serve_php(cli_call, mock_func_factory):
     assert func.called == 1
 
 
-def test_cli_serve_starlette_custom_server(cli_call, mock_func_factory):
-    pytest.importorskip("starlette")
-    mod = pytest.importorskip("uvicorn")
-    func = mock_func_factory(mod, "run")
-
-    cli_call('serve', '--engine', 'starlette', '--server', 'foobar:4545') == 0
-
-    assert func.called == 1
-    assert func.last_kwargs['host'] == 'foobar'
-    assert func.last_kwargs['port'] == 4545
-
-
-def test_cli_serve_starlette_custom_server_bad_port(cli_call, mock_func_factory):
-    pytest.importorskip("starlette")
-    mod = pytest.importorskip("uvicorn")
-    func = mock_func_factory(mod, "run")
-
-    cli_call('serve', '--engine', 'starlette', '--server', 'foobar:45:45') == 1
-
-
-@pytest.mark.parametrize("engine", ['falcon', 'starlette'])
-def test_cli_serve_uvicorn_based(cli_call, engine, mock_func_factory):
-    pytest.importorskip(engine)
-    mod = pytest.importorskip("uvicorn")
-    func = mock_func_factory(mod, "run")
-
-    cli_call('serve', '--engine', engine) == 0
-
-    assert func.called == 1
-    assert func.last_kwargs['host'] == '127.0.0.1'
-    assert func.last_kwargs['port'] == 8088
-
 
 class TestCliWithDb:
 
@@ -120,16 +89,19 @@ class TestCliWithDb:
 
 
     @pytest.mark.parametrize("params,do_bnds,do_ranks", [
-                              ([], 1, 1),
-                              (['--boundaries-only'], 1, 0),
-                              (['--no-boundaries'], 0, 1),
+                              ([], 2, 2),
+                              (['--boundaries-only'], 2, 0),
+                              (['--no-boundaries'], 0, 2),
                               (['--boundaries-only', '--no-boundaries'], 0, 0)])
-    def test_index_command(self, mock_func_factory, table_factory,
+    def test_index_command(self, monkeypatch, async_mock_func_factory, table_factory,
                            params, do_bnds, do_ranks):
         table_factory('import_status', 'indexed bool')
-        bnd_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_boundaries')
-        rank_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_by_rank')
-        postcode_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes')
+        bnd_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_boundaries')
+        rank_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_by_rank')
+        postcode_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes')
+
+        monkeypatch.setattr(nominatim_db.indexer.indexer.Indexer, 'has_pending', 
+                            [False, True].pop)
 
         assert self.call_nominatim('index', *params) == 0
 
index 85235e1e2e1e123ed132d7a3c81a6f211fe20706..e47d713c1f32ab04574eb510fa1677450b9d9132 100644 (file)
@@ -34,7 +34,8 @@ class TestCliImportWithDb:
 
 
     @pytest.mark.parametrize('with_updates', [True, False])
-    def test_import_full(self, mock_func_factory, with_updates, place_table, property_table):
+    def test_import_full(self, mock_func_factory, async_mock_func_factory,
+                         with_updates, place_table, property_table):
         mocks = [
             mock_func_factory(nominatim_db.tools.database_import, 'setup_database_skeleton'),
             mock_func_factory(nominatim_db.data.country_info, 'setup_country_tables'),
@@ -42,15 +43,15 @@ class TestCliImportWithDb:
             mock_func_factory(nominatim_db.tools.refresh, 'import_wikipedia_articles'),
             mock_func_factory(nominatim_db.tools.refresh, 'import_secondary_importance'),
             mock_func_factory(nominatim_db.tools.database_import, 'truncate_data_tables'),
-            mock_func_factory(nominatim_db.tools.database_import, 'load_data'),
+            async_mock_func_factory(nominatim_db.tools.database_import, 'load_data'),
             mock_func_factory(nominatim_db.tools.database_import, 'create_tables'),
             mock_func_factory(nominatim_db.tools.database_import, 'create_table_triggers'),
             mock_func_factory(nominatim_db.tools.database_import, 'create_partition_tables'),
-            mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
+            async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
             mock_func_factory(nominatim_db.data.country_info, 'create_country_names'),
             mock_func_factory(nominatim_db.tools.refresh, 'load_address_levels_from_config'),
             mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes'),
-            mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
+            async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
             mock_func_factory(nominatim_db.tools.refresh, 'setup_website'),
         ]
 
@@ -72,14 +73,14 @@ class TestCliImportWithDb:
             assert mock.called == 1, "Mock '{}' not called".format(mock.func_name)
 
 
-    def test_import_continue_load_data(self, mock_func_factory):
+    def test_import_continue_load_data(self, mock_func_factory, async_mock_func_factory):
         mocks = [
             mock_func_factory(nominatim_db.tools.database_import, 'truncate_data_tables'),
-            mock_func_factory(nominatim_db.tools.database_import, 'load_data'),
-            mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
+            async_mock_func_factory(nominatim_db.tools.database_import, 'load_data'),
+            async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
             mock_func_factory(nominatim_db.data.country_info, 'create_country_names'),
             mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes'),
-            mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
+            async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
             mock_func_factory(nominatim_db.tools.refresh, 'setup_website'),
             mock_func_factory(nominatim_db.db.properties, 'set_property')
         ]
@@ -91,12 +92,12 @@ class TestCliImportWithDb:
             assert mock.called == 1, "Mock '{}' not called".format(mock.func_name)
 
 
-    def test_import_continue_indexing(self, mock_func_factory, placex_table,
-                                      temp_db_conn):
+    def test_import_continue_indexing(self, mock_func_factory, async_mock_func_factory,
+                                      placex_table, temp_db_conn):
         mocks = [
-            mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
+            async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
             mock_func_factory(nominatim_db.data.country_info, 'create_country_names'),
-            mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
+            async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
             mock_func_factory(nominatim_db.tools.refresh, 'setup_website'),
             mock_func_factory(nominatim_db.db.properties, 'set_property')
         ]
@@ -110,9 +111,9 @@ class TestCliImportWithDb:
         assert self.call_nominatim('import', '--continue', 'indexing') == 0
 
 
-    def test_import_continue_postprocess(self, mock_func_factory):
+    def test_import_continue_postprocess(self, mock_func_factory, async_mock_func_factory):
         mocks = [
-            mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
+            async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
             mock_func_factory(nominatim_db.data.country_info, 'create_country_names'),
             mock_func_factory(nominatim_db.tools.refresh, 'setup_website'),
             mock_func_factory(nominatim_db.db.properties, 'set_property')
index 3939681720c6142cf1fd17804a8b127da81d0727..9074b2cc3d44ae4434f68fdef8ecaf778366468c 100644 (file)
@@ -45,9 +45,9 @@ class TestRefresh:
         assert self.tokenizer_mock.update_word_tokens_called
 
 
-    def test_refresh_postcodes(self, mock_func_factory, place_table):
+    def test_refresh_postcodes(self, async_mock_func_factory, mock_func_factory, place_table):
         func_mock = mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes')
-        idx_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes')
+        idx_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes')
 
         assert self.call_nominatim('refresh', '--postcodes') == 0
         assert func_mock.called == 1
index 8c1e8ea6440a5d74da03084d57b6543458d89507..21c6350d4e762c62f66de007cd147da95d347671 100644 (file)
@@ -47,8 +47,8 @@ def init_status(temp_db_conn, status_table):
 
 
 @pytest.fixture
-def index_mock(mock_func_factory, tokenizer_mock, init_status):
-    return mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full')
+def index_mock(async_mock_func_factory, tokenizer_mock, init_status):
+    return async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full')
 
 
 @pytest.fixture
index d301c9736eb68c9f616a635c92693e2e7335b282..3ced320558347685543e246dd90b5741c8716625 100644 (file)
@@ -8,7 +8,8 @@ import itertools
 import sys
 from pathlib import Path
 
-import psycopg2
+import psycopg
+from psycopg import sql as pysql
 import pytest
 
 # always test against the source
@@ -36,26 +37,23 @@ def temp_db(monkeypatch):
         exported into NOMINATIM_DATABASE_DSN.
     """
     name = 'test_nominatim_python_unittest'
-    conn = psycopg2.connect(database='postgres')
 
-    conn.set_isolation_level(0)
-    with conn.cursor() as cur:
-        cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
-        cur.execute('CREATE DATABASE {}'.format(name))
-
-    conn.close()
+    with psycopg.connect(dbname='postgres', autocommit=True) as conn:
+        with conn.cursor() as cur:
+            cur.execute(pysql.SQL('DROP DATABASE IF EXISTS') + pysql.Identifier(name))
+            cur.execute(pysql.SQL('CREATE DATABASE') + pysql.Identifier(name))
 
     monkeypatch.setenv('NOMINATIM_DATABASE_DSN', 'dbname=' + name)
 
-    yield name
-
-    conn = psycopg2.connect(database='postgres')
+    with psycopg.connect(dbname=name) as conn:
+        with conn.cursor() as cur:
+            cur.execute('CREATE EXTENSION hstore')
 
-    conn.set_isolation_level(0)
-    with conn.cursor() as cur:
-        cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
+    yield name
 
-    conn.close()
+    with psycopg.connect(dbname='postgres', autocommit=True) as conn:
+        with conn.cursor() as cur:
+            cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
 
 
 @pytest.fixture
@@ -65,11 +63,9 @@ def dsn(temp_db):
 
 @pytest.fixture
 def temp_db_with_extensions(temp_db):
-    conn = psycopg2.connect(database=temp_db)
-    with conn.cursor() as cur:
-        cur.execute('CREATE EXTENSION hstore; CREATE EXTENSION postgis;')
-    conn.commit()
-    conn.close()
+    with psycopg.connect(dbname=temp_db) as conn:
+        with conn.cursor() as cur:
+            cur.execute('CREATE EXTENSION postgis')
 
     return temp_db
 
@@ -77,7 +73,8 @@ def temp_db_with_extensions(temp_db):
 def temp_db_conn(temp_db):
     """ Connection to the test database.
     """
-    with connection.connect('dbname=' + temp_db) as conn:
+    with connection.connect('', autocommit=True, dbname=temp_db) as conn:
+        connection.register_hstore(conn)
         yield conn
 
 
@@ -86,22 +83,25 @@ def temp_db_cursor(temp_db):
     """ Connection and cursor towards the test database. The connection will
         be in auto-commit mode.
     """
-    conn = psycopg2.connect('dbname=' + temp_db)
-    conn.set_isolation_level(0)
-    with conn.cursor(cursor_factory=CursorForTesting) as cur:
-        yield cur
-    conn.close()
+    with psycopg.connect(dbname=temp_db, autocommit=True, cursor_factory=CursorForTesting) as conn:
+        connection.register_hstore(conn)
+        with conn.cursor() as cur:
+            yield cur
 
 
 @pytest.fixture
-def table_factory(temp_db_cursor):
+def table_factory(temp_db_conn):
     """ A fixture that creates new SQL tables, potentially filled with
         content.
     """
     def mk_table(name, definition='id INT', content=None):
-        temp_db_cursor.execute('CREATE TABLE {} ({})'.format(name, definition))
-        if content is not None:
-            temp_db_cursor.execute_values("INSERT INTO {} VALUES %s".format(name), content)
+        with psycopg.ClientCursor(temp_db_conn) as cur:
+            cur.execute('CREATE TABLE {} ({})'.format(name, definition))
+            if content:
+                sql = pysql.SQL("INSERT INTO {} VALUES ({})")\
+                           .format(pysql.Identifier(name),
+                                   pysql.SQL(',').join([pysql.Placeholder() for _ in range(len(content[0]))]))
+                cur.executemany(sql , content)
 
     return mk_table
 
@@ -168,7 +168,6 @@ def place_row(place_table, temp_db_cursor):
     """ A factory for rows in the place table. The table is created as a
         prerequisite to the fixture.
     """
-    psycopg2.extras.register_hstore(temp_db_cursor)
     idseq = itertools.count(1001)
     def _insert(osm_type='N', osm_id=None, cls='amenity', typ='cafe', names=None,
                 admin_level=None, address=None, extratags=None, geom=None):
index 7d586b3c1520d9e0fea75cbd144c6111b3c0e58a..b3fc260a2526d016ae2607e681e702baaddb591e 100644 (file)
@@ -5,11 +5,11 @@
 # Copyright (C) 2024 by the Nominatim developer community.
 # For a full list of authors see the git log.
 """
-Specialised psycopg2 cursor with shortcut functions useful for testing.
+Specialised psycopg cursor with shortcut functions useful for testing.
 """
-import psycopg2.extras
+import psycopg
 
-class CursorForTesting(psycopg2.extras.DictCursor):
+class CursorForTesting(psycopg.Cursor):
     """ Extension to the DictCursor class that provides execution
         short-cuts that simplify writing assertions.
     """
@@ -59,9 +59,3 @@ class CursorForTesting(psycopg2.extras.DictCursor):
             return self.scalar('SELECT count(*) FROM ' + table)
 
         return self.scalar('SELECT count(*) FROM {} WHERE {}'.format(table, where))
-
-
-    def execute_values(self, *args, **kwargs):
-        """ Execute the execute_values() function on the cursor.
-        """
-        psycopg2.extras.execute_values(self, *args, **kwargs)
diff --git a/test/python/db/test_async_connection.py b/test/python/db/test_async_connection.py
deleted file mode 100644 (file)
index fff695e..0000000
+++ /dev/null
@@ -1,113 +0,0 @@
-# SPDX-License-Identifier: GPL-3.0-or-later
-#
-# This file is part of Nominatim. (https://nominatim.org)
-#
-# Copyright (C) 2024 by the Nominatim developer community.
-# For a full list of authors see the git log.
-"""
-Tests for function providing a non-blocking query interface towards PostgreSQL.
-"""
-from contextlib import closing
-import concurrent.futures
-
-import pytest
-import psycopg2
-
-from nominatim_db.db.async_connection import DBConnection, DeadlockHandler
-
-
-@pytest.fixture
-def conn(temp_db):
-    with closing(DBConnection('dbname=' + temp_db)) as connection:
-        yield connection
-
-
-@pytest.fixture
-def simple_conns(temp_db):
-    conn1 = psycopg2.connect('dbname=' + temp_db)
-    conn2 = psycopg2.connect('dbname=' + temp_db)
-
-    yield conn1.cursor(), conn2.cursor()
-
-    conn1.close()
-    conn2.close()
-
-
-def test_simple_query(conn, temp_db_conn):
-    conn.connect()
-
-    conn.perform('CREATE TABLE foo (id INT)')
-    conn.wait()
-
-    temp_db_conn.table_exists('foo')
-
-
-def test_wait_for_query(conn):
-    conn.connect()
-
-    conn.perform('SELECT pg_sleep(1)')
-
-    assert not conn.is_done()
-
-    conn.wait()
-
-
-def test_bad_query(conn):
-    conn.connect()
-
-    conn.perform('SELECT efasfjsea')
-
-    with pytest.raises(psycopg2.ProgrammingError):
-        conn.wait()
-
-
-def test_bad_query_ignore(temp_db):
-    with closing(DBConnection('dbname=' + temp_db, ignore_sql_errors=True)) as conn:
-        conn.connect()
-
-        conn.perform('SELECT efasfjsea')
-
-        conn.wait()
-
-
-def exec_with_deadlock(cur, sql, detector):
-    with DeadlockHandler(lambda *args: detector.append(1)):
-        cur.execute(sql)
-
-
-def test_deadlock(simple_conns):
-    cur1, cur2 = simple_conns
-
-    cur1.execute("""CREATE TABLE t1 (id INT PRIMARY KEY, t TEXT);
-                    INSERT into t1 VALUES (1, 'a'), (2, 'b')""")
-    cur1.connection.commit()
-
-    cur1.execute("UPDATE t1 SET t = 'x' WHERE id = 1")
-    cur2.execute("UPDATE t1 SET t = 'x' WHERE id = 2")
-
-    # This is the tricky part of the test. The first SQL command runs into
-    # a lock and blocks, so we have to run it in a separate thread. When the
-    # second deadlocking SQL statement is issued, Postgresql will abort one of
-    # the two transactions that cause the deadlock. There is no way to tell
-    # which one of the two. Therefore wrap both in a DeadlockHandler and
-    # expect that exactly one of the two triggers.
-    with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
-        deadlock_check = []
-        try:
-            future = executor.submit(exec_with_deadlock, cur2,
-                                     "UPDATE t1 SET t = 'y' WHERE id = 1",
-                                     deadlock_check)
-
-            while not future.running():
-                pass
-
-
-            exec_with_deadlock(cur1, "UPDATE t1 SET t = 'y' WHERE id = 2",
-                               deadlock_check)
-        finally:
-            # Whatever happens, make sure the deadlock gets resolved.
-            cur1.connection.rollback()
-
-        future.result()
-
-        assert len(deadlock_check) == 1
index 8b4cc62f0c46e6f8bd46db49bc5d3f3ab58cae58..a8b5d677ce22e43f8e9b8e69ff222dc08202e560 100644 (file)
@@ -8,63 +8,76 @@
 Tests for specialised connection and cursor classes.
 """
 import pytest
-import psycopg2
+import psycopg
 
-from nominatim_db.db.connection import connect, get_pg_env
+import nominatim_db.db.connection as nc
 
 @pytest.fixture
 def db(dsn):
-    with connect(dsn) as conn:
+    with nc.connect(dsn) as conn:
         yield conn
 
 
 def test_connection_table_exists(db, table_factory):
-    assert not db.table_exists('foobar')
+    assert not nc.table_exists(db, 'foobar')
 
     table_factory('foobar')
 
-    assert db.table_exists('foobar')
+    assert nc.table_exists(db, 'foobar')
 
 
 def test_has_column_no_table(db):
-    assert not db.table_has_column('sometable', 'somecolumn')
+    assert not nc.table_has_column(db, 'sometable', 'somecolumn')
 
 
 @pytest.mark.parametrize('name,result', [('tram', True), ('car', False)])
 def test_has_column(db, table_factory, name, result):
     table_factory('stuff', 'tram TEXT')
 
-    assert db.table_has_column('stuff', name) == result
+    assert nc.table_has_column(db, 'stuff', name) == result
 
 def test_connection_index_exists(db, table_factory, temp_db_cursor):
-    assert not db.index_exists('some_index')
+    assert not nc.index_exists(db, 'some_index')
 
     table_factory('foobar')
     temp_db_cursor.execute('CREATE INDEX some_index ON foobar(id)')
 
-    assert db.index_exists('some_index')
-    assert db.index_exists('some_index', table='foobar')
-    assert not db.index_exists('some_index', table='bar')
+    assert nc.index_exists(db, 'some_index')
+    assert nc.index_exists(db, 'some_index', table='foobar')
+    assert not nc.index_exists(db, 'some_index', table='bar')
 
 
 def test_drop_table_existing(db, table_factory):
     table_factory('dummy')
-    assert db.table_exists('dummy')
+    assert nc.table_exists(db, 'dummy')
 
-    db.drop_table('dummy')
-    assert not db.table_exists('dummy')
+    nc.drop_tables(db, 'dummy')
+    assert not nc.table_exists(db, 'dummy')
 
 
-def test_drop_table_non_existsing(db):
-    db.drop_table('dfkjgjriogjigjgjrdghehtre')
+def test_drop_table_non_existing(db):
+    nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre')
+
+
+def test_drop_many_tables(db, table_factory):
+    tables = [f'table{n}' for n in range(5)]
+
+    for t in tables:
+        table_factory(t)
+        assert nc.table_exists(db, t)
+
+    nc.drop_tables(db, *tables)
+
+    for t in tables:
+        assert not nc.table_exists(db, t)
 
 
 def test_drop_table_non_existing_force(db):
-    with pytest.raises(psycopg2.ProgrammingError, match='.*does not exist.*'):
-        db.drop_table('dfkjgjriogjigjgjrdghehtre', if_exists=False)
+    with pytest.raises(psycopg.ProgrammingError, match='.*does not exist.*'):
+        nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre', if_exists=False)
 
 def test_connection_server_version_tuple(db):
-    ver = db.server_version_tuple()
+    ver = nc.server_version_tuple(db)
 
     assert isinstance(ver, tuple)
     assert len(ver) == 2
@@ -72,7 +85,7 @@ def test_connection_server_version_tuple(db):
 
 
 def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
-    ver = db.postgis_version_tuple()
+    ver = nc.postgis_version_tuple(db)
 
     assert isinstance(ver, tuple)
     assert len(ver) == 2
@@ -82,27 +95,24 @@ def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
 def test_cursor_scalar(db, table_factory):
     table_factory('dummy')
 
-    with db.cursor() as cur:
-        assert cur.scalar('SELECT count(*) FROM dummy') == 0
+    assert nc.execute_scalar(db, 'SELECT count(*) FROM dummy') == 0
 
 
 def test_cursor_scalar_many_rows(db):
-    with db.cursor() as cur:
-        with pytest.raises(RuntimeError):
-            cur.scalar('SELECT * FROM pg_tables')
+    with pytest.raises(RuntimeError, match='Query did not return a single row.'):
+        nc.execute_scalar(db, 'SELECT * FROM pg_tables')
 
 
 def test_cursor_scalar_no_rows(db, table_factory):
     table_factory('dummy')
 
-    with db.cursor() as cur:
-        with pytest.raises(RuntimeError):
-            cur.scalar('SELECT id FROM dummy')
+    with pytest.raises(RuntimeError, match='Query did not return a single row.'):
+        nc.execute_scalar(db, 'SELECT id FROM dummy')
 
 
 def test_get_pg_env_add_variable(monkeypatch):
     monkeypatch.delenv('PGPASSWORD', raising=False)
-    env = get_pg_env('user=fooF')
+    env = nc.get_pg_env('user=fooF')
 
     assert env['PGUSER'] == 'fooF'
     assert 'PGPASSWORD' not in env
@@ -110,12 +120,12 @@ def test_get_pg_env_add_variable(monkeypatch):
 
 def test_get_pg_env_overwrite_variable(monkeypatch):
     monkeypatch.setenv('PGUSER', 'some default')
-    env = get_pg_env('user=overwriter')
+    env = nc.get_pg_env('user=overwriter')
 
     assert env['PGUSER'] == 'overwriter'
 
 
 def test_get_pg_env_ignore_unknown():
-    env = get_pg_env('client_encoding=stuff', base_env={})
+    env = nc.get_pg_env('client_encoding=stuff', base_env={})
 
     assert env == {}
index 114c5244e14f77c040434c127c6ecf20ae743ec7..45109c70c759452f0c972964944c40e79a322ff1 100644 (file)
@@ -8,6 +8,7 @@
 Tests for SQL preprocessing.
 """
 import pytest
+import pytest_asyncio
 
 from nominatim_db.db.sql_preprocessor import SQLPreprocessor
 
@@ -54,3 +55,17 @@ def test_load_file_with_params(sql_preprocessor, sql_factory, temp_db_conn, temp
     sql_preprocessor.run_sql_file(temp_db_conn, sqlfile, bar='XX', foo='ZZ')
 
     assert temp_db_cursor.scalar('SELECT test()') == 'ZZ XX'
+
+
+@pytest.mark.asyncio
+async def test_load_parallel_file(dsn, sql_preprocessor, tmp_path, temp_db_cursor):
+    (tmp_path / 'test.sql').write_text("""
+        CREATE TABLE foo (a TEXT);
+        CREATE TABLE foo2(a TEXT);""" + 
+        "\n---\nCREATE TABLE bar (b INT);")
+
+    await sql_preprocessor.run_parallel_sql_file(dsn, 'test.sql', num_threads=4)
+
+    assert temp_db_cursor.table_exists('foo')
+    assert temp_db_cursor.table_exists('foo2')
+    assert temp_db_cursor.table_exists('bar')
index b4335ab039d6dc473b7f0cd24ad51c98ec922569..7c46846dbdd4fb0ba5c2fc2ae436f7df76005411 100644 (file)
@@ -58,103 +58,3 @@ def test_execute_file_with_post_code(dsn, tmp_path, temp_db_cursor):
     db_utils.execute_file(dsn, tmpfile, post_code='INSERT INTO test VALUES(23)')
 
     assert temp_db_cursor.row_set('SELECT * FROM test') == {(23, )}
-
-
-class TestCopyBuffer:
-    TABLE_NAME = 'copytable'
-
-    @pytest.fixture(autouse=True)
-    def setup_test_table(self, table_factory):
-        table_factory(self.TABLE_NAME, 'col_a INT, col_b TEXT')
-
-
-    def table_rows(self, cursor):
-        return cursor.row_set('SELECT * FROM ' + self.TABLE_NAME)
-
-
-    def test_copybuffer_empty(self):
-        with db_utils.CopyBuffer() as buf:
-            buf.copy_out(None, "dummy")
-
-
-    def test_all_columns(self, temp_db_cursor):
-        with db_utils.CopyBuffer() as buf:
-            buf.add(3, 'hum')
-            buf.add(None, 'f\\t')
-
-            buf.copy_out(temp_db_cursor, self.TABLE_NAME)
-
-        assert self.table_rows(temp_db_cursor) == {(3, 'hum'), (None, 'f\\t')}
-
-
-    def test_selected_columns(self, temp_db_cursor):
-        with db_utils.CopyBuffer() as buf:
-            buf.add('foo')
-
-            buf.copy_out(temp_db_cursor, self.TABLE_NAME,
-                         columns=['col_b'])
-
-        assert self.table_rows(temp_db_cursor) == {(None, 'foo')}
-
-
-    def test_reordered_columns(self, temp_db_cursor):
-        with db_utils.CopyBuffer() as buf:
-            buf.add('one', 1)
-            buf.add(' two ', 2)
-
-            buf.copy_out(temp_db_cursor, self.TABLE_NAME,
-                         columns=['col_b', 'col_a'])
-
-        assert self.table_rows(temp_db_cursor) == {(1, 'one'), (2, ' two ')}
-
-
-    def test_special_characters(self, temp_db_cursor):
-        with db_utils.CopyBuffer() as buf:
-            buf.add('foo\tbar')
-            buf.add('sun\nson')
-            buf.add('\\N')
-
-            buf.copy_out(temp_db_cursor, self.TABLE_NAME,
-                         columns=['col_b'])
-
-        assert self.table_rows(temp_db_cursor) == {(None, 'foo\tbar'),
-                                                   (None, 'sun\nson'),
-                                                   (None, '\\N')}
-
-
-
-class TestCopyBufferJson:
-    TABLE_NAME = 'copytable'
-
-    @pytest.fixture(autouse=True)
-    def setup_test_table(self, table_factory):
-        table_factory(self.TABLE_NAME, 'col_a INT, col_b JSONB')
-
-
-    def table_rows(self, cursor):
-        cursor.execute('SELECT * FROM ' + self.TABLE_NAME)
-        results = {k: v for k,v in cursor}
-
-        assert len(results) == cursor.rowcount
-
-        return results
-
-
-    def test_json_object(self, temp_db_cursor):
-        with db_utils.CopyBuffer() as buf:
-            buf.add(1, json.dumps({'test': 'value', 'number': 1}))
-
-            buf.copy_out(temp_db_cursor, self.TABLE_NAME)
-
-        assert self.table_rows(temp_db_cursor) == \
-                   {1: {'test': 'value', 'number': 1}}
-
-
-    def test_json_object_special_chras(self, temp_db_cursor):
-        with db_utils.CopyBuffer() as buf:
-            buf.add(1, json.dumps({'te\tst': 'va\nlue', 'nu"mber': None}))
-
-            buf.copy_out(temp_db_cursor, self.TABLE_NAME)
-
-        assert self.table_rows(temp_db_cursor) == \
-                   {1: {'te\tst': 'va\nlue', 'nu"mber': None}}
index 93481e8a8a0bc8bd6b7978c07a99cc206367f466..fe65b69c4dff04d39ebb04ceb133a42bc2a21f82 100644 (file)
@@ -9,6 +9,7 @@ Tests for running the indexing.
 """
 import itertools
 import pytest
+import pytest_asyncio
 
 from nominatim_db.indexer import indexer
 from nominatim_db.tokenizer import factory
@@ -21,9 +22,8 @@ class IndexerTestDB:
         self.postcode_id = itertools.count(700000)
 
         self.conn = conn
-        self.conn.set_isolation_level(0)
+        self.conn.autocimmit = True
         with self.conn.cursor() as cur:
-            cur.execute('CREATE EXTENSION hstore')
             cur.execute("""CREATE TABLE placex (place_id BIGINT,
                                                 name HSTORE,
                                                 class TEXT,
@@ -156,7 +156,8 @@ def test_tokenizer(tokenizer_mock, project_env):
 
 
 @pytest.mark.parametrize("threads", [1, 15])
-def test_index_all_by_rank(test_db, threads, test_tokenizer):
+@pytest.mark.asyncio
+async def test_index_all_by_rank(test_db, threads, test_tokenizer):
     for rank in range(31):
         test_db.add_place(rank_address=rank, rank_search=rank)
     test_db.add_osmline()
@@ -165,7 +166,7 @@ def test_index_all_by_rank(test_db, threads, test_tokenizer):
     assert test_db.osmline_unindexed() == 1
 
     idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
-    idx.index_by_rank(0, 30)
+    await idx.index_by_rank(0, 30)
 
     assert test_db.placex_unindexed() == 0
     assert test_db.osmline_unindexed() == 0
@@ -190,7 +191,8 @@ def test_index_all_by_rank(test_db, threads, test_tokenizer):
 
 
 @pytest.mark.parametrize("threads", [1, 15])
-def test_index_partial_without_30(test_db, threads, test_tokenizer):
+@pytest.mark.asyncio
+async def test_index_partial_without_30(test_db, threads, test_tokenizer):
     for rank in range(31):
         test_db.add_place(rank_address=rank, rank_search=rank)
     test_db.add_osmline()
@@ -200,7 +202,7 @@ def test_index_partial_without_30(test_db, threads, test_tokenizer):
 
     idx = indexer.Indexer('dbname=test_nominatim_python_unittest',
                           test_tokenizer, threads)
-    idx.index_by_rank(4, 15)
+    await idx.index_by_rank(4, 15)
 
     assert test_db.placex_unindexed() == 19
     assert test_db.osmline_unindexed() == 1
@@ -211,7 +213,8 @@ def test_index_partial_without_30(test_db, threads, test_tokenizer):
 
 
 @pytest.mark.parametrize("threads", [1, 15])
-def test_index_partial_with_30(test_db, threads, test_tokenizer):
+@pytest.mark.asyncio
+async def test_index_partial_with_30(test_db, threads, test_tokenizer):
     for rank in range(31):
         test_db.add_place(rank_address=rank, rank_search=rank)
     test_db.add_osmline()
@@ -220,7 +223,7 @@ def test_index_partial_with_30(test_db, threads, test_tokenizer):
     assert test_db.osmline_unindexed() == 1
 
     idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
-    idx.index_by_rank(28, 30)
+    await idx.index_by_rank(28, 30)
 
     assert test_db.placex_unindexed() == 27
     assert test_db.osmline_unindexed() == 0
@@ -230,7 +233,8 @@ def test_index_partial_with_30(test_db, threads, test_tokenizer):
                       WHERE indexed_status = 0 AND rank_address between 1 and 27""") == 0
 
 @pytest.mark.parametrize("threads", [1, 15])
-def test_index_boundaries(test_db, threads, test_tokenizer):
+@pytest.mark.asyncio
+async def test_index_boundaries(test_db, threads, test_tokenizer):
     for rank in range(4, 10):
         test_db.add_admin(rank_address=rank, rank_search=rank)
     for rank in range(31):
@@ -241,7 +245,7 @@ def test_index_boundaries(test_db, threads, test_tokenizer):
     assert test_db.osmline_unindexed() == 1
 
     idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
-    idx.index_boundaries(0, 30)
+    await idx.index_boundaries(0, 30)
 
     assert test_db.placex_unindexed() == 31
     assert test_db.osmline_unindexed() == 1
@@ -252,21 +256,23 @@ def test_index_boundaries(test_db, threads, test_tokenizer):
 
 
 @pytest.mark.parametrize("threads", [1, 15])
-def test_index_postcodes(test_db, threads, test_tokenizer):
+@pytest.mark.asyncio
+async def test_index_postcodes(test_db, threads, test_tokenizer):
     for postcode in range(1000):
         test_db.add_postcode('de', postcode)
     for postcode in range(32000, 33000):
         test_db.add_postcode('us', postcode)
 
     idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
-    idx.index_postcodes()
+    await idx.index_postcodes()
 
     assert test_db.scalar("""SELECT count(*) FROM location_postcode
                                   WHERE indexed_status != 0""") == 0
 
 
 @pytest.mark.parametrize("analyse", [True, False])
-def test_index_full(test_db, analyse, test_tokenizer):
+@pytest.mark.asyncio
+async def test_index_full(test_db, analyse, test_tokenizer):
     for rank in range(4, 10):
         test_db.add_admin(rank_address=rank, rank_search=rank)
     for rank in range(31):
@@ -276,22 +282,9 @@ def test_index_full(test_db, analyse, test_tokenizer):
         test_db.add_postcode('de', postcode)
 
     idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, 4)
-    idx.index_full(analyse=analyse)
+    await idx.index_full(analyse=analyse)
 
     assert test_db.placex_unindexed() == 0
     assert test_db.osmline_unindexed() == 0
     assert test_db.scalar("""SELECT count(*) FROM location_postcode
                              WHERE indexed_status != 0""") == 0
-
-
-@pytest.mark.parametrize("threads", [1, 15])
-def test_index_reopen_connection(test_db, threads, monkeypatch, test_tokenizer):
-    monkeypatch.setattr(indexer.WorkerPool, "REOPEN_CONNECTIONS_AFTER", 15)
-
-    for _ in range(1000):
-        test_db.add_place(rank_address=30, rank_search=30)
-
-    idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
-    idx.index_by_rank(28, 30)
-
-    assert test_db.placex_unindexed() == 0
index 5c465e8b7e988b6488756f9a8fb8a5c54278ba94..e8b4390f5fc38c4c6c55c5417f3cc36373a8dc7b 100644 (file)
@@ -8,6 +8,7 @@
 Legacy word table for testing with functions to prefil and test contents
 of the table.
 """
+from nominatim_db.db.connection import execute_scalar
 
 class MockIcuWordTable:
     """ A word table for testing using legacy word table structure.
@@ -35,9 +36,9 @@ class MockIcuWordTable:
         with self.conn.cursor() as cur:
             cur.execute("""INSERT INTO word (word_token, type, word, info)
                               VALUES (%s, 'S', %s,
-                                      json_build_object('class', %s,
-                                                        'type', %s,
-                                                        'op', %s))
+                                      json_build_object('class', %s::text,
+                                                        'type', %s::text,
+                                                        'op', %s::text))
                         """, (word_token, word, cls, typ, oper))
         self.conn.commit()
 
@@ -70,25 +71,22 @@ class MockIcuWordTable:
                     word = word_tokens[0]
                 for token in word_tokens:
                     cur.execute("""INSERT INTO word (word_id, word_token, type, word, info)
-                                      VALUES (%s, %s, 'H', %s, jsonb_build_object('lookup', %s))
+                                      VALUES (%s, %s, 'H', %s, jsonb_build_object('lookup', %s::text))
                                 """, (word_id, token, word, word_tokens[0]))
 
         self.conn.commit()
 
 
     def count(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word")
 
 
     def count_special(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word WHERE type = 'S'")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'S'")
 
 
     def count_housenumbers(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word WHERE type = 'H'")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'H'")
 
 
     def get_special(self):
index 9804341f41c33fd91df794e15fcf935f3f00d395..d3f81a4db3948bcb95d267676e236d392987bf71 100644 (file)
@@ -8,6 +8,7 @@
 Legacy word table for testing with functions to prefil and test contents
 of the table.
 """
+from nominatim_db.db.connection import execute_scalar
 
 class MockLegacyWordTable:
     """ A word table for testing using legacy word table structure.
@@ -58,18 +59,16 @@ class MockLegacyWordTable:
 
 
     def count(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word")
 
 
     def count_special(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word WHERE class != 'place'")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE class != 'place'")
 
 
     def get_special(self):
         with self.conn.cursor() as cur:
-            cur.execute("""SELECT word_token, word, class, type, operator
+            cur.execute("""SELECT word_token, word, class as cls, type, operator
                            FROM word WHERE class != 'place'""")
             result = set((tuple(row) for row in cur))
             assert len(result) == cur.rowcount, "Word table has duplicates."
index 82e700bd4456f71fc8c839f0741d450decdc958c..cde0b7bb58e8fdab6208ca6cef05d10c179535ff 100644 (file)
@@ -9,8 +9,6 @@ Custom mocks for testing.
 """
 import itertools
 
-import psycopg2.extras
-
 from nominatim_db.db import properties
 
 # This must always point to the mock word table for the default tokenizer.
@@ -56,7 +54,6 @@ class MockPlacexTable:
             admin_level=None, address=None, extratags=None, geom='POINT(10 4)',
             country=None, housenumber=None, rank_search=30):
         with self.conn.cursor() as cur:
-            psycopg2.extras.register_hstore(cur)
             cur.execute("""INSERT INTO placex (place_id, osm_type, osm_id, class,
                                                type, name, admin_level, address,
                                                housenumber, rank_search,
index 357b7d4ae4cde571315c2a7d8109595911e343f4..a2bf676699ec9a326f59957740740899da8340dc 100644 (file)
@@ -199,16 +199,16 @@ def test_update_sql_functions(db_prop, temp_db_cursor,
     assert test_content == set((('1133', ), ))
 
 
-def test_finalize_import(tokenizer_factory, temp_db_conn,
-                         temp_db_cursor, test_config, sql_preprocessor_cfg):
+def test_finalize_import(tokenizer_factory, temp_db_cursor,
+                         test_config, sql_preprocessor_cfg):
     tok = tokenizer_factory()
     tok.init_new_db(test_config)
 
-    assert not temp_db_conn.index_exists('idx_word_word_id')
+    assert not temp_db_cursor.index_exists('word', 'idx_word_word_id')
 
     tok.finalize_import(test_config)
 
-    assert temp_db_conn.index_exists('idx_word_word_id')
+    assert temp_db_cursor.index_exists('word', 'idx_word_word_id')
 
 
 def test_check_database(test_config, tokenizer_factory,
index 548ec800b6b8f966427911dd2d91a90d19c8cb64..df2042982c98f090dc24174e38064873c15b09b3 100644 (file)
@@ -8,10 +8,11 @@
 Tests for functions to import a new database.
 """
 from pathlib import Path
-from contextlib import closing
 
 import pytest
-import psycopg2
+import pytest_asyncio
+import psycopg
+from psycopg import sql as pysql
 
 from nominatim_db.tools import database_import
 from nominatim_db.errors import UsageError
@@ -21,10 +22,7 @@ class TestDatabaseSetup:
 
     @pytest.fixture(autouse=True)
     def setup_nonexistant_db(self):
-        conn = psycopg2.connect(database='postgres')
-
-        try:
-            conn.set_isolation_level(0)
+        with psycopg.connect(dbname='postgres', autocommit=True) as conn:
             with conn.cursor() as cur:
                 cur.execute(f'DROP DATABASE IF EXISTS {self.DBNAME}')
 
@@ -32,22 +30,17 @@ class TestDatabaseSetup:
 
             with conn.cursor() as cur:
                 cur.execute(f'DROP DATABASE IF EXISTS {self.DBNAME}')
-        finally:
-            conn.close()
+
 
     @pytest.fixture
     def cursor(self):
-        conn = psycopg2.connect(database=self.DBNAME)
-
-        try:
+        with psycopg.connect(dbname=self.DBNAME) as conn:
             with conn.cursor() as cur:
                 yield cur
-        finally:
-            conn.close()
 
 
     def conn(self):
-        return closing(psycopg2.connect(database=self.DBNAME))
+        return psycopg.connect(dbname=self.DBNAME)
 
 
     def test_setup_skeleton(self):
@@ -132,7 +125,7 @@ def test_import_osm_data_simple_ignore_no_data(table_factory, osm2pgsql_options)
                                     ignore_errors=True)
 
 
-def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_options):
+def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql_options):
     table_factory('place', content=((1, ), ))
     table_factory('planet_osm_nodes')
 
@@ -144,7 +137,7 @@ def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_o
     database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options, drop=True)
 
     assert not flatfile.exists()
-    assert not temp_db_conn.table_exists('planet_osm_nodes')
+    assert not temp_db_cursor.table_exists('planet_osm_nodes')
 
 
 def test_import_osm_data_default_cache(table_factory, osm2pgsql_options, capfd):
@@ -178,18 +171,19 @@ def test_truncate_database_tables(temp_db_conn, temp_db_cursor, table_factory, w
 
 
 @pytest.mark.parametrize("threads", (1, 5))
-def test_load_data(dsn, place_row, placex_table, osmline_table,
+@pytest.mark.asyncio
+async def test_load_data(dsn, place_row, placex_table, osmline_table,
                    temp_db_cursor, threads):
     for func in ('precompute_words', 'getorcreate_housenumber_id', 'make_standard_name'):
-        temp_db_cursor.execute(f"""CREATE FUNCTION {func} (src TEXT)
-                                  RETURNS TEXT AS $$ SELECT 'a'::TEXT $$ LANGUAGE SQL
-                               """)
+        temp_db_cursor.execute(pysql.SQL("""CREATE FUNCTION {} (src TEXT)
+                                            RETURNS TEXT AS $$ SELECT 'a'::TEXT $$ LANGUAGE SQL
+                                         """).format(pysql.Identifier(func)))
     for oid in range(100, 130):
         place_row(osm_id=oid)
     place_row(osm_type='W', osm_id=342, cls='place', typ='houses',
               geom='SRID=4326;LINESTRING(0 0, 10 10)')
 
-    database_import.load_data(dsn, threads)
+    await database_import.load_data(dsn, threads)
 
     assert temp_db_cursor.table_rows('placex') == 30
     assert temp_db_cursor.table_rows('location_property_osmline') == 1
@@ -241,11 +235,12 @@ class TestSetupSQL:
 
 
     @pytest.mark.parametrize("drop", [True, False])
-    def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop):
+    @pytest.mark.asyncio
+    async def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop):
         self.write_sql('indices.sql',
                        """CREATE FUNCTION test() RETURNS bool
                           AS $$ SELECT {{drop}} $$ LANGUAGE SQL""")
 
-        database_import.create_search_indices(temp_db_conn, self.config, drop)
+        await database_import.create_search_indices(temp_db_conn, self.config, drop)
 
         temp_db_cursor.scalar('SELECT test()') == drop
index 64eb7b1856b0996c3958ae71e51b3ecd3fbbdbb4..0d33e6e0f30e4a0ab47bfb02a4cd63f2ad4c55a9 100644 (file)
@@ -75,7 +75,8 @@ def test_load_white_and_black_lists(sp_importer):
     assert isinstance(black_list, dict) and isinstance(white_list, dict)
 
 
-def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn,
+def test_create_place_classtype_indexes(temp_db_with_extensions,
+                                        temp_db_conn, temp_db_cursor,
                                         table_factory, sp_importer):
     """
         Test that _create_place_classtype_indexes() create the
@@ -88,10 +89,11 @@ def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn,
     table_factory(table_name, 'place_id BIGINT, centroid GEOMETRY')
 
     sp_importer._create_place_classtype_indexes('', phrase_class, phrase_type)
+    temp_db_conn.commit()
 
-    assert check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type)
+    assert check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type)
 
-def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer):
+def test_create_place_classtype_table(temp_db_conn, temp_db_cursor, placex_table, sp_importer):
     """
         Test that _create_place_classtype_table() create
         the right place_classtype table.
@@ -99,10 +101,12 @@ def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer):
     phrase_class = 'class'
     phrase_type = 'type'
     sp_importer._create_place_classtype_table('', phrase_class, phrase_type)
+    temp_db_conn.commit()
 
-    assert check_table_exist(temp_db_conn, phrase_class, phrase_type)
+    assert check_table_exist(temp_db_cursor, phrase_class, phrase_type)
 
-def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_importer):
+def test_grant_access_to_web_user(temp_db_conn, temp_db_cursor, table_factory,
+                                  def_config, sp_importer):
     """
         Test that _grant_access_to_webuser() give
         right access to the web user.
@@ -114,12 +118,13 @@ def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_im
     table_factory(table_name)
 
     sp_importer._grant_access_to_webuser(phrase_class, phrase_type)
+    temp_db_conn.commit()
 
-    assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, phrase_class, phrase_type)
+    assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, phrase_class, phrase_type)
 
 def test_create_place_classtype_table_and_indexes(
-        temp_db_conn, def_config, placex_table,
-        sp_importer):
+        temp_db_cursor, def_config, placex_table,
+        sp_importer, temp_db_conn):
     """
         Test that _create_place_classtype_table_and_indexes()
         create the right place_classtype tables and place_id indexes
@@ -129,14 +134,15 @@ def test_create_place_classtype_table_and_indexes(
     pairs = set([('class1', 'type1'), ('class2', 'type2')])
 
     sp_importer._create_classtype_table_and_indexes(pairs)
+    temp_db_conn.commit()
 
     for pair in pairs:
-        assert check_table_exist(temp_db_conn, pair[0], pair[1])
-        assert check_placeid_and_centroid_indexes(temp_db_conn, pair[0], pair[1])
-        assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, pair[0], pair[1])
+        assert check_table_exist(temp_db_cursor, pair[0], pair[1])
+        assert check_placeid_and_centroid_indexes(temp_db_cursor, pair[0], pair[1])
+        assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, pair[0], pair[1])
 
 def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
-                                            temp_db_conn):
+                                            temp_db_conn, temp_db_cursor):
     """
         Check for the remove_non_existent_phrases_from_db() method.
 
@@ -159,15 +165,14 @@ def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
     """
 
     sp_importer._remove_non_existent_tables_from_db()
+    temp_db_conn.commit()
 
-    # Changes are not committed yet. Use temp_db_conn for checking results.
-    with temp_db_conn.cursor(cursor_factory=CursorForTesting) as cur:
-        assert cur.row_set(query_tables) \
+    assert temp_db_cursor.row_set(query_tables) \
                  == {('place_classtype_testclasstypetable_to_keep', )}
 
 
 @pytest.mark.parametrize("should_replace", [(True), (False)])
-def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer,
+def test_import_phrases(monkeypatch, temp_db_cursor, def_config, sp_importer,
                         placex_table, table_factory, tokenizer_mock,
                         xml_wiki_content, should_replace):
     """
@@ -193,49 +198,49 @@ def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer,
     class_test = 'aerialway'
     type_test = 'zip_line'
 
-    assert check_table_exist(temp_db_conn, class_test, type_test)
-    assert check_placeid_and_centroid_indexes(temp_db_conn, class_test, type_test)
-    assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, class_test, type_test)
-    assert check_table_exist(temp_db_conn, 'amenity', 'animal_shelter')
+    assert check_table_exist(temp_db_cursor, class_test, type_test)
+    assert check_placeid_and_centroid_indexes(temp_db_cursor, class_test, type_test)
+    assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, class_test, type_test)
+    assert check_table_exist(temp_db_cursor, 'amenity', 'animal_shelter')
     if should_replace:
-        assert not check_table_exist(temp_db_conn, 'wrong_class', 'wrong_type')
+        assert not check_table_exist(temp_db_cursor, 'wrong_class', 'wrong_type')
 
-    assert temp_db_conn.table_exists('place_classtype_amenity_animal_shelter')
+    assert temp_db_cursor.table_exists('place_classtype_amenity_animal_shelter')
     if should_replace:
-        assert not temp_db_conn.table_exists('place_classtype_wrongclass_wrongtype')
+        assert not temp_db_cursor.table_exists('place_classtype_wrongclass_wrongtype')
 
-def check_table_exist(temp_db_conn, phrase_class, phrase_type):
+def check_table_exist(temp_db_cursor, phrase_class, phrase_type):
     """
         Verify that the place_classtype table exists for the given
         phrase_class and phrase_type.
     """
-    return temp_db_conn.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type))
+    return temp_db_cursor.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type))
 
 
-def check_grant_access(temp_db_conn, user, phrase_class, phrase_type):
+def check_grant_access(temp_db_cursor, user, phrase_class, phrase_type):
     """
         Check that the web user has been granted right access to the
         place_classtype table of the given phrase_class and phrase_type.
     """
     table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
 
-    with temp_db_conn.cursor() as temp_db_cursor:
-        temp_db_cursor.execute("""
-                SELECT * FROM information_schema.role_table_grants
-                WHERE table_name='{}'
-                AND grantee='{}'
-                AND privilege_type='SELECT'""".format(table_name, user))
-        return temp_db_cursor.fetchone()
+    temp_db_cursor.execute("""
+            SELECT * FROM information_schema.role_table_grants
+            WHERE table_name='{}'
+            AND grantee='{}'
+            AND privilege_type='SELECT'""".format(table_name, user))
+    return temp_db_cursor.fetchone()
 
-def check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type):
+def check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type):
     """
         Check that the place_id index and centroid index exist for the
         place_classtype table of the given phrase_class and phrase_type.
     """
+    table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
     index_prefix = 'idx_place_classtype_{}_{}_'.format(phrase_class, phrase_type)
 
     return (
-        temp_db_conn.index_exists(index_prefix + 'centroid')
+        temp_db_cursor.index_exists(table_name, index_prefix + 'centroid')
         and
-        temp_db_conn.index_exists(index_prefix + 'place_id')
+        temp_db_cursor.index_exists(table_name, index_prefix + 'place_id')
     )
index 3a849adbf956ed8469730c9cb6eb2d2d377eff0c..2c7b2d5664cbf16591a848934e87ac9bf72a680e 100644 (file)
@@ -8,10 +8,10 @@
 Tests for migration functions
 """
 import pytest
-import psycopg2.extras
 
 from nominatim_db.tools import migration
 from nominatim_db.errors import UsageError
+from nominatim_db.db.connection import server_version_tuple
 import nominatim_db.version
 
 from mock_legacy_word_table import MockLegacyWordTable
@@ -43,7 +43,6 @@ def test_no_migration_old_versions(temp_db_with_extensions, table_factory, def_c
 def test_set_up_migration_for_36(temp_db_with_extensions, temp_db_cursor,
                                  table_factory, def_config, monkeypatch,
                                  postprocess_mock):
-    psycopg2.extras.register_hstore(temp_db_cursor)
     # don't actually run any migration, except the property table creation
     monkeypatch.setattr(migration, '_MIGRATION_FUNCTIONS',
                         [((3, 5, 0, 99), migration.add_nominatim_property_table)])
@@ -63,7 +62,7 @@ def test_set_up_migration_for_36(temp_db_with_extensions, temp_db_cursor,
                                           WHERE property = 'database_version'""")
 
 
-def test_already_at_version(def_config, property_table):
+def test_already_at_version(temp_db_with_extensions, def_config, property_table):
 
     property_table.set('database_version',
                        str(nominatim_db.version.NOMINATIM_VERSION))
@@ -71,8 +70,8 @@ def test_already_at_version(def_config, property_table):
     assert migration.migrate(def_config, {}) == 0
 
 
-def test_run_single_migration(def_config, temp_db_cursor, property_table,
-                              monkeypatch, postprocess_mock):
+def test_run_single_migration(temp_db_with_extensions, def_config, temp_db_cursor,
+                              property_table, monkeypatch, postprocess_mock):
     oldversion = [x for x in nominatim_db.version.NOMINATIM_VERSION]
     oldversion[0] -= 1
     property_table.set('database_version',
@@ -226,7 +225,7 @@ def test_create_tiger_housenumber_index(temp_db_conn, temp_db_cursor, table_fact
     migration.create_tiger_housenumber_index(temp_db_conn)
     temp_db_conn.commit()
 
-    if temp_db_conn.server_version_tuple() >= (11, 0, 0):
+    if server_version_tuple(temp_db_conn) >= (11, 0, 0):
         assert temp_db_cursor.index_exists('location_property_tiger',
                                            'idx_location_property_tiger_housenumber_migrated')
 
index febb2271da289a76266fab6b5443f95df871a6a8..f035bb19affa5c1195cb8999f61897cdd68c00de 100644 (file)
@@ -47,7 +47,7 @@ class MockPostcodeTable:
                                                           country_code, postcode,
                                                           geometry)
                            VALUES (nextval('seq_place'), 1, %s, %s,
-                                   'SRID=4326;POINT(%s %s)')""",
+                                   ST_SetSRID(ST_MakePoint(%s, %s), 4326))""",
                         (country, postcode, x, y))
         self.conn.commit()
 
index 50ff6398b5673cdc073b4d4c62af6831729f4def..1f1968cfac7fe0c1a0e214037aa097728cfe0efe 100644 (file)
@@ -12,6 +12,7 @@ from pathlib import Path
 import pytest
 
 from nominatim_db.tools import refresh
+from nominatim_db.db.connection import postgis_version_tuple
 
 def test_refresh_import_wikipedia_not_existing(dsn):
     assert refresh.import_wikipedia_articles(dsn, Path('.')) == 1
@@ -23,13 +24,13 @@ def test_refresh_import_secondary_importance_non_existing(dsn):
 def test_refresh_import_secondary_importance_testdb(dsn, src_dir, temp_db_conn, temp_db_cursor):
     temp_db_cursor.execute('CREATE EXTENSION postgis')
 
-    if temp_db_conn.postgis_version_tuple()[0] < 3:
+    if postgis_version_tuple(temp_db_conn)[0] < 3:
         assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') > 0
     else:
         temp_db_cursor.execute('CREATE EXTENSION postgis_raster')
         assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') == 0
 
-        assert temp_db_conn.table_exists('secondary_importance')
+        assert temp_db_cursor.table_exists('secondary_importance')
 
 
 @pytest.mark.parametrize("replace", (True, False))
index 7ef6a1e63306020f41d1dcb91863a4a01be6c51e..5d65fafb3b5c4d272472269ece832ea7bb220be6 100644 (file)
@@ -11,7 +11,9 @@ import tarfile
 from textwrap import dedent
 
 import pytest
+import pytest_asyncio
 
+from nominatim_db.db.connection import execute_scalar
 from nominatim_db.tools import tiger_data, freeze
 from nominatim_db.errors import UsageError
 
@@ -31,8 +33,7 @@ class MockTigerTable:
             cur.execute("CREATE TABLE place (number INTEGER)")
 
     def count(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM tiger")
+        return execute_scalar(self.conn, "SELECT count(*) FROM tiger")
 
     def row(self):
         with self.conn.cursor() as cur:
@@ -76,82 +77,91 @@ def csv_factory(tmp_path):
 
 
 @pytest.mark.parametrize("threads", (1, 5))
-def test_add_tiger_data(def_config, src_dir, tiger_table, tokenizer_mock, threads):
-    tiger_data.add_tiger_data(str(src_dir / 'test' / 'testdb' / 'tiger'),
-                              def_config, threads, tokenizer_mock())
+@pytest.mark.asyncio
+async def test_add_tiger_data(def_config, src_dir, tiger_table, tokenizer_mock, threads):
+    await tiger_data.add_tiger_data(str(src_dir / 'test' / 'testdb' / 'tiger'),
+                                    def_config, threads, tokenizer_mock())
 
     assert tiger_table.count() == 6213
 
 
-def test_add_tiger_data_database_frozen(def_config, temp_db_conn, tiger_table, tokenizer_mock,
+@pytest.mark.asyncio
+async def test_add_tiger_data_database_frozen(def_config, temp_db_conn, tiger_table, tokenizer_mock,
                                  tmp_path):
     freeze.drop_update_tables(temp_db_conn)
 
     with pytest.raises(UsageError) as excinfo:
-        tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
+        await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
 
         assert "database frozen" in str(excinfo.value)
 
     assert tiger_table.count() == 0
 
-def test_add_tiger_data_no_files(def_config, tiger_table, tokenizer_mock,
+
+@pytest.mark.asyncio
+async def test_add_tiger_data_no_files(def_config, tiger_table, tokenizer_mock,
                                  tmp_path):
-    tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
+    await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
 
     assert tiger_table.count() == 0
 
 
-def test_add_tiger_data_bad_file(def_config, tiger_table, tokenizer_mock,
+@pytest.mark.asyncio
+async def test_add_tiger_data_bad_file(def_config, tiger_table, tokenizer_mock,
                                  tmp_path):
     sqlfile = tmp_path / '1010.csv'
     sqlfile.write_text("""Random text""")
 
-    tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
+    await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
 
     assert tiger_table.count() == 0
 
 
-def test_add_tiger_data_hnr_nan(def_config, tiger_table, tokenizer_mock,
+@pytest.mark.asyncio
+async def test_add_tiger_data_hnr_nan(def_config, tiger_table, tokenizer_mock,
                                 csv_factory, tmp_path):
     csv_factory('file1', hnr_from=99)
     csv_factory('file2', hnr_from='L12')
     csv_factory('file3', hnr_to='12.4')
 
-    tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
+    await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
 
     assert tiger_table.count() == 1
-    assert tiger_table.row()['start'] == 99
+    assert tiger_table.row().start == 99
 
 
 @pytest.mark.parametrize("threads", (1, 5))
-def test_add_tiger_data_tarfile(def_config, tiger_table, tokenizer_mock,
+@pytest.mark.asyncio
+async def test_add_tiger_data_tarfile(def_config, tiger_table, tokenizer_mock,
                                 tmp_path, src_dir, threads):
     tar = tarfile.open(str(tmp_path / 'sample.tar.gz'), "w:gz")
     tar.add(str(src_dir / 'test' / 'testdb' / 'tiger' / '01001.csv'))
     tar.close()
 
-    tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, threads,
-                              tokenizer_mock())
+    await tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, threads,
+                                    tokenizer_mock())
 
     assert tiger_table.count() == 6213
 
 
-def test_add_tiger_data_bad_tarfile(def_config, tiger_table, tokenizer_mock,
+@pytest.mark.asyncio
+async def test_add_tiger_data_bad_tarfile(def_config, tiger_table, tokenizer_mock,
                                     tmp_path):
     tarfile = tmp_path / 'sample.tar.gz'
     tarfile.write_text("""Random text""")
 
     with pytest.raises(UsageError):
-        tiger_data.add_tiger_data(str(tarfile), def_config, 1, tokenizer_mock())
+        await tiger_data.add_tiger_data(str(tarfile), def_config, 1, tokenizer_mock())
 
 
-def test_add_tiger_data_empty_tarfile(def_config, tiger_table, tokenizer_mock,
+@pytest.mark.asyncio
+async def test_add_tiger_data_empty_tarfile(def_config, tiger_table, tokenizer_mock,
                                       tmp_path):
     tar = tarfile.open(str(tmp_path / 'sample.tar.gz'), "w:gz")
     tar.add(__file__)
     tar.close()
 
-    tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, 1,
-                              tokenizer_mock())
+    await tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, 1,
+                                    tokenizer_mock())
 
     assert tiger_table.count() == 0
index 9d8642a69de397f34b3ee4a4dd7add593b6342df..d02ea0a74e36c4582a8070c71b88e2deeaa23e83 100755 (executable)
@@ -26,10 +26,14 @@ export DEBIAN_FRONTEND=noninteractive #DOCS:
                         nlohmann-json3-dev postgresql-14-postgis-3 \
                         postgresql-contrib-14 postgresql-14-postgis-3-scripts \
                         libicu-dev python3-dotenv \
-                        python3-psycopg2 python3-psutil python3-jinja2 \
+                        python3-pip python3-psutil python3-jinja2 \
                         python3-sqlalchemy python3-asyncpg \
                         python3-icu python3-datrie python3-yaml git
 
+# Some of the Python packages that come with Ubuntu 22.04 are too old,
+# so install the latest version from pip:
+
+    pip3 install --user psycopg[binary]
 
 #
 # System Configuration