From af7226393a45a0ea5b87967c3231392b0e12da64 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Tue, 23 Feb 2021 14:11:11 +0100 Subject: [PATCH] add function to set up libpq environment Instead of parsing the DSN for each external libpq program we are going to execute, provide a function that feeds them all necessary parameters through the environment. osm2pgsql is the first user. --- nominatim/db/connection.py | 55 +++++++++++++++++++++++++++++-- nominatim/tools/exec_utils.py | 14 ++------ test/python/test_db_connection.py | 23 ++++++++++++- 3 files changed, 77 insertions(+), 15 deletions(-) diff --git a/nominatim/db/connection.py b/nominatim/db/connection.py index 6bd81a2f..68e988f6 100644 --- a/nominatim/db/connection.py +++ b/nominatim/db/connection.py @@ -3,6 +3,7 @@ Specialised connection and cursor functions. """ import contextlib import logging +import os import psycopg2 import psycopg2.extensions @@ -10,6 +11,8 @@ import psycopg2.extras from ..errors import UsageError +LOG = logging.getLogger() + class _Cursor(psycopg2.extras.DictCursor): """ A cursor returning dict-like objects and providing specialised execution functions. @@ -18,8 +21,7 @@ class _Cursor(psycopg2.extras.DictCursor): def execute(self, query, args=None): # pylint: disable=W0221 """ Query execution that logs the SQL query when debugging is enabled. """ - logger = logging.getLogger() - logger.debug(self.mogrify(query, args).decode('utf-8')) + LOG.debug(self.mogrify(query, args).decode('utf-8')) super().execute(query, args) @@ -96,3 +98,52 @@ def connect(dsn): return ctxmgr except psycopg2.OperationalError as err: raise UsageError("Cannot connect to database: {}".format(err)) from err + + +# Translation from PG connection string parameters to PG environment variables. +# Derived from https://www.postgresql.org/docs/current/libpq-envars.html. +_PG_CONNECTION_STRINGS = { + 'host': 'PGHOST', + 'hostaddr': 'PGHOSTADDR', + 'port': 'PGPORT', + 'dbname': 'PGDATABASE', + 'user': 'PGUSER', + 'password': 'PGPASSWORD', + 'passfile': 'PGPASSFILE', + 'channel_binding': 'PGCHANNELBINDING', + 'service': 'PGSERVICE', + 'options': 'PGOPTIONS', + 'application_name': 'PGAPPNAME', + 'sslmode': 'PGSSLMODE', + 'requiressl': 'PGREQUIRESSL', + 'sslcompression': 'PGSSLCOMPRESSION', + 'sslcert': 'PGSSLCERT', + 'sslkey': 'PGSSLKEY', + 'sslrootcert': 'PGSSLROOTCERT', + 'sslcrl': 'PGSSLCRL', + 'requirepeer': 'PGREQUIREPEER', + 'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION', + 'ssl_min_protocol_version': 'PGSSLMAXPROTOCOLVERSION', + 'gssencmode': 'PGGSSENCMODE', + 'krbsrvname': 'PGKRBSRVNAME', + 'gsslib': 'PGGSSLIB', + 'connect_timeout': 'PGCONNECT_TIMEOUT', + 'target_session_attrs': 'PGTARGETSESSIONATTRS', +} + + +def get_pg_env(dsn, base_env=None): + """ Return a copy of `base_env` with the environment variables for + PostgresSQL set up from the given database connection string. + If `base_env` is None, then the OS environment is used as a base + environment. + """ + env = base_env if base_env is not None else os.environ + + for param, value in psycopg2.extensions.parse_dsn(dsn).items(): + if param in _PG_CONNECTION_STRINGS: + env[_PG_CONNECTION_STRINGS[param]] = value + else: + LOG.error("Unknown connection parameter '%s' ignored.", param) + + return env diff --git a/nominatim/tools/exec_utils.py b/nominatim/tools/exec_utils.py index f373f347..004a821c 100644 --- a/nominatim/tools/exec_utils.py +++ b/nominatim/tools/exec_utils.py @@ -10,6 +10,7 @@ from urllib.parse import urlencode from psycopg2.extensions import parse_dsn from ..version import NOMINATIM_VERSION +from ..db.connection import get_pg_env LOG = logging.getLogger() @@ -100,7 +101,7 @@ def run_php_server(server_address, base_dir): def run_osm2pgsql(options): """ Run osm2pgsql with the given options. """ - env = os.environ + env = get_pg_env(options['dsn']) cmd = [options['osm2pgsql'], '--hstore', '--latlon', '--slim', '--with-forward-dependencies', 'false', @@ -116,17 +117,6 @@ def run_osm2pgsql(options): if options['flatnode_file']: cmd.extend(('--flat-nodes', options['flatnode_file'])) - dsn = parse_dsn(options['dsn']) - if 'password' in dsn: - env['PGPASSWORD'] = dsn['password'] - if 'dbname' in dsn: - cmd.extend(('-d', dsn['dbname'])) - if 'user' in dsn: - cmd.extend(('--username', dsn['user'])) - for param in ('host', 'port'): - if param in dsn: - cmd.extend(('--' + param, dsn[param])) - if options.get('disable_jit', False): env['PGOPTIONS'] = '-c jit=off -c max_parallel_workers_per_gather=0' diff --git a/test/python/test_db_connection.py b/test/python/test_db_connection.py index 846ef864..fd5da754 100644 --- a/test/python/test_db_connection.py +++ b/test/python/test_db_connection.py @@ -3,7 +3,7 @@ Tests for specialised conenction and cursor classes. """ import pytest -from nominatim.db.connection import connect +from nominatim.db.connection import connect, get_pg_env @pytest.fixture def db(temp_db): @@ -48,3 +48,24 @@ def test_cursor_scalar_many_rows(db): with db.cursor() as cur: with pytest.raises(RuntimeError): cur.scalar('SELECT * FROM pg_tables') + + +def test_get_pg_env_add_variable(monkeypatch): + monkeypatch.delenv('PGPASSWORD', raising=False) + env = get_pg_env('user=fooF') + + assert env['PGUSER'] == 'fooF' + assert 'PGPASSWORD' not in env + + +def test_get_pg_env_overwrite_variable(monkeypatch): + monkeypatch.setenv('PGUSER', 'some default') + env = get_pg_env('user=overwriter') + + assert env['PGUSER'] == 'overwriter' + + +def test_get_pg_env_ignore_unknown(): + env = get_pg_env('tty=stuff', base_env={}) + + assert env == {} -- 2.39.5