2 Specialised connection and cursor functions.
9 import psycopg2.extensions
10 import psycopg2.extras
12 from ..errors import UsageError
14 LOG = logging.getLogger()
16 class _Cursor(psycopg2.extras.DictCursor):
17 """ A cursor returning dict-like objects and providing specialised
21 def execute(self, query, args=None): # pylint: disable=W0221
22 """ Query execution that logs the SQL query when debugging is enabled.
24 LOG.debug(self.mogrify(query, args).decode('utf-8'))
26 super().execute(query, args)
28 def scalar(self, sql, args=None):
29 """ Execute query that returns a single value. The value is returned.
30 If the query yields more than one row, a ValueError is raised.
32 self.execute(sql, args)
34 if self.rowcount != 1:
35 raise RuntimeError("Query did not return a single row.")
37 return self.fetchone()[0]
40 class _Connection(psycopg2.extensions.connection):
41 """ A connection that provides the specialised cursor by default and
42 adds convenience functions for administrating the database.
45 def cursor(self, cursor_factory=_Cursor, **kwargs):
46 """ Return a new cursor. By default the specialised cursor is returned.
48 return super().cursor(cursor_factory=cursor_factory, **kwargs)
51 def table_exists(self, table):
52 """ Check that a table with the given name exists in the database.
54 with self.cursor() as cur:
55 num = cur.scalar("""SELECT count(*) FROM pg_tables
56 WHERE tablename = %s and schemaname = 'public'""", (table, ))
60 def index_exists(self, index, table=None):
61 """ Check that an index with the given name exists in the database.
62 If table is not None then the index must relate to the given
65 with self.cursor() as cur:
66 cur.execute("""SELECT tablename FROM pg_indexes
67 WHERE indexname = %s and schemaname = 'public'""", (index, ))
73 return row[0] == table
78 def server_version_tuple(self):
79 """ Return the server version as a tuple of (major, minor).
80 Converts correctly for pre-10 and post-10 PostgreSQL versions.
82 version = self.server_version
84 return (int(version / 10000), (version % 10000) / 100)
86 return (int(version / 10000), version % 10000)
89 def postgis_version_tuple(self):
90 """ Return the postgis version installed in the database as a
91 tuple of (major, minor). Assumes that the PostGIS extension
92 has been installed already.
94 with self.cursor() as cur:
95 version = cur.scalar('SELECT postgis_lib_version()')
97 return tuple((int(x) for x in version.split('.')[:2]))
101 """ Open a connection to the database using the specialised connection
102 factory. The returned object may be used in conjunction with 'with'.
103 When used outside a context manager, use the `connection` attribute
104 to get the connection.
107 conn = psycopg2.connect(dsn, connection_factory=_Connection)
108 ctxmgr = contextlib.closing(conn)
109 ctxmgr.connection = conn
111 except psycopg2.OperationalError as err:
112 raise UsageError("Cannot connect to database: {}".format(err)) from err
115 # Translation from PG connection string parameters to PG environment variables.
116 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
117 _PG_CONNECTION_STRINGS = {
119 'hostaddr': 'PGHOSTADDR',
121 'dbname': 'PGDATABASE',
123 'password': 'PGPASSWORD',
124 'passfile': 'PGPASSFILE',
125 'channel_binding': 'PGCHANNELBINDING',
126 'service': 'PGSERVICE',
127 'options': 'PGOPTIONS',
128 'application_name': 'PGAPPNAME',
129 'sslmode': 'PGSSLMODE',
130 'requiressl': 'PGREQUIRESSL',
131 'sslcompression': 'PGSSLCOMPRESSION',
132 'sslcert': 'PGSSLCERT',
133 'sslkey': 'PGSSLKEY',
134 'sslrootcert': 'PGSSLROOTCERT',
135 'sslcrl': 'PGSSLCRL',
136 'requirepeer': 'PGREQUIREPEER',
137 'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
138 'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
139 'gssencmode': 'PGGSSENCMODE',
140 'krbsrvname': 'PGKRBSRVNAME',
141 'gsslib': 'PGGSSLIB',
142 'connect_timeout': 'PGCONNECT_TIMEOUT',
143 'target_session_attrs': 'PGTARGETSESSIONATTRS',
147 def get_pg_env(dsn, base_env=None):
148 """ Return a copy of `base_env` with the environment variables for
149 PostgresSQL set up from the given database connection string.
150 If `base_env` is None, then the OS environment is used as a base
153 env = dict(base_env if base_env is not None else os.environ)
155 for param, value in psycopg2.extensions.parse_dsn(dsn).items():
156 if param in _PG_CONNECTION_STRINGS:
157 env[_PG_CONNECTION_STRINGS[param]] = value
159 LOG.error("Unknown connection parameter '%s' ignored.", param)