]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/connection.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / db / connection.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Specialised connection and cursor functions.
9 """
10 import contextlib
11 import logging
12 import os
13
14 import psycopg2
15 import psycopg2.extensions
16 import psycopg2.extras
17 from psycopg2 import sql as pysql
18
19 from nominatim.errors import UsageError
20
21 LOG = logging.getLogger()
22
23 class _Cursor(psycopg2.extras.DictCursor):
24     """ A cursor returning dict-like objects and providing specialised
25         execution functions.
26     """
27
28     # pylint: disable=arguments-renamed,arguments-differ
29     def execute(self, query, args=None):
30         """ Query execution that logs the SQL query when debugging is enabled.
31         """
32         LOG.debug(self.mogrify(query, args).decode('utf-8'))
33
34         super().execute(query, args)
35
36
37     def execute_values(self, sql, argslist, template=None):
38         """ Wrapper for the psycopg2 convenience function to execute
39             SQL for a list of values.
40         """
41         LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
42
43         psycopg2.extras.execute_values(self, sql, argslist, template=template)
44
45
46     def scalar(self, sql, args=None):
47         """ Execute query that returns a single value. The value is returned.
48             If the query yields more than one row, a ValueError is raised.
49         """
50         self.execute(sql, args)
51
52         if self.rowcount != 1:
53             raise RuntimeError("Query did not return a single row.")
54
55         return self.fetchone()[0]
56
57
58     def drop_table(self, name, if_exists=True, cascade=False):
59         """ Drop the table with the given name.
60             Set `if_exists` to False if a non-existant table should raise
61             an exception instead of just being ignored. If 'cascade' is set
62             to True then all dependent tables are deleted as well.
63         """
64         sql = 'DROP TABLE '
65         if if_exists:
66             sql += 'IF EXISTS '
67         sql += '{}'
68         if cascade:
69             sql += ' CASCADE'
70
71         self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
72
73
74 class _Connection(psycopg2.extensions.connection):
75     """ A connection that provides the specialised cursor by default and
76         adds convenience functions for administrating the database.
77     """
78
79     def cursor(self, cursor_factory=_Cursor, **kwargs):
80         """ Return a new cursor. By default the specialised cursor is returned.
81         """
82         return super().cursor(cursor_factory=cursor_factory, **kwargs)
83
84
85     def table_exists(self, table):
86         """ Check that a table with the given name exists in the database.
87         """
88         with self.cursor() as cur:
89             num = cur.scalar("""SELECT count(*) FROM pg_tables
90                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
91             return num == 1
92
93
94     def table_has_column(self, table, column):
95         """ Check if the table 'table' exists and has a column with name 'column'.
96         """
97         with self.cursor() as cur:
98             has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
99                                        WHERE table_name = %s
100                                              and column_name = %s""",
101                                     (table, column))
102             return has_column > 0
103
104
105     def index_exists(self, index, table=None):
106         """ Check that an index with the given name exists in the database.
107             If table is not None then the index must relate to the given
108             table.
109         """
110         with self.cursor() as cur:
111             cur.execute("""SELECT tablename FROM pg_indexes
112                            WHERE indexname = %s and schemaname = 'public'""", (index, ))
113             if cur.rowcount == 0:
114                 return False
115
116             if table is not None:
117                 row = cur.fetchone()
118                 return row[0] == table
119
120         return True
121
122
123     def drop_table(self, name, if_exists=True, cascade=False):
124         """ Drop the table with the given name.
125             Set `if_exists` to False if a non-existant table should raise
126             an exception instead of just being ignored.
127         """
128         with self.cursor() as cur:
129             cur.drop_table(name, if_exists, cascade)
130         self.commit()
131
132
133     def server_version_tuple(self):
134         """ Return the server version as a tuple of (major, minor).
135             Converts correctly for pre-10 and post-10 PostgreSQL versions.
136         """
137         version = self.server_version
138         if version < 100000:
139             return (int(version / 10000), (version % 10000) / 100)
140
141         return (int(version / 10000), version % 10000)
142
143
144     def postgis_version_tuple(self):
145         """ Return the postgis version installed in the database as a
146             tuple of (major, minor). Assumes that the PostGIS extension
147             has been installed already.
148         """
149         with self.cursor() as cur:
150             version = cur.scalar('SELECT postgis_lib_version()')
151
152         return tuple((int(x) for x in version.split('.')[:2]))
153
154
155 def connect(dsn):
156     """ Open a connection to the database using the specialised connection
157         factory. The returned object may be used in conjunction with 'with'.
158         When used outside a context manager, use the `connection` attribute
159         to get the connection.
160     """
161     try:
162         conn = psycopg2.connect(dsn, connection_factory=_Connection)
163         ctxmgr = contextlib.closing(conn)
164         ctxmgr.connection = conn
165         return ctxmgr
166     except psycopg2.OperationalError as err:
167         raise UsageError(f"Cannot connect to database: {err}") from err
168
169
170 # Translation from PG connection string parameters to PG environment variables.
171 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
172 _PG_CONNECTION_STRINGS = {
173     'host': 'PGHOST',
174     'hostaddr': 'PGHOSTADDR',
175     'port': 'PGPORT',
176     'dbname': 'PGDATABASE',
177     'user': 'PGUSER',
178     'password': 'PGPASSWORD',
179     'passfile': 'PGPASSFILE',
180     'channel_binding': 'PGCHANNELBINDING',
181     'service': 'PGSERVICE',
182     'options': 'PGOPTIONS',
183     'application_name': 'PGAPPNAME',
184     'sslmode': 'PGSSLMODE',
185     'requiressl': 'PGREQUIRESSL',
186     'sslcompression': 'PGSSLCOMPRESSION',
187     'sslcert': 'PGSSLCERT',
188     'sslkey': 'PGSSLKEY',
189     'sslrootcert': 'PGSSLROOTCERT',
190     'sslcrl': 'PGSSLCRL',
191     'requirepeer': 'PGREQUIREPEER',
192     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
193     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
194     'gssencmode': 'PGGSSENCMODE',
195     'krbsrvname': 'PGKRBSRVNAME',
196     'gsslib': 'PGGSSLIB',
197     'connect_timeout': 'PGCONNECT_TIMEOUT',
198     'target_session_attrs': 'PGTARGETSESSIONATTRS',
199 }
200
201
202 def get_pg_env(dsn, base_env=None):
203     """ Return a copy of `base_env` with the environment variables for
204         PostgresSQL set up from the given database connection string.
205         If `base_env` is None, then the OS environment is used as a base
206         environment.
207     """
208     env = dict(base_env if base_env is not None else os.environ)
209
210     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
211         if param in _PG_CONNECTION_STRINGS:
212             env[_PG_CONNECTION_STRINGS[param]] = value
213         else:
214             LOG.error("Unknown connection parameter '%s' ignored.", param)
215
216     return env