]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/connection.py
6bd81a2ff53d2025e164169a12ddd9bbaa88edab
[nominatim.git] / nominatim / db / connection.py
1 """
2 Specialised connection and cursor functions.
3 """
4 import contextlib
5 import logging
6
7 import psycopg2
8 import psycopg2.extensions
9 import psycopg2.extras
10
11 from ..errors import UsageError
12
13 class _Cursor(psycopg2.extras.DictCursor):
14     """ A cursor returning dict-like objects and providing specialised
15         execution functions.
16     """
17
18     def execute(self, query, args=None): # pylint: disable=W0221
19         """ Query execution that logs the SQL query when debugging is enabled.
20         """
21         logger = logging.getLogger()
22         logger.debug(self.mogrify(query, args).decode('utf-8'))
23
24         super().execute(query, args)
25
26     def scalar(self, sql, args=None):
27         """ Execute query that returns a single value. The value is returned.
28             If the query yields more than one row, a ValueError is raised.
29         """
30         self.execute(sql, args)
31
32         if self.rowcount != 1:
33             raise RuntimeError("Query did not return a single row.")
34
35         return self.fetchone()[0]
36
37
38 class _Connection(psycopg2.extensions.connection):
39     """ A connection that provides the specialised cursor by default and
40         adds convenience functions for administrating the database.
41     """
42
43     def cursor(self, cursor_factory=_Cursor, **kwargs):
44         """ Return a new cursor. By default the specialised cursor is returned.
45         """
46         return super().cursor(cursor_factory=cursor_factory, **kwargs)
47
48
49     def table_exists(self, table):
50         """ Check that a table with the given name exists in the database.
51         """
52         with self.cursor() as cur:
53             num = cur.scalar("""SELECT count(*) FROM pg_tables
54                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
55             return num == 1
56
57
58     def index_exists(self, index, table=None):
59         """ Check that an index with the given name exists in the database.
60             If table is not None then the index must relate to the given
61             table.
62         """
63         with self.cursor() as cur:
64             cur.execute("""SELECT tablename FROM pg_indexes
65                            WHERE indexname = %s and schemaname = 'public'""", (index, ))
66             if cur.rowcount == 0:
67                 return False
68
69             if table is not None:
70                 row = cur.fetchone()
71                 return row[0] == table
72
73         return True
74
75
76     def server_version_tuple(self):
77         """ Return the server version as a tuple of (major, minor).
78             Converts correctly for pre-10 and post-10 PostgreSQL versions.
79         """
80         version = self.server_version
81         if version < 100000:
82             return (version / 10000, (version % 10000) / 100)
83
84         return (version / 10000, version % 10000)
85
86 def connect(dsn):
87     """ Open a connection to the database using the specialised connection
88         factory. The returned object may be used in conjunction with 'with'.
89         When used outside a context manager, use the `connection` attribute
90         to get the connection.
91     """
92     try:
93         conn = psycopg2.connect(dsn, connection_factory=_Connection)
94         ctxmgr = contextlib.closing(conn)
95         ctxmgr.connection = conn
96         return ctxmgr
97     except psycopg2.OperationalError as err:
98         raise UsageError("Cannot connect to database: {}".format(err)) from err