]> git.openstreetmap.org Git - nominatim.git/commitdiff
make DB helper functions free functions
authorSarah Hoffmann <lonvia@denofr.de>
Tue, 2 Jul 2024 13:15:50 +0000 (15:15 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 29 Jul 2024 06:49:30 +0000 (08:49 +0200)
Also changes the drop function so that it can drop multiple tables
at once.

30 files changed:
src/nominatim_db/clicmd/admin.py
src/nominatim_db/clicmd/refresh.py
src/nominatim_db/data/country_info.py
src/nominatim_db/db/connection.py
src/nominatim_db/db/properties.py
src/nominatim_db/db/sql_preprocessor.py
src/nominatim_db/db/status.py
src/nominatim_db/indexer/indexer.py
src/nominatim_db/tokenizer/icu_tokenizer.py
src/nominatim_db/tokenizer/legacy_tokenizer.py
src/nominatim_db/tools/admin.py
src/nominatim_db/tools/check_database.py
src/nominatim_db/tools/collect_os_info.py
src/nominatim_db/tools/database_import.py
src/nominatim_db/tools/freeze.py
src/nominatim_db/tools/migration.py
src/nominatim_db/tools/postcodes.py
src/nominatim_db/tools/refresh.py
src/nominatim_db/tools/replication.py
src/nominatim_db/tools/special_phrases/sp_importer.py
test/python/db/test_async_connection.py
test/python/db/test_connection.py
test/python/mock_icu_word_table.py
test/python/mock_legacy_word_table.py
test/python/tokenizer/test_icu.py
test/python/tools/test_database_import.py
test/python/tools/test_import_special_phrases.py
test/python/tools/test_migration.py
test/python/tools/test_refresh.py
test/python/tools/test_tiger_data.py

index 7744595bbd3f99b449e0ea09a1e4ba2840edf710..1edff174dc37ad251de9d9d048736098d717d2f0 100644 (file)
@@ -12,7 +12,7 @@ import argparse
 import random
 
 from ..errors import UsageError
 import random
 
 from ..errors import UsageError
-from ..db.connection import connect
+from ..db.connection import connect, table_exists
 from .args import NominatimArgs
 
 # Do not repeat documentation of subcommand classes.
 from .args import NominatimArgs
 
 # Do not repeat documentation of subcommand classes.
@@ -115,7 +115,7 @@ class AdminFuncs:
 
                 tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
                 with connect(args.config.get_libpq_dsn()) as conn:
 
                 tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
                 with connect(args.config.get_libpq_dsn()) as conn:
-                    if conn.table_exists('search_name'):
+                    if table_exists(conn, 'search_name'):
                         words = tokenizer.most_frequent_words(conn, 1000)
                     else:
                         words = []
                         words = tokenizer.most_frequent_words(conn, 1000)
                     else:
                         words = []
index d5acf54b3a36a0e8b0832081a436ed4ce62c586f..363bad7853431851264f53eab83be7f277e46e07 100644 (file)
@@ -13,7 +13,7 @@ import logging
 from pathlib import Path
 
 from ..config import Configuration
 from pathlib import Path
 
 from ..config import Configuration
-from ..db.connection import connect
+from ..db.connection import connect, table_exists
 from ..tokenizer.base import AbstractTokenizer
 from .args import NominatimArgs
 
 from ..tokenizer.base import AbstractTokenizer
 from .args import NominatimArgs
 
@@ -124,7 +124,7 @@ class UpdateRefresh:
             with connect(args.config.get_libpq_dsn()) as conn:
                 # If the table did not exist before, then the importance code
                 # needs to be enabled.
             with connect(args.config.get_libpq_dsn()) as conn:
                 # If the table did not exist before, then the importance code
                 # needs to be enabled.
-                if not conn.table_exists('secondary_importance'):
+                if not table_exists(conn, 'secondary_importance'):
                     args.functions = True
 
             LOG.warning('Import secondary importance raster data from %s', args.project_dir)
                     args.functions = True
 
             LOG.warning('Import secondary importance raster data from %s', args.project_dir)
index c8002ee7023ca27a9e680003e92a66e3d8f0b428..e2bf5133f6403dd1acaf0a5e7069b2d3bac8ee7c 100644 (file)
@@ -9,10 +9,9 @@ Functions for importing and managing static country information.
 """
 from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload
 from pathlib import Path
 """
 from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload
 from pathlib import Path
-import psycopg2.extras
 
 from ..db import utils as db_utils
 
 from ..db import utils as db_utils
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, register_hstore
 from ..errors import UsageError
 from ..config import Configuration
 from ..tokenizer.base import AbstractTokenizer
 from ..errors import UsageError
 from ..config import Configuration
 from ..tokenizer.base import AbstractTokenizer
@@ -129,8 +128,8 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals
 
             params.append((ccode, props['names'], lang, partition))
     with connect(dsn) as conn:
 
             params.append((ccode, props['names'], lang, partition))
     with connect(dsn) as conn:
+        register_hstore(conn)
         with conn.cursor() as cur:
         with conn.cursor() as cur:
-            psycopg2.extras.register_hstore(cur)
             cur.execute(
                 """ CREATE TABLE public.country_name (
                         country_code character varying(2),
             cur.execute(
                 """ CREATE TABLE public.country_name (
                         country_code character varying(2),
@@ -157,8 +156,8 @@ def create_country_names(conn: Connection, tokenizer: AbstractTokenizer,
         return ':' not in key or not languages or \
                key[key.index(':') + 1:] in languages
 
         return ':' not in key or not languages or \
                key[key.index(':') + 1:] in languages
 
+    register_hstore(conn)
     with conn.cursor() as cur:
     with conn.cursor() as cur:
-        psycopg2.extras.register_hstore(cur)
         cur.execute("""SELECT country_code, name FROM country_name
                        WHERE country_code is not null""")
 
         cur.execute("""SELECT country_code, name FROM country_name
                        WHERE country_code is not null""")
 
index 8faa3f93334dda7a219b4e5dcd1caf98052a6274..629fad6ac9598603be5819baea8abcc7c1ba8697 100644 (file)
@@ -7,7 +7,8 @@
 """
 Specialised connection and cursor functions.
 """
 """
 Specialised connection and cursor functions.
 """
-from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
+from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\
+                   Tuple, Iterable
 import contextlib
 import logging
 import os
 import contextlib
 import logging
 import os
@@ -46,37 +47,6 @@ class Cursor(psycopg2.extras.DictCursor):
         psycopg2.extras.execute_values(self, sql, argslist, template=template)
 
 
         psycopg2.extras.execute_values(self, sql, argslist, template=template)
 
 
-    def scalar(self, sql: Query, args: Any = None) -> Any:
-        """ Execute query that returns a single value. The value is returned.
-            If the query yields more than one row, a ValueError is raised.
-        """
-        self.execute(sql, args)
-
-        if self.rowcount != 1:
-            raise RuntimeError("Query did not return a single row.")
-
-        result = self.fetchone()
-        assert result is not None
-
-        return result[0]
-
-
-    def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
-        """ Drop the table with the given name.
-            Set `if_exists` to False if a non-existent table should raise
-            an exception instead of just being ignored. If 'cascade' is set
-            to True then all dependent tables are deleted as well.
-        """
-        sql = 'DROP TABLE '
-        if if_exists:
-            sql += 'IF EXISTS '
-        sql += '{}'
-        if cascade:
-            sql += ' CASCADE'
-
-        self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
-
-
 class Connection(psycopg2.extensions.connection):
     """ A connection that provides the specialised cursor by default and
         adds convenience functions for administrating the database.
 class Connection(psycopg2.extensions.connection):
     """ A connection that provides the specialised cursor by default and
         adds convenience functions for administrating the database.
@@ -99,80 +69,105 @@ class Connection(psycopg2.extensions.connection):
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
 
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
 
-    def table_exists(self, table: str) -> bool:
-        """ Check that a table with the given name exists in the database.
-        """
-        with self.cursor() as cur:
-            num = cur.scalar("""SELECT count(*) FROM pg_tables
-                                WHERE tablename = %s and schemaname = 'public'""", (table, ))
-            return num == 1 if isinstance(num, int) else False
+def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
+    """ Execute query that returns a single value. The value is returned.
+        If the query yields more than one row, a ValueError is raised.
+    """
+    with conn.cursor() as cur:
+        cur.execute(sql, args)
 
 
+        if cur.rowcount != 1:
+            raise RuntimeError("Query did not return a single row.")
 
 
-    def table_has_column(self, table: str, column: str) -> bool:
-        """ Check if the table 'table' exists and has a column with name 'column'.
-        """
-        with self.cursor() as cur:
-            has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
-                                       WHERE table_name = %s
-                                             and column_name = %s""",
-                                    (table, column))
-            return has_column > 0 if isinstance(has_column, int) else False
-
-
-    def index_exists(self, index: str, table: Optional[str] = None) -> bool:
-        """ Check that an index with the given name exists in the database.
-            If table is not None then the index must relate to the given
-            table.
-        """
-        with self.cursor() as cur:
-            cur.execute("""SELECT tablename FROM pg_indexes
-                           WHERE indexname = %s and schemaname = 'public'""", (index, ))
-            if cur.rowcount == 0:
+        result = cur.fetchone()
+
+    assert result is not None
+    return result[0]
+
+
+def table_exists(conn: Connection, table: str) -> bool:
+    """ Check that a table with the given name exists in the database.
+    """
+    num = execute_scalar(conn,
+            """SELECT count(*) FROM pg_tables
+               WHERE tablename = %s and schemaname = 'public'""", (table, ))
+    return num == 1 if isinstance(num, int) else False
+
+
+def table_has_column(conn: Connection, table: str, column: str) -> bool:
+    """ Check if the table 'table' exists and has a column with name 'column'.
+    """
+    has_column = execute_scalar(conn,
+                    """SELECT count(*) FROM information_schema.columns
+                       WHERE table_name = %s and column_name = %s""",
+                    (table, column))
+    return has_column > 0 if isinstance(has_column, int) else False
+
+
+def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
+    """ Check that an index with the given name exists in the database.
+        If table is not None then the index must relate to the given
+        table.
+    """
+    with conn.cursor() as cur:
+        cur.execute("""SELECT tablename FROM pg_indexes
+                       WHERE indexname = %s and schemaname = 'public'""", (index, ))
+        if cur.rowcount == 0:
+            return False
+
+        if table is not None:
+            row = cur.fetchone()
+            if row is None or not isinstance(row[0], str):
                 return False
                 return False
+            return row[0] == table
 
 
-            if table is not None:
-                row = cur.fetchone()
-                if row is None or not isinstance(row[0], str):
-                    return False
-                return row[0] == table
+    return True
 
 
-        return True
+def drop_tables(conn: Connection, *names: str,
+               if_exists: bool = True, cascade: bool = False) -> None:
+    """ Drop one or more tables with the given names.
+        Set `if_exists` to False if a non-existent table should raise
+        an exception instead of just being ignored. `cascade` will cause
+        depended objects to be dropped as well.
+        The caller needs to take care of committing the change.
+    """
+    sql = pysql.SQL('DROP TABLE%s{}%s' % (
+                        ' IF EXISTS ' if if_exists else ' ',
+                        ' CASCADE' if cascade else ''))
 
 
+    with conn.cursor() as cur:
+        for name in names:
+            cur.execute(sql.format(pysql.Identifier(name)))
 
 
-    def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
-        """ Drop the table with the given name.
-            Set `if_exists` to False if a non-existent table should raise
-            an exception instead of just being ignored.
-        """
-        with self.cursor() as cur:
-            cur.drop_table(name, if_exists, cascade)
-        self.commit()
 
 
+def server_version_tuple(conn: Connection) -> Tuple[int, int]:
+    """ Return the server version as a tuple of (major, minor).
+        Converts correctly for pre-10 and post-10 PostgreSQL versions.
+    """
+    version = conn.server_version
+    if version < 100000:
+        return (int(version / 10000), int((version % 10000) / 100))
 
 
-    def server_version_tuple(self) -> Tuple[int, int]:
-        """ Return the server version as a tuple of (major, minor).
-            Converts correctly for pre-10 and post-10 PostgreSQL versions.
-        """
-        version = self.server_version
-        if version < 100000:
-            return (int(version / 10000), int((version % 10000) / 100))
+    return (int(version / 10000), version % 10000)
 
 
-        return (int(version / 10000), version % 10000)
 
 
+def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
+    """ Return the postgis version installed in the database as a
+        tuple of (major, minor). Assumes that the PostGIS extension
+        has been installed already.
+    """
+    version = execute_scalar(conn, 'SELECT postgis_lib_version()')
 
 
-    def postgis_version_tuple(self) -> Tuple[int, int]:
-        """ Return the postgis version installed in the database as a
-            tuple of (major, minor). Assumes that the PostGIS extension
-            has been installed already.
-        """
-        with self.cursor() as cur:
-            version = cur.scalar('SELECT postgis_lib_version()')
+    version_parts = version.split('.')
+    if len(version_parts) < 2:
+        raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
 
 
-        version_parts = version.split('.')
-        if len(version_parts) < 2:
-            raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
+    return (int(version_parts[0]), int(version_parts[1]))
 
 
-        return (int(version_parts[0]), int(version_parts[1]))
+def register_hstore(conn: Connection) -> None:
+    """ Register the hstore type with psycopg for the connection.
+    """
+    psycopg2.extras.register_hstore(conn)
 
 
 class ConnectionContext(ContextManager[Connection]):
 
 
 class ConnectionContext(ContextManager[Connection]):
index 3549382f944ad9e7f998b08190d2190b9c124d97..0e017ead900c0c2f6fa236e65d3ae927498e059b 100644 (file)
@@ -9,7 +9,7 @@ Query and access functions for the in-database property table.
 """
 from typing import Optional, cast
 
 """
 from typing import Optional, cast
 
-from .connection import Connection
+from .connection import Connection, table_exists
 
 def set_property(conn: Connection, name: str, value: str) -> None:
     """ Add or replace the property with the given name.
 
 def set_property(conn: Connection, name: str, value: str) -> None:
     """ Add or replace the property with the given name.
@@ -31,7 +31,7 @@ def get_property(conn: Connection, name: str) -> Optional[str]:
     """ Return the current value of the given property or None if the property
         is not set.
     """
     """ Return the current value of the given property or None if the property
         is not set.
     """
-    if not conn.table_exists('nominatim_properties'):
+    if not table_exists(conn, 'nominatim_properties'):
         return None
 
     with conn.cursor() as cur:
         return None
 
     with conn.cursor() as cur:
index 468f35107b8f5188a3cbc9b93c9fb3353ef5e322..691ab6c52cff562aafa8307ecea9d368ffa10c44 100644 (file)
@@ -10,7 +10,7 @@ Preprocessing of SQL files.
 from typing import Set, Dict, Any, cast
 import jinja2
 
 from typing import Set, Dict, Any, cast
 import jinja2
 
-from .connection import Connection
+from .connection import Connection, server_version_tuple, postgis_version_tuple
 from .async_connection import WorkerPool
 from ..config import Configuration
 
 from .async_connection import WorkerPool
 from ..config import Configuration
 
@@ -66,8 +66,8 @@ def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
     """ Set up a dictionary with various optional Postgresql/Postgis features that
         depend on the database version.
     """
     """ Set up a dictionary with various optional Postgresql/Postgis features that
         depend on the database version.
     """
-    pg_version = conn.server_version_tuple()
-    postgis_version = conn.postgis_version_tuple()
+    pg_version = server_version_tuple(conn)
+    postgis_version = postgis_version_tuple(conn)
     pg11plus = pg_version >= (11, 0, 0)
     ps3 = postgis_version >= (3, 0)
     return {
     pg11plus = pg_version >= (11, 0, 0)
     ps3 = postgis_version >= (3, 0)
     return {
index 1278359cc02ebae2ac1513162a0b49a0ef37d211..1d2b3bec5e6bdebeab68d7c1d266cdf06f9257e1 100644 (file)
@@ -12,7 +12,7 @@ import datetime as dt
 import logging
 import re
 
 import logging
 import re
 
-from .connection import Connection
+from .connection import Connection, table_exists, execute_scalar
 from ..utils.url_utils import get_url
 from ..errors import UsageError
 from ..typing import TypedDict
 from ..utils.url_utils import get_url
 from ..errors import UsageError
 from ..typing import TypedDict
@@ -34,7 +34,7 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
         data base.
     """
     # If there is a date from osm2pgsql available, use that.
         data base.
     """
     # If there is a date from osm2pgsql available, use that.
-    if conn.table_exists('osm2pgsql_properties'):
+    if table_exists(conn, 'osm2pgsql_properties'):
         with conn.cursor() as cur:
             cur.execute(""" SELECT value FROM osm2pgsql_properties
                             WHERE property = 'current_timestamp' """)
         with conn.cursor() as cur:
             cur.execute(""" SELECT value FROM osm2pgsql_properties
                             WHERE property = 'current_timestamp' """)
@@ -47,15 +47,14 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
         raise UsageError("Cannot determine database date from data in offline mode.")
 
     # Else, find the node with the highest ID in the database
         raise UsageError("Cannot determine database date from data in offline mode.")
 
     # Else, find the node with the highest ID in the database
-    with conn.cursor() as cur:
-        if conn.table_exists('place'):
-            osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'")
-        else:
-            osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'")
-
-        if osmid is None:
-            LOG.fatal("No data found in the database.")
-            raise UsageError("No data found in the database.")
+    if table_exists(conn, 'place'):
+        osmid = execute_scalar(conn, "SELECT max(osm_id) FROM place WHERE osm_type='N'")
+    else:
+        osmid = execute_scalar(conn, "SELECT max(osm_id) FROM placex WHERE osm_type='N'")
+
+    if osmid is None:
+        LOG.fatal("No data found in the database.")
+        raise UsageError("No data found in the database.")
 
     LOG.info("Using node id %d for timestamp lookup", osmid)
     # Get the node from the API to find the timestamp when it was created.
 
     LOG.info("Using node id %d for timestamp lookup", osmid)
     # Get the node from the API to find the timestamp when it was created.
index 5a219f6b2ecc3904b1129c9b4f9cff92c12132a5..b4c9732c624231df40215eb9e6ba543554c332e4 100644 (file)
@@ -15,7 +15,7 @@ import psycopg2.extras
 
 from ..typing import DictCursorResults
 from ..db.async_connection import DBConnection, WorkerPool
 
 from ..typing import DictCursorResults
 from ..db.async_connection import DBConnection, WorkerPool
-from ..db.connection import connect, Connection, Cursor
+from ..db.connection import connect, Connection, Cursor, execute_scalar, register_hstore
 from ..tokenizer.base import AbstractTokenizer
 from .progress import ProgressLogger
 from . import runners
 from ..tokenizer.base import AbstractTokenizer
 from .progress import ProgressLogger
 from . import runners
@@ -32,15 +32,15 @@ class PlaceFetcher:
         self.conn: Optional[DBConnection] = DBConnection(dsn,
                                                cursor_factory=psycopg2.extras.DictCursor)
 
         self.conn: Optional[DBConnection] = DBConnection(dsn,
                                                cursor_factory=psycopg2.extras.DictCursor)
 
-        with setup_conn.cursor() as cur:
-            # need to fetch those manually because register_hstore cannot
-            # fetch them on an asynchronous connection below.
-            hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
-            hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
+        # need to fetch those manually because register_hstore cannot
+        # fetch them on an asynchronous connection below.
+        hstore_oid = execute_scalar(setup_conn, "SELECT 'hstore'::regtype::oid")
+        hstore_array_oid = execute_scalar(setup_conn, "SELECT 'hstore[]'::regtype::oid")
 
         psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
                                         array_oid=hstore_array_oid)
 
 
         psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
                                         array_oid=hstore_array_oid)
 
+
     def close(self) -> None:
         """ Close the underlying asynchronous connection.
         """
     def close(self) -> None:
         """ Close the underlying asynchronous connection.
         """
@@ -205,10 +205,9 @@ class Indexer:
         LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
         with connect(self.dsn) as conn:
         LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
         with connect(self.dsn) as conn:
-            psycopg2.extras.register_hstore(conn)
-            with conn.cursor() as cur:
-                total_tuples = cur.scalar(runner.sql_count_objects())
-                LOG.debug("Total number of rows: %i", total_tuples)
+            register_hstore(conn)
+            total_tuples = execute_scalar(conn, runner.sql_count_objects())
+            LOG.debug("Total number of rows: %i", total_tuples)
 
             conn.commit()
 
 
             conn.commit()
 
index 22e2d048291bcaa6d566e0f5fd3cf92fe9fcf870..70c5c27a096dd1d6c8764370fa7ec6d504a4e99d 100644 (file)
@@ -16,7 +16,8 @@ import logging
 from pathlib import Path
 from textwrap import dedent
 
 from pathlib import Path
 from textwrap import dedent
 
-from ..db.connection import connect, Connection, Cursor
+from ..db.connection import connect, Connection, Cursor, server_version_tuple,\
+                            drop_tables, table_exists, execute_scalar
 from ..config import Configuration
 from ..db.utils import CopyBuffer
 from ..db.sql_preprocessor import SQLPreprocessor
 from ..config import Configuration
 from ..db.utils import CopyBuffer
 from ..db.sql_preprocessor import SQLPreprocessor
@@ -108,7 +109,7 @@ class ICUTokenizer(AbstractTokenizer):
         """ Recompute frequencies for all name words.
         """
         with connect(self.dsn) as conn:
         """ Recompute frequencies for all name words.
         """
         with connect(self.dsn) as conn:
-            if not conn.table_exists('search_name'):
+            if not table_exists(conn, 'search_name'):
                 return
 
             with conn.cursor() as cur:
                 return
 
             with conn.cursor() as cur:
@@ -117,10 +118,9 @@ class ICUTokenizer(AbstractTokenizer):
                     cur.execute('SET max_parallel_workers_per_gather TO %s',
                                 (min(threads, 6),))
 
                     cur.execute('SET max_parallel_workers_per_gather TO %s',
                                 (min(threads, 6),))
 
-                if conn.server_version_tuple() < (12, 0):
+                if server_version_tuple(conn) < (12, 0):
                     LOG.info('Computing word frequencies')
                     LOG.info('Computing word frequencies')
-                    cur.drop_table('word_frequencies')
-                    cur.drop_table('addressword_frequencies')
+                    drop_tables(conn, 'word_frequencies', 'addressword_frequencies')
                     cur.execute("""CREATE TEMP TABLE word_frequencies AS
                                      SELECT unnest(name_vector) as id, count(*)
                                      FROM search_name GROUP BY id""")
                     cur.execute("""CREATE TEMP TABLE word_frequencies AS
                                      SELECT unnest(name_vector) as id, count(*)
                                      FROM search_name GROUP BY id""")
@@ -152,17 +152,16 @@ class ICUTokenizer(AbstractTokenizer):
                                    $$ LANGUAGE plpgsql IMMUTABLE;
                                 """)
                     LOG.info('Update word table with recomputed frequencies')
                                    $$ LANGUAGE plpgsql IMMUTABLE;
                                 """)
                     LOG.info('Update word table with recomputed frequencies')
-                    cur.drop_table('tmp_word')
+                    drop_tables(conn, 'tmp_word')
                     cur.execute("""CREATE TABLE tmp_word AS
                                     SELECT word_id, word_token, type, word,
                                            word_freq_update(word_id, info) as info
                                     FROM word
                                 """)
                     cur.execute("""CREATE TABLE tmp_word AS
                                     SELECT word_id, word_token, type, word,
                                            word_freq_update(word_id, info) as info
                                     FROM word
                                 """)
-                    cur.drop_table('word_frequencies')
-                    cur.drop_table('addressword_frequencies')
+                    drop_tables(conn, 'word_frequencies', 'addressword_frequencies')
                 else:
                     LOG.info('Computing word frequencies')
                 else:
                     LOG.info('Computing word frequencies')
-                    cur.drop_table('word_frequencies')
+                    drop_tables(conn, 'word_frequencies')
                     cur.execute("""
                       CREATE TEMP TABLE word_frequencies AS
                       WITH word_freq AS MATERIALIZED (
                     cur.execute("""
                       CREATE TEMP TABLE word_frequencies AS
                       WITH word_freq AS MATERIALIZED (
@@ -182,7 +181,7 @@ class ICUTokenizer(AbstractTokenizer):
                     cur.execute('CREATE UNIQUE INDEX ON word_frequencies(id) INCLUDE(info)')
                     cur.execute('ANALYSE word_frequencies')
                     LOG.info('Update word table with recomputed frequencies')
                     cur.execute('CREATE UNIQUE INDEX ON word_frequencies(id) INCLUDE(info)')
                     cur.execute('ANALYSE word_frequencies')
                     LOG.info('Update word table with recomputed frequencies')
-                    cur.drop_table('tmp_word')
+                    drop_tables(conn, 'tmp_word')
                     cur.execute("""CREATE TABLE tmp_word AS
                                     SELECT word_id, word_token, type, word,
                                            (CASE WHEN wf.info is null THEN word.info
                     cur.execute("""CREATE TABLE tmp_word AS
                                     SELECT word_id, word_token, type, word,
                                            (CASE WHEN wf.info is null THEN word.info
@@ -191,7 +190,7 @@ class ICUTokenizer(AbstractTokenizer):
                                     FROM word LEFT JOIN word_frequencies wf
                                          ON word.word_id = wf.id
                                 """)
                                     FROM word LEFT JOIN word_frequencies wf
                                          ON word.word_id = wf.id
                                 """)
-                    cur.drop_table('word_frequencies')
+                    drop_tables(conn, 'word_frequencies')
 
             with conn.cursor() as cur:
                 cur.execute('SET max_parallel_workers_per_gather TO 0')
 
             with conn.cursor() as cur:
                 cur.execute('SET max_parallel_workers_per_gather TO 0')
@@ -210,7 +209,7 @@ class ICUTokenizer(AbstractTokenizer):
         """ Remove unused house numbers.
         """
         with connect(self.dsn) as conn:
         """ Remove unused house numbers.
         """
         with connect(self.dsn) as conn:
-            if not conn.table_exists('search_name'):
+            if not table_exists(conn, 'search_name'):
                 return
             with conn.cursor(name="hnr_counter") as cur:
                 cur.execute("""SELECT DISTINCT word_id, coalesce(info->>'lookup', word_token)
                 return
             with conn.cursor(name="hnr_counter") as cur:
                 cur.execute("""SELECT DISTINCT word_id, coalesce(info->>'lookup', word_token)
@@ -311,8 +310,7 @@ class ICUTokenizer(AbstractTokenizer):
             frequencies.
         """
         with connect(self.dsn) as conn:
             frequencies.
         """
         with connect(self.dsn) as conn:
-            with conn.cursor() as cur:
-                cur.drop_table('word')
+            drop_tables(conn, 'word')
             sqlp = SQLPreprocessor(conn, config)
             sqlp.run_string(conn, """
                 CREATE TABLE word (
             sqlp = SQLPreprocessor(conn, config)
             sqlp.run_string(conn, """
                 CREATE TABLE word (
@@ -370,8 +368,8 @@ class ICUTokenizer(AbstractTokenizer):
         """ Rename all tables and indexes used by the tokenizer.
         """
         with connect(self.dsn) as conn:
         """ Rename all tables and indexes used by the tokenizer.
         """
         with connect(self.dsn) as conn:
+            drop_tables(conn, 'word')
             with conn.cursor() as cur:
             with conn.cursor() as cur:
-                cur.drop_table('word')
                 cur.execute(f"ALTER TABLE {old} RENAME TO word")
                 for idx in ('word_token', 'word_id'):
                     cur.execute(f"""ALTER INDEX idx_{old}_{idx}
                 cur.execute(f"ALTER TABLE {old} RENAME TO word")
                 for idx in ('word_token', 'word_id'):
                     cur.execute(f"""ALTER INDEX idx_{old}_{idx}
@@ -733,11 +731,10 @@ class ICUNameAnalyzer(AbstractAnalyzer):
             if norm_name:
                 result = self._cache.housenumbers.get(norm_name, result)
                 if result[0] is None:
             if norm_name:
                 result = self._cache.housenumbers.get(norm_name, result)
                 if result[0] is None:
-                    with self.conn.cursor() as cur:
-                        hid = cur.scalar("SELECT getorcreate_hnr_id(%s)", (norm_name, ))
+                    hid = execute_scalar(self.conn, "SELECT getorcreate_hnr_id(%s)", (norm_name, ))
 
 
-                        result = hid, norm_name
-                        self._cache.housenumbers[norm_name] = result
+                    result = hid, norm_name
+                    self._cache.housenumbers[norm_name] = result
         else:
             # Otherwise use the analyzer to determine the canonical name.
             # Per convention we use the first variant as the 'lookup name', the
         else:
             # Otherwise use the analyzer to determine the canonical name.
             # Per convention we use the first variant as the 'lookup name', the
@@ -748,11 +745,10 @@ class ICUNameAnalyzer(AbstractAnalyzer):
                 if result[0] is None:
                     variants = analyzer.compute_variants(word_id)
                     if variants:
                 if result[0] is None:
                     variants = analyzer.compute_variants(word_id)
                     if variants:
-                        with self.conn.cursor() as cur:
-                            hid = cur.scalar("SELECT create_analyzed_hnr_id(%s, %s)",
+                        hid = execute_scalar(self.conn, "SELECT create_analyzed_hnr_id(%s, %s)",
                                              (word_id, list(variants)))
                                              (word_id, list(variants)))
-                            result = hid, variants[0]
-                            self._cache.housenumbers[word_id] = result
+                        result = hid, variants[0]
+                        self._cache.housenumbers[word_id] = result
 
         return result
 
 
         return result
 
index 136a733159cd3180e2e0558e48b97141887165c1..0e8dfcf97fc07b51216e88862b13f6018043f082 100644 (file)
@@ -18,10 +18,10 @@ from textwrap import dedent
 
 from icu import Transliterator
 import psycopg2
 
 from icu import Transliterator
 import psycopg2
-import psycopg2.extras
 
 from ..errors import UsageError
 
 from ..errors import UsageError
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, drop_tables, table_exists,\
+                            execute_scalar, register_hstore
 from ..config import Configuration
 from ..db import properties
 from ..db import utils as db_utils
 from ..config import Configuration
 from ..db import properties
 from ..db import utils as db_utils
@@ -179,11 +179,10 @@ class LegacyTokenizer(AbstractTokenizer):
              * Can nominatim.so be accessed by the database user?
              """
         with connect(self.dsn) as conn:
              * Can nominatim.so be accessed by the database user?
              """
         with connect(self.dsn) as conn:
-            with conn.cursor() as cur:
-                try:
-                    out = cur.scalar("SELECT make_standard_name('a')")
-                except psycopg2.Error as err:
-                    return hint.format(error=str(err))
+            try:
+                out = execute_scalar(conn, "SELECT make_standard_name('a')")
+            except psycopg2.Error as err:
+                return hint.format(error=str(err))
 
         if out != 'a':
             return hint.format(error='Unexpected result for make_standard_name()')
 
         if out != 'a':
             return hint.format(error='Unexpected result for make_standard_name()')
@@ -214,9 +213,9 @@ class LegacyTokenizer(AbstractTokenizer):
         """ Recompute the frequency of full words.
         """
         with connect(self.dsn) as conn:
         """ Recompute the frequency of full words.
         """
         with connect(self.dsn) as conn:
-            if conn.table_exists('search_name'):
+            if table_exists(conn, 'search_name'):
+                drop_tables(conn, "word_frequencies")
                 with conn.cursor() as cur:
                 with conn.cursor() as cur:
-                    cur.drop_table("word_frequencies")
                     LOG.info("Computing word frequencies")
                     cur.execute("""CREATE TEMP TABLE word_frequencies AS
                                      SELECT unnest(name_vector) as id, count(*)
                     LOG.info("Computing word frequencies")
                     cur.execute("""CREATE TEMP TABLE word_frequencies AS
                                      SELECT unnest(name_vector) as id, count(*)
@@ -226,7 +225,7 @@ class LegacyTokenizer(AbstractTokenizer):
                     cur.execute("""UPDATE word SET search_name_count = count
                                    FROM word_frequencies
                                    WHERE word_token like ' %' and word_id = id""")
                     cur.execute("""UPDATE word SET search_name_count = count
                                    FROM word_frequencies
                                    WHERE word_token like ' %' and word_id = id""")
-                    cur.drop_table("word_frequencies")
+                drop_tables(conn, "word_frequencies")
             conn.commit()
 
 
             conn.commit()
 
 
@@ -316,7 +315,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
         self.conn: Optional[Connection] = connect(dsn).connection
         self.conn.autocommit = True
         self.normalizer = normalizer
         self.conn: Optional[Connection] = connect(dsn).connection
         self.conn.autocommit = True
         self.normalizer = normalizer
-        psycopg2.extras.register_hstore(self.conn)
+        register_hstore(self.conn)
 
         self._cache = _TokenCache(self.conn)
 
 
         self._cache = _TokenCache(self.conn)
 
@@ -536,9 +535,8 @@ class _TokenInfo:
     def add_names(self, conn: Connection, names: Mapping[str, str]) -> None:
         """ Add token information for the names of the place.
         """
     def add_names(self, conn: Connection, names: Mapping[str, str]) -> None:
         """ Add token information for the names of the place.
         """
-        with conn.cursor() as cur:
-            # Create the token IDs for all names.
-            self.data['names'] = cur.scalar("SELECT make_keywords(%s)::text",
+        # Create the token IDs for all names.
+        self.data['names'] = execute_scalar(conn, "SELECT make_keywords(%s)::text",
                                             (names, ))
 
 
                                             (names, ))
 
 
@@ -576,9 +574,8 @@ class _TokenInfo:
         """ Add addr:street match terms.
         """
         def _get_street(name: str) -> Optional[str]:
         """ Add addr:street match terms.
         """
         def _get_street(name: str) -> Optional[str]:
-            with conn.cursor() as cur:
-                return cast(Optional[str],
-                            cur.scalar("SELECT word_ids_from_name(%s)::text", (name, )))
+            return cast(Optional[str],
+                        execute_scalar(conn, "SELECT word_ids_from_name(%s)::text", (name, )))
 
         tokens = self.cache.streets.get(street, _get_street)
         self.data['street'] = tokens or '{}'
 
         tokens = self.cache.streets.get(street, _get_street)
         self.data['street'] = tokens or '{}'
index cea2ad664a7e339c2c66dd23790ab517dcd7b338..3e502199aa287cc6c3834f9c485c2aca7bebe91a 100644 (file)
@@ -10,12 +10,12 @@ Functions for database analysis and maintenance.
 from typing import Optional, Tuple, Any, cast
 import logging
 
 from typing import Optional, Tuple, Any, cast
 import logging
 
-from psycopg2.extras import Json, register_hstore
+from psycopg2.extras import Json
 from psycopg2 import DataError
 
 from ..typing import DictCursorResult
 from ..config import Configuration
 from psycopg2 import DataError
 
 from ..typing import DictCursorResult
 from ..config import Configuration
-from ..db.connection import connect, Cursor
+from ..db.connection import connect, Cursor, register_hstore
 from ..errors import UsageError
 from ..tokenizer import factory as tokenizer_factory
 from ..data.place_info import PlaceInfo
 from ..errors import UsageError
 from ..tokenizer import factory as tokenizer_factory
 from ..data.place_info import PlaceInfo
index ef28a0e5a61653b6a2f2c250137050130caff084..946f929138ca6453083d3eabf899a24d025b67ed 100644 (file)
@@ -12,7 +12,8 @@ from enum import Enum
 from textwrap import dedent
 
 from ..config import Configuration
 from textwrap import dedent
 
 from ..config import Configuration
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, server_version_tuple,\
+                            index_exists, table_exists, execute_scalar
 from ..db import properties
 from ..errors import UsageError
 from ..tokenizer import factory as tokenizer_factory
 from ..db import properties
 from ..errors import UsageError
 from ..tokenizer import factory as tokenizer_factory
@@ -109,14 +110,14 @@ def _get_indexes(conn: Connection) -> List[str]:
                'idx_postcode_id',
                'idx_postcode_postcode'
               ]
                'idx_postcode_id',
                'idx_postcode_postcode'
               ]
-    if conn.table_exists('search_name'):
+    if table_exists(conn, 'search_name'):
         indexes.extend(('idx_search_name_nameaddress_vector',
                         'idx_search_name_name_vector',
                         'idx_search_name_centroid'))
         indexes.extend(('idx_search_name_nameaddress_vector',
                         'idx_search_name_name_vector',
                         'idx_search_name_centroid'))
-        if conn.server_version_tuple() >= (11, 0, 0):
+        if server_version_tuple(conn) >= (11, 0, 0):
             indexes.extend(('idx_placex_housenumber',
                             'idx_osmline_parent_osm_id_with_hnr'))
             indexes.extend(('idx_placex_housenumber',
                             'idx_osmline_parent_osm_id_with_hnr'))
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         indexes.extend(('idx_location_area_country_place_id',
                         'idx_place_osm_unique',
                         'idx_placex_rank_address_sector',
         indexes.extend(('idx_location_area_country_place_id',
                         'idx_place_osm_unique',
                         'idx_placex_rank_address_sector',
@@ -153,7 +154,7 @@ def check_connection(conn: Any, config: Configuration) -> CheckResult:
 
              Hints:
              * Are you connecting to the correct database?
 
              Hints:
              * Are you connecting to the correct database?
-             
+
              {instruction}
 
              Check the Migration chapter of the Administration Guide.
              {instruction}
 
              Check the Migration chapter of the Administration Guide.
@@ -165,7 +166,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
     """ Checking database_version matches Nominatim software version
     """
 
     """ Checking database_version matches Nominatim software version
     """
 
-    if conn.table_exists('nominatim_properties'):
+    if table_exists(conn, 'nominatim_properties'):
         db_version_str = properties.get_property(conn, 'database_version')
     else:
         db_version_str = None
         db_version_str = properties.get_property(conn, 'database_version')
     else:
         db_version_str = None
@@ -202,7 +203,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
 def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
     """ Checking for placex table
     """
 def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
     """ Checking for placex table
     """
-    if conn.table_exists('placex'):
+    if table_exists(conn, 'placex'):
         return CheckState.OK
 
     return CheckState.FATAL, dict(config=config)
         return CheckState.OK
 
     return CheckState.FATAL, dict(config=config)
@@ -212,8 +213,7 @@ def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
 def check_placex_size(conn: Connection, _: Configuration) -> CheckResult:
     """ Checking for placex content
     """
 def check_placex_size(conn: Connection, _: Configuration) -> CheckResult:
     """ Checking for placex content
     """
-    with conn.cursor() as cur:
-        cnt = cur.scalar('SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x')
+    cnt = execute_scalar(conn, 'SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x')
 
     return CheckState.OK if cnt > 0 else CheckState.FATAL
 
 
     return CheckState.OK if cnt > 0 else CheckState.FATAL
 
@@ -244,16 +244,15 @@ def check_tokenizer(_: Connection, config: Configuration) -> CheckResult:
 def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult:
     """ Checking for wikipedia/wikidata data
     """
 def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult:
     """ Checking for wikipedia/wikidata data
     """
-    if not conn.table_exists('search_name') or not conn.table_exists('place'):
+    if not table_exists(conn, 'search_name') or not table_exists(conn, 'place'):
         return CheckState.NOT_APPLICABLE
 
         return CheckState.NOT_APPLICABLE
 
-    with conn.cursor() as cur:
-        if conn.table_exists('wikimedia_importance'):
-            cnt = cur.scalar('SELECT count(*) FROM wikimedia_importance')
-        else:
-            cnt = cur.scalar('SELECT count(*) FROM wikipedia_article')
+    if table_exists(conn, 'wikimedia_importance'):
+        cnt = execute_scalar(conn, 'SELECT count(*) FROM wikimedia_importance')
+    else:
+        cnt = execute_scalar(conn, 'SELECT count(*) FROM wikipedia_article')
 
 
-        return CheckState.WARN if cnt == 0 else CheckState.OK
+    return CheckState.WARN if cnt == 0 else CheckState.OK
 
 
 @_check(hint="""\
 
 
 @_check(hint="""\
@@ -264,8 +263,7 @@ def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult
 def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
     """ Checking indexing status
     """
 def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
     """ Checking indexing status
     """
-    with conn.cursor() as cur:
-        cnt = cur.scalar('SELECT count(*) FROM placex WHERE indexed_status > 0')
+    cnt = execute_scalar(conn, 'SELECT count(*) FROM placex WHERE indexed_status > 0')
 
     if cnt == 0:
         return CheckState.OK
 
     if cnt == 0:
         return CheckState.OK
@@ -276,7 +274,7 @@ def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
             Low counts of unindexed places are fine."""
         return CheckState.WARN, dict(count=cnt, index_cmd=index_cmd)
 
             Low counts of unindexed places are fine."""
         return CheckState.WARN, dict(count=cnt, index_cmd=index_cmd)
 
-    if conn.index_exists('idx_placex_rank_search'):
+    if index_exists(conn, 'idx_placex_rank_search'):
         # Likely just an interrupted update.
         index_cmd = 'nominatim index'
     else:
         # Likely just an interrupted update.
         index_cmd = 'nominatim index'
     else:
@@ -297,7 +295,7 @@ def check_database_indexes(conn: Connection, _: Configuration) -> CheckResult:
     """
     missing = []
     for index in _get_indexes(conn):
     """
     missing = []
     for index in _get_indexes(conn):
-        if not conn.index_exists(index):
+        if not index_exists(conn, index):
             missing.append(index)
 
     if missing:
             missing.append(index)
 
     if missing:
@@ -340,11 +338,10 @@ def check_tiger_table(conn: Connection, config: Configuration) -> CheckResult:
     if not config.get_bool('USE_US_TIGER_DATA'):
         return CheckState.NOT_APPLICABLE
 
     if not config.get_bool('USE_US_TIGER_DATA'):
         return CheckState.NOT_APPLICABLE
 
-    if not conn.table_exists('location_property_tiger'):
+    if not table_exists(conn, 'location_property_tiger'):
         return CheckState.FAIL, dict(error='TIGER data table not found.')
 
         return CheckState.FAIL, dict(error='TIGER data table not found.')
 
-    with conn.cursor() as cur:
-        if cur.scalar('SELECT count(*) FROM location_property_tiger') == 0:
-            return CheckState.FAIL, dict(error='TIGER data table is empty.')
+    if execute_scalar(conn, 'SELECT count(*) FROM location_property_tiger') == 0:
+        return CheckState.FAIL, dict(error='TIGER data table is empty.')
 
     return CheckState.OK
 
     return CheckState.OK
index e1f8b16637a4cd02de87fb582f8932b8695cd1b1..db3e773d8559470ce610fe2acfa286bb7bfd1b8f 100644 (file)
@@ -12,21 +12,16 @@ import os
 import subprocess
 import sys
 from pathlib import Path
 import subprocess
 import sys
 from pathlib import Path
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Union
 
 import psutil
 
 import psutil
-from psycopg2.extensions import make_dsn, parse_dsn
+from psycopg2.extensions import make_dsn
 
 from ..config import Configuration
 
 from ..config import Configuration
-from ..db.connection import connect
+from ..db.connection import connect, server_version_tuple, execute_scalar
 from ..version import NOMINATIM_VERSION
 
 
 from ..version import NOMINATIM_VERSION
 
 
-def convert_version(ver_tup: Tuple[int, int]) -> str:
-    """converts tuple version (ver_tup) to a string representation"""
-    return ".".join(map(str, ver_tup))
-
-
 def friendly_memory_string(mem: float) -> str:
     """Create a user friendly string for the amount of memory specified as mem"""
     mem_magnitude = ("bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
 def friendly_memory_string(mem: float) -> str:
     """Create a user friendly string for the amount of memory specified as mem"""
     mem_magnitude = ("bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
@@ -103,16 +98,16 @@ def report_system_information(config: Configuration) -> None:
     storage, and database configuration."""
 
     with connect(make_dsn(config.get_libpq_dsn(), dbname='postgres')) as conn:
     storage, and database configuration."""
 
     with connect(make_dsn(config.get_libpq_dsn(), dbname='postgres')) as conn:
-        postgresql_ver: str = convert_version(conn.server_version_tuple())
+        postgresql_ver: str = '.'.join(map(str, server_version_tuple(conn)))
 
         with conn.cursor() as cur:
 
         with conn.cursor() as cur:
-            num = cur.scalar("SELECT count(*) FROM pg_catalog.pg_database WHERE datname=%s",
-                             (parse_dsn(config.get_libpq_dsn())['dbname'], ))
-            nominatim_db_exists = num == 1 if isinstance(num, int) else False
+            cur.execute("SELECT datname FROM pg_catalog.pg_database WHERE datname=%s",
+                        (config.get_database_params()['dbname'], ))
+            nominatim_db_exists = cur.rowcount > 0
 
     if nominatim_db_exists:
         with connect(config.get_libpq_dsn()) as conn:
 
     if nominatim_db_exists:
         with connect(config.get_libpq_dsn()) as conn:
-            postgis_ver: str = convert_version(conn.postgis_version_tuple())
+            postgis_ver: str = execute_scalar(conn, 'SELECT postgis_lib_version()')
     else:
         postgis_ver = "Unable to connect to database"
 
     else:
         postgis_ver = "Unable to connect to database"
 
index c4b3023a8585d57d28ffee15f06d1c3ddc4f60f5..2398d4043a734ceb6b3cf45971e998133fe6456a 100644 (file)
@@ -19,7 +19,8 @@ from psycopg2 import sql as pysql
 
 from ..errors import UsageError
 from ..config import Configuration
 
 from ..errors import UsageError
 from ..config import Configuration
-from ..db.connection import connect, get_pg_env, Connection
+from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\
+                            postgis_version_tuple, drop_tables, table_exists, execute_scalar
 from ..db.async_connection import DBConnection
 from ..db.sql_preprocessor import SQLPreprocessor
 from .exec_utils import run_osm2pgsql
 from ..db.async_connection import DBConnection
 from ..db.sql_preprocessor import SQLPreprocessor
 from .exec_utils import run_osm2pgsql
@@ -51,10 +52,10 @@ def check_existing_database_plugins(dsn: str) -> None:
     """ Check that the database has the required plugins installed."""
     with connect(dsn) as conn:
         _require_version('PostgreSQL server',
     """ Check that the database has the required plugins installed."""
     with connect(dsn) as conn:
         _require_version('PostgreSQL server',
-                         conn.server_version_tuple(),
+                         server_version_tuple(conn),
                          POSTGRESQL_REQUIRED_VERSION)
         _require_version('PostGIS',
                          POSTGRESQL_REQUIRED_VERSION)
         _require_version('PostGIS',
-                         conn.postgis_version_tuple(),
+                         postgis_version_tuple(conn),
                          POSTGIS_REQUIRED_VERSION)
         _require_loaded('hstore', conn)
 
                          POSTGIS_REQUIRED_VERSION)
         _require_loaded('hstore', conn)
 
@@ -80,31 +81,30 @@ def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
 
     with connect(dsn) as conn:
         _require_version('PostgreSQL server',
 
     with connect(dsn) as conn:
         _require_version('PostgreSQL server',
-                         conn.server_version_tuple(),
+                         server_version_tuple(conn),
                          POSTGRESQL_REQUIRED_VERSION)
 
         if rouser is not None:
                          POSTGRESQL_REQUIRED_VERSION)
 
         if rouser is not None:
-            with conn.cursor() as cur:
-                cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
+            cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s',
                                  (rouser, ))
                                  (rouser, ))
-                if cnt == 0:
-                    LOG.fatal("Web user '%s' does not exist. Create it with:\n"
-                              "\n      createuser %s", rouser, rouser)
-                    raise UsageError('Missing read-only user.')
+            if cnt == 0:
+                LOG.fatal("Web user '%s' does not exist. Create it with:\n"
+                          "\n      createuser %s", rouser, rouser)
+                raise UsageError('Missing read-only user.')
 
         # Create extensions.
         with conn.cursor() as cur:
             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
 
 
         # Create extensions.
         with conn.cursor() as cur:
             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
 
-            postgis_version = conn.postgis_version_tuple()
+            postgis_version = postgis_version_tuple(conn)
             if postgis_version[0] >= 3:
                 cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
 
         conn.commit()
 
         _require_version('PostGIS',
             if postgis_version[0] >= 3:
                 cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
 
         conn.commit()
 
         _require_version('PostGIS',
-                         conn.postgis_version_tuple(),
+                         postgis_version_tuple(conn),
                          POSTGIS_REQUIRED_VERSION)
 
 
                          POSTGIS_REQUIRED_VERSION)
 
 
@@ -141,7 +141,8 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]],
                     raise UsageError('No data imported by osm2pgsql.')
 
         if drop:
                     raise UsageError('No data imported by osm2pgsql.')
 
         if drop:
-            conn.drop_table('planet_osm_nodes')
+            drop_tables(conn, 'planet_osm_nodes')
+            conn.commit()
 
     if drop and options['flatnode_file']:
         Path(options['flatnode_file']).unlink()
 
     if drop and options['flatnode_file']:
         Path(options['flatnode_file']).unlink()
@@ -184,7 +185,7 @@ def truncate_data_tables(conn: Connection) -> None:
         cur.execute('TRUNCATE location_property_tiger')
         cur.execute('TRUNCATE location_property_osmline')
         cur.execute('TRUNCATE location_postcode')
         cur.execute('TRUNCATE location_property_tiger')
         cur.execute('TRUNCATE location_property_osmline')
         cur.execute('TRUNCATE location_postcode')
-        if conn.table_exists('search_name'):
+        if table_exists(conn, 'search_name'):
             cur.execute('TRUNCATE search_name')
         cur.execute('DROP SEQUENCE IF EXISTS seq_place')
         cur.execute('CREATE SEQUENCE seq_place start 100000')
             cur.execute('TRUNCATE search_name')
         cur.execute('DROP SEQUENCE IF EXISTS seq_place')
         cur.execute('CREATE SEQUENCE seq_place start 100000')
index bd52ba9a5a7ad194c353a721d1529ea4cefa20cb..e6d80e1e156c3342feaabe747ec4183dcef02cf6 100644 (file)
@@ -12,7 +12,7 @@ from pathlib import Path
 
 from psycopg2 import sql as pysql
 
 
 from psycopg2 import sql as pysql
 
-from ..db.connection import Connection
+from ..db.connection import Connection, drop_tables, table_exists
 
 UPDATE_TABLES = [
     'address_levels',
 
 UPDATE_TABLES = [
     'address_levels',
@@ -39,9 +39,7 @@ def drop_update_tables(conn: Connection) -> None:
                     + pysql.SQL(' or ').join(parts))
         tables = [r[0] for r in cur]
 
                     + pysql.SQL(' or ').join(parts))
         tables = [r[0] for r in cur]
 
-        for table in tables:
-            cur.drop_table(table, cascade=True)
-
+    drop_tables(conn, *tables, cascade=True)
     conn.commit()
 
 
     conn.commit()
 
 
@@ -55,4 +53,4 @@ def is_frozen(conn: Connection) -> bool:
     """ Returns true if database is in a frozen state
     """
 
     """ Returns true if database is in a frozen state
     """
 
-    return conn.table_exists('place') is False
+    return table_exists(conn, 'place') is False
index e6803c7dea23521a805aa939584d028f66b4a7b5..46ba0125bc90a73ecafed806032aa5af13af16f2 100644 (file)
@@ -15,7 +15,8 @@ from psycopg2 import sql as pysql
 from ..errors import UsageError
 from ..config import Configuration
 from ..db import properties
 from ..errors import UsageError
 from ..config import Configuration
 from ..db import properties
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, server_version_tuple,\
+                            table_has_column, table_exists, execute_scalar, register_hstore
 from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version
 from ..tokenizer import factory as tokenizer_factory
 from . import refresh
 from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version
 from ..tokenizer import factory as tokenizer_factory
 from . import refresh
@@ -29,7 +30,8 @@ def migrate(config: Configuration, paths: Any) -> int:
         if necesssary.
     """
     with connect(config.get_libpq_dsn()) as conn:
         if necesssary.
     """
     with connect(config.get_libpq_dsn()) as conn:
-        if conn.table_exists('nominatim_properties'):
+        register_hstore(conn)
+        if table_exists(conn, 'nominatim_properties'):
             db_version_str = properties.get_property(conn, 'database_version')
         else:
             db_version_str = None
             db_version_str = properties.get_property(conn, 'database_version')
         else:
             db_version_str = None
@@ -72,16 +74,15 @@ def _guess_version(conn: Connection) -> NominatimVersion:
         Only migrations for 3.6 and later are supported, so bail out
         when the version seems older.
     """
         Only migrations for 3.6 and later are supported, so bail out
         when the version seems older.
     """
-    with conn.cursor() as cur:
-        # In version 3.6, the country_name table was updated. Check for that.
-        cnt = cur.scalar("""SELECT count(*) FROM
-                            (SELECT svals(name) FROM  country_name
-                             WHERE country_code = 'gb')x;
-                         """)
-        if cnt < 100:
-            LOG.fatal('It looks like your database was imported with a version '
-                      'prior to 3.6.0. Automatic migration not possible.')
-            raise UsageError('Migration not possible.')
+    # In version 3.6, the country_name table was updated. Check for that.
+    cnt = execute_scalar(conn, """SELECT count(*) FROM
+                                  (SELECT svals(name) FROM  country_name
+                                  WHERE country_code = 'gb')x;
+                               """)
+    if cnt < 100:
+        LOG.fatal('It looks like your database was imported with a version '
+                  'prior to 3.6.0. Automatic migration not possible.')
+        raise UsageError('Migration not possible.')
 
     return NominatimVersion(3, 5, 0, 99)
 
 
     return NominatimVersion(3, 5, 0, 99)
 
@@ -125,7 +126,7 @@ def import_status_timestamp_change(conn: Connection, **_: Any) -> None:
 def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None:
     """ Add nominatim_property table.
     """
 def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None:
     """ Add nominatim_property table.
     """
-    if not conn.table_exists('nominatim_properties'):
+    if not table_exists(conn, 'nominatim_properties'):
         with conn.cursor() as cur:
             cur.execute(pysql.SQL("""CREATE TABLE nominatim_properties (
                                         property TEXT,
         with conn.cursor() as cur:
             cur.execute(pysql.SQL("""CREATE TABLE nominatim_properties (
                                         property TEXT,
@@ -189,13 +190,9 @@ def install_legacy_tokenizer(conn: Connection, config: Configuration, **_: Any)
         configuration for the backwards-compatible legacy tokenizer
     """
     if properties.get_property(conn, 'tokenizer') is None:
         configuration for the backwards-compatible legacy tokenizer
     """
     if properties.get_property(conn, 'tokenizer') is None:
-        with conn.cursor() as cur:
-            for table in ('placex', 'location_property_osmline'):
-                has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
-                                           WHERE table_name = %s
-                                           and column_name = 'token_info'""",
-                                        (table, ))
-                if has_column == 0:
+        for table in ('placex', 'location_property_osmline'):
+            if not table_has_column(conn, table, 'token_info'):
+                with conn.cursor() as cur:
                     cur.execute(pysql.SQL('ALTER TABLE {} ADD COLUMN token_info JSONB')
                                 .format(pysql.Identifier(table)))
         tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False,
                     cur.execute(pysql.SQL('ALTER TABLE {} ADD COLUMN token_info JSONB')
                                 .format(pysql.Identifier(table)))
         tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False,
@@ -212,7 +209,7 @@ def create_tiger_housenumber_index(conn: Connection, **_: Any) -> None:
         The inclusion is needed for efficient lookup of housenumbers in
         full address searches.
     """
         The inclusion is needed for efficient lookup of housenumbers in
         full address searches.
     """
-    if conn.server_version_tuple() >= (11, 0, 0):
+    if server_version_tuple(conn) >= (11, 0, 0):
         with conn.cursor() as cur:
             cur.execute(""" CREATE INDEX IF NOT EXISTS
                                 idx_location_property_tiger_housenumber_migrated
         with conn.cursor() as cur:
             cur.execute(""" CREATE INDEX IF NOT EXISTS
                                 idx_location_property_tiger_housenumber_migrated
@@ -239,7 +236,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
         Also converts the data into the stricter format which requires that
         startnumbers comply with the odd/even requirements.
     """
         Also converts the data into the stricter format which requires that
         startnumbers comply with the odd/even requirements.
     """
-    if conn.table_has_column('location_property_osmline', 'step'):
+    if table_has_column(conn, 'location_property_osmline', 'step'):
         return
 
     with conn.cursor() as cur:
         return
 
     with conn.cursor() as cur:
@@ -271,7 +268,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
 def add_step_column_for_tiger(conn: Connection, **_: Any) -> None:
     """ Add a new column 'step' to the tiger data table.
     """
 def add_step_column_for_tiger(conn: Connection, **_: Any) -> None:
     """ Add a new column 'step' to the tiger data table.
     """
-    if conn.table_has_column('location_property_tiger', 'step'):
+    if table_has_column(conn, 'location_property_tiger', 'step'):
         return
 
     with conn.cursor() as cur:
         return
 
     with conn.cursor() as cur:
@@ -287,7 +284,7 @@ def add_derived_name_column_for_country_names(conn: Connection, **_: Any) -> Non
     """ Add a new column 'derived_name' which in the future takes the
         country names as imported from OSM data.
     """
     """ Add a new column 'derived_name' which in the future takes the
         country names as imported from OSM data.
     """
-    if not conn.table_has_column('country_name', 'derived_name'):
+    if not table_has_column(conn, 'country_name', 'derived_name'):
         with conn.cursor() as cur:
             cur.execute("ALTER TABLE country_name ADD COLUMN derived_name public.HSTORE")
 
         with conn.cursor() as cur:
             cur.execute("ALTER TABLE country_name ADD COLUMN derived_name public.HSTORE")
 
@@ -297,12 +294,9 @@ def mark_internal_country_names(conn: Connection, config: Configuration, **_: An
     """ Names from the country table should be marked as internal to prevent
         them from being deleted. Only necessary for ICU tokenizer.
     """
     """ Names from the country table should be marked as internal to prevent
         them from being deleted. Only necessary for ICU tokenizer.
     """
-    import psycopg2.extras # pylint: disable=import-outside-toplevel
-
     tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
     with tokenizer.name_analyzer() as analyzer:
         with conn.cursor() as cur:
     tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
     with tokenizer.name_analyzer() as analyzer:
         with conn.cursor() as cur:
-            psycopg2.extras.register_hstore(cur)
             cur.execute("SELECT country_code, name FROM country_name")
 
             for country_code, names in cur:
             cur.execute("SELECT country_code, name FROM country_name")
 
             for country_code, names in cur:
@@ -319,7 +313,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
         The table is only necessary when updates are possible, i.e.
         the database is not in freeze mode.
     """
         The table is only necessary when updates are possible, i.e.
         the database is not in freeze mode.
     """
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         with conn.cursor() as cur:
             cur.execute("""CREATE TABLE IF NOT EXISTS place_to_be_deleted (
                              osm_type CHAR(1),
         with conn.cursor() as cur:
             cur.execute("""CREATE TABLE IF NOT EXISTS place_to_be_deleted (
                              osm_type CHAR(1),
@@ -333,7 +327,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
 def split_pending_index(conn: Connection, **_: Any) -> None:
     """ Reorganise indexes for pending updates.
     """
 def split_pending_index(conn: Connection, **_: Any) -> None:
     """ Reorganise indexes for pending updates.
     """
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS idx_placex_rank_address_sector
                            ON placex USING BTREE (rank_address, geometry_sector)
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS idx_placex_rank_address_sector
                            ON placex USING BTREE (rank_address, geometry_sector)
@@ -349,7 +343,7 @@ def split_pending_index(conn: Connection, **_: Any) -> None:
 def enable_forward_dependencies(conn: Connection, **_: Any) -> None:
     """ Create indexes for updates with forward dependency tracking (long-running).
     """
 def enable_forward_dependencies(conn: Connection, **_: Any) -> None:
     """ Create indexes for updates with forward dependency tracking (long-running).
     """
-    if conn.table_exists('planet_osm_ways'):
+    if table_exists(conn, 'planet_osm_ways'):
         with conn.cursor() as cur:
             cur.execute("""SELECT * FROM pg_indexes
                            WHERE tablename = 'planet_osm_ways'
         with conn.cursor() as cur:
             cur.execute("""SELECT * FROM pg_indexes
                            WHERE tablename = 'planet_osm_ways'
@@ -398,7 +392,7 @@ def create_postcode_area_lookup_index(conn: Connection, **_: Any) -> None:
 def create_postcode_parent_index(conn: Connection, **_: Any) -> None:
     """ Create index needed for updating postcodes when a parent changes.
     """
 def create_postcode_parent_index(conn: Connection, **_: Any) -> None:
     """ Create index needed for updating postcodes when a parent changes.
     """
-    if conn.table_exists('planet_osm_ways'):
+    if table_exists(conn, 'planet_osm_ways'):
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS
                              idx_location_postcode_parent_place_id
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS
                              idx_location_postcode_parent_place_id
index 8dc5bdbdb43cacb46fec70a86d553734a4ef92d7..a5d8ef8becf0ac53f562dbd45acc856cd4b8cfb9 100644 (file)
@@ -18,7 +18,7 @@ from math import isfinite
 
 from psycopg2 import sql as pysql
 
 
 from psycopg2 import sql as pysql
 
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, table_exists
 from ..utils.centroid import PointsCentroid
 from ..data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher
 from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
 from ..utils.centroid import PointsCentroid
 from ..data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher
 from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
@@ -231,4 +231,4 @@ def can_compute(dsn: str) -> bool:
         postcodes can be computed.
     """
     with connect(dsn) as conn:
         postcodes can be computed.
     """
     with connect(dsn) as conn:
-        return conn.table_exists('place')
+        return table_exists(conn, 'place')
index 6a40c0a73a4c6427d11fe89b0db0552c19a1e7cf..2e2ffabd07c60ea959bf1cb9aeaf9efa4cfe2b21 100644 (file)
@@ -17,7 +17,8 @@ from pathlib import Path
 from psycopg2 import sql as pysql
 
 from ..config import Configuration
 from psycopg2 import sql as pysql
 
 from ..config import Configuration
-from ..db.connection import Connection, connect
+from ..db.connection import Connection, connect, postgis_version_tuple,\
+                            drop_tables, table_exists
 from ..db.utils import execute_file, CopyBuffer
 from ..db.sql_preprocessor import SQLPreprocessor
 from ..version import NOMINATIM_VERSION
 from ..db.utils import execute_file, CopyBuffer
 from ..db.sql_preprocessor import SQLPreprocessor
 from ..version import NOMINATIM_VERSION
@@ -56,9 +57,9 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s
     for entry in levels:
         _add_address_level_rows_from_entry(rows, entry)
 
     for entry in levels:
         _add_address_level_rows_from_entry(rows, entry)
 
-    with conn.cursor() as cur:
-        cur.drop_table(table)
+    drop_tables(conn, table)
 
 
+    with conn.cursor() as cur:
         cur.execute(pysql.SQL("""CREATE TABLE {} (
                                         country_code varchar(2),
                                         class TEXT,
         cur.execute(pysql.SQL("""CREATE TABLE {} (
                                         country_code varchar(2),
                                         class TEXT,
@@ -159,10 +160,8 @@ def import_importance_csv(dsn: str, data_file: Path) -> int:
     wd_done = set()
 
     with connect(dsn) as conn:
     wd_done = set()
 
     with connect(dsn) as conn:
+        drop_tables(conn, 'wikipedia_article', 'wikipedia_redirect', 'wikimedia_importance')
         with conn.cursor() as cur:
         with conn.cursor() as cur:
-            cur.drop_table('wikipedia_article')
-            cur.drop_table('wikipedia_redirect')
-            cur.drop_table('wikimedia_importance')
             cur.execute("""CREATE TABLE wikimedia_importance (
                              language TEXT NOT NULL,
                              title TEXT NOT NULL,
             cur.execute("""CREATE TABLE wikimedia_importance (
                              language TEXT NOT NULL,
                              title TEXT NOT NULL,
@@ -228,7 +227,7 @@ def import_secondary_importance(dsn: str, data_path: Path, ignore_errors: bool =
         return 1
 
     with connect(dsn) as conn:
         return 1
 
     with connect(dsn) as conn:
-        postgis_version = conn.postgis_version_tuple()
+        postgis_version = postgis_version_tuple(conn)
         if postgis_version[0] < 3:
             LOG.error('PostGIS version is too old for using OSM raster data.')
             return 2
         if postgis_version[0] < 3:
             LOG.error('PostGIS version is too old for using OSM raster data.')
             return 2
@@ -309,7 +308,7 @@ def setup_website(basedir: Path, config: Configuration, conn: Connection) -> Non
 
     template = "\nrequire_once(CONST_LibDir.'/website/{}');\n"
 
 
     template = "\nrequire_once(CONST_LibDir.'/website/{}');\n"
 
-    search_name_table_exists = bool(conn and conn.table_exists('search_name'))
+    search_name_table_exists = bool(conn and table_exists(conn, 'search_name'))
 
     for script in WEBSITE_SCRIPTS:
         if not search_name_table_exists and script == 'search.php':
 
     for script in WEBSITE_SCRIPTS:
         if not search_name_table_exists and script == 'search.php':
index bf1189df032a85a40a985c20e783edfd74410e29..2b1d444f0cb1e3cd08b6540273ac9b64d136e43f 100644 (file)
@@ -20,7 +20,7 @@ import requests
 
 from ..errors import UsageError
 from ..db import status
 
 from ..errors import UsageError
 from ..db import status
-from ..db.connection import Connection, connect
+from ..db.connection import Connection, connect, server_version_tuple
 from .exec_utils import run_osm2pgsql
 
 try:
 from .exec_utils import run_osm2pgsql
 
 try:
@@ -155,7 +155,7 @@ def run_osm2pgsql_updates(conn: Connection, options: MutableMapping[str, Any]) -
 
     # Consume updates with osm2pgsql.
     options['append'] = True
 
     # Consume updates with osm2pgsql.
     options['append'] = True
-    options['disable_jit'] = conn.server_version_tuple() >= (11, 0)
+    options['disable_jit'] = server_version_tuple(conn) >= (11, 0)
     run_osm2pgsql(options)
 
     # Handle deletions
     run_osm2pgsql(options)
 
     # Handle deletions
index 1bdcdaf133d5d3db10fdb18e23c2b3f6ea24abce..4a63ff149507641f1e3fa2597e38c6570d6d93d4 100644 (file)
@@ -21,7 +21,7 @@ from psycopg2.sql import Identifier, SQL
 
 from ...typing import Protocol
 from ...config import Configuration
 
 from ...typing import Protocol
 from ...config import Configuration
-from ...db.connection import Connection
+from ...db.connection import Connection, drop_tables, index_exists
 from .importer_statistics import SpecialPhrasesImporterStatistics
 from .special_phrase import SpecialPhrase
 from ...tokenizer.base import AbstractTokenizer
 from .importer_statistics import SpecialPhrasesImporterStatistics
 from .special_phrase import SpecialPhrase
 from ...tokenizer.base import AbstractTokenizer
@@ -233,7 +233,7 @@ class SPImporter():
         index_prefix = f'idx_place_classtype_{phrase_class}_{phrase_type}_'
         base_table = _classtype_table(phrase_class, phrase_type)
         # Index on centroid
         index_prefix = f'idx_place_classtype_{phrase_class}_{phrase_type}_'
         base_table = _classtype_table(phrase_class, phrase_type)
         # Index on centroid
-        if not self.db_connection.index_exists(index_prefix + 'centroid'):
+        if not index_exists(self.db_connection, index_prefix + 'centroid'):
             with self.db_connection.cursor() as db_cursor:
                 db_cursor.execute(SQL("CREATE INDEX {} ON {} USING GIST (centroid) {}")
                                   .format(Identifier(index_prefix + 'centroid'),
             with self.db_connection.cursor() as db_cursor:
                 db_cursor.execute(SQL("CREATE INDEX {} ON {} USING GIST (centroid) {}")
                                   .format(Identifier(index_prefix + 'centroid'),
@@ -241,7 +241,7 @@ class SPImporter():
                                           SQL(sql_tablespace)))
 
         # Index on place_id
                                           SQL(sql_tablespace)))
 
         # Index on place_id
-        if not self.db_connection.index_exists(index_prefix + 'place_id'):
+        if not index_exists(self.db_connection, index_prefix + 'place_id'):
             with self.db_connection.cursor() as db_cursor:
                 db_cursor.execute(SQL("CREATE INDEX {} ON {} USING btree(place_id) {}")
                                   .format(Identifier(index_prefix + 'place_id'),
             with self.db_connection.cursor() as db_cursor:
                 db_cursor.execute(SQL("CREATE INDEX {} ON {} USING btree(place_id) {}")
                                   .format(Identifier(index_prefix + 'place_id'),
@@ -259,6 +259,7 @@ class SPImporter():
                               .format(Identifier(table_name),
                                       Identifier(self.config.DATABASE_WEBUSER)))
 
                               .format(Identifier(table_name),
                                       Identifier(self.config.DATABASE_WEBUSER)))
 
+
     def _remove_non_existent_tables_from_db(self) -> None:
         """
             Remove special phrases which doesn't exist on the wiki anymore.
     def _remove_non_existent_tables_from_db(self) -> None:
         """
             Remove special phrases which doesn't exist on the wiki anymore.
@@ -268,7 +269,6 @@ class SPImporter():
 
         # Delete place_classtype tables corresponding to class/type which
         # are not on the wiki anymore.
 
         # Delete place_classtype tables corresponding to class/type which
         # are not on the wiki anymore.
-        with self.db_connection.cursor() as db_cursor:
-            for table in self.table_phrases_to_delete:
-                self.statistics_handler.notify_one_table_deleted()
-                db_cursor.drop_table(table)
+        drop_tables(self.db_connection, *self.table_phrases_to_delete)
+        for _ in self.table_phrases_to_delete:
+            self.statistics_handler.notify_one_table_deleted()
index fff695e53ebc97520ba9893f6fff470846bf7269..9647bedc802471bd41808d297f0698a29bafda50 100644 (file)
@@ -33,13 +33,13 @@ def simple_conns(temp_db):
     conn2.close()
 
 
     conn2.close()
 
 
-def test_simple_query(conn, temp_db_conn):
+def test_simple_query(conn, temp_db_cursor):
     conn.connect()
 
     conn.perform('CREATE TABLE foo (id INT)')
     conn.wait()
 
     conn.connect()
 
     conn.perform('CREATE TABLE foo (id INT)')
     conn.wait()
 
-    temp_db_conn.table_exists('foo')
+    assert temp_db_cursor.table_exists('foo')
 
 
 def test_wait_for_query(conn):
 
 
 def test_wait_for_query(conn):
index 8b4cc62f0c46e6f8bd46db49bc5d3f3ab58cae58..9f1442f3aba6bcdbf523c618cd9184434cf99522 100644 (file)
@@ -10,61 +10,74 @@ Tests for specialised connection and cursor classes.
 import pytest
 import psycopg2
 
 import pytest
 import psycopg2
 
-from nominatim_db.db.connection import connect, get_pg_env
+import nominatim_db.db.connection as nc
 
 @pytest.fixture
 def db(dsn):
 
 @pytest.fixture
 def db(dsn):
-    with connect(dsn) as conn:
+    with nc.connect(dsn) as conn:
         yield conn
 
 
 def test_connection_table_exists(db, table_factory):
         yield conn
 
 
 def test_connection_table_exists(db, table_factory):
-    assert not db.table_exists('foobar')
+    assert not nc.table_exists(db, 'foobar')
 
     table_factory('foobar')
 
 
     table_factory('foobar')
 
-    assert db.table_exists('foobar')
+    assert nc.table_exists(db, 'foobar')
 
 
 def test_has_column_no_table(db):
 
 
 def test_has_column_no_table(db):
-    assert not db.table_has_column('sometable', 'somecolumn')
+    assert not nc.table_has_column(db, 'sometable', 'somecolumn')
 
 
 @pytest.mark.parametrize('name,result', [('tram', True), ('car', False)])
 def test_has_column(db, table_factory, name, result):
     table_factory('stuff', 'tram TEXT')
 
 
 
 @pytest.mark.parametrize('name,result', [('tram', True), ('car', False)])
 def test_has_column(db, table_factory, name, result):
     table_factory('stuff', 'tram TEXT')
 
-    assert db.table_has_column('stuff', name) == result
+    assert nc.table_has_column(db, 'stuff', name) == result
 
 def test_connection_index_exists(db, table_factory, temp_db_cursor):
 
 def test_connection_index_exists(db, table_factory, temp_db_cursor):
-    assert not db.index_exists('some_index')
+    assert not nc.index_exists(db, 'some_index')
 
     table_factory('foobar')
     temp_db_cursor.execute('CREATE INDEX some_index ON foobar(id)')
 
 
     table_factory('foobar')
     temp_db_cursor.execute('CREATE INDEX some_index ON foobar(id)')
 
-    assert db.index_exists('some_index')
-    assert db.index_exists('some_index', table='foobar')
-    assert not db.index_exists('some_index', table='bar')
+    assert nc.index_exists(db, 'some_index')
+    assert nc.index_exists(db, 'some_index', table='foobar')
+    assert not nc.index_exists(db, 'some_index', table='bar')
 
 
 def test_drop_table_existing(db, table_factory):
     table_factory('dummy')
 
 
 def test_drop_table_existing(db, table_factory):
     table_factory('dummy')
-    assert db.table_exists('dummy')
+    assert nc.table_exists(db, 'dummy')
 
 
-    db.drop_table('dummy')
-    assert not db.table_exists('dummy')
+    nc.drop_tables(db, 'dummy')
+    assert not nc.table_exists(db, 'dummy')
 
 
 
 
-def test_drop_table_non_existsing(db):
-    db.drop_table('dfkjgjriogjigjgjrdghehtre')
+def test_drop_table_non_existing(db):
+    nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre')
+
+
+def test_drop_many_tables(db, table_factory):
+    tables = [f'table{n}' for n in range(5)]
+
+    for t in tables:
+        table_factory(t)
+        assert nc.table_exists(db, t)
+
+    nc.drop_tables(db, *tables)
+
+    for t in tables:
+        assert not nc.table_exists(db, t)
 
 
 def test_drop_table_non_existing_force(db):
     with pytest.raises(psycopg2.ProgrammingError, match='.*does not exist.*'):
 
 
 def test_drop_table_non_existing_force(db):
     with pytest.raises(psycopg2.ProgrammingError, match='.*does not exist.*'):
-        db.drop_table('dfkjgjriogjigjgjrdghehtre', if_exists=False)
+        nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre', if_exists=False)
 
 def test_connection_server_version_tuple(db):
 
 def test_connection_server_version_tuple(db):
-    ver = db.server_version_tuple()
+    ver = nc.server_version_tuple(db)
 
     assert isinstance(ver, tuple)
     assert len(ver) == 2
 
     assert isinstance(ver, tuple)
     assert len(ver) == 2
@@ -72,7 +85,7 @@ def test_connection_server_version_tuple(db):
 
 
 def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
 
 
 def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
-    ver = db.postgis_version_tuple()
+    ver = nc.postgis_version_tuple(db)
 
     assert isinstance(ver, tuple)
     assert len(ver) == 2
 
     assert isinstance(ver, tuple)
     assert len(ver) == 2
@@ -82,27 +95,24 @@ def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
 def test_cursor_scalar(db, table_factory):
     table_factory('dummy')
 
 def test_cursor_scalar(db, table_factory):
     table_factory('dummy')
 
-    with db.cursor() as cur:
-        assert cur.scalar('SELECT count(*) FROM dummy') == 0
+    assert nc.execute_scalar(db, 'SELECT count(*) FROM dummy') == 0
 
 
 def test_cursor_scalar_many_rows(db):
 
 
 def test_cursor_scalar_many_rows(db):
-    with db.cursor() as cur:
-        with pytest.raises(RuntimeError):
-            cur.scalar('SELECT * FROM pg_tables')
+    with pytest.raises(RuntimeError, match='Query did not return a single row.'):
+        nc.execute_scalar(db, 'SELECT * FROM pg_tables')
 
 
 def test_cursor_scalar_no_rows(db, table_factory):
     table_factory('dummy')
 
 
 
 def test_cursor_scalar_no_rows(db, table_factory):
     table_factory('dummy')
 
-    with db.cursor() as cur:
-        with pytest.raises(RuntimeError):
-            cur.scalar('SELECT id FROM dummy')
+    with pytest.raises(RuntimeError, match='Query did not return a single row.'):
+        nc.execute_scalar(db, 'SELECT id FROM dummy')
 
 
 def test_get_pg_env_add_variable(monkeypatch):
     monkeypatch.delenv('PGPASSWORD', raising=False)
 
 
 def test_get_pg_env_add_variable(monkeypatch):
     monkeypatch.delenv('PGPASSWORD', raising=False)
-    env = get_pg_env('user=fooF')
+    env = nc.get_pg_env('user=fooF')
 
     assert env['PGUSER'] == 'fooF'
     assert 'PGPASSWORD' not in env
 
     assert env['PGUSER'] == 'fooF'
     assert 'PGPASSWORD' not in env
@@ -110,12 +120,12 @@ def test_get_pg_env_add_variable(monkeypatch):
 
 def test_get_pg_env_overwrite_variable(monkeypatch):
     monkeypatch.setenv('PGUSER', 'some default')
 
 def test_get_pg_env_overwrite_variable(monkeypatch):
     monkeypatch.setenv('PGUSER', 'some default')
-    env = get_pg_env('user=overwriter')
+    env = nc.get_pg_env('user=overwriter')
 
     assert env['PGUSER'] == 'overwriter'
 
 
 def test_get_pg_env_ignore_unknown():
 
     assert env['PGUSER'] == 'overwriter'
 
 
 def test_get_pg_env_ignore_unknown():
-    env = get_pg_env('client_encoding=stuff', base_env={})
+    env = nc.get_pg_env('client_encoding=stuff', base_env={})
 
     assert env == {}
 
     assert env == {}
index 5c465e8b7e988b6488756f9a8fb8a5c54278ba94..67be1892d4504e6b02e1afdf4848a89dfccf440a 100644 (file)
@@ -8,6 +8,7 @@
 Legacy word table for testing with functions to prefil and test contents
 of the table.
 """
 Legacy word table for testing with functions to prefil and test contents
 of the table.
 """
+from nominatim_db.db.connection import execute_scalar
 
 class MockIcuWordTable:
     """ A word table for testing using legacy word table structure.
 
 class MockIcuWordTable:
     """ A word table for testing using legacy word table structure.
@@ -77,18 +78,15 @@ class MockIcuWordTable:
 
 
     def count(self):
 
 
     def count(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word")
 
 
     def count_special(self):
 
 
     def count_special(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word WHERE type = 'S'")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'S'")
 
 
     def count_housenumbers(self):
 
 
     def count_housenumbers(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word WHERE type = 'H'")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'H'")
 
 
     def get_special(self):
 
 
     def get_special(self):
index 9804341f41c33fd91df794e15fcf935f3f00d395..d1c523eb7d9a4c766f8d4057518c5c3945d91b52 100644 (file)
@@ -8,6 +8,7 @@
 Legacy word table for testing with functions to prefil and test contents
 of the table.
 """
 Legacy word table for testing with functions to prefil and test contents
 of the table.
 """
+from nominatim_db.db.connection import execute_scalar
 
 class MockLegacyWordTable:
     """ A word table for testing using legacy word table structure.
 
 class MockLegacyWordTable:
     """ A word table for testing using legacy word table structure.
@@ -58,13 +59,11 @@ class MockLegacyWordTable:
 
 
     def count(self):
 
 
     def count(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word")
 
 
     def count_special(self):
 
 
     def count_special(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM word WHERE class != 'place'")
+        return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE class != 'place'")
 
 
     def get_special(self):
 
 
     def get_special(self):
index 357b7d4ae4cde571315c2a7d8109595911e343f4..a2bf676699ec9a326f59957740740899da8340dc 100644 (file)
@@ -199,16 +199,16 @@ def test_update_sql_functions(db_prop, temp_db_cursor,
     assert test_content == set((('1133', ), ))
 
 
     assert test_content == set((('1133', ), ))
 
 
-def test_finalize_import(tokenizer_factory, temp_db_conn,
-                         temp_db_cursor, test_config, sql_preprocessor_cfg):
+def test_finalize_import(tokenizer_factory, temp_db_cursor,
+                         test_config, sql_preprocessor_cfg):
     tok = tokenizer_factory()
     tok.init_new_db(test_config)
 
     tok = tokenizer_factory()
     tok.init_new_db(test_config)
 
-    assert not temp_db_conn.index_exists('idx_word_word_id')
+    assert not temp_db_cursor.index_exists('word', 'idx_word_word_id')
 
     tok.finalize_import(test_config)
 
 
     tok.finalize_import(test_config)
 
-    assert temp_db_conn.index_exists('idx_word_word_id')
+    assert temp_db_cursor.index_exists('word', 'idx_word_word_id')
 
 
 def test_check_database(test_config, tokenizer_factory,
 
 
 def test_check_database(test_config, tokenizer_factory,
index 548ec800b6b8f966427911dd2d91a90d19c8cb64..9d56efa0fd744a5a736649d0f2970b5a95ea794b 100644 (file)
@@ -132,7 +132,7 @@ def test_import_osm_data_simple_ignore_no_data(table_factory, osm2pgsql_options)
                                     ignore_errors=True)
 
 
                                     ignore_errors=True)
 
 
-def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_options):
+def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql_options):
     table_factory('place', content=((1, ), ))
     table_factory('planet_osm_nodes')
 
     table_factory('place', content=((1, ), ))
     table_factory('planet_osm_nodes')
 
@@ -144,7 +144,7 @@ def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_o
     database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options, drop=True)
 
     assert not flatfile.exists()
     database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options, drop=True)
 
     assert not flatfile.exists()
-    assert not temp_db_conn.table_exists('planet_osm_nodes')
+    assert not temp_db_cursor.table_exists('planet_osm_nodes')
 
 
 def test_import_osm_data_default_cache(table_factory, osm2pgsql_options, capfd):
 
 
 def test_import_osm_data_default_cache(table_factory, osm2pgsql_options, capfd):
index 64eb7b1856b0996c3958ae71e51b3ecd3fbbdbb4..0d33e6e0f30e4a0ab47bfb02a4cd63f2ad4c55a9 100644 (file)
@@ -75,7 +75,8 @@ def test_load_white_and_black_lists(sp_importer):
     assert isinstance(black_list, dict) and isinstance(white_list, dict)
 
 
     assert isinstance(black_list, dict) and isinstance(white_list, dict)
 
 
-def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn,
+def test_create_place_classtype_indexes(temp_db_with_extensions,
+                                        temp_db_conn, temp_db_cursor,
                                         table_factory, sp_importer):
     """
         Test that _create_place_classtype_indexes() create the
                                         table_factory, sp_importer):
     """
         Test that _create_place_classtype_indexes() create the
@@ -88,10 +89,11 @@ def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn,
     table_factory(table_name, 'place_id BIGINT, centroid GEOMETRY')
 
     sp_importer._create_place_classtype_indexes('', phrase_class, phrase_type)
     table_factory(table_name, 'place_id BIGINT, centroid GEOMETRY')
 
     sp_importer._create_place_classtype_indexes('', phrase_class, phrase_type)
+    temp_db_conn.commit()
 
 
-    assert check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type)
+    assert check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type)
 
 
-def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer):
+def test_create_place_classtype_table(temp_db_conn, temp_db_cursor, placex_table, sp_importer):
     """
         Test that _create_place_classtype_table() create
         the right place_classtype table.
     """
         Test that _create_place_classtype_table() create
         the right place_classtype table.
@@ -99,10 +101,12 @@ def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer):
     phrase_class = 'class'
     phrase_type = 'type'
     sp_importer._create_place_classtype_table('', phrase_class, phrase_type)
     phrase_class = 'class'
     phrase_type = 'type'
     sp_importer._create_place_classtype_table('', phrase_class, phrase_type)
+    temp_db_conn.commit()
 
 
-    assert check_table_exist(temp_db_conn, phrase_class, phrase_type)
+    assert check_table_exist(temp_db_cursor, phrase_class, phrase_type)
 
 
-def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_importer):
+def test_grant_access_to_web_user(temp_db_conn, temp_db_cursor, table_factory,
+                                  def_config, sp_importer):
     """
         Test that _grant_access_to_webuser() give
         right access to the web user.
     """
         Test that _grant_access_to_webuser() give
         right access to the web user.
@@ -114,12 +118,13 @@ def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_im
     table_factory(table_name)
 
     sp_importer._grant_access_to_webuser(phrase_class, phrase_type)
     table_factory(table_name)
 
     sp_importer._grant_access_to_webuser(phrase_class, phrase_type)
+    temp_db_conn.commit()
 
 
-    assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, phrase_class, phrase_type)
+    assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, phrase_class, phrase_type)
 
 def test_create_place_classtype_table_and_indexes(
 
 def test_create_place_classtype_table_and_indexes(
-        temp_db_conn, def_config, placex_table,
-        sp_importer):
+        temp_db_cursor, def_config, placex_table,
+        sp_importer, temp_db_conn):
     """
         Test that _create_place_classtype_table_and_indexes()
         create the right place_classtype tables and place_id indexes
     """
         Test that _create_place_classtype_table_and_indexes()
         create the right place_classtype tables and place_id indexes
@@ -129,14 +134,15 @@ def test_create_place_classtype_table_and_indexes(
     pairs = set([('class1', 'type1'), ('class2', 'type2')])
 
     sp_importer._create_classtype_table_and_indexes(pairs)
     pairs = set([('class1', 'type1'), ('class2', 'type2')])
 
     sp_importer._create_classtype_table_and_indexes(pairs)
+    temp_db_conn.commit()
 
     for pair in pairs:
 
     for pair in pairs:
-        assert check_table_exist(temp_db_conn, pair[0], pair[1])
-        assert check_placeid_and_centroid_indexes(temp_db_conn, pair[0], pair[1])
-        assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, pair[0], pair[1])
+        assert check_table_exist(temp_db_cursor, pair[0], pair[1])
+        assert check_placeid_and_centroid_indexes(temp_db_cursor, pair[0], pair[1])
+        assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, pair[0], pair[1])
 
 def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
 
 def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
-                                            temp_db_conn):
+                                            temp_db_conn, temp_db_cursor):
     """
         Check for the remove_non_existent_phrases_from_db() method.
 
     """
         Check for the remove_non_existent_phrases_from_db() method.
 
@@ -159,15 +165,14 @@ def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
     """
 
     sp_importer._remove_non_existent_tables_from_db()
     """
 
     sp_importer._remove_non_existent_tables_from_db()
+    temp_db_conn.commit()
 
 
-    # Changes are not committed yet. Use temp_db_conn for checking results.
-    with temp_db_conn.cursor(cursor_factory=CursorForTesting) as cur:
-        assert cur.row_set(query_tables) \
+    assert temp_db_cursor.row_set(query_tables) \
                  == {('place_classtype_testclasstypetable_to_keep', )}
 
 
 @pytest.mark.parametrize("should_replace", [(True), (False)])
                  == {('place_classtype_testclasstypetable_to_keep', )}
 
 
 @pytest.mark.parametrize("should_replace", [(True), (False)])
-def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer,
+def test_import_phrases(monkeypatch, temp_db_cursor, def_config, sp_importer,
                         placex_table, table_factory, tokenizer_mock,
                         xml_wiki_content, should_replace):
     """
                         placex_table, table_factory, tokenizer_mock,
                         xml_wiki_content, should_replace):
     """
@@ -193,49 +198,49 @@ def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer,
     class_test = 'aerialway'
     type_test = 'zip_line'
 
     class_test = 'aerialway'
     type_test = 'zip_line'
 
-    assert check_table_exist(temp_db_conn, class_test, type_test)
-    assert check_placeid_and_centroid_indexes(temp_db_conn, class_test, type_test)
-    assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, class_test, type_test)
-    assert check_table_exist(temp_db_conn, 'amenity', 'animal_shelter')
+    assert check_table_exist(temp_db_cursor, class_test, type_test)
+    assert check_placeid_and_centroid_indexes(temp_db_cursor, class_test, type_test)
+    assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, class_test, type_test)
+    assert check_table_exist(temp_db_cursor, 'amenity', 'animal_shelter')
     if should_replace:
     if should_replace:
-        assert not check_table_exist(temp_db_conn, 'wrong_class', 'wrong_type')
+        assert not check_table_exist(temp_db_cursor, 'wrong_class', 'wrong_type')
 
 
-    assert temp_db_conn.table_exists('place_classtype_amenity_animal_shelter')
+    assert temp_db_cursor.table_exists('place_classtype_amenity_animal_shelter')
     if should_replace:
     if should_replace:
-        assert not temp_db_conn.table_exists('place_classtype_wrongclass_wrongtype')
+        assert not temp_db_cursor.table_exists('place_classtype_wrongclass_wrongtype')
 
 
-def check_table_exist(temp_db_conn, phrase_class, phrase_type):
+def check_table_exist(temp_db_cursor, phrase_class, phrase_type):
     """
         Verify that the place_classtype table exists for the given
         phrase_class and phrase_type.
     """
     """
         Verify that the place_classtype table exists for the given
         phrase_class and phrase_type.
     """
-    return temp_db_conn.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type))
+    return temp_db_cursor.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type))
 
 
 
 
-def check_grant_access(temp_db_conn, user, phrase_class, phrase_type):
+def check_grant_access(temp_db_cursor, user, phrase_class, phrase_type):
     """
         Check that the web user has been granted right access to the
         place_classtype table of the given phrase_class and phrase_type.
     """
     table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
 
     """
         Check that the web user has been granted right access to the
         place_classtype table of the given phrase_class and phrase_type.
     """
     table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
 
-    with temp_db_conn.cursor() as temp_db_cursor:
-        temp_db_cursor.execute("""
-                SELECT * FROM information_schema.role_table_grants
-                WHERE table_name='{}'
-                AND grantee='{}'
-                AND privilege_type='SELECT'""".format(table_name, user))
-        return temp_db_cursor.fetchone()
+    temp_db_cursor.execute("""
+            SELECT * FROM information_schema.role_table_grants
+            WHERE table_name='{}'
+            AND grantee='{}'
+            AND privilege_type='SELECT'""".format(table_name, user))
+    return temp_db_cursor.fetchone()
 
 
-def check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type):
+def check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type):
     """
         Check that the place_id index and centroid index exist for the
         place_classtype table of the given phrase_class and phrase_type.
     """
     """
         Check that the place_id index and centroid index exist for the
         place_classtype table of the given phrase_class and phrase_type.
     """
+    table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
     index_prefix = 'idx_place_classtype_{}_{}_'.format(phrase_class, phrase_type)
 
     return (
     index_prefix = 'idx_place_classtype_{}_{}_'.format(phrase_class, phrase_type)
 
     return (
-        temp_db_conn.index_exists(index_prefix + 'centroid')
+        temp_db_cursor.index_exists(table_name, index_prefix + 'centroid')
         and
         and
-        temp_db_conn.index_exists(index_prefix + 'place_id')
+        temp_db_cursor.index_exists(table_name, index_prefix + 'place_id')
     )
     )
index 3a849adbf956ed8469730c9cb6eb2d2d377eff0c..8821f694c927d5becec13db65972782f80fee4e8 100644 (file)
@@ -12,6 +12,7 @@ import psycopg2.extras
 
 from nominatim_db.tools import migration
 from nominatim_db.errors import UsageError
 
 from nominatim_db.tools import migration
 from nominatim_db.errors import UsageError
+from nominatim_db.db.connection import server_version_tuple
 import nominatim_db.version
 
 from mock_legacy_word_table import MockLegacyWordTable
 import nominatim_db.version
 
 from mock_legacy_word_table import MockLegacyWordTable
@@ -63,7 +64,7 @@ def test_set_up_migration_for_36(temp_db_with_extensions, temp_db_cursor,
                                           WHERE property = 'database_version'""")
 
 
                                           WHERE property = 'database_version'""")
 
 
-def test_already_at_version(def_config, property_table):
+def test_already_at_version(temp_db_with_extensions, def_config, property_table):
 
     property_table.set('database_version',
                        str(nominatim_db.version.NOMINATIM_VERSION))
 
     property_table.set('database_version',
                        str(nominatim_db.version.NOMINATIM_VERSION))
@@ -71,8 +72,8 @@ def test_already_at_version(def_config, property_table):
     assert migration.migrate(def_config, {}) == 0
 
 
     assert migration.migrate(def_config, {}) == 0
 
 
-def test_run_single_migration(def_config, temp_db_cursor, property_table,
-                              monkeypatch, postprocess_mock):
+def test_run_single_migration(temp_db_with_extensions, def_config, temp_db_cursor,
+                              property_table, monkeypatch, postprocess_mock):
     oldversion = [x for x in nominatim_db.version.NOMINATIM_VERSION]
     oldversion[0] -= 1
     property_table.set('database_version',
     oldversion = [x for x in nominatim_db.version.NOMINATIM_VERSION]
     oldversion[0] -= 1
     property_table.set('database_version',
@@ -226,7 +227,7 @@ def test_create_tiger_housenumber_index(temp_db_conn, temp_db_cursor, table_fact
     migration.create_tiger_housenumber_index(temp_db_conn)
     temp_db_conn.commit()
 
     migration.create_tiger_housenumber_index(temp_db_conn)
     temp_db_conn.commit()
 
-    if temp_db_conn.server_version_tuple() >= (11, 0, 0):
+    if server_version_tuple(temp_db_conn) >= (11, 0, 0):
         assert temp_db_cursor.index_exists('location_property_tiger',
                                            'idx_location_property_tiger_housenumber_migrated')
 
         assert temp_db_cursor.index_exists('location_property_tiger',
                                            'idx_location_property_tiger_housenumber_migrated')
 
index 50ff6398b5673cdc073b4d4c62af6831729f4def..1f1968cfac7fe0c1a0e214037aa097728cfe0efe 100644 (file)
@@ -12,6 +12,7 @@ from pathlib import Path
 import pytest
 
 from nominatim_db.tools import refresh
 import pytest
 
 from nominatim_db.tools import refresh
+from nominatim_db.db.connection import postgis_version_tuple
 
 def test_refresh_import_wikipedia_not_existing(dsn):
     assert refresh.import_wikipedia_articles(dsn, Path('.')) == 1
 
 def test_refresh_import_wikipedia_not_existing(dsn):
     assert refresh.import_wikipedia_articles(dsn, Path('.')) == 1
@@ -23,13 +24,13 @@ def test_refresh_import_secondary_importance_non_existing(dsn):
 def test_refresh_import_secondary_importance_testdb(dsn, src_dir, temp_db_conn, temp_db_cursor):
     temp_db_cursor.execute('CREATE EXTENSION postgis')
 
 def test_refresh_import_secondary_importance_testdb(dsn, src_dir, temp_db_conn, temp_db_cursor):
     temp_db_cursor.execute('CREATE EXTENSION postgis')
 
-    if temp_db_conn.postgis_version_tuple()[0] < 3:
+    if postgis_version_tuple(temp_db_conn)[0] < 3:
         assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') > 0
     else:
         temp_db_cursor.execute('CREATE EXTENSION postgis_raster')
         assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') == 0
 
         assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') > 0
     else:
         temp_db_cursor.execute('CREATE EXTENSION postgis_raster')
         assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') == 0
 
-        assert temp_db_conn.table_exists('secondary_importance')
+        assert temp_db_cursor.table_exists('secondary_importance')
 
 
 @pytest.mark.parametrize("replace", (True, False))
 
 
 @pytest.mark.parametrize("replace", (True, False))
index 7ef6a1e63306020f41d1dcb91863a4a01be6c51e..fc01f22ffb905bd1ff15435633ec6f730225fefc 100644 (file)
@@ -12,6 +12,7 @@ from textwrap import dedent
 
 import pytest
 
 
 import pytest
 
+from nominatim_db.db.connection import execute_scalar
 from nominatim_db.tools import tiger_data, freeze
 from nominatim_db.errors import UsageError
 
 from nominatim_db.tools import tiger_data, freeze
 from nominatim_db.errors import UsageError
 
@@ -31,8 +32,7 @@ class MockTigerTable:
             cur.execute("CREATE TABLE place (number INTEGER)")
 
     def count(self):
             cur.execute("CREATE TABLE place (number INTEGER)")
 
     def count(self):
-        with self.conn.cursor() as cur:
-            return cur.scalar("SELECT count(*) FROM tiger")
+        return execute_scalar(self.conn, "SELECT count(*) FROM tiger")
 
     def row(self):
         with self.conn.cursor() as cur:
 
     def row(self):
         with self.conn.cursor() as cur: