]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/db_search_builder.py
release 5.1.0post2
[nominatim.git] / src / nominatim_api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Conversion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator, Dict
11 import heapq
12
13 from ..types import SearchDetails, DataLayer
14 from . import query as qmod
15 from .token_assignment import TokenAssignment
16 from . import db_search_fields as dbf
17 from . import db_searches as dbs
18 from . import db_search_lookups as lookups
19
20
21 def wrap_near_search(categories: List[Tuple[str, str]],
22                      search: dbs.AbstractSearch) -> dbs.NearSearch:
23     """ Create a new search that wraps the given search in a search
24         for near places of the given category.
25     """
26     return dbs.NearSearch(penalty=search.penalty,
27                           categories=dbf.WeightedCategories(categories,
28                                                             [0.0] * len(categories)),
29                           search=search)
30
31
32 def build_poi_search(category: List[Tuple[str, str]],
33                      countries: Optional[List[str]]) -> dbs.PoiSearch:
34     """ Create a new search for places by the given category, possibly
35         constraint to the given countries.
36     """
37     if countries:
38         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
39     else:
40         ccs = dbf.WeightedStrings([], [])
41
42     class _PoiData(dbf.SearchData):
43         penalty = 0.0
44         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
45         countries = ccs
46
47     return dbs.PoiSearch(_PoiData())
48
49
50 class SearchBuilder:
51     """ Build the abstract search queries from token assignments.
52     """
53
54     def __init__(self, query: qmod.QueryStruct, details: SearchDetails) -> None:
55         self.query = query
56         self.details = details
57
58     @property
59     def configured_for_country(self) -> bool:
60         """ Return true if the search details are configured to
61             allow countries in the result.
62         """
63         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
64             and self.details.layer_enabled(DataLayer.ADDRESS)
65
66     @property
67     def configured_for_postcode(self) -> bool:
68         """ Return true if the search details are configured to
69             allow postcodes in the result.
70         """
71         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
72             and self.details.layer_enabled(DataLayer.ADDRESS)
73
74     @property
75     def configured_for_housenumbers(self) -> bool:
76         """ Return true if the search details are configured to
77             allow addresses in the result.
78         """
79         return self.details.max_rank >= 30 \
80             and self.details.layer_enabled(DataLayer.ADDRESS)
81
82     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
83         """ Yield all possible abstract searches for the given token assignment.
84         """
85         sdata = self.get_search_data(assignment)
86         if sdata is None:
87             return
88
89         near_items = self.get_near_items(assignment)
90         if near_items is not None and not near_items:
91             return  # impossible combination of near items and category parameter
92
93         if assignment.name is None:
94             if near_items and not sdata.postcodes:
95                 sdata.qualifiers = near_items
96                 near_items = None
97                 builder = self.build_poi_search(sdata)
98             elif assignment.housenumber:
99                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
100                                                    qmod.TOKEN_HOUSENUMBER)
101                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
102             else:
103                 builder = self.build_special_search(sdata, assignment.address,
104                                                     bool(near_items))
105         else:
106             builder = self.build_name_search(sdata, assignment.name, assignment.address,
107                                              bool(near_items))
108
109         if near_items:
110             penalty = min(near_items.penalties)
111             near_items.penalties = [p - penalty for p in near_items.penalties]
112             for search in builder:
113                 search_penalty = search.penalty
114                 search.penalty = 0.0
115                 yield dbs.NearSearch(penalty + assignment.penalty + search_penalty,
116                                      near_items, search)
117         else:
118             for search in builder:
119                 search.penalty += assignment.penalty
120                 yield search
121
122     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
123         """ Build abstract search query for a simple category search.
124             This kind of search requires an additional geographic constraint.
125         """
126         if not sdata.housenumbers \
127            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
128             yield dbs.PoiSearch(sdata)
129
130     def build_special_search(self, sdata: dbf.SearchData,
131                              address: List[qmod.TokenRange],
132                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
133         """ Build abstract search queries for searches that do not involve
134             a named place.
135         """
136         if sdata.qualifiers:
137             # No special searches over qualifiers supported.
138             return
139
140         if sdata.countries and not address and not sdata.postcodes \
141            and self.configured_for_country:
142             yield dbs.CountrySearch(sdata)
143
144         if sdata.postcodes and (is_category or self.configured_for_postcode):
145             penalty = 0.0 if sdata.countries else 0.1
146             if address:
147                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
148                                                  [t.token for r in address
149                                                   for t in self.query.get_partials_list(r)],
150                                                  lookups.Restrict)]
151             yield dbs.PostcodeSearch(penalty, sdata)
152
153     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[qmod.Token],
154                                  address: List[qmod.TokenRange]) -> Iterator[dbs.AbstractSearch]:
155         """ Build a simple address search for special entries where the
156             housenumber is the main name token.
157         """
158         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], lookups.LookupAny)]
159         expected_count = sum(t.count for t in hnrs)
160
161         partials = {t.token: t.addr_count for trange in address
162                     for t in self.query.get_partials_list(trange)}
163
164         if not partials:
165             # can happen when none of the partials is indexed
166             return
167
168         if expected_count < 8000:
169             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
170                                                  list(partials), lookups.Restrict))
171         elif len(partials) != 1 or list(partials.values())[0] < 10000:
172             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
173                                                  list(partials), lookups.LookupAll))
174         else:
175             addr_fulls = [t.token for t
176                           in self.query.get_tokens(address[0], qmod.TOKEN_WORD)]
177             if len(addr_fulls) > 5:
178                 return
179             sdata.lookups.append(
180                 dbf.FieldLookup('nameaddress_vector', addr_fulls, lookups.LookupAny))
181
182         sdata.housenumbers = dbf.WeightedStrings([], [])
183         yield dbs.PlaceSearch(0.05, sdata, expected_count)
184
185     def build_name_search(self, sdata: dbf.SearchData,
186                           name: qmod.TokenRange, address: List[qmod.TokenRange],
187                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
188         """ Build abstract search queries for simple name or address searches.
189         """
190         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
191             ranking = self.get_name_ranking(name)
192             name_penalty = ranking.normalize_penalty()
193             if ranking.rankings:
194                 sdata.rankings.append(ranking)
195             for penalty, count, lookup in self.yield_lookups(name, address):
196                 sdata.lookups = lookup
197                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
198
199     def yield_lookups(self, name: qmod.TokenRange, address: List[qmod.TokenRange]
200                       ) -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
201         """ Yield all variants how the given name and address should best
202             be searched for. This takes into account how frequent the terms
203             are and tries to find a lookup that optimizes index use.
204         """
205         penalty = 0.0  # extra penalty
206         name_partials = {t.token: t for t in self.query.get_partials_list(name)}
207
208         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
209         addr_tokens = list({t.token for t in addr_partials})
210
211         exp_count = min(t.count for t in name_partials.values()) / (3**(len(name_partials) - 1))
212
213         if (len(name_partials) > 3 or exp_count < 8000):
214             yield penalty, exp_count, dbf.lookup_by_names(list(name_partials.keys()), addr_tokens)
215             return
216
217         addr_count = min(t.addr_count for t in addr_partials) if addr_partials else 50000
218         # Partial term to frequent. Try looking up by rare full names first.
219         name_fulls = self.query.get_tokens(name, qmod.TOKEN_WORD)
220         if name_fulls:
221             fulls_count = sum(t.count for t in name_fulls)
222
223             if fulls_count < 80000 or addr_count < 50000:
224                 yield penalty, fulls_count / (2**len(addr_tokens)), \
225                     self.get_full_name_ranking(name_fulls, addr_partials,
226                                                fulls_count > 30000 / max(1, len(addr_tokens)))
227
228         # To catch remaining results, lookup by name and address
229         # We only do this if there is a reasonable number of results expected.
230         exp_count /= 2**len(addr_tokens)
231         if exp_count < 10000 and addr_count < 20000:
232             penalty += 0.35 * max(1 if name_fulls else 0.1,
233                                   5 - len(name_partials) - len(addr_tokens))
234             yield penalty, exp_count, \
235                 self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
236
237     def get_name_address_ranking(self, name_tokens: List[int],
238                                  addr_partials: List[qmod.Token]) -> List[dbf.FieldLookup]:
239         """ Create a ranking expression looking up by name and address.
240         """
241         lookup = [dbf.FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
242
243         addr_restrict_tokens = []
244         addr_lookup_tokens = []
245         for t in addr_partials:
246             if t.addr_count > 20000:
247                 addr_restrict_tokens.append(t.token)
248             else:
249                 addr_lookup_tokens.append(t.token)
250
251         if addr_restrict_tokens:
252             lookup.append(dbf.FieldLookup('nameaddress_vector',
253                                           addr_restrict_tokens, lookups.Restrict))
254         if addr_lookup_tokens:
255             lookup.append(dbf.FieldLookup('nameaddress_vector',
256                                           addr_lookup_tokens, lookups.LookupAll))
257
258         return lookup
259
260     def get_full_name_ranking(self, name_fulls: List[qmod.Token], addr_partials: List[qmod.Token],
261                               use_lookup: bool) -> List[dbf.FieldLookup]:
262         """ Create a ranking expression with full name terms and
263             additional address lookup. When 'use_lookup' is true, then
264             address lookups will use the index, when the occurrences are not
265             too many.
266         """
267         if use_lookup:
268             addr_restrict_tokens = []
269             addr_lookup_tokens = [t.token for t in addr_partials]
270         else:
271             addr_restrict_tokens = [t.token for t in addr_partials]
272             addr_lookup_tokens = []
273
274         return dbf.lookup_by_any_name([t.token for t in name_fulls],
275                                       addr_restrict_tokens, addr_lookup_tokens)
276
277     def get_name_ranking(self, trange: qmod.TokenRange,
278                          db_field: str = 'name_vector') -> dbf.FieldRanking:
279         """ Create a ranking expression for a name term in the given range.
280         """
281         name_fulls = self.query.get_tokens(trange, qmod.TOKEN_WORD)
282         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
283         ranks.sort(key=lambda r: r.penalty)
284         # Fallback, sum of penalty for partials
285         name_partials = self.query.get_partials_list(trange)
286         default = sum(t.penalty for t in name_partials) + 0.2
287         return dbf.FieldRanking(db_field, default, ranks)
288
289     def get_addr_ranking(self, trange: qmod.TokenRange) -> dbf.FieldRanking:
290         """ Create a list of ranking expressions for an address term
291             for the given ranges.
292         """
293         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
294         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
295         ranks: List[dbf.RankedTokens] = []
296
297         while todo:
298             neglen, pos, rank = heapq.heappop(todo)
299             for tlist in self.query.nodes[pos].starting:
300                 if tlist.ttype in (qmod.TOKEN_PARTIAL, qmod.TOKEN_WORD):
301                     if tlist.end < trange.end:
302                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
303                         if tlist.ttype == qmod.TOKEN_PARTIAL:
304                             penalty = rank.penalty + chgpenalty \
305                                       + max(t.penalty for t in tlist.tokens)
306                             heapq.heappush(todo, (neglen - 1, tlist.end,
307                                                   dbf.RankedTokens(penalty, rank.tokens)))
308                         else:
309                             for t in tlist.tokens:
310                                 heapq.heappush(todo, (neglen - 1, tlist.end,
311                                                       rank.with_token(t, chgpenalty)))
312                     elif tlist.end == trange.end:
313                         if tlist.ttype == qmod.TOKEN_PARTIAL:
314                             ranks.append(dbf.RankedTokens(rank.penalty
315                                                           + max(t.penalty for t in tlist.tokens),
316                                                           rank.tokens))
317                         else:
318                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
319                         if len(ranks) >= 10:
320                             # Too many variants, bail out and only add
321                             # Worst-case Fallback: sum of penalty of partials
322                             name_partials = self.query.get_partials_list(trange)
323                             default = sum(t.penalty for t in name_partials) + 0.2
324                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
325                             # Bail out of outer loop
326                             todo.clear()
327                             break
328
329         ranks.sort(key=lambda r: len(r.tokens))
330         default = ranks[0].penalty + 0.3
331         del ranks[0]
332         ranks.sort(key=lambda r: r.penalty)
333
334         return dbf.FieldRanking('nameaddress_vector', default, ranks)
335
336     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
337         """ Collect the tokens for the non-name search fields in the
338             assignment.
339         """
340         sdata = dbf.SearchData()
341         sdata.penalty = assignment.penalty
342         if assignment.country:
343             tokens = self.get_country_tokens(assignment.country)
344             if not tokens:
345                 return None
346             sdata.set_strings('countries', tokens)
347         elif self.details.countries:
348             sdata.countries = dbf.WeightedStrings(self.details.countries,
349                                                   [0.0] * len(self.details.countries))
350         if assignment.housenumber:
351             sdata.set_strings('housenumbers',
352                               self.query.get_tokens(assignment.housenumber,
353                                                     qmod.TOKEN_HOUSENUMBER))
354         if assignment.postcode:
355             sdata.set_strings('postcodes',
356                               self.query.get_tokens(assignment.postcode,
357                                                     qmod.TOKEN_POSTCODE))
358         if assignment.qualifier:
359             tokens = self.get_qualifier_tokens(assignment.qualifier)
360             if not tokens:
361                 return None
362             sdata.set_qualifiers(tokens)
363         elif self.details.categories:
364             sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
365                                                       [0.0] * len(self.details.categories))
366
367         if assignment.address:
368             if not assignment.name and assignment.housenumber:
369                 # housenumber search: the first item needs to be handled like
370                 # a name in ranking or penalties are not comparable with
371                 # normal searches.
372                 sdata.set_ranking([self.get_name_ranking(assignment.address[0],
373                                                          db_field='nameaddress_vector')]
374                                   + [self.get_addr_ranking(r) for r in assignment.address[1:]])
375             else:
376                 sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
377         else:
378             sdata.rankings = []
379
380         return sdata
381
382     def get_country_tokens(self, trange: qmod.TokenRange) -> List[qmod.Token]:
383         """ Return the list of country tokens for the given range,
384             optionally filtered by the country list from the details
385             parameters.
386         """
387         tokens = self.query.get_tokens(trange, qmod.TOKEN_COUNTRY)
388         if self.details.countries:
389             tokens = [t for t in tokens if t.lookup_word in self.details.countries]
390
391         return tokens
392
393     def get_qualifier_tokens(self, trange: qmod.TokenRange) -> List[qmod.Token]:
394         """ Return the list of qualifier tokens for the given range,
395             optionally filtered by the qualifier list from the details
396             parameters.
397         """
398         tokens = self.query.get_tokens(trange, qmod.TOKEN_QUALIFIER)
399         if self.details.categories:
400             tokens = [t for t in tokens if t.get_category() in self.details.categories]
401
402         return tokens
403
404     def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
405         """ Collect tokens for near items search or use the categories
406             requested per parameter.
407             Returns None if no category search is requested.
408         """
409         if assignment.near_item:
410             tokens: Dict[Tuple[str, str], float] = {}
411             for t in self.query.get_tokens(assignment.near_item, qmod.TOKEN_NEAR_ITEM):
412                 cat = t.get_category()
413                 # The category of a near search will be that of near_item.
414                 # Thus, if search is restricted to a category parameter,
415                 # the two sets must intersect.
416                 if (not self.details.categories or cat in self.details.categories)\
417                    and t.penalty < tokens.get(cat, 1000.0):
418                     tokens[cat] = t.penalty
419             return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
420
421         return None
422
423
424 PENALTY_WORDCHANGE = {
425     qmod.BREAK_START: 0.0,
426     qmod.BREAK_END: 0.0,
427     qmod.BREAK_PHRASE: 0.0,
428     qmod.BREAK_SOFT_PHRASE: 0.0,
429     qmod.BREAK_WORD: 0.1,
430     qmod.BREAK_PART: 0.2,
431     qmod.BREAK_TOKEN: 0.4
432 }