]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/connection.py
Added missing return types to functions
[nominatim.git] / nominatim / db / connection.py
index 10327725aa20b8d0fc2b2b973324706b31e09284..82801ae7995c9d1e5527baec0d9dd89c85e70e4d 100644 (file)
@@ -7,7 +7,7 @@
 """
 Specialised connection and cursor functions.
 """
-from typing import List, Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple
+from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
 import contextlib
 import logging
 import os
@@ -22,7 +22,7 @@ from nominatim.errors import UsageError
 
 LOG = logging.getLogger()
 
-class _Cursor(psycopg2.extras.DictCursor):
+class Cursor(psycopg2.extras.DictCursor):
     """ A cursor returning dict-like objects and providing specialised
         execution functions.
     """
@@ -31,13 +31,13 @@ class _Cursor(psycopg2.extras.DictCursor):
         """ Query execution that logs the SQL query when debugging is enabled.
         """
         if LOG.isEnabledFor(logging.DEBUG):
-            LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore
+            LOG.debug(self.mogrify(query, args).decode('utf-8'))
 
         super().execute(query, args)
 
 
-    def execute_values(self, sql: Query, argslist: List[Any],
-                       template: Optional[str] = None) -> None:
+    def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
+                       template: Optional[Query] = None) -> None:
         """ Wrapper for the psycopg2 convenience function to execute
             SQL for a list of values.
         """
@@ -55,7 +55,7 @@ class _Cursor(psycopg2.extras.DictCursor):
         if self.rowcount != 1:
             raise RuntimeError("Query did not return a single row.")
 
-        result = self.fetchone() # type: ignore
+        result = self.fetchone()
         assert result is not None
 
         return result[0]
@@ -63,7 +63,7 @@ class _Cursor(psycopg2.extras.DictCursor):
 
     def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
         """ Drop the table with the given name.
-            Set `if_exists` to False if a non-existant table should raise
+            Set `if_exists` to False if a non-existent table should raise
             an exception instead of just being ignored. If 'cascade' is set
             to True then all dependent tables are deleted as well.
         """
@@ -74,26 +74,26 @@ class _Cursor(psycopg2.extras.DictCursor):
         if cascade:
             sql += ' CASCADE'
 
-        self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore
+        self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
 
 
-class _Connection(psycopg2.extensions.connection):
+class Connection(psycopg2.extensions.connection):
     """ A connection that provides the specialised cursor by default and
         adds convenience functions for administrating the database.
     """
     @overload # type: ignore[override]
-    def cursor(self) -> _Cursor:
+    def cursor(self) -> Cursor:
         ...
 
     @overload
-    def cursor(self, name: str) -> _Cursor:
+    def cursor(self, name: str) -> Cursor:
         ...
 
     @overload
     def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
         ...
 
-    def cursor(self, cursor_factory  = _Cursor, **kwargs): # type: ignore
+    def cursor(self, cursor_factory  = Cursor, **kwargs): # type: ignore
         """ Return a new cursor. By default the specialised cursor is returned.
         """
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
@@ -131,7 +131,7 @@ class _Connection(psycopg2.extensions.connection):
                 return False
 
             if table is not None:
-                row = cur.fetchone() # type: ignore
+                row = cur.fetchone()
                 if row is None or not isinstance(row[0], str):
                     return False
                 return row[0] == table
@@ -141,7 +141,7 @@ class _Connection(psycopg2.extensions.connection):
 
     def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
         """ Drop the table with the given name.
-            Set `if_exists` to False if a non-existant table should raise
+            Set `if_exists` to False if a non-existent table should raise
             an exception instead of just being ignored.
         """
         with self.cursor() as cur:
@@ -174,19 +174,31 @@ class _Connection(psycopg2.extensions.connection):
 
         return (int(version_parts[0]), int(version_parts[1]))
 
-class _ConnectionContext(ContextManager[_Connection]):
-    connection: _Connection
 
-def connect(dsn: str) -> _ConnectionContext:
+    def extension_loaded(self, extension_name: str) -> bool:
+        """ Return True if the hstore extension is loaded in the database.
+        """
+        with self.cursor() as cur:
+            cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
+            return cur.rowcount > 0
+
+
+class ConnectionContext(ContextManager[Connection]):
+    """ Context manager of the connection that also provides direct access
+        to the underlying connection.
+    """
+    connection: Connection
+
+def connect(dsn: str) -> ConnectionContext:
     """ Open a connection to the database using the specialised connection
         factory. The returned object may be used in conjunction with 'with'.
         When used outside a context manager, use the `connection` attribute
         to get the connection.
     """
     try:
-        conn = psycopg2.connect(dsn, connection_factory=_Connection)
-        ctxmgr = cast(_ConnectionContext, contextlib.closing(conn))
-        ctxmgr.connection = cast(_Connection, conn)
+        conn = psycopg2.connect(dsn, connection_factory=Connection)
+        ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
+        ctxmgr.connection = conn
         return ctxmgr
     except psycopg2.OperationalError as err:
         raise UsageError(f"Cannot connect to database: {err}") from err
@@ -233,7 +245,7 @@ def get_pg_env(dsn: str,
     """
     env = dict(base_env if base_env is not None else os.environ)
 
-    for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore
+    for param, value in psycopg2.extensions.parse_dsn(dsn).items():
         if param in _PG_CONNECTION_STRINGS:
             env[_PG_CONNECTION_STRINGS[param]] = value
         else: