]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/token_assignment.py
avoid duplicate lines during category search
[nominatim.git] / nominatim / api / search / token_assignment.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Create query interpretations where each vertice in the query is assigned
9 a specific function (expressed as a token type).
10 """
11 from typing import Optional, List, Iterator
12 import dataclasses
13
14 import nominatim.api.search.query as qmod
15 from nominatim.api.logging import log
16
17 # pylint: disable=too-many-return-statements,too-many-branches
18
19 @dataclasses.dataclass
20 class TypedRange:
21     """ A token range for a specific type of tokens.
22     """
23     ttype: qmod.TokenType
24     trange: qmod.TokenRange
25
26
27 PENALTY_TOKENCHANGE = {
28     qmod.BreakType.START: 0.0,
29     qmod.BreakType.END: 0.0,
30     qmod.BreakType.PHRASE: 0.0,
31     qmod.BreakType.WORD: 0.1,
32     qmod.BreakType.PART: 0.2,
33     qmod.BreakType.TOKEN: 0.4
34 }
35
36 TypedRangeSeq = List[TypedRange]
37
38 @dataclasses.dataclass
39 class TokenAssignment: # pylint: disable=too-many-instance-attributes
40     """ Representation of a possible assignment of token types
41         to the tokens in a tokenized query.
42     """
43     penalty: float = 0.0
44     name: Optional[qmod.TokenRange] = None
45     address: List[qmod.TokenRange] = dataclasses.field(default_factory=list)
46     housenumber: Optional[qmod.TokenRange] = None
47     postcode: Optional[qmod.TokenRange] = None
48     country: Optional[qmod.TokenRange] = None
49     category: Optional[qmod.TokenRange] = None
50     qualifier: Optional[qmod.TokenRange] = None
51
52
53     @staticmethod
54     def from_ranges(ranges: TypedRangeSeq) -> 'TokenAssignment':
55         """ Create a new token assignment from a sequence of typed spans.
56         """
57         out = TokenAssignment()
58         for token in ranges:
59             if token.ttype == qmod.TokenType.PARTIAL:
60                 out.address.append(token.trange)
61             elif token.ttype == qmod.TokenType.HOUSENUMBER:
62                 out.housenumber = token.trange
63             elif token.ttype == qmod.TokenType.POSTCODE:
64                 out.postcode = token.trange
65             elif token.ttype == qmod.TokenType.COUNTRY:
66                 out.country = token.trange
67             elif token.ttype == qmod.TokenType.CATEGORY:
68                 out.category = token.trange
69             elif token.ttype == qmod.TokenType.QUALIFIER:
70                 out.qualifier = token.trange
71         return out
72
73
74 class _TokenSequence:
75     """ Working state used to put together the token assignements.
76
77         Represents an intermediate state while traversing the tokenized
78         query.
79     """
80     def __init__(self, seq: TypedRangeSeq,
81                  direction: int = 0, penalty: float = 0.0) -> None:
82         self.seq = seq
83         self.direction = direction
84         self.penalty = penalty
85
86
87     def __str__(self) -> str:
88         seq = ''.join(f'[{r.trange.start} - {r.trange.end}: {r.ttype.name}]' for r in self.seq)
89         return f'{seq} (dir: {self.direction}, penalty: {self.penalty})'
90
91
92     @property
93     def end_pos(self) -> int:
94         """ Return the index of the global end of the current sequence.
95         """
96         return self.seq[-1].trange.end if self.seq else 0
97
98
99     def has_types(self, *ttypes: qmod.TokenType) -> bool:
100         """ Check if the current sequence contains any typed ranges of
101             the given types.
102         """
103         return any(s.ttype in ttypes for s in self.seq)
104
105
106     def is_final(self) -> bool:
107         """ Return true when the sequence cannot be extended by any
108             form of token anymore.
109         """
110         # Country and category must be the final term for left-to-right
111         return len(self.seq) > 1 and \
112                self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.CATEGORY)
113
114
115     def appendable(self, ttype: qmod.TokenType) -> Optional[int]:
116         """ Check if the give token type is appendable to the existing sequence.
117
118             Returns None if the token type is not appendable, otherwise the
119             new direction of the sequence after adding such a type. The
120             token is not added.
121         """
122         if ttype == qmod.TokenType.WORD:
123             return None
124
125         if not self.seq:
126             # Append unconditionally to the empty list
127             if ttype == qmod.TokenType.COUNTRY:
128                 return -1
129             if ttype in (qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
130                 return 1
131             return self.direction
132
133         # Name tokens are always acceptable and don't change direction
134         if ttype == qmod.TokenType.PARTIAL:
135             return self.direction
136
137         # Other tokens may only appear once
138         if self.has_types(ttype):
139             return None
140
141         if ttype == qmod.TokenType.HOUSENUMBER:
142             if self.direction == 1:
143                 if len(self.seq) == 1 and self.seq[0].ttype == qmod.TokenType.QUALIFIER:
144                     return None
145                 if len(self.seq) > 2 \
146                    or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
147                     return None # direction left-to-right: housenumber must come before anything
148             elif self.direction == -1 \
149                  or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
150                 return -1 # force direction right-to-left if after other terms
151
152             return self.direction
153
154         if ttype == qmod.TokenType.POSTCODE:
155             if self.direction == -1:
156                 if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
157                     return None
158                 return -1
159             if self.direction == 1:
160                 return None if self.has_types(qmod.TokenType.COUNTRY) else 1
161             if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
162                 return 1
163             return self.direction
164
165         if ttype == qmod.TokenType.COUNTRY:
166             return None if self.direction == -1 else 1
167
168         if ttype == qmod.TokenType.CATEGORY:
169             return self.direction
170
171         if ttype == qmod.TokenType.QUALIFIER:
172             if self.direction == 1:
173                 if (len(self.seq) == 1
174                     and self.seq[0].ttype in (qmod.TokenType.PARTIAL, qmod.TokenType.CATEGORY)) \
175                    or (len(self.seq) == 2
176                        and self.seq[0].ttype == qmod.TokenType.CATEGORY
177                        and self.seq[1].ttype == qmod.TokenType.PARTIAL):
178                     return 1
179                 return None
180             if self.direction == -1:
181                 return -1
182
183             tempseq = self.seq[1:] if self.seq[0].ttype == qmod.TokenType.CATEGORY else self.seq
184             if len(tempseq) == 0:
185                 return 1
186             if len(tempseq) == 1 and self.seq[0].ttype == qmod.TokenType.HOUSENUMBER:
187                 return None
188             if len(tempseq) > 1 or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
189                 return -1
190             return 0
191
192         return None
193
194
195     def advance(self, ttype: qmod.TokenType, end_pos: int,
196                 btype: qmod.BreakType) -> Optional['_TokenSequence']:
197         """ Return a new token sequence state with the given token type
198             extended.
199         """
200         newdir = self.appendable(ttype)
201         if newdir is None:
202             return None
203
204         if not self.seq:
205             newseq = [TypedRange(ttype, qmod.TokenRange(0, end_pos))]
206             new_penalty = 0.0
207         else:
208             last = self.seq[-1]
209             if btype != qmod.BreakType.PHRASE and last.ttype == ttype:
210                 # extend the existing range
211                 newseq = self.seq[:-1] + [TypedRange(ttype, last.trange.replace_end(end_pos))]
212                 new_penalty = 0.0
213             else:
214                 # start a new range
215                 newseq = list(self.seq) + [TypedRange(ttype,
216                                                       qmod.TokenRange(last.trange.end, end_pos))]
217                 new_penalty = PENALTY_TOKENCHANGE[btype]
218
219         return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
220
221
222     def _adapt_penalty_from_priors(self, priors: int, new_dir: int) -> bool:
223         if priors == 2:
224             self.penalty += 1.0
225         elif priors > 2:
226             if self.direction == 0:
227                 self.direction = new_dir
228             else:
229                 return False
230
231         return True
232
233
234     def recheck_sequence(self) -> bool:
235         """ Check that the sequence is a fully valid token assignment
236             and addapt direction and penalties further if necessary.
237
238             This function catches some impossible assignments that need
239             forward context and can therefore not be exluded when building
240             the assignment.
241         """
242         # housenumbers may not be further than 2 words from the beginning.
243         # If there are two words in front, give it a penalty.
244         hnrpos = next((i for i, tr in enumerate(self.seq)
245                        if tr.ttype == qmod.TokenType.HOUSENUMBER),
246                       None)
247         if hnrpos is not None:
248             if self.direction != -1:
249                 priors = sum(1 for t in self.seq[:hnrpos] if t.ttype == qmod.TokenType.PARTIAL)
250                 if not self._adapt_penalty_from_priors(priors, -1):
251                     return False
252             if self.direction != 1:
253                 priors = sum(1 for t in self.seq[hnrpos+1:] if t.ttype == qmod.TokenType.PARTIAL)
254                 if not self._adapt_penalty_from_priors(priors, 1):
255                     return False
256             if any(t.ttype == qmod.TokenType.CATEGORY for t in self.seq):
257                 self.penalty += 1.0
258
259         return True
260
261
262     def _get_assignments_postcode(self, base: TokenAssignment,
263                                   query_len: int)  -> Iterator[TokenAssignment]:
264         """ Yield possible assignments of Postcode searches with an
265             address component.
266         """
267         assert base.postcode is not None
268
269         if (base.postcode.start == 0 and self.direction != -1)\
270            or (base.postcode.end == query_len and self.direction != 1):
271             log().comment('postcode search')
272             # <address>,<postcode> should give preference to address search
273             if base.postcode.start == 0:
274                 penalty = self.penalty
275                 self.direction = -1 # name searches are only possbile backwards
276             else:
277                 penalty = self.penalty + 0.1
278                 self.direction = 1 # name searches are only possbile forwards
279             yield dataclasses.replace(base, penalty=penalty)
280
281
282     def _get_assignments_address_forward(self, base: TokenAssignment,
283                                          query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
284         """ Yield possible assignments of address searches with
285             left-to-right reading.
286         """
287         first = base.address[0]
288
289         log().comment('first word = name')
290         yield dataclasses.replace(base, penalty=self.penalty,
291                                   name=first, address=base.address[1:])
292
293         # To paraphrase:
294         #  * if another name term comes after the first one and before the
295         #    housenumber
296         #  * a qualifier comes after the name
297         #  * the containing phrase is strictly typed
298         if (base.housenumber and first.end < base.housenumber.start)\
299            or (base.qualifier and base.qualifier > first)\
300            or (query.nodes[first.start].ptype != qmod.PhraseType.NONE):
301             return
302
303         penalty = self.penalty
304
305         # Penalty for:
306         #  * <name>, <street>, <housenumber> , ...
307         #  * queries that are comma-separated
308         if (base.housenumber and base.housenumber > first) or len(query.source) > 1:
309             penalty += 0.25
310
311         for i in range(first.start + 1, first.end):
312             name, addr = first.split(i)
313             log().comment(f'split first word = name ({i - first.start})')
314             yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
315                                       penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
316
317
318     def _get_assignments_address_backward(self, base: TokenAssignment,
319                                           query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
320         """ Yield possible assignments of address searches with
321             right-to-left reading.
322         """
323         last = base.address[-1]
324
325         if self.direction == -1 or len(base.address) > 1:
326             log().comment('last word = name')
327             yield dataclasses.replace(base, penalty=self.penalty,
328                                       name=last, address=base.address[:-1])
329
330         # To paraphrase:
331         #  * if another name term comes before the last one and after the
332         #    housenumber
333         #  * a qualifier comes before the name
334         #  * the containing phrase is strictly typed
335         if (base.housenumber and last.start > base.housenumber.end)\
336            or (base.qualifier and base.qualifier < last)\
337            or (query.nodes[last.start].ptype != qmod.PhraseType.NONE):
338             return
339
340         penalty = self.penalty
341         if base.housenumber and base.housenumber < last:
342             penalty += 0.4
343         if len(query.source) > 1:
344             penalty += 0.25
345
346         for i in range(last.start + 1, last.end):
347             addr, name = last.split(i)
348             log().comment(f'split last word = name ({i - last.start})')
349             yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
350                                       penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
351
352
353     def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
354         """ Yield possible assignments for the current sequence.
355
356             This function splits up general name assignments into name
357             and address and yields all possible variants of that.
358         """
359         base = TokenAssignment.from_ranges(self.seq)
360
361         num_addr_tokens = sum(t.end - t.start for t in base.address)
362         if num_addr_tokens > 50:
363             return
364
365         # Postcode search (postcode-only search is covered in next case)
366         if base.postcode is not None and base.address:
367             yield from self._get_assignments_postcode(base, query.num_token_slots())
368
369         # Postcode or country-only search
370         if not base.address:
371             if not base.housenumber and (base.postcode or base.country or base.category):
372                 log().comment('postcode/country search')
373                 yield dataclasses.replace(base, penalty=self.penalty)
374         else:
375             # <postcode>,<address> should give preference to postcode search
376             if base.postcode and base.postcode.start == 0:
377                 self.penalty += 0.1
378
379             # Right-to-left reading of the address
380             if self.direction != -1:
381                 yield from self._get_assignments_address_forward(base, query)
382
383             # Left-to-right reading of the address
384             if self.direction != 1:
385                 yield from self._get_assignments_address_backward(base, query)
386
387             # variant for special housenumber searches
388             if base.housenumber:
389                 yield dataclasses.replace(base, penalty=self.penalty)
390
391
392 def yield_token_assignments(query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
393     """ Return possible word type assignments to word positions.
394
395         The assignments are computed from the concrete tokens listed
396         in the tokenized query.
397
398         The result includes the penalty for transitions from one word type to
399         another. It does not include penalties for transitions within a
400         type.
401     """
402     todo = [_TokenSequence([], direction=0 if query.source[0].ptype == qmod.PhraseType.NONE else 1)]
403
404     while todo:
405         state = todo.pop()
406         node = query.nodes[state.end_pos]
407
408         for tlist in node.starting:
409             newstate = state.advance(tlist.ttype, tlist.end, node.btype)
410             if newstate is not None:
411                 if newstate.end_pos == query.num_token_slots():
412                     if newstate.recheck_sequence():
413                         log().var_dump('Assignment', newstate)
414                         yield from newstate.get_assignments(query)
415                 elif not newstate.is_final():
416                     todo.append(newstate)