]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/db/connection.py
add forgotten BDD test
[nominatim.git] / src / nominatim_db / db / connection.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Specialised connection and cursor functions.
9 """
10 from typing import Optional, Any, Dict, Tuple
11 import logging
12 import os
13
14 import psycopg
15 import psycopg.types.hstore
16 from psycopg import sql as pysql
17
18 from ..typing import SysEnv
19 from ..errors import UsageError
20
21 LOG = logging.getLogger()
22
23 Cursor = psycopg.Cursor[Any]
24 Connection = psycopg.Connection[Any]
25
26 def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any:
27     """ Execute query that returns a single value. The value is returned.
28         If the query yields more than one row, a ValueError is raised.
29     """
30     with conn.cursor(row_factory=psycopg.rows.tuple_row) as cur:
31         cur.execute(sql, args)
32
33         if cur.rowcount != 1:
34             raise RuntimeError("Query did not return a single row.")
35
36         result = cur.fetchone()
37
38     assert result is not None
39     return result[0]
40
41
42 def table_exists(conn: Connection, table: str) -> bool:
43     """ Check that a table with the given name exists in the database.
44     """
45     num = execute_scalar(conn,
46             """SELECT count(*) FROM pg_tables
47                WHERE tablename = %s and schemaname = 'public'""", (table, ))
48     return num == 1 if isinstance(num, int) else False
49
50
51 def table_has_column(conn: Connection, table: str, column: str) -> bool:
52     """ Check if the table 'table' exists and has a column with name 'column'.
53     """
54     has_column = execute_scalar(conn,
55                     """SELECT count(*) FROM information_schema.columns
56                        WHERE table_name = %s and column_name = %s""",
57                     (table, column))
58     return has_column > 0 if isinstance(has_column, int) else False
59
60
61 def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
62     """ Check that an index with the given name exists in the database.
63         If table is not None then the index must relate to the given
64         table.
65     """
66     with conn.cursor() as cur:
67         cur.execute("""SELECT tablename FROM pg_indexes
68                        WHERE indexname = %s and schemaname = 'public'""", (index, ))
69         if cur.rowcount == 0:
70             return False
71
72         if table is not None:
73             row = cur.fetchone()
74             if row is None or not isinstance(row[0], str):
75                 return False
76             return row[0] == table
77
78     return True
79
80 def drop_tables(conn: Connection, *names: str,
81                if_exists: bool = True, cascade: bool = False) -> None:
82     """ Drop one or more tables with the given names.
83         Set `if_exists` to False if a non-existent table should raise
84         an exception instead of just being ignored. `cascade` will cause
85         depended objects to be dropped as well.
86         The caller needs to take care of committing the change.
87     """
88     sql = pysql.SQL('DROP TABLE%s{}%s' % (
89                         ' IF EXISTS ' if if_exists else ' ',
90                         ' CASCADE' if cascade else ''))
91
92     with conn.cursor() as cur:
93         for name in names:
94             cur.execute(sql.format(pysql.Identifier(name)))
95
96
97 def server_version_tuple(conn: Connection) -> Tuple[int, int]:
98     """ Return the server version as a tuple of (major, minor).
99         Converts correctly for pre-10 and post-10 PostgreSQL versions.
100     """
101     version = conn.info.server_version
102     if version < 100000:
103         return (int(version / 10000), int((version % 10000) / 100))
104
105     return (int(version / 10000), version % 10000)
106
107
108 def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
109     """ Return the postgis version installed in the database as a
110         tuple of (major, minor). Assumes that the PostGIS extension
111         has been installed already.
112     """
113     version = execute_scalar(conn, 'SELECT postgis_lib_version()')
114
115     version_parts = version.split('.')
116     if len(version_parts) < 2:
117         raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
118
119     return (int(version_parts[0]), int(version_parts[1]))
120
121
122 def register_hstore(conn: Connection) -> None:
123     """ Register the hstore type with psycopg for the connection.
124     """
125     info = psycopg.types.TypeInfo.fetch(conn, "hstore")
126     if info is None:
127         raise RuntimeError('Hstore extension is requested but not installed.')
128     psycopg.types.hstore.register_hstore(info, conn)
129
130
131 def connect(dsn: str, **kwargs: Any) -> Connection:
132     """ Open a connection to the database using the specialised connection
133         factory. The returned object may be used in conjunction with 'with'.
134         When used outside a context manager, use the `connection` attribute
135         to get the connection.
136     """
137     try:
138         return psycopg.connect(dsn, row_factory=psycopg.rows.namedtuple_row, **kwargs)
139     except psycopg.OperationalError as err:
140         raise UsageError(f"Cannot connect to database: {err}") from err
141
142
143 # Translation from PG connection string parameters to PG environment variables.
144 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
145 _PG_CONNECTION_STRINGS = {
146     'host': 'PGHOST',
147     'hostaddr': 'PGHOSTADDR',
148     'port': 'PGPORT',
149     'dbname': 'PGDATABASE',
150     'user': 'PGUSER',
151     'password': 'PGPASSWORD',
152     'passfile': 'PGPASSFILE',
153     'channel_binding': 'PGCHANNELBINDING',
154     'service': 'PGSERVICE',
155     'options': 'PGOPTIONS',
156     'application_name': 'PGAPPNAME',
157     'sslmode': 'PGSSLMODE',
158     'requiressl': 'PGREQUIRESSL',
159     'sslcompression': 'PGSSLCOMPRESSION',
160     'sslcert': 'PGSSLCERT',
161     'sslkey': 'PGSSLKEY',
162     'sslrootcert': 'PGSSLROOTCERT',
163     'sslcrl': 'PGSSLCRL',
164     'requirepeer': 'PGREQUIREPEER',
165     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
166     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
167     'gssencmode': 'PGGSSENCMODE',
168     'krbsrvname': 'PGKRBSRVNAME',
169     'gsslib': 'PGGSSLIB',
170     'connect_timeout': 'PGCONNECT_TIMEOUT',
171     'target_session_attrs': 'PGTARGETSESSIONATTRS',
172 }
173
174
175 def get_pg_env(dsn: str,
176                base_env: Optional[SysEnv] = None) -> Dict[str, str]:
177     """ Return a copy of `base_env` with the environment variables for
178         PostgreSQL set up from the given database connection string.
179         If `base_env` is None, then the OS environment is used as a base
180         environment.
181     """
182     env = dict(base_env if base_env is not None else os.environ)
183
184     for param, value in psycopg.conninfo.conninfo_to_dict(dsn).items():
185         if param in _PG_CONNECTION_STRINGS:
186             env[_PG_CONNECTION_STRINGS[param]] = str(value)
187         else:
188             LOG.error("Unknown connection parameter '%s' ignored.", param)
189
190     return env
191
192
193 async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None:
194     """ Open a connection to the database and run a single query
195         asynchronously.
196     """
197     async with await psycopg.AsyncConnection.connect(dsn) as aconn:
198         await aconn.execute(query)