]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/token_assignment.py
fix style issue found by flake8
[nominatim.git] / src / nominatim_api / search / token_assignment.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Create query interpretations where each vertice in the query is assigned
9 a specific function (expressed as a token type).
10 """
11 from typing import Optional, List, Iterator
12 import dataclasses
13
14 from ..logging import log
15 from . import query as qmod
16
17
18 @dataclasses.dataclass
19 class TypedRange:
20     """ A token range for a specific type of tokens.
21     """
22     ttype: qmod.TokenType
23     trange: qmod.TokenRange
24
25
26 PENALTY_TOKENCHANGE = {
27     qmod.BreakType.START: 0.0,
28     qmod.BreakType.END: 0.0,
29     qmod.BreakType.PHRASE: 0.0,
30     qmod.BreakType.WORD: 0.1,
31     qmod.BreakType.PART: 0.2,
32     qmod.BreakType.TOKEN: 0.4
33 }
34
35 TypedRangeSeq = List[TypedRange]
36
37
38 @dataclasses.dataclass
39 class TokenAssignment:
40     """ Representation of a possible assignment of token types
41         to the tokens in a tokenized query.
42     """
43     penalty: float = 0.0
44     name: Optional[qmod.TokenRange] = None
45     address: List[qmod.TokenRange] = dataclasses.field(default_factory=list)
46     housenumber: Optional[qmod.TokenRange] = None
47     postcode: Optional[qmod.TokenRange] = None
48     country: Optional[qmod.TokenRange] = None
49     near_item: Optional[qmod.TokenRange] = None
50     qualifier: Optional[qmod.TokenRange] = None
51
52     @staticmethod
53     def from_ranges(ranges: TypedRangeSeq) -> 'TokenAssignment':
54         """ Create a new token assignment from a sequence of typed spans.
55         """
56         out = TokenAssignment()
57         for token in ranges:
58             if token.ttype == qmod.TokenType.PARTIAL:
59                 out.address.append(token.trange)
60             elif token.ttype == qmod.TokenType.HOUSENUMBER:
61                 out.housenumber = token.trange
62             elif token.ttype == qmod.TokenType.POSTCODE:
63                 out.postcode = token.trange
64             elif token.ttype == qmod.TokenType.COUNTRY:
65                 out.country = token.trange
66             elif token.ttype == qmod.TokenType.NEAR_ITEM:
67                 out.near_item = token.trange
68             elif token.ttype == qmod.TokenType.QUALIFIER:
69                 out.qualifier = token.trange
70         return out
71
72
73 class _TokenSequence:
74     """ Working state used to put together the token assignments.
75
76         Represents an intermediate state while traversing the tokenized
77         query.
78     """
79     def __init__(self, seq: TypedRangeSeq,
80                  direction: int = 0, penalty: float = 0.0) -> None:
81         self.seq = seq
82         self.direction = direction
83         self.penalty = penalty
84
85     def __str__(self) -> str:
86         seq = ''.join(f'[{r.trange.start} - {r.trange.end}: {r.ttype.name}]' for r in self.seq)
87         return f'{seq} (dir: {self.direction}, penalty: {self.penalty})'
88
89     @property
90     def end_pos(self) -> int:
91         """ Return the index of the global end of the current sequence.
92         """
93         return self.seq[-1].trange.end if self.seq else 0
94
95     def has_types(self, *ttypes: qmod.TokenType) -> bool:
96         """ Check if the current sequence contains any typed ranges of
97             the given types.
98         """
99         return any(s.ttype in ttypes for s in self.seq)
100
101     def is_final(self) -> bool:
102         """ Return true when the sequence cannot be extended by any
103             form of token anymore.
104         """
105         # Country and category must be the final term for left-to-right
106         return len(self.seq) > 1 and \
107             self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.NEAR_ITEM)
108
109     def appendable(self, ttype: qmod.TokenType) -> Optional[int]:
110         """ Check if the give token type is appendable to the existing sequence.
111
112             Returns None if the token type is not appendable, otherwise the
113             new direction of the sequence after adding such a type. The
114             token is not added.
115         """
116         if ttype == qmod.TokenType.WORD:
117             return None
118
119         if not self.seq:
120             # Append unconditionally to the empty list
121             if ttype == qmod.TokenType.COUNTRY:
122                 return -1
123             if ttype in (qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
124                 return 1
125             return self.direction
126
127         # Name tokens are always acceptable and don't change direction
128         if ttype == qmod.TokenType.PARTIAL:
129             # qualifiers cannot appear in the middle of the query. They need
130             # to be near the next phrase.
131             if self.direction == -1 \
132                and any(t.ttype == qmod.TokenType.QUALIFIER for t in self.seq[:-1]):
133                 return None
134             return self.direction
135
136         # Other tokens may only appear once
137         if self.has_types(ttype):
138             return None
139
140         if ttype == qmod.TokenType.HOUSENUMBER:
141             if self.direction == 1:
142                 if len(self.seq) == 1 and self.seq[0].ttype == qmod.TokenType.QUALIFIER:
143                     return None
144                 if len(self.seq) > 2 \
145                    or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
146                     return None  # direction left-to-right: housenumber must come before anything
147             elif (self.direction == -1
148                   or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY)):
149                 return -1  # force direction right-to-left if after other terms
150
151             return self.direction
152
153         if ttype == qmod.TokenType.POSTCODE:
154             if self.direction == -1:
155                 if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
156                     return None
157                 return -1
158             if self.direction == 1:
159                 return None if self.has_types(qmod.TokenType.COUNTRY) else 1
160             if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
161                 return 1
162             return self.direction
163
164         if ttype == qmod.TokenType.COUNTRY:
165             return None if self.direction == -1 else 1
166
167         if ttype == qmod.TokenType.NEAR_ITEM:
168             return self.direction
169
170         if ttype == qmod.TokenType.QUALIFIER:
171             if self.direction == 1:
172                 if (len(self.seq) == 1
173                     and self.seq[0].ttype in (qmod.TokenType.PARTIAL, qmod.TokenType.NEAR_ITEM)) \
174                    or (len(self.seq) == 2
175                        and self.seq[0].ttype == qmod.TokenType.NEAR_ITEM
176                        and self.seq[1].ttype == qmod.TokenType.PARTIAL):
177                     return 1
178                 return None
179             if self.direction == -1:
180                 return -1
181
182             tempseq = self.seq[1:] if self.seq[0].ttype == qmod.TokenType.NEAR_ITEM else self.seq
183             if len(tempseq) == 0:
184                 return 1
185             if len(tempseq) == 1 and self.seq[0].ttype == qmod.TokenType.HOUSENUMBER:
186                 return None
187             if len(tempseq) > 1 or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
188                 return -1
189             return 0
190
191         return None
192
193     def advance(self, ttype: qmod.TokenType, end_pos: int,
194                 btype: qmod.BreakType) -> Optional['_TokenSequence']:
195         """ Return a new token sequence state with the given token type
196             extended.
197         """
198         newdir = self.appendable(ttype)
199         if newdir is None:
200             return None
201
202         if not self.seq:
203             newseq = [TypedRange(ttype, qmod.TokenRange(0, end_pos))]
204             new_penalty = 0.0
205         else:
206             last = self.seq[-1]
207             if btype != qmod.BreakType.PHRASE and last.ttype == ttype:
208                 # extend the existing range
209                 newseq = self.seq[:-1] + [TypedRange(ttype, last.trange.replace_end(end_pos))]
210                 new_penalty = 0.0
211             else:
212                 # start a new range
213                 newseq = list(self.seq) + [TypedRange(ttype,
214                                                       qmod.TokenRange(last.trange.end, end_pos))]
215                 new_penalty = PENALTY_TOKENCHANGE[btype]
216
217         return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
218
219     def _adapt_penalty_from_priors(self, priors: int, new_dir: int) -> bool:
220         if priors >= 2:
221             if self.direction == 0:
222                 self.direction = new_dir
223             else:
224                 if priors == 2:
225                     self.penalty += 0.8
226                 else:
227                     return False
228
229         return True
230
231     def recheck_sequence(self) -> bool:
232         """ Check that the sequence is a fully valid token assignment
233             and adapt direction and penalties further if necessary.
234
235             This function catches some impossible assignments that need
236             forward context and can therefore not be excluded when building
237             the assignment.
238         """
239         # housenumbers may not be further than 2 words from the beginning.
240         # If there are two words in front, give it a penalty.
241         hnrpos = next((i for i, tr in enumerate(self.seq)
242                        if tr.ttype == qmod.TokenType.HOUSENUMBER),
243                       None)
244         if hnrpos is not None:
245             if self.direction != -1:
246                 priors = sum(1 for t in self.seq[:hnrpos] if t.ttype == qmod.TokenType.PARTIAL)
247                 if not self._adapt_penalty_from_priors(priors, -1):
248                     return False
249             if self.direction != 1:
250                 priors = sum(1 for t in self.seq[hnrpos+1:] if t.ttype == qmod.TokenType.PARTIAL)
251                 if not self._adapt_penalty_from_priors(priors, 1):
252                     return False
253             if any(t.ttype == qmod.TokenType.NEAR_ITEM for t in self.seq):
254                 self.penalty += 1.0
255
256         return True
257
258     def _get_assignments_postcode(self, base: TokenAssignment,
259                                   query_len: int) -> Iterator[TokenAssignment]:
260         """ Yield possible assignments of Postcode searches with an
261             address component.
262         """
263         assert base.postcode is not None
264
265         if (base.postcode.start == 0 and self.direction != -1)\
266            or (base.postcode.end == query_len and self.direction != 1):
267             log().comment('postcode search')
268             # <address>,<postcode> should give preference to address search
269             if base.postcode.start == 0:
270                 penalty = self.penalty
271                 self.direction = -1  # name searches are only possible backwards
272             else:
273                 penalty = self.penalty + 0.1
274                 self.direction = 1  # name searches are only possible forwards
275             yield dataclasses.replace(base, penalty=penalty)
276
277     def _get_assignments_address_forward(self, base: TokenAssignment,
278                                          query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
279         """ Yield possible assignments of address searches with
280             left-to-right reading.
281         """
282         first = base.address[0]
283
284         log().comment('first word = name')
285         yield dataclasses.replace(base, penalty=self.penalty,
286                                   name=first, address=base.address[1:])
287
288         # To paraphrase:
289         #  * if another name term comes after the first one and before the
290         #    housenumber
291         #  * a qualifier comes after the name
292         #  * the containing phrase is strictly typed
293         if (base.housenumber and first.end < base.housenumber.start)\
294            or (base.qualifier and base.qualifier > first)\
295            or (query.nodes[first.start].ptype != qmod.PhraseType.NONE):
296             return
297
298         penalty = self.penalty
299
300         # Penalty for:
301         #  * <name>, <street>, <housenumber> , ...
302         #  * queries that are comma-separated
303         if (base.housenumber and base.housenumber > first) or len(query.source) > 1:
304             penalty += 0.25
305
306         for i in range(first.start + 1, first.end):
307             name, addr = first.split(i)
308             log().comment(f'split first word = name ({i - first.start})')
309             yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
310                                       penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
311
312     def _get_assignments_address_backward(self, base: TokenAssignment,
313                                           query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
314         """ Yield possible assignments of address searches with
315             right-to-left reading.
316         """
317         last = base.address[-1]
318
319         if self.direction == -1 or len(base.address) > 1:
320             log().comment('last word = name')
321             yield dataclasses.replace(base, penalty=self.penalty,
322                                       name=last, address=base.address[:-1])
323
324         # To paraphrase:
325         #  * if another name term comes before the last one and after the
326         #    housenumber
327         #  * a qualifier comes before the name
328         #  * the containing phrase is strictly typed
329         if (base.housenumber and last.start > base.housenumber.end)\
330            or (base.qualifier and base.qualifier < last)\
331            or (query.nodes[last.start].ptype != qmod.PhraseType.NONE):
332             return
333
334         penalty = self.penalty
335         if base.housenumber and base.housenumber < last:
336             penalty += 0.4
337         if len(query.source) > 1:
338             penalty += 0.25
339
340         for i in range(last.start + 1, last.end):
341             addr, name = last.split(i)
342             log().comment(f'split last word = name ({i - last.start})')
343             yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
344                                       penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
345
346     def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
347         """ Yield possible assignments for the current sequence.
348
349             This function splits up general name assignments into name
350             and address and yields all possible variants of that.
351         """
352         base = TokenAssignment.from_ranges(self.seq)
353
354         num_addr_tokens = sum(t.end - t.start for t in base.address)
355         if num_addr_tokens > 50:
356             return
357
358         # Postcode search (postcode-only search is covered in next case)
359         if base.postcode is not None and base.address:
360             yield from self._get_assignments_postcode(base, query.num_token_slots())
361
362         # Postcode or country-only search
363         if not base.address:
364             if not base.housenumber and (base.postcode or base.country or base.near_item):
365                 log().comment('postcode/country search')
366                 yield dataclasses.replace(base, penalty=self.penalty)
367         else:
368             # <postcode>,<address> should give preference to postcode search
369             if base.postcode and base.postcode.start == 0:
370                 self.penalty += 0.1
371
372             # Right-to-left reading of the address
373             if self.direction != -1:
374                 yield from self._get_assignments_address_forward(base, query)
375
376             # Left-to-right reading of the address
377             if self.direction != 1:
378                 yield from self._get_assignments_address_backward(base, query)
379
380             # variant for special housenumber searches
381             if base.housenumber and not base.qualifier:
382                 yield dataclasses.replace(base, penalty=self.penalty)
383
384
385 def yield_token_assignments(query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
386     """ Return possible word type assignments to word positions.
387
388         The assignments are computed from the concrete tokens listed
389         in the tokenized query.
390
391         The result includes the penalty for transitions from one word type to
392         another. It does not include penalties for transitions within a
393         type.
394     """
395     todo = [_TokenSequence([], direction=0 if query.source[0].ptype == qmod.PhraseType.NONE else 1)]
396
397     while todo:
398         state = todo.pop()
399         node = query.nodes[state.end_pos]
400
401         for tlist in node.starting:
402             newstate = state.advance(tlist.ttype, tlist.end, node.btype)
403             if newstate is not None:
404                 if newstate.end_pos == query.num_token_slots():
405                     if newstate.recheck_sequence():
406                         log().var_dump('Assignment', newstate)
407                         yield from newstate.get_assignments(query)
408                 elif not newstate.is_final():
409                     todo.append(newstate)