]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/connection.py
type annotations for DB connection
[nominatim.git] / nominatim / db / connection.py
index c60bcfddb8af99faf3bc2ba4e7f7a48b7baf731f..cbdbbfef323ccf94abc9917ddbc6322739dceb81 100644 (file)
@@ -7,6 +7,7 @@
 """
 Specialised connection and cursor functions.
 """
 """
 Specialised connection and cursor functions.
 """
+from typing import Union, List, Optional, Any, Callable, ContextManager, Mapping, cast, TypeVar, overload, Tuple, Sequence
 import contextlib
 import logging
 import os
 import contextlib
 import logging
 import os
@@ -18,23 +19,26 @@ from psycopg2 import sql as pysql
 
 from nominatim.errors import UsageError
 
 
 from nominatim.errors import UsageError
 
+Query = Union[str, bytes, pysql.Composable]
+T = TypeVar('T', bound=psycopg2.extensions.cursor)
+
 LOG = logging.getLogger()
 
 class _Cursor(psycopg2.extras.DictCursor):
     """ A cursor returning dict-like objects and providing specialised
         execution functions.
     """
 LOG = logging.getLogger()
 
 class _Cursor(psycopg2.extras.DictCursor):
     """ A cursor returning dict-like objects and providing specialised
         execution functions.
     """
-
     # pylint: disable=arguments-renamed,arguments-differ
     # pylint: disable=arguments-renamed,arguments-differ
-    def execute(self, query, args=None):
+    def execute(self, query: Query, args: Any = None) -> None:
         """ Query execution that logs the SQL query when debugging is enabled.
         """
         """ Query execution that logs the SQL query when debugging is enabled.
         """
-        LOG.debug(self.mogrify(query, args).decode('utf-8'))
+        if LOG.isEnabledFor(logging.DEBUG):
+            LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore
 
         super().execute(query, args)
 
 
 
         super().execute(query, args)
 
 
-    def execute_values(self, sql, argslist, template=None):
+    def execute_values(self, sql: Query, argslist: List[Any], template: Optional[str] = None) -> None:
         """ Wrapper for the psycopg2 convenience function to execute
             SQL for a list of values.
         """
         """ Wrapper for the psycopg2 convenience function to execute
             SQL for a list of values.
         """
@@ -43,7 +47,7 @@ class _Cursor(psycopg2.extras.DictCursor):
         psycopg2.extras.execute_values(self, sql, argslist, template=template)
 
 
         psycopg2.extras.execute_values(self, sql, argslist, template=template)
 
 
-    def scalar(self, sql, args=None):
+    def scalar(self, sql: Query, args: Any = None) -> Any:
         """ Execute query that returns a single value. The value is returned.
             If the query yields more than one row, a ValueError is raised.
         """
         """ Execute query that returns a single value. The value is returned.
             If the query yields more than one row, a ValueError is raised.
         """
@@ -52,10 +56,13 @@ class _Cursor(psycopg2.extras.DictCursor):
         if self.rowcount != 1:
             raise RuntimeError("Query did not return a single row.")
 
         if self.rowcount != 1:
             raise RuntimeError("Query did not return a single row.")
 
-        return self.fetchone()[0]
+        result = self.fetchone() # type: ignore
+        assert result is not None
+
+        return result[0]
 
 
 
 
-    def drop_table(self, name, if_exists=True, cascade=False):
+    def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
         """ Drop the table with the given name.
             Set `if_exists` to False if a non-existant table should raise
             an exception instead of just being ignored. If 'cascade' is set
         """ Drop the table with the given name.
             Set `if_exists` to False if a non-existant table should raise
             an exception instead of just being ignored. If 'cascade' is set
@@ -68,30 +75,41 @@ class _Cursor(psycopg2.extras.DictCursor):
         if cascade:
             sql += ' CASCADE'
 
         if cascade:
             sql += ' CASCADE'
 
-        self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
+        self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore
 
 
 class _Connection(psycopg2.extensions.connection):
     """ A connection that provides the specialised cursor by default and
         adds convenience functions for administrating the database.
     """
 
 
 class _Connection(psycopg2.extensions.connection):
     """ A connection that provides the specialised cursor by default and
         adds convenience functions for administrating the database.
     """
+    @overload # type: ignore[override]
+    def cursor(self) -> _Cursor:
+        ...
 
 
-    def cursor(self, cursor_factory=_Cursor, **kwargs):
+    @overload
+    def cursor(self, name: str) -> _Cursor:
+        ...
+
+    @overload
+    def cursor(self, cursor_factory: Callable[..., T]) -> T:
+        ...
+
+    def cursor(self, cursor_factory  = _Cursor, **kwargs): # type: ignore
         """ Return a new cursor. By default the specialised cursor is returned.
         """
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
 
         """ Return a new cursor. By default the specialised cursor is returned.
         """
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
 
-    def table_exists(self, table):
+    def table_exists(self, table: str) -> bool:
         """ Check that a table with the given name exists in the database.
         """
         with self.cursor() as cur:
             num = cur.scalar("""SELECT count(*) FROM pg_tables
                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
         """ Check that a table with the given name exists in the database.
         """
         with self.cursor() as cur:
             num = cur.scalar("""SELECT count(*) FROM pg_tables
                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
-            return num == 1
+            return num == 1 if isinstance(num, int) else False
 
 
 
 
-    def table_has_column(self, table, column):
+    def table_has_column(self, table: str, column: str) -> bool:
         """ Check if the table 'table' exists and has a column with name 'column'.
         """
         with self.cursor() as cur:
         """ Check if the table 'table' exists and has a column with name 'column'.
         """
         with self.cursor() as cur:
@@ -99,10 +117,10 @@ class _Connection(psycopg2.extensions.connection):
                                        WHERE table_name = %s
                                              and column_name = %s""",
                                     (table, column))
                                        WHERE table_name = %s
                                              and column_name = %s""",
                                     (table, column))
