]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/connection.py
Merge pull request #2539 from lonvia/clean-up-python-tests
[nominatim.git] / nominatim / db / connection.py
1 """
2 Specialised connection and cursor functions.
3 """
4 import contextlib
5 import logging
6 import os
7
8 import psycopg2
9 import psycopg2.extensions
10 import psycopg2.extras
11 from psycopg2 import sql as pysql
12
13 from nominatim.errors import UsageError
14
15 LOG = logging.getLogger()
16
17 class _Cursor(psycopg2.extras.DictCursor):
18     """ A cursor returning dict-like objects and providing specialised
19         execution functions.
20     """
21
22     def execute(self, query, args=None): # pylint: disable=W0221
23         """ Query execution that logs the SQL query when debugging is enabled.
24         """
25         LOG.debug(self.mogrify(query, args).decode('utf-8'))
26
27         super().execute(query, args)
28
29
30     def execute_values(self, sql, argslist, template=None):
31         """ Wrapper for the psycopg2 convenience function to execute
32             SQL for a list of values.
33         """
34         LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
35
36         psycopg2.extras.execute_values(self, sql, argslist, template=template)
37
38
39     def scalar(self, sql, args=None):
40         """ Execute query that returns a single value. The value is returned.
41             If the query yields more than one row, a ValueError is raised.
42         """
43         self.execute(sql, args)
44
45         if self.rowcount != 1:
46             raise RuntimeError("Query did not return a single row.")
47
48         return self.fetchone()[0]
49
50
51     def drop_table(self, name, if_exists=True, cascade=False):
52         """ Drop the table with the given name.
53             Set `if_exists` to False if a non-existant table should raise
54             an exception instead of just being ignored. If 'cascade' is set
55             to True then all dependent tables are deleted as well.
56         """
57         sql = 'DROP TABLE '
58         if if_exists:
59             sql += 'IF EXISTS '
60         sql += '{}'
61         if cascade:
62             sql += ' CASCADE'
63
64         self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
65
66
67 class _Connection(psycopg2.extensions.connection):
68     """ A connection that provides the specialised cursor by default and
69         adds convenience functions for administrating the database.
70     """
71
72     def cursor(self, cursor_factory=_Cursor, **kwargs):
73         """ Return a new cursor. By default the specialised cursor is returned.
74         """
75         return super().cursor(cursor_factory=cursor_factory, **kwargs)
76
77
78     def table_exists(self, table):
79         """ Check that a table with the given name exists in the database.
80         """
81         with self.cursor() as cur:
82             num = cur.scalar("""SELECT count(*) FROM pg_tables
83                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
84             return num == 1
85
86
87     def index_exists(self, index, table=None):
88         """ Check that an index with the given name exists in the database.
89             If table is not None then the index must relate to the given
90             table.
91         """
92         with self.cursor() as cur:
93             cur.execute("""SELECT tablename FROM pg_indexes
94                            WHERE indexname = %s and schemaname = 'public'""", (index, ))
95             if cur.rowcount == 0:
96                 return False
97
98             if table is not None:
99                 row = cur.fetchone()
100                 return row[0] == table
101
102         return True
103
104
105     def drop_table(self, name, if_exists=True, cascade=False):
106         """ Drop the table with the given name.
107             Set `if_exists` to False if a non-existant table should raise
108             an exception instead of just being ignored.
109         """
110         with self.cursor() as cur:
111             cur.drop_table(name, if_exists, cascade)
112         self.commit()
113
114
115     def server_version_tuple(self):
116         """ Return the server version as a tuple of (major, minor).
117             Converts correctly for pre-10 and post-10 PostgreSQL versions.
118         """
119         version = self.server_version
120         if version < 100000:
121             return (int(version / 10000), (version % 10000) / 100)
122
123         return (int(version / 10000), version % 10000)
124
125
126     def postgis_version_tuple(self):
127         """ Return the postgis version installed in the database as a
128             tuple of (major, minor). Assumes that the PostGIS extension
129             has been installed already.
130         """
131         with self.cursor() as cur:
132             version = cur.scalar('SELECT postgis_lib_version()')
133
134         return tuple((int(x) for x in version.split('.')[:2]))
135
136
137 def connect(dsn):
138     """ Open a connection to the database using the specialised connection
139         factory. The returned object may be used in conjunction with 'with'.
140         When used outside a context manager, use the `connection` attribute
141         to get the connection.
142     """
143     try:
144         conn = psycopg2.connect(dsn, connection_factory=_Connection)
145         ctxmgr = contextlib.closing(conn)
146         ctxmgr.connection = conn
147         return ctxmgr
148     except psycopg2.OperationalError as err:
149         raise UsageError("Cannot connect to database: {}".format(err)) from err
150
151
152 # Translation from PG connection string parameters to PG environment variables.
153 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
154 _PG_CONNECTION_STRINGS = {
155     'host': 'PGHOST',
156     'hostaddr': 'PGHOSTADDR',
157     'port': 'PGPORT',
158     'dbname': 'PGDATABASE',
159     'user': 'PGUSER',
160     'password': 'PGPASSWORD',
161     'passfile': 'PGPASSFILE',
162     'channel_binding': 'PGCHANNELBINDING',
163     'service': 'PGSERVICE',
164     'options': 'PGOPTIONS',
165     'application_name': 'PGAPPNAME',
166     'sslmode': 'PGSSLMODE',
167     'requiressl': 'PGREQUIRESSL',
168     'sslcompression': 'PGSSLCOMPRESSION',
169     'sslcert': 'PGSSLCERT',
170     'sslkey': 'PGSSLKEY',
171     'sslrootcert': 'PGSSLROOTCERT',
172     'sslcrl': 'PGSSLCRL',
173     'requirepeer': 'PGREQUIREPEER',
174     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
175     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
176     'gssencmode': 'PGGSSENCMODE',
177     'krbsrvname': 'PGKRBSRVNAME',
178     'gsslib': 'PGGSSLIB',
179     'connect_timeout': 'PGCONNECT_TIMEOUT',
180     'target_session_attrs': 'PGTARGETSESSIONATTRS',
181 }
182
183
184 def get_pg_env(dsn, base_env=None):
185     """ Return a copy of `base_env` with the environment variables for
186         PostgresSQL set up from the given database connection string.
187         If `base_env` is None, then the OS environment is used as a base
188         environment.
189     """
190     env = dict(base_env if base_env is not None else os.environ)
191
192     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
193         if param in _PG_CONNECTION_STRINGS:
194             env[_PG_CONNECTION_STRINGS[param]] = value
195         else:
196             LOG.error("Unknown connection parameter '%s' ignored.", param)
197
198     return env