1 # SPDX-License-Identifier: GPL-2.0-only
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Specialised connection and cursor functions.
15 import psycopg2.extensions
16 import psycopg2.extras
17 from psycopg2 import sql as pysql
19 from nominatim.errors import UsageError
21 LOG = logging.getLogger()
23 class _Cursor(psycopg2.extras.DictCursor):
24 """ A cursor returning dict-like objects and providing specialised
28 # pylint: disable=arguments-renamed,arguments-differ
29 def execute(self, query, args=None):
30 """ Query execution that logs the SQL query when debugging is enabled.
32 LOG.debug(self.mogrify(query, args).decode('utf-8'))
34 super().execute(query, args)
37 def execute_values(self, sql, argslist, template=None):
38 """ Wrapper for the psycopg2 convenience function to execute
39 SQL for a list of values.
41 LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
43 psycopg2.extras.execute_values(self, sql, argslist, template=template)
46 def scalar(self, sql, args=None):
47 """ Execute query that returns a single value. The value is returned.
48 If the query yields more than one row, a ValueError is raised.
50 self.execute(sql, args)
52 if self.rowcount != 1:
53 raise RuntimeError("Query did not return a single row.")
55 return self.fetchone()[0]
58 def drop_table(self, name, if_exists=True, cascade=False):
59 """ Drop the table with the given name.
60 Set `if_exists` to False if a non-existant table should raise
61 an exception instead of just being ignored. If 'cascade' is set
62 to True then all dependent tables are deleted as well.
71 self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
74 class _Connection(psycopg2.extensions.connection):
75 """ A connection that provides the specialised cursor by default and
76 adds convenience functions for administrating the database.
79 def cursor(self, cursor_factory=_Cursor, **kwargs):
80 """ Return a new cursor. By default the specialised cursor is returned.
82 return super().cursor(cursor_factory=cursor_factory, **kwargs)
85 def table_exists(self, table):
86 """ Check that a table with the given name exists in the database.
88 with self.cursor() as cur:
89 num = cur.scalar("""SELECT count(*) FROM pg_tables
90 WHERE tablename = %s and schemaname = 'public'""", (table, ))
94 def table_has_column(self, table, column):
95 """ Check if the table 'table' exists and has a column with name 'column'.
97 with self.cursor() as cur:
98 has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
100 and column_name = %s""",
102 return has_column > 0
105 def index_exists(self, index, table=None):
106 """ Check that an index with the given name exists in the database.
107 If table is not None then the index must relate to the given
110 with self.cursor() as cur:
111 cur.execute("""SELECT tablename FROM pg_indexes
112 WHERE indexname = %s and schemaname = 'public'""", (index, ))
113 if cur.rowcount == 0:
116 if table is not None:
118 return row[0] == table
123 def drop_table(self, name, if_exists=True, cascade=False):
124 """ Drop the table with the given name.
125 Set `if_exists` to False if a non-existant table should raise
126 an exception instead of just being ignored.
128 with self.cursor() as cur:
129 cur.drop_table(name, if_exists, cascade)
133 def server_version_tuple(self):
134 """ Return the server version as a tuple of (major, minor).
135 Converts correctly for pre-10 and post-10 PostgreSQL versions.
137 version = self.server_version
139 return (int(version / 10000), (version % 10000) / 100)
141 return (int(version / 10000), version % 10000)
144 def postgis_version_tuple(self):
145 """ Return the postgis version installed in the database as a
146 tuple of (major, minor). Assumes that the PostGIS extension
147 has been installed already.
149 with self.cursor() as cur:
150 version = cur.scalar('SELECT postgis_lib_version()')
152 return tuple((int(x) for x in version.split('.')[:2]))
156 """ Open a connection to the database using the specialised connection
157 factory. The returned object may be used in conjunction with 'with'.
158 When used outside a context manager, use the `connection` attribute
159 to get the connection.
162 conn = psycopg2.connect(dsn, connection_factory=_Connection)
163 ctxmgr = contextlib.closing(conn)
164 ctxmgr.connection = conn
166 except psycopg2.OperationalError as err:
167 raise UsageError(f"Cannot connect to database: {err}") from err
170 # Translation from PG connection string parameters to PG environment variables.
171 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
172 _PG_CONNECTION_STRINGS = {
174 'hostaddr': 'PGHOSTADDR',
176 'dbname': 'PGDATABASE',
178 'password': 'PGPASSWORD',
179 'passfile': 'PGPASSFILE',
180 'channel_binding': 'PGCHANNELBINDING',
181 'service': 'PGSERVICE',
182 'options': 'PGOPTIONS',
183 'application_name': 'PGAPPNAME',
184 'sslmode': 'PGSSLMODE',
185 'requiressl': 'PGREQUIRESSL',
186 'sslcompression': 'PGSSLCOMPRESSION',
187 'sslcert': 'PGSSLCERT',
188 'sslkey': 'PGSSLKEY',
189 'sslrootcert': 'PGSSLROOTCERT',
190 'sslcrl': 'PGSSLCRL',
191 'requirepeer': 'PGREQUIREPEER',
192 'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
193 'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
194 'gssencmode': 'PGGSSENCMODE',
195 'krbsrvname': 'PGKRBSRVNAME',
196 'gsslib': 'PGGSSLIB',
197 'connect_timeout': 'PGCONNECT_TIMEOUT',
198 'target_session_attrs': 'PGTARGETSESSIONATTRS',
202 def get_pg_env(dsn, base_env=None):
203 """ Return a copy of `base_env` with the environment variables for
204 PostgresSQL set up from the given database connection string.
205 If `base_env` is None, then the OS environment is used as a base
208 env = dict(base_env if base_env is not None else os.environ)
210 for param, value in psycopg2.extensions.parse_dsn(dsn).items():
211 if param in _PG_CONNECTION_STRINGS:
212 env[_PG_CONNECTION_STRINGS[param]] = value
214 LOG.error("Unknown connection parameter '%s' ignored.", param)