-            return has_column > 0
+            return has_column > 0 if isinstance(has_column, int) else False
 
 
 
 
-    def index_exists(self, index, table=None):
+    def index_exists(self, index: str, table: Optional[str] = None) -> bool:
         """ Check that an index with the given name exists in the database.
             If table is not None then the index must relate to the given
             table.
         """ Check that an index with the given name exists in the database.
             If table is not None then the index must relate to the given
             table.
@@ -114,13 +132,15 @@ class _Connection(psycopg2.extensions.connection):
                 return False
 
             if table is not None:
                 return False
 
             if table is not None:
-                row = cur.fetchone()
+                row = cur.fetchone() # type: ignore
+                if row is None or not isinstance(row[0], str):
+                    return False
                 return row[0] == table
 
         return True
 
 
                 return row[0] == table
 
         return True
 
 
-    def drop_table(self, name, if_exists=True, cascade=False):
+    def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
         """ Drop the table with the given name.
             Set `if_exists` to False if a non-existant table should raise
             an exception instead of just being ignored.
         """ Drop the table with the given name.
             Set `if_exists` to False if a non-existant table should raise
             an exception instead of just being ignored.
@@ -130,18 +150,18 @@ class _Connection(psycopg2.extensions.connection):
         self.commit()
 
 
         self.commit()
 
 
-    def server_version_tuple(self):
+    def server_version_tuple(self) -> Tuple[int, int]:
         """ Return the server version as a tuple of (major, minor).
             Converts correctly for pre-10 and post-10 PostgreSQL versions.
         """
         version = self.server_version
         if version < 100000:
         """ Return the server version as a tuple of (major, minor).
             Converts correctly for pre-10 and post-10 PostgreSQL versions.
         """
         version = self.server_version
         if version < 100000:
-            return (int(version / 10000), (version % 10000) / 100)
+            return (int(version / 10000), int((version % 10000) / 100))
 
         return (int(version / 10000), version % 10000)
 
 
 
         return (int(version / 10000), version % 10000)
 
 
-    def postgis_version_tuple(self):
+    def postgis_version_tuple(self) -> Tuple[int, int]:
         """ Return the postgis version installed in the database as a
             tuple of (major, minor). Assumes that the PostGIS extension
             has been installed already.
         """ Return the postgis version installed in the database as a
             tuple of (major, minor). Assumes that the PostGIS extension
             has been installed already.
@@ -149,10 +169,16 @@ class _Connection(psycopg2.extensions.connection):
         with self.cursor() as cur:
             version = cur.scalar('SELECT postgis_lib_version()')
 
         with self.cursor() as cur:
             version = cur.scalar('SELECT postgis_lib_version()')
 
-        return tuple((int(x) for x in version.split('.')[:2]))
+        version_parts = version.split('.')
+        if len(version_parts) < 2:
+            raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
+
+        return (int(version_parts[0]), int(version_parts[1]))
 
 
+class _ConnectionContext(ContextManager[_Connection]):
+    connection: _Connection
 
 
-def connect(dsn):
+def connect(dsn: str) -> _ConnectionContext:
     """ Open a connection to the database using the specialised connection
         factory. The returned object may be used in conjunction with 'with'.
         When used outside a context manager, use the `connection` attribute
     """ Open a connection to the database using the specialised connection
         factory. The returned object may be used in conjunction with 'with'.
         When used outside a context manager, use the `connection` attribute
@@ -160,8 +186,8 @@ def connect(dsn):
     """
     try:
         conn = psycopg2.connect(dsn, connection_factory=_Connection)
     """
     try:
         conn = psycopg2.connect(dsn, connection_factory=_Connection)
-        ctxmgr = contextlib.closing(conn)
-        ctxmgr.connection = conn
+        ctxmgr = cast(_ConnectionContext, contextlib.closing(conn))
+        ctxmgr.connection = cast(_Connection, conn)
         return ctxmgr
     except psycopg2.OperationalError as err:
         raise UsageError(f"Cannot connect to database: {err}") from err
         return ctxmgr
     except psycopg2.OperationalError as err:
         raise UsageError(f"Cannot connect to database: {err}") from err
@@ -199,7 +225,8 @@ _PG_CONNECTION_STRINGS = {
 }
 
 
 }
 
 
-def get_pg_env(dsn, base_env=None):
+def get_pg_env(dsn: str,
+               base_env: Optional[Mapping[str, Optional[str]]] = None) -> Mapping[str, Optional[str]]:
     """ Return a copy of `base_env` with the environment variables for
         PostgresSQL set up from the given database connection string.
         If `base_env` is None, then the OS environment is used as a base
     """ Return a copy of `base_env` with the environment variables for
         PostgresSQL set up from the given database connection string.
         If `base_env` is None, then the OS environment is used as a base
@@ -207,7 +234,7 @@ def get_pg_env(dsn, base_env=None):
     """
     env = dict(base_env if base_env is not None else os.environ)
 
     """
     env = dict(base_env if base_env is not None else os.environ)
 
-    for param, value in psycopg2.extensions.parse_dsn(dsn).items():
+    for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore
         if param in _PG_CONNECTION_STRINGS:
             env[_PG_CONNECTION_STRINGS[param]] = value
         else:
         if param in _PG_CONNECTION_STRINGS:
             env[_PG_CONNECTION_STRINGS[param]] = value
         else: