]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/token_assignment.py
port unit tests to new python package layout
[nominatim.git] / src / nominatim_api / search / token_assignment.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Create query interpretations where each vertice in the query is assigned
9 a specific function (expressed as a token type).
10 """
11 from typing import Optional, List, Iterator
12 import dataclasses
13
14 from ..logging import log
15 from . import query as qmod
16
17 # pylint: disable=too-many-return-statements,too-many-branches
18
19 @dataclasses.dataclass
20 class TypedRange:
21     """ A token range for a specific type of tokens.
22     """
23     ttype: qmod.TokenType
24     trange: qmod.TokenRange
25
26
27 PENALTY_TOKENCHANGE = {
28     qmod.BreakType.START: 0.0,
29     qmod.BreakType.END: 0.0,
30     qmod.BreakType.PHRASE: 0.0,
31     qmod.BreakType.WORD: 0.1,
32     qmod.BreakType.PART: 0.2,
33     qmod.BreakType.TOKEN: 0.4
34 }
35
36 TypedRangeSeq = List[TypedRange]
37
38 @dataclasses.dataclass
39 class TokenAssignment: # pylint: disable=too-many-instance-attributes
40     """ Representation of a possible assignment of token types
41         to the tokens in a tokenized query.
42     """
43     penalty: float = 0.0
44     name: Optional[qmod.TokenRange] = None
45     address: List[qmod.TokenRange] = dataclasses.field(default_factory=list)
46     housenumber: Optional[qmod.TokenRange] = None
47     postcode: Optional[qmod.TokenRange] = None
48     country: Optional[qmod.TokenRange] = None
49     near_item: Optional[qmod.TokenRange] = None
50     qualifier: Optional[qmod.TokenRange] = None
51
52
53     @staticmethod
54     def from_ranges(ranges: TypedRangeSeq) -> 'TokenAssignment':
55         """ Create a new token assignment from a sequence of typed spans.
56         """
57         out = TokenAssignment()
58         for token in ranges:
59             if token.ttype == qmod.TokenType.PARTIAL:
60                 out.address.append(token.trange)
61             elif token.ttype == qmod.TokenType.HOUSENUMBER:
62                 out.housenumber = token.trange
63             elif token.ttype == qmod.TokenType.POSTCODE:
64                 out.postcode = token.trange
65             elif token.ttype == qmod.TokenType.COUNTRY:
66                 out.country = token.trange
67             elif token.ttype == qmod.TokenType.NEAR_ITEM:
68                 out.near_item = token.trange
69             elif token.ttype == qmod.TokenType.QUALIFIER:
70                 out.qualifier = token.trange
71         return out
72
73
74 class _TokenSequence:
75     """ Working state used to put together the token assignments.
76
77         Represents an intermediate state while traversing the tokenized
78         query.
79     """
80     def __init__(self, seq: TypedRangeSeq,
81                  direction: int = 0, penalty: float = 0.0) -> None:
82         self.seq = seq
83         self.direction = direction
84         self.penalty = penalty
85
86
87     def __str__(self) -> str:
88         seq = ''.join(f'[{r.trange.start} - {r.trange.end}: {r.ttype.name}]' for r in self.seq)
89         return f'{seq} (dir: {self.direction}, penalty: {self.penalty})'
90
91
92     @property
93     def end_pos(self) -> int:
94         """ Return the index of the global end of the current sequence.
95         """
96         return self.seq[-1].trange.end if self.seq else 0
97
98
99     def has_types(self, *ttypes: qmod.TokenType) -> bool:
100         """ Check if the current sequence contains any typed ranges of
101             the given types.
102         """
103         return any(s.ttype in ttypes for s in self.seq)
104
105
106     def is_final(self) -> bool:
107         """ Return true when the sequence cannot be extended by any
108             form of token anymore.
109         """
110         # Country and category must be the final term for left-to-right
111         return len(self.seq) > 1 and \
112                self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.NEAR_ITEM)
113
114
115     def appendable(self, ttype: qmod.TokenType) -> Optional[int]:
116         """ Check if the give token type is appendable to the existing sequence.
117
118             Returns None if the token type is not appendable, otherwise the
119             new direction of the sequence after adding such a type. The
120             token is not added.
121         """
122         if ttype == qmod.TokenType.WORD:
123             return None
124
125         if not self.seq:
126             # Append unconditionally to the empty list
127             if ttype == qmod.TokenType.COUNTRY:
128                 return -1
129             if ttype in (qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
130                 return 1
131             return self.direction
132
133         # Name tokens are always acceptable and don't change direction
134         if ttype == qmod.TokenType.PARTIAL:
135             # qualifiers cannot appear in the middle of the query. They need
136             # to be near the next phrase.
137             if self.direction == -1 \
138                and any(t.ttype == qmod.TokenType.QUALIFIER for t in self.seq[:-1]):
139                 return None
140             return self.direction
141
142         # Other tokens may only appear once
143         if self.has_types(ttype):
144             return None
145
146         if ttype == qmod.TokenType.HOUSENUMBER:
147             if self.direction == 1:
148                 if len(self.seq) == 1 and self.seq[0].ttype == qmod.TokenType.QUALIFIER:
149                     return None
150                 if len(self.seq) > 2 \
151                    or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
152                     return None # direction left-to-right: housenumber must come before anything
153             elif self.direction == -1 \
154                  or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
155                 return -1 # force direction right-to-left if after other terms
156
157             return self.direction
158
159         if ttype == qmod.TokenType.POSTCODE:
160             if self.direction == -1:
161                 if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
162                     return None
163                 return -1
164             if self.direction == 1:
165                 return None if self.has_types(qmod.TokenType.COUNTRY) else 1
166             if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
167                 return 1
168             return self.direction
169
170         if ttype == qmod.TokenType.COUNTRY:
171             return None if self.direction == -1 else 1
172
173         if ttype == qmod.TokenType.NEAR_ITEM:
174             return self.direction
175
176         if ttype == qmod.TokenType.QUALIFIER:
177             if self.direction == 1:
178                 if (len(self.seq) == 1
179                     and self.seq[0].ttype in (qmod.TokenType.PARTIAL, qmod.TokenType.NEAR_ITEM)) \
180                    or (len(self.seq) == 2
181                        and self.seq[0].ttype == qmod.TokenType.NEAR_ITEM
182                        and self.seq[1].ttype == qmod.TokenType.PARTIAL):
183                     return 1
184                 return None
185             if self.direction == -1:
186                 return -1
187
188             tempseq = self.seq[1:] if self.seq[0].ttype == qmod.TokenType.NEAR_ITEM else self.seq
189             if len(tempseq) == 0:
190                 return 1
191             if len(tempseq) == 1 and self.seq[0].ttype == qmod.TokenType.HOUSENUMBER:
192                 return None
193             if len(tempseq) > 1 or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
194                 return -1
195             return 0
196
197         return None
198
199
200     def advance(self, ttype: qmod.TokenType, end_pos: int,
201                 btype: qmod.BreakType) -> Optional['_TokenSequence']:
202         """ Return a new token sequence state with the given token type
203             extended.
204         """
205         newdir = self.appendable(ttype)
206         if newdir is None:
207             return None
208
209         if not self.seq:
210             newseq = [TypedRange(ttype, qmod.TokenRange(0, end_pos))]
211             new_penalty = 0.0
212         else:
213             last = self.seq[-1]
214             if btype != qmod.BreakType.PHRASE and last.ttype == ttype:
215                 # extend the existing range
216                 newseq = self.seq[:-1] + [TypedRange(ttype, last.trange.replace_end(end_pos))]
217                 new_penalty = 0.0
218             else:
219                 # start a new range
220                 newseq = list(self.seq) + [TypedRange(ttype,
221                                                       qmod.TokenRange(last.trange.end, end_pos))]
222                 new_penalty = PENALTY_TOKENCHANGE[btype]
223
224         return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
225
226
227     def _adapt_penalty_from_priors(self, priors: int, new_dir: int) -> bool:
228         if priors >= 2:
229             if self.direction == 0:
230                 self.direction = new_dir
231             else:
232                 if priors == 2:
233                     self.penalty += 0.8
234                 else:
235                     return False
236
237         return True
238
239
240     def recheck_sequence(self) -> bool:
241         """ Check that the sequence is a fully valid token assignment
242             and adapt direction and penalties further if necessary.
243
244             This function catches some impossible assignments that need
245             forward context and can therefore not be excluded when building
246             the assignment.
247         """
248         # housenumbers may not be further than 2 words from the beginning.
249         # If there are two words in front, give it a penalty.
250         hnrpos = next((i for i, tr in enumerate(self.seq)
251                        if tr.ttype == qmod.TokenType.HOUSENUMBER),
252                       None)
253         if hnrpos is not None:
254             if self.direction != -1:
255                 priors = sum(1 for t in self.seq[:hnrpos] if t.ttype == qmod.TokenType.PARTIAL)
256                 if not self._adapt_penalty_from_priors(priors, -1):
257                     return False
258             if self.direction != 1:
259                 priors = sum(1 for t in self.seq[hnrpos+1:] if t.ttype == qmod.TokenType.PARTIAL)
260                 if not self._adapt_penalty_from_priors(priors, 1):
261                     return False
262             if any(t.ttype == qmod.TokenType.NEAR_ITEM for t in self.seq):
263                 self.penalty += 1.0
264
265         return True
266
267
268     def _get_assignments_postcode(self, base: TokenAssignment,
269                                   query_len: int)  -> Iterator[TokenAssignment]:
270         """ Yield possible assignments of Postcode searches with an
271             address component.
272         """
273         assert base.postcode is not None
274
275         if (base.postcode.start == 0 and self.direction != -1)\
276            or (base.postcode.end == query_len and self.direction != 1):
277             log().comment('postcode search')
278             # <address>,<postcode> should give preference to address search
279             if base.postcode.start == 0:
280                 penalty = self.penalty
281                 self.direction = -1 # name searches are only possible backwards
282             else:
283                 penalty = self.penalty + 0.1
284                 self.direction = 1 # name searches are only possible forwards
285             yield dataclasses.replace(base, penalty=penalty)
286
287
288     def _get_assignments_address_forward(self, base: TokenAssignment,
289                                          query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
290         """ Yield possible assignments of address searches with
291             left-to-right reading.
292         """
293         first = base.address[0]
294
295         log().comment('first word = name')
296         yield dataclasses.replace(base, penalty=self.penalty,
297                                   name=first, address=base.address[1:])
298
299         # To paraphrase:
300         #  * if another name term comes after the first one and before the
301         #    housenumber
302         #  * a qualifier comes after the name
303         #  * the containing phrase is strictly typed
304         if (base.housenumber and first.end < base.housenumber.start)\
305            or (base.qualifier and base.qualifier > first)\
306            or (query.nodes[first.start].ptype != qmod.PhraseType.NONE):
307             return
308
309         penalty = self.penalty
310
311         # Penalty for:
312         #  * <name>, <street>, <housenumber> , ...
313         #  * queries that are comma-separated
314         if (base.housenumber and base.housenumber > first) or len(query.source) > 1:
315             penalty += 0.25
316
317         for i in range(first.start + 1, first.end):
318             name, addr = first.split(i)
319             log().comment(f'split first word = name ({i - first.start})')
320             yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
321                                       penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
322
323
324     def _get_assignments_address_backward(self, base: TokenAssignment,
325                                           query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
326         """ Yield possible assignments of address searches with
327             right-to-left reading.
328         """
329         last = base.address[-1]
330
331         if self.direction == -1 or len(base.address) > 1:
332             log().comment('last word = name')
333             yield dataclasses.replace(base, penalty=self.penalty,
334                                       name=last, address=base.address[:-1])
335
336         # To paraphrase:
337         #  * if another name term comes before the last one and after the
338         #    housenumber
339         #  * a qualifier comes before the name
340         #  * the containing phrase is strictly typed
341         if (base.housenumber and last.start > base.housenumber.end)\
342            or (base.qualifier and base.qualifier < last)\
343            or (query.nodes[last.start].ptype != qmod.PhraseType.NONE):
344             return
345
346         penalty = self.penalty
347         if base.housenumber and base.housenumber < last:
348             penalty += 0.4
349         if len(query.source) > 1:
350             penalty += 0.25
351
352         for i in range(last.start + 1, last.end):
353             addr, name = last.split(i)
354             log().comment(f'split last word = name ({i - last.start})')
355             yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
356                                       penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
357
358
359     def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
360         """ Yield possible assignments for the current sequence.
361
362             This function splits up general name assignments into name
363             and address and yields all possible variants of that.
364         """
365         base = TokenAssignment.from_ranges(self.seq)
366
367         num_addr_tokens = sum(t.end - t.start for t in base.address)
368         if num_addr_tokens > 50:
369             return
370
371         # Postcode search (postcode-only search is covered in next case)
372         if base.postcode is not None and base.address:
373             yield from self._get_assignments_postcode(base, query.num_token_slots())
374
375         # Postcode or country-only search
376         if not base.address:
377             if not base.housenumber and (base.postcode or base.country or base.near_item):
378                 log().comment('postcode/country search')
379                 yield dataclasses.replace(base, penalty=self.penalty)
380         else:
381             # <postcode>,<address> should give preference to postcode search
382             if base.postcode and base.postcode.start == 0:
383                 self.penalty += 0.1
384
385             # Right-to-left reading of the address
386             if self.direction != -1:
387                 yield from self._get_assignments_address_forward(base, query)
388
389             # Left-to-right reading of the address
390             if self.direction != 1:
391                 yield from self._get_assignments_address_backward(base, query)
392
393             # variant for special housenumber searches
394             if base.housenumber and not base.qualifier:
395                 yield dataclasses.replace(base, penalty=self.penalty)
396
397
398 def yield_token_assignments(query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
399     """ Return possible word type assignments to word positions.
400
401         The assignments are computed from the concrete tokens listed
402         in the tokenized query.
403
404         The result includes the penalty for transitions from one word type to
405         another. It does not include penalties for transitions within a
406         type.
407     """
408     todo = [_TokenSequence([], direction=0 if query.source[0].ptype == qmod.PhraseType.NONE else 1)]
409
410     while todo:
411         state = todo.pop()
412         node = query.nodes[state.end_pos]
413
414         for tlist in node.starting:
415             newstate = state.advance(tlist.ttype, tlist.end, node.btype)
416             if newstate is not None:
417                 if newstate.end_pos == query.num_token_slots():
418                     if newstate.recheck_sequence():
419                         log().var_dump('Assignment', newstate)
420                         yield from newstate.get_assignments(query)
421                 elif not newstate.is_final():
422                     todo.append(newstate)