X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/d78f0ba80470a33a7a76edfe3ace5108684873cd..f08078ccca8077b09f4b1cbb8991a4ecf7257cc4:/nominatim/db/connection.py diff --git a/nominatim/db/connection.py b/nominatim/db/connection.py index 8e75d7a2..b941f46f 100644 --- a/nominatim/db/connection.py +++ b/nominatim/db/connection.py @@ -7,6 +7,8 @@ import psycopg2 import psycopg2.extensions import psycopg2.extras +from ..errors import UsageError + class _Cursor(psycopg2.extras.DictCursor): """ A cursor returning dict-like objects and providing specialised execution functions. @@ -27,7 +29,7 @@ class _Cursor(psycopg2.extras.DictCursor): self.execute(sql, args) if self.rowcount != 1: - raise ValueError("Query did not return a single row.") + raise RuntimeError("Query did not return a single row.") return self.fetchone()[0] @@ -42,17 +44,49 @@ class _Connection(psycopg2.extensions.connection): """ return super().cursor(cursor_factory=cursor_factory, **kwargs) + def table_exists(self, table): """ Check that a table with the given name exists in the database. """ with self.cursor() as cur: num = cur.scalar("""SELECT count(*) FROM pg_tables - WHERE tablename = %s""", (table, )) + WHERE tablename = %s and schemaname = 'public'""", (table, )) return num == 1 + def index_exists(self, index, table=None): + """ Check that an index with the given name exists in the database. + If table is not None then the index must relate to the given + table. + """ + with self.cursor() as cur: + cur.execute("""SELECT tablename FROM pg_indexes + WHERE indexname = %s and schemaname = 'public'""", (index, )) + if cur.rowcount == 0: + return False + + if table is not None: + row = cur.fetchone() + return row[0] == table + + return True + + + def server_version_tuple(self): + """ Return the server version as a tuple of (major, minor). + Converts correctly for pre-10 and post-10 PostgreSQL versions. + """ + version = self.server_version + if version < 100000: + return (version / 10000, (version % 10000) / 100) + + return (version / 10000, version % 10000) + def connect(dsn): """ Open a connection to the database using the specialised connection factory. """ - return psycopg2.connect(dsn, connection_factory=_Connection) + try: + return psycopg2.connect(dsn, connection_factory=_Connection) + except psycopg2.OperationalError as err: + raise UsageError("Cannot connect to database: {}".format(err)) from err