]> git.openstreetmap.org Git - nominatim.git/blob - test/python/api/search/test_api_search_query.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / test / python / api / search / test_api_search_query.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Tests for tokenized query data structures.
9 """
10 import pytest
11
12 from nominatim.api.search import query
13
14 class MyToken(query.Token):
15
16     def get_category(self):
17         return 'this', 'that'
18
19
20 def mktoken(tid: int):
21     return MyToken(3.0, tid, 1, 'foo', True)
22
23
24 @pytest.mark.parametrize('ptype,ttype', [('NONE', 'WORD'),
25                                          ('AMENITY', 'QUALIFIER'),
26                                          ('STREET', 'PARTIAL'),
27                                          ('CITY', 'WORD'),
28                                          ('COUNTRY', 'COUNTRY'),
29                                          ('POSTCODE', 'POSTCODE')])
30 def test_phrase_compatible(ptype, ttype):
31     assert query.PhraseType[ptype].compatible_with(query.TokenType[ttype])
32
33
34 @pytest.mark.parametrize('ptype', ['COUNTRY', 'POSTCODE'])
35 def test_phrase_incompatible(ptype):
36     assert not query.PhraseType[ptype].compatible_with(query.TokenType.PARTIAL)
37
38
39 def test_query_node_empty():
40     qn = query.QueryNode(query.BreakType.PHRASE, query.PhraseType.NONE)
41
42     assert not qn.has_tokens(3, query.TokenType.PARTIAL)
43     assert qn.get_tokens(3, query.TokenType.WORD) is None
44
45
46 def test_query_node_with_content():
47     qn = query.QueryNode(query.BreakType.PHRASE, query.PhraseType.NONE)
48     qn.starting.append(query.TokenList(2, query.TokenType.PARTIAL, [mktoken(100), mktoken(101)]))
49     qn.starting.append(query.TokenList(2, query.TokenType.WORD, [mktoken(1000)]))
50
51     assert not qn.has_tokens(3, query.TokenType.PARTIAL)
52     assert not qn.has_tokens(2, query.TokenType.COUNTRY)
53     assert qn.has_tokens(2, query.TokenType.PARTIAL)
54     assert qn.has_tokens(2, query.TokenType.WORD)
55
56     assert qn.get_tokens(3, query.TokenType.PARTIAL) is None
57     assert qn.get_tokens(2, query.TokenType.COUNTRY) is None
58     assert len(qn.get_tokens(2, query.TokenType.PARTIAL)) == 2
59     assert len(qn.get_tokens(2, query.TokenType.WORD)) == 1
60
61
62 def test_query_struct_empty():
63     q = query.QueryStruct([])
64
65     assert q.num_token_slots() == 0
66
67
68 def test_query_struct_with_tokens():
69     q = query.QueryStruct([query.Phrase(query.PhraseType.NONE, 'foo bar')])
70     q.add_node(query.BreakType.WORD, query.PhraseType.NONE)
71     q.add_node(query.BreakType.END, query.PhraseType.NONE)
72
73     assert q.num_token_slots() == 2
74
75     q.add_token(query.TokenRange(0, 1), query.TokenType.PARTIAL, mktoken(1))
76     q.add_token(query.TokenRange(1, 2), query.TokenType.PARTIAL, mktoken(2))
77     q.add_token(query.TokenRange(1, 2), query.TokenType.WORD, mktoken(99))
78     q.add_token(query.TokenRange(1, 2), query.TokenType.WORD, mktoken(98))
79
80     assert q.get_tokens(query.TokenRange(0, 2), query.TokenType.WORD) == []
81     assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.WORD)) == 2
82
83     partials = q.get_partials_list(query.TokenRange(0, 2))
84
85     assert len(partials) == 2
86     assert [t.token for t in partials] == [1, 2]
87
88     assert q.find_lookup_word_by_id(4) == 'None'
89     assert q.find_lookup_word_by_id(99) == '[W]foo'
90
91
92 def test_query_struct_incompatible_token():
93     q = query.QueryStruct([query.Phrase(query.PhraseType.COUNTRY, 'foo bar')])
94     q.add_node(query.BreakType.WORD, query.PhraseType.COUNTRY)
95     q.add_node(query.BreakType.END, query.PhraseType.NONE)
96
97     q.add_token(query.TokenRange(0, 1), query.TokenType.PARTIAL, mktoken(1))
98     q.add_token(query.TokenRange(1, 2), query.TokenType.COUNTRY, mktoken(100))
99
100     assert q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL) == []
101     assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.COUNTRY)) == 1