]> git.openstreetmap.org Git - nominatim.git/blobdiff - test/bdd/steps/steps_db_ops.py
lift restrictions on search with frequent terms slightly
[nominatim.git] / test / bdd / steps / steps_db_ops.py
index 441198fdd4dfe46391d8fdcffad199d39f74efb5..a0dd9b348e7f60681ad56ed9618f45972d0da945 100644 (file)
@@ -7,7 +7,8 @@
 import logging
 from itertools import chain
 
 import logging
 from itertools import chain
 
-import psycopg2.extras
+import psycopg
+from psycopg import sql as pysql
 
 from place_inserter import PlaceColumn
 from table_compare import NominatimID, DBRow
 
 from place_inserter import PlaceColumn
 from table_compare import NominatimID, DBRow
@@ -18,7 +19,7 @@ from nominatim_db.tokenizer import factory as tokenizer_factory
 def check_database_integrity(context):
     """ Check some generic constraints on the tables.
     """
 def check_database_integrity(context):
     """ Check some generic constraints on the tables.
     """
-    with context.db.cursor() as cur:
+    with context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
         # place_addressline should not have duplicate (place_id, address_place_id)
         cur.execute("""SELECT count(*) FROM
                         (SELECT place_id, address_place_id, count(*) as c
         # place_addressline should not have duplicate (place_id, address_place_id)
         cur.execute("""SELECT count(*) FROM
                         (SELECT place_id, address_place_id, count(*) as c
@@ -54,7 +55,7 @@ def add_data_to_planet_relations(context):
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
-        if row is None or row[0] == '1':
+        if row is None or row['value'] == '1':
             for r in context.table:
                 last_node = 0
                 last_way = 0
             for r in context.table:
                 last_node = 0
                 last_way = 0
@@ -96,8 +97,8 @@ def add_data_to_planet_relations(context):
 
                 cur.execute("""INSERT INTO planet_osm_rels (id, tags, members)
                                VALUES (%s, %s, %s)""",
 
                 cur.execute("""INSERT INTO planet_osm_rels (id, tags, members)
                                VALUES (%s, %s, %s)""",
-                            (r['id'], psycopg2.extras.Json(tags),
-                             psycopg2.extras.Json(members)))
+                            (r['id'], psycopg.types.json.Json(tags),
+                             psycopg.types.json.Json(members)))
 
 @given("the ways")
 def add_data_to_planet_ways(context):
 
 @given("the ways")
 def add_data_to_planet_ways(context):
@@ -107,10 +108,10 @@ def add_data_to_planet_ways(context):
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
     with context.db.cursor() as cur:
         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
         row = cur.fetchone()
-        json_tags = row is not None and row[0] != '1'
+        json_tags = row is not None and row['value'] != '1'
         for r in context.table:
             if json_tags:
         for r in context.table:
             if json_tags:
-                tags = psycopg2.extras.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
+                tags = psycopg.types.json.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
             else:
                 tags = list(chain.from_iterable([(h[5:], r[h])
                                                  for h in r.headings if h.startswith("tags+")]))
             else:
                 tags = list(chain.from_iterable([(h[5:], r[h])
                                                  for h in r.headings if h.startswith("tags+")]))
@@ -197,7 +198,7 @@ def check_place_contents(context, table, exact):
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         expected_content = set()
         for row in context.table:
             nid = NominatimID(row['object'])
         expected_content = set()
         for row in context.table:
             nid = NominatimID(row['object'])
@@ -215,8 +216,9 @@ def check_place_contents(context, table, exact):
                 DBRow(nid, res, context).assert_row(row, ['object'])
 
         if exact:
                 DBRow(nid, res, context).assert_row(row, ['object'])
 
         if exact:
-            cur.execute('SELECT osm_type, osm_id, class from {}'.format(table))
-            actual = set([(r[0], r[1], r[2]) for r in cur])
+            cur.execute(pysql.SQL('SELECT osm_type, osm_id, class from')
+                        + pysql.Identifier(table))
+            actual = set([(r['osm_type'], r['osm_id'], r['class']) for r in cur])
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"
@@ -227,7 +229,7 @@ def check_place_has_entry(context, table, oid):
     """ Ensure that no database row for the given object exists. The ID
         must be of the form '<NRW><osm id>[:<class>]'.
     """
     """ Ensure that no database row for the given object exists. The ID
         must be of the form '<NRW><osm id>[:<class>]'.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
         assert cur.rowcount == 0, \
                "Found {} entries for ID {}".format(cur.rowcount, oid)
         NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
         assert cur.rowcount == 0, \
                "Found {} entries for ID {}".format(cur.rowcount, oid)
@@ -244,7 +246,7 @@ def check_search_name_contents(context, exclude):
     tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config())
 
     with tokenizer.name_analyzer() as analyzer:
     tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config())
 
     with tokenizer.name_analyzer() as analyzer:
-        with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        with context.db.cursor() as cur:
             for row in context.table:
                 nid = NominatimID(row['object'])
                 nid.row_by_place_id(cur, 'search_name',
             for row in context.table:
                 nid = NominatimID(row['object'])
                 nid.row_by_place_id(cur, 'search_name',
@@ -276,7 +278,7 @@ def check_search_name_has_entry(context, oid):
     """ Check that there is noentry in the search_name table for the given
         objects. IDs are in format '<NRW><osm id>[:<class>]'.
     """
     """ Check that there is noentry in the search_name table for the given
         objects. IDs are in format '<NRW><osm id>[:<class>]'.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         NominatimID(oid).row_by_place_id(cur, 'search_name')
 
         assert cur.rowcount == 0, \
         NominatimID(oid).row_by_place_id(cur, 'search_name')
 
         assert cur.rowcount == 0, \
@@ -290,7 +292,7 @@ def check_location_postcode(context):
         All rows must be present as excepted and there must not be additional
         rows.
     """
         All rows must be present as excepted and there must not be additional
         rows.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
         assert cur.rowcount == len(list(context.table)), \
             "Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table)))
         cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
         assert cur.rowcount == len(list(context.table)), \
             "Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table)))
@@ -321,7 +323,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
 
     plist.sort()
 
 
     plist.sort()
 
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         if nctx.tokenizer != 'legacy':
             cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)",
                         (plist,))
         if nctx.tokenizer != 'legacy':
             cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)",
                         (plist,))
@@ -330,7 +332,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
                              and class = 'place' and type = 'postcode'""",
                         (plist,))
 
                              and class = 'place' and type = 'postcode'""",
                         (plist,))
 
-        found = [row[0] for row in cur]
+        found = [row['word'] for row in cur]
         assert len(found) == len(set(found)), f"Duplicate rows for postcodes: {found}"
 
     if exclude:
         assert len(found) == len(set(found)), f"Duplicate rows for postcodes: {found}"
 
     if exclude:
@@ -347,7 +349,7 @@ def check_place_addressline(context):
         representing the addressee and the 'address' column, representing the
         address item.
     """
         representing the addressee and the 'address' column, representing the
         address item.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         for row in context.table:
             nid = NominatimID(row['object'])
             pid = nid.get_place_id(cur)
         for row in context.table:
             nid = NominatimID(row['object'])
             pid = nid.get_place_id(cur)
@@ -366,7 +368,7 @@ def check_place_addressline_exclude(context):
     """ Check that the place_addressline doesn't contain any entries for the
         given addressee/address item pairs.
     """
     """ Check that the place_addressline doesn't contain any entries for the
         given addressee/address item pairs.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         for row in context.table:
             pid = NominatimID(row['object']).get_place_id(cur)
             apid = NominatimID(row['address']).get_place_id(cur, allow_empty=True)
         for row in context.table:
             pid = NominatimID(row['object']).get_place_id(cur)
             apid = NominatimID(row['address']).get_place_id(cur, allow_empty=True)
@@ -381,7 +383,7 @@ def check_place_addressline_exclude(context):
 def check_location_property_osmline(context, oid, neg):
     """ Check that the given way is present in the interpolation table.
     """
 def check_location_property_osmline(context, oid, neg):
     """ Check that the given way is present in the interpolation table.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt
                        FROM location_property_osmline
                        WHERE osm_id = %s AND startnumber IS NOT NULL""",
         cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt
                        FROM location_property_osmline
                        WHERE osm_id = %s AND startnumber IS NOT NULL""",
@@ -417,7 +419,7 @@ def check_place_contents(context, exact):
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
         expected rows are expected to be present with at least one database row.
         When 'exactly' is given, there must not be additional rows in the database.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+    with context.db.cursor() as cur:
         expected_content = set()
         for row in context.table:
             if ':' in row['object']:
         expected_content = set()
         for row in context.table:
             if ':' in row['object']:
@@ -447,7 +449,7 @@ def check_place_contents(context, exact):
 
         if exact:
             cur.execute('SELECT osm_id, startnumber from location_property_osmline')
 
         if exact:
             cur.execute('SELECT osm_id, startnumber from location_property_osmline')
-            actual = set([(r[0], r[1]) for r in cur])
+            actual = set([(r['osm_id'], r['startnumber']) for r in cur])
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"
             assert expected_content == actual, \
                    f"Missing entries: {expected_content - actual}\n" \
                    f"Not expected in table: {actual - expected_content}"