X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/efafa5271957fb54b356ec1c90e8613f14de40d4..681aad7e0dc099658eea15b769fdefd44cd8c484:/nominatim/db/connection.py diff --git a/nominatim/db/connection.py b/nominatim/db/connection.py index 1c115207..729e8a70 100644 --- a/nominatim/db/connection.py +++ b/nominatim/db/connection.py @@ -7,6 +7,7 @@ """ Specialised connection and cursor functions. """ +from typing import List, Optional, Any, Callable, ContextManager, Mapping, cast, overload, Tuple import contextlib import logging import os @@ -16,6 +17,7 @@ import psycopg2.extensions import psycopg2.extras from psycopg2 import sql as pysql +from nominatim.typing import Query, T_cursor from nominatim.errors import UsageError LOG = logging.getLogger() @@ -24,16 +26,18 @@ class _Cursor(psycopg2.extras.DictCursor): """ A cursor returning dict-like objects and providing specialised execution functions. """ - - def execute(self, query, args=None): # pylint: disable=W0221 + # pylint: disable=arguments-renamed,arguments-differ + def execute(self, query: Query, args: Any = None) -> None: """ Query execution that logs the SQL query when debugging is enabled. """ - LOG.debug(self.mogrify(query, args).decode('utf-8')) + if LOG.isEnabledFor(logging.DEBUG): + LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore super().execute(query, args) - def execute_values(self, sql, argslist, template=None): + def execute_values(self, sql: Query, argslist: List[Any], + template: Optional[str] = None) -> None: """ Wrapper for the psycopg2 convenience function to execute SQL for a list of values. """ @@ -42,7 +46,7 @@ class _Cursor(psycopg2.extras.DictCursor): psycopg2.extras.execute_values(self, sql, argslist, template=template) - def scalar(self, sql, args=None): + def scalar(self, sql: Query, args: Any = None) -> Any: """ Execute query that returns a single value. The value is returned. If the query yields more than one row, a ValueError is raised. """ @@ -51,10 +55,13 @@ class _Cursor(psycopg2.extras.DictCursor): if self.rowcount != 1: raise RuntimeError("Query did not return a single row.") - return self.fetchone()[0] + result = self.fetchone() # type: ignore + assert result is not None + + return result[0] - def drop_table(self, name, if_exists=True, cascade=False): + def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: """ Drop the table with the given name. Set `if_exists` to False if a non-existant table should raise an exception instead of just being ignored. If 'cascade' is set @@ -67,30 +74,52 @@ class _Cursor(psycopg2.extras.DictCursor): if cascade: sql += ' CASCADE' - self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) + self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore class _Connection(psycopg2.extensions.connection): """ A connection that provides the specialised cursor by default and adds convenience functions for administrating the database. """ + @overload # type: ignore[override] + def cursor(self) -> _Cursor: + ... + + @overload + def cursor(self, name: str) -> _Cursor: + ... - def cursor(self, cursor_factory=_Cursor, **kwargs): + @overload + def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor: + ... + + def cursor(self, cursor_factory = _Cursor, **kwargs): # type: ignore """ Return a new cursor. By default the specialised cursor is returned. """ return super().cursor(cursor_factory=cursor_factory, **kwargs) - def table_exists(self, table): + def table_exists(self, table: str) -> bool: """ Check that a table with the given name exists in the database. """ with self.cursor() as cur: num = cur.scalar("""SELECT count(*) FROM pg_tables WHERE tablename = %s and schemaname = 'public'""", (table, )) - return num == 1 + return num == 1 if isinstance(num, int) else False - def index_exists(self, index, table=None): + def table_has_column(self, table: str, column: str) -> bool: + """ Check if the table 'table' exists and has a column with name 'column'. + """ + with self.cursor() as cur: + has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns + WHERE table_name = %s + and column_name = %s""", + (table, column)) + return has_column > 0 if isinstance(has_column, int) else False + + + def index_exists(self, index: str, table: Optional[str] = None) -> bool: """ Check that an index with the given name exists in the database. If table is not None then the index must relate to the given table. @@ -102,13 +131,15 @@ class _Connection(psycopg2.extensions.connection): return False if table is not None: - row = cur.fetchone() + row = cur.fetchone() # type: ignore + if row is None or not isinstance(row[0], str): + return False return row[0] == table return True - def drop_table(self, name, if_exists=True, cascade=False): + def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: """ Drop the table with the given name. Set `if_exists` to False if a non-existant table should raise an exception instead of just being ignored. @@ -118,18 +149,18 @@ class _Connection(psycopg2.extensions.connection): self.commit() - def server_version_tuple(self): + def server_version_tuple(self) -> Tuple[int, int]: """ Return the server version as a tuple of (major, minor). Converts correctly for pre-10 and post-10 PostgreSQL versions. """ version = self.server_version if version < 100000: - return (int(version / 10000), (version % 10000) / 100) + return (int(version / 10000), int((version % 10000) / 100)) return (int(version / 10000), version % 10000) - def postgis_version_tuple(self): + def postgis_version_tuple(self) -> Tuple[int, int]: """ Return the postgis version installed in the database as a tuple of (major, minor). Assumes that the PostGIS extension has been installed already. @@ -137,10 +168,16 @@ class _Connection(psycopg2.extensions.connection): with self.cursor() as cur: version = cur.scalar('SELECT postgis_lib_version()') - return tuple((int(x) for x in version.split('.')[:2])) + version_parts = version.split('.') + if len(version_parts) < 2: + raise UsageError(f"Error fetching Postgis version. Bad format: {version}") + + return (int(version_parts[0]), int(version_parts[1])) +class _ConnectionContext(ContextManager[_Connection]): + connection: _Connection -def connect(dsn): +def connect(dsn: str) -> _ConnectionContext: """ Open a connection to the database using the specialised connection factory. The returned object may be used in conjunction with 'with'. When used outside a context manager, use the `connection` attribute @@ -148,11 +185,11 @@ def connect(dsn): """ try: conn = psycopg2.connect(dsn, connection_factory=_Connection) - ctxmgr = contextlib.closing(conn) - ctxmgr.connection = conn + ctxmgr = cast(_ConnectionContext, contextlib.closing(conn)) + ctxmgr.connection = cast(_Connection, conn) return ctxmgr except psycopg2.OperationalError as err: - raise UsageError("Cannot connect to database: {}".format(err)) from err + raise UsageError(f"Cannot connect to database: {err}") from err # Translation from PG connection string parameters to PG environment variables. @@ -187,7 +224,8 @@ _PG_CONNECTION_STRINGS = { } -def get_pg_env(dsn, base_env=None): +def get_pg_env(dsn: str, + base_env: Optional[Mapping[str, str]] = None) -> Mapping[str, str]: """ Return a copy of `base_env` with the environment variables for PostgresSQL set up from the given database connection string. If `base_env` is None, then the OS environment is used as a base @@ -195,7 +233,7 @@ def get_pg_env(dsn, base_env=None): """ env = dict(base_env if base_env is not None else os.environ) - for param, value in psycopg2.extensions.parse_dsn(dsn).items(): + for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore if param in _PG_CONNECTION_STRINGS: env[_PG_CONNECTION_STRINGS[param]] = value else: