]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/connection.py
factor out async connection handling into separate class
[nominatim.git] / nominatim / db / connection.py
1 """
2 Specialised connection and cursor functions.
3 """
4 import contextlib
5 import logging
6 import os
7
8 import psycopg2
9 import psycopg2.extensions
10 import psycopg2.extras
11
12 from nominatim.errors import UsageError
13
14 LOG = logging.getLogger()
15
16 class _Cursor(psycopg2.extras.DictCursor):
17     """ A cursor returning dict-like objects and providing specialised
18         execution functions.
19     """
20
21     def execute(self, query, args=None): # pylint: disable=W0221
22         """ Query execution that logs the SQL query when debugging is enabled.
23         """
24         LOG.debug(self.mogrify(query, args).decode('utf-8'))
25
26         super().execute(query, args)
27
28     def scalar(self, sql, args=None):
29         """ Execute query that returns a single value. The value is returned.
30             If the query yields more than one row, a ValueError is raised.
31         """
32         self.execute(sql, args)
33
34         if self.rowcount != 1:
35             raise RuntimeError("Query did not return a single row.")
36
37         return self.fetchone()[0]
38
39
40 class _Connection(psycopg2.extensions.connection):
41     """ A connection that provides the specialised cursor by default and
42         adds convenience functions for administrating the database.
43     """
44
45     def cursor(self, cursor_factory=_Cursor, **kwargs):
46         """ Return a new cursor. By default the specialised cursor is returned.
47         """
48         return super().cursor(cursor_factory=cursor_factory, **kwargs)
49
50
51     def table_exists(self, table):
52         """ Check that a table with the given name exists in the database.
53         """
54         with self.cursor() as cur:
55             num = cur.scalar("""SELECT count(*) FROM pg_tables
56                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
57             return num == 1
58
59
60     def index_exists(self, index, table=None):
61         """ Check that an index with the given name exists in the database.
62             If table is not None then the index must relate to the given
63             table.
64         """
65         with self.cursor() as cur:
66             cur.execute("""SELECT tablename FROM pg_indexes
67                            WHERE indexname = %s and schemaname = 'public'""", (index, ))
68             if cur.rowcount == 0:
69                 return False
70
71             if table is not None:
72                 row = cur.fetchone()
73                 return row[0] == table
74
75         return True
76
77
78     def drop_table(self, name, if_exists=True):
79         """ Drop the table with the given name.
80             Set `if_exists` to False if a non-existant table should raise
81             an exception instead of just being ignored.
82         """
83         with self.cursor() as cur:
84             cur.execute("""DROP TABLE {} "{}"
85                         """.format('IF EXISTS' if if_exists else '', name))
86         self.commit()
87
88
89     def server_version_tuple(self):
90         """ Return the server version as a tuple of (major, minor).
91             Converts correctly for pre-10 and post-10 PostgreSQL versions.
92         """
93         version = self.server_version
94         if version < 100000:
95             return (int(version / 10000), (version % 10000) / 100)
96
97         return (int(version / 10000), version % 10000)
98
99
100     def postgis_version_tuple(self):
101         """ Return the postgis version installed in the database as a
102             tuple of (major, minor). Assumes that the PostGIS extension
103             has been installed already.
104         """
105         with self.cursor() as cur:
106             version = cur.scalar('SELECT postgis_lib_version()')
107
108         return tuple((int(x) for x in version.split('.')[:2]))
109
110
111 def connect(dsn):
112     """ Open a connection to the database using the specialised connection
113         factory. The returned object may be used in conjunction with 'with'.
114         When used outside a context manager, use the `connection` attribute
115         to get the connection.
116     """
117     try:
118         conn = psycopg2.connect(dsn, connection_factory=_Connection)
119         ctxmgr = contextlib.closing(conn)
120         ctxmgr.connection = conn
121         return ctxmgr
122     except psycopg2.OperationalError as err:
123         raise UsageError("Cannot connect to database: {}".format(err)) from err
124
125
126 # Translation from PG connection string parameters to PG environment variables.
127 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
128 _PG_CONNECTION_STRINGS = {
129     'host': 'PGHOST',
130     'hostaddr': 'PGHOSTADDR',
131     'port': 'PGPORT',
132     'dbname': 'PGDATABASE',
133     'user': 'PGUSER',
134     'password': 'PGPASSWORD',
135     'passfile': 'PGPASSFILE',
136     'channel_binding': 'PGCHANNELBINDING',
137     'service': 'PGSERVICE',
138     'options': 'PGOPTIONS',
139     'application_name': 'PGAPPNAME',
140     'sslmode': 'PGSSLMODE',
141     'requiressl': 'PGREQUIRESSL',
142     'sslcompression': 'PGSSLCOMPRESSION',
143     'sslcert': 'PGSSLCERT',
144     'sslkey': 'PGSSLKEY',
145     'sslrootcert': 'PGSSLROOTCERT',
146     'sslcrl': 'PGSSLCRL',
147     'requirepeer': 'PGREQUIREPEER',
148     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
149     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
150     'gssencmode': 'PGGSSENCMODE',
151     'krbsrvname': 'PGKRBSRVNAME',
152     'gsslib': 'PGGSSLIB',
153     'connect_timeout': 'PGCONNECT_TIMEOUT',
154     'target_session_attrs': 'PGTARGETSESSIONATTRS',
155 }
156
157
158 def get_pg_env(dsn, base_env=None):
159     """ Return a copy of `base_env` with the environment variables for
160         PostgresSQL set up from the given database connection string.
161         If `base_env` is None, then the OS environment is used as a base
162         environment.
163     """
164     env = dict(base_env if base_env is not None else os.environ)
165
166     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
167         if param in _PG_CONNECTION_STRINGS:
168             env[_PG_CONNECTION_STRINGS[param]] = value
169         else:
170             LOG.error("Unknown connection parameter '%s' ignored.", param)
171
172     return env