X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/af7226393a45a0ea5b87967c3231392b0e12da64..3bcd32ca2006cb2184f290e18f4e054f8aca8bf4:/nominatim/db/connection.py?ds=inline diff --git a/nominatim/db/connection.py b/nominatim/db/connection.py index 68e988f6..c60bcfdd 100644 --- a/nominatim/db/connection.py +++ b/nominatim/db/connection.py @@ -1,3 +1,9 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2022 by the Nominatim developer community. +# For a full list of authors see the git log. """ Specialised connection and cursor functions. """ @@ -8,8 +14,9 @@ import os import psycopg2 import psycopg2.extensions import psycopg2.extras +from psycopg2 import sql as pysql -from ..errors import UsageError +from nominatim.errors import UsageError LOG = logging.getLogger() @@ -18,13 +25,24 @@ class _Cursor(psycopg2.extras.DictCursor): execution functions. """ - def execute(self, query, args=None): # pylint: disable=W0221 + # pylint: disable=arguments-renamed,arguments-differ + def execute(self, query, args=None): """ Query execution that logs the SQL query when debugging is enabled. """ LOG.debug(self.mogrify(query, args).decode('utf-8')) super().execute(query, args) + + def execute_values(self, sql, argslist, template=None): + """ Wrapper for the psycopg2 convenience function to execute + SQL for a list of values. + """ + LOG.debug("SQL execute_values(%s, %s)", sql, argslist) + + psycopg2.extras.execute_values(self, sql, argslist, template=template) + + def scalar(self, sql, args=None): """ Execute query that returns a single value. The value is returned. If the query yields more than one row, a ValueError is raised. @@ -37,6 +55,22 @@ class _Cursor(psycopg2.extras.DictCursor): return self.fetchone()[0] + def drop_table(self, name, if_exists=True, cascade=False): + """ Drop the table with the given name. + Set `if_exists` to False if a non-existant table should raise + an exception instead of just being ignored. If 'cascade' is set + to True then all dependent tables are deleted as well. + """ + sql = 'DROP TABLE ' + if if_exists: + sql += 'IF EXISTS ' + sql += '{}' + if cascade: + sql += ' CASCADE' + + self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) + + class _Connection(psycopg2.extensions.connection): """ A connection that provides the specialised cursor by default and adds convenience functions for administrating the database. @@ -57,6 +91,17 @@ class _Connection(psycopg2.extensions.connection): return num == 1 + def table_has_column(self, table, column): + """ Check if the table 'table' exists and has a column with name 'column'. + """ + with self.cursor() as cur: + has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns + WHERE table_name = %s + and column_name = %s""", + (table, column)) + return has_column > 0 + + def index_exists(self, index, table=None): """ Check that an index with the given name exists in the database. If table is not None then the index must relate to the given @@ -75,15 +120,37 @@ class _Connection(psycopg2.extensions.connection): return True + def drop_table(self, name, if_exists=True, cascade=False): + """ Drop the table with the given name. + Set `if_exists` to False if a non-existant table should raise + an exception instead of just being ignored. + """ + with self.cursor() as cur: + cur.drop_table(name, if_exists, cascade) + self.commit() + + def server_version_tuple(self): """ Return the server version as a tuple of (major, minor). Converts correctly for pre-10 and post-10 PostgreSQL versions. """ version = self.server_version if version < 100000: - return (version / 10000, (version % 10000) / 100) + return (int(version / 10000), (version % 10000) / 100) + + return (int(version / 10000), version % 10000) + + + def postgis_version_tuple(self): + """ Return the postgis version installed in the database as a + tuple of (major, minor). Assumes that the PostGIS extension + has been installed already. + """ + with self.cursor() as cur: + version = cur.scalar('SELECT postgis_lib_version()') + + return tuple((int(x) for x in version.split('.')[:2])) - return (version / 10000, version % 10000) def connect(dsn): """ Open a connection to the database using the specialised connection @@ -97,7 +164,7 @@ def connect(dsn): ctxmgr.connection = conn return ctxmgr except psycopg2.OperationalError as err: - raise UsageError("Cannot connect to database: {}".format(err)) from err + raise UsageError(f"Cannot connect to database: {err}") from err # Translation from PG connection string parameters to PG environment variables. @@ -123,7 +190,7 @@ _PG_CONNECTION_STRINGS = { 'sslcrl': 'PGSSLCRL', 'requirepeer': 'PGREQUIREPEER', 'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION', - 'ssl_min_protocol_version': 'PGSSLMAXPROTOCOLVERSION', + 'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION', 'gssencmode': 'PGGSSENCMODE', 'krbsrvname': 'PGKRBSRVNAME', 'gsslib': 'PGGSSLIB', @@ -138,7 +205,7 @@ def get_pg_env(dsn, base_env=None): If `base_env` is None, then the OS environment is used as a base environment. """ - env = base_env if base_env is not None else os.environ + env = dict(base_env if base_env is not None else os.environ) for param, value in psycopg2.extensions.parse_dsn(dsn).items(): if param in _PG_CONNECTION_STRINGS: