import psycopg2.extensions
import psycopg2.extras
+from ..errors import UsageError
+
class _Cursor(psycopg2.extras.DictCursor):
""" A cursor returning dict-like objects and providing specialised
execution functions.
"""
return super().cursor(cursor_factory=cursor_factory, **kwargs)
+
def table_exists(self, table):
""" Check that a table with the given name exists in the database.
"""
with self.cursor() as cur:
num = cur.scalar("""SELECT count(*) FROM pg_tables
- WHERE tablename = %s""", (table, ))
+ WHERE tablename = %s and schemaname = 'public'""", (table, ))
return num == 1
+
+ def index_exists(self, index, table=None):
+ """ Check that an index with the given name exists in the database.
+ If table is not None then the index must relate to the given
+ table.
+ """
+ with self.cursor() as cur:
+ cur.execute("""SELECT tablename FROM pg_indexes
+ WHERE indexname = %s and schemaname = 'public'""", (index, ))
+ if cur.rowcount == 0:
+ return False
+
+ if table is not None:
+ row = cur.fetchone()
+ return row[0] == table
+
+ return True
+
+
def server_version_tuple(self):
""" Return the server version as a tuple of (major, minor).
Converts correctly for pre-10 and post-10 PostgreSQL versions.
""" Open a connection to the database using the specialised connection
factory.
"""
- return psycopg2.connect(dsn, connection_factory=_Connection)
+ try:
+ return psycopg2.connect(dsn, connection_factory=_Connection)
+ except psycopg2.OperationalError as err:
+ raise UsageError("Cannot connect to database: {}".format(err)) from err