]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/connection.py
avoid issues with Python < 3.9 and linting
[nominatim.git] / nominatim / db / connection.py
index 1c1152079af9e8437c4f685f1d52bd4afcd40bd3..729e8a700761843ea33e36f7630482cc063faa72 100644 (file)
@@ -7,6 +7,7 @@
 """
 Specialised connection and cursor functions.
 """
+from typing import List, Optional, Any, Callable, ContextManager, Mapping, cast, overload, Tuple
 import contextlib
 import logging
 import os
@@ -16,6 +17,7 @@ import psycopg2.extensions
 import psycopg2.extras
 from psycopg2 import sql as pysql
 
+from nominatim.typing import Query, T_cursor
 from nominatim.errors import UsageError
 
 LOG = logging.getLogger()
@@ -24,16 +26,18 @@ class _Cursor(psycopg2.extras.DictCursor):
     """ A cursor returning dict-like objects and providing specialised
         execution functions.
     """
-
-    def execute(self, query, args=None): # pylint: disable=W0221
+    # pylint: disable=arguments-renamed,arguments-differ
+    def execute(self, query: Query, args: Any = None) -> None:
         """ Query execution that logs the SQL query when debugging is enabled.
         """
-        LOG.debug(self.mogrify(query, args).decode('utf-8'))
+        if LOG.isEnabledFor(logging.DEBUG):
+            LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore
 
         super().execute(query, args)
 
 
-    def execute_values(self, sql, argslist, template=None):
+    def execute_values(self, sql: Query, argslist: List[Any],
+                       template: Optional[str] = None) -> None:
         """ Wrapper for the psycopg2 convenience function to execute
             SQL for a list of values.
         """
@@ -42,7 +46,7 @@ class _Cursor(psycopg2.extras.DictCursor):
         psycopg2.extras.execute_values(self, sql, argslist, template=template)
 
 
-    def scalar(self, sql, args=None):
+    def scalar(self, sql: Query, args: Any = None) -> Any:
         """ Execute query that returns a single value. The value is returned.
             If the query yields more than one row, a ValueError is raised.
         """
@@ -51,10 +55,13 @@ class _Cursor(psycopg2.extras.DictCursor):
         if self.rowcount != 1:
             raise RuntimeError("Query did not return a single row.")
 
-        return self.fetchone()[0]
+        result = self.fetchone() # type: ignore
+        assert result is not None
+
+        return result[0]
 
 
-    def drop_table(self, name, if_exists=True, cascade=False):
+    def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
         """ Drop the table with the given name.
             Set `if_exists` to False if a non-existant table should raise
             an exception instead of just being ignored. If 'cascade' is set
@@ -67,30 +74,52 @@ class _Cursor(psycopg2.extras.DictCursor):
         if cascade:
             sql += ' CASCADE'
 
-        self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
+        self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore
 
 
 class _Connection(psycopg2.extensions.connection):
     """ A connection that provides the specialised cursor by default and
         adds convenience functions for administrating the database.
     """
+    @overload # type: ignore[override]
+    def cursor(self) -> _Cursor:
+        ...
+
+    @overload
+    def cursor(self, name: str) -> _Cursor:
+        ...
 
-    def cursor(self, cursor_factory=_Cursor, **kwargs):
+    @overload
+    def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
+        ...
+
+    def cursor(self, cursor_factory  = _Cursor, **kwargs): # type: ignore
         """ Return a new cursor. By default the specialised cursor is returned.
         """
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
 
-    def table_exists(self, table):
+    def table_exists(self, table: str) -> bool:
         """ Check that a table with the given name exists in the database.
         """
         with self.cursor() as cur:
             num = cur.scalar("""SELECT count(*) FROM pg_tables
                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
