]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/connection.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / db / connection.py
1 """
2 Specialised connection and cursor functions.
3 """
4 import logging
5
6 import psycopg2
7 import psycopg2.extensions
8 import psycopg2.extras
9
10 class _Cursor(psycopg2.extras.DictCursor):
11     """ A cursor returning dict-like objects and providing specialised
12         execution functions.
13     """
14
15     def execute(self, query, args=None): # pylint: disable=W0221
16         """ Query execution that logs the SQL query when debugging is enabled.
17         """
18         logger = logging.getLogger()
19         logger.debug(self.mogrify(query, args).decode('utf-8'))
20
21         super().execute(query, args)
22
23     def scalar(self, sql, args=None):
24         """ Execute query that returns a single value. The value is returned.
25             If the query yields more than one row, a ValueError is raised.
26         """
27         self.execute(sql, args)
28
29         if self.rowcount != 1:
30             raise RuntimeError("Query did not return a single row.")
31
32         return self.fetchone()[0]
33
34
35 class _Connection(psycopg2.extensions.connection):
36     """ A connection that provides the specialised cursor by default and
37         adds convenience functions for administrating the database.
38     """
39
40     def cursor(self, cursor_factory=_Cursor, **kwargs):
41         """ Return a new cursor. By default the specialised cursor is returned.
42         """
43         return super().cursor(cursor_factory=cursor_factory, **kwargs)
44
45     def table_exists(self, table):
46         """ Check that a table with the given name exists in the database.
47         """
48         with self.cursor() as cur:
49             num = cur.scalar("""SELECT count(*) FROM pg_tables
50                                 WHERE tablename = %s""", (table, ))
51             return num == 1
52
53     def server_version_tuple(self):
54         """ Return the server version as a tuple of (major, minor).
55             Converts correctly for pre-10 and post-10 PostgreSQL versions.
56         """
57         version = self.server_version
58         if version < 100000:
59             return (version / 10000, (version % 10000) / 100)
60
61         return (version / 10000, version % 10000)
62
63 def connect(dsn):
64     """ Open a connection to the database using the specialised connection
65         factory.
66     """
67     return psycopg2.connect(dsn, connection_factory=_Connection)