]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/connection.py
bdd tests: directly call python code for setup-website
[nominatim.git] / nominatim / db / connection.py
index c7e22c98e500c7d53936fad88fa77a8bdd3fad7f..b941f46f56c63c74444506dec457bf09a8999c07 100644 (file)
@@ -7,6 +7,8 @@ import psycopg2
 import psycopg2.extensions
 import psycopg2.extras
 
 import psycopg2.extensions
 import psycopg2.extras
 
+from ..errors import UsageError
+
 class _Cursor(psycopg2.extras.DictCursor):
     """ A cursor returning dict-like objects and providing specialised
         execution functions.
 class _Cursor(psycopg2.extras.DictCursor):
     """ A cursor returning dict-like objects and providing specialised
         execution functions.
@@ -42,14 +44,34 @@ class _Connection(psycopg2.extensions.connection):
         """
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
         """
         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
+
     def table_exists(self, table):
         """ Check that a table with the given name exists in the database.
         """
         with self.cursor() as cur:
             num = cur.scalar("""SELECT count(*) FROM pg_tables
     def table_exists(self, table):
         """ Check that a table with the given name exists in the database.
         """
         with self.cursor() as cur:
             num = cur.scalar("""SELECT count(*) FROM pg_tables
-                                WHERE tablename = %s""", (table, ))
+                                WHERE tablename = %s and schemaname = 'public'""", (table, ))
             return num == 1
 
             return num == 1
 
+
+    def index_exists(self, index, table=None):
+        """ Check that an index with the given name exists in the database.
+            If table is not None then the index must relate to the given
+            table.
+        """
+        with self.cursor() as cur:
+            cur.execute("""SELECT tablename FROM pg_indexes
+                           WHERE indexname = %s and schemaname = 'public'""", (index, ))
+            if cur.rowcount == 0:
+                return False
+
+            if table is not None:
+                row = cur.fetchone()
+                return row[0] == table
+
+        return True
+
+
     def server_version_tuple(self):
         """ Return the server version as a tuple of (major, minor).
             Converts correctly for pre-10 and post-10 PostgreSQL versions.
     def server_version_tuple(self):
         """ Return the server version as a tuple of (major, minor).
             Converts correctly for pre-10 and post-10 PostgreSQL versions.
@@ -64,4 +86,7 @@ def connect(dsn):
     """ Open a connection to the database using the specialised connection
         factory.
     """
     """ Open a connection to the database using the specialised connection
         factory.
     """
-    return psycopg2.connect(dsn, connection_factory=_Connection)
+    try:
+        return psycopg2.connect(dsn, connection_factory=_Connection)
+    except psycopg2.OperationalError as err:
+        raise UsageError("Cannot connect to database: {}".format(err)) from err