]> git.openstreetmap.org Git - nominatim.git/blobdiff - test/bdd/steps/steps_db_ops.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / test / bdd / steps / steps_db_ops.py
index 493b40cc333c57610487d4ac2f76ddf33684e686..a0dd9b348e7f60681ad56ed9618f45972d0da945 100644 (file)
@@ -1,24 +1,25 @@
-# SPDX-License-Identifier: GPL-2.0-only
+# SPDX-License-Identifier: GPL-3.0-or-later
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
-# Copyright (C) 2022 by the Nominatim developer community.
+# Copyright (C) 2024 by the Nominatim developer community.
 # For a full list of authors see the git log.
 import logging
 from itertools import chain
 
-import psycopg2.extras
+import psycopg
+from psycopg import sql as pysql
 
 from place_inserter import PlaceColumn
 from table_compare import NominatimID, DBRow
 
-from nominatim.indexer import indexer
-from nominatim.tokenizer import factory as tokenizer_factory
+from nominatim_db.indexer import indexer
+from nominatim_db.tokenizer import factory as tokenizer_factory
 
 def check_database_integrity(context):
     """ Check some generic constraints on the tables.
     """
-    with context.db.cursor() as cur:
+    with context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
         # place_addressline should not have duplicate (place_id, address_place_id)
         cur.execute("""SELECT count(*) FROM
                         (SELECT place_id, address_place_id, count(*) as c
@@ -52,33 +53,52 @@ def add_data_to_planet_relations(context):
         for tests on data that looks up members.
     """
     with context.db.cursor() as cur:
-        for r in context.table:
-            last_node = 0
-            last_way = 0
-            parts = []
-            if r['members']:
-                members = []
-                for m in r['members'].split(','):
-                    mid = NominatimID(m)
-                    if mid.typ == 'N':
-                        parts.insert(last_node, int(mid.oid))
-                        last_node += 1
-                        last_way += 1
-                    elif mid.typ == 'W':
-                        parts.insert(last_way, int(mid.oid))
-                        last_way += 1
-                    else:
-                        parts.append(int(mid.oid))
-
-                    members.extend((mid.typ.lower() + mid.oid, mid.cls or ''))
-            else:
-                members = None
-
-            tags = chain.from_iterable([(h[5:], r[h]) for h in r.headings if h.startswith("tags+")])
-
-            cur.execute("""INSERT INTO planet_osm_rels (id, way_off, rel_off, parts, members, tags)
-                           VALUES (%s, %s, %s, %s, %s, %s)""",
-                        (r['id'], last_node, last_way, parts, members, list(tags)))
+        cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
+        row = cur.fetchone()
+        if row is None or row['value'] == '1':
+            for r in context.table:
+                last_node = 0
+                last_way = 0
+                parts = []
+                if r['members']:
+                    members = []
+                    for m in r['members'].split(','):
+                        mid = NominatimID(m)
+                        if mid.typ == 'N':
+                            parts.insert(last_node, int(mid.oid))
+                            last_node += 1
+                            last_way += 1
+                        elif mid.typ == 'W':
+                            parts.insert(last_way, int(mid.oid))
+                            last_way += 1
+                        else:
+                            parts.append(int(mid.oid))
+
+                        members.extend((mid.typ.lower() + mid.oid, mid.cls or ''))
+                else:
+                    members = None
+
+                tags = chain.from_iterable([(h[5:], r[h]) for h in r.headings if h.startswith("tags+")])
+
+                cur.execute("""INSERT INTO planet_osm_rels (id, way_off, rel_off, parts, members, tags)
+                               VALUES (%s, %s, %s, %s, %s, %s)""",
+                            (r['id'], last_node, last_way, parts, members, list(tags)))
+        else:
+            for r in context.table:
+                if r['members']:
+                    members = []
+                    for m in r['members'].split(','):
+                        mid = NominatimID(m)
+                        members.append({'ref': mid.oid, 'role': mid.cls or '', 'type': mid.typ})
+                else:
+                    members = []
+
+                tags = {h[5:]: r[h] for h in r.headings if h.startswith("tags+")}
+
+                cur.execute("""INSERT INTO planet_osm_rels (id, tags, members)
+                               VALUES (%s, %s, %s)""",
+                            (r['id'], psycopg.types.json.Json(tags),
+                             psycopg.types.json.Json(members)))
 
 @given("the ways")
 def add_data_to_planet_ways(context):
@@ -86,12 +106,19 @@ def add_data_to_planet_ways(context):
         tests on that that looks up node ids in this table.
     """
     with context.db.cursor() as cur:
+        cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
+        row = cur.fetchone()
+        json_tags = row is not None and row['value'] != '1'
         for r in context.table:
-            tags = chain.from_iterable([(h[5:], r[h]) for h in r.headings if h.startswith("tags+")])
+            if json_tags:
+                tags = psycopg.types.json.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
+            else:
+                tags = list(chain.from_iterable([(h[5:], r[h])
+                                                 for h in r.headings if h.startswith("tags+")]))
             nodes = [ int(x.strip()) for x in r['nodes'].split(',') ]
 
             cur.execute("INSERT INTO planet_osm_ways (id, nodes, tags) VALUES (%s, %s, %s)",
-                        (r['id'], nodes, list(tags)))
+                        (r['id'], nodes, tags))
 
 ################################ WHEN ##################################
 
@@ -171,7 +198,7 @@ def check_place_contents(context, table, exact):
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         expected_content = set()
         for row in context.table:
             nid = NominatimID(row['object'])
@@ -189,8 +216,9 @@ def check_place_contents(context, table, exact):
                 DBRow(nid, res, context).assert_row(row, ['object'])
 
         if exact:
-            cur.execute('SELECT osm_type, osm_id, class from {}'.format(table))
-            actual = set([(r[0], r[1], r[2]) for r in cur])
+            cur.execute(pysql.SQL('SELECT osm_type, osm_id, class from')
+                        + pysql.Identifier(table))
+            actual = set([(r['osm_type'], r['osm_id'], r['class']) for r in cur])
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"
@@ -201,7 +229,7 @@ def check_place_has_entry(context, table, oid):
     """ Ensure that no database row for the given object exists. The ID
         must be of the form '<NRW><osm id>[:<class>]'.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
         assert cur.rowcount == 0, \
                "Found {} entries for ID {}".format(cur.rowcount, oid)
@@ -218,7 +246,7 @@ def check_search_name_contents(context, exclude):
     tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config())
 
     with tokenizer.name_analyzer() as analyzer:
-        with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        with context.db.cursor() as cur:
             for row in context.table:
                 nid = NominatimID(row['object'])
                 nid.row_by_place_id(cur, 'search_name',
@@ -250,7 +278,7 @@ def check_search_name_has_entry(context, oid):
     """ Check that there is noentry in the search_name table for the given
         objects. IDs are in format '<NRW><osm id>[:<class>]'.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         NominatimID(oid).row_by_place_id(cur, 'search_name')
 
         assert cur.rowcount == 0, \
@@ -264,7 +292,7 @@ def check_location_postcode(context):
         All rows must be present as excepted and there must not be additional
         rows.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
         assert cur.rowcount == len(list(context.table)), \
             "Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table)))
@@ -295,7 +323,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
 
     plist.sort()
 
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         if nctx.tokenizer != 'legacy':
             cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)",
                         (plist,))
@@ -304,7 +332,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
                              and class = 'place' and type = 'postcode'""",
                         (plist,))
 
-        found = [row[0] for row in cur]
+        found = [row['word'] for row in cur]
         assert len(found) == len(set(found)), f"Duplicate rows for postcodes: {found}"
 
     if exclude:
@@ -321,7 +349,7 @@ def check_place_addressline(context):
         representing the addressee and the 'address' column, representing the
         address item.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         for row in context.table:
             nid = NominatimID(row['object'])
             pid = nid.get_place_id(cur)
@@ -340,7 +368,7 @@ def check_place_addressline_exclude(context):
     """ Check that the place_addressline doesn't contain any entries for the
         given addressee/address item pairs.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         for row in context.table:
             pid = NominatimID(row['object']).get_place_id(cur)
             apid = NominatimID(row['address']).get_place_id(cur, allow_empty=True)
@@ -355,7 +383,7 @@ def check_place_addressline_exclude(context):
 def check_location_property_osmline(context, oid, neg):
     """ Check that the given way is present in the interpolation table.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt
                        FROM location_property_osmline
                        WHERE osm_id = %s AND startnumber IS NOT NULL""",
@@ -391,7 +419,7 @@ def check_place_contents(context, exact):
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         expected_content = set()
         for row in context.table:
             if ':' in row['object']:
@@ -421,7 +449,7 @@ def check_place_contents(context, exact):
 
         if exact:
             cur.execute('SELECT osm_id, startnumber from location_property_osmline')
-            actual = set([(r[0], r[1]) for r in cur])
+            actual = set([(r['osm_id'], r['startnumber']) for r in cur])
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"