]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/connection.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / db / connection.py
index 5a1b46959d457b6ce3f9dd7e4d3e6b00590f9434..82801ae7995c9d1e5527baec0d9dd89c85e70e4d 100644 (file)
@@ -7,7 +7,7 @@
 """
 Specialised connection and cursor functions.
 """
 """
 Specialised connection and cursor functions.
 """
-from typing import List, Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple
+from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
 import contextlib
 import logging
 import os
 import contextlib
 import logging
 import os
@@ -31,13 +31,13 @@ class Cursor(psycopg2.extras.DictCursor):
         """ Query execution that logs the SQL query when debugging is enabled.
         """
         if LOG.isEnabledFor(logging.DEBUG):
         """ Query execution that logs the SQL query when debugging is enabled.
         """
         if LOG.isEnabledFor(logging.DEBUG):
-            LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore[no-untyped-call]
+            LOG.debug(self.mogrify(query, args).decode('utf-8'))
 
         super().execute(query, args)
 
 
 
         super().execute(query, args)
 
 
-    def execute_values(self, sql: Query, argslist: List[Any],
-                       template: Optional[str] = None) -> None:
+    def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
+                       template: Optional[Query] = None) -> None:
         """ Wrapper for the psycopg2 convenience function to execute
             SQL for a list of values.
         """
         """ Wrapper for the psycopg2 convenience function to execute
             SQL for a list of values.
         """
@@ -55,7 +55,7 @@ class Cursor(psycopg2.extras.DictCursor):
         if self.rowcount != 1:
             raise RuntimeError("Query did not return a single row.")
 
         if self.rowcount != 1:
             raise RuntimeError("Query did not return a single row.")
 
-        result = self.fetchone() # type: ignore[no-untyped-call]
+        result = self.fetchone()
         assert result is not None
 
         return result[0]
         assert result is not None
 
         return result[0]
@@ -63,7 +63,7 @@ class Cursor(psycopg2.extras.DictCursor):
 
     def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
         """ Drop the table with the given name.
 
     def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
         """ Drop the table with the given name.
-            Set `if_exists` to False if a non-existant table should raise
+            Set `if_exists` to False if a non-existent table should raise
             an exception instead of just being ignored. If 'cascade' is set
             to True then all dependent tables are deleted as well.
         """
             an exception instead of just being ignored. If 'cascade' is set
             to True then all dependent tables are deleted as well.
         """
@@ -74,7 +74,7 @@ class Cursor(psycopg2.extras.DictCursor):
         if cascade:
             sql += ' CASCADE'
 
         if cascade:
             sql += ' CASCADE'
 
-        self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore[no-untyped-call]
+        self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
 
 
 class Connection(psycopg2.extensions.connection):
 
 
 class Connection(psycopg2.extensions.connection):
@@ -131,7 +131,7 @@ class Connection(psycopg2.extensions.connection):
                 return False
 
             if table is not None:
                 return False
 
             if table is not None:
-                row = cur.fetchone() # type: ignore[no-untyped-call]
+                row = cur.fetchone()
                 if row is None or not isinstance(row[0], str):
                     return False
                 return row[0] == table
                 if row is None or not isinstance(row[0], str):
                     return False
                 return row[0] == table
@@ -141,7 +141,7 @@ class Connection(psycopg2.extensions.connection):
 
     def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
         """ Drop the table with the given name.
 
     def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
         """ Drop the table with the given name.
-            Set `if_exists` to False if a non-existant table should raise
+            Set `if_exists` to False if a non-existent table should raise
             an exception instead of just being ignored.
         """
         with self.cursor() as cur:
             an exception instead of just being ignored.
         """
         with self.cursor() as cur:
@@ -174,6 +174,15 @@ class Connection(psycopg2.extensions.connection):
 
         return (int(version_parts[0]), int(version_parts[1]))
 
 
         return (int(version_parts[0]), int(version_parts[1]))
 
+
+    def extension_loaded(self, extension_name: str) -> bool:
+        """ Return True if the hstore extension is loaded in the database.
+        """
+        with self.cursor() as cur:
+            cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
+            return cur.rowcount > 0
+
+
 class ConnectionContext(ContextManager[Connection]):
     """ Context manager of the connection that also provides direct access
         to the underlying connection.
 class ConnectionContext(ContextManager[Connection]):
     """ Context manager of the connection that also provides direct access
         to the underlying connection.
@@ -189,7 +198,7 @@ def connect(dsn: str) -> ConnectionContext:
     try:
         conn = psycopg2.connect(dsn, connection_factory=Connection)
         ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
     try:
         conn = psycopg2.connect(dsn, connection_factory=Connection)
         ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
-        ctxmgr.connection = cast(Connection, conn)
+        ctxmgr.connection = conn
         return ctxmgr
     except psycopg2.OperationalError as err:
         raise UsageError(f"Cannot connect to database: {err}") from err
         return ctxmgr
     except psycopg2.OperationalError as err:
         raise UsageError(f"Cannot connect to database: {err}") from err
@@ -236,7 +245,7 @@ def get_pg_env(dsn: str,
     """
     env = dict(base_env if base_env is not None else os.environ)
 
     """
     env = dict(base_env if base_env is not None else os.environ)
 
-    for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore
+    for param, value in psycopg2.extensions.parse_dsn(dsn).items():
         if param in _PG_CONNECTION_STRINGS:
             env[_PG_CONNECTION_STRINGS[param]] = value
         else:
         if param in _PG_CONNECTION_STRINGS:
             env[_PG_CONNECTION_STRINGS[param]] = value
         else: