]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/db/connection.py
8faa3f93334dda7a219b4e5dcd1caf98052a6274
[nominatim.git] / src / nominatim_db / db / connection.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Specialised connection and cursor functions.
9 """
10 from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
11 import contextlib
12 import logging
13 import os
14
15 import psycopg2
16 import psycopg2.extensions
17 import psycopg2.extras
18 from psycopg2 import sql as pysql
19
20 from ..typing import SysEnv, Query, T_cursor
21 from ..errors import UsageError
22
23 LOG = logging.getLogger()
24
25 class Cursor(psycopg2.extras.DictCursor):
26     """ A cursor returning dict-like objects and providing specialised
27         execution functions.
28     """
29     # pylint: disable=arguments-renamed,arguments-differ
30     def execute(self, query: Query, args: Any = None) -> None:
31         """ Query execution that logs the SQL query when debugging is enabled.
32         """
33         if LOG.isEnabledFor(logging.DEBUG):
34             LOG.debug(self.mogrify(query, args).decode('utf-8'))
35
36         super().execute(query, args)
37
38
39     def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
40                        template: Optional[Query] = None) -> None:
41         """ Wrapper for the psycopg2 convenience function to execute
42             SQL for a list of values.
43         """
44         LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
45
46         psycopg2.extras.execute_values(self, sql, argslist, template=template)
47
48
49     def scalar(self, sql: Query, args: Any = None) -> Any:
50         """ Execute query that returns a single value. The value is returned.
51             If the query yields more than one row, a ValueError is raised.
52         """
53         self.execute(sql, args)
54
55         if self.rowcount != 1:
56             raise RuntimeError("Query did not return a single row.")
57
58         result = self.fetchone()
59         assert result is not None
60
61         return result[0]
62
63
64     def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
65         """ Drop the table with the given name.
66             Set `if_exists` to False if a non-existent table should raise
67             an exception instead of just being ignored. If 'cascade' is set
68             to True then all dependent tables are deleted as well.
69         """
70         sql = 'DROP TABLE '
71         if if_exists:
72             sql += 'IF EXISTS '
73         sql += '{}'
74         if cascade:
75             sql += ' CASCADE'
76
77         self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
78
79
80 class Connection(psycopg2.extensions.connection):
81     """ A connection that provides the specialised cursor by default and
82         adds convenience functions for administrating the database.
83     """
84     @overload # type: ignore[override]
85     def cursor(self) -> Cursor:
86         ...
87
88     @overload
89     def cursor(self, name: str) -> Cursor:
90         ...
91
92     @overload
93     def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
94         ...
95
96     def cursor(self, cursor_factory  = Cursor, **kwargs): # type: ignore
97         """ Return a new cursor. By default the specialised cursor is returned.
98         """
99         return super().cursor(cursor_factory=cursor_factory, **kwargs)
100
101
102     def table_exists(self, table: str) -> bool:
103         """ Check that a table with the given name exists in the database.
104         """
105         with self.cursor() as cur:
106             num = cur.scalar("""SELECT count(*) FROM pg_tables
107                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
108             return num == 1 if isinstance(num, int) else False
109
110
111     def table_has_column(self, table: str, column: str) -> bool:
112         """ Check if the table 'table' exists and has a column with name 'column'.
113         """
114         with self.cursor() as cur:
115             has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
116                                        WHERE table_name = %s
117                                              and column_name = %s""",
118                                     (table, column))
119             return has_column > 0 if isinstance(has_column, int) else False
120
121
122     def index_exists(self, index: str, table: Optional[str] = None) -> bool:
123         """ Check that an index with the given name exists in the database.
124             If table is not None then the index must relate to the given
125             table.
126         """
127         with self.cursor() as cur:
128             cur.execute("""SELECT tablename FROM pg_indexes
129                            WHERE indexname = %s and schemaname = 'public'""", (index, ))
130             if cur.rowcount == 0:
131                 return False
132
133             if table is not None:
134                 row = cur.fetchone()
135                 if row is None or not isinstance(row[0], str):
136                     return False
137                 return row[0] == table
138
139         return True
140
141
142     def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
143         """ Drop the table with the given name.
144             Set `if_exists` to False if a non-existent table should raise
145             an exception instead of just being ignored.
146         """
147         with self.cursor() as cur:
148             cur.drop_table(name, if_exists, cascade)
149         self.commit()
150
151
152     def server_version_tuple(self) -> Tuple[int, int]:
153         """ Return the server version as a tuple of (major, minor).
154             Converts correctly for pre-10 and post-10 PostgreSQL versions.
155         """
156         version = self.server_version
157         if version < 100000:
158             return (int(version / 10000), int((version % 10000) / 100))
159
160         return (int(version / 10000), version % 10000)
161
162
163     def postgis_version_tuple(self) -> Tuple[int, int]:
164         """ Return the postgis version installed in the database as a
165             tuple of (major, minor). Assumes that the PostGIS extension
166             has been installed already.
167         """
168         with self.cursor() as cur:
169             version = cur.scalar('SELECT postgis_lib_version()')
170
171         version_parts = version.split('.')
172         if len(version_parts) < 2:
173             raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
174
175         return (int(version_parts[0]), int(version_parts[1]))
176
177
178 class ConnectionContext(ContextManager[Connection]):
179     """ Context manager of the connection that also provides direct access
180         to the underlying connection.
181     """
182     connection: Connection
183
184
185 def connect(dsn: str) -> ConnectionContext:
186     """ Open a connection to the database using the specialised connection
187         factory. The returned object may be used in conjunction with 'with'.
188         When used outside a context manager, use the `connection` attribute
189         to get the connection.
190     """
191     try:
192         conn = psycopg2.connect(dsn, connection_factory=Connection)
193         ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
194         ctxmgr.connection = conn
195         return ctxmgr
196     except psycopg2.OperationalError as err:
197         raise UsageError(f"Cannot connect to database: {err}") from err
198
199
200 # Translation from PG connection string parameters to PG environment variables.
201 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
202 _PG_CONNECTION_STRINGS = {
203     'host': 'PGHOST',
204     'hostaddr': 'PGHOSTADDR',
205     'port': 'PGPORT',
206     'dbname': 'PGDATABASE',
207     'user': 'PGUSER',
208     'password': 'PGPASSWORD',
209     'passfile': 'PGPASSFILE',
210     'channel_binding': 'PGCHANNELBINDING',
211     'service': 'PGSERVICE',
212     'options': 'PGOPTIONS',
213     'application_name': 'PGAPPNAME',
214     'sslmode': 'PGSSLMODE',
215     'requiressl': 'PGREQUIRESSL',
216     'sslcompression': 'PGSSLCOMPRESSION',
217     'sslcert': 'PGSSLCERT',
218     'sslkey': 'PGSSLKEY',
219     'sslrootcert': 'PGSSLROOTCERT',
220     'sslcrl': 'PGSSLCRL',
221     'requirepeer': 'PGREQUIREPEER',
222     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
223     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
224     'gssencmode': 'PGGSSENCMODE',
225     'krbsrvname': 'PGKRBSRVNAME',
226     'gsslib': 'PGGSSLIB',
227     'connect_timeout': 'PGCONNECT_TIMEOUT',
228     'target_session_attrs': 'PGTARGETSESSIONATTRS',
229 }
230
231
232 def get_pg_env(dsn: str,
233                base_env: Optional[SysEnv] = None) -> Dict[str, str]:
234     """ Return a copy of `base_env` with the environment variables for
235         PostgreSQL set up from the given database connection string.
236         If `base_env` is None, then the OS environment is used as a base
237         environment.
238     """
239     env = dict(base_env if base_env is not None else os.environ)
240
241     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
242         if param in _PG_CONNECTION_STRINGS:
243             env[_PG_CONNECTION_STRINGS[param]] = value
244         else:
245             LOG.error("Unknown connection parameter '%s' ignored.", param)
246
247     return env