2 Specialised connection and cursor functions.
9 import psycopg2.extensions
10 import psycopg2.extras
11 from psycopg2 import sql as pysql
13 from nominatim.errors import UsageError
15 LOG = logging.getLogger()
17 class _Cursor(psycopg2.extras.DictCursor):
18 """ A cursor returning dict-like objects and providing specialised
22 def execute(self, query, args=None): # pylint: disable=W0221
23 """ Query execution that logs the SQL query when debugging is enabled.
25 LOG.debug(self.mogrify(query, args).decode('utf-8'))
27 super().execute(query, args)
30 def execute_values(self, sql, argslist, template=None):
31 """ Wrapper for the psycopg2 convenience function to execute
32 SQL for a list of values.
34 LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
36 psycopg2.extras.execute_values(self, sql, argslist, template=template)
39 def scalar(self, sql, args=None):
40 """ Execute query that returns a single value. The value is returned.
41 If the query yields more than one row, a ValueError is raised.
43 self.execute(sql, args)
45 if self.rowcount != 1:
46 raise RuntimeError("Query did not return a single row.")
48 return self.fetchone()[0]
51 def drop_table(self, name, if_exists=True, cascade=False):
52 """ Drop the table with the given name.
53 Set `if_exists` to False if a non-existant table should raise
54 an exception instead of just being ignored. If 'cascade' is set
55 to True then all dependent tables are deleted as well.
64 self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
67 class _Connection(psycopg2.extensions.connection):
68 """ A connection that provides the specialised cursor by default and
69 adds convenience functions for administrating the database.
72 def cursor(self, cursor_factory=_Cursor, **kwargs):
73 """ Return a new cursor. By default the specialised cursor is returned.
75 return super().cursor(cursor_factory=cursor_factory, **kwargs)
78 def table_exists(self, table):
79 """ Check that a table with the given name exists in the database.
81 with self.cursor() as cur:
82 num = cur.scalar("""SELECT count(*) FROM pg_tables
83 WHERE tablename = %s and schemaname = 'public'""", (table, ))
87 def index_exists(self, index, table=None):
88 """ Check that an index with the given name exists in the database.
89 If table is not None then the index must relate to the given
92 with self.cursor() as cur:
93 cur.execute("""SELECT tablename FROM pg_indexes
94 WHERE indexname = %s and schemaname = 'public'""", (index, ))
100 return row[0] == table
105 def drop_table(self, name, if_exists=True, cascade=False):
106 """ Drop the table with the given name.
107 Set `if_exists` to False if a non-existant table should raise
108 an exception instead of just being ignored.
110 with self.cursor() as cur:
111 cur.drop_table(name, if_exists, cascade)
115 def server_version_tuple(self):
116 """ Return the server version as a tuple of (major, minor).
117 Converts correctly for pre-10 and post-10 PostgreSQL versions.
119 version = self.server_version
121 return (int(version / 10000), (version % 10000) / 100)
123 return (int(version / 10000), version % 10000)
126 def postgis_version_tuple(self):
127 """ Return the postgis version installed in the database as a
128 tuple of (major, minor). Assumes that the PostGIS extension
129 has been installed already.
131 with self.cursor() as cur:
132 version = cur.scalar('SELECT postgis_lib_version()')
134 return tuple((int(x) for x in version.split('.')[:2]))
138 """ Open a connection to the database using the specialised connection
139 factory. The returned object may be used in conjunction with 'with'.
140 When used outside a context manager, use the `connection` attribute
141 to get the connection.
144 conn = psycopg2.connect(dsn, connection_factory=_Connection)
145 ctxmgr = contextlib.closing(conn)
146 ctxmgr.connection = conn
148 except psycopg2.OperationalError as err:
149 raise UsageError("Cannot connect to database: {}".format(err)) from err
152 # Translation from PG connection string parameters to PG environment variables.
153 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
154 _PG_CONNECTION_STRINGS = {
156 'hostaddr': 'PGHOSTADDR',
158 'dbname': 'PGDATABASE',
160 'password': 'PGPASSWORD',
161 'passfile': 'PGPASSFILE',
162 'channel_binding': 'PGCHANNELBINDING',
163 'service': 'PGSERVICE',
164 'options': 'PGOPTIONS',
165 'application_name': 'PGAPPNAME',
166 'sslmode': 'PGSSLMODE',
167 'requiressl': 'PGREQUIRESSL',
168 'sslcompression': 'PGSSLCOMPRESSION',
169 'sslcert': 'PGSSLCERT',
170 'sslkey': 'PGSSLKEY',
171 'sslrootcert': 'PGSSLROOTCERT',
172 'sslcrl': 'PGSSLCRL',
173 'requirepeer': 'PGREQUIREPEER',
174 'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
175 'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
176 'gssencmode': 'PGGSSENCMODE',
177 'krbsrvname': 'PGKRBSRVNAME',
178 'gsslib': 'PGGSSLIB',
179 'connect_timeout': 'PGCONNECT_TIMEOUT',
180 'target_session_attrs': 'PGTARGETSESSIONATTRS',
184 def get_pg_env(dsn, base_env=None):
185 """ Return a copy of `base_env` with the environment variables for
186 PostgresSQL set up from the given database connection string.
187 If `base_env` is None, then the OS environment is used as a base
190 env = dict(base_env if base_env is not None else os.environ)
192 for param, value in psycopg2.extensions.parse_dsn(dsn).items():
193 if param in _PG_CONNECTION_STRINGS:
194 env[_PG_CONNECTION_STRINGS[param]] = value
196 LOG.error("Unknown connection parameter '%s' ignored.", param)