]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/query.py
Merge pull request #3678 from lonvia/search-tweaks
[nominatim.git] / src / nominatim_api / search / query.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Datastructures for a tokenized query.
9 """
10 from typing import Dict, List, Tuple, Optional, Iterator
11 from abc import ABC, abstractmethod
12 from collections import defaultdict
13 import dataclasses
14
15
16 BreakType = str
17 """ Type of break between tokens.
18 """
19 BREAK_START = '<'
20 """ Begin of the query. """
21 BREAK_END = '>'
22 """ End of the query. """
23 BREAK_PHRASE = ','
24 """ Hard break between two phrases. Address parts cannot cross hard
25     phrase boundaries."""
26 BREAK_SOFT_PHRASE = ':'
27 """ Likely break between two phrases. Address parts should not cross soft
28     phrase boundaries. Soft breaks can be inserted by a preprocessor
29     that is analysing the input string.
30 """
31 BREAK_WORD = ' '
32 """ Break between words. """
33 BREAK_PART = '-'
34 """ Break inside a word, for example a hyphen or apostrophe. """
35 BREAK_TOKEN = '`'
36 """ Break created as a result of tokenization.
37     This may happen in languages without spaces between words.
38 """
39
40
41 TokenType = str
42 """ Type of token.
43 """
44 TOKEN_WORD = 'W'
45 """ Full name of a place. """
46 TOKEN_PARTIAL = 'w'
47 """ Word term without breaks, does not necessarily represent a full name. """
48 TOKEN_HOUSENUMBER = 'H'
49 """ Housenumber term. """
50 TOKEN_POSTCODE = 'P'
51 """ Postal code term. """
52 TOKEN_COUNTRY = 'C'
53 """ Country name or reference. """
54 TOKEN_QUALIFIER = 'Q'
55 """ Special term used together with name (e.g. _Hotel_ Bellevue). """
56 TOKEN_NEAR_ITEM = 'N'
57 """ Special term used as searchable object(e.g. supermarket in ...). """
58
59
60 PhraseType = int
61 """ Designation of a phrase.
62 """
63 PHRASE_ANY = 0
64 """ No specific designation (i.e. source is free-form query). """
65 PHRASE_AMENITY = 1
66 """ Contains name or type of a POI. """
67 PHRASE_STREET = 2
68 """ Contains a street name optionally with a housenumber. """
69 PHRASE_CITY = 3
70 """ Contains the postal city. """
71 PHRASE_COUNTY = 4
72 """ Contains the equivalent of a county. """
73 PHRASE_STATE = 5
74 """ Contains a state or province. """
75 PHRASE_POSTCODE = 6
76 """ Contains a postal code. """
77 PHRASE_COUNTRY = 7
78 """ Contains the country name or code. """
79
80
81 def _phrase_compatible_with(ptype: PhraseType, ttype: TokenType,
82                             is_full_phrase: bool) -> bool:
83     """ Check if the given token type can be used with the phrase type.
84     """
85     if ptype == PHRASE_ANY:
86         return not is_full_phrase or ttype != TOKEN_QUALIFIER
87     if ptype == PHRASE_AMENITY:
88         return ttype in (TOKEN_WORD, TOKEN_PARTIAL)\
89                or (is_full_phrase and ttype == TOKEN_NEAR_ITEM)\
90                or (not is_full_phrase and ttype == TOKEN_QUALIFIER)
91     if ptype == PHRASE_STREET:
92         return ttype in (TOKEN_WORD, TOKEN_PARTIAL, TOKEN_HOUSENUMBER)
93     if ptype == PHRASE_POSTCODE:
94         return ttype == TOKEN_POSTCODE
95     if ptype == PHRASE_COUNTRY:
96         return ttype == TOKEN_COUNTRY
97
98     return ttype in (TOKEN_WORD, TOKEN_PARTIAL)
99
100
101 @dataclasses.dataclass
102 class Token(ABC):
103     """ Base type for tokens.
104         Specific query analyzers must implement the concrete token class.
105     """
106
107     penalty: float
108     token: int
109     count: int
110     addr_count: int
111     lookup_word: str
112
113     @abstractmethod
114     def get_category(self) -> Tuple[str, str]:
115         """ Return the category restriction for qualifier terms and
116             category objects.
117         """
118
119
120 @dataclasses.dataclass
121 class TokenRange:
122     """ Indexes of query nodes over which a token spans.
123     """
124     start: int
125     end: int
126     penalty: Optional[float] = None
127
128     def __lt__(self, other: 'TokenRange') -> bool:
129         return self.end <= other.start
130
131     def __le__(self, other: 'TokenRange') -> bool:
132         return NotImplemented
133
134     def __gt__(self, other: 'TokenRange') -> bool:
135         return self.start >= other.end
136
137     def __ge__(self, other: 'TokenRange') -> bool:
138         return NotImplemented
139
140     def replace_start(self, new_start: int) -> 'TokenRange':
141         """ Return a new token range with the new start.
142         """
143         return TokenRange(new_start, self.end)
144
145     def replace_end(self, new_end: int) -> 'TokenRange':
146         """ Return a new token range with the new end.
147         """
148         return TokenRange(self.start, new_end)
149
150     def split(self, index: int) -> Tuple['TokenRange', 'TokenRange']:
151         """ Split the span into two spans at the given index.
152             The index must be within the span.
153         """
154         return self.replace_end(index), self.replace_start(index)
155
156
157 @dataclasses.dataclass
158 class TokenList:
159     """ List of all tokens of a given type going from one breakpoint to another.
160     """
161     end: int
162     ttype: TokenType
163     tokens: List[Token]
164
165     def add_penalty(self, penalty: float) -> None:
166         """ Add the given penalty to all tokens in the list.
167         """
168         for token in self.tokens:
169             token.penalty += penalty
170
171
172 @dataclasses.dataclass
173 class QueryNode:
174     """ A node of the query representing a break between terms.
175
176         The node also contains information on the source term
177         ending at the node. The tokens are created from this information.
178     """
179     btype: BreakType
180     ptype: PhraseType
181
182     penalty: float
183     """ Penalty for the break at this node.
184     """
185     term_lookup: str
186     """ Transliterated term following this node.
187     """
188     term_normalized: str
189     """ Normalised form of term following this node.
190         When the token resulted from a split during transliteration,
191         then this string contains the complete source term.
192     """
193
194     starting: List[TokenList] = dataclasses.field(default_factory=list)
195
196     def adjust_break(self, btype: BreakType, penalty: float) -> None:
197         """ Change the break type and penalty for this node.
198         """
199         self.btype = btype
200         self.penalty = penalty
201
202     def has_tokens(self, end: int, *ttypes: TokenType) -> bool:
203         """ Check if there are tokens of the given types ending at the
204             given node.
205         """
206         return any(tl.end == end and tl.ttype in ttypes for tl in self.starting)
207
208     def get_tokens(self, end: int, ttype: TokenType) -> Optional[List[Token]]:
209         """ Get the list of tokens of the given type starting at this node
210             and ending at the node 'end'. Returns 'None' if no such
211             tokens exist.
212         """
213         for tlist in self.starting:
214             if tlist.end == end and tlist.ttype == ttype:
215                 return tlist.tokens
216         return None
217
218
219 @dataclasses.dataclass
220 class Phrase:
221     """ A normalized query part. Phrases may be typed which means that
222         they then represent a specific part of the address.
223     """
224     ptype: PhraseType
225     text: str
226
227
228 class QueryStruct:
229     """ A tokenized search query together with the normalized source
230         from which the tokens have been parsed.
231
232         The query contains a list of nodes that represent the breaks
233         between words. Tokens span between nodes, which don't necessarily
234         need to be direct neighbours. Thus the query is represented as a
235         directed acyclic graph.
236
237         When created, a query contains a single node: the start of the
238         query. Further nodes can be added by appending to 'nodes'.
239     """
240
241     def __init__(self, source: List[Phrase]) -> None:
242         self.source = source
243         self.nodes: List[QueryNode] = \
244             [QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY,
245                        0.0, '', '')]
246
247     def num_token_slots(self) -> int:
248         """ Return the length of the query in vertice steps.
249         """
250         return len(self.nodes) - 1
251
252     def add_node(self, btype: BreakType, ptype: PhraseType,
253                  break_penalty: float = 0.0,
254                  term_lookup: str = '', term_normalized: str = '') -> None:
255         """ Append a new break node with the given break type.
256             The phrase type denotes the type for any tokens starting
257             at the node.
258         """
259         self.nodes.append(QueryNode(btype, ptype, break_penalty, term_lookup, term_normalized))
260
261     def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
262         """ Add a token to the query. 'start' and 'end' are the indexes of the
263             nodes from which to which the token spans. The indexes must exist
264             and are expected to be in the same phrase.
265             'ttype' denotes the type of the token and 'token' the token to
266             be inserted.
267
268             If the token type is not compatible with the phrase it should
269             be added to, then the token is silently dropped.
270         """
271         snode = self.nodes[trange.start]
272         full_phrase = snode.btype in (BREAK_START, BREAK_PHRASE)\
273             and self.nodes[trange.end].btype in (BREAK_PHRASE, BREAK_END)
274         if _phrase_compatible_with(snode.ptype, ttype, full_phrase):
275             tlist = snode.get_tokens(trange.end, ttype)
276             if tlist is None:
277                 snode.starting.append(TokenList(trange.end, ttype, [token]))
278             else:
279                 tlist.append(token)
280
281     def get_tokens(self, trange: TokenRange, ttype: TokenType) -> List[Token]:
282         """ Get the list of tokens of a given type, spanning the given
283             nodes. The nodes must exist. If no tokens exist, an
284             empty list is returned.
285         """
286         return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
287
288     def get_partials_list(self, trange: TokenRange) -> List[Token]:
289         """ Create a list of partial tokens between the given nodes.
290             The list is composed of the first token of type PARTIAL
291             going to the subsequent node. Such PARTIAL tokens are
292             assumed to exist.
293         """
294         return [next(iter(self.get_tokens(TokenRange(i, i+1), TOKEN_PARTIAL)))
295                 for i in range(trange.start, trange.end)]
296
297     def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]:
298         """ Iterator over all token lists in the query.
299         """
300         for i, node in enumerate(self.nodes):
301             for tlist in node.starting:
302                 yield i, node, tlist
303
304     def find_lookup_word_by_id(self, token: int) -> str:
305         """ Find the first token with the given token ID and return
306             its lookup word. Returns 'None' if no such token exists.
307             The function is very slow and must only be used for
308             debugging.
309         """
310         for node in self.nodes:
311             for tlist in node.starting:
312                 for t in tlist.tokens:
313                     if t.token == token:
314                         return f"[{tlist.ttype}]{t.lookup_word}"
315         return 'None'
316
317     def get_transliterated_query(self) -> str:
318         """ Return a string representation of the transliterated query
319             with the character representation of the different break types.
320
321             For debugging purposes only.
322         """
323         return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes)
324
325     def extract_words(self, base_penalty: float = 0.0,
326                       start: int = 0,
327                       endpos: Optional[int] = None) -> Dict[str, List[TokenRange]]:
328         """ Add all combinations of words that can be formed from the terms
329             between the given start and endnode. The terms are joined with
330             spaces for each break. Words can never go across a BREAK_PHRASE.
331
332             The functions returns a dictionary of possible words with their
333             position within the query and a penalty. The penalty is computed
334             from the base_penalty plus the penalty for each node the word
335             crosses.
336         """
337         if endpos is None:
338             endpos = len(self.nodes)
339
340         words: Dict[str, List[TokenRange]] = defaultdict(list)
341
342         for first in range(start, endpos - 1):
343             word = self.nodes[first + 1].term_lookup
344             penalty = base_penalty
345             words[word].append(TokenRange(first, first + 1, penalty=penalty))
346             if self.nodes[first + 1].btype != BREAK_PHRASE:
347                 for last in range(first + 2, min(first + 20, endpos)):
348                     word = ' '.join((word, self.nodes[last].term_lookup))
349                     penalty += self.nodes[last - 1].penalty
350                     words[word].append(TokenRange(first, last, penalty=penalty))
351                     if self.nodes[last].btype == BREAK_PHRASE:
352                         break
353
354         return words