From 71249bd94a1bd698a937983663f06a9376629ae6 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Tue, 2 Jul 2024 14:52:57 +0200 Subject: [PATCH 1/1] remove extension existence helper This is only used in one place. --- src/nominatim_db/db/connection.py | 9 +-------- src/nominatim_db/tools/database_import.py | 8 +++++--- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/nominatim_db/db/connection.py b/src/nominatim_db/db/connection.py index 19fcddd4..8faa3f93 100644 --- a/src/nominatim_db/db/connection.py +++ b/src/nominatim_db/db/connection.py @@ -175,20 +175,13 @@ class Connection(psycopg2.extensions.connection): return (int(version_parts[0]), int(version_parts[1])) - def extension_loaded(self, extension_name: str) -> bool: - """ Return True if the hstore extension is loaded in the database. - """ - with self.cursor() as cur: - cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, )) - return cur.rowcount > 0 - - class ConnectionContext(ContextManager[Connection]): """ Context manager of the connection that also provides direct access to the underlying connection. """ connection: Connection + def connect(dsn: str) -> ConnectionContext: """ Open a connection to the database using the specialised connection factory. The returned object may be used in conjunction with 'with'. diff --git a/src/nominatim_db/tools/database_import.py b/src/nominatim_db/tools/database_import.py index d07febc8..c4b3023a 100644 --- a/src/nominatim_db/tools/database_import.py +++ b/src/nominatim_db/tools/database_import.py @@ -40,9 +40,11 @@ def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, def _require_loaded(extension_name: str, conn: Connection) -> None: """ Check that the given extension is loaded. """ - if not conn.extension_loaded(extension_name): - LOG.fatal('Required module %s is not loaded.', extension_name) - raise UsageError(f'{extension_name} is not loaded.') + with conn.cursor() as cur: + cur.execute('SELECT * FROM pg_extension WHERE extname = %s', (extension_name, )) + if cur.rowcount <= 0: + LOG.fatal('Required module %s is not loaded.', extension_name) + raise UsageError(f'{extension_name} is not loaded.') def check_existing_database_plugins(dsn: str) -> None: -- 2.39.5