1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Specialised connection and cursor functions.
10 from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\
17 import psycopg2.extensions
18 import psycopg2.extras
19 from psycopg2 import sql as pysql
21 from ..typing import SysEnv, Query, T_cursor
22 from ..errors import UsageError
24 LOG = logging.getLogger()
26 class Cursor(psycopg2.extras.DictCursor):
27 """ A cursor returning dict-like objects and providing specialised
30 # pylint: disable=arguments-renamed,arguments-differ
31 def execute(self, query: Query, args: Any = None) -> None:
32 """ Query execution that logs the SQL query when debugging is enabled.
34 if LOG.isEnabledFor(logging.DEBUG):
35 LOG.debug(self.mogrify(query, args).decode('utf-8'))
37 super().execute(query, args)
40 def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
41 template: Optional[Query] = None) -> None:
42 """ Wrapper for the psycopg2 convenience function to execute
43 SQL for a list of values.
45 LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
47 psycopg2.extras.execute_values(self, sql, argslist, template=template)
50 class Connection(psycopg2.extensions.connection):
51 """ A connection that provides the specialised cursor by default and
52 adds convenience functions for administrating the database.
54 @overload # type: ignore[override]
55 def cursor(self) -> Cursor:
59 def cursor(self, name: str) -> Cursor:
63 def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
66 def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
67 """ Return a new cursor. By default the specialised cursor is returned.
69 return super().cursor(cursor_factory=cursor_factory, **kwargs)
72 def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
73 """ Execute query that returns a single value. The value is returned.
74 If the query yields more than one row, a ValueError is raised.
76 with conn.cursor() as cur:
77 cur.execute(sql, args)
80 raise RuntimeError("Query did not return a single row.")
82 result = cur.fetchone()
84 assert result is not None
88 def table_exists(conn: Connection, table: str) -> bool:
89 """ Check that a table with the given name exists in the database.
91 num = execute_scalar(conn,
92 """SELECT count(*) FROM pg_tables
93 WHERE tablename = %s and schemaname = 'public'""", (table, ))
94 return num == 1 if isinstance(num, int) else False
97 def table_has_column(conn: Connection, table: str, column: str) -> bool:
98 """ Check if the table 'table' exists and has a column with name 'column'.
100 has_column = execute_scalar(conn,
101 """SELECT count(*) FROM information_schema.columns
102 WHERE table_name = %s and column_name = %s""",
104 return has_column > 0 if isinstance(has_column, int) else False
107 def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
108 """ Check that an index with the given name exists in the database.
109 If table is not None then the index must relate to the given
112 with conn.cursor() as cur:
113 cur.execute("""SELECT tablename FROM pg_indexes
114 WHERE indexname = %s and schemaname = 'public'""", (index, ))
115 if cur.rowcount == 0:
118 if table is not None:
120 if row is None or not isinstance(row[0], str):
122 return row[0] == table
126 def drop_tables(conn: Connection, *names: str,
127 if_exists: bool = True, cascade: bool = False) -> None:
128 """ Drop one or more tables with the given names.
129 Set `if_exists` to False if a non-existent table should raise
130 an exception instead of just being ignored. `cascade` will cause
131 depended objects to be dropped as well.
132 The caller needs to take care of committing the change.
134 sql = pysql.SQL('DROP TABLE%s{}%s' % (
135 ' IF EXISTS ' if if_exists else ' ',
136 ' CASCADE' if cascade else ''))
138 with conn.cursor() as cur:
140 cur.execute(sql.format(pysql.Identifier(name)))
143 def server_version_tuple(conn: Connection) -> Tuple[int, int]:
144 """ Return the server version as a tuple of (major, minor).
145 Converts correctly for pre-10 and post-10 PostgreSQL versions.
147 version = conn.server_version
149 return (int(version / 10000), int((version % 10000) / 100))
151 return (int(version / 10000), version % 10000)
154 def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
155 """ Return the postgis version installed in the database as a
156 tuple of (major, minor). Assumes that the PostGIS extension
157 has been installed already.
159 version = execute_scalar(conn, 'SELECT postgis_lib_version()')
161 version_parts = version.split('.')
162 if len(version_parts) < 2:
163 raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
165 return (int(version_parts[0]), int(version_parts[1]))
167 def register_hstore(conn: Connection) -> None:
168 """ Register the hstore type with psycopg for the connection.
170 psycopg2.extras.register_hstore(conn)
173 class ConnectionContext(ContextManager[Connection]):
174 """ Context manager of the connection that also provides direct access
175 to the underlying connection.
177 connection: Connection
180 def connect(dsn: str) -> ConnectionContext:
181 """ Open a connection to the database using the specialised connection
182 factory. The returned object may be used in conjunction with 'with'.
183 When used outside a context manager, use the `connection` attribute
184 to get the connection.
187 conn = psycopg2.connect(dsn, connection_factory=Connection)
188 ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
189 ctxmgr.connection = conn
191 except psycopg2.OperationalError as err:
192 raise UsageError(f"Cannot connect to database: {err}") from err
195 # Translation from PG connection string parameters to PG environment variables.
196 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
197 _PG_CONNECTION_STRINGS = {
199 'hostaddr': 'PGHOSTADDR',
201 'dbname': 'PGDATABASE',
203 'password': 'PGPASSWORD',
204 'passfile': 'PGPASSFILE',
205 'channel_binding': 'PGCHANNELBINDING',
206 'service': 'PGSERVICE',
207 'options': 'PGOPTIONS',
208 'application_name': 'PGAPPNAME',
209 'sslmode': 'PGSSLMODE',
210 'requiressl': 'PGREQUIRESSL',
211 'sslcompression': 'PGSSLCOMPRESSION',
212 'sslcert': 'PGSSLCERT',
213 'sslkey': 'PGSSLKEY',
214 'sslrootcert': 'PGSSLROOTCERT',
215 'sslcrl': 'PGSSLCRL',
216 'requirepeer': 'PGREQUIREPEER',
217 'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
218 'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
219 'gssencmode': 'PGGSSENCMODE',
220 'krbsrvname': 'PGKRBSRVNAME',
221 'gsslib': 'PGGSSLIB',
222 'connect_timeout': 'PGCONNECT_TIMEOUT',
223 'target_session_attrs': 'PGTARGETSESSIONATTRS',
227 def get_pg_env(dsn: str,
228 base_env: Optional[SysEnv] = None) -> Dict[str, str]:
229 """ Return a copy of `base_env` with the environment variables for
230 PostgreSQL set up from the given database connection string.
231 If `base_env` is None, then the OS environment is used as a base
234 env = dict(base_env if base_env is not None else os.environ)
236 for param, value in psycopg2.extensions.parse_dsn(dsn).items():
237 if param in _PG_CONNECTION_STRINGS:
238 env[_PG_CONNECTION_STRINGS[param]] = value
240 LOG.error("Unknown connection parameter '%s' ignored.", param)