]> git.openstreetmap.org Git - nominatim.git/blob - test/python/cursor.py
release 4.5.0.post1
[nominatim.git] / test / python / cursor.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Specialised psycopg cursor with shortcut functions useful for testing.
9 """
10 import psycopg
11
12 class CursorForTesting(psycopg.Cursor):
13     """ Extension to the DictCursor class that provides execution
14         short-cuts that simplify writing assertions.
15     """
16
17     def scalar(self, sql, params=None):
18         """ Execute a query with a single return value and return this value.
19             Raises an assertion when not exactly one row is returned.
20         """
21         self.execute(sql, params)
22         assert self.rowcount == 1
23         return self.fetchone()[0]
24
25
26     def row_set(self, sql, params=None):
27         """ Execute a query and return the result as a set of tuples.
28             Fails when the SQL command returns duplicate rows.
29         """
30         self.execute(sql, params)
31
32         result = set((tuple(row) for row in self))
33         assert len(result) == self.rowcount
34
35         return result
36
37
38     def table_exists(self, table):
39         """ Check that a table with the given name exists in the database.
40         """
41         num = self.scalar("""SELECT count(*) FROM pg_tables
42                              WHERE tablename = %s""", (table, ))
43         return num == 1
44
45
46     def index_exists(self, table, index):
47         """ Check that an indexwith the given name exists on the given table.
48         """
49         num = self.scalar("""SELECT count(*) FROM pg_indexes
50                              WHERE tablename = %s and indexname = %s""",
51                           (table, index))
52         return num == 1
53
54
55     def table_rows(self, table, where=None):
56         """ Return the number of rows in the given table.
57         """
58         if where is None:
59             return self.scalar('SELECT count(*) FROM ' + table)
60
61         return self.scalar('SELECT count(*) FROM {} WHERE {}'.format(table, where))