From e5206133628c0ab1cacd6c5a04a2a9a973bfc86c Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Tue, 23 Feb 2021 10:11:21 +0100 Subject: [PATCH] convert connect() into a context manager --- nominatim/clicmd/admin.py | 5 +-- nominatim/clicmd/freeze.py | 5 +-- nominatim/clicmd/index.py | 5 +-- nominatim/clicmd/refresh.py | 22 +++++------ nominatim/clicmd/replication.py | 39 ++++++++----------- nominatim/db/connection.py | 10 ++++- nominatim/tools/check_database.py | 2 +- test/python/conftest.py | 5 +-- test/python/test_db_connection.py | 5 +-- test/python/test_tools_admin.py | 5 +-- test/python/test_tools_check_database.py | 4 ++ .../test_tools_refresh_create_functions.py | 5 +-- 12 files changed, 53 insertions(+), 59 deletions(-) diff --git a/nominatim/clicmd/admin.py b/nominatim/clicmd/admin.py index e5863575..fd9382eb 100644 --- a/nominatim/clicmd/admin.py +++ b/nominatim/clicmd/admin.py @@ -54,9 +54,8 @@ class AdminFuncs: if args.analyse_indexing: LOG.warning('Analysing performance of indexing function') from ..tools import admin - conn = connect(args.config.get_libpq_dsn()) - admin.analyse_indexing(conn, osm_id=args.osm_id, place_id=args.place_id) - conn.close() + with connect(args.config.get_libpq_dsn()) as conn: + admin.analyse_indexing(conn, osm_id=args.osm_id, place_id=args.place_id) return 0 diff --git a/nominatim/clicmd/freeze.py b/nominatim/clicmd/freeze.py index 8bca04b9..1b311e97 100644 --- a/nominatim/clicmd/freeze.py +++ b/nominatim/clicmd/freeze.py @@ -29,9 +29,8 @@ class SetupFreeze: def run(args): from ..tools import freeze - conn = connect(args.config.get_libpq_dsn()) - freeze.drop_update_tables(conn) + with connect(args.config.get_libpq_dsn()) as conn: + freeze.drop_update_tables(conn) freeze.drop_flatnode_file(args.config.FLATNODE_FILE) - conn.close() return 0 diff --git a/nominatim/clicmd/index.py b/nominatim/clicmd/index.py index ca3f9dee..96a69396 100644 --- a/nominatim/clicmd/index.py +++ b/nominatim/clicmd/index.py @@ -51,8 +51,7 @@ class UpdateIndex: if not args.no_boundaries and not args.boundaries_only \ and args.minrank == 0 and args.maxrank == 30: - conn = connect(args.config.get_libpq_dsn()) - status.set_indexed(conn, True) - conn.close() + with connect(args.config.get_libpq_dsn()) as conn: + status.set_indexed(conn, True) return 0 diff --git a/nominatim/clicmd/refresh.py b/nominatim/clicmd/refresh.py index ffbe628b..f68e185a 100644 --- a/nominatim/clicmd/refresh.py +++ b/nominatim/clicmd/refresh.py @@ -50,29 +50,25 @@ class UpdateRefresh: if args.postcodes: LOG.warning("Update postcodes centroid") - conn = connect(args.config.get_libpq_dsn()) - refresh.update_postcodes(conn, args.sqllib_dir) - conn.close() + with connect(args.config.get_libpq_dsn()) as conn: + refresh.update_postcodes(conn, args.sqllib_dir) if args.word_counts: LOG.warning('Recompute frequency of full-word search terms') - conn = connect(args.config.get_libpq_dsn()) - refresh.recompute_word_counts(conn, args.sqllib_dir) - conn.close() + with connect(args.config.get_libpq_dsn()) as conn: + refresh.recompute_word_counts(conn, args.sqllib_dir) if args.address_levels: cfg = Path(args.config.ADDRESS_LEVEL_CONFIG) LOG.warning('Updating address levels from %s', cfg) - conn = connect(args.config.get_libpq_dsn()) - refresh.load_address_levels_from_file(conn, cfg) - conn.close() + with connect(args.config.get_libpq_dsn()) as conn: + refresh.load_address_levels_from_file(conn, cfg) if args.functions: LOG.warning('Create functions') - conn = connect(args.config.get_libpq_dsn()) - refresh.create_functions(conn, args.config, args.sqllib_dir, - args.diffs, args.enable_debug_statements) - conn.close() + with connect(args.config.get_libpq_dsn()) as conn: + refresh.create_functions(conn, args.config, args.sqllib_dir, + args.diffs, args.enable_debug_statements) if args.wiki_data: run_legacy_script('setup.php', '--import-wikipedia-articles', diff --git a/nominatim/clicmd/replication.py b/nominatim/clicmd/replication.py index 2a19e6cd..e766be2b 100644 --- a/nominatim/clicmd/replication.py +++ b/nominatim/clicmd/replication.py @@ -62,13 +62,12 @@ class UpdateReplication: from ..tools import replication, refresh LOG.warning("Initialising replication updates") - conn = connect(args.config.get_libpq_dsn()) - replication.init_replication(conn, base_url=args.config.REPLICATION_URL) - if args.update_functions: - LOG.warning("Create functions") - refresh.create_functions(conn, args.config, args.sqllib_dir, - True, False) - conn.close() + with connect(args.config.get_libpq_dsn()) as conn: + replication.init_replication(conn, base_url=args.config.REPLICATION_URL) + if args.update_functions: + LOG.warning("Create functions") + refresh.create_functions(conn, args.config, args.sqllib_dir, + True, False) return 0 @@ -76,10 +75,8 @@ class UpdateReplication: def _check_for_updates(args): from ..tools import replication - conn = connect(args.config.get_libpq_dsn()) - ret = replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL) - conn.close() - return ret + with connect(args.config.get_libpq_dsn()) as conn: + return replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL) @staticmethod def _report_update(batchdate, start_import, start_index): @@ -122,13 +119,12 @@ class UpdateReplication: recheck_interval = args.config.get_int('REPLICATION_RECHECK_INTERVAL') while True: - conn = connect(args.config.get_libpq_dsn()) - start = dt.datetime.now(dt.timezone.utc) - state = replication.update(conn, params) - if state is not replication.UpdateState.NO_CHANGES: - status.log_status(conn, start, 'import') - batchdate, _, _ = status.get_status(conn) - conn.close() + with connect(args.config.get_libpq_dsn()) as conn: + start = dt.datetime.now(dt.timezone.utc) + state = replication.update(conn, params) + if state is not replication.UpdateState.NO_CHANGES: + status.log_status(conn, start, 'import') + batchdate, _, _ = status.get_status(conn) if state is not replication.UpdateState.NO_CHANGES and args.do_index: index_start = dt.datetime.now(dt.timezone.utc) @@ -137,10 +133,9 @@ class UpdateReplication: indexer.index_boundaries(0, 30) indexer.index_by_rank(0, 30) - conn = connect(args.config.get_libpq_dsn()) - status.set_indexed(conn, True) - status.log_status(conn, index_start, 'index') - conn.close() + with connect(args.config.get_libpq_dsn()) as conn: + status.set_indexed(conn, True) + status.log_status(conn, index_start, 'index') else: index_start = None diff --git a/nominatim/db/connection.py b/nominatim/db/connection.py index b941f46f..6bd81a2f 100644 --- a/nominatim/db/connection.py +++ b/nominatim/db/connection.py @@ -1,6 +1,7 @@ """ Specialised connection and cursor functions. """ +import contextlib import logging import psycopg2 @@ -84,9 +85,14 @@ class _Connection(psycopg2.extensions.connection): def connect(dsn): """ Open a connection to the database using the specialised connection - factory. + factory. The returned object may be used in conjunction with 'with'. + When used outside a context manager, use the `connection` attribute + to get the connection. """ try: - return psycopg2.connect(dsn, connection_factory=_Connection) + conn = psycopg2.connect(dsn, connection_factory=_Connection) + ctxmgr = contextlib.closing(conn) + ctxmgr.connection = conn + return ctxmgr except psycopg2.OperationalError as err: raise UsageError("Cannot connect to database: {}".format(err)) from err diff --git a/nominatim/tools/check_database.py b/nominatim/tools/check_database.py index 7b8da200..d8ab08cc 100644 --- a/nominatim/tools/check_database.py +++ b/nominatim/tools/check_database.py @@ -60,7 +60,7 @@ def check_database(config): """ Run a number of checks on the database and return the status. """ try: - conn = connect(config.get_libpq_dsn()) + conn = connect(config.get_libpq_dsn()).connection except UsageError as err: conn = _BadConnection(str(err)) diff --git a/test/python/conftest.py b/test/python/conftest.py index 72a56dcf..0e0e808c 100644 --- a/test/python/conftest.py +++ b/test/python/conftest.py @@ -85,9 +85,8 @@ def temp_db_with_extensions(temp_db): def temp_db_conn(temp_db): """ Connection to the test database. """ - conn = connection.connect('dbname=' + temp_db) - yield conn - conn.close() + with connection.connect('dbname=' + temp_db) as conn: + yield conn @pytest.fixture diff --git a/test/python/test_db_connection.py b/test/python/test_db_connection.py index 11ad691a..846ef864 100644 --- a/test/python/test_db_connection.py +++ b/test/python/test_db_connection.py @@ -7,9 +7,8 @@ from nominatim.db.connection import connect @pytest.fixture def db(temp_db): - conn = connect('dbname=' + temp_db) - yield conn - conn.close() + with connect('dbname=' + temp_db) as conn: + yield conn def test_connection_table_exists(db, temp_db_cursor): diff --git a/test/python/test_tools_admin.py b/test/python/test_tools_admin.py index a40a17db..36c7d6ff 100644 --- a/test/python/test_tools_admin.py +++ b/test/python/test_tools_admin.py @@ -9,9 +9,8 @@ from nominatim.tools import admin @pytest.fixture def db(temp_db, placex_table): - conn = connect('dbname=' + temp_db) - yield conn - conn.close() + with connect('dbname=' + temp_db) as conn: + yield conn def test_analyse_indexing_no_objects(db): with pytest.raises(UsageError): diff --git a/test/python/test_tools_check_database.py b/test/python/test_tools_check_database.py index 0b5c23a6..3787c3be 100644 --- a/test/python/test_tools_check_database.py +++ b/test/python/test_tools_check_database.py @@ -10,6 +10,10 @@ def test_check_database_unknown_db(def_config, monkeypatch): assert 1 == chkdb.check_database(def_config) +def test_check_database_fatal_test(def_config, temp_db): + assert 1 == chkdb.check_database(def_config) + + def test_check_conection_good(temp_db_conn, def_config): assert chkdb.check_connection(temp_db_conn, def_config) == chkdb.CheckState.OK diff --git a/test/python/test_tools_refresh_create_functions.py b/test/python/test_tools_refresh_create_functions.py index d219d748..ac2f2211 100644 --- a/test/python/test_tools_refresh_create_functions.py +++ b/test/python/test_tools_refresh_create_functions.py @@ -11,9 +11,8 @@ SQL_DIR = (Path(__file__) / '..' / '..' / '..' / 'lib-sql').resolve() @pytest.fixture def db(temp_db): - conn = connect('dbname=' + temp_db) - yield conn - conn.close() + with connect('dbname=' + temp_db) as conn: + yield conn @pytest.fixture def db_with_tables(db): -- 2.39.5