1 # SPDX-License-Identifier: GPL-2.0-only
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Specialised connection and cursor functions.
10 from typing import List, Optional, Any, Callable, ContextManager, Mapping, cast, overload, Tuple
16 import psycopg2.extensions
17 import psycopg2.extras
18 from psycopg2 import sql as pysql
20 from nominatim.typing import Query, T_cursor
21 from nominatim.errors import UsageError
23 LOG = logging.getLogger()
25 class _Cursor(psycopg2.extras.DictCursor):
26 """ A cursor returning dict-like objects and providing specialised
29 # pylint: disable=arguments-renamed,arguments-differ
30 def execute(self, query: Query, args: Any = None) -> None:
31 """ Query execution that logs the SQL query when debugging is enabled.
33 if LOG.isEnabledFor(logging.DEBUG):
34 LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore
36 super().execute(query, args)
39 def execute_values(self, sql: Query, argslist: List[Any],
40 template: Optional[str] = None) -> None:
41 """ Wrapper for the psycopg2 convenience function to execute
42 SQL for a list of values.
44 LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
46 psycopg2.extras.execute_values(self, sql, argslist, template=template)
49 def scalar(self, sql: Query, args: Any = None) -> Any:
50 """ Execute query that returns a single value. The value is returned.
51 If the query yields more than one row, a ValueError is raised.
53 self.execute(sql, args)
55 if self.rowcount != 1:
56 raise RuntimeError("Query did not return a single row.")
58 result = self.fetchone() # type: ignore
59 assert result is not None
64 def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
65 """ Drop the table with the given name.
66 Set `if_exists` to False if a non-existant table should raise
67 an exception instead of just being ignored. If 'cascade' is set
68 to True then all dependent tables are deleted as well.
77 self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore
80 class _Connection(psycopg2.extensions.connection):
81 """ A connection that provides the specialised cursor by default and
82 adds convenience functions for administrating the database.
84 @overload # type: ignore[override]
85 def cursor(self) -> _Cursor:
89 def cursor(self, name: str) -> _Cursor:
93 def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
96 def cursor(self, cursor_factory = _Cursor, **kwargs): # type: ignore
97 """ Return a new cursor. By default the specialised cursor is returned.
99 return super().cursor(cursor_factory=cursor_factory, **kwargs)
102 def table_exists(self, table: str) -> bool:
103 """ Check that a table with the given name exists in the database.
105 with self.cursor() as cur:
106 num = cur.scalar("""SELECT count(*) FROM pg_tables
107 WHERE tablename = %s and schemaname = 'public'""", (table, ))
108 return num == 1 if isinstance(num, int) else False
111 def table_has_column(self, table: str, column: str) -> bool:
112 """ Check if the table 'table' exists and has a column with name 'column'.
114 with self.cursor() as cur:
115 has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
116 WHERE table_name = %s
117 and column_name = %s""",
119 return has_column > 0 if isinstance(has_column, int) else False
122 def index_exists(self, index: str, table: Optional[str] = None) -> bool:
123 """ Check that an index with the given name exists in the database.
124 If table is not None then the index must relate to the given
127 with self.cursor() as cur:
128 cur.execute("""SELECT tablename FROM pg_indexes
129 WHERE indexname = %s and schemaname = 'public'""", (index, ))
130 if cur.rowcount == 0:
133 if table is not None:
134 row = cur.fetchone() # type: ignore
135 if row is None or not isinstance(row[0], str):
137 return row[0] == table
142 def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
143 """ Drop the table with the given name.
144 Set `if_exists` to False if a non-existant table should raise
145 an exception instead of just being ignored.
147 with self.cursor() as cur:
148 cur.drop_table(name, if_exists, cascade)
152 def server_version_tuple(self) -> Tuple[int, int]:
153 """ Return the server version as a tuple of (major, minor).
154 Converts correctly for pre-10 and post-10 PostgreSQL versions.
156 version = self.server_version
158 return (int(version / 10000), int((version % 10000) / 100))
160 return (int(version / 10000), version % 10000)
163 def postgis_version_tuple(self) -> Tuple[int, int]:
164 """ Return the postgis version installed in the database as a
165 tuple of (major, minor). Assumes that the PostGIS extension
166 has been installed already.
168 with self.cursor() as cur:
169 version = cur.scalar('SELECT postgis_lib_version()')
171 version_parts = version.split('.')
172 if len(version_parts) < 2:
173 raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
175 return (int(version_parts[0]), int(version_parts[1]))
177 class _ConnectionContext(ContextManager[_Connection]):
178 connection: _Connection
180 def connect(dsn: str) -> _ConnectionContext:
181 """ Open a connection to the database using the specialised connection
182 factory. The returned object may be used in conjunction with 'with'.
183 When used outside a context manager, use the `connection` attribute
184 to get the connection.
187 conn = psycopg2.connect(dsn, connection_factory=_Connection)
188 ctxmgr = cast(_ConnectionContext, contextlib.closing(conn))
189 ctxmgr.connection = cast(_Connection, conn)
191 except psycopg2.OperationalError as err:
192 raise UsageError(f"Cannot connect to database: {err}") from err
195 # Translation from PG connection string parameters to PG environment variables.
196 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
197 _PG_CONNECTION_STRINGS = {
199 'hostaddr': 'PGHOSTADDR',
201 'dbname': 'PGDATABASE',
203 'password': 'PGPASSWORD',
204 'passfile': 'PGPASSFILE',
205 'channel_binding': 'PGCHANNELBINDING',
206 'service': 'PGSERVICE',
207 'options': 'PGOPTIONS',
208 'application_name': 'PGAPPNAME',
209 'sslmode': 'PGSSLMODE',
210 'requiressl': 'PGREQUIRESSL',
211 'sslcompression': 'PGSSLCOMPRESSION',
212 'sslcert': 'PGSSLCERT',
213 'sslkey': 'PGSSLKEY',
214 'sslrootcert': 'PGSSLROOTCERT',
215 'sslcrl': 'PGSSLCRL',
216 'requirepeer': 'PGREQUIREPEER',
217 'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
218 'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
219 'gssencmode': 'PGGSSENCMODE',
220 'krbsrvname': 'PGKRBSRVNAME',
221 'gsslib': 'PGGSSLIB',
222 'connect_timeout': 'PGCONNECT_TIMEOUT',
223 'target_session_attrs': 'PGTARGETSESSIONATTRS',
227 def get_pg_env(dsn: str,
228 base_env: Optional[Mapping[str, str]] = None) -> Mapping[str, str]:
229 """ Return a copy of `base_env` with the environment variables for
230 PostgresSQL set up from the given database connection string.
231 If `base_env` is None, then the OS environment is used as a base
234 env = dict(base_env if base_env is not None else os.environ)
236 for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore
237 if param in _PG_CONNECTION_STRINGS:
238 env[_PG_CONNECTION_STRINGS[param]] = value
240 LOG.error("Unknown connection parameter '%s' ignored.", param)