-            return num == 1
+            return num == 1 if isinstance(num, int) else False
 
 
-    def index_exists(self, index, table=None):
+    def table_has_column(self, table: str, column: str) -> bool:
+        """ Check if the table 'table' exists and has a column with name 'column'.
+        """
+        with self.cursor() as cur:
+            has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
+                                       WHERE table_name = %s
+                                             and column_name = %s""",
+                                    (table, column))
+            return has_column > 0 if isinstance(has_column, int) else False
+
+
+    def index_exists(self, index: str, table: Optional[str] = None) -> bool:
         """ Check that an index with the given name exists in the database.
             If table is not None then the index must relate to the given
             table.
@@ -102,13 +131,15 @@ class _Connection(psycopg2.extensions.connection):
                 return False
 
             if table is not None:
-                row = cur.fetchone()
+                row = cur.fetchone() # type: ignore
+                if row is None or not isinstance(row[0], str):
+                    return False
                 return row[0] == table
 
         return True
 
 
-    def drop_table(self, name, if_exists=True, cascade=False):
+    def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
         """ Drop the table with the given name.
             Set `if_exists` to False if a non-existant table should raise
             an exception instead of just being ignored.
@@ -118,18 +149,18 @@ class _Connection(psycopg2.extensions.connection):
         self.commit()
 
 
-    def server_version_tuple(self):
+    def server_version_tuple(self) -> Tuple[int, int]:
         """ Return the server version as a tuple of (major, minor).
             Converts correctly for pre-10 and post-10 PostgreSQL versions.
         """
         version = self.server_version
         if version < 100000:
-            return (int(version / 10000), (version % 10000) / 100)
+            return (int(version / 10000), int((version % 10000) / 100))
 
         return (int(version / 10000), version % 10000)
 
 
-    def postgis_version_tuple(self):
+    def postgis_version_tuple(self) -> Tuple[int, int]:
         """ Return the postgis version installed in the database as a
             tuple of (major, minor). Assumes that the PostGIS extension
             has been installed already.
@@ -137,10 +168,16 @@ class _Connection(psycopg2.extensions.connection):
         with self.cursor() as cur:
             version = cur.scalar('SELECT postgis_lib_version()')
 
-        return tuple((int(x) for x in version.split('.')[:2]))
+        version_parts = version.split('.')
+        if len(version_parts) < 2:
+            raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
+
+        return (int(version_parts[0]), int(version_parts[1]))
 
+class _ConnectionContext(ContextManager[_Connection]):
+    connection: _Connection
 
-def connect(dsn):
+def connect(dsn: str) -> _ConnectionContext:
     """ Open a connection to the database using the specialised connection
         factory. The returned object may be used in conjunction with 'with'.
         When used outside a context manager, use the `connection` attribute
@@ -148,11 +185,11 @@ def connect(dsn):
     """
     try:
         conn = psycopg2.connect(dsn, connection_factory=_Connection)
-        ctxmgr = contextlib.closing(conn)
-        ctxmgr.connection = conn
+        ctxmgr = cast(_ConnectionContext, contextlib.closing(conn))
+        ctxmgr.connection = cast(_Connection, conn)
         return ctxmgr
     except psycopg2.OperationalError as err:
-        raise UsageError("Cannot connect to database: {}".format(err)) from err
+        raise UsageError(f"Cannot connect to database: {err}") from err
 
 
 # Translation from PG connection string parameters to PG environment variables.
@@ -187,7 +224,8 @@ _PG_CONNECTION_STRINGS = {
 }
 
 
-def get_pg_env(dsn, base_env=None):
+def get_pg_env(dsn: str,
+               base_env: Optional[Mapping[str, str]] = None) -> Mapping[str, str]:
     """ Return a copy of `base_env` with the environment variables for
         PostgresSQL set up from the given database connection string.
         If `base_env` is None, then the OS environment is used as a base
@@ -195,7 +233,7 @@ def get_pg_env(dsn, base_env=None):
     """
     env = dict(base_env if base_env is not None else os.environ)
 
-    for param, value in psycopg2.extensions.parse_dsn(dsn).items():
+    for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore
         if param in _PG_CONNECTION_STRINGS:
             env[_PG_CONNECTION_STRINGS[param]] = value
         else: