]> git.openstreetmap.org Git - nominatim.git/blobdiff - test/bdd/steps/steps_db_ops.py
port code to psycopg3
[nominatim.git] / test / bdd / steps / steps_db_ops.py
index 441198fdd4dfe46391d8fdcffad199d39f74efb5..a0dd9b348e7f60681ad56ed9618f45972d0da945 100644 (file)
@@ -7,7 +7,8 @@
 import logging
 from itertools import chain
 
-import psycopg2.extras
+import psycopg
+from psycopg import sql as pysql
 
 from place_inserter import PlaceColumn
 from table_compare import NominatimID, DBRow
@@ -18,7 +19,7 @@ from nominatim_db.tokenizer import factory as tokenizer_factory
 def check_database_integrity(context):
     """ Check some generic constraints on the tables.
     """
-    with context.db.cursor() as cur:
+    with context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
         # place_addressline should not have duplicate (place_id, address_place_id)
         cur.execute("""SELECT count(*) FROM
                         (SELECT place_id, address_place_id, count(*) as c
@@ -54,7 +55,7 @@ def add_data_to_planet_relations(context):
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
-        if row is None or row[0] == '1':
+        if row is None or row['value'] == '1':
             for r in context.table:
                 last_node = 0
                 last_way = 0
@@ -96,8 +97,8 @@ def add_data_to_planet_relations(context):
 
                 cur.execute("""INSERT INTO planet_osm_rels (id, tags, members)
                                VALUES (%s, %s, %s)""",
-                            (r['id'], psycopg2.extras.Json(tags),
-                             psycopg2.extras.Json(members)))
+                            (r['id'], psycopg.types.json.Json(tags),
+                             psycopg.types.json.Json(members)))
 
 @given("the ways")
 def add_data_to_planet_ways(context):
@@ -107,10 +108,10 @@ def add_data_to_planet_ways(context):
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
-        json_tags = row is not None and row[0] != '1'
+        json_tags = row is not None and row['value'] != '1'
         for r in context.table:
             if json_tags:
-                tags = psycopg2.extras.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
+                tags = psycopg.types.json.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
             else:
                 tags = list(chain.from_iterable([(h[5:], r[h])
                                                  for h in r.headings if h.startswith("tags+")]))
@@ -197,7 +198,7 @@ def check_place_contents(context, table, exact):
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         expected_content = set()
         for row in context.table:
             nid = NominatimID(row['object'])
@@ -215,8 +216,9 @@ def check_place_contents(context, table, exact):
                 DBRow(nid, res, context).assert_row(row, ['object'])
 
         if exact:
-            cur.execute('SELECT osm_type, osm_id, class from {}'.format(table))
-            actual = set([(r[0], r[1], r[2]) for r in cur])
+            cur.execute(pysql.SQL('SELECT osm_type, osm_id, class from')
+                        + pysql.Identifier(table))
+            actual = set([(r['osm_type'], r['osm_id'], r['class']) for r in cur])
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"
@@ -227,7 +229,7 @@ def check_place_has_entry(context, table, oid):
     """ Ensure that no database row for the given object exists. The ID
         must be of the form '<NRW><osm id>[:<class>]'.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
         assert cur.rowcount == 0, \
                "Found {} entries for ID {}".format(cur.rowcount, oid)
@@ -244,7 +246,7 @@ def check_search_name_contents(context, exclude):
     tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config())
 
     with tokenizer.name_analyzer() as analyzer:
-        with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        with context.db.cursor() as cur:
             for row in context.table:
                 nid = NominatimID(row['object'])
                 nid.row_by_place_id(cur, 'search_name',
@@ -276,7 +278,7 @@ def check_search_name_has_entry(context, oid):
     """ Check that there is noentry in the search_name table for the given
         objects. IDs are in format '<NRW><osm id>[:<class>]'.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         NominatimID(oid).row_by_place_id(cur, 'search_name')
 
         assert cur.rowcount == 0, \
@@ -290,7 +292,7 @@ def check_location_postcode(context):
         All rows must be present as excepted and there must not be additional
         rows.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
         assert cur.rowcount == len(list(context.table)), \
             "Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table)))
@@ -321,7 +323,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
 
     plist.sort()
 
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         if nctx.tokenizer != 'legacy':
             cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)",
                         (plist,))
@@ -330,7 +332,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
                              and class = 'place' and type = 'postcode'""",
                         (plist,))
 
-        found = [row[0] for row in cur]
+        found = [row['word'] for row in cur]
         assert len(found) == len(set(found)), f"Duplicate rows for postcodes: {found}"
 
     if exclude:
@@ -347,7 +349,7 @@ def check_place_addressline(context):
         representing the addressee and the 'address' column, representing the
         address item.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         for row in context.table:
             nid = NominatimID(row['object'])
             pid = nid.get_place_id(cur)
@@ -366,7 +368,7 @@ def check_place_addressline_exclude(context):
     """ Check that the place_addressline doesn't contain any entries for the
         given addressee/address item pairs.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         for row in context.table:
             pid = NominatimID(row['object']).get_place_id(cur)
             apid = NominatimID(row['address']).get_place_id(cur, allow_empty=True)
@@ -381,7 +383,7 @@ def check_place_addressline_exclude(context):
 def check_location_property_osmline(context, oid, neg):
     """ Check that the given way is present in the interpolation table.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt
                        FROM location_property_osmline
                        WHERE osm_id = %s AND startnumber IS NOT NULL""",
@@ -417,7 +419,7 @@ def check_place_contents(context, exact):
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         expected_content = set()
         for row in context.table:
             if ':' in row['object']:
@@ -447,7 +449,7 @@ def check_place_contents(context, exact):
 
         if exact:
             cur.execute('SELECT osm_id, startnumber from location_property_osmline')
-            actual = set([(r[0], r[1]) for r in cur])
+            actual = set([(r['osm_id'], r['startnumber']) for r in cur])
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"