]> git.openstreetmap.org Git - nominatim.git/blobdiff - test/bdd/steps/steps_db_ops.py
test: use src_dir fixture instead of self-computed paths
[nominatim.git] / test / bdd / steps / steps_db_ops.py
index c549f3eb5476e144f3921b3e1d92ea445755f528..8285f9389dfc8ee8f027950e76ce058ebf19fc14 100644 (file)
@@ -1,3 +1,4 @@
+import logging
 from itertools import chain
 
 import psycopg2.extras
 from itertools import chain
 
 import psycopg2.extras
@@ -5,6 +6,8 @@ import psycopg2.extras
 from place_inserter import PlaceColumn
 from table_compare import NominatimID, DBRow
 
 from place_inserter import PlaceColumn
 from table_compare import NominatimID, DBRow
 
+from nominatim.indexer import indexer
+from nominatim.tokenizer import factory as tokenizer_factory
 
 def check_database_integrity(context):
     """ Check some generic constraints on the tables.
 
 def check_database_integrity(context):
     """ Check some generic constraints on the tables.
@@ -84,8 +87,30 @@ def add_data_to_planet_ways(context):
 def import_and_index_data_from_place_table(context):
     """ Import data previously set up in the place table.
     """
 def import_and_index_data_from_place_table(context):
     """ Import data previously set up in the place table.
     """
+    nctx = context.nominatim
+
+    tokenizer = tokenizer_factory.create_tokenizer(nctx.get_test_config())
     context.nominatim.copy_from_place(context.db)
     context.nominatim.copy_from_place(context.db)
-    context.nominatim.run_setup_script('calculate-postcodes', 'index', 'index-noanalyse')
+
+    # XXX use tool function as soon as it is ported
+    with context.db.cursor() as cur:
+        with (context.nominatim.src_dir / 'lib-sql' / 'postcode_tables.sql').open('r') as fd:
+            cur.execute(fd.read())
+        cur.execute("""
+            INSERT INTO location_postcode
+             (place_id, indexed_status, country_code, postcode, geometry)
+            SELECT nextval('seq_place'), 1, country_code,
+                   upper(trim (both ' ' from address->'postcode')) as pc,
+                   ST_Centroid(ST_Collect(ST_Centroid(geometry)))
+              FROM placex
+             WHERE address ? 'postcode' AND address->'postcode' NOT SIMILAR TO '%(,|;)%'
+                   AND geometry IS NOT null
+             GROUP BY country_code, pc""")
+
+    # Call directly as the refresh function does not include postcodes.
+    indexer.LOG.setLevel(logging.ERROR)
+    indexer.Indexer(context.nominatim.get_libpq_dsn(), tokenizer, 1).index_full(analyse=False)
+
     check_database_integrity(context)
 
 @when("updating places")
     check_database_integrity(context)
 
 @when("updating places")
@@ -93,8 +118,7 @@ def update_place_table(context):
     """ Update the place table with the given data. Also runs all triggers
         related to updates and reindexes the new data.
     """
     """ Update the place table with the given data. Also runs all triggers
         related to updates and reindexes the new data.
     """
-    context.nominatim.run_setup_script(
-        'create-functions', 'create-partition-functions', 'enable-diff-updates')
+    context.nominatim.run_nominatim('refresh', '--functions')
     with context.db.cursor() as cur:
         for row in context.table:
             PlaceColumn(context).add_row(row, False).db_insert(cur)
     with context.db.cursor() as cur:
         for row in context.table:
             PlaceColumn(context).add_row(row, False).db_insert(cur)
@@ -106,7 +130,7 @@ def update_place_table(context):
 def update_postcodes(context):
     """ Rerun the calculation of postcodes.
     """
 def update_postcodes(context):
     """ Rerun the calculation of postcodes.
     """
-    context.nominatim.run_update_script('calculate-postcodes')
+    context.nominatim.run_nominatim('refresh', '--postcodes')
 
 @when("marking for delete (?P<oids>.*)")
 def delete_places(context, oids):
 
 @when("marking for delete (?P<oids>.*)")
 def delete_places(context, oids):
@@ -114,8 +138,7 @@ def delete_places(context, oids):
         separated by commas. Also runs all triggers
         related to updates and reindexes the new data.
     """
         separated by commas. Also runs all triggers
         related to updates and reindexes the new data.
     """
-    context.nominatim.run_setup_script(
-        'create-functions', 'create-partition-functions', 'enable-diff-updates')
+    context.nominatim.run_nominatim('refresh', '--functions')
     with context.db.cursor() as cur:
         for oid in oids.split(','):
             NominatimID(oid).query_osm_id(cur, 'DELETE FROM place WHERE {}')
     with context.db.cursor() as cur:
         for oid in oids.split(','):
             NominatimID(oid).query_osm_id(cur, 'DELETE FROM place WHERE {}')
@@ -176,44 +199,35 @@ def check_search_name_contents(context, exclude):
         have an identifier of the form '<NRW><osm id>[:<class>]'. All
         expected rows are expected to be present with at least one database row.
     """
         have an identifier of the form '<NRW><osm id>[:<class>]'. All
         expected rows are expected to be present with at least one database row.
     """
-    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
-        for row in context.table:
-            nid = NominatimID(row['object'])
-            nid.row_by_place_id(cur, 'search_name',
-                                ['ST_X(centroid) as cx', 'ST_Y(centroid) as cy'])
-            assert cur.rowcount > 0, "No rows found for " + row['object']
+    tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config())
+
+    with tokenizer.name_analyzer() as analyzer:
+        with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+            for row in context.table:
+                nid = NominatimID(row['object'])
+                nid.row_by_place_id(cur, 'search_name',
+                                    ['ST_X(centroid) as cx', 'ST_Y(centroid) as cy'])
+                assert cur.rowcount > 0, "No rows found for " + row['object']
+
+                for res in cur:
+                    db_row = DBRow(nid, res, context)
+                    for name, value in zip(row.headings, row.cells):
+                        if name in ('name_vector', 'nameaddress_vector'):
+                            items = [x.strip() for x in value.split(',')]
+                            tokens = analyzer.get_word_token_info(context.db, items)
 
 
-            for res in cur:
-                db_row = DBRow(nid, res, context)
-                for name, value in zip(row.headings, row.cells):
-                    if name in ('name_vector', 'nameaddress_vector'):
-                        items = [x.strip() for x in value.split(',')]
-                        with context.db.cursor() as subcur:
-                            subcur.execute(""" SELECT word_id, word_token
-                                               FROM word, (SELECT unnest(%s::TEXT[]) as term) t
-                                               WHERE word_token = make_standard_name(t.term)
-                                                     and class is null and country_code is null
-                                                     and operator is null
-                                              UNION
-                                               SELECT word_id, word_token
-                                               FROM word, (SELECT unnest(%s::TEXT[]) as term) t
-                                               WHERE word_token = ' ' || make_standard_name(t.term)
-                                                     and class is null and country_code is null
-                                                     and operator is null
-                                           """,
-                                           (list(filter(lambda x: not x.startswith('#'), items)),
-                                            list(filter(lambda x: x.startswith('#'), items))))
                             if not exclude:
                             if not exclude:
-                                assert subcur.rowcount >= len(items), \
-                                    "No word entry found for {}. Entries found: {!s}".format(value, subcur.rowcount)
-                            for wid in subcur:
-                                present = wid[0] in res[name]
+                                assert len(tokens) >= len(items), \
+                                       "No word entry found for {}. Entries found: {!s}".format(value, len(tokens))
+                            for word, token, wid in tokens:
                                 if exclude:
                                 if exclude:
-                                    assert not present, "Found term for {}/{}: {}".format(row['object'], name, wid[1])
+                                    assert wid not in res[name], \
+                                           "Found term for {}/{}: {}".format(nid, name, wid)
                                 else:
                                 else:
-                                    assert present, "Missing term for {}/{}: {}".fromat(row['object'], name, wid[1])
-                    elif name != 'object':
-                        assert db_row.contains(name, value), db_row.assert_msg(name, value)
+                                    assert wid in res[name], \
+                                           "Missing term for {}/{}: {}".format(nid, name, wid)
+                        elif name != 'object':
+                            assert db_row.contains(name, value), db_row.assert_msg(name, value)
 
 @then("search_name has no entry for (?P<oid>.*)")
 def check_search_name_has_entry(context, oid):
 
 @then("search_name has no entry for (?P<oid>.*)")
 def check_search_name_has_entry(context, oid):
@@ -237,7 +251,7 @@ def check_location_postcode(context):
     with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
         cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
         assert cur.rowcount == len(list(context.table)), \
     with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
         cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
         assert cur.rowcount == len(list(context.table)), \
-            "Postcode table has {} rows, expected {}.".foramt(cur.rowcount, len(list(context.table)))
+            "Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table)))
 
         results = {}
         for row in cur:
 
         results = {}
         for row in cur: