1 # SPDX-License-Identifier: GPL-2.0-only
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Specialised connection and cursor functions.
10 from typing import Union, List, Optional, Any, Callable, ContextManager, Mapping, cast, TypeVar, overload, Tuple, Sequence
16 import psycopg2.extensions
17 import psycopg2.extras
18 from psycopg2 import sql as pysql
20 from nominatim.errors import UsageError
22 Query = Union[str, bytes, pysql.Composable]
23 T = TypeVar('T', bound=psycopg2.extensions.cursor)
25 LOG = logging.getLogger()
27 class _Cursor(psycopg2.extras.DictCursor):
28 """ A cursor returning dict-like objects and providing specialised
31 # pylint: disable=arguments-renamed,arguments-differ
32 def execute(self, query: Query, args: Any = None) -> None:
33 """ Query execution that logs the SQL query when debugging is enabled.
35 if LOG.isEnabledFor(logging.DEBUG):
36 LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore
38 super().execute(query, args)
41 def execute_values(self, sql: Query, argslist: List[Any], template: Optional[str] = None) -> None:
42 """ Wrapper for the psycopg2 convenience function to execute
43 SQL for a list of values.
45 LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
47 psycopg2.extras.execute_values(self, sql, argslist, template=template)
50 def scalar(self, sql: Query, args: Any = None) -> Any:
51 """ Execute query that returns a single value. The value is returned.
52 If the query yields more than one row, a ValueError is raised.
54 self.execute(sql, args)
56 if self.rowcount != 1:
57 raise RuntimeError("Query did not return a single row.")
59 result = self.fetchone() # type: ignore
60 assert result is not None
65 def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
66 """ Drop the table with the given name.
67 Set `if_exists` to False if a non-existant table should raise
68 an exception instead of just being ignored. If 'cascade' is set
69 to True then all dependent tables are deleted as well.
78 self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore
81 class _Connection(psycopg2.extensions.connection):
82 """ A connection that provides the specialised cursor by default and
83 adds convenience functions for administrating the database.
85 @overload # type: ignore[override]
86 def cursor(self) -> _Cursor:
90 def cursor(self, name: str) -> _Cursor:
94 def cursor(self, cursor_factory: Callable[..., T]) -> T:
97 def cursor(self, cursor_factory = _Cursor, **kwargs): # type: ignore
98 """ Return a new cursor. By default the specialised cursor is returned.
100 return super().cursor(cursor_factory=cursor_factory, **kwargs)
103 def table_exists(self, table: str) -> bool:
104 """ Check that a table with the given name exists in the database.
106 with self.cursor() as cur:
107 num = cur.scalar("""SELECT count(*) FROM pg_tables
108 WHERE tablename = %s and schemaname = 'public'""", (table, ))
109 return num == 1 if isinstance(num, int) else False
112 def table_has_column(self, table: str, column: str) -> bool:
113 """ Check if the table 'table' exists and has a column with name 'column'.
115 with self.cursor() as cur:
116 has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
117 WHERE table_name = %s
118 and column_name = %s""",
120 return has_column > 0 if isinstance(has_column, int) else False
123 def index_exists(self, index: str, table: Optional[str] = None) -> bool:
124 """ Check that an index with the given name exists in the database.
125 If table is not None then the index must relate to the given
128 with self.cursor() as cur:
129 cur.execute("""SELECT tablename FROM pg_indexes
130 WHERE indexname = %s and schemaname = 'public'""", (index, ))
131 if cur.rowcount == 0:
134 if table is not None:
135 row = cur.fetchone() # type: ignore
136 if row is None or not isinstance(row[0], str):
138 return row[0] == table
143 def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
144 """ Drop the table with the given name.
145 Set `if_exists` to False if a non-existant table should raise
146 an exception instead of just being ignored.
148 with self.cursor() as cur:
149 cur.drop_table(name, if_exists, cascade)
153 def server_version_tuple(self) -> Tuple[int, int]:
154 """ Return the server version as a tuple of (major, minor).
155 Converts correctly for pre-10 and post-10 PostgreSQL versions.
157 version = self.server_version
159 return (int(version / 10000), int((version % 10000) / 100))
161 return (int(version / 10000), version % 10000)
164 def postgis_version_tuple(self) -> Tuple[int, int]:
165 """ Return the postgis version installed in the database as a
166 tuple of (major, minor). Assumes that the PostGIS extension
167 has been installed already.
169 with self.cursor() as cur:
170 version = cur.scalar('SELECT postgis_lib_version()')
172 version_parts = version.split('.')
173 if len(version_parts) < 2:
174 raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
176 return (int(version_parts[0]), int(version_parts[1]))
178 class _ConnectionContext(ContextManager[_Connection]):
179 connection: _Connection
181 def connect(dsn: str) -> _ConnectionContext:
182 """ Open a connection to the database using the specialised connection
183 factory. The returned object may be used in conjunction with 'with'.
184 When used outside a context manager, use the `connection` attribute
185 to get the connection.
188 conn = psycopg2.connect(dsn, connection_factory=_Connection)
189 ctxmgr = cast(_ConnectionContext, contextlib.closing(conn))
190 ctxmgr.connection = cast(_Connection, conn)
192 except psycopg2.OperationalError as err:
193 raise UsageError(f"Cannot connect to database: {err}") from err
196 # Translation from PG connection string parameters to PG environment variables.
197 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
198 _PG_CONNECTION_STRINGS = {
200 'hostaddr': 'PGHOSTADDR',
202 'dbname': 'PGDATABASE',
204 'password': 'PGPASSWORD',
205 'passfile': 'PGPASSFILE',
206 'channel_binding': 'PGCHANNELBINDING',
207 'service': 'PGSERVICE',
208 'options': 'PGOPTIONS',
209 'application_name': 'PGAPPNAME',
210 'sslmode': 'PGSSLMODE',
211 'requiressl': 'PGREQUIRESSL',
212 'sslcompression': 'PGSSLCOMPRESSION',
213 'sslcert': 'PGSSLCERT',
214 'sslkey': 'PGSSLKEY',
215 'sslrootcert': 'PGSSLROOTCERT',
216 'sslcrl': 'PGSSLCRL',
217 'requirepeer': 'PGREQUIREPEER',
218 'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
219 'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
220 'gssencmode': 'PGGSSENCMODE',
221 'krbsrvname': 'PGKRBSRVNAME',
222 'gsslib': 'PGGSSLIB',
223 'connect_timeout': 'PGCONNECT_TIMEOUT',
224 'target_session_attrs': 'PGTARGETSESSIONATTRS',
228 def get_pg_env(dsn: str,
229 base_env: Optional[Mapping[str, str]] = None) -> Mapping[str, str]:
230 """ Return a copy of `base_env` with the environment variables for
231 PostgresSQL set up from the given database connection string.
232 If `base_env` is None, then the OS environment is used as a base
235 env = dict(base_env if base_env is not None else os.environ)
237 for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore
238 if param in _PG_CONNECTION_STRINGS:
239 env[_PG_CONNECTION_STRINGS[param]] = value
241 LOG.error("Unknown connection parameter '%s' ignored.", param)