]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/token_assignment.py
Merge pull request #3626 from lonvia/import-performance
[nominatim.git] / src / nominatim_api / search / token_assignment.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Create query interpretations where each vertice in the query is assigned
9 a specific function (expressed as a token type).
10 """
11 from typing import Optional, List, Iterator
12 import dataclasses
13
14 from ..logging import log
15 from . import query as qmod
16
17
18 @dataclasses.dataclass
19 class TypedRange:
20     """ A token range for a specific type of tokens.
21     """
22     ttype: qmod.TokenType
23     trange: qmod.TokenRange
24
25
26 PENALTY_TOKENCHANGE = {
27     qmod.BreakType.START: 0.0,
28     qmod.BreakType.END: 0.0,
29     qmod.BreakType.PHRASE: 0.0,
30     qmod.BreakType.SOFT_PHRASE: 0.0,
31     qmod.BreakType.WORD: 0.1,
32     qmod.BreakType.PART: 0.2,
33     qmod.BreakType.TOKEN: 0.4
34 }
35
36 TypedRangeSeq = List[TypedRange]
37
38
39 @dataclasses.dataclass
40 class TokenAssignment:
41     """ Representation of a possible assignment of token types
42         to the tokens in a tokenized query.
43     """
44     penalty: float = 0.0
45     name: Optional[qmod.TokenRange] = None
46     address: List[qmod.TokenRange] = dataclasses.field(default_factory=list)
47     housenumber: Optional[qmod.TokenRange] = None
48     postcode: Optional[qmod.TokenRange] = None
49     country: Optional[qmod.TokenRange] = None
50     near_item: Optional[qmod.TokenRange] = None
51     qualifier: Optional[qmod.TokenRange] = None
52
53     @staticmethod
54     def from_ranges(ranges: TypedRangeSeq) -> 'TokenAssignment':
55         """ Create a new token assignment from a sequence of typed spans.
56         """
57         out = TokenAssignment()
58         for token in ranges:
59             if token.ttype == qmod.TokenType.PARTIAL:
60                 out.address.append(token.trange)
61             elif token.ttype == qmod.TokenType.HOUSENUMBER:
62                 out.housenumber = token.trange
63             elif token.ttype == qmod.TokenType.POSTCODE:
64                 out.postcode = token.trange
65             elif token.ttype == qmod.TokenType.COUNTRY:
66                 out.country = token.trange
67             elif token.ttype == qmod.TokenType.NEAR_ITEM:
68                 out.near_item = token.trange
69             elif token.ttype == qmod.TokenType.QUALIFIER:
70                 out.qualifier = token.trange
71         return out
72
73
74 class _TokenSequence:
75     """ Working state used to put together the token assignments.
76
77         Represents an intermediate state while traversing the tokenized
78         query.
79     """
80     def __init__(self, seq: TypedRangeSeq,
81                  direction: int = 0, penalty: float = 0.0) -> None:
82         self.seq = seq
83         self.direction = direction
84         self.penalty = penalty
85
86     def __str__(self) -> str:
87         seq = ''.join(f'[{r.trange.start} - {r.trange.end}: {r.ttype.name}]' for r in self.seq)
88         return f'{seq} (dir: {self.direction}, penalty: {self.penalty})'
89
90     @property
91     def end_pos(self) -> int:
92         """ Return the index of the global end of the current sequence.
93         """
94         return self.seq[-1].trange.end if self.seq else 0
95
96     def has_types(self, *ttypes: qmod.TokenType) -> bool:
97         """ Check if the current sequence contains any typed ranges of
98             the given types.
99         """
100         return any(s.ttype in ttypes for s in self.seq)
101
102     def is_final(self) -> bool:
103         """ Return true when the sequence cannot be extended by any
104             form of token anymore.
105         """
106         # Country and category must be the final term for left-to-right
107         return len(self.seq) > 1 and \
108             self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.NEAR_ITEM)
109
110     def appendable(self, ttype: qmod.TokenType) -> Optional[int]:
111         """ Check if the give token type is appendable to the existing sequence.
112
113             Returns None if the token type is not appendable, otherwise the
114             new direction of the sequence after adding such a type. The
115             token is not added.
116         """
117         if ttype == qmod.TokenType.WORD:
118             return None
119
120         if not self.seq:
121             # Append unconditionally to the empty list
122             if ttype == qmod.TokenType.COUNTRY:
123                 return -1
124             if ttype in (qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
125                 return 1
126             return self.direction
127
128         # Name tokens are always acceptable and don't change direction
129         if ttype == qmod.TokenType.PARTIAL:
130             # qualifiers cannot appear in the middle of the query. They need
131             # to be near the next phrase.
132             if self.direction == -1 \
133                and any(t.ttype == qmod.TokenType.QUALIFIER for t in self.seq[:-1]):
134                 return None
135             return self.direction
136
137         # Other tokens may only appear once
138         if self.has_types(ttype):
139             return None
140
141         if ttype == qmod.TokenType.HOUSENUMBER:
142             if self.direction == 1:
143                 if len(self.seq) == 1 and self.seq[0].ttype == qmod.TokenType.QUALIFIER:
144                     return None
145                 if len(self.seq) > 2 \
146                    or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
147                     return None  # direction left-to-right: housenumber must come before anything
148             elif (self.direction == -1
149                   or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY)):
150                 return -1  # force direction right-to-left if after other terms
151
152             return self.direction
153
154         if ttype == qmod.TokenType.POSTCODE:
155             if self.direction == -1:
156                 if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
157                     return None
158                 return -1
159             if self.direction == 1:
160                 return None if self.has_types(qmod.TokenType.COUNTRY) else 1
161             if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
162                 return 1
163             return self.direction
164
165         if ttype == qmod.TokenType.COUNTRY:
166             return None if self.direction == -1 else 1
167
168         if ttype == qmod.TokenType.NEAR_ITEM:
169             return self.direction
170
171         if ttype == qmod.TokenType.QUALIFIER:
172             if self.direction == 1:
173                 if (len(self.seq) == 1
174                     and self.seq[0].ttype in (qmod.TokenType.PARTIAL, qmod.TokenType.NEAR_ITEM)) \
175                    or (len(self.seq) == 2
176                        and self.seq[0].ttype == qmod.TokenType.NEAR_ITEM
177                        and self.seq[1].ttype == qmod.TokenType.PARTIAL):
178                     return 1
179                 return None
180             if self.direction == -1:
181                 return -1
182
183             tempseq = self.seq[1:] if self.seq[0].ttype == qmod.TokenType.NEAR_ITEM else self.seq
184             if len(tempseq) == 0:
185                 return 1
186             if len(tempseq) == 1 and self.seq[0].ttype == qmod.TokenType.HOUSENUMBER:
187                 return None
188             if len(tempseq) > 1 or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
189                 return -1
190             return 0
191
192         return None
193
194     def advance(self, ttype: qmod.TokenType, end_pos: int,
195                 btype: qmod.BreakType) -> Optional['_TokenSequence']:
196         """ Return a new token sequence state with the given token type
197             extended.
198         """
199         newdir = self.appendable(ttype)
200         if newdir is None:
201             return None
202
203         if not self.seq:
204             newseq = [TypedRange(ttype, qmod.TokenRange(0, end_pos))]
205             new_penalty = 0.0
206         else:
207             last = self.seq[-1]
208             if btype != qmod.BreakType.PHRASE and last.ttype == ttype:
209                 # extend the existing range
210                 newseq = self.seq[:-1] + [TypedRange(ttype, last.trange.replace_end(end_pos))]
211                 new_penalty = 0.0
212             else:
213                 # start a new range
214                 newseq = list(self.seq) + [TypedRange(ttype,
215                                                       qmod.TokenRange(last.trange.end, end_pos))]
216                 new_penalty = PENALTY_TOKENCHANGE[btype]
217
218         return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
219
220     def _adapt_penalty_from_priors(self, priors: int, new_dir: int) -> bool:
221         if priors >= 2:
222             if self.direction == 0:
223                 self.direction = new_dir
224             else:
225                 if priors == 2:
226                     self.penalty += 0.8
227                 else:
228                     return False
229
230         return True
231
232     def recheck_sequence(self) -> bool:
233         """ Check that the sequence is a fully valid token assignment
234             and adapt direction and penalties further if necessary.
235
236             This function catches some impossible assignments that need
237             forward context and can therefore not be excluded when building
238             the assignment.
239         """
240         # housenumbers may not be further than 2 words from the beginning.
241         # If there are two words in front, give it a penalty.
242         hnrpos = next((i for i, tr in enumerate(self.seq)
243                        if tr.ttype == qmod.TokenType.HOUSENUMBER),
244                       None)
245         if hnrpos is not None:
246             if self.direction != -1:
247                 priors = sum(1 for t in self.seq[:hnrpos] if t.ttype == qmod.TokenType.PARTIAL)
248                 if not self._adapt_penalty_from_priors(priors, -1):
249                     return False
250             if self.direction != 1:
251                 priors = sum(1 for t in self.seq[hnrpos+1:] if t.ttype == qmod.TokenType.PARTIAL)
252                 if not self._adapt_penalty_from_priors(priors, 1):
253                     return False
254             if any(t.ttype == qmod.TokenType.NEAR_ITEM for t in self.seq):
255                 self.penalty += 1.0
256
257         return True
258
259     def _get_assignments_postcode(self, base: TokenAssignment,
260                                   query_len: int) -> Iterator[TokenAssignment]:
261         """ Yield possible assignments of Postcode searches with an
262             address component.
263         """
264         assert base.postcode is not None
265
266         if (base.postcode.start == 0 and self.direction != -1)\
267            or (base.postcode.end == query_len and self.direction != 1):
268             log().comment('postcode search')
269             # <address>,<postcode> should give preference to address search
270             if base.postcode.start == 0:
271                 penalty = self.penalty
272                 self.direction = -1  # name searches are only possible backwards
273             else:
274                 penalty = self.penalty + 0.1
275                 self.direction = 1  # name searches are only possible forwards
276             yield dataclasses.replace(base, penalty=penalty)
277
278     def _get_assignments_address_forward(self, base: TokenAssignment,
279                                          query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
280         """ Yield possible assignments of address searches with
281             left-to-right reading.
282         """
283         first = base.address[0]
284
285         log().comment('first word = name')
286         yield dataclasses.replace(base, penalty=self.penalty,
287                                   name=first, address=base.address[1:])
288
289         # To paraphrase:
290         #  * if another name term comes after the first one and before the
291         #    housenumber
292         #  * a qualifier comes after the name
293         #  * the containing phrase is strictly typed
294         if (base.housenumber and first.end < base.housenumber.start)\
295            or (base.qualifier and base.qualifier > first)\
296            or (query.nodes[first.start].ptype != qmod.PhraseType.NONE):
297             return
298
299         penalty = self.penalty
300
301         # Penalty for:
302         #  * <name>, <street>, <housenumber> , ...
303         #  * queries that are comma-separated
304         if (base.housenumber and base.housenumber > first) or len(query.source) > 1:
305             penalty += 0.25
306
307         for i in range(first.start + 1, first.end):
308             name, addr = first.split(i)
309             log().comment(f'split first word = name ({i - first.start})')
310             yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
311                                       penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
312
313     def _get_assignments_address_backward(self, base: TokenAssignment,
314                                           query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
315         """ Yield possible assignments of address searches with
316             right-to-left reading.
317         """
318         last = base.address[-1]
319
320         if self.direction == -1 or len(base.address) > 1:
321             log().comment('last word = name')
322             yield dataclasses.replace(base, penalty=self.penalty,
323                                       name=last, address=base.address[:-1])
324
325         # To paraphrase:
326         #  * if another name term comes before the last one and after the
327         #    housenumber
328         #  * a qualifier comes before the name
329         #  * the containing phrase is strictly typed
330         if (base.housenumber and last.start > base.housenumber.end)\
331            or (base.qualifier and base.qualifier < last)\
332            or (query.nodes[last.start].ptype != qmod.PhraseType.NONE):
333             return
334
335         penalty = self.penalty
336         if base.housenumber and base.housenumber < last:
337             penalty += 0.4
338         if len(query.source) > 1:
339             penalty += 0.25
340
341         for i in range(last.start + 1, last.end):
342             addr, name = last.split(i)
343             log().comment(f'split last word = name ({i - last.start})')
344             yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
345                                       penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
346
347     def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
348         """ Yield possible assignments for the current sequence.
349
350             This function splits up general name assignments into name
351             and address and yields all possible variants of that.
352         """
353         base = TokenAssignment.from_ranges(self.seq)
354
355         num_addr_tokens = sum(t.end - t.start for t in base.address)
356         if num_addr_tokens > 50:
357             return
358
359         # Postcode search (postcode-only search is covered in next case)
360         if base.postcode is not None and base.address:
361             yield from self._get_assignments_postcode(base, query.num_token_slots())
362
363         # Postcode or country-only search
364         if not base.address:
365             if not base.housenumber and (base.postcode or base.country or base.near_item):
366                 log().comment('postcode/country search')
367                 yield dataclasses.replace(base, penalty=self.penalty)
368         else:
369             # <postcode>,<address> should give preference to postcode search
370             if base.postcode and base.postcode.start == 0:
371                 self.penalty += 0.1
372
373             # Right-to-left reading of the address
374             if self.direction != -1:
375                 yield from self._get_assignments_address_forward(base, query)
376
377             # Left-to-right reading of the address
378             if self.direction != 1:
379                 yield from self._get_assignments_address_backward(base, query)
380
381             # variant for special housenumber searches
382             if base.housenumber and not base.qualifier:
383                 yield dataclasses.replace(base, penalty=self.penalty)
384
385
386 def yield_token_assignments(query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
387     """ Return possible word type assignments to word positions.
388
389         The assignments are computed from the concrete tokens listed
390         in the tokenized query.
391
392         The result includes the penalty for transitions from one word type to
393         another. It does not include penalties for transitions within a
394         type.
395     """
396     todo = [_TokenSequence([], direction=0 if query.source[0].ptype == qmod.PhraseType.NONE else 1)]
397
398     while todo:
399         state = todo.pop()
400         node = query.nodes[state.end_pos]
401
402         for tlist in node.starting:
403             newstate = state.advance(tlist.ttype, tlist.end, node.btype)
404             if newstate is not None:
405                 if newstate.end_pos == query.num_token_slots():
406                     if newstate.recheck_sequence():
407                         log().var_dump('Assignment', newstate)
408                         yield from newstate.get_assignments(query)
409                 elif not newstate.is_final():
410                     todo.append(newstate)