X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/71249bd94a1bd698a937983663f06a9376629ae6..7cf21f8caf60a0992dd0f7e1ab51d673ca8429e2:/src/nominatim_db/db/connection.py diff --git a/src/nominatim_db/db/connection.py b/src/nominatim_db/db/connection.py index 8faa3f93..6c7e843f 100644 --- a/src/nominatim_db/db/connection.py +++ b/src/nominatim_db/db/connection.py @@ -7,193 +7,136 @@ """ Specialised connection and cursor functions. """ -from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable -import contextlib +from typing import Optional, Any, Dict, Tuple import logging import os -import psycopg2 -import psycopg2.extensions -import psycopg2.extras -from psycopg2 import sql as pysql +import psycopg +import psycopg.types.hstore +from psycopg import sql as pysql -from ..typing import SysEnv, Query, T_cursor +from ..typing import SysEnv from ..errors import UsageError LOG = logging.getLogger() -class Cursor(psycopg2.extras.DictCursor): - """ A cursor returning dict-like objects and providing specialised - execution functions. - """ - # pylint: disable=arguments-renamed,arguments-differ - def execute(self, query: Query, args: Any = None) -> None: - """ Query execution that logs the SQL query when debugging is enabled. - """ - if LOG.isEnabledFor(logging.DEBUG): - LOG.debug(self.mogrify(query, args).decode('utf-8')) - - super().execute(query, args) - - - def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]], - template: Optional[Query] = None) -> None: - """ Wrapper for the psycopg2 convenience function to execute - SQL for a list of values. - """ - LOG.debug("SQL execute_values(%s, %s)", sql, argslist) - - psycopg2.extras.execute_values(self, sql, argslist, template=template) +Cursor = psycopg.Cursor[Any] +Connection = psycopg.Connection[Any] +def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any: + """ Execute query that returns a single value. The value is returned. + If the query yields more than one row, a ValueError is raised. + """ + with conn.cursor(row_factory=psycopg.rows.tuple_row) as cur: + cur.execute(sql, args) - def scalar(self, sql: Query, args: Any = None) -> Any: - """ Execute query that returns a single value. The value is returned. - If the query yields more than one row, a ValueError is raised. - """ - self.execute(sql, args) - - if self.rowcount != 1: + if cur.rowcount != 1: raise RuntimeError("Query did not return a single row.") - result = self.fetchone() - assert result is not None + result = cur.fetchone() - return result[0] + assert result is not None + return result[0] - def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: - """ Drop the table with the given name. - Set `if_exists` to False if a non-existent table should raise - an exception instead of just being ignored. If 'cascade' is set - to True then all dependent tables are deleted as well. - """ - sql = 'DROP TABLE ' - if if_exists: - sql += 'IF EXISTS ' - sql += '{}' - if cascade: - sql += ' CASCADE' +def table_exists(conn: Connection, table: str) -> bool: + """ Check that a table with the given name exists in the database. + """ + num = execute_scalar(conn, + """SELECT count(*) FROM pg_tables + WHERE tablename = %s and schemaname = 'public'""", (table, )) + return num == 1 if isinstance(num, int) else False - self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) +def table_has_column(conn: Connection, table: str, column: str) -> bool: + """ Check if the table 'table' exists and has a column with name 'column'. + """ + has_column = execute_scalar(conn, + """SELECT count(*) FROM information_schema.columns + WHERE table_name = %s and column_name = %s""", + (table, column)) + return has_column > 0 if isinstance(has_column, int) else False -class Connection(psycopg2.extensions.connection): - """ A connection that provides the specialised cursor by default and - adds convenience functions for administrating the database. + +def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool: + """ Check that an index with the given name exists in the database. + If table is not None then the index must relate to the given + table. """ - @overload # type: ignore[override] - def cursor(self) -> Cursor: - ... - - @overload - def cursor(self, name: str) -> Cursor: - ... - - @overload - def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor: - ... - - def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore - """ Return a new cursor. By default the specialised cursor is returned. - """ - return super().cursor(cursor_factory=cursor_factory, **kwargs) - - - def table_exists(self, table: str) -> bool: - """ Check that a table with the given name exists in the database. - """ - with self.cursor() as cur: - num = cur.scalar("""SELECT count(*) FROM pg_tables - WHERE tablename = %s and schemaname = 'public'""", (table, )) - return num == 1 if isinstance(num, int) else False - - - def table_has_column(self, table: str, column: str) -> bool: - """ Check if the table 'table' exists and has a column with name 'column'. - """ - with self.cursor() as cur: - has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns - WHERE table_name = %s - and column_name = %s""", - (table, column)) - return has_column > 0 if isinstance(has_column, int) else False - - - def index_exists(self, index: str, table: Optional[str] = None) -> bool: - """ Check that an index with the given name exists in the database. - If table is not None then the index must relate to the given - table. - """ - with self.cursor() as cur: - cur.execute("""SELECT tablename FROM pg_indexes - WHERE indexname = %s and schemaname = 'public'""", (index, )) - if cur.rowcount == 0: + with conn.cursor() as cur: + cur.execute("""SELECT tablename FROM pg_indexes + WHERE indexname = %s and schemaname = 'public'""", (index, )) + if cur.rowcount == 0: + return False + + if table is not None: + row = cur.fetchone() + if row is None or not isinstance(row[0], str): return False + return row[0] == table - if table is not None: - row = cur.fetchone() - if row is None or not isinstance(row[0], str): - return False - return row[0] == table - - return True + return True +def drop_tables(conn: Connection, *names: str, + if_exists: bool = True, cascade: bool = False) -> None: + """ Drop one or more tables with the given names. + Set `if_exists` to False if a non-existent table should raise + an exception instead of just being ignored. `cascade` will cause + depended objects to be dropped as well. + The caller needs to take care of committing the change. + """ + sql = pysql.SQL('DROP TABLE%s{}%s' % ( + ' IF EXISTS ' if if_exists else ' ', + ' CASCADE' if cascade else '')) - def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: - """ Drop the table with the given name. - Set `if_exists` to False if a non-existent table should raise - an exception instead of just being ignored. - """ - with self.cursor() as cur: - cur.drop_table(name, if_exists, cascade) - self.commit() + with conn.cursor() as cur: + for name in names: + cur.execute(sql.format(pysql.Identifier(name))) - def server_version_tuple(self) -> Tuple[int, int]: - """ Return the server version as a tuple of (major, minor). - Converts correctly for pre-10 and post-10 PostgreSQL versions. - """ - version = self.server_version - if version < 100000: - return (int(version / 10000), int((version % 10000) / 100)) +def server_version_tuple(conn: Connection) -> Tuple[int, int]: + """ Return the server version as a tuple of (major, minor). + Converts correctly for pre-10 and post-10 PostgreSQL versions. + """ + version = conn.info.server_version + if version < 100000: + return (int(version / 10000), int((version % 10000) / 100)) - return (int(version / 10000), version % 10000) + return (int(version / 10000), version % 10000) - def postgis_version_tuple(self) -> Tuple[int, int]: - """ Return the postgis version installed in the database as a - tuple of (major, minor). Assumes that the PostGIS extension - has been installed already. - """ - with self.cursor() as cur: - version = cur.scalar('SELECT postgis_lib_version()') +def postgis_version_tuple(conn: Connection) -> Tuple[int, int]: + """ Return the postgis version installed in the database as a + tuple of (major, minor). Assumes that the PostGIS extension + has been installed already. + """ + version = execute_scalar(conn, 'SELECT postgis_lib_version()') - version_parts = version.split('.') - if len(version_parts) < 2: - raise UsageError(f"Error fetching Postgis version. Bad format: {version}") + version_parts = version.split('.') + if len(version_parts) < 2: + raise UsageError(f"Error fetching Postgis version. Bad format: {version}") - return (int(version_parts[0]), int(version_parts[1])) + return (int(version_parts[0]), int(version_parts[1])) -class ConnectionContext(ContextManager[Connection]): - """ Context manager of the connection that also provides direct access - to the underlying connection. +def register_hstore(conn: Connection) -> None: + """ Register the hstore type with psycopg for the connection. """ - connection: Connection + info = psycopg.types.TypeInfo.fetch(conn, "hstore") + if info is None: + raise RuntimeError('Hstore extension is requested but not installed.') + psycopg.types.hstore.register_hstore(info, conn) -def connect(dsn: str) -> ConnectionContext: +def connect(dsn: str, **kwargs: Any) -> Connection: """ Open a connection to the database using the specialised connection factory. The returned object may be used in conjunction with 'with'. When used outside a context manager, use the `connection` attribute to get the connection. """ try: - conn = psycopg2.connect(dsn, connection_factory=Connection) - ctxmgr = cast(ConnectionContext, contextlib.closing(conn)) - ctxmgr.connection = conn - return ctxmgr - except psycopg2.OperationalError as err: + return psycopg.connect(dsn, row_factory=psycopg.rows.namedtuple_row, **kwargs) + except psycopg.OperationalError as err: raise UsageError(f"Cannot connect to database: {err}") from err @@ -238,10 +181,18 @@ def get_pg_env(dsn: str, """ env = dict(base_env if base_env is not None else os.environ) - for param, value in psycopg2.extensions.parse_dsn(dsn).items(): + for param, value in psycopg.conninfo.conninfo_to_dict(dsn).items(): if param in _PG_CONNECTION_STRINGS: - env[_PG_CONNECTION_STRINGS[param]] = value + env[_PG_CONNECTION_STRINGS[param]] = str(value) else: LOG.error("Unknown connection parameter '%s' ignored.", param) return env + + +async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None: + """ Open a connection to the database and run a single query + asynchronously. + """ + async with await psycopg.AsyncConnection.connect(dsn) as aconn: + await aconn.execute(query)