X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/f22fa992f7975757d84ae17e42ff98c15f7a572b..f923304eead3cb9e9cfad8f41c33df1fdc1a16fd:/nominatim/db/connection.py diff --git a/nominatim/db/connection.py b/nominatim/db/connection.py index 729e8a70..d6860836 100644 --- a/nominatim/db/connection.py +++ b/nominatim/db/connection.py @@ -7,7 +7,7 @@ """ Specialised connection and cursor functions. """ -from typing import List, Optional, Any, Callable, ContextManager, Mapping, cast, overload, Tuple +from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable import contextlib import logging import os @@ -17,12 +17,12 @@ import psycopg2.extensions import psycopg2.extras from psycopg2 import sql as pysql -from nominatim.typing import Query, T_cursor +from nominatim.typing import SysEnv, Query, T_cursor from nominatim.errors import UsageError LOG = logging.getLogger() -class _Cursor(psycopg2.extras.DictCursor): +class Cursor(psycopg2.extras.DictCursor): """ A cursor returning dict-like objects and providing specialised execution functions. """ @@ -31,13 +31,13 @@ class _Cursor(psycopg2.extras.DictCursor): """ Query execution that logs the SQL query when debugging is enabled. """ if LOG.isEnabledFor(logging.DEBUG): - LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore + LOG.debug(self.mogrify(query, args).decode('utf-8')) super().execute(query, args) - def execute_values(self, sql: Query, argslist: List[Any], - template: Optional[str] = None) -> None: + def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]], + template: Optional[Query] = None) -> None: """ Wrapper for the psycopg2 convenience function to execute SQL for a list of values. """ @@ -55,7 +55,7 @@ class _Cursor(psycopg2.extras.DictCursor): if self.rowcount != 1: raise RuntimeError("Query did not return a single row.") - result = self.fetchone() # type: ignore + result = self.fetchone() assert result is not None return result[0] @@ -63,7 +63,7 @@ class _Cursor(psycopg2.extras.DictCursor): def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: """ Drop the table with the given name. - Set `if_exists` to False if a non-existant table should raise + Set `if_exists` to False if a non-existent table should raise an exception instead of just being ignored. If 'cascade' is set to True then all dependent tables are deleted as well. """ @@ -74,26 +74,26 @@ class _Cursor(psycopg2.extras.DictCursor): if cascade: sql += ' CASCADE' - self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore + self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) -class _Connection(psycopg2.extensions.connection): +class Connection(psycopg2.extensions.connection): """ A connection that provides the specialised cursor by default and adds convenience functions for administrating the database. """ @overload # type: ignore[override] - def cursor(self) -> _Cursor: + def cursor(self) -> Cursor: ... @overload - def cursor(self, name: str) -> _Cursor: + def cursor(self, name: str) -> Cursor: ... @overload def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor: ... - def cursor(self, cursor_factory = _Cursor, **kwargs): # type: ignore + def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore """ Return a new cursor. By default the specialised cursor is returned. """ return super().cursor(cursor_factory=cursor_factory, **kwargs) @@ -131,7 +131,7 @@ class _Connection(psycopg2.extensions.connection): return False if table is not None: - row = cur.fetchone() # type: ignore + row = cur.fetchone() if row is None or not isinstance(row[0], str): return False return row[0] == table @@ -141,7 +141,7 @@ class _Connection(psycopg2.extensions.connection): def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: """ Drop the table with the given name. - Set `if_exists` to False if a non-existant table should raise + Set `if_exists` to False if a non-existent table should raise an exception instead of just being ignored. """ with self.cursor() as cur: @@ -174,19 +174,31 @@ class _Connection(psycopg2.extensions.connection): return (int(version_parts[0]), int(version_parts[1])) -class _ConnectionContext(ContextManager[_Connection]): - connection: _Connection -def connect(dsn: str) -> _ConnectionContext: + def extension_loaded(self, extension_name: str) -> bool: + """ Return True if the hstore extension is loaded in the database. + """ + with self.cursor() as cur: + cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, )) + return cur.rowcount > 0 + + +class ConnectionContext(ContextManager[Connection]): + """ Context manager of the connection that also provides direct access + to the underlying connection. + """ + connection: Connection + +def connect(dsn: str) -> ConnectionContext: """ Open a connection to the database using the specialised connection factory. The returned object may be used in conjunction with 'with'. When used outside a context manager, use the `connection` attribute to get the connection. """ try: - conn = psycopg2.connect(dsn, connection_factory=_Connection) - ctxmgr = cast(_ConnectionContext, contextlib.closing(conn)) - ctxmgr.connection = cast(_Connection, conn) + conn = psycopg2.connect(dsn, connection_factory=Connection) + ctxmgr = cast(ConnectionContext, contextlib.closing(conn)) + ctxmgr.connection = conn return ctxmgr except psycopg2.OperationalError as err: raise UsageError(f"Cannot connect to database: {err}") from err @@ -225,15 +237,15 @@ _PG_CONNECTION_STRINGS = { def get_pg_env(dsn: str, - base_env: Optional[Mapping[str, str]] = None) -> Mapping[str, str]: + base_env: Optional[SysEnv] = None) -> Dict[str, str]: """ Return a copy of `base_env` with the environment variables for - PostgresSQL set up from the given database connection string. + PostgreSQL set up from the given database connection string. If `base_env` is None, then the OS environment is used as a base environment. """ env = dict(base_env if base_env is not None else os.environ) - for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore + for param, value in psycopg2.extensions.parse_dsn(dsn).items(): if param in _PG_CONNECTION_STRINGS: env[_PG_CONNECTION_STRINGS[param]] = value else: