]> git.openstreetmap.org Git - nominatim.git/blobdiff - test/bdd/steps/steps_db_ops.py
release 4.5.0.post7
[nominatim.git] / test / bdd / steps / steps_db_ops.py
index c30ee894280d4eb912a325d6669b0148e2c35d7c..fb8431d5ffae827c7c681878d57a3b94bc3375b5 100644 (file)
@@ -1,24 +1,25 @@
-# SPDX-License-Identifier: GPL-2.0-only
+# SPDX-License-Identifier: GPL-3.0-or-later
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
-# Copyright (C) 2022 by the Nominatim developer community.
+# Copyright (C) 2024 by the Nominatim developer community.
 # For a full list of authors see the git log.
 import logging
 from itertools import chain
 
 # For a full list of authors see the git log.
 import logging
 from itertools import chain
 
-import psycopg2.extras
+import psycopg
+from psycopg import sql as pysql
 
 from place_inserter import PlaceColumn
 from table_compare import NominatimID, DBRow
 
 
 from place_inserter import PlaceColumn
 from table_compare import NominatimID, DBRow
 
-from nominatim.indexer import indexer
-from nominatim.tokenizer import factory as tokenizer_factory
+from nominatim_db.indexer import indexer
+from nominatim_db.tokenizer import factory as tokenizer_factory
 
 def check_database_integrity(context):
     """ Check some generic constraints on the tables.
     """
 
 def check_database_integrity(context):
     """ Check some generic constraints on the tables.
     """
-    with context.db.cursor() as cur:
+    with context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
         # place_addressline should not have duplicate (place_id, address_place_id)
         cur.execute("""SELECT count(*) FROM
                         (SELECT place_id, address_place_id, count(*) as c
         # place_addressline should not have duplicate (place_id, address_place_id)
         cur.execute("""SELECT count(*) FROM
                         (SELECT place_id, address_place_id, count(*) as c
@@ -27,9 +28,8 @@ def check_database_integrity(context):
         assert cur.fetchone()[0] == 0, "Duplicates found in place_addressline"
 
         # word table must not have empty word_tokens
         assert cur.fetchone()[0] == 0, "Duplicates found in place_addressline"
 
         # word table must not have empty word_tokens
-        if context.nominatim.tokenizer != 'legacy':
-            cur.execute("SELECT count(*) FROM word WHERE word_token = ''")
-            assert cur.fetchone()[0] == 0, "Empty word tokens found in word table"
+        cur.execute("SELECT count(*) FROM word WHERE word_token = ''")
+        assert cur.fetchone()[0] == 0, "Empty word tokens found in word table"
 
 
 
 
 
 
@@ -54,7 +54,7 @@ def add_data_to_planet_relations(context):
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
-        if row is None or row[0] == '1':
+        if row is None or row['value'] == '1':
             for r in context.table:
                 last_node = 0
                 last_way = 0
             for r in context.table:
                 last_node = 0
                 last_way = 0
@@ -96,8 +96,8 @@ def add_data_to_planet_relations(context):
 
                 cur.execute("""INSERT INTO planet_osm_rels (id, tags, members)
                                VALUES (%s, %s, %s)""",
 
                 cur.execute("""INSERT INTO planet_osm_rels (id, tags, members)
                                VALUES (%s, %s, %s)""",
-                            (r['id'], psycopg2.extras.Json(tags),
-                             psycopg2.extras.Json(members)))
+                            (r['id'], psycopg.types.json.Json(tags),
+                             psycopg.types.json.Json(members)))
 
 @given("the ways")
 def add_data_to_planet_ways(context):
 
 @given("the ways")
 def add_data_to_planet_ways(context):
@@ -107,10 +107,10 @@ def add_data_to_planet_ways(context):
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
-        json_tags = row is not None and row[0] != '1'
+        json_tags = row is not None and row['value'] != '1'
         for r in context.table:
             if json_tags:
         for r in context.table:
             if json_tags:
-                tags = psycopg2.extras.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
+                tags = psycopg.types.json.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
             else:
                 tags = list(chain.from_iterable([(h[5:], r[h])
                                                  for h in r.headings if h.startswith("tags+")]))
             else:
                 tags = list(chain.from_iterable([(h[5:], r[h])
                                                  for h in r.headings if h.startswith("tags+")]))
@@ -197,7 +197,7 @@ def check_place_contents(context, table, exact):
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         expected_content = set()
         for row in context.table:
             nid = NominatimID(row['object'])
         expected_content = set()
         for row in context.table:
             nid = NominatimID(row['object'])
@@ -215,8 +215,9 @@ def check_place_contents(context, table, exact):
                 DBRow(nid, res, context).assert_row(row, ['object'])
 
         if exact:
                 DBRow(nid, res, context).assert_row(row, ['object'])
 
         if exact:
-            cur.execute('SELECT osm_type, osm_id, class from {}'.format(table))
-            actual = set([(r[0], r[1], r[2]) for r in cur])
+            cur.execute(pysql.SQL('SELECT osm_type, osm_id, class from')
+                        + pysql.Identifier(table))
+            actual = set([(r['osm_type'], r['osm_id'], r['class']) for r in cur])
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"
@@ -227,7 +228,7 @@ def check_place_has_entry(context, table, oid):
     """ Ensure that no database row for the given object exists. The ID
         must be of the form '<NRW><osm id>[:<class>]'.
     """
     """ Ensure that no database row for the given object exists. The ID
         must be of the form '<NRW><osm id>[:<class>]'.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
         assert cur.rowcount == 0, \
                "Found {} entries for ID {}".format(cur.rowcount, oid)
         NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
         assert cur.rowcount == 0, \
                "Found {} entries for ID {}".format(cur.rowcount, oid)
@@ -244,7 +245,7 @@ def check_search_name_contents(context, exclude):
     tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config())
 
     with tokenizer.name_analyzer() as analyzer:
     tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config())
 
     with tokenizer.name_analyzer() as analyzer:
-        with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        with context.db.cursor() as cur:
             for row in context.table:
                 nid = NominatimID(row['object'])
                 nid.row_by_place_id(cur, 'search_name',
             for row in context.table:
                 nid = NominatimID(row['object'])
                 nid.row_by_place_id(cur, 'search_name',
@@ -276,7 +277,7 @@ def check_search_name_has_entry(context, oid):
     """ Check that there is noentry in the search_name table for the given
         objects. IDs are in format '<NRW><osm id>[:<class>]'.
     """
     """ Check that there is noentry in the search_name table for the given
         objects. IDs are in format '<NRW><osm id>[:<class>]'.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         NominatimID(oid).row_by_place_id(cur, 'search_name')
 
         assert cur.rowcount == 0, \
         NominatimID(oid).row_by_place_id(cur, 'search_name')
 
         assert cur.rowcount == 0, \
@@ -290,7 +291,7 @@ def check_location_postcode(context):
         All rows must be present as excepted and there must not be additional
         rows.
     """
         All rows must be present as excepted and there must not be additional
         rows.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
         assert cur.rowcount == len(list(context.table)), \
             "Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table)))
         cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
         assert cur.rowcount == len(list(context.table)), \
             "Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table)))
@@ -321,16 +322,11 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
 
     plist.sort()
 
 
     plist.sort()
 
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
-        if nctx.tokenizer != 'legacy':
-            cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)",
-                        (plist,))
-        else:
-            cur.execute("""SELECT word FROM word WHERE word = any(%s)
-                             and class = 'place' and type = 'postcode'""",
-                        (plist,))
+    with context.db.cursor() as cur:
+        cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)",
+                    (plist,))
 
 
-        found = [row[0] for row in cur]
+        found = [row['word'] for row in cur]
         assert len(found) == len(set(found)), f"Duplicate rows for postcodes: {found}"
 
     if exclude:
         assert len(found) == len(set(found)), f"Duplicate rows for postcodes: {found}"
 
     if exclude:
@@ -347,7 +343,7 @@ def check_place_addressline(context):
         representing the addressee and the 'address' column, representing the
         address item.
     """
         representing the addressee and the 'address' column, representing the
         address item.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         for row in context.table:
             nid = NominatimID(row['object'])
             pid = nid.get_place_id(cur)
         for row in context.table:
             nid = NominatimID(row['object'])
             pid = nid.get_place_id(cur)
@@ -366,7 +362,7 @@ def check_place_addressline_exclude(context):
     """ Check that the place_addressline doesn't contain any entries for the
         given addressee/address item pairs.
     """
     """ Check that the place_addressline doesn't contain any entries for the
         given addressee/address item pairs.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         for row in context.table:
             pid = NominatimID(row['object']).get_place_id(cur)
             apid = NominatimID(row['address']).get_place_id(cur, allow_empty=True)
         for row in context.table:
             pid = NominatimID(row['object']).get_place_id(cur)
             apid = NominatimID(row['address']).get_place_id(cur, allow_empty=True)
@@ -381,7 +377,7 @@ def check_place_addressline_exclude(context):
 def check_location_property_osmline(context, oid, neg):
     """ Check that the given way is present in the interpolation table.
     """
 def check_location_property_osmline(context, oid, neg):
     """ Check that the given way is present in the interpolation table.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt
                        FROM location_property_osmline
                        WHERE osm_id = %s AND startnumber IS NOT NULL""",
         cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt
                        FROM location_property_osmline
                        WHERE osm_id = %s AND startnumber IS NOT NULL""",
@@ -417,7 +413,7 @@ def check_place_contents(context, exact):
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         expected_content = set()
         for row in context.table:
             if ':' in row['object']:
         expected_content = set()
         for row in context.table:
             if ':' in row['object']:
@@ -447,7 +443,7 @@ def check_place_contents(context, exact):
 
         if exact:
             cur.execute('SELECT osm_id, startnumber from location_property_osmline')
 
         if exact:
             cur.execute('SELECT osm_id, startnumber from location_property_osmline')
-            actual = set([(r[0], r[1]) for r in cur])
+            actual = set([(r['osm_id'], r['startnumber']) for r in cur])
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"