]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/query.py
search: merge QueryPart array with QueryNodes
[nominatim.git] / src / nominatim_api / search / query.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Datastructures for a tokenized query.
9 """
10 from typing import List, Tuple, Optional, Iterator
11 from abc import ABC, abstractmethod
12 import dataclasses
13
14
15 BreakType = str
16 """ Type of break between tokens.
17 """
18 BREAK_START = '<'
19 """ Begin of the query. """
20 BREAK_END = '>'
21 """ End of the query. """
22 BREAK_PHRASE = ','
23 """ Hard break between two phrases. Address parts cannot cross hard
24     phrase boundaries."""
25 BREAK_SOFT_PHRASE = ':'
26 """ Likely break between two phrases. Address parts should not cross soft
27     phrase boundaries. Soft breaks can be inserted by a preprocessor
28     that is analysing the input string.
29 """
30 BREAK_WORD = ' '
31 """ Break between words. """
32 BREAK_PART = '-'
33 """ Break inside a word, for example a hyphen or apostrophe. """
34 BREAK_TOKEN = '`'
35 """ Break created as a result of tokenization.
36     This may happen in languages without spaces between words.
37 """
38
39
40 TokenType = str
41 """ Type of token.
42 """
43 TOKEN_WORD = 'W'
44 """ Full name of a place. """
45 TOKEN_PARTIAL = 'w'
46 """ Word term without breaks, does not necessarily represent a full name. """
47 TOKEN_HOUSENUMBER = 'H'
48 """ Housenumber term. """
49 TOKEN_POSTCODE = 'P'
50 """ Postal code term. """
51 TOKEN_COUNTRY = 'C'
52 """ Country name or reference. """
53 TOKEN_QUALIFIER = 'Q'
54 """ Special term used together with name (e.g. _Hotel_ Bellevue). """
55 TOKEN_NEAR_ITEM = 'N'
56 """ Special term used as searchable object(e.g. supermarket in ...). """
57
58
59 PhraseType = int
60 """ Designation of a phrase.
61 """
62 PHRASE_ANY = 0
63 """ No specific designation (i.e. source is free-form query). """
64 PHRASE_AMENITY = 1
65 """ Contains name or type of a POI. """
66 PHRASE_STREET = 2
67 """ Contains a street name optionally with a housenumber. """
68 PHRASE_CITY = 3
69 """ Contains the postal city. """
70 PHRASE_COUNTY = 4
71 """ Contains the equivalent of a county. """
72 PHRASE_STATE = 5
73 """ Contains a state or province. """
74 PHRASE_POSTCODE = 6
75 """ Contains a postal code. """
76 PHRASE_COUNTRY = 7
77 """ Contains the country name or code. """
78
79
80 def _phrase_compatible_with(ptype: PhraseType, ttype: TokenType,
81                             is_full_phrase: bool) -> bool:
82     """ Check if the given token type can be used with the phrase type.
83     """
84     if ptype == PHRASE_ANY:
85         return not is_full_phrase or ttype != TOKEN_QUALIFIER
86     if ptype == PHRASE_AMENITY:
87         return ttype in (TOKEN_WORD, TOKEN_PARTIAL)\
88                or (is_full_phrase and ttype == TOKEN_NEAR_ITEM)\
89                or (not is_full_phrase and ttype == TOKEN_QUALIFIER)
90     if ptype == PHRASE_STREET:
91         return ttype in (TOKEN_WORD, TOKEN_PARTIAL, TOKEN_HOUSENUMBER)
92     if ptype == PHRASE_POSTCODE:
93         return ttype == TOKEN_POSTCODE
94     if ptype == PHRASE_COUNTRY:
95         return ttype == TOKEN_COUNTRY
96
97     return ttype in (TOKEN_WORD, TOKEN_PARTIAL)
98
99
100 @dataclasses.dataclass
101 class Token(ABC):
102     """ Base type for tokens.
103         Specific query analyzers must implement the concrete token class.
104     """
105
106     penalty: float
107     token: int
108     count: int
109     addr_count: int
110     lookup_word: str
111
112     @abstractmethod
113     def get_category(self) -> Tuple[str, str]:
114         """ Return the category restriction for qualifier terms and
115             category objects.
116         """
117
118
119 @dataclasses.dataclass
120 class TokenRange:
121     """ Indexes of query nodes over which a token spans.
122     """
123     start: int
124     end: int
125     penalty: Optional[float] = None
126
127     def __lt__(self, other: 'TokenRange') -> bool:
128         return self.end <= other.start
129
130     def __le__(self, other: 'TokenRange') -> bool:
131         return NotImplemented
132
133     def __gt__(self, other: 'TokenRange') -> bool:
134         return self.start >= other.end
135
136     def __ge__(self, other: 'TokenRange') -> bool:
137         return NotImplemented
138
139     def replace_start(self, new_start: int) -> 'TokenRange':
140         """ Return a new token range with the new start.
141         """
142         return TokenRange(new_start, self.end)
143
144     def replace_end(self, new_end: int) -> 'TokenRange':
145         """ Return a new token range with the new end.
146         """
147         return TokenRange(self.start, new_end)
148
149     def split(self, index: int) -> Tuple['TokenRange', 'TokenRange']:
150         """ Split the span into two spans at the given index.
151             The index must be within the span.
152         """
153         return self.replace_end(index), self.replace_start(index)
154
155
156 @dataclasses.dataclass
157 class TokenList:
158     """ List of all tokens of a given type going from one breakpoint to another.
159     """
160     end: int
161     ttype: TokenType
162     tokens: List[Token]
163
164     def add_penalty(self, penalty: float) -> None:
165         """ Add the given penalty to all tokens in the list.
166         """
167         for token in self.tokens:
168             token.penalty += penalty
169
170
171 @dataclasses.dataclass
172 class QueryNode:
173     """ A node of the query representing a break between terms.
174
175         The node also contains information on the source term
176         ending at the node. The tokens are created from this information.
177     """
178     btype: BreakType
179     ptype: PhraseType
180
181     penalty: float
182     """ Penalty for the break at this node.
183     """
184     term_lookup: str
185     """ Transliterated term following this node.
186     """
187     term_normalized: str
188     """ Normalised form of term following this node.
189         When the token resulted from a split during transliteration,
190         then this string contains the complete source term.
191     """
192
193     starting: List[TokenList] = dataclasses.field(default_factory=list)
194
195     def adjust_break(self, btype: BreakType, penalty: float) -> None:
196         """ Change the break type and penalty for this node.
197         """
198         self.btype = btype
199         self.penalty = penalty
200
201     def has_tokens(self, end: int, *ttypes: TokenType) -> bool:
202         """ Check if there are tokens of the given types ending at the
203             given node.
204         """
205         return any(tl.end == end and tl.ttype in ttypes for tl in self.starting)
206
207     def get_tokens(self, end: int, ttype: TokenType) -> Optional[List[Token]]:
208         """ Get the list of tokens of the given type starting at this node
209             and ending at the node 'end'. Returns 'None' if no such
210             tokens exist.
211         """
212         for tlist in self.starting:
213             if tlist.end == end and tlist.ttype == ttype:
214                 return tlist.tokens
215         return None
216
217
218 @dataclasses.dataclass
219 class Phrase:
220     """ A normalized query part. Phrases may be typed which means that
221         they then represent a specific part of the address.
222     """
223     ptype: PhraseType
224     text: str
225
226
227 class QueryStruct:
228     """ A tokenized search query together with the normalized source
229         from which the tokens have been parsed.
230
231         The query contains a list of nodes that represent the breaks
232         between words. Tokens span between nodes, which don't necessarily
233         need to be direct neighbours. Thus the query is represented as a
234         directed acyclic graph.
235
236         When created, a query contains a single node: the start of the
237         query. Further nodes can be added by appending to 'nodes'.
238     """
239
240     def __init__(self, source: List[Phrase]) -> None:
241         self.source = source
242         self.nodes: List[QueryNode] = \
243             [QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY,
244                        0.0, '', '')]
245
246     def num_token_slots(self) -> int:
247         """ Return the length of the query in vertice steps.
248         """
249         return len(self.nodes) - 1
250
251     def add_node(self, btype: BreakType, ptype: PhraseType,
252                  break_penalty: float = 0.0,
253                  term_lookup: str = '', term_normalized: str = '') -> None:
254         """ Append a new break node with the given break type.
255             The phrase type denotes the type for any tokens starting
256             at the node.
257         """
258         self.nodes.append(QueryNode(btype, ptype, break_penalty, term_lookup, term_normalized))
259
260     def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
261         """ Add a token to the query. 'start' and 'end' are the indexes of the
262             nodes from which to which the token spans. The indexes must exist
263             and are expected to be in the same phrase.
264             'ttype' denotes the type of the token and 'token' the token to
265             be inserted.
266
267             If the token type is not compatible with the phrase it should
268             be added to, then the token is silently dropped.
269         """
270         snode = self.nodes[trange.start]
271         full_phrase = snode.btype in (BREAK_START, BREAK_PHRASE)\
272             and self.nodes[trange.end].btype in (BREAK_PHRASE, BREAK_END)
273         if _phrase_compatible_with(snode.ptype, ttype, full_phrase):
274             tlist = snode.get_tokens(trange.end, ttype)
275             if tlist is None:
276                 snode.starting.append(TokenList(trange.end, ttype, [token]))
277             else:
278                 tlist.append(token)
279
280     def get_tokens(self, trange: TokenRange, ttype: TokenType) -> List[Token]:
281         """ Get the list of tokens of a given type, spanning the given
282             nodes. The nodes must exist. If no tokens exist, an
283             empty list is returned.
284         """
285         return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
286
287     def get_partials_list(self, trange: TokenRange) -> List[Token]:
288         """ Create a list of partial tokens between the given nodes.
289             The list is composed of the first token of type PARTIAL
290             going to the subsequent node. Such PARTIAL tokens are
291             assumed to exist.
292         """
293         return [next(iter(self.get_tokens(TokenRange(i, i+1), TOKEN_PARTIAL)))
294                 for i in range(trange.start, trange.end)]
295
296     def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]:
297         """ Iterator over all token lists in the query.
298         """
299         for i, node in enumerate(self.nodes):
300             for tlist in node.starting:
301                 yield i, node, tlist
302
303     def find_lookup_word_by_id(self, token: int) -> str:
304         """ Find the first token with the given token ID and return
305             its lookup word. Returns 'None' if no such token exists.
306             The function is very slow and must only be used for
307             debugging.
308         """
309         for node in self.nodes:
310             for tlist in node.starting:
311                 for t in tlist.tokens:
312                     if t.token == token:
313                         return f"[{tlist.ttype}]{t.lookup_word}"
314         return 'None'
315
316     def get_transliterated_query(self) -> str:
317         """ Return a string representation of the transliterated query
318             with the character representation of the different break types.
319
320             For debugging purposes only.
321         """
322         return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes)