]> git.openstreetmap.org Git - nominatim.git/blobdiff - test/bdd/steps/steps_db_ops.py
enable flake for bdd test code
[nominatim.git] / test / bdd / steps / steps_db_ops.py
index c30ee894280d4eb912a325d6669b0148e2c35d7c..8b62cbc6b3d6bd06baaf001a7daeadce97c22e05 100644 (file)
@@ -1,24 +1,24 @@
-# SPDX-License-Identifier: GPL-2.0-only
+# SPDX-License-Identifier: GPL-3.0-or-later
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
-# Copyright (C) 2022 by the Nominatim developer community.
+# Copyright (C) 2025 by the Nominatim developer community.
 # For a full list of authors see the git log.
-import logging
 from itertools import chain
 
-import psycopg2.extras
+import psycopg
+from psycopg import sql as pysql
 
 from place_inserter import PlaceColumn
 from table_compare import NominatimID, DBRow
 
-from nominatim.indexer import indexer
-from nominatim.tokenizer import factory as tokenizer_factory
+from nominatim_db.tokenizer import factory as tokenizer_factory
+
 
 def check_database_integrity(context):
     """ Check some generic constraints on the tables.
     """
-    with context.db.cursor() as cur:
+    with context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
         # place_addressline should not have duplicate (place_id, address_place_id)
         cur.execute("""SELECT count(*) FROM
                         (SELECT place_id, address_place_id, count(*) as c
@@ -27,14 +27,12 @@ def check_database_integrity(context):
         assert cur.fetchone()[0] == 0, "Duplicates found in place_addressline"
 
         # word table must not have empty word_tokens
-        if context.nominatim.tokenizer != 'legacy':
-            cur.execute("SELECT count(*) FROM word WHERE word_token = ''")
-            assert cur.fetchone()[0] == 0, "Empty word tokens found in word table"
+        cur.execute("SELECT count(*) FROM word WHERE word_token = ''")
+        assert cur.fetchone()[0] == 0, "Empty word tokens found in word table"
 
+# GIVEN ##################################
 
 
-################################ GIVEN ##################################
-
 @given("the (?P<named>named )?places")
 def add_data_to_place_table(context, named):
     """ Add entries into the place table. 'named places' makes sure that
@@ -46,6 +44,7 @@ def add_data_to_place_table(context, named):
             PlaceColumn(context).add_row(row, named is not None).db_insert(cur)
         cur.execute('ALTER TABLE place ENABLE TRIGGER place_before_insert')
 
+
 @given("the relations")
 def add_data_to_planet_relations(context):
     """ Add entries into the osm2pgsql relation middle table. This is needed
@@ -54,7 +53,7 @@ def add_data_to_planet_relations(context):
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
-        if row is None or row[0] == '1':
+        if row is None or row['value'] == '1':
             for r in context.table:
                 last_node = 0
                 last_way = 0
@@ -77,9 +76,11 @@ def add_data_to_planet_relations(context):
                 else:
                     members = None
 
-                tags = chain.from_iterable([(h[5:], r[h]) for h in r.headings if h.startswith("tags+")])
+                tags = chain.from_iterable([(h[5:], r[h]) for h in r.headings
+                                            if h.startswith("tags+")])
 
-                cur.execute("""INSERT INTO planet_osm_rels (id, way_off, rel_off, parts, members, tags)
+                cur.execute("""INSERT INTO planet_osm_rels (id, way_off, rel_off,
+                                                            parts, members, tags)
                                VALUES (%s, %s, %s, %s, %s, %s)""",
                             (r['id'], last_node, last_way, parts, members, list(tags)))
         else:
@@ -96,8 +97,9 @@ def add_data_to_planet_relations(context):
 
                 cur.execute("""INSERT INTO planet_osm_rels (id, tags, members)
                                VALUES (%s, %s, %s)""",
-                            (r['id'], psycopg2.extras.Json(tags),
-                             psycopg2.extras.Json(members)))
+                            (r['id'], psycopg.types.json.Json(tags),
+                             psycopg.types.json.Json(members)))
+
 
 @given("the ways")
 def add_data_to_planet_ways(context):
@@ -107,19 +109,21 @@ def add_data_to_planet_ways(context):
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
-        json_tags = row is not None and row[0] != '1'
+        json_tags = row is not None and row['value'] != '1'
         for r in context.table:
             if json_tags:
-                tags = psycopg2.extras.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
+                tags = psycopg.types.json.Json({h[5:]: r[h] for h in r.headings
+                                                if h.startswith("tags+")})
             else:
                 tags = list(chain.from_iterable([(h[5:], r[h])
                                                  for h in r.headings if h.startswith("tags+")]))
-            nodes = [ int(x.strip()) for x in r['nodes'].split(',') ]
+            nodes = [int(x.strip()) for x in r['nodes'].split(',')]
 
             cur.execute("INSERT INTO planet_osm_ways (id, nodes, tags) VALUES (%s, %s, %s)",
                         (r['id'], nodes, tags))
 
-################################ WHEN ##################################
+# WHEN ##################################
+
 
 @when("importing")
 def import_and_index_data_from_place_table(context):
@@ -136,6 +140,7 @@ def import_and_index_data_from_place_table(context):
     # itself.
     context.log_capture.buffer.clear()
 
+
 @when("updating places")
 def update_place_table(context):
     """ Update the place table with the given data. Also runs all triggers
@@ -164,6 +169,7 @@ def update_postcodes(context):
     """
     context.nominatim.run_nominatim('refresh', '--postcodes')
 
+
 @when("marking for delete (?P<oids>.*)")
 def delete_places(context, oids):
     """ Remove entries from the place table. Multiple ids may be given
@@ -184,7 +190,8 @@ def delete_places(context, oids):
     # itself.
     context.log_capture.buffer.clear()
 
-################################ THEN ##################################
+# THEN ##################################
+
 
 @then("(?P<table>placex|place) contains(?P<exact> exactly)?")
 def check_place_contents(context, table, exact):
@@ -197,11 +204,12 @@ def check_place_contents(context, table, exact):
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         expected_content = set()
         for row in context.table:
             nid = NominatimID(row['object'])
-            query = 'SELECT *, ST_AsText(geometry) as geomtxt, ST_GeometryType(geometry) as geometrytype'
+            query = """SELECT *, ST_AsText(geometry) as geomtxt,
+                              ST_GeometryType(geometry) as geometrytype """
             if table == 'placex':
                 query += ' ,ST_X(centroid) as cx, ST_Y(centroid) as cy'
             query += " FROM %s WHERE {}" % (table, )
@@ -215,8 +223,9 @@ def check_place_contents(context, table, exact):
                 DBRow(nid, res, context).assert_row(row, ['object'])
 
         if exact:
-            cur.execute('SELECT osm_type, osm_id, class from {}'.format(table))
-            actual = set([(r[0], r[1], r[2]) for r in cur])
+            cur.execute(pysql.SQL('SELECT osm_type, osm_id, class from')
+                        + pysql.Identifier(table))
+            actual = set([(r['osm_type'], r['osm_id'], r['class']) for r in cur])
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"
@@ -227,7 +236,7 @@ def check_place_has_entry(context, table, oid):
     """ Ensure that no database row for the given object exists. The ID
         must be of the form '<NRW><osm id>[:<class>]'.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
         assert cur.rowcount == 0, \
                "Found {} entries for ID {}".format(cur.rowcount, oid)
@@ -244,7 +253,7 @@ def check_search_name_contents(context, exclude):
     tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config())
 
     with tokenizer.name_analyzer() as analyzer:
-        with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        with context.db.cursor() as cur:
             for row in context.table:
                 nid = NominatimID(row['object'])
                 nid.row_by_place_id(cur, 'search_name',
@@ -260,28 +269,30 @@ def check_search_name_contents(context, exclude):
 
                             if not exclude:
                                 assert len(tokens) >= len(items), \
-                                       "No word entry found for {}. Entries found: {!s}".format(value, len(tokens))
+                                    f"No word entry found for {value}. Entries found: {len(tokens)}"
                             for word, token, wid in tokens:
                                 if exclude:
                                     assert wid not in res[name], \
-                                           "Found term for {}/{}: {}".format(nid, name, wid)
+                                        "Found term for {}/{}: {}".format(nid, name, wid)
                                 else:
                                     assert wid in res[name], \
-                                           "Missing term for {}/{}: {}".format(nid, name, wid)
+                                        "Missing term for {}/{}: {}".format(nid, name, wid)
                         elif name != 'object':
                             assert db_row.contains(name, value), db_row.assert_msg(name, value)
 
+
 @then("search_name has no entry for (?P<oid>.*)")
 def check_search_name_has_entry(context, oid):
     """ Check that there is noentry in the search_name table for the given
         objects. IDs are in format '<NRW><osm id>[:<class>]'.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         NominatimID(oid).row_by_place_id(cur, 'search_name')
 
         assert cur.rowcount == 0, \
                "Found {} entries for ID {}".format(cur.rowcount, oid)
 
+
 @then("location_postcode contains exactly")
 def check_location_postcode(context):
     """ Check full contents for location_postcode table. Each row represents a table row
@@ -290,24 +301,25 @@ def check_location_postcode(context):
         All rows must be present as excepted and there must not be additional
         rows.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
         assert cur.rowcount == len(list(context.table)), \
-            "Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table)))
+            "Postcode table has {cur.rowcount} rows, expected {len(list(context.table))}."
 
         results = {}
         for row in cur:
             key = (row['country_code'], row['postcode'])
             assert key not in results, "Postcode table has duplicate entry: {}".format(row)
-            results[key] = DBRow((row['country_code'],row['postcode']), row, context)
+            results[key] = DBRow((row['country_code'], row['postcode']), row, context)
 
         for row in context.table:
-            db_row = results.get((row['country'],row['postcode']))
+            db_row = results.get((row['country'], row['postcode']))
             assert db_row is not None, \
                 f"Missing row for country '{row['country']}' postcode '{row['postcode']}'."
 
             db_row.assert_row(row, ('country', 'postcode'))
 
+
 @then("there are(?P<exclude> no)? word tokens for postcodes (?P<postcodes>.*)")
 def check_word_table_for_postcodes(context, exclude, postcodes):
     """ Check that the tokenizer produces postcode tokens for the given
@@ -321,23 +333,19 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
 
     plist.sort()
 
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
-        if nctx.tokenizer != 'legacy':
-            cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)",
-                        (plist,))
-        else:
-            cur.execute("""SELECT word FROM word WHERE word = any(%s)
-                             and class = 'place' and type = 'postcode'""",
-                        (plist,))
+    with context.db.cursor() as cur:
+        cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)",
+                    (plist,))
 
-        found = [row[0] for row in cur]
+        found = [row['word'] for row in cur]
         assert len(found) == len(set(found)), f"Duplicate rows for postcodes: {found}"
 
     if exclude:
         assert len(found) == 0, f"Unexpected postcodes: {found}"
     else:
         assert set(found) == set(plist), \
-        f"Missing postcodes {set(plist) - set(found)}. Found: {found}"
+            f"Missing postcodes {set(plist) - set(found)}. Found: {found}"
+
 
 @then("place_addressline contains")
 def check_place_addressline(context):
@@ -347,7 +355,7 @@ def check_place_addressline(context):
         representing the addressee and the 'address' column, representing the
         address item.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         for row in context.table:
             nid = NominatimID(row['object'])
             pid = nid.get_place_id(cur)
@@ -356,17 +364,18 @@ def check_place_addressline(context):
                             WHERE place_id = %s AND address_place_id = %s""",
                         (pid, apid))
             assert cur.rowcount > 0, \
-                        "No rows found for place %s and address %s" % (row['object'], row['address'])
+                f"No rows found for place {row['object']} and address {row['address']}."
 
             for res in cur:
                 DBRow(nid, res, context).assert_row(row, ('address', 'object'))
 
+
 @then("place_addressline doesn't contain")
 def check_place_addressline_exclude(context):
     """ Check that the place_addressline doesn't contain any entries for the
         given addressee/address item pairs.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         for row in context.table:
             pid = NominatimID(row['object']).get_place_id(cur)
             apid = NominatimID(row['address']).get_place_id(cur, allow_empty=True)
@@ -375,13 +384,14 @@ def check_place_addressline_exclude(context):
                                 WHERE place_id = %s AND address_place_id = %s""",
                             (pid, apid))
                 assert cur.rowcount == 0, \
-                    "Row found for place %s and address %s" % (row['object'], row['address'])
+                    f"Row found for place {row['object']} and address {row['address']}."
+
 
-@then("W(?P<oid>\d+) expands to(?P<neg> no)? interpolation")
+@then(r"W(?P<oid>\d+) expands to(?P<neg> no)? interpolation")
 def check_location_property_osmline(context, oid, neg):
     """ Check that the given way is present in the interpolation table.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt
                        FROM location_property_osmline
                        WHERE osm_id = %s AND startnumber IS NOT NULL""",
@@ -396,7 +406,7 @@ def check_location_property_osmline(context, oid, neg):
             for i in todo:
                 row = context.table[i]
                 if (int(row['start']) == res['startnumber']
-                    and int(row['end']) == res['endnumber']):
+                        and int(row['end']) == res['endnumber']):
                     todo.remove(i)
                     break
             else:
@@ -406,8 +416,9 @@ def check_location_property_osmline(context, oid, neg):
 
         assert not todo, f"Unmatched lines in table: {list(context.table[i] for i in todo)}"
 
+
 @then("location_property_osmline contains(?P<exact> exactly)?")
-def check_place_contents(context, exact):
+def check_osmline_contents(context, exact):
     """ Check contents of the interpolation table. Each row represents a table row
         and all data must match. Data not present in the expected table, may
         be arbitrary. The rows are identified via the 'object' column which must
@@ -417,7 +428,7 @@ def check_place_contents(context, exact):
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         expected_content = set()
         for row in context.table:
             if ':' in row['object']:
@@ -447,8 +458,7 @@ def check_place_contents(context, exact):
 
         if exact:
             cur.execute('SELECT osm_id, startnumber from location_property_osmline')
-            actual = set([(r[0], r[1]) for r in cur])
+            actual = set([(r['osm_id'], r['startnumber']) for r in cur])
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"
-