X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/e629a175ed0a1c54398622251f56d56baeef768f..c41f2fed2133668dc3179813261d39d3ff69cbdd:/nominatim/db/connection.py diff --git a/nominatim/db/connection.py b/nominatim/db/connection.py index 4d30151d..82801ae7 100644 --- a/nominatim/db/connection.py +++ b/nominatim/db/connection.py @@ -1,26 +1,52 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2022 by the Nominatim developer community. +# For a full list of authors see the git log. """ Specialised connection and cursor functions. """ +from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable +import contextlib import logging +import os import psycopg2 import psycopg2.extensions import psycopg2.extras +from psycopg2 import sql as pysql -class _Cursor(psycopg2.extras.DictCursor): +from nominatim.typing import SysEnv, Query, T_cursor +from nominatim.errors import UsageError + +LOG = logging.getLogger() + +class Cursor(psycopg2.extras.DictCursor): """ A cursor returning dict-like objects and providing specialised execution functions. """ - - def execute(self, query, args=None): # pylint: disable=W0221 + # pylint: disable=arguments-renamed,arguments-differ + def execute(self, query: Query, args: Any = None) -> None: """ Query execution that logs the SQL query when debugging is enabled. """ - logger = logging.getLogger() - logger.debug(self.mogrify(query, args).decode('utf-8')) + if LOG.isEnabledFor(logging.DEBUG): + LOG.debug(self.mogrify(query, args).decode('utf-8')) super().execute(query, args) - def scalar(self, sql, args=None): + + def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]], + template: Optional[Query] = None) -> None: + """ Wrapper for the psycopg2 convenience function to execute + SQL for a list of values. + """ + LOG.debug("SQL execute_values(%s, %s)", sql, argslist) + + psycopg2.extras.execute_values(self, sql, argslist, template=template) + + + def scalar(self, sql: Query, args: Any = None) -> Any: """ Execute query that returns a single value. The value is returned. If the query yields more than one row, a ValueError is raised. """ @@ -29,30 +55,200 @@ class _Cursor(psycopg2.extras.DictCursor): if self.rowcount != 1: raise RuntimeError("Query did not return a single row.") - return self.fetchone()[0] + result = self.fetchone() + assert result is not None + + return result[0] -class _Connection(psycopg2.extensions.connection): + def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: + """ Drop the table with the given name. + Set `if_exists` to False if a non-existent table should raise + an exception instead of just being ignored. If 'cascade' is set + to True then all dependent tables are deleted as well. + """ + sql = 'DROP TABLE ' + if if_exists: + sql += 'IF EXISTS ' + sql += '{}' + if cascade: + sql += ' CASCADE' + + self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) + + +class Connection(psycopg2.extensions.connection): """ A connection that provides the specialised cursor by default and adds convenience functions for administrating the database. """ + @overload # type: ignore[override] + def cursor(self) -> Cursor: + ... + + @overload + def cursor(self, name: str) -> Cursor: + ... - def cursor(self, cursor_factory=_Cursor, **kwargs): + @overload + def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor: + ... + + def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore """ Return a new cursor. By default the specialised cursor is returned. """ return super().cursor(cursor_factory=cursor_factory, **kwargs) - def table_exists(self, table): + + def table_exists(self, table: str) -> bool: """ Check that a table with the given name exists in the database. """ with self.cursor() as cur: num = cur.scalar("""SELECT count(*) FROM pg_tables - WHERE tablename = %s""", (table, )) - return num == 1 + WHERE tablename = %s and schemaname = 'public'""", (table, )) + return num == 1 if isinstance(num, int) else False -def connect(dsn): + def table_has_column(self, table: str, column: str) -> bool: + """ Check if the table 'table' exists and has a column with name 'column'. + """ + with self.cursor() as cur: + has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns + WHERE table_name = %s + and column_name = %s""", + (table, column)) + return has_column > 0 if isinstance(has_column, int) else False + + + def index_exists(self, index: str, table: Optional[str] = None) -> bool: + """ Check that an index with the given name exists in the database. + If table is not None then the index must relate to the given + table. + """ + with self.cursor() as cur: + cur.execute("""SELECT tablename FROM pg_indexes + WHERE indexname = %s and schemaname = 'public'""", (index, )) + if cur.rowcount == 0: + return False + + if table is not None: + row = cur.fetchone() + if row is None or not isinstance(row[0], str): + return False + return row[0] == table + + return True + + + def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: + """ Drop the table with the given name. + Set `if_exists` to False if a non-existent table should raise + an exception instead of just being ignored. + """ + with self.cursor() as cur: + cur.drop_table(name, if_exists, cascade) + self.commit() + + + def server_version_tuple(self) -> Tuple[int, int]: + """ Return the server version as a tuple of (major, minor). + Converts correctly for pre-10 and post-10 PostgreSQL versions. + """ + version = self.server_version + if version < 100000: + return (int(version / 10000), int((version % 10000) / 100)) + + return (int(version / 10000), version % 10000) + + + def postgis_version_tuple(self) -> Tuple[int, int]: + """ Return the postgis version installed in the database as a + tuple of (major, minor). Assumes that the PostGIS extension + has been installed already. + """ + with self.cursor() as cur: + version = cur.scalar('SELECT postgis_lib_version()') + + version_parts = version.split('.') + if len(version_parts) < 2: + raise UsageError(f"Error fetching Postgis version. Bad format: {version}") + + return (int(version_parts[0]), int(version_parts[1])) + + + def extension_loaded(self, extension_name: str) -> bool: + """ Return True if the hstore extension is loaded in the database. + """ + with self.cursor() as cur: + cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, )) + return cur.rowcount > 0 + + +class ConnectionContext(ContextManager[Connection]): + """ Context manager of the connection that also provides direct access + to the underlying connection. + """ + connection: Connection + +def connect(dsn: str) -> ConnectionContext: """ Open a connection to the database using the specialised connection - factory. + factory. The returned object may be used in conjunction with 'with'. + When used outside a context manager, use the `connection` attribute + to get the connection. """ - return psycopg2.connect(dsn, connection_factory=_Connection) + try: + conn = psycopg2.connect(dsn, connection_factory=Connection) + ctxmgr = cast(ConnectionContext, contextlib.closing(conn)) + ctxmgr.connection = conn + return ctxmgr + except psycopg2.OperationalError as err: + raise UsageError(f"Cannot connect to database: {err}") from err + + +# Translation from PG connection string parameters to PG environment variables. +# Derived from https://www.postgresql.org/docs/current/libpq-envars.html. +_PG_CONNECTION_STRINGS = { + 'host': 'PGHOST', + 'hostaddr': 'PGHOSTADDR', + 'port': 'PGPORT', + 'dbname': 'PGDATABASE', + 'user': 'PGUSER', + 'password': 'PGPASSWORD', + 'passfile': 'PGPASSFILE', + 'channel_binding': 'PGCHANNELBINDING', + 'service': 'PGSERVICE', + 'options': 'PGOPTIONS', + 'application_name': 'PGAPPNAME', + 'sslmode': 'PGSSLMODE', + 'requiressl': 'PGREQUIRESSL', + 'sslcompression': 'PGSSLCOMPRESSION', + 'sslcert': 'PGSSLCERT', + 'sslkey': 'PGSSLKEY', + 'sslrootcert': 'PGSSLROOTCERT', + 'sslcrl': 'PGSSLCRL', + 'requirepeer': 'PGREQUIREPEER', + 'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION', + 'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION', + 'gssencmode': 'PGGSSENCMODE', + 'krbsrvname': 'PGKRBSRVNAME', + 'gsslib': 'PGGSSLIB', + 'connect_timeout': 'PGCONNECT_TIMEOUT', + 'target_session_attrs': 'PGTARGETSESSIONATTRS', +} + + +def get_pg_env(dsn: str, + base_env: Optional[SysEnv] = None) -> Dict[str, str]: + """ Return a copy of `base_env` with the environment variables for + PostgresSQL set up from the given database connection string. + If `base_env` is None, then the OS environment is used as a base + environment. + """ + env = dict(base_env if base_env is not None else os.environ) + + for param, value in psycopg2.extensions.parse_dsn(dsn).items(): + if param in _PG_CONNECTION_STRINGS: + env[_PG_CONNECTION_STRINGS[param]] = value + else: + LOG.error("Unknown connection parameter '%s' ignored.", param) + + return env