]> git.openstreetmap.org Git - nominatim.git/blobdiff - test/bdd/steps/table_compare.py
Merge pull request #3463 from lonvia/sqlalchemy14-with-psycopg
[nominatim.git] / test / bdd / steps / table_compare.py
index dfc261d425367f265166c1965ba9eac41a8399f5..cf2e12f127871390126379fb348f2eb2cfeceec1 100644 (file)
@@ -1,7 +1,16 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Functions to facilitate accessing and comparing the content of DB tables.
 """
 import re
 """
 Functions to facilitate accessing and comparing the content of DB tables.
 """
 import re
+import json
+
+from steps.check_functions import Almost
 
 ID_REGEX = re.compile(r"(?P<typ>[NRW])(?P<oid>\d+)(:(?P<cls>\w+))?")
 
 
 ID_REGEX = re.compile(r"(?P<typ>[NRW])(?P<oid>\d+)(:(?P<cls>\w+))?")
 
@@ -41,22 +50,173 @@ class NominatimID:
             where += ' and class = %s'
             params.append(self.cls)
 
             where += ' and class = %s'
             params.append(self.cls)
 
-        return cur.execute(query.format(where), params)
+        cur.execute(query.format(where), params)
 
 
-    def query_place_id(self, cur, query):
-        """ Run a query on cursor `cur` using the place ID. The `query` string
-            must contain exactly one placeholder '%s' where the 'where' query
-            should go.
+    def row_by_place_id(self, cur, table, extra_columns=None):
+        """ Get a row by place_id from the given table using cursor `cur`.
+            extra_columns may contain a list additional elements for the select
+            part of the query.
         """
         pid = self.get_place_id(cur)
         """
         pid = self.get_place_id(cur)
-        return cur.execute(query, (pid, ))
+        query = "SELECT {} FROM {} WHERE place_id = %s".format(
+                    ','.join(['*'] + (extra_columns or [])), table)
+        cur.execute(query, (pid, ))
 
 
-    def get_place_id(self, cur):
+    def get_place_id(self, cur, allow_empty=False):
         """ Look up the place id for the ID. Throws an assertion if the ID
             is not unique.
         """
         self.query_osm_id(cur, "SELECT place_id FROM placex WHERE {}")
         """ Look up the place id for the ID. Throws an assertion if the ID
             is not unique.
         """
         self.query_osm_id(cur, "SELECT place_id FROM placex WHERE {}")
+        if cur.rowcount == 0 and allow_empty:
+            return None
+
         assert cur.rowcount == 1, \
                "Place ID {!s} not unique. Found {} entries.".format(self, cur.rowcount)
 
         return cur.fetchone()[0]
         assert cur.rowcount == 1, \
                "Place ID {!s} not unique. Found {} entries.".format(self, cur.rowcount)
 
         return cur.fetchone()[0]
+
+
+class DBRow:
+    """ Represents a row from a database and offers comparison functions.
+    """
+    def __init__(self, nid, db_row, context):
+        self.nid = nid
+        self.db_row = db_row
+        self.context = context
+
+    def assert_row(self, row, exclude_columns):
+        """ Check that all columns of the given behave row are contained
+            in the database row. Exclude behave rows with the names given
+            in the `exclude_columns` list.
+        """
+        for name, value in zip(row.headings, row.cells):
+            if name not in exclude_columns:
+                assert self.contains(name, value), self.assert_msg(name, value)
+
+    def contains(self, name, expected):
+        """ Check that the DB row contains a column `name` with the given value.
+        """
+        if '+' in name:
+            column, field = name.split('+', 1)
+            return self._contains_hstore_value(column, field, expected)
+
+        if name == 'geometry':
+            return self._has_geometry(expected)
+
+        if name not in self.db_row:
+            return False
+
+        actual = self.db_row[name]
+
+        if expected == '-':
+            return actual is None
+
+        if name == 'name' and ':' not in expected:
+            return self._compare_column(actual[name], expected)
+
+        if 'place_id' in name:
+            return self._compare_place_id(actual, expected)
+
+        if name == 'centroid':
+            return self._has_centroid(expected)
+
+        return self._compare_column(actual, expected)
+
+    def _contains_hstore_value(self, column, field, expected):
+        if column == 'addr':
+            column = 'address'
+
+        if column not in self.db_row:
+            return False
+
+        if expected == '-':
+            return self.db_row[column] is None or field not in self.db_row[column]
+
+        if self.db_row[column] is None:
+            return False
+
+        return self._compare_column(self.db_row[column].get(field), expected)
+
+    def _compare_column(self, actual, expected):
+        if isinstance(actual, dict):
+            return actual == eval('{' + expected + '}')
+
+        return str(actual) == expected
+
+    def _compare_place_id(self, actual, expected):
+       if expected == '0':
+            return actual == 0
+
+       with self.context.db.cursor() as cur:
+            return NominatimID(expected).get_place_id(cur) == actual
+
+    def _has_centroid(self, expected):
+        if expected == 'in geometry':
+            with self.context.db.cursor() as cur:
+                cur.execute("""SELECT ST_Within(ST_SetSRID(ST_Point({cx}, {cy}), 4326),
+                                        ST_SetSRID('{geomtxt}'::geometry, 4326))""".format(**self.db_row))
+                return cur.fetchone()[0]
+
+        if ' ' in expected:
+            x, y = expected.split(' ')
+        else:
+            x, y = self.context.osm.grid_node(int(expected))
+
+        return Almost(float(x)) == self.db_row['cx'] and Almost(float(y)) == self.db_row['cy']
+
+    def _has_geometry(self, expected):
+        geom = self.context.osm.parse_geometry(expected)
+        with self.context.db.cursor() as cur:
+            cur.execute("""SELECT ST_Equals(ST_SnapToGrid({}, 0.00001, 0.00001),
+                                   ST_SnapToGrid(ST_SetSRID('{}'::geometry, 4326), 0.00001, 0.00001))""".format(
+                            geom, self.db_row['geomtxt']))
+            return cur.fetchone()[0]
+
+    def assert_msg(self, name, value):
+        """ Return a string with an informative message for a failed compare.
+        """
+        msg = "\nBad column '{}' in row '{!s}'.".format(name, self.nid)
+        actual = self._get_actual(name)
+        if actual is not None:
+            msg += " Expected: {}, got: {}.".format(value, actual)
+        else:
+            msg += " No such column."
+
+        return msg + "\nFull DB row: {}".format(json.dumps(dict(self.db_row), indent=4, default=str))
+
+    def _get_actual(self, name):
+        if '+' in name:
+            column, field = name.split('+', 1)
+            if column == 'addr':
+                column = 'address'
+            return (self.db_row.get(column) or {}).get(field)
+
+        if name == 'geometry':
+            return self.db_row['geomtxt']
+
+        if name not in self.db_row:
+            return None
+
+        if name == 'centroid':
+            return "POINT({cx} {cy})".format(**self.db_row)
+
+        actual = self.db_row[name]
+
+        if 'place_id' in name:
+            if actual is None:
+                return '<null>'
+
+            if actual == 0:
+                return "place ID 0"
+
+            with self.context.db.cursor() as cur:
+                cur.execute("""SELECT osm_type, osm_id, class
+                               FROM placex WHERE place_id = %s""",
+                            (actual, ))
+
+                if cur.rowcount == 1:
+                    return "{0[0]}{0[1]}:{0[2]}".format(cur.fetchone())
+
+                return "[place ID {} not found]".format(actual)
+
+        return actual