]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/db_search_builder.py
Revert "work round typing bug in pyosmium 4.0"
[nominatim.git] / src / nominatim_api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Conversion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator, Dict
11 import heapq
12
13 from ..types import SearchDetails, DataLayer
14 from .query import QueryStruct, Token, TokenType, TokenRange, BreakType
15 from .token_assignment import TokenAssignment
16 from . import db_search_fields as dbf
17 from . import db_searches as dbs
18 from . import db_search_lookups as lookups
19
20
21 def wrap_near_search(categories: List[Tuple[str, str]],
22                      search: dbs.AbstractSearch) -> dbs.NearSearch:
23     """ Create a new search that wraps the given search in a search
24         for near places of the given category.
25     """
26     return dbs.NearSearch(penalty=search.penalty,
27                           categories=dbf.WeightedCategories(categories,
28                                                             [0.0] * len(categories)),
29                           search=search)
30
31
32 def build_poi_search(category: List[Tuple[str, str]],
33                      countries: Optional[List[str]]) -> dbs.PoiSearch:
34     """ Create a new search for places by the given category, possibly
35         constraint to the given countries.
36     """
37     if countries:
38         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
39     else:
40         ccs = dbf.WeightedStrings([], [])
41
42     class _PoiData(dbf.SearchData):
43         penalty = 0.0
44         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
45         countries=ccs
46
47     return dbs.PoiSearch(_PoiData())
48
49
50 class SearchBuilder:
51     """ Build the abstract search queries from token assignments.
52     """
53
54     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
55         self.query = query
56         self.details = details
57
58
59     @property
60     def configured_for_country(self) -> bool:
61         """ Return true if the search details are configured to
62             allow countries in the result.
63         """
64         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
65                and self.details.layer_enabled(DataLayer.ADDRESS)
66
67
68     @property
69     def configured_for_postcode(self) -> bool:
70         """ Return true if the search details are configured to
71             allow postcodes in the result.
72         """
73         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
74                and self.details.layer_enabled(DataLayer.ADDRESS)
75
76
77     @property
78     def configured_for_housenumbers(self) -> bool:
79         """ Return true if the search details are configured to
80             allow addresses in the result.
81         """
82         return self.details.max_rank >= 30 \
83                and self.details.layer_enabled(DataLayer.ADDRESS)
84
85
86     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
87         """ Yield all possible abstract searches for the given token assignment.
88         """
89         sdata = self.get_search_data(assignment)
90         if sdata is None:
91             return
92
93         near_items = self.get_near_items(assignment)
94         if near_items is not None and not near_items:
95             return # impossible compbination of near items and category parameter
96
97         if assignment.name is None:
98             if near_items and not sdata.postcodes:
99                 sdata.qualifiers = near_items
100                 near_items = None
101                 builder = self.build_poi_search(sdata)
102             elif assignment.housenumber:
103                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
104                                                    TokenType.HOUSENUMBER)
105                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
106             else:
107                 builder = self.build_special_search(sdata, assignment.address,
108                                                     bool(near_items))
109         else:
110             builder = self.build_name_search(sdata, assignment.name, assignment.address,
111                                              bool(near_items))
112
113         if near_items:
114             penalty = min(near_items.penalties)
115             near_items.penalties = [p - penalty for p in near_items.penalties]
116             for search in builder:
117                 search_penalty = search.penalty
118                 search.penalty = 0.0
119                 yield dbs.NearSearch(penalty + assignment.penalty + search_penalty,
120                                      near_items, search)
121         else:
122             for search in builder:
123                 search.penalty += assignment.penalty
124                 yield search
125
126
127     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
128         """ Build abstract search query for a simple category search.
129             This kind of search requires an additional geographic constraint.
130         """
131         if not sdata.housenumbers \
132            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
133             yield dbs.PoiSearch(sdata)
134
135
136     def build_special_search(self, sdata: dbf.SearchData,
137                              address: List[TokenRange],
138                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
139         """ Build abstract search queries for searches that do not involve
140             a named place.
141         """
142         if sdata.qualifiers:
143             # No special searches over qualifiers supported.
144             return
145
146         if sdata.countries and not address and not sdata.postcodes \
147            and self.configured_for_country:
148             yield dbs.CountrySearch(sdata)
149
150         if sdata.postcodes and (is_category or self.configured_for_postcode):
151             penalty = 0.0 if sdata.countries else 0.1
152             if address:
153                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
154                                                  [t.token for r in address
155                                                   for t in self.query.get_partials_list(r)],
156                                                  lookups.Restrict)]
157                 penalty += 0.2
158             yield dbs.PostcodeSearch(penalty, sdata)
159
160
161     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
162                                  address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
163         """ Build a simple address search for special entries where the
164             housenumber is the main name token.
165         """
166         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], lookups.LookupAny)]
167         expected_count = sum(t.count for t in hnrs)
168
169         partials = {t.token: t.addr_count for trange in address
170                        for t in self.query.get_partials_list(trange)}
171
172         if not partials:
173             # can happen when none of the partials is indexed
174             return
175
176         if expected_count < 8000:
177             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
178                                                  list(partials), lookups.Restrict))
179         elif len(partials) != 1 or list(partials.values())[0] < 10000:
180             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
181                                                  list(partials), lookups.LookupAll))
182         else:
183             addr_fulls = [t.token for t
184                           in self.query.get_tokens(address[0], TokenType.WORD)]
185             if len(addr_fulls) > 5:
186                 return
187             sdata.lookups.append(
188                 dbf.FieldLookup('nameaddress_vector', addr_fulls, lookups.LookupAny))
189
190         sdata.housenumbers = dbf.WeightedStrings([], [])
191         yield dbs.PlaceSearch(0.05, sdata, expected_count)
192
193
194     def build_name_search(self, sdata: dbf.SearchData,
195                           name: TokenRange, address: List[TokenRange],
196                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
197         """ Build abstract search queries for simple name or address searches.
198         """
199         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
200             ranking = self.get_name_ranking(name)
201             name_penalty = ranking.normalize_penalty()
202             if ranking.rankings:
203                 sdata.rankings.append(ranking)
204             for penalty, count, lookup in self.yield_lookups(name, address):
205                 sdata.lookups = lookup
206                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
207
208
209     def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
210                           -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
211         """ Yield all variants how the given name and address should best
212             be searched for. This takes into account how frequent the terms
213             are and tries to find a lookup that optimizes index use.
214         """
215         penalty = 0.0 # extra penalty
216         name_partials = {t.token: t for t in self.query.get_partials_list(name)}
217
218         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
219         addr_tokens = list({t.token for t in addr_partials})
220
221         exp_count = min(t.count for t in name_partials.values()) / (2**(len(name_partials) - 1))
222
223         if (len(name_partials) > 3 or exp_count < 8000):
224             yield penalty, exp_count, dbf.lookup_by_names(list(name_partials.keys()), addr_tokens)
225             return
226
227         addr_count = min(t.addr_count for t in addr_partials) if addr_partials else 30000
228         # Partial term to frequent. Try looking up by rare full names first.
229         name_fulls = self.query.get_tokens(name, TokenType.WORD)
230         if name_fulls:
231             fulls_count = sum(t.count for t in name_fulls)
232
233             if fulls_count < 50000 or addr_count < 30000:
234                 yield penalty,fulls_count / (2**len(addr_tokens)), \
235                     self.get_full_name_ranking(name_fulls, addr_partials,
236                                                fulls_count > 30000 / max(1, len(addr_tokens)))
237
238         # To catch remaining results, lookup by name and address
239         # We only do this if there is a reasonable number of results expected.
240         exp_count = exp_count / (2**len(addr_tokens)) if addr_tokens else exp_count
241         if exp_count < 10000 and addr_count < 20000:
242             penalty += 0.35 * max(1 if name_fulls else 0.1,
243                                   5 - len(name_partials) - len(addr_tokens))
244             yield penalty, exp_count,\
245                   self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
246
247
248     def get_name_address_ranking(self, name_tokens: List[int],
249                                  addr_partials: List[Token]) -> List[dbf.FieldLookup]:
250         """ Create a ranking expression looking up by name and address.
251         """
252         lookup = [dbf.FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
253
254         addr_restrict_tokens = []
255         addr_lookup_tokens = []
256         for t in addr_partials:
257             if t.addr_count > 20000:
258                 addr_restrict_tokens.append(t.token)
259             else:
260                 addr_lookup_tokens.append(t.token)
261
262         if addr_restrict_tokens:
263             lookup.append(dbf.FieldLookup('nameaddress_vector',
264                                           addr_restrict_tokens, lookups.Restrict))
265         if addr_lookup_tokens:
266             lookup.append(dbf.FieldLookup('nameaddress_vector',
267                                           addr_lookup_tokens, lookups.LookupAll))
268
269         return lookup
270
271
272     def get_full_name_ranking(self, name_fulls: List[Token], addr_partials: List[Token],
273                               use_lookup: bool) -> List[dbf.FieldLookup]:
274         """ Create a ranking expression with full name terms and
275             additional address lookup. When 'use_lookup' is true, then
276             address lookups will use the index, when the occurrences are not
277             too many.
278         """
279         # At this point drop unindexed partials from the address.
280         # This might yield wrong results, nothing we can do about that.
281         if use_lookup:
282             addr_restrict_tokens = []
283             addr_lookup_tokens = []
284             for t in addr_partials:
285                 if t.addr_count > 20000:
286                     addr_restrict_tokens.append(t.token)
287                 else:
288                     addr_lookup_tokens.append(t.token)
289         else:
290             addr_restrict_tokens = [t.token for t in addr_partials]
291             addr_lookup_tokens = []
292
293         return dbf.lookup_by_any_name([t.token for t in name_fulls],
294                                       addr_restrict_tokens, addr_lookup_tokens)
295
296
297     def get_name_ranking(self, trange: TokenRange,
298                          db_field: str = 'name_vector') -> dbf.FieldRanking:
299         """ Create a ranking expression for a name term in the given range.
300         """
301         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
302         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
303         ranks.sort(key=lambda r: r.penalty)
304         # Fallback, sum of penalty for partials
305         name_partials = self.query.get_partials_list(trange)
306         default = sum(t.penalty for t in name_partials) + 0.2
307         return dbf.FieldRanking(db_field, default, ranks)
308
309
310     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
311         """ Create a list of ranking expressions for an address term
312             for the given ranges.
313         """
314         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
315         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
316         ranks: List[dbf.RankedTokens] = []
317
318         while todo: # pylint: disable=too-many-nested-blocks
319             neglen, pos, rank = heapq.heappop(todo)
320             for tlist in self.query.nodes[pos].starting:
321                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
322                     if tlist.end < trange.end:
323                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
324                         if tlist.ttype == TokenType.PARTIAL:
325                             penalty = rank.penalty + chgpenalty \
326                                       + max(t.penalty for t in tlist.tokens)
327                             heapq.heappush(todo, (neglen - 1, tlist.end,
328                                                   dbf.RankedTokens(penalty, rank.tokens)))
329                         else:
330                             for t in tlist.tokens:
331                                 heapq.heappush(todo, (neglen - 1, tlist.end,
332                                                       rank.with_token(t, chgpenalty)))
333                     elif tlist.end == trange.end:
334                         if tlist.ttype == TokenType.PARTIAL:
335                             ranks.append(dbf.RankedTokens(rank.penalty
336                                                           + max(t.penalty for t in tlist.tokens),
337                                                           rank.tokens))
338                         else:
339                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
340                         if len(ranks) >= 10:
341                             # Too many variants, bail out and only add
342                             # Worst-case Fallback: sum of penalty of partials
343                             name_partials = self.query.get_partials_list(trange)
344                             default = sum(t.penalty for t in name_partials) + 0.2
345                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
346                             # Bail out of outer loop
347                             todo.clear()
348                             break
349
350         ranks.sort(key=lambda r: len(r.tokens))
351         default = ranks[0].penalty + 0.3
352         del ranks[0]
353         ranks.sort(key=lambda r: r.penalty)
354
355         return dbf.FieldRanking('nameaddress_vector', default, ranks)
356
357
358     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
359         """ Collect the tokens for the non-name search fields in the
360             assignment.
361         """
362         sdata = dbf.SearchData()
363         sdata.penalty = assignment.penalty
364         if assignment.country:
365             tokens = self.get_country_tokens(assignment.country)
366             if not tokens:
367                 return None
368             sdata.set_strings('countries', tokens)
369         elif self.details.countries:
370             sdata.countries = dbf.WeightedStrings(self.details.countries,
371                                                   [0.0] * len(self.details.countries))
372         if assignment.housenumber:
373             sdata.set_strings('housenumbers',
374                               self.query.get_tokens(assignment.housenumber,
375                                                     TokenType.HOUSENUMBER))
376         if assignment.postcode:
377             sdata.set_strings('postcodes',
378                               self.query.get_tokens(assignment.postcode,
379                                                     TokenType.POSTCODE))
380         if assignment.qualifier:
381             tokens = self.get_qualifier_tokens(assignment.qualifier)
382             if not tokens:
383                 return None
384             sdata.set_qualifiers(tokens)
385         elif self.details.categories:
386             sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
387                                                       [0.0] * len(self.details.categories))
388
389         if assignment.address:
390             if not assignment.name and assignment.housenumber:
391                 # housenumber search: the first item needs to be handled like
392                 # a name in ranking or penalties are not comparable with
393                 # normal searches.
394                 sdata.set_ranking([self.get_name_ranking(assignment.address[0],
395                                                          db_field='nameaddress_vector')]
396                                   + [self.get_addr_ranking(r) for r in assignment.address[1:]])
397             else:
398                 sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
399         else:
400             sdata.rankings = []
401
402         return sdata
403
404
405     def get_country_tokens(self, trange: TokenRange) -> List[Token]:
406         """ Return the list of country tokens for the given range,
407             optionally filtered by the country list from the details
408             parameters.
409         """
410         tokens = self.query.get_tokens(trange, TokenType.COUNTRY)
411         if self.details.countries:
412             tokens = [t for t in tokens if t.lookup_word in self.details.countries]
413
414         return tokens
415
416
417     def get_qualifier_tokens(self, trange: TokenRange) -> List[Token]:
418         """ Return the list of qualifier tokens for the given range,
419             optionally filtered by the qualifier list from the details
420             parameters.
421         """
422         tokens = self.query.get_tokens(trange, TokenType.QUALIFIER)
423         if self.details.categories:
424             tokens = [t for t in tokens if t.get_category() in self.details.categories]
425
426         return tokens
427
428
429     def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
430         """ Collect tokens for near items search or use the categories
431             requested per parameter.
432             Returns None if no category search is requested.
433         """
434         if assignment.near_item:
435             tokens: Dict[Tuple[str, str], float] = {}
436             for t in self.query.get_tokens(assignment.near_item, TokenType.NEAR_ITEM):
437                 cat = t.get_category()
438                 # The category of a near search will be that of near_item.
439                 # Thus, if search is restricted to a category parameter,
440                 # the two sets must intersect.
441                 if (not self.details.categories or cat in self.details.categories)\
442                    and t.penalty < tokens.get(cat, 1000.0):
443                     tokens[cat] = t.penalty
444             return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
445
446         return None
447
448
449 PENALTY_WORDCHANGE = {
450     BreakType.START: 0.0,
451     BreakType.END: 0.0,
452     BreakType.PHRASE: 0.0,
453     BreakType.WORD: 0.1,
454     BreakType.PART: 0.2,
455     BreakType.TOKEN: 0.4
456 }