X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/d586b95ff11a45ab6223e83e1b0540f1db87b325..e8caf8d78d046008b7a951f6b1ba5d015dafb6bc:/test/bdd/steps/table_compare.py?ds=inline diff --git a/test/bdd/steps/table_compare.py b/test/bdd/steps/table_compare.py index dfc261d4..2e71d943 100644 --- a/test/bdd/steps/table_compare.py +++ b/test/bdd/steps/table_compare.py @@ -2,6 +2,9 @@ Functions to facilitate accessing and comparing the content of DB tables. """ import re +import json + +from steps.check_functions import Almost ID_REGEX = re.compile(r"(?P[NRW])(?P\d+)(:(?P\w+))?") @@ -41,15 +44,17 @@ class NominatimID: where += ' and class = %s' params.append(self.cls) - return cur.execute(query.format(where), params) + cur.execute(query.format(where), params) - def query_place_id(self, cur, query): - """ Run a query on cursor `cur` using the place ID. The `query` string - must contain exactly one placeholder '%s' where the 'where' query - should go. + def row_by_place_id(self, cur, table, extra_columns=None): + """ Get a row by place_id from the given table using cursor `cur`. + extra_columns may contain a list additional elements for the select + part of the query. """ pid = self.get_place_id(cur) - return cur.execute(query, (pid, )) + query = "SELECT {} FROM {} WHERE place_id = %s".format( + ','.join(['*'] + (extra_columns or [])), table) + cur.execute(query, (pid, )) def get_place_id(self, cur): """ Look up the place id for the ID. Throws an assertion if the ID @@ -60,3 +65,145 @@ class NominatimID: "Place ID {!s} not unique. Found {} entries.".format(self, cur.rowcount) return cur.fetchone()[0] + + +class DBRow: + """ Represents a row from a database and offers comparison functions. + """ + def __init__(self, nid, db_row, context): + self.nid = nid + self.db_row = db_row + self.context = context + + def assert_row(self, row, exclude_columns): + """ Check that all columns of the given behave row are contained + in the database row. Exclude behave rows with the names given + in the `exclude_columns` list. + """ + for name, value in zip(row.headings, row.cells): + if name not in exclude_columns: + assert self.contains(name, value), self.assert_msg(name, value) + + def contains(self, name, expected): + """ Check that the DB row contains a column `name` with the given value. + """ + if '+' in name: + column, field = name.split('+', 1) + return self._contains_hstore_value(column, field, expected) + + if name == 'geometry': + return self._has_geometry(expected) + + if name not in self.db_row: + return False + + actual = self.db_row[name] + + if expected == '-': + return actual is None + + if name == 'name' and ':' not in expected: + return self._compare_column(actual[name], expected) + + if 'place_id' in name: + return self._compare_place_id(actual, expected) + + if name == 'centroid': + return self._has_centroid(expected) + + return self._compare_column(actual, expected) + + def _contains_hstore_value(self, column, field, expected): + if column == 'addr': + column = 'address' + + if column not in self.db_row: + return False + + if expected == '-': + return self.db_row[column] is None or field not in self.db_row[column] + + if self.db_row[column] is None: + return False + + return self._compare_column(self.db_row[column].get(field), expected) + + def _compare_column(self, actual, expected): + if isinstance(actual, dict): + return actual == eval('{' + expected + '}') + + return str(actual) == expected + + def _compare_place_id(self, actual, expected): + if expected == '0': + return actual == 0 + + with self.context.db.cursor() as cur: + return NominatimID(expected).get_place_id(cur) == actual + + def _has_centroid(self, expected): + if expected == 'in geometry': + with self.context.db.cursor() as cur: + cur.execute("""SELECT ST_Within(ST_SetSRID(ST_Point({cx}, {cy}), 4326), + ST_SetSRID('{geomtxt}'::geometry, 4326))""".format(**self.db_row)) + return cur.fetchone()[0] + + x, y = expected.split(' ') + return Almost(float(x)) == self.db_row['cx'] and Almost(float(y)) == self.db_row['cy'] + + def _has_geometry(self, expected): + geom = self.context.osm.parse_geometry(expected, self.context.scene) + with self.context.db.cursor() as cur: + cur.execute("""SELECT ST_Equals(ST_SnapToGrid({}, 0.00001, 0.00001), + ST_SnapToGrid(ST_SetSRID('{}'::geometry, 4326), 0.00001, 0.00001))""".format( + geom, self.db_row['geomtxt'])) + return cur.fetchone()[0] + + def assert_msg(self, name, value): + """ Return a string with an informative message for a failed compare. + """ + msg = "\nBad column '{}' in row '{!s}'.".format(name, self.nid) + actual = self._get_actual(name) + if actual is not None: + msg += " Expected: {}, got: {}.".format(value, actual) + else: + msg += " No such column." + + return msg + "\nFull DB row: {}".format(json.dumps(dict(self.db_row), indent=4, default=str)) + + def _get_actual(self, name): + if '+' in name: + column, field = name.split('+', 1) + if column == 'addr': + column = 'address' + return (self.db_row.get(column) or {}).get(field) + + if name == 'geometry': + return self.db_row['geomtxt'] + + if name not in self.db_row: + return None + + if name == 'centroid': + return "POINT({cx} {cy})".format(**self.db_row) + + actual = self.db_row[name] + + if 'place_id' in name: + if actual is None: + return '' + + if actual == 0: + return "place ID 0" + + with self.context.db.cursor() as cur: + cur.execute("""SELECT osm_type, osm_id, class + FROM placex WHERE place_id = %s""", + (actual, )) + + if cur.rowcount == 1: + return "{0[0]}{0[1]}:{0[2]}".format(cur.fetchone()) + + return "[place ID {} not found]".format(actual) + + return actual