]> git.openstreetmap.org Git - nominatim.git/blob - test/python/cursor.py
Merge pull request #3670 from lonvia/flake-for-tests
[nominatim.git] / test / python / cursor.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2025 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Specialised psycopg cursor with shortcut functions useful for testing.
9 """
10 import psycopg
11
12
13 class CursorForTesting(psycopg.Cursor):
14     """ Extension to the DictCursor class that provides execution
15         short-cuts that simplify writing assertions.
16     """
17
18     def scalar(self, sql, params=None):
19         """ Execute a query with a single return value and return this value.
20             Raises an assertion when not exactly one row is returned.
21         """
22         self.execute(sql, params)
23         assert self.rowcount == 1
24         return self.fetchone()[0]
25
26     def row_set(self, sql, params=None):
27         """ Execute a query and return the result as a set of tuples.
28             Fails when the SQL command returns duplicate rows.
29         """
30         self.execute(sql, params)
31
32         result = set((tuple(row) for row in self))
33         assert len(result) == self.rowcount
34
35         return result
36
37     def table_exists(self, table):
38         """ Check that a table with the given name exists in the database.
39         """
40         num = self.scalar("""SELECT count(*) FROM pg_tables
41                              WHERE tablename = %s""", (table, ))
42         return num == 1
43
44     def index_exists(self, table, index):
45         """ Check that an indexwith the given name exists on the given table.
46         """
47         num = self.scalar("""SELECT count(*) FROM pg_indexes
48                              WHERE tablename = %s and indexname = %s""",
49                           (table, index))
50         return num == 1
51
52     def table_rows(self, table, where=None):
53         """ Return the number of rows in the given table.
54         """
55         if where is None:
56             return self.scalar('SELECT count(*) FROM ' + table)
57
58         return self.scalar('SELECT count(*) FROM {} WHERE {}'.format(table, where))