]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/token_assignment.py
de75318ae24b51f3d27e8d4de23c624b65c6d7b9
[nominatim.git] / src / nominatim_api / search / token_assignment.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Create query interpretations where each vertice in the query is assigned
9 a specific function (expressed as a token type).
10 """
11 from typing import Optional, List, Iterator
12 import dataclasses
13
14 from ..logging import log
15 from . import query as qmod
16
17
18 @dataclasses.dataclass
19 class TypedRange:
20     """ A token range for a specific type of tokens.
21     """
22     ttype: qmod.TokenType
23     trange: qmod.TokenRange
24
25
26 PENALTY_TOKENCHANGE = {
27     qmod.BREAK_START: 0.0,
28     qmod.BREAK_END: 0.0,
29     qmod.BREAK_PHRASE: 0.0,
30     qmod.BREAK_SOFT_PHRASE: 0.0,
31     qmod.BREAK_WORD: 0.1,
32     qmod.BREAK_PART: 0.2,
33     qmod.BREAK_TOKEN: 0.4
34 }
35
36 TypedRangeSeq = List[TypedRange]
37
38
39 @dataclasses.dataclass
40 class TokenAssignment:
41     """ Representation of a possible assignment of token types
42         to the tokens in a tokenized query.
43     """
44     penalty: float = 0.0
45     name: Optional[qmod.TokenRange] = None
46     address: List[qmod.TokenRange] = dataclasses.field(default_factory=list)
47     housenumber: Optional[qmod.TokenRange] = None
48     postcode: Optional[qmod.TokenRange] = None
49     country: Optional[qmod.TokenRange] = None
50     near_item: Optional[qmod.TokenRange] = None
51     qualifier: Optional[qmod.TokenRange] = None
52
53     @staticmethod
54     def from_ranges(ranges: TypedRangeSeq) -> 'TokenAssignment':
55         """ Create a new token assignment from a sequence of typed spans.
56         """
57         out = TokenAssignment()
58         for token in ranges:
59             if token.ttype == qmod.TOKEN_PARTIAL:
60                 out.address.append(token.trange)
61             elif token.ttype == qmod.TOKEN_HOUSENUMBER:
62                 out.housenumber = token.trange
63             elif token.ttype == qmod.TOKEN_POSTCODE:
64                 out.postcode = token.trange
65             elif token.ttype == qmod.TOKEN_COUNTRY:
66                 out.country = token.trange
67             elif token.ttype == qmod.TOKEN_NEAR_ITEM:
68                 out.near_item = token.trange
69             elif token.ttype == qmod.TOKEN_QUALIFIER:
70                 out.qualifier = token.trange
71         return out
72
73
74 class _TokenSequence:
75     """ Working state used to put together the token assignments.
76
77         Represents an intermediate state while traversing the tokenized
78         query.
79     """
80     def __init__(self, seq: TypedRangeSeq,
81                  direction: int = 0, penalty: float = 0.0) -> None:
82         self.seq = seq
83         self.direction = direction
84         self.penalty = penalty
85
86     def __str__(self) -> str:
87         seq = ''.join(f'[{r.trange.start} - {r.trange.end}: {r.ttype}]' for r in self.seq)
88         return f'{seq} (dir: {self.direction}, penalty: {self.penalty})'
89
90     @property
91     def end_pos(self) -> int:
92         """ Return the index of the global end of the current sequence.
93         """
94         return self.seq[-1].trange.end if self.seq else 0
95
96     def has_types(self, *ttypes: qmod.TokenType) -> bool:
97         """ Check if the current sequence contains any typed ranges of
98             the given types.
99         """
100         return any(s.ttype in ttypes for s in self.seq)
101
102     def is_final(self) -> bool:
103         """ Return true when the sequence cannot be extended by any
104             form of token anymore.
105         """
106         # Country and category must be the final term for left-to-right
107         return len(self.seq) > 1 and \
108             self.seq[-1].ttype in (qmod.TOKEN_COUNTRY, qmod.TOKEN_NEAR_ITEM)
109
110     def appendable(self, ttype: qmod.TokenType) -> Optional[int]:
111         """ Check if the give token type is appendable to the existing sequence.
112
113             Returns None if the token type is not appendable, otherwise the
114             new direction of the sequence after adding such a type. The
115             token is not added.
116         """
117         if ttype == qmod.TOKEN_WORD:
118             return None
119
120         if not self.seq:
121             # Append unconditionally to the empty list
122             if ttype == qmod.TOKEN_COUNTRY:
123                 return -1
124             if ttype in (qmod.TOKEN_HOUSENUMBER, qmod.TOKEN_QUALIFIER):
125                 return 1
126             return self.direction
127
128         # Name tokens are always acceptable and don't change direction
129         if ttype == qmod.TOKEN_PARTIAL:
130             # qualifiers cannot appear in the middle of the query. They need
131             # to be near the next phrase.
132             if self.direction == -1 \
133                and any(t.ttype == qmod.TOKEN_QUALIFIER for t in self.seq[:-1]):
134                 return None
135             return self.direction
136
137         # Other tokens may only appear once
138         if self.has_types(ttype):
139             return None
140
141         if ttype == qmod.TOKEN_HOUSENUMBER:
142             if self.direction == 1:
143                 if len(self.seq) == 1 and self.seq[0].ttype == qmod.TOKEN_QUALIFIER:
144                     return None
145                 if len(self.seq) > 2 \
146                    or self.has_types(qmod.TOKEN_POSTCODE, qmod.TOKEN_COUNTRY):
147                     return None  # direction left-to-right: housenumber must come before anything
148             elif (self.direction == -1
149                   or self.has_types(qmod.TOKEN_POSTCODE, qmod.TOKEN_COUNTRY)):
150                 return -1  # force direction right-to-left if after other terms
151
152             return self.direction
153
154         if ttype == qmod.TOKEN_POSTCODE:
155             if self.direction == -1:
156                 if self.has_types(qmod.TOKEN_HOUSENUMBER, qmod.TOKEN_QUALIFIER):
157                     return None
158                 return -1
159             if self.direction == 1:
160                 return None if self.has_types(qmod.TOKEN_COUNTRY) else 1
161             if self.has_types(qmod.TOKEN_HOUSENUMBER, qmod.TOKEN_QUALIFIER):
162                 return 1
163             return self.direction
164
165         if ttype == qmod.TOKEN_COUNTRY:
166             return None if self.direction == -1 else 1
167
168         if ttype == qmod.TOKEN_NEAR_ITEM:
169             return self.direction
170
171         if ttype == qmod.TOKEN_QUALIFIER:
172             if self.direction == 1:
173                 if (len(self.seq) == 1
174                     and self.seq[0].ttype in (qmod.TOKEN_PARTIAL, qmod.TOKEN_NEAR_ITEM)) \
175                    or (len(self.seq) == 2
176                        and self.seq[0].ttype == qmod.TOKEN_NEAR_ITEM
177                        and self.seq[1].ttype == qmod.TOKEN_PARTIAL):
178                     return 1
179                 return None
180             if self.direction == -1:
181                 return -1
182
183             tempseq = self.seq[1:] if self.seq[0].ttype == qmod.TOKEN_NEAR_ITEM else self.seq
184             if len(tempseq) == 0:
185                 return 1
186             if len(tempseq) == 1 and self.seq[0].ttype == qmod.TOKEN_HOUSENUMBER:
187                 return None
188             if len(tempseq) > 1 or self.has_types(qmod.TOKEN_POSTCODE, qmod.TOKEN_COUNTRY):
189                 return -1
190             return 0
191
192         return None
193
194     def advance(self, ttype: qmod.TokenType, end_pos: int,
195                 btype: qmod.BreakType) -> Optional['_TokenSequence']:
196         """ Return a new token sequence state with the given token type
197             extended.
198         """
199         newdir = self.appendable(ttype)
200         if newdir is None:
201             return None
202
203         if not self.seq:
204             newseq = [TypedRange(ttype, qmod.TokenRange(0, end_pos))]
205             new_penalty = 0.0
206         else:
207             last = self.seq[-1]
208             if btype != qmod.BREAK_PHRASE and last.ttype == ttype:
209                 # extend the existing range
210                 newseq = self.seq[:-1] + [TypedRange(ttype, last.trange.replace_end(end_pos))]
211                 new_penalty = 0.0
212             else:
213                 # start a new range
214                 newseq = list(self.seq) + [TypedRange(ttype,
215                                                       qmod.TokenRange(last.trange.end, end_pos))]
216                 new_penalty = PENALTY_TOKENCHANGE[btype]
217
218         return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
219
220     def _adapt_penalty_from_priors(self, priors: int, new_dir: int) -> bool:
221         if priors >= 2:
222             if self.direction == 0:
223                 self.direction = new_dir
224             else:
225                 if priors == 2:
226                     self.penalty += 0.8
227                 else:
228                     return False
229
230         return True
231
232     def recheck_sequence(self) -> bool:
233         """ Check that the sequence is a fully valid token assignment
234             and adapt direction and penalties further if necessary.
235
236             This function catches some impossible assignments that need
237             forward context and can therefore not be excluded when building
238             the assignment.
239         """
240         # housenumbers may not be further than 2 words from the beginning.
241         # If there are two words in front, give it a penalty.
242         hnrpos = next((i for i, tr in enumerate(self.seq)
243                        if tr.ttype == qmod.TOKEN_HOUSENUMBER),
244                       None)
245         if hnrpos is not None:
246             if self.direction != -1:
247                 priors = sum(1 for t in self.seq[:hnrpos] if t.ttype == qmod.TOKEN_PARTIAL)
248                 if not self._adapt_penalty_from_priors(priors, -1):
249                     return False
250             if self.direction != 1:
251                 priors = sum(1 for t in self.seq[hnrpos+1:] if t.ttype == qmod.TOKEN_PARTIAL)
252                 if not self._adapt_penalty_from_priors(priors, 1):
253                     return False
254             if any(t.ttype == qmod.TOKEN_NEAR_ITEM for t in self.seq):
255                 self.penalty += 1.0
256
257         return True
258
259     def _get_assignments_postcode(self, base: TokenAssignment,
260                                   query_len: int) -> Iterator[TokenAssignment]:
261         """ Yield possible assignments of Postcode searches with an
262             address component.
263         """
264         assert base.postcode is not None
265
266         if (base.postcode.start == 0 and self.direction != -1)\
267            or (base.postcode.end == query_len and self.direction != 1):
268             log().comment('postcode search')
269             # <address>,<postcode> should give preference to address search
270             if base.postcode.start == 0:
271                 penalty = self.penalty
272             else:
273                 penalty = self.penalty + 0.1
274             yield dataclasses.replace(base, penalty=penalty)
275
276     def _get_assignments_address_forward(self, base: TokenAssignment,
277                                          query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
278         """ Yield possible assignments of address searches with
279             left-to-right reading.
280         """
281         first = base.address[0]
282
283         # The postcode must come after the name.
284         if base.postcode and base.postcode < first:
285             log().var_dump('skip forward', (base.postcode, first))
286             return
287
288         log().comment('first word = name')
289         yield dataclasses.replace(base, penalty=self.penalty,
290                                   name=first, address=base.address[1:])
291
292         # To paraphrase:
293         #  * if another name term comes after the first one and before the
294         #    housenumber
295         #  * a qualifier comes after the name
296         #  * the containing phrase is strictly typed
297         if (base.housenumber and first.end < base.housenumber.start)\
298            or (base.qualifier and base.qualifier > first)\
299            or (query.nodes[first.start].ptype != qmod.PHRASE_ANY):
300             return
301
302         penalty = self.penalty
303
304         # Penalty for:
305         #  * <name>, <street>, <housenumber> , ...
306         #  * queries that are comma-separated
307         if (base.housenumber and base.housenumber > first) or len(query.source) > 1:
308             penalty += 0.25
309
310         for i in range(first.start + 1, first.end):
311             name, addr = first.split(i)
312             log().comment(f'split first word = name ({i - first.start})')
313             yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
314                                       penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
315
316     def _get_assignments_address_backward(self, base: TokenAssignment,
317                                           query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
318         """ Yield possible assignments of address searches with
319             right-to-left reading.
320         """
321         last = base.address[-1]
322
323         # The postcode must come before the name for backward direction.
324         if base.postcode and base.postcode > last:
325             log().var_dump('skip backward', (base.postcode, last))
326             return
327
328         if self.direction == -1 or len(base.address) > 1 or base.postcode:
329             log().comment('last word = name')
330             yield dataclasses.replace(base, penalty=self.penalty,
331                                       name=last, address=base.address[:-1])
332
333         # To paraphrase:
334         #  * if another name term comes before the last one and after the
335         #    housenumber
336         #  * a qualifier comes before the name
337         #  * the containing phrase is strictly typed
338         if (base.housenumber and last.start > base.housenumber.end)\
339            or (base.qualifier and base.qualifier < last)\
340            or (query.nodes[last.start].ptype != qmod.PHRASE_ANY):
341             return
342
343         penalty = self.penalty
344         if base.housenumber and base.housenumber < last:
345             penalty += 0.4
346         if len(query.source) > 1:
347             penalty += 0.25
348
349         for i in range(last.start + 1, last.end):
350             addr, name = last.split(i)
351             log().comment(f'split last word = name ({i - last.start})')
352             yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
353                                       penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
354
355     def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
356         """ Yield possible assignments for the current sequence.
357
358             This function splits up general name assignments into name
359             and address and yields all possible variants of that.
360         """
361         base = TokenAssignment.from_ranges(self.seq)
362
363         num_addr_tokens = sum(t.end - t.start for t in base.address)
364         if num_addr_tokens > 50:
365             return
366
367         # Postcode search (postcode-only search is covered in next case)
368         if base.postcode is not None and base.address:
369             yield from self._get_assignments_postcode(base, query.num_token_slots())
370
371         # Postcode or country-only search
372         if not base.address:
373             if not base.housenumber and (base.postcode or base.country or base.near_item):
374                 log().comment('postcode/country search')
375                 yield dataclasses.replace(base, penalty=self.penalty)
376         else:
377             # <postcode>,<address> should give preference to postcode search
378             if base.postcode and base.postcode.start == 0:
379                 self.penalty += 0.1
380
381             # Right-to-left reading of the address
382             if self.direction != -1:
383                 yield from self._get_assignments_address_forward(base, query)
384
385             # Left-to-right reading of the address
386             if self.direction != 1:
387                 yield from self._get_assignments_address_backward(base, query)
388
389             # variant for special housenumber searches
390             if base.housenumber and not base.qualifier:
391                 yield dataclasses.replace(base, penalty=self.penalty)
392
393
394 def yield_token_assignments(query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
395     """ Return possible word type assignments to word positions.
396
397         The assignments are computed from the concrete tokens listed
398         in the tokenized query.
399
400         The result includes the penalty for transitions from one word type to
401         another. It does not include penalties for transitions within a
402         type.
403     """
404     todo = [_TokenSequence([], direction=0 if query.source[0].ptype == qmod.PHRASE_ANY else 1)]
405
406     while todo:
407         state = todo.pop()
408         node = query.nodes[state.end_pos]
409
410         for tlist in node.starting:
411             newstate = state.advance(tlist.ttype, tlist.end, node.btype)
412             if newstate is not None:
413                 if newstate.end_pos == query.num_token_slots():
414                     if newstate.recheck_sequence():
415                         log().var_dump('Assignment', newstate)
416                         yield from newstate.get_assignments(query)
417                 elif not newstate.is_final():
418                     todo.append(newstate)