]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/db/connection.py
Merge pull request #3582 from lonvia/switch-to-flake
[nominatim.git] / src / nominatim_db / db / connection.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Specialised connection and cursor functions.
9 """
10 from typing import Optional, Any, Dict, Tuple
11 import logging
12 import os
13
14 import psycopg
15 import psycopg.types.hstore
16 from psycopg import sql as pysql
17
18 from ..typing import SysEnv
19 from ..errors import UsageError
20
21 LOG = logging.getLogger()
22
23 Cursor = psycopg.Cursor[Any]
24 Connection = psycopg.Connection[Any]
25
26
27 def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any:
28     """ Execute query that returns a single value. The value is returned.
29         If the query yields more than one row, a ValueError is raised.
30     """
31     with conn.cursor(row_factory=psycopg.rows.tuple_row) as cur:
32         cur.execute(sql, args)
33
34         if cur.rowcount != 1:
35             raise RuntimeError("Query did not return a single row.")
36
37         result = cur.fetchone()
38
39     assert result is not None
40     return result[0]
41
42
43 def table_exists(conn: Connection, table: str) -> bool:
44     """ Check that a table with the given name exists in the database.
45     """
46     num = execute_scalar(
47         conn,
48         """SELECT count(*) FROM pg_tables
49            WHERE tablename = %s and schemaname = 'public'""", (table, ))
50     return num == 1 if isinstance(num, int) else False
51
52
53 def table_has_column(conn: Connection, table: str, column: str) -> bool:
54     """ Check if the table 'table' exists and has a column with name 'column'.
55     """
56     has_column = execute_scalar(conn,
57                                 """SELECT count(*) FROM information_schema.columns
58                                    WHERE table_name = %s and column_name = %s""",
59                                 (table, column))
60     return has_column > 0 if isinstance(has_column, int) else False
61
62
63 def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
64     """ Check that an index with the given name exists in the database.
65         If table is not None then the index must relate to the given
66         table.
67     """
68     with conn.cursor() as cur:
69         cur.execute("""SELECT tablename FROM pg_indexes
70                        WHERE indexname = %s and schemaname = 'public'""", (index, ))
71         if cur.rowcount == 0:
72             return False
73
74         if table is not None:
75             row = cur.fetchone()
76             if row is None or not isinstance(row[0], str):
77                 return False
78             return row[0] == table
79
80     return True
81
82
83 def drop_tables(conn: Connection, *names: str,
84                 if_exists: bool = True, cascade: bool = False) -> None:
85     """ Drop one or more tables with the given names.
86         Set `if_exists` to False if a non-existent table should raise
87         an exception instead of just being ignored. `cascade` will cause
88         depended objects to be dropped as well.
89         The caller needs to take care of committing the change.
90     """
91     sql = pysql.SQL('DROP TABLE%s{}%s' % (
92                         ' IF EXISTS ' if if_exists else ' ',
93                         ' CASCADE' if cascade else ''))
94
95     with conn.cursor() as cur:
96         for name in names:
97             cur.execute(sql.format(pysql.Identifier(name)))
98
99
100 def server_version_tuple(conn: Connection) -> Tuple[int, int]:
101     """ Return the server version as a tuple of (major, minor).
102         Converts correctly for pre-10 and post-10 PostgreSQL versions.
103     """
104     version = conn.info.server_version
105     if version < 100000:
106         return (int(version / 10000), int((version % 10000) / 100))
107
108     return (int(version / 10000), version % 10000)
109
110
111 def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
112     """ Return the postgis version installed in the database as a
113         tuple of (major, minor). Assumes that the PostGIS extension
114         has been installed already.
115     """
116     version = execute_scalar(conn, 'SELECT postgis_lib_version()')
117
118     version_parts = version.split('.')
119     if len(version_parts) < 2:
120         raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
121
122     return (int(version_parts[0]), int(version_parts[1]))
123
124
125 def register_hstore(conn: Connection) -> None:
126     """ Register the hstore type with psycopg for the connection.
127     """
128     info = psycopg.types.TypeInfo.fetch(conn, "hstore")
129     if info is None:
130         raise RuntimeError('Hstore extension is requested but not installed.')
131     psycopg.types.hstore.register_hstore(info, conn)
132
133
134 def connect(dsn: str, **kwargs: Any) -> Connection:
135     """ Open a connection to the database using the specialised connection
136         factory. The returned object may be used in conjunction with 'with'.
137         When used outside a context manager, use the `connection` attribute
138         to get the connection.
139     """
140     try:
141         return psycopg.connect(dsn, row_factory=psycopg.rows.namedtuple_row, **kwargs)
142     except psycopg.OperationalError as err:
143         raise UsageError(f"Cannot connect to database: {err}") from err
144
145
146 # Translation from PG connection string parameters to PG environment variables.
147 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
148 _PG_CONNECTION_STRINGS = {
149     'host': 'PGHOST',
150     'hostaddr': 'PGHOSTADDR',
151     'port': 'PGPORT',
152     'dbname': 'PGDATABASE',
153     'user': 'PGUSER',
154     'password': 'PGPASSWORD',
155     'passfile': 'PGPASSFILE',
156     'channel_binding': 'PGCHANNELBINDING',
157     'service': 'PGSERVICE',
158     'options': 'PGOPTIONS',
159     'application_name': 'PGAPPNAME',
160     'sslmode': 'PGSSLMODE',
161     'requiressl': 'PGREQUIRESSL',
162     'sslcompression': 'PGSSLCOMPRESSION',
163     'sslcert': 'PGSSLCERT',
164     'sslkey': 'PGSSLKEY',
165     'sslrootcert': 'PGSSLROOTCERT',
166     'sslcrl': 'PGSSLCRL',
167     'requirepeer': 'PGREQUIREPEER',
168     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
169     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
170     'gssencmode': 'PGGSSENCMODE',
171     'krbsrvname': 'PGKRBSRVNAME',
172     'gsslib': 'PGGSSLIB',
173     'connect_timeout': 'PGCONNECT_TIMEOUT',
174     'target_session_attrs': 'PGTARGETSESSIONATTRS',
175 }
176
177
178 def get_pg_env(dsn: str,
179                base_env: Optional[SysEnv] = None) -> Dict[str, str]:
180     """ Return a copy of `base_env` with the environment variables for
181         PostgreSQL set up from the given database connection string.
182         If `base_env` is None, then the OS environment is used as a base
183         environment.
184     """
185     env = dict(base_env if base_env is not None else os.environ)
186
187     for param, value in psycopg.conninfo.conninfo_to_dict(dsn).items():
188         if param in _PG_CONNECTION_STRINGS:
189             env[_PG_CONNECTION_STRINGS[param]] = str(value)
190         else:
191             LOG.error("Unknown connection parameter '%s' ignored.", param)
192
193     return env
194
195
196 async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None:
197     """ Open a connection to the database and run a single query
198         asynchronously.
199     """
200     async with await psycopg.AsyncConnection.connect(dsn) as aconn:
201         await aconn.execute(query)