]> git.openstreetmap.org Git - nominatim.git/commitdiff
make DB helper functions free functions
authorSarah Hoffmann <lonvia@denofr.de>
Tue, 2 Jul 2024 13:15:50 +0000 (15:15 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 29 Jul 2024 06:49:30 +0000 (08:49 +0200)
Also changes the drop function so that it can drop multiple tables
at once.

30 files changed:
src/nominatim_db/clicmd/admin.py
src/nominatim_db/clicmd/refresh.py
src/nominatim_db/data/country_info.py
src/nominatim_db/db/connection.py
src/nominatim_db/db/properties.py
src/nominatim_db/db/sql_preprocessor.py
src/nominatim_db/db/status.py
src/nominatim_db/indexer/indexer.py
src/nominatim_db/tokenizer/icu_tokenizer.py
src/nominatim_db/tokenizer/legacy_tokenizer.py
src/nominatim_db/tools/admin.py
src/nominatim_db/tools/check_database.py
src/nominatim_db/tools/collect_os_info.py
src/nominatim_db/tools/database_import.py
src/nominatim_db/tools/freeze.py
src/nominatim_db/tools/migration.py
src/nominatim_db/tools/postcodes.py
src/nominatim_db/tools/refresh.py
src/nominatim_db/tools/replication.py
src/nominatim_db/tools/special_phrases/sp_importer.py
test/python/db/test_async_connection.py
test/python/db/test_connection.py
test/python/mock_icu_word_table.py
test/python/mock_legacy_word_table.py
test/python/tokenizer/test_icu.py
test/python/tools/test_database_import.py
test/python/tools/test_import_special_phrases.py
test/python/tools/test_migration.py
test/python/tools/test_refresh.py
test/python/tools/test_tiger_data.py

index 7744595bbd3f99b449e0ea09a1e4ba2840edf710..1edff174dc37ad251de9d9d048736098d717d2f0 100644 (file)
@@ -12,7 +12,7 @@ import argparse
 import random
 
 from ..errors import UsageError
-from ..db.connection import connect
+from ..db.connection import connect, table_exists
 from .args import NominatimArgs
 
 # Do not repeat documentation of subcommand classes.
@@ -115,7 +115,7 @@ class AdminFuncs:
 
                 tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
                 with connect(args.config.get_libpq_dsn()) as conn:
-                    if conn.table_exists('search_name'):
+                    if table_exists(conn, 'search_name'):
                         words = tokenizer.most_frequent_words(conn, 1000)
                     else:
                         words = []
index d5acf54b3a36a0e8b0832081a436ed4ce62c586f..363bad7853431851264f53eab83be7f277e46e07 100644 (file)
@@ -13,7 +13,7 @@ import logging
 from pathlib import Path
 
 from ..config import Configuration
-from ..db.connection import connect
+from ..db.connection import connect, table_exists
 from ..tokenizer.base import AbstractTokenizer
 from .args import NominatimArgs
 
@@ -124,7 +124,7 @@ class UpdateRefresh:
             with connect(args.config.get_libpq_dsn()) as conn:
                 # If the table did not exist before, then the importance code
                 # needs to be enabled.
-                if not conn.table_exists('secondary_importance'):
+                if not table_exists(conn, 'secondary_importance'):
                     args.functions = True
 
             LOG.warning('Import secondary importance raster data from %s', args.project_dir)
index c8002ee7023ca27a9e680003e92a66e3d8f0b428..e2bf5133f6403dd1acaf0a5e7069b2d3bac8ee7c 100644 (file)
@@ -9,10 +9,9 @@ Functions for importing and managing static country information.
 """
 from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload
 from pathlib import Path
-import psycopg2.extras
 
 from ..db import utils as db_utils
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, register_hstore
 from ..errors import UsageError
 from ..config import Configuration
 from ..tokenizer.base import AbstractTokenizer
@@ -129,8 +128,8 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals
 
             params.append((ccode, props['names'], lang, partition))
     with connect(dsn) as conn:
+        register_hstore(conn)
         with conn.cursor() as cur:
-            psycopg2.extras.register_hstore(cur)
             cur.execute(
                 """ CREATE TABLE public.country_name (
                         country_code character varying(2),
@@ -157,8 +156,8 @@ def create_country_names(conn: Connection, tokenizer: AbstractTokenizer,
         return ':' not in key or not languages or \
                key[key.index(':') + 1:] in languages
 
+    register_hstore(conn)
     with conn.cursor() as cur:
-        psycopg2.extras.register_hstore(cur)
         cur.execute("""SELECT country_code, name FROM country_name
                        WHERE country_code is not null""")
 
index 8faa3f93334dda7a219b4e5dcd1caf98052a6274..629fad6ac9598603be5819baea8abcc7c1ba8697 100644 (file)
@@ -7,7 +7,8 @@
 """
 Specialised connection and cursor functions.
 """
-from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
+from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\
+                   Tuple, Iterable
 import contextlib
 import logging
 import os
@@ -46,37 +47,6 @@ class Cursor(psycopg2.extras.DictCursor):
         psycopg2.extras.execute_values(self, sql, argslist, template=template)
 
 
-    def scalar(self, sql: Query, args: Any = None) -> Any:
-        """ Execute query that returns a single value. The value is returned.
-            If the query yields more than one row, a ValueError is raised.
-        """
-        self.execute(sql, args)
-
-        if self.rowcount != 1:
-            raise RuntimeError("Query did not return a single row.")
-
-        result = self.fetchone()
-        assert result is not None
-
-        return result[0]
-
-
-    def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
-        """ Drop the table with the given name.
-            Set `if_exists` to False if a non-existent table should raise
-            an exception instead of just being ignored. If 'cascade' is set
-            to True then all dependent tables are deleted as well.
-        """
-        sql = 'DROP TABLE '
-        if if_exists:
-            sql += 'IF EXISTS '
-        sql += '{}'
-        if cascade:
-            sql += ' CASCADE'
-
-        self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
-
-
 class Connection(psycopg2.extensions.connection):
     """ A connection that provides the specialised cursor by default and
         adds convenience functions for administrating the database.
@@ -99,80 +69,105 @@ class Connection(psycopg2.extensions.connection):
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
 
-    def table_exists(self, table: str) -> bool:
-        """ Check that a table with the given name exists in the database.
-        """
-        with self.cursor() as cur:
-            num = cur.scalar("""SELECT count(*) FROM pg_tables
-                                WHERE tablename = %s and schemaname = 'public'""", (table, ))
-            return num == 1 if isinstance(num, int) else False
+def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
+    """ Execute query that returns a single value. The value is returned.
+        If the query yields more than one row, a ValueError is raised.
+    """
+    with conn.cursor() as cur:
+        cur.execute(sql, args)
 
+        if cur.rowcount != 1:
+            raise RuntimeError("Query did not return a single row.")
 
-    def table_has_column(self, table: str, column: str) -> bool:
-        """ Check if the table 'table' exists and has a column with name 'column'.
-        """
-        with self.cursor() as cur:
-            has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
-                                       WHERE table_name = %s
-                                             and column_name = %s""",
-                                    (table, column))
-            return has_column > 0 if isinstance(has_column, int) else False
-
-
-    def index_exists(self, index: str, table: Optional[str] = None) -> bool:
-        """ Check that an index with the given name exists in the database.
-            If table is not None then the index must relate to the given
-            table.
-        """
-        with self.cursor() as cur:
-            cur.execute("""SELECT tablename FROM pg_indexes
-                           WHERE indexname = %s and schemaname = 'public'""", (index, ))
-            if cur.rowcount == 0:
+        result = cur.fetchone()
+
+    assert result is not None
+    return result[0]
+
+
+def table_exists(conn: Connection, table: str) -> bool:
+    """ Check that a table with the given name exists in the database.
+    """
+    num = execute_scalar(conn,
+            """SELECT count(*) FROM pg_tables
+               WHERE tablename = %s and schemaname = 'public'""", (table, ))
+    return num == 1 if isinstance(num, int) else False
+
+
+def table_has_column(conn: Connection, table: str, column: str) -> bool:
+    """ Check if the table 'table' exists and has a column with name 'column'.
+    """
+    has_column = execute_scalar(conn,
+                    """SELECT count(*) FROM information_schema.columns
+                       WHERE table_name = %s and column_name = %s""",
+                    (table, column))
+    return has_column > 0 if isinstance(has_column, int) else False
+
+
+def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
+    """ Check that an index with the given name exists in the database.
+        If table is not None then the index must relate to the given
+        table.
+    """
+    with conn.cursor() as cur:
+        cur.execute("""SELECT tablename FROM pg_indexes
+                       WHERE indexname = %s and schemaname = 'public'""", (index, ))
+        if cur.rowcount == 0:
+            return False
+
+        if table is not None:
+            row = cur.fetchone()
+            if row is None or not isinstance(row[0], str):
                 return False
+            return row[0] == table
 
-            if table is not None:
-                row = cur.fetchone()
-                if row is None or not isinstance(row[0], str):
-                    return False
-                return row[0] == table
+    return True
 
-        return True
+def drop_tables(conn: Connection, *names: str,
+               if_exists: bool = True, cascade: bool = False) -> None:
+    """ Drop one or more tables with the given names.
+        Set `if_exists` to False if a non-existent table should raise
+        an exception instead of just being ignored. `cascade` will cause
+        depended objects to be dropped as well.
+        The caller needs to take care of committing the change.
+    """
+    sql = pysql.SQL('DROP TABLE%s{}%s' % (
+                        ' IF EXISTS ' if if_exists else ' ',
+                        ' CASCADE' if cascade else ''))
 
+    with conn.cursor() as cur:
+        for name in names:
+            cur.execute(sql.format(pysql.Identifier(name)))
 
-    def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
-        """ Drop the table with the given name.
-            Set `if_exists` to False if a non-existent table should raise
-            an exception instead of just being ignored.
-        """
-        with self.cursor() as cur:
-            cur.drop_table(name, if_exists, cascade)
-        self.commit()
 
+def server_version_tuple(conn: Connection) -> Tuple[int, int]:
+    """ Return the server version as a tuple of (major, minor).
+        Converts correctly for pre-10 and post-10 PostgreSQL versions.
+    """
+    version = conn.server_version
+    if version < 100000:
+        return (int(version / 10000), int((version % 10000) / 100))
 
-    def server_version_tuple(self) -> Tuple[int, int]:
-        """ Return the server version as a tuple of (major, minor).
-            Converts correctly for pre-10 and post-10 PostgreSQL versions.
-        """
-        version = self.server_version
-        if version < 100000:
-            return (int(version / 10000), int((version % 10000) / 100))
+    return (int(version / 10000), version % 10000)
 
-        return (int(version / 10000), version % 10000)
 
+def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
+    """ Return the postgis version installed in the database as a
+        tuple of (major, minor). Assumes that the PostGIS extension
+        has been installed already.
+    """
+    version = execute_scalar(conn, 'SELECT postgis_lib_version()')
 
-    def postgis_version_tuple(self) -> Tuple[int, int]:
-        """ Return the postgis version installed in the database as a
-            tuple of (major, minor). Assumes that the PostGIS extension
-            has been installed already.
-        """
-        with self.cursor() as cur:
-            version = cur.scalar('SELECT postgis_lib_version()')
+    version_parts = version.split('.')
+    if len(version_parts) < 2:
+        raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
 
-        version_parts = version.split('.')
-        if len(version_parts) < 2:
-            raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
+    return (int(version_parts[0]), int(version_parts[1]))
 
-        return (int(version_parts[0]), int(version_parts[1]))
+def register_hstore(conn: Connection) -> None:
+    """ Register the hstore type with psycopg for the connection.
+    """
+    psycopg2.extras.register_hstore(conn)
 
 
 class ConnectionContext(ContextManager[Connection]):
index 3549382f944ad9e7f998b08190d2190b9c124d97..0e017ead900c0c2f6fa236e65d3ae927498e059b 100644 (file)
@@ -9,7 +9,7 @@ Query and access functions for the in-database property table.
 """
 from typing import Optional, cast
 
-from .connection import Connection
+from .connection import Connection, table_exists
 
 def set_property(conn: Connection, name: str, value: str) -> None:
     """ Add or replace the property with the given name.
@@ -31,7 +31,7 @@ def get_property(conn: Connection, name: str) -> Optional[str]:
     """ Return the current value of the given property or None if the property
         is not set.
     """
-    if not conn.table_exists('nominatim_properties'):
+    if not table_exists(conn, 'nominatim_properties'):
         return None
 
     with conn.cursor() as cur:
index 468f35107b8f5188a3cbc9b93c9fb3353ef5e322..691ab6c52cff562aafa8307ecea9d368ffa10c44 100644 (file)
@@ -10,7 +10,7 @@ Preprocessing of SQL files.
 from typing import Set, Dict, Any, cast
 import jinja2
 
-from .connection import Connection
+from .connection import Connection, server_version_tuple, postgis_version_tuple
 from .async_connection import WorkerPool
 from ..config import Configuration
 
@@ -66,8 +66,8 @@ def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
     """ Set up a dictionary with various optional Postgresql/Postgis features that
         depend on the database version.
     """
-    pg_version = conn.server_version_tuple()
-    postgis_version = conn.postgis_version_tuple()
+    pg_version = server_version_tuple(conn)
+    postgis_version = postgis_version_tuple(conn)
     pg11plus = pg_version >= (11, 0, 0)
     ps3 = postgis_version >= (3, 0)
     return {
index 1278359cc02ebae2ac1513162a0b49a0ef37d211..1d2b3bec5e6bdebeab68d7c1d266cdf06f9257e1 100644 (file)
@@ -12,7 +12,7 @@ import datetime as dt
 import logging
 import re
 
-from .connection import Connection
+from .connection import Connection, table_exists, execute_scalar
 from ..utils.url_utils import get_url
 from ..errors import UsageError
 from ..typing import TypedDict
@@ -34,7 +34,7 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
         data base.
     """
     # If there is a date from osm2pgsql available, use that.
-    if conn.table_exists('osm2pgsql_properties'):
+    if table_exists(conn, 'osm2pgsql_properties'):
         with conn.cursor() as cur:
             cur.execute(""" SELECT value FROM osm2pgsql_properties
                             WHERE property = 'current_timestamp' """)
@@ -47,15 +47,14 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
         raise UsageError("Cannot determine database date from data in offline mode.")
 
     # Else, find the node with the highest ID in the database
-    with conn.cursor() as cur:
-        if conn.table_exists('place'):
-            osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'")
-        else:
-            osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'")
-
-        if osmid is None:
-            LOG.fatal("No data found in the database.")
-            raise UsageError("No data found in the database.")
+    if table_exists(conn, 'place'):
+        osmid = execute_scalar(conn, "SELECT max(osm_id) FROM place WHERE osm_type='N'")
+    else:
+        osmid = execute_scalar(conn, "SELECT max(osm_id) FROM placex WHERE osm_type='N'")
+
+    if osmid is None:
+        LOG.fatal("No data found in the database.")
+        raise UsageError("No data found in the database.")
 
     LOG.info("Using node id %d for timestamp lookup", osmid)
     # Get the node from the API to find the timestamp when it was created.
index 5a219f6b2ecc3904b1129c9b4f9cff92c12132a5..b4c9732c624231df40215eb9e6ba543554c332e4 100644 (file)
@@ -15,7 +15,7 @@ import psycopg2.extras
 
 from ..typing import DictCursorResults
 from ..db.async_connection import DBConnection, WorkerPool
-from ..db.connection import connect, Connection, Cursor
+from ..db.connection import connect, Connection, Cursor, execute_scalar, register_hstore
 from ..tokenizer.base import AbstractTokenizer
 from .progress import ProgressLogger
 from . import runners
@@ -32,15 +32,15 @@ class PlaceFetcher:
         self.conn: Optional[DBConnection] = DBConnection(dsn,
                                                cursor_factory=psycopg2.extras.DictCursor)
 
-        with setup_conn.cursor() as cur:
-            # need to fetch those manually because register_hstore cannot
-            # fetch them on an asynchronous connection below.
-            hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
-            hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
+        # need to fetch those manually because register_hstore cannot
+        # fetch them on an asynchronous connection below.
+        hstore_oid = execute_scalar(setup_conn, "SELECT 'hstore'::regtype::oid")
+        hstore_array_oid = execute_scalar(setup_conn, "SELECT 'hstore[]'::regtype::oid")
 
         psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
                                         array_oid=hstore_array_oid)
 
+
     def close(self) -> None:
         """ Close the underlying asynchronous connection.
         """
@@ -205,10 +205,9 @@ class Indexer:
         LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
         with connect(self.dsn) as conn:
-            psycopg2.extras.register_hstore(conn)
-            with conn.cursor() as cur:
-                total_tuples = cur.scalar(runner.sql_count_objects())
-                LOG.debug("Total number of rows: %i", total_tuples)
+            register_hstore(conn)
+            total_tuples = execute_scalar(conn, runner.sql_count_objects())
+            LOG.debug("Total number of rows: %i", total_tuples)
 
             conn.commit()
 
index 22e2d048291bcaa6d566e0f5fd3cf92fe9fcf870..70c5c27a096dd1d6c8764370fa7ec6d504a4e99d 100644 (file)
@@ -16,7 +16,8 @@ import logging
 from pathlib import Path
 from textwrap import dedent
 
-from ..db.connection import connect, Connection, Cursor
+from ..db.connection import connect, Connection, Cursor, server_version_tuple,\
+                            drop_tables, table_exists, execute_scalar
 from ..config import Configuration
 from ..db.utils import CopyBuffer
 from ..db.sql_preprocessor import SQLPreprocessor
@@ -108,7 +109,7 @@ class ICUTokenizer(AbstractTokenizer):
         """ Recompute frequencies for all name words.
         """
         with connect(self.dsn) as conn:
-            if not conn.table_exists('search_name'):
+            if not table_exists(conn, 'search_name'):
                 return
 
             with conn.cursor() as cur:
@@ -117,10 +118,9 @@ class ICUTokenizer(AbstractTokenizer):
                     cur.execute('SET max_parallel_workers_per_gather TO %s',
                                 (min(threads, 6),))
 
-                if conn.server_version_tuple() < (12, 0):
+                if server_version_tuple(conn) < (12, 0):
                     LOG.info('Computing word frequencies')
-                    cur.drop_table('word_frequencies')
-                    cur.drop_table('addressword_frequencies')
+                    drop_tables(conn, 'word_frequencies', 'addressword_frequencies')
                     cur.execute("""CREATE TEMP TABLE word_frequencies AS
                                      SELECT unnest(name_vector) as id, count(*)
                                      FROM search_name GROUP BY id""")
@@ -152,17 +152,16 @@ class ICUTokenizer(AbstractTokenizer):
                                    $$ LANGUAGE plpgsql IMMUTABLE;
                                 """)
                     LOG.info('Update word table with recomputed frequencies')
-                    cur.drop_table('tmp_word')
+                    drop_tables(conn, 'tmp_word')
                     cur.execute("""CREATE TABLE tmp_word AS
                                     SELECT word_id, word_token, type, word,
                                            word_freq_update(word_id, info) as info
                                     FROM word
                                 """)
-                    cur.drop_table('word_frequencies')
-                    cur.drop_table('addressword_frequencies')
+                    drop_tables(conn, 'word_frequencies', 'addressword_frequencies')
                 else:
                     LOG.info('Computing word frequencies')
-                    cur.drop_table('word_frequencies')
+                    drop_tables(conn, 'word_frequencies')
                     cur.execute("""
                       CREATE TEMP TABLE word_frequencies AS
                       WITH word_freq AS MATERIALIZED (
@@ -182,7 +181,7 @@ class ICUTokenizer(AbstractTokenizer):
                     cur.execute('CREATE UNIQUE INDEX ON word_frequencies(id) INCLUDE(info)')
                     cur.execute('ANALYSE word_frequencies')
                     LOG.info('Update word table with recomputed frequencies')
-                    cur.drop_table('tmp_word')
+                    drop_tables(conn, 'tmp_word')
                     cur.execute("""CREATE TABLE tmp_word AS
                                     SELECT word_id, word_token, type, word,
                                            (CASE WHEN wf.info is null THEN word.info
@@ -191,7 +190,7 @@ class ICUTokenizer(AbstractTokenizer):
                                     FROM word LEFT JOIN word_frequencies wf
                                          ON word.word_id = wf.id
                                 """)
-                    cur.drop_table('word_frequencies')
+                    drop_tables(conn, 'word_frequencies')
 
             with conn.cursor() as cur:
                 cur.execute('SET max_parallel_workers_per_gather TO 0')
@@ -210,7 +209,7 @@ class ICUTokenizer(AbstractTokenizer):
         """ Remove unused house numbers.
         """
         with connect(self.dsn) as conn:
-            if not conn.table_exists('search_name'):
+            if not table_exists(conn, 'search_name'):
                 return
             with conn.cursor(name="hnr_counter") as cur:
                 cur.execute("""SELECT DISTINCT word_id, coalesce(info->>'lookup', word_token)
@@ -311,8 +310,7 @@ class ICUTokenizer(AbstractTokenizer):
             frequencies.
         """
         with connect(self.dsn) as conn:
-            with conn.cursor() as cur:
-                cur.drop_table('word')
+            drop_tables(conn, 'word')
             sqlp = SQLPreprocessor(conn, config)
             sqlp.run_string(conn, """
                 CREATE TABLE word (
@@ -370,8 +368,8 @@ class ICUTokenizer(AbstractTokenizer):
         """ Rename all tables and indexes used by the tokenizer.
         """
         with connect(self.dsn) as conn:
+            drop_tables(conn, 'word')
             with conn.cursor() as cur:
-                cur.drop_table('word')
                 cur.execute(f"ALTER TABLE {old} RENAME TO word")
                 for idx in ('word_token', 'word_id'):
                     cur.execute(f"""ALTER INDEX idx_{old}_{idx}
@@ -733,11 +731,10 @@ class ICUNameAnalyzer(AbstractAnalyzer):
             if norm_name:
                 result = self._cache.housenumbers.get(norm_name, result)
                 if result[0] is None:
-                    with self.conn.cursor() as cur:
-                        hid = cur.scalar("SELECT getorcreate_hnr_id(%s)", (norm_name, ))
+                    hid = execute_scalar(self.conn, "SELECT getorcreate_hnr_id(%s)", (norm_name, ))
 
-                        result = hid, norm_name
-                        self._cache.housenumbers[norm_name] = result
+                    result = hid, norm_name
+                    self._cache.housenumbers[norm_name] = result
         else:
             # Otherwise use the analyzer to determine the canonical name.
             # Per convention we use the first variant as the 'lookup name', the
@@ -748,11 +745,10 @@ class ICUNameAnalyzer(AbstractAnalyzer):
                 if result[0] is None:
                     variants = analyzer.compute_variants(word_id)
                     if variants:
-                        with self.conn.cursor() as cur:
-                            hid = cur.scalar("SELECT create_analyzed_hnr_id(%s, %s)",
+                        hid = execute_scalar(self.conn, "SELECT create_analyzed_hnr_id(%s, %s)",
                                              (word_id, list(variants)))
-                            result = hid, variants[0]
-                            self._cache.housenumbers[word_id] = result
+                        result = hid, variants[0]
+                        self._cache.housenumbers[word_id] = result
 
         return result
 
index 136a733159cd3180e2e0558e48b97141887165c1..0e8dfcf97fc07b51216e88862b13f6018043f082 100644 (file)
@@ -18,10 +18,10 @@ from textwrap import dedent
 
 from icu import Transliterator
 import psycopg2
-import psycopg2.extras
 
 from ..errors import UsageError
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, drop_tables, table_exists,\
+                            execute_scalar, register_hstore
 from ..config import Configuration
 from ..db import properties
 from ..db import utils as db_utils
@@ -179,11 +179,10 @@ class LegacyTokenizer(AbstractTokenizer):
              * Can nominatim.so be accessed by the database user?
              """
         with connect(self.dsn) as conn:
-            with conn.cursor() as cur:
-                try:
-                    out = cur.scalar("SELECT make_standard_name('a')")
-                except psycopg2.Error as err:
-                    return hint.format(error=str(err))
+            try:
+                out = execute_scalar(conn, "SELECT make_standard_name('a')")
+            except psycopg2.Error as err:
+                return hint.format(error=str(err))
 
         if out != 'a':
             return hint.format(error='Unexpected result for make_standard_name()')
@@ -214,9 +213,9 @@ class LegacyTokenizer(AbstractTokenizer):
         """ Recompute the frequency of full words.
         """
         with connect(self.dsn) as conn:
-            if conn.table_exists('search_name'):
+            if table_exists(conn, 'search_name'):
+                drop_tables(conn, "word_frequencies")
                 with conn.cursor() as cur:
-                    cur.drop_table("word_frequencies")
                     LOG.info("Computing word frequencies")
                     cur.execute("""CREATE TEMP TABLE word_frequencies AS
                                      SELECT unnest(name_vector) as id, count(*)
@@ -226,7 +225,7 @@ class LegacyTokenizer(AbstractTokenizer):
                     cur.execute("""UPDATE word SET search_name_count = count
                                    FROM word_frequencies
                                    WHERE word_token like ' %' and word_id = id""")
-                    cur.drop_table("word_frequencies")
+                drop_tables(conn, "word_frequencies")
             conn.commit()
 
 
@@ -316,7 +315,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
         self.conn: Optional[Connection] = connect(dsn).connection
         self.conn.autocommit = True
         self.normalizer = normalizer
-        psycopg2.extras.register_hstore(self.conn)
+        register_hstore(self.conn)
 
         self._cache = _TokenCache(self.conn)
 
@@ -536,9 +535,8 @@ class _TokenInfo:
     def add_names(self, conn: Connection, names: Mapping[str, str]) -> None:
         """ Add token information for the names of the place.
         """
-        with conn.cursor() as cur:
-            # Create the token IDs for all names.
-            self.data['names'] = cur.scalar("SELECT make_keywords(%s)::text",
+        # Create the token IDs for all names.
+        self.data['names'] = execute_scalar(conn, "SELECT make_keywords(%s)::text",
                                             (names, ))
 
 
@@ -576,9 +574,8 @@ class _TokenInfo:
         """ Add addr:street match terms.
         """
         def _get_street(name: str) -> Optional[str]:
-            with conn.cursor() as cur:
-                return cast(Optional[str],
-                            cur.scalar("SELECT word_ids_from_name(%s)::text", (name, )))
+            return cast(Optional[str],
+                        execute_scalar(conn, "SELECT word_ids_from_name(%s)::text", (name, )))
 
         tokens = self.cache.streets.get(street, _get_street)
         self.data['street'] = tokens or '{}'
index cea2ad664a7e339c2c66dd23790ab517dcd7b338..3e502199aa287cc6c3834f9c485c2aca7bebe91a 100644 (file)
@@ -10,12 +10,12 @@ Functions for database analysis and maintenance.
 from typing import Optional, Tuple, Any, cast
 import logging
 
-from psycopg2.extras import Json, register_hstore
+from psycopg2.extras import Json
 from psycopg2 import DataError
 
 from ..typing import DictCursorResult
 from ..config import Configuration
-from ..db.connection import connect, Cursor
+from ..db.connection import connect, Cursor, register_hstore
 from ..errors import UsageError
 from ..tokenizer import factory as tokenizer_factory
 from ..data.place_info import PlaceInfo
index ef28a0e5a61653b6a2f2c250137050130caff084..946f929138ca6453083d3eabf899a24d025b67ed 100644 (file)
@@ -12,7 +12,8 @@ from enum import Enum
 from textwrap import dedent
 
 from ..config import Configuration
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, server_version_tuple,\
+                            index_exists, table_exists, execute_scalar
 from ..db import properties
 from ..errors import UsageError
 from ..tokenizer import factory as tokenizer_factory
@@ -109,14 +110,14 @@ def _get_indexes(conn: Connection) -> List[str]:
                'idx_postcode_id',
                'idx_postcode_postcode'
               ]
-    if conn.table_exists('search_name'):
+    if table_exists(conn, 'search_name'):
         indexes.extend(('idx_search_name_nameaddress_vector',
                         'idx_search_name_name_vector',
                         'idx_search_name_centroid'))
-        if conn.server_version_tuple() >= (11, 0, 0):
+        if server_version_tuple(conn) >= (11, 0, 0):
             indexes.extend(('idx_placex_housenumber',
                             'idx_osmline_parent_osm_id_with_hnr'))
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         indexes.extend(('idx_location_area_country_place_id',
                         'idx_place_osm_unique',
                         'idx_placex_rank_address_sector',
@@ -153,7 +154,7 @@ def check_connection(conn: Any, config: Configuration) -> CheckResult:
 
              Hints:
              * Are you connecting to the correct database?
-             
+
              {instruction}
 
              Check the Migration chapter of the Administration Guide.
@@ -165,7 +166,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
     """ Checking database_version matches Nominatim software version
     """
 
-    if conn.table_exists('nominatim_properties'):
+    if table_exists(conn, 'nominatim_properties'):
         db_version_str = properties.get_property(conn, 'database_version')
     else:
         db_version_str = None
@@ -202,7 +203,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
 def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
     """ Checking for placex table
     """
-    if conn.table_exists('placex'):
+    if table_exists(conn, 'placex'):
         return CheckState.OK
 
     return CheckState.FATAL, dict(config=config)
@@ -212,8 +213,7 @@ def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
 def check_placex_size(conn: Connection, _: Configuration) -> CheckResult:
     """ Checking for placex content
     """
-    with conn.cursor() as cur:
-        cnt = cur.scalar('SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x')
+    cnt = execute_scalar(conn, 'SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x')
 
     return CheckState.OK if cnt > 0 else CheckState.FATAL
 
@@ -244,16 +244,15 @@ def check_tokenizer(_: Connection, config: Configuration) -> CheckResult:
 def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult:
     """ Checking for wikipedia/wikidata data
     """
-    if not conn.table_exists('search_name') or not conn.table_exists('place'):
+    if not table_exists(conn, 'search_name') or not table_exists(conn, 'place'):
         return CheckState.NOT_APPLICABLE
 
-    with conn.cursor() as cur:
-        if conn.table_exists('wikimedia_importance'):
-            cnt = cur.scalar('SELECT count(*) FROM wikimedia_importance')
-        else:
-            cnt = cur.scalar('SELECT count(*) FROM wikipedia_article')
+    if table_exists(conn, 'wikimedia_importance'):
+        cnt = execute_scalar(conn, 'SELECT count(*) FROM wikimedia_importance')
+    else:
+        cnt = execute_scalar(conn, 'SELECT count(*) FROM wikipedia_article')
 
-        return CheckState.WARN if cnt == 0 else CheckState.OK
+    return CheckState.WARN if cnt == 0 else CheckState.OK
 
 
 @_check(hint="""\
@@ -264,8 +263,7 @@ def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult
 def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
     """ Checking indexing status
     """
-    with conn.cursor() as cur:
-        cnt = cur.scalar('SELECT count(*) FROM placex WHERE indexed_status > 0')
+    cnt = execute_scalar(conn, 'SELECT count(*) FROM placex WHERE indexed_status > 0')
 
     if cnt == 0:
         return CheckState.OK
@@ -276,7 +274,7 @@ def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
             Low counts of unindexed places are fine."""
         return CheckState.WARN, dict(count=cnt, index_cmd=index_cmd)
 
-    if conn.index_exists('idx_placex_rank_search'):
+    if index_exists(conn, 'idx_placex_rank_search'):
         # Likely just an interrupted update.
         index_cmd = 'nominatim index'
     else:
@@ -297,7 +295,7 @@ def check_database_indexes(conn: Connection, _: Configuration) -> CheckResult:
     """
     missing = []
     for index in _get_indexes(conn):
-        if not conn.index_exists(index):
+        if not index_exists(conn, index):
             missing.append(index)
 
     if missing:
@@ -340,11 +338,10 @@ def check_tiger_table(conn: Connection, config: Configuration) -> CheckResult:
     if not config.get_bool('USE_US_TIGER_DATA'):
         return CheckState.NOT_APPLICABLE
 
-    if not conn.table_exists('location_property_tiger'):
+    if not table_exists(conn, 'location_property_tiger'):
         return CheckState.FAIL, dict(error='TIGER data table not found.')
 
-    with conn.cursor() as cur:
-        if cur.scalar('SELECT count(*) FROM location_property_tiger') == 0:
-            return CheckState.FAIL, dict(error='TIGER data table is empty.')
+    if execute_scalar(conn, 'SELECT count(*) FROM location_property_tiger') == 0:
+        return CheckState.FAIL, dict(error='TIGER data table is empty.')
 
     return CheckState.OK
index e1f8b16637a4cd02de87fb582f8932b8695cd1b1..db3e773d8559470ce610fe2acfa286bb7bfd1b8f 100644 (file)
@@ -12,21 +12,16 @@ import os
 import subprocess
 import sys
 from pathlib import Path
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Union
 
 import psutil
-from psycopg2.extensions import make_dsn, parse_dsn
+from psycopg2.extensions import make_dsn
 
 from ..config import Configuration
-from ..db.connection import connect
+from ..db.connection import connect, server_version_tuple, execute_scalar
 from ..version import NOMINATIM_VERSION
 
 
-def convert_version(ver_tup: Tuple[int, int]) -> str:
-    """converts tuple version (ver_tup) to a string representation"""
-    return ".".join(map(str, ver_tup))
-
-
 def friendly_memory_string(mem: float) -> str:
     """Create a user friendly string for the amount of memory specified as mem"""
     mem_magnitude = ("bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
@@ -103,16 +98,16 @@ def report_system_information(config: Configuration) -> None:
     storage, and database configuration."""
 
     with connect(make_dsn(config.get_libpq_dsn(), dbname='postgres')) as conn:
-        postgresql_ver: str = convert_version(conn.server_version_tuple())
+        postgresql_ver: str = '.'.join(map(str, server_version_tuple(conn)))
 
         with conn.cursor() as cur:
-            num = cur.scalar("SELECT count(*) FROM pg_catalog.pg_database WHERE datname=%s",
-                             (parse_dsn(config.get_libpq_dsn())['dbname'], ))
-            nominatim_db_exists = num == 1 if isinstance(num, int) else False
+            cur.execute("SELECT datname FROM pg_catalog.pg_database WHERE datname=%s",
+                        (config.get_database_params()['dbname'], ))
+            nominatim_db_exists = cur.rowcount > 0
 
     if nominatim_db_exists:
         with connect(config.get_libpq_dsn()) as conn:
-            postgis_ver: str = convert_version(conn.postgis_version_tuple())
+            postgis_ver: str = execute_scalar(conn, 'SELECT postgis_lib_version()')
     else:
         postgis_ver = "Unable to connect to database"
 
index c4b3023a8585d57d28ffee15f06d1c3ddc4f60f5..2398d4043a734ceb6b3cf45971e998133fe6456a 100644 (file)
@@ -19,7 +19,8 @@ from psycopg2 import sql as pysql
 
 from ..errors import UsageError
 from ..config import Configuration
-from ..db.connection import connect, get_pg_env, Connection
+from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\
+                            postgis_version_tuple, drop_tables, table_exists, execute_scalar
 from ..db.async_connection import DBConnection
 from ..db.sql_preprocessor import SQLPreprocessor
 from .exec_utils import run_osm2pgsql
@@ -51,10 +52,10 @@ def check_existing_database_plugins(dsn: str) -> None:
     """ Check that the database has the required plugins installed."""
     with connect(dsn) as conn:
         _require_version('PostgreSQL server',
-                         conn.server_version_tuple(),
+                         server_version_tuple(conn),
                          POSTGRESQL_REQUIRED_VERSION)
         _require_version('PostGIS',
-                         conn.postgis_version_tuple(),
+                         postgis_version_tuple(conn),
                          POSTGIS_REQUIRED_VERSION)
         _require_loaded('hstore', conn)
 
@@ -80,31 +81,30 @@ def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
 
     with connect(dsn) as conn:
         _require_version('PostgreSQL server',
-                         conn.server_version_tuple(),
+                         server_version_tuple(conn),
                          POSTGRESQL_REQUIRED_VERSION)
 
         if rouser is not None:
-            with conn.cursor() as cur:
-                cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
+            cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s',
                                  (rouser, ))
-                if cnt == 0:
-                    LOG.fatal("Web user '%s' does not exist. Create it with:\n"
-                              "\n      createuser %s", rouser, rouser)
-                    raise UsageError('Missing read-only user.')
+            if cnt == 0:
+                LOG.fatal("Web user '%s' does not exist. Create it with:\n"
+                          "\n      createuser %s", rouser, rouser)
+                raise UsageError('Missing read-only user.')
 
         # Create extensions.
         with conn.cursor() as cur:
             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
 
-            postgis_version = conn.postgis_version_tuple()
+            postgis_version = postgis_version_tuple(conn)
             if postgis_version[0] >= 3:
                 cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
 
         conn.commit()
 
         _require_version('PostGIS',
-                         conn.postgis_version_tuple(),
+                         postgis_version_tuple(conn),
                          POSTGIS_REQUIRED_VERSION)
 
 
@@ -141,7 +141,8 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]],
                     raise UsageError('No data imported by osm2pgsql.')
 
         if drop:
-            conn.drop_table('planet_osm_nodes')
+            drop_tables(conn, 'planet_osm_nodes')
+            conn.commit()
 
     if drop and options['flatnode_file']:
         Path(options['flatnode_file']).unlink()
@@ -184,7 +185,7 @@ def truncate_data_tables(conn: Connection) -> None:
         cur.execute('TRUNCATE location_property_tiger')
         cur.execute('TRUNCATE location_property_osmline')
         cur.execute('TRUNCATE location_postcode')
-        if conn.table_exists('search_name'):
+        if table_exists(conn, 'search_name'):
             cur.execute('TRUNCATE search_name')
         cur.execute('DROP SEQUENCE IF EXISTS seq_place')
         cur.execute('CREATE SEQUENCE seq_place start 100000')
index bd52ba9a5a7ad194c353a721d1529ea4cefa20cb..e6d80e1e156c3342feaabe747ec4183dcef02cf6 100644 (file)
@@ -12,7 +12,7 @@ from pathlib import Path
 
 from psycopg2 import sql as pysql
 
-from ..db.connection import Connection
+from ..db.connection import Connection, drop_tables, table_exists
 
 UPDATE_TABLES = [
     'address_levels',
@@ -39,9 +39,7 @@ def drop_update_tables(conn: Connection) -> None:
                     + pysql.SQL(' or ').join(parts))
         tables = [r[0] for r in cur]
 
-        for table in tables:
-            cur.drop_table(table, cascade=True)
-
+    drop_tables(conn, *tables, cascade=True)
     conn.commit()
 
 
@@ -55,4 +53,4 @@ def is_frozen(conn: Connection) -> bool:
     """ Returns true if database is in a frozen state
     """
 
-    return conn.table_exists('place') is False
+    return table_exists(conn, 'place') is False
index e6803c7dea23521a805aa939584d028f66b4a7b5..46ba0125bc90a73ecafed806032aa5af13af16f2 100644 (file)
@@ -15,7 +15,8 @@ from psycopg2 import sql as pysql
 from ..errors import UsageError
 from ..config import Configuration
 from ..db import properties
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, server_version_tuple,\
+                            table_has_column, table_exists, execute_scalar, register_hstore
 from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version
 from ..tokenizer import factory as tokenizer_factory
 from . import refresh
@@ -29,7 +30,8 @@ def migrate(config: Configuration, paths: Any) -> int:
         if necesssary.
     """
     with connect(config.get_libpq_dsn()) as conn:
-        if conn.table_exists('nominatim_properties'):
+        register_hstore(conn)
+        if table_exists(conn, 'nominatim_properties'):
             db_version_str = properties.get_property(conn, 'database_version')
         else:
             db_version_str = None
@@ -72,16 +74,15 @@ def _guess_version(conn: Connection) -> NominatimVersion:
         Only migrations for 3.6 and later are supported, so bail out
         when the version seems older.
     """
-    with conn.cursor() as cur:
-        # In version 3.6, the country_name table was updated. Check for that.
-        cnt = cur.scalar("""SELECT count(*) FROM
-                            (SELECT svals(name) FROM  country_name
-                             WHERE country_code = 'gb')x;
-                         """)
-        if cnt < 100:
-            LOG.fatal('It looks like your database was imported with a version '
-                      'prior to 3.6.0. Automatic migration not possible.')
-            raise UsageError('Migration not possible.')
+    # In version 3.6, the country_name table was updated. Check for that.
+    cnt = execute_scalar(conn, """SELECT count(*) FROM
+                                  (SELECT svals(name) FROM  country_name
+                                  WHERE country_code = 'gb')x;
+                               """)
+    if cnt < 100:
+        LOG.fatal('It looks like your database was imported with a version '
+                  'prior to 3.6.0. Automatic migration not possible.')
+        raise UsageError('Migration not possible.')
 
     return NominatimVersion(3, 5, 0, 99)
 
@@ -125,7 +126,7 @@ def import_status_timestamp_change(conn: Connection, **_: Any) -> None:
 def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None:
     """ Add nominatim_property table.
     """
-    if not conn.table_exists('nominatim_properties'):
+    if not table_exists(conn, 'nominatim_properties'):
         with conn.cursor() as cur:
             cur.execute(pysql.SQL("""CREATE TABLE nominatim_properties (
                                         property TEXT,
@@ -189,13 +190,9 @@ def install_legacy_tokenizer(conn: Connection, config: Configuration, **_: Any)
         configuration for the backwards-compatible legacy tokenizer
     """
     if properties.get_property(conn, 'tokenizer') is None:
-        with conn.cursor() as cur:
-            for table in ('placex', 'location_property_osmline'):
-                has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
-                                           WHERE table_name = %s
-                                           and column_name = 'token_info'""",
-                                        (table, ))
-                if has_column == 0:
+        for table in ('placex', 'location_property_osmline'):
+            if not table_has_column(conn, table, 'token_info'):
+                with conn.cursor() as cur:
                     cur.execute(pysql.SQL('ALTER TABLE {} ADD COLUMN token_info JSONB')
                                 .format(pysql.Identifier(table)))
         tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False,
@@ -212,7 +209,7 @@ def create_tiger_housenumber_index(conn: Connection, **_: Any) -> None:
         The inclusion is needed for efficient lookup of housenumbers in
         full address searches.
     """
-    if conn.server_version_tuple() >= (11, 0, 0):
+    if server_version_tuple(conn) >= (11, 0, 0):
         with conn.cursor() as cur:
             cur.execute(""" CREATE INDEX IF NOT EXISTS
                                 idx_location_property_tiger_housenumber_migrated
@@ -239,7 +236,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
         Also converts the data into the stricter format which requires that
         startnumbers comply with the odd/even requirements.
     """
-    if conn.table_has_column('location_property_osmline', 'step'):
+    if table_has_column(conn, 'location_property_osmline', 'step'):
         return
 
     with conn.cursor() as cur:
@@ -271,7 +268,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
 def add_step_column_for_tiger(conn: Connection, **_: Any) -> None:
     """ Add a new column 'step' to the tiger data table.
     """
-    if conn.table_has_column('location_property_tiger', 'step'):
+    if table_has_column(conn, 'location_property_tiger', 'step'):
         return
 
     with conn.cursor() as cur:
@@ -287,7 +284,7 @@ def add_derived_name_column_for_country_names(conn: Connection, **_: Any) -> Non
     """ Add a new column 'derived_name' which in the future takes the
         country names as imported from OSM data.
     """
-    if not conn.table_has_column('country_name', 'derived_name'):
+    if not table_has_column(conn, 'country_name', 'derived_name'):
         with conn.cursor() as cur:
             cur.execute("ALTER TABLE country_name ADD COLUMN derived_name public.HSTORE")
 
@@ -297,12 +294,9 @@ def mark_internal_country_names(conn: Connection, config: Configuration, **_: An
     """ Names from the country table should be marked as internal to prevent
         them from being deleted. Only necessary for ICU tokenizer.
     """
-    import psycopg2.extras # pylint: disable=import-outside-toplevel
-
     tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
     with tokenizer.name_analyzer() as analyzer:
         with conn.cursor() as cur:
-            psycopg2.extras.register_hstore(cur)
             cur.execute("SELECT country_code, name FROM country_name")
 
             for country_code, names in cur:
@@ -319,7 +313,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
         The table is only necessary when updates are possible, i.e.
         the database is not in freeze mode.
     """
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         with conn.cursor() as cur:
             cur.execute("""CREATE TABLE IF NOT EXISTS place_to_be_deleted (
                              osm_type CHAR(1),
@@ -333,7 +327,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
 def split_pending_index(conn: Connection, **_: Any) -> None:
     """ Reorganise indexes for pending updates.
     """
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS idx_placex_rank_address_sector
                            ON placex USING BTREE (rank_address, geometry_sector)
@@ -349,7 +343,7 @@ def split_pending_index(conn: Connection, **_: Any) -> None:
 def enable_forward_dependencies(conn: Connection, **_: Any) -> None:
     """ Create indexes for updates with forward dependency tracking (long-running).
     """
-    if conn.table_exists('planet_osm_ways'):
+    if table_exists(conn, 'planet_osm_ways'):
         with conn.cursor() as cur:
             cur.execute("""SELECT * FROM pg_indexes
                            WHERE tablename = 'planet_osm_ways'
@@ -398,7 +392,7 @@ def create_postcode_area_lookup_index(conn: Connection, **_: Any) -> None:
 def create_postcode_parent_index(conn: Connection, **_: Any) -> None:
     """ Create index needed for updating postcodes when a parent changes.
     """
-    if conn.table_exists('planet_osm_ways'):
+    if table_exists(conn, 'planet_osm_ways'):
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS
                              idx_location_postcode_parent_place_id
index 8dc5bdbdb43cacb46fec70a86d553734a4ef92d7..a5d8ef8becf0ac53f562dbd45acc856cd4b8cfb9 100644 (file)
@@ -18,7 +18,7 @@ from math import isfinite
 
 from psycopg2 import sql as pysql
 
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, table_exists
 from ..utils.centroid import PointsCentroid
 from ..data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher
 from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
@@ -231,4 +231,4 @@ def can_compute(dsn: str) -> bool:
         postcodes can be computed.
     """
     with connect(dsn) as conn:
-        return conn.table_exists('place')
+        return table_exists(conn, 'place')
index 6a40c0a73a4c6427d11fe89b0db0552c19a1e7cf..2e2ffabd07c60ea959bf1cb9aeaf9efa4cfe2b21 100644 (file)
@@ -17,7 +17,8 @@ from pathlib import Path
 from psycopg2 import sql as pysql
 
 from ..config import Configuration
-from ..db.connection import Connection, connect
+from ..db.connection import Connection, connect, postgis_version_tuple,\
+                            drop_tables, table_exists
 from ..db.utils import execute_file, CopyBuffer
 from ..db.sql_preprocessor import SQLPreprocessor
 from ..version import NOMINATIM_VERSION
@@ -56,9 +57,9 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s
     for entry in levels:
         _add_address_level_rows_from_entry(rows, entry)
 
-    with conn.cursor() as cur:
-        cur.drop_table(table)
+    drop_tables(conn, table)
 
+    with conn.cursor() as cur:
         cur.execute(pysql.SQL("""CREATE TABLE {} (
                                         country_code varchar(2),
                                         class TEXT,
@@ -159,10 +160,8 @@ def import_importance_csv(dsn: str, data_file: Path) -> int:
     wd_done = set()
 
     with connect(dsn) as conn:
+        drop_tables(conn, 'wikipedia_article', 'wikipedia_redirect', 'wikimedia_importance')
         with conn.cursor() as cur:
-            cur.drop_table('wikipedia_article')
-            cur.drop_table('wikipedia_redirect')
-            cur.drop_table('wikimedia_importance')
             cur.execute("""CREATE TABLE wikimedia_importance (
                              language TEXT NOT NULL,
                              title TEXT NOT NULL,
@@ -228,7 +227,7 @@ def import_secondary_importance(dsn: str, data_path: Path, ignore_errors: bool =
         return 1
 
     with connect(dsn) as conn:
-        postgis_version = conn.postgis_version_tuple()
+        postgis_version = postgis_version_tuple(conn)
         if postgis_version[0] < 3:
             LOG.error('PostGIS version is too old for using OSM raster data.')
             return 2
@@ -309,7 +308,7 @@ def setup_website(basedir: Path, config: Configuration, conn: Connection) -> Non
 
     template = "\nrequire_once(CONST_LibDir.'/website/{}');\n"
 
-    search_name_table_exists = bool(conn and conn.table_exists('search_name'))
+    search_name_table_exists = bool(conn and table_exists(conn, 'search_name'))
 
     for script in WEBSITE_SCRIPTS:
         if not search_name_table_exists and script == 'search.php':
index bf1189df032a85a40a985c20e783edfd74410e29..2b1d444f0cb1e3cd08b6540273ac9b64d136e43f 100644 (file)
@@ -20,7 +20,7 @@ import requests
 
 from ..errors import UsageError
 from ..db import status
-from ..db.connection import Connection, connect
+from ..db.connection import Connection, connect, server_version_tuple
 from .exec_utils import run_osm2pgsql
 
 try:
@@ -155,7 +155,7 @@ def run_osm2pgsql_updates(conn: Connection, options: MutableMapping[str, Any]) -
 
     # Consume updates with osm2pgsql.
     options['append'] = True
-    options['disable_jit'] = conn.server_version_tuple() >= (11, 0)
+    options['disable_jit'] = server_version_tuple(conn) >= (11, 0)
     run_osm2pgsql(options)
 
     # Handle deletions
index 1bdcdaf133d5d3db10fdb18e23c2b3f6ea24abce..4a63ff149507641f1e3fa2597e38c6570d6d93d4 100644 (file)
@@ -21,7 +21,7 @@ from psycopg2.sql import Identifier, SQL
 
 from ...typing import Protocol
 from ...config import Configuration
-from ...db.connection import Connection
+from ...db.connection import Connection, drop_tables, index_exists
 from .importer_statistics import SpecialPhrasesImporterStatistics
 from .special_phrase import SpecialPhrase
 from ...tokenizer.base import AbstractTokenizer
@@ -233,7 +233,7 @@ class SPImporter():
         index_prefix = f'idx_place_classtype_{phrase_class}_{phrase_type}_'
         base_table = _classtype_table(phrase_class, phrase_type)
         # Index on centroid
-        if not self.db_connection.index_exists(index_prefix + 'centroid'):
+        if not index_exists(self.db_connection, index_prefix + 'centroid'):
             with self.db_connection.cursor() as db_cursor:
                 db_cursor.execute(SQL("CREATE INDEX {} ON {} USING GIST (centroid) {}")
                                   .format(Identifier(index_prefix + 'centroid'),
@@ -241,7 +241,7 @@ class SPImporter():
                                           SQL(sql_tablespace)))
 
         # Index on place_id
-        if not self.db_connection.index_exists(index_prefix + 'place_id'):
+        if not index_exists(self.db_connection, index_prefix + 'place_id'):
             with self.db_connection.cursor() as db_cursor:
                 db_cursor.execute(SQL("CREATE INDEX {} ON {} USING btree(place_id) {}")
                                   .format(Identifier(index_prefix + 'place_id'),
@@ -259,6 +259,7 @@ class SPImporter():
                               .format(Identifier(table_name),
                                       Identifier(self.config.DATABASE_WEBUSER)))
 
+
     def _remove_non_existent_tables_from_db(self) -> None:
         """
             Remove special phrases which doesn't exist on the wiki anymore.
@@ -268,7 +269,6 @@ class SPImporter():
 
         # Delete place_classtype tables corresponding to class/type which
         # are not on the wiki anymore.
-        with self.db_connection.cursor() as db_cursor:
-            for table in self.table_phrases_to_delete:
-                self.statistics_handler.notify_one_table_deleted()
-                db_cursor.drop_table(table)
+        drop_tables(self.db_connection, *self.table_phrases_to_delete)
+        for _ in self.table_phrases_to_delete:
+            self.statistics_handler.notify_one_table_deleted()
index fff695e53ebc97520ba9893f6fff470846bf7269..9647bedc802471bd41808d297f0698a29bafda50 100644 (file)
@@ -33,13 +33,13 @@ def simple_conns(temp_db):
     conn2.close()
 
 
-def test_simple_query(conn, temp_db_conn):
+def test_simple_query(conn, temp_db_cursor):
     conn.connect()
 
     conn.perform('CREATE TABLE foo (id INT)')
     conn.wait()
 
-    temp_db_conn.table_exists('foo')
+    assert temp_db_cursor.table_exists('foo')
 
 
 def test_wait_for_query(conn):
index 8b4cc62f0c46e6f8bd46db49bc5d3f3ab58cae58..9f1442f3aba6bcdbf523c618cd9184434cf99522 100644 (file)
@@ -10,61 +10,74 @@ Tests for specialised connection and cursor classes.
 import pytest
 import psycopg2
 
-from nominatim_db.db.connection import connect, get_pg_env
+import nominatim_db.db.connection as nc
 
 @pytest.fixture
 def db(dsn):
-    with connect(dsn) as conn:
+    with nc.connect(dsn) as conn:
         yield conn
 
 
 def test_connection_table_exists(db, table_factory):
-    assert not db.table_exists('foobar')
+    assert not nc.table_exists(db, 'foobar')
 
     table_factory('foobar')
 
-    assert db.table_exists('foobar')
+    assert nc.table_exists(db, 'foobar')
 
 
 def test_has_column_no_table(db):
-    assert not db.table_has_column('sometable', 'somecolumn')
+    assert not nc.table_has_column(db, 'sometable', 'somecolumn')
 
 
 @pytest.mark.parametrize('name,result', [('tram', True), ('car', False)])
 def test_has_column(db, table_factory, name, result):
     table_factory('stuff', 'tram TEXT')
 
-    assert db.table_has_column('stuff', name) == result
+    assert nc.table_has_column(db, 'stuff', name) == result
 
 def test_connection_index_exists(db, table_factory, temp_db_cursor):
-    assert not db.index_exists('some_index')
+    assert not nc.index_exists(db, 'some_index')
 
     table_factory('foobar')
     temp_db_cursor.execute('CREATE INDEX some_index ON foobar(id)')
 
-    assert db.index_exists('some_index')
-    assert db.index_exists('some_index', table='foobar')
-    assert not db.index_exists('some_index', table='bar')
+    assert nc.index_exists(db, 'some_index')
+    assert nc.index_exists(db, 'some_index', table='foobar')
+    assert not nc.index_exists(db, 'some_index', table='bar')
 
 
 def test_drop_table_existing(db, table_factory):
     table_factory('dummy')
-    assert db.table_exists('dummy')
+    assert nc.table_exists(db, 'dummy')
 
-    db.drop_table('dummy')
-    assert not db.table_exists('dummy')
+    nc.drop_tables(db, 'dummy')
+    assert not nc.table_exists(db, 'dummy')
 
 
-def test_drop_table_non_existsing(db):
-    db.drop_table('dfkjgjriogjigjgjrdghehtre')
+def test_drop_table_non_existing(db):
+    nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre')
+
+
+def test_drop_many_tables(db, table_factory):
+    tables = [f'table{n}' for n in range(5)]
+
+    for t in tables:
+        table_factory(t)
+        assert nc.table_exists(db, t)
+
+    nc.drop_tables(db, *tables)
+
+    for t in tables:
+        assert not nc.table_exists(db, t)
 
 
 def test_drop_table_non_existing_force(db):
     with pytest.raises(psycopg2.ProgrammingError, match='.*does not exist.*'):
-        db.drop_table('dfkjgjriogjigjgjrdghehtre', if_exists=False)
+        nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre', if_exists=False)
 
 def test_connection_server_version_tuple(db):
-    ver = db.server_version_tuple()
+    ver = nc.server_version_tuple(db)
 
     assert isinstance(ver, tuple)
     assert len(ver) == 2
@@ -72,7 +85,7 @@ def test_connection_server_version_tuple(db):
 
 
 def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
-    ver = db.postgis_version_tuple()
+    ver = nc.postgis_version_tuple(db)
 
     assert isinstance(ver, tuple)
     assert len(ver) == 2
@@ -82,27 +95,24 @@ def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
 def test_cursor_scalar(db, table_factory):
     table_factory('dummy')
 
-    with db.cursor() as cur:
-        assert cur.scalar('SELECT count(*) FROM dummy') == 0
+    assert nc.execute_scalar(db, 'SELECT count(*) FROM dummy') == 0
 
 
 def test_cursor_scalar_many_rows(db):
-    with db.cursor() as cur:
-        with pytest.raises(RuntimeError):
-            cur.scalar('SELECT * FROM pg_tables')
+    with pytest.raises(RuntimeError, match='Query did not return a single row.'):
+        nc.execute_scalar(db, 'SELECT * FROM pg_tables')
 
 
 def test_cursor_scalar_no_rows(db, table_factory):
     table_factory('dummy')
 
-    with db.cursor() as cur:
-        with pytest.raises(RuntimeError):
-            cur.scalar('SELECT id FROM dummy')
+    with pytest.raises(RuntimeError, match='Query did not return a single row.'):
+        nc.execute_scalar(db, 'SELECT id FROM dummy')
 
 
 def test_get_pg_env_add_variable(monkeypatch):
     monkeypatch.delenv('PGPASSWORD', raising=False)
-    env = get_pg_env('user=fooF')
+    env = nc.get_pg_env('user=fooF')
 
     assert env['PGUSER'] == 'fooF'
     assert 'PGPASSWORD' not in env
@@ -110,12 +120,12 @@ def test_get_pg_env_add_variable(monkeypatch):
 
 def test_get_pg_env_overwrite_variable(monkeypatch):
     monkeypatch.setenv('PGUSER', 'some default')
-    env = get_pg_env('user=overwriter')
+    env = nc.get_pg_env('user=overwriter')
 
     assert env['PGUSER'] == 'overwriter'
 
 
 def test_get_pg_env_ignore_unknown():
-    env = get_pg_env('client_encoding=stuff', base_env={})
+    env = nc.get_pg_env('client_encoding=stuff', base_env={})
 
     assert env == {}
index 5c465e8b7e988b6488756f9a8fb8a5c54278ba94..67be1892d4504e6b02e1afdf4848a89dfccf440a 100644 (file)
@@ -8,6 +8,7 @@
 Legacy word table for testing with functions to prefil and test contents
 of the table.
 """
+from nominatim_db.db.connection import execute_scalar
 
 class MockIcuWordTable:
     """ A word table for testing using legacy word table structure.
@@ -77,18 +78,15 @@ class MockIcuWordTable:
 
 
     def count(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word")
 
 
     def count_special(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word WHERE type = 'S'")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'S'")
 
 
     def count_housenumbers(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word WHERE type = 'H'")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'H'")
 
 
     def get_special(self):
index 9804341f41c33fd91df794e15fcf935f3f00d395..d1c523eb7d9a4c766f8d4057518c5c3945d91b52 100644 (file)
@@ -8,6 +8,7 @@
 Legacy word table for testing with functions to prefil and test contents
 of the table.
 """
+from nominatim_db.db.connection import execute_scalar
 
 class MockLegacyWordTable:
     """ A word table for testing using legacy word table structure.
@@ -58,13 +59,11 @@ class MockLegacyWordTable:
 
 
     def count(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word")
 
 
     def count_special(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word WHERE class != 'place'")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE class != 'place'")
 
 
     def get_special(self):
index 357b7d4ae4cde571315c2a7d8109595911e343f4..a2bf676699ec9a326f59957740740899da8340dc 100644 (file)
@@ -199,16 +199,16 @@ def test_update_sql_functions(db_prop, temp_db_cursor,
     assert test_content == set((('1133', ), ))
 
 
-def test_finalize_import(tokenizer_factory, temp_db_conn,
-                         temp_db_cursor, test_config, sql_preprocessor_cfg):
+def test_finalize_import(tokenizer_factory, temp_db_cursor,
+                         test_config, sql_preprocessor_cfg):
     tok = tokenizer_factory()
     tok.init_new_db(test_config)
 
-    assert not temp_db_conn.index_exists('idx_word_word_id')
+    assert not temp_db_cursor.index_exists('word', 'idx_word_word_id')
 
     tok.finalize_import(test_config)
 
-    assert temp_db_conn.index_exists('idx_word_word_id')
+    assert temp_db_cursor.index_exists('word', 'idx_word_word_id')
 
 
 def test_check_database(test_config, tokenizer_factory,
index 548ec800b6b8f966427911dd2d91a90d19c8cb64..9d56efa0fd744a5a736649d0f2970b5a95ea794b 100644 (file)
@@ -132,7 +132,7 @@ def test_import_osm_data_simple_ignore_no_data(table_factory, osm2pgsql_options)
                                     ignore_errors=True)
 
 
-def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_options):
+def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql_options):
     table_factory('place', content=((1, ), ))
     table_factory('planet_osm_nodes')
 
@@ -144,7 +144,7 @@ def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_o
     database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options, drop=True)
 
     assert not flatfile.exists()
-    assert not temp_db_conn.table_exists('planet_osm_nodes')
+    assert not temp_db_cursor.table_exists('planet_osm_nodes')
 
 
 def test_import_osm_data_default_cache(table_factory, osm2pgsql_options, capfd):
index 64eb7b1856b0996c3958ae71e51b3ecd3fbbdbb4..0d33e6e0f30e4a0ab47bfb02a4cd63f2ad4c55a9 100644 (file)
@@ -75,7 +75,8 @@ def test_load_white_and_black_lists(sp_importer):
     assert isinstance(black_list, dict) and isinstance(white_list, dict)
 
 
-def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn,
+def test_create_place_classtype_indexes(temp_db_with_extensions,
+                                        temp_db_conn, temp_db_cursor,
                                         table_factory, sp_importer):
     """
         Test that _create_place_classtype_indexes() create the
@@ -88,10 +89,11 @@ def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn,
     table_factory(table_name, 'place_id BIGINT, centroid GEOMETRY')
 
     sp_importer._create_place_classtype_indexes('', phrase_class, phrase_type)
+    temp_db_conn.commit()
 
-    assert check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type)
+    assert check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type)
 
-def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer):
+def test_create_place_classtype_table(temp_db_conn, temp_db_cursor, placex_table, sp_importer):
     """
         Test that _create_place_classtype_table() create
         the right place_classtype table.
@@ -99,10 +101,12 @@ def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer):
     phrase_class = 'class'
     phrase_type = 'type'
     sp_importer._create_place_classtype_table('', phrase_class, phrase_type)
+    temp_db_conn.commit()
 
-    assert check_table_exist(temp_db_conn, phrase_class, phrase_type)
+    assert check_table_exist(temp_db_cursor, phrase_class, phrase_type)
 
-def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_importer):
+def test_grant_access_to_web_user(temp_db_conn, temp_db_cursor, table_factory,
+                                  def_config, sp_importer):
     """
         Test that _grant_access_to_webuser() give
         right access to the web user.
@@ -114,12 +118,13 @@ def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_im
     table_factory(table_name)
 
     sp_importer._grant_access_to_webuser(phrase_class, phrase_type)
+    temp_db_conn.commit()
 
-    assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, phrase_class, phrase_type)
+    assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, phrase_class, phrase_type)
 
 def test_create_place_classtype_table_and_indexes(
-        temp_db_conn, def_config, placex_table,
-        sp_importer):
+        temp_db_cursor, def_config, placex_table,
+        sp_importer, temp_db_conn):
     """
         Test that _create_place_classtype_table_and_indexes()
         create the right place_classtype tables and place_id indexes
@@ -129,14 +134,15 @@ def test_create_place_classtype_table_and_indexes(
     pairs = set([('class1', 'type1'), ('class2', 'type2')])
 
     sp_importer._create_classtype_table_and_indexes(pairs)
+    temp_db_conn.commit()
 
     for pair in pairs:
-        assert check_table_exist(temp_db_conn, pair[0], pair[1])
-        assert check_placeid_and_centroid_indexes(temp_db_conn, pair[0], pair[1])
-        assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, pair[0], pair[1])
+        assert check_table_exist(temp_db_cursor, pair[0], pair[1])
+        assert check_placeid_and_centroid_indexes(temp_db_cursor, pair[0], pair[1])
+        assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, pair[0], pair[1])
 
 def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
-                                            temp_db_conn):
+                                            temp_db_conn, temp_db_cursor):
     """
         Check for the remove_non_existent_phrases_from_db() method.
 
@@ -159,15 +165,14 @@ def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
     """
 
     sp_importer._remove_non_existent_tables_from_db()
+    temp_db_conn.commit()
 
-    # Changes are not committed yet. Use temp_db_conn for checking results.
-    with temp_db_conn.cursor(cursor_factory=CursorForTesting) as cur:
-        assert cur.row_set(query_tables) \
+    assert temp_db_cursor.row_set(query_tables) \
                  == {('place_classtype_testclasstypetable_to_keep', )}
 
 
 @pytest.mark.parametrize("should_replace", [(True), (False)])
-def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer,
+def test_import_phrases(monkeypatch, temp_db_cursor, def_config, sp_importer,
                         placex_table, table_factory, tokenizer_mock,
                         xml_wiki_content, should_replace):
     """
@@ -193,49 +198,49 @@ def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer,
     class_test = 'aerialway'
     type_test = 'zip_line'
 
-    assert check_table_exist(temp_db_conn, class_test, type_test)
-    assert check_placeid_and_centroid_indexes(temp_db_conn, class_test, type_test)
-    assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, class_test, type_test)
-    assert check_table_exist(temp_db_conn, 'amenity', 'animal_shelter')
+    assert check_table_exist(temp_db_cursor, class_test, type_test)
+    assert check_placeid_and_centroid_indexes(temp_db_cursor, class_test, type_test)
+    assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, class_test, type_test)
+    assert check_table_exist(temp_db_cursor, 'amenity', 'animal_shelter')
     if should_replace:
-        assert not check_table_exist(temp_db_conn, 'wrong_class', 'wrong_type')
+        assert not check_table_exist(temp_db_cursor, 'wrong_class', 'wrong_type')
 
-    assert temp_db_conn.table_exists('place_classtype_amenity_animal_shelter')
+    assert temp_db_cursor.table_exists('place_classtype_amenity_animal_shelter')
     if should_replace:
-        assert not temp_db_conn.table_exists('place_classtype_wrongclass_wrongtype')
+        assert not temp_db_cursor.table_exists('place_classtype_wrongclass_wrongtype')
 
-def check_table_exist(temp_db_conn, phrase_class, phrase_type):
+def check_table_exist(temp_db_cursor, phrase_class, phrase_type):
     """
         Verify that the place_classtype table exists for the given
         phrase_class and phrase_type.
     """
-    return temp_db_conn.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type))
+    return temp_db_cursor.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type))
 
 
-def check_grant_access(temp_db_conn, user, phrase_class, phrase_type):
+def check_grant_access(temp_db_cursor, user, phrase_class, phrase_type):
     """
         Check that the web user has been granted right access to the
         place_classtype table of the given phrase_class and phrase_type.
     """
     table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
 
-    with temp_db_conn.cursor() as temp_db_cursor:
-        temp_db_cursor.execute("""
-                SELECT * FROM information_schema.role_table_grants
-                WHERE table_name='{}'
-                AND grantee='{}'
-                AND privilege_type='SELECT'""".format(table_name, user))
-        return temp_db_cursor.fetchone()
+    temp_db_cursor.execute("""
+            SELECT * FROM information_schema.role_table_grants
+            WHERE table_name='{}'
+            AND grantee='{}'
+            AND privilege_type='SELECT'""".format(table_name, user))
+    return temp_db_cursor.fetchone()
 
-def check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type):
+def check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type):
     """
         Check that the place_id index and centroid index exist for the
         place_classtype table of the given phrase_class and phrase_type.
     """
+    table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
     index_prefix = 'idx_place_classtype_{}_{}_'.format(phrase_class, phrase_type)
 
     return (
-        temp_db_conn.index_exists(index_prefix + 'centroid')
+        temp_db_cursor.index_exists(table_name, index_prefix + 'centroid')
         and
-        temp_db_conn.index_exists(index_prefix + 'place_id')
+        temp_db_cursor.index_exists(table_name, index_prefix + 'place_id')
     )
index 3a849adbf956ed8469730c9cb6eb2d2d377eff0c..8821f694c927d5becec13db65972782f80fee4e8 100644 (file)
@@ -12,6 +12,7 @@ import psycopg2.extras
 
 from nominatim_db.tools import migration
 from nominatim_db.errors import UsageError
+from nominatim_db.db.connection import server_version_tuple
 import nominatim_db.version
 
 from mock_legacy_word_table import MockLegacyWordTable
@@ -63,7 +64,7 @@ def test_set_up_migration_for_36(temp_db_with_extensions, temp_db_cursor,
                                           WHERE property = 'database_version'""")
 
 
-def test_already_at_version(def_config, property_table):
+def test_already_at_version(temp_db_with_extensions, def_config, property_table):
 
     property_table.set('database_version',
                        str(nominatim_db.version.NOMINATIM_VERSION))
@@ -71,8 +72,8 @@ def test_already_at_version(def_config, property_table):
     assert migration.migrate(def_config, {}) == 0
 
 
-def test_run_single_migration(def_config, temp_db_cursor, property_table,
-                              monkeypatch, postprocess_mock):
+def test_run_single_migration(temp_db_with_extensions, def_config, temp_db_cursor,
+                              property_table, monkeypatch, postprocess_mock):
     oldversion = [x for x in nominatim_db.version.NOMINATIM_VERSION]
     oldversion[0] -= 1
     property_table.set('database_version',
@@ -226,7 +227,7 @@ def test_create_tiger_housenumber_index(temp_db_conn, temp_db_cursor, table_fact
     migration.create_tiger_housenumber_index(temp_db_conn)
     temp_db_conn.commit()
 
-    if temp_db_conn.server_version_tuple() >= (11, 0, 0):
+    if server_version_tuple(temp_db_conn) >= (11, 0, 0):
         assert temp_db_cursor.index_exists('location_property_tiger',
                                            'idx_location_property_tiger_housenumber_migrated')
 
index 50ff6398b5673cdc073b4d4c62af6831729f4def..1f1968cfac7fe0c1a0e214037aa097728cfe0efe 100644 (file)
@@ -12,6 +12,7 @@ from pathlib import Path
 import pytest
 
 from nominatim_db.tools import refresh
+from nominatim_db.db.connection import postgis_version_tuple
 
 def test_refresh_import_wikipedia_not_existing(dsn):
     assert refresh.import_wikipedia_articles(dsn, Path('.')) == 1
@@ -23,13 +24,13 @@ def test_refresh_import_secondary_importance_non_existing(dsn):
 def test_refresh_import_secondary_importance_testdb(dsn, src_dir, temp_db_conn, temp_db_cursor):
     temp_db_cursor.execute('CREATE EXTENSION postgis')
 
-    if temp_db_conn.postgis_version_tuple()[0] < 3:
+    if postgis_version_tuple(temp_db_conn)[0] < 3:
         assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') > 0
     else:
         temp_db_cursor.execute('CREATE EXTENSION postgis_raster')
         assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') == 0
 
-        assert temp_db_conn.table_exists('secondary_importance')
+        assert temp_db_cursor.table_exists('secondary_importance')
 
 
 @pytest.mark.parametrize("replace", (True, False))
index 7ef6a1e63306020f41d1dcb91863a4a01be6c51e..fc01f22ffb905bd1ff15435633ec6f730225fefc 100644 (file)
@@ -12,6 +12,7 @@ from textwrap import dedent
 
 import pytest
 
+from nominatim_db.db.connection import execute_scalar
 from nominatim_db.tools import tiger_data, freeze
 from nominatim_db.errors import UsageError
 
@@ -31,8 +32,7 @@ class MockTigerTable:
             cur.execute("CREATE TABLE place (number INTEGER)")
 
     def count(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM tiger")
+        return execute_scalar(self.conn, "SELECT count(*) FROM tiger")
 
     def row(self):
         with self.conn.cursor() as cur: