]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/connection.py
Added missing return types to functions
[nominatim.git] / nominatim / db / connection.py
index b941f46f56c63c74444506dec457bf09a8999c07..82801ae7995c9d1e5527baec0d9dd89c85e70e4d 100644 (file)
@@ -1,28 +1,52 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Specialised connection and cursor functions.
 """
 """
 Specialised connection and cursor functions.
 """
+from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
+import contextlib
 import logging
 import logging
+import os
 
 import psycopg2
 import psycopg2.extensions
 import psycopg2.extras
 
 import psycopg2
 import psycopg2.extensions
 import psycopg2.extras
+from psycopg2 import sql as pysql
 
 
-from ..errors import UsageError
+from nominatim.typing import SysEnv, Query, T_cursor
+from nominatim.errors import UsageError
 
 
-class _Cursor(psycopg2.extras.DictCursor):
+LOG = logging.getLogger()
+
+class Cursor(psycopg2.extras.DictCursor):
     """ A cursor returning dict-like objects and providing specialised
         execution functions.
     """
     """ A cursor returning dict-like objects and providing specialised
         execution functions.
     """
-
-    def execute(self, query, args=None): # pylint: disable=W0221
+    # pylint: disable=arguments-renamed,arguments-differ
+    def execute(self, query: Query, args: Any = None) -> None:
         """ Query execution that logs the SQL query when debugging is enabled.
         """
         """ Query execution that logs the SQL query when debugging is enabled.
         """
-        logger = logging.getLogger()
-        logger.debug(self.mogrify(query, args).decode('utf-8'))
+        if LOG.isEnabledFor(logging.DEBUG):
+            LOG.debug(self.mogrify(query, args).decode('utf-8'))
 
         super().execute(query, args)
 
 
         super().execute(query, args)
 
-    def scalar(self, sql, args=None):
+
+    def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
+                       template: Optional[Query] = None) -> None:
+        """ Wrapper for the psycopg2 convenience function to execute
+            SQL for a list of values.
+        """
+        LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
+
+        psycopg2.extras.execute_values(self, sql, argslist, template=template)
+
+
+    def scalar(self, sql: Query, args: Any = None) -> Any:
         """ Execute query that returns a single value. The value is returned.
             If the query yields more than one row, a ValueError is raised.
         """
         """ Execute query that returns a single value. The value is returned.
             If the query yields more than one row, a ValueError is raised.
         """
@@ -31,30 +55,71 @@ class _Cursor(psycopg2.extras.DictCursor):
         if self.rowcount != 1:
             raise RuntimeError("Query did not return a single row.")
 
         if self.rowcount != 1:
             raise RuntimeError("Query did not return a single row.")
 
-        return self.fetchone()[0]
+        result = self.fetchone()
+        assert result is not None
+
+        return result[0]
 
 
 
 
-class _Connection(psycopg2.extensions.connection):
+    def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
+        """ Drop the table with the given name.
+            Set `if_exists` to False if a non-existent table should raise
+            an exception instead of just being ignored. If 'cascade' is set
+            to True then all dependent tables are deleted as well.
+        """
+        sql = 'DROP TABLE '
+        if if_exists:
+            sql += 'IF EXISTS '
+        sql += '{}'
+        if cascade:
+            sql += ' CASCADE'
+
+        self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
+
+
+class Connection(psycopg2.extensions.connection):
     """ A connection that provides the specialised cursor by default and
         adds convenience functions for administrating the database.
     """
     """ A connection that provides the specialised cursor by default and
         adds convenience functions for administrating the database.
     """
+    @overload # type: ignore[override]
+    def cursor(self) -> Cursor:
+        ...
+
+    @overload
+    def cursor(self, name: str) -> Cursor:
+        ...
 
 
-    def cursor(self, cursor_factory=_Cursor, **kwargs):
+    @overload
+    def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
+        ...
+
+    def cursor(self, cursor_factory  = Cursor, **kwargs): # type: ignore
         """ Return a new cursor. By default the specialised cursor is returned.
         """
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
 
         """ Return a new cursor. By default the specialised cursor is returned.
         """
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
 
-    def table_exists(self, table):
+    def table_exists(self, table: str) -> bool:
         """ Check that a table with the given name exists in the database.
         """
         with self.cursor() as cur:
             num = cur.scalar("""SELECT count(*) FROM pg_tables
                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
         """ Check that a table with the given name exists in the database.
         """
         with self.cursor() as cur:
             num = cur.scalar("""SELECT count(*) FROM pg_tables
                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
-            return num == 1
+            return num == 1 if isinstance(num, int) else False
 
 
 
 
-    def index_exists(self, index, table=None):
+    def table_has_column(self, table: str, column: str) -> bool:
+        """ Check if the table 'table' exists and has a column with name 'column'.
+        """
+        with self.cursor() as cur:
+            has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
+                                       WHERE table_name = %s
+                                             and column_name = %s""",
+                                    (table, column))
+            return has_column > 0 if isinstance(has_column, int) else False
+
+
+    def index_exists(self, index: str, table: Optional[str] = None) -> bool:
         """ Check that an index with the given name exists in the database.
             If table is not None then the index must relate to the given
             table.
         """ Check that an index with the given name exists in the database.
             If table is not None then the index must relate to the given
             table.
@@ -67,26 +132,123 @@ class _Connection(psycopg2.extensions.connection):
 
             if table is not None:
                 row = cur.fetchone()
 
             if table is not None:
                 row = cur.fetchone()
+                if row is None or not isinstance(row[0], str):
+                    return False
                 return row[0] == table
 
         return True
 
 
                 return row[0] == table
 
         return True
 
 
-    def server_version_tuple(self):
+    def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
+        """ Drop the table with the given name.
+            Set `if_exists` to False if a non-existent table should raise
+            an exception instead of just being ignored.
+        """
+        with self.cursor() as cur:
+            cur.drop_table(name, if_exists, cascade)
+        self.commit()
+
+
+    def server_version_tuple(self) -> Tuple[int, int]:
         """ Return the server version as a tuple of (major, minor).
             Converts correctly for pre-10 and post-10 PostgreSQL versions.
         """
         version = self.server_version
         if version < 100000:
         """ Return the server version as a tuple of (major, minor).
             Converts correctly for pre-10 and post-10 PostgreSQL versions.
         """
         version = self.server_version
         if version < 100000:
-            return (version / 10000, (version % 10000) / 100)
+            return (int(version / 10000), int((version % 10000) / 100))
+
+        return (int(version / 10000), version % 10000)
+
+
+    def postgis_version_tuple(self) -> Tuple[int, int]:
+        """ Return the postgis version installed in the database as a
+            tuple of (major, minor). Assumes that the PostGIS extension
+            has been installed already.
+        """
+        with self.cursor() as cur:
+            version = cur.scalar('SELECT postgis_lib_version()')
+
+        version_parts = version.split('.')
+        if len(version_parts) < 2:
+            raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
+
+        return (int(version_parts[0]), int(version_parts[1]))
 
 
-        return (version / 10000, version % 10000)
 
 
-def connect(dsn):
+    def extension_loaded(self, extension_name: str) -> bool:
+        """ Return True if the hstore extension is loaded in the database.
+        """
+        with self.cursor() as cur:
+            cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
+            return cur.rowcount > 0
+
+
+class ConnectionContext(ContextManager[Connection]):
+    """ Context manager of the connection that also provides direct access
+        to the underlying connection.
+    """
+    connection: Connection
+
+def connect(dsn: str) -> ConnectionContext:
     """ Open a connection to the database using the specialised connection
     """ Open a connection to the database using the specialised connection
-        factory.
+        factory. The returned object may be used in conjunction with 'with'.
+        When used outside a context manager, use the `connection` attribute
+        to get the connection.
     """
     try:
     """
     try:
-        return psycopg2.connect(dsn, connection_factory=_Connection)
+        conn = psycopg2.connect(dsn, connection_factory=Connection)
+        ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
+        ctxmgr.connection = conn
+        return ctxmgr
     except psycopg2.OperationalError as err:
     except psycopg2.OperationalError as err:
-        raise UsageError("Cannot connect to database: {}".format(err)) from err
+        raise UsageError(f"Cannot connect to database: {err}") from err
+
+
+# Translation from PG connection string parameters to PG environment variables.
+# Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
+_PG_CONNECTION_STRINGS = {
+    'host': 'PGHOST',
+    'hostaddr': 'PGHOSTADDR',
+    'port': 'PGPORT',
+    'dbname': 'PGDATABASE',
+    'user': 'PGUSER',
+    'password': 'PGPASSWORD',
+    'passfile': 'PGPASSFILE',
+    'channel_binding': 'PGCHANNELBINDING',
+    'service': 'PGSERVICE',
+    'options': 'PGOPTIONS',
+    'application_name': 'PGAPPNAME',
+    'sslmode': 'PGSSLMODE',
+    'requiressl': 'PGREQUIRESSL',
+    'sslcompression': 'PGSSLCOMPRESSION',
+    'sslcert': 'PGSSLCERT',
+    'sslkey': 'PGSSLKEY',
+    'sslrootcert': 'PGSSLROOTCERT',
+    'sslcrl': 'PGSSLCRL',
+    'requirepeer': 'PGREQUIREPEER',
+    'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
+    'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
+    'gssencmode': 'PGGSSENCMODE',
+    'krbsrvname': 'PGKRBSRVNAME',
+    'gsslib': 'PGGSSLIB',
+    'connect_timeout': 'PGCONNECT_TIMEOUT',
+    'target_session_attrs': 'PGTARGETSESSIONATTRS',
+}
+
+
+def get_pg_env(dsn: str,
+               base_env: Optional[SysEnv] = None) -> Dict[str, str]:
+    """ Return a copy of `base_env` with the environment variables for
+        PostgresSQL set up from the given database connection string.
+        If `base_env` is None, then the OS environment is used as a base
+        environment.
+    """
+    env = dict(base_env if base_env is not None else os.environ)
+
+    for param, value in psycopg2.extensions.parse_dsn(dsn).items():
+        if param in _PG_CONNECTION_STRINGS:
+            env[_PG_CONNECTION_STRINGS[param]] = value
+        else:
+            LOG.error("Unknown connection parameter '%s' ignored.", param)
+
+    return env