X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/72b01148d2d12f71c12440c15fa078b55e1c8f86..3c9e09545e4a3955e0454097df9e5adcaa212d3d:/nominatim/db/connection.py diff --git a/nominatim/db/connection.py b/nominatim/db/connection.py index b941f46f..5aa05ced 100644 --- a/nominatim/db/connection.py +++ b/nominatim/db/connection.py @@ -1,7 +1,9 @@ """ Specialised connection and cursor functions. """ +import contextlib import logging +import os import psycopg2 import psycopg2.extensions @@ -9,6 +11,8 @@ import psycopg2.extras from ..errors import UsageError +LOG = logging.getLogger() + class _Cursor(psycopg2.extras.DictCursor): """ A cursor returning dict-like objects and providing specialised execution functions. @@ -17,8 +21,7 @@ class _Cursor(psycopg2.extras.DictCursor): def execute(self, query, args=None): # pylint: disable=W0221 """ Query execution that logs the SQL query when debugging is enabled. """ - logger = logging.getLogger() - logger.debug(self.mogrify(query, args).decode('utf-8')) + LOG.debug(self.mogrify(query, args).decode('utf-8')) super().execute(query, args) @@ -72,21 +75,98 @@ class _Connection(psycopg2.extensions.connection): return True + def drop_table(self, name, if_exists=True): + """ Drop the table with the given name. + Set `if_exists` to False if a non-existant table should raise + an exception instead of just being ignored. + """ + with self.cursor() as cur: + cur.execute("""DROP TABLE {} "{}" + """.format('IF EXISTS' if if_exists else '', name)) + self.commit() + + def server_version_tuple(self): """ Return the server version as a tuple of (major, minor). Converts correctly for pre-10 and post-10 PostgreSQL versions. """ version = self.server_version if version < 100000: - return (version / 10000, (version % 10000) / 100) + return (int(version / 10000), (version % 10000) / 100) + + return (int(version / 10000), version % 10000) + + + def postgis_version_tuple(self): + """ Return the postgis version installed in the database as a + tuple of (major, minor). Assumes that the PostGIS extension + has been installed already. + """ + with self.cursor() as cur: + version = cur.scalar('SELECT postgis_lib_version()') + + return tuple((int(x) for x in version.split('.')[:2])) - return (version / 10000, version % 10000) def connect(dsn): """ Open a connection to the database using the specialised connection - factory. + factory. The returned object may be used in conjunction with 'with'. + When used outside a context manager, use the `connection` attribute + to get the connection. """ try: - return psycopg2.connect(dsn, connection_factory=_Connection) + conn = psycopg2.connect(dsn, connection_factory=_Connection) + ctxmgr = contextlib.closing(conn) + ctxmgr.connection = conn + return ctxmgr except psycopg2.OperationalError as err: raise UsageError("Cannot connect to database: {}".format(err)) from err + + +# Translation from PG connection string parameters to PG environment variables. +# Derived from https://www.postgresql.org/docs/current/libpq-envars.html. +_PG_CONNECTION_STRINGS = { + 'host': 'PGHOST', + 'hostaddr': 'PGHOSTADDR', + 'port': 'PGPORT', + 'dbname': 'PGDATABASE', + 'user': 'PGUSER', + 'password': 'PGPASSWORD', + 'passfile': 'PGPASSFILE', + 'channel_binding': 'PGCHANNELBINDING', + 'service': 'PGSERVICE', + 'options': 'PGOPTIONS', + 'application_name': 'PGAPPNAME', + 'sslmode': 'PGSSLMODE', + 'requiressl': 'PGREQUIRESSL', + 'sslcompression': 'PGSSLCOMPRESSION', + 'sslcert': 'PGSSLCERT', + 'sslkey': 'PGSSLKEY', + 'sslrootcert': 'PGSSLROOTCERT', + 'sslcrl': 'PGSSLCRL', + 'requirepeer': 'PGREQUIREPEER', + 'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION', + 'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION', + 'gssencmode': 'PGGSSENCMODE', + 'krbsrvname': 'PGKRBSRVNAME', + 'gsslib': 'PGGSSLIB', + 'connect_timeout': 'PGCONNECT_TIMEOUT', + 'target_session_attrs': 'PGTARGETSESSIONATTRS', +} + + +def get_pg_env(dsn, base_env=None): + """ Return a copy of `base_env` with the environment variables for + PostgresSQL set up from the given database connection string. + If `base_env` is None, then the OS environment is used as a base + environment. + """ + env = dict(base_env if base_env is not None else os.environ) + + for param, value in psycopg2.extensions.parse_dsn(dsn).items(): + if param in _PG_CONNECTION_STRINGS: + env[_PG_CONNECTION_STRINGS[param]] = value + else: + LOG.error("Unknown connection parameter '%s' ignored.", param) + + return env