]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/db/connection.py
make DB helper functions free functions
[nominatim.git] / src / nominatim_db / db / connection.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Specialised connection and cursor functions.
9 """
10 from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\
11                    Tuple, Iterable
12 import contextlib
13 import logging
14 import os
15
16 import psycopg2
17 import psycopg2.extensions
18 import psycopg2.extras
19 from psycopg2 import sql as pysql
20
21 from ..typing import SysEnv, Query, T_cursor
22 from ..errors import UsageError
23
24 LOG = logging.getLogger()
25
26 class Cursor(psycopg2.extras.DictCursor):
27     """ A cursor returning dict-like objects and providing specialised
28         execution functions.
29     """
30     # pylint: disable=arguments-renamed,arguments-differ
31     def execute(self, query: Query, args: Any = None) -> None:
32         """ Query execution that logs the SQL query when debugging is enabled.
33         """
34         if LOG.isEnabledFor(logging.DEBUG):
35             LOG.debug(self.mogrify(query, args).decode('utf-8'))
36
37         super().execute(query, args)
38
39
40     def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
41                        template: Optional[Query] = None) -> None:
42         """ Wrapper for the psycopg2 convenience function to execute
43             SQL for a list of values.
44         """
45         LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
46
47         psycopg2.extras.execute_values(self, sql, argslist, template=template)
48
49
50 class Connection(psycopg2.extensions.connection):
51     """ A connection that provides the specialised cursor by default and
52         adds convenience functions for administrating the database.
53     """
54     @overload # type: ignore[override]
55     def cursor(self) -> Cursor:
56         ...
57
58     @overload
59     def cursor(self, name: str) -> Cursor:
60         ...
61
62     @overload
63     def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
64         ...
65
66     def cursor(self, cursor_factory  = Cursor, **kwargs): # type: ignore
67         """ Return a new cursor. By default the specialised cursor is returned.
68         """
69         return super().cursor(cursor_factory=cursor_factory, **kwargs)
70
71
72 def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
73     """ Execute query that returns a single value. The value is returned.
74         If the query yields more than one row, a ValueError is raised.
75     """
76     with conn.cursor() as cur:
77         cur.execute(sql, args)
78
79         if cur.rowcount != 1:
80             raise RuntimeError("Query did not return a single row.")
81
82         result = cur.fetchone()
83
84     assert result is not None
85     return result[0]
86
87
88 def table_exists(conn: Connection, table: str) -> bool:
89     """ Check that a table with the given name exists in the database.
90     """
91     num = execute_scalar(conn,
92             """SELECT count(*) FROM pg_tables
93                WHERE tablename = %s and schemaname = 'public'""", (table, ))
94     return num == 1 if isinstance(num, int) else False
95
96
97 def table_has_column(conn: Connection, table: str, column: str) -> bool:
98     """ Check if the table 'table' exists and has a column with name 'column'.
99     """
100     has_column = execute_scalar(conn,
101                     """SELECT count(*) FROM information_schema.columns
102                        WHERE table_name = %s and column_name = %s""",
103                     (table, column))
104     return has_column > 0 if isinstance(has_column, int) else False
105
106
107 def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
108     """ Check that an index with the given name exists in the database.
109         If table is not None then the index must relate to the given
110         table.
111     """
112     with conn.cursor() as cur:
113         cur.execute("""SELECT tablename FROM pg_indexes
114                        WHERE indexname = %s and schemaname = 'public'""", (index, ))
115         if cur.rowcount == 0:
116             return False
117
118         if table is not None:
119             row = cur.fetchone()
120             if row is None or not isinstance(row[0], str):
121                 return False
122             return row[0] == table
123
124     return True
125
126 def drop_tables(conn: Connection, *names: str,
127                if_exists: bool = True, cascade: bool = False) -> None:
128     """ Drop one or more tables with the given names.
129         Set `if_exists` to False if a non-existent table should raise
130         an exception instead of just being ignored. `cascade` will cause
131         depended objects to be dropped as well.
132         The caller needs to take care of committing the change.
133     """
134     sql = pysql.SQL('DROP TABLE%s{}%s' % (
135                         ' IF EXISTS ' if if_exists else ' ',
136                         ' CASCADE' if cascade else ''))
137
138     with conn.cursor() as cur:
139         for name in names:
140             cur.execute(sql.format(pysql.Identifier(name)))
141
142
143 def server_version_tuple(conn: Connection) -> Tuple[int, int]:
144     """ Return the server version as a tuple of (major, minor).
145         Converts correctly for pre-10 and post-10 PostgreSQL versions.
146     """
147     version = conn.server_version
148     if version < 100000:
149         return (int(version / 10000), int((version % 10000) / 100))
150
151     return (int(version / 10000), version % 10000)
152
153
154 def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
155     """ Return the postgis version installed in the database as a
156         tuple of (major, minor). Assumes that the PostGIS extension
157         has been installed already.
158     """
159     version = execute_scalar(conn, 'SELECT postgis_lib_version()')
160
161     version_parts = version.split('.')
162     if len(version_parts) < 2:
163         raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
164
165     return (int(version_parts[0]), int(version_parts[1]))
166
167 def register_hstore(conn: Connection) -> None:
168     """ Register the hstore type with psycopg for the connection.
169     """
170     psycopg2.extras.register_hstore(conn)
171
172
173 class ConnectionContext(ContextManager[Connection]):
174     """ Context manager of the connection that also provides direct access
175         to the underlying connection.
176     """
177     connection: Connection
178
179
180 def connect(dsn: str) -> ConnectionContext:
181     """ Open a connection to the database using the specialised connection
182         factory. The returned object may be used in conjunction with 'with'.
183         When used outside a context manager, use the `connection` attribute
184         to get the connection.
185     """
186     try:
187         conn = psycopg2.connect(dsn, connection_factory=Connection)
188         ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
189         ctxmgr.connection = conn
190         return ctxmgr
191     except psycopg2.OperationalError as err:
192         raise UsageError(f"Cannot connect to database: {err}") from err
193
194
195 # Translation from PG connection string parameters to PG environment variables.
196 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
197 _PG_CONNECTION_STRINGS = {
198     'host': 'PGHOST',
199     'hostaddr': 'PGHOSTADDR',
200     'port': 'PGPORT',
201     'dbname': 'PGDATABASE',
202     'user': 'PGUSER',
203     'password': 'PGPASSWORD',
204     'passfile': 'PGPASSFILE',
205     'channel_binding': 'PGCHANNELBINDING',
206     'service': 'PGSERVICE',
207     'options': 'PGOPTIONS',
208     'application_name': 'PGAPPNAME',
209     'sslmode': 'PGSSLMODE',
210     'requiressl': 'PGREQUIRESSL',
211     'sslcompression': 'PGSSLCOMPRESSION',
212     'sslcert': 'PGSSLCERT',
213     'sslkey': 'PGSSLKEY',
214     'sslrootcert': 'PGSSLROOTCERT',
215     'sslcrl': 'PGSSLCRL',
216     'requirepeer': 'PGREQUIREPEER',
217     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
218     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
219     'gssencmode': 'PGGSSENCMODE',
220     'krbsrvname': 'PGKRBSRVNAME',
221     'gsslib': 'PGGSSLIB',
222     'connect_timeout': 'PGCONNECT_TIMEOUT',
223     'target_session_attrs': 'PGTARGETSESSIONATTRS',
224 }
225
226
227 def get_pg_env(dsn: str,
228                base_env: Optional[SysEnv] = None) -> Dict[str, str]:
229     """ Return a copy of `base_env` with the environment variables for
230         PostgreSQL set up from the given database connection string.
231         If `base_env` is None, then the OS environment is used as a base
232         environment.
233     """
234     env = dict(base_env if base_env is not None else os.environ)
235
236     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
237         if param in _PG_CONNECTION_STRINGS:
238             env[_PG_CONNECTION_STRINGS[param]] = value
239         else:
240             LOG.error("Unknown connection parameter '%s' ignored.", param)
241
242     return env