]> git.openstreetmap.org Git - nominatim.git/blobdiff - test/python/api/search/test_api_search_query.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / test / python / api / search / test_api_search_query.py
index f8c9c2dc865ba9f8ca527014c1d292dfbba14313..412a5bf2478323a38b1a3837ce40add4ad6c506f 100644 (file)
@@ -2,14 +2,14 @@
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
-# Copyright (C) 2023 by the Nominatim developer community.
+# Copyright (C) 2024 by the Nominatim developer community.
 # For a full list of authors see the git log.
 """
 Tests for tokenized query data structures.
 """
 import pytest
 
-from nominatim.api.search import query
+from nominatim_api.search import query
 
 class MyToken(query.Token):
 
@@ -18,45 +18,46 @@ class MyToken(query.Token):
 
 
 def mktoken(tid: int):
-    return MyToken(3.0, tid, 1, 'foo', True)
+    return MyToken(penalty=3.0, token=tid, count=1, addr_count=1,
+                   lookup_word='foo')
 
 
-@pytest.mark.parametrize('ptype,ttype', [('NONE', 'WORD'),
-                                         ('AMENITY', 'QUALIFIER'),
-                                         ('STREET', 'PARTIAL'),
-                                         ('CITY', 'WORD'),
-                                         ('COUNTRY', 'COUNTRY'),
-                                         ('POSTCODE', 'POSTCODE')])
+@pytest.mark.parametrize('ptype,ttype', [(query.PHRASE_ANY, 'W'),
+                                         (query.PHRASE_AMENITY, 'Q'),
+                                         (query.PHRASE_STREET, 'w'),
+                                         (query.PHRASE_CITY, 'W'),
+                                         (query.PHRASE_COUNTRY, 'C'),
+                                         (query.PHRASE_POSTCODE, 'P')])
 def test_phrase_compatible(ptype, ttype):
-    assert query.PhraseType[ptype].compatible_with(query.TokenType[ttype])
+    assert query._phrase_compatible_with(ptype, ttype, False)
 
 
-@pytest.mark.parametrize('ptype', ['COUNTRY', 'POSTCODE'])
+@pytest.mark.parametrize('ptype', [query.PHRASE_COUNTRY, query.PHRASE_POSTCODE])
 def test_phrase_incompatible(ptype):
-    assert not query.PhraseType[ptype].compatible_with(query.TokenType.PARTIAL)
+    assert not query._phrase_compatible_with(ptype, query.TOKEN_PARTIAL, True)
 
 
 def test_query_node_empty():
-    qn = query.QueryNode(query.BreakType.PHRASE, query.PhraseType.NONE)
+    qn = query.QueryNode(query.BREAK_PHRASE, query.PHRASE_ANY)
 
-    assert not qn.has_tokens(3, query.TokenType.PARTIAL)
-    assert qn.get_tokens(3, query.TokenType.WORD) is None
+    assert not qn.has_tokens(3, query.TOKEN_PARTIAL)
+    assert qn.get_tokens(3, query.TOKEN_WORD) is None
 
 
 def test_query_node_with_content():
-    qn = query.QueryNode(query.BreakType.PHRASE, query.PhraseType.NONE)
-    qn.starting.append(query.TokenList(2, query.TokenType.PARTIAL, [mktoken(100), mktoken(101)]))
-    qn.starting.append(query.TokenList(2, query.TokenType.WORD, [mktoken(1000)]))
+    qn = query.QueryNode(query.BREAK_PHRASE, query.PHRASE_ANY)
+    qn.starting.append(query.TokenList(2, query.TOKEN_PARTIAL, [mktoken(100), mktoken(101)]))
+    qn.starting.append(query.TokenList(2, query.TOKEN_WORD, [mktoken(1000)]))
 
-    assert not qn.has_tokens(3, query.TokenType.PARTIAL)
-    assert not qn.has_tokens(2, query.TokenType.COUNTRY)
-    assert qn.has_tokens(2, query.TokenType.PARTIAL)
-    assert qn.has_tokens(2, query.TokenType.WORD)
+    assert not qn.has_tokens(3, query.TOKEN_PARTIAL)
+    assert not qn.has_tokens(2, query.TOKEN_COUNTRY)
+    assert qn.has_tokens(2, query.TOKEN_PARTIAL)
+    assert qn.has_tokens(2, query.TOKEN_WORD)
 
-    assert qn.get_tokens(3, query.TokenType.PARTIAL) is None
-    assert qn.get_tokens(2, query.TokenType.COUNTRY) is None
-    assert len(qn.get_tokens(2, query.TokenType.PARTIAL)) == 2
-    assert len(qn.get_tokens(2, query.TokenType.WORD)) == 1
+    assert qn.get_tokens(3, query.TOKEN_PARTIAL) is None
+    assert qn.get_tokens(2, query.TOKEN_COUNTRY) is None
+    assert len(qn.get_tokens(2, query.TOKEN_PARTIAL)) == 2
+    assert len(qn.get_tokens(2, query.TOKEN_WORD)) == 1
 
 
 def test_query_struct_empty():
@@ -66,19 +67,19 @@ def test_query_struct_empty():
 
 
 def test_query_struct_with_tokens():
-    q = query.QueryStruct([query.Phrase(query.PhraseType.NONE, 'foo bar')])
-    q.add_node(query.BreakType.WORD, query.PhraseType.NONE)
-    q.add_node(query.BreakType.END, query.PhraseType.NONE)
+    q = query.QueryStruct([query.Phrase(query.PHRASE_ANY, 'foo bar')])
+    q.add_node(query.BREAK_WORD, query.PHRASE_ANY)
+    q.add_node(query.BREAK_END, query.PHRASE_ANY)
 
     assert q.num_token_slots() == 2
 
-    q.add_token(query.TokenRange(0, 1), query.TokenType.PARTIAL, mktoken(1))
-    q.add_token(query.TokenRange(1, 2), query.TokenType.PARTIAL, mktoken(2))
-    q.add_token(query.TokenRange(1, 2), query.TokenType.WORD, mktoken(99))
-    q.add_token(query.TokenRange(1, 2), query.TokenType.WORD, mktoken(98))
+    q.add_token(query.TokenRange(0, 1), query.TOKEN_PARTIAL, mktoken(1))
+    q.add_token(query.TokenRange(1, 2), query.TOKEN_PARTIAL, mktoken(2))
+    q.add_token(query.TokenRange(1, 2), query.TOKEN_WORD, mktoken(99))
+    q.add_token(query.TokenRange(1, 2), query.TOKEN_WORD, mktoken(98))
 
-    assert q.get_tokens(query.TokenRange(0, 2), query.TokenType.WORD) == []
-    assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.WORD)) == 2
+    assert q.get_tokens(query.TokenRange(0, 2), query.TOKEN_WORD) == []
+    assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_WORD)) == 2
 
     partials = q.get_partials_list(query.TokenRange(0, 2))
 
@@ -90,12 +91,45 @@ def test_query_struct_with_tokens():
 
 
 def test_query_struct_incompatible_token():
-    q = query.QueryStruct([query.Phrase(query.PhraseType.COUNTRY, 'foo bar')])
-    q.add_node(query.BreakType.WORD, query.PhraseType.COUNTRY)
-    q.add_node(query.BreakType.END, query.PhraseType.NONE)
+    q = query.QueryStruct([query.Phrase(query.PHRASE_COUNTRY, 'foo bar')])
+    q.add_node(query.BREAK_WORD, query.PHRASE_COUNTRY)
+    q.add_node(query.BREAK_END, query.PHRASE_ANY)
 
-    q.add_token(query.TokenRange(0, 1), query.TokenType.PARTIAL, mktoken(1))
-    q.add_token(query.TokenRange(1, 2), query.TokenType.COUNTRY, mktoken(100))
+    q.add_token(query.TokenRange(0, 1), query.TOKEN_PARTIAL, mktoken(1))
+    q.add_token(query.TokenRange(1, 2), query.TOKEN_COUNTRY, mktoken(100))
+
+    assert q.get_tokens(query.TokenRange(0, 1), query.TOKEN_PARTIAL) == []
+    assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_COUNTRY)) == 1
+
+
+def test_query_struct_amenity_single_word():
+    q = query.QueryStruct([query.Phrase(query.PHRASE_AMENITY, 'bar')])
+    q.add_node(query.BREAK_END, query.PHRASE_ANY)
+
+    q.add_token(query.TokenRange(0, 1), query.TOKEN_PARTIAL, mktoken(1))
+    q.add_token(query.TokenRange(0, 1), query.TOKEN_NEAR_ITEM, mktoken(2))
+    q.add_token(query.TokenRange(0, 1), query.TOKEN_QUALIFIER, mktoken(3))
+
+    assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_PARTIAL)) == 1
+    assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_NEAR_ITEM)) == 1
+    assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_QUALIFIER)) == 0
+
+
+def test_query_struct_amenity_two_words():
+    q = query.QueryStruct([query.Phrase(query.PHRASE_AMENITY, 'foo bar')])
+    q.add_node(query.BREAK_WORD, query.PHRASE_AMENITY)
+    q.add_node(query.BREAK_END, query.PHRASE_ANY)
+
+    for trange in [(0, 1), (1, 2)]:
+        q.add_token(query.TokenRange(*trange), query.TOKEN_PARTIAL, mktoken(1))
+        q.add_token(query.TokenRange(*trange), query.TOKEN_NEAR_ITEM, mktoken(2))
+        q.add_token(query.TokenRange(*trange), query.TOKEN_QUALIFIER, mktoken(3))
+
+    assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_PARTIAL)) == 1
+    assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_NEAR_ITEM)) == 0
+    assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_QUALIFIER)) == 1
+
+    assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_PARTIAL)) == 1
+    assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_NEAR_ITEM)) == 0
+    assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_QUALIFIER)) == 1
 
-    assert q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL) == []
-    assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.COUNTRY)) == 1