]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/db_search_builder.py
Merge pull request #3586 from lonvia/reduce-lookup-calls
[nominatim.git] / src / nominatim_api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Conversion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator, Dict
11 import heapq
12
13 from ..types import SearchDetails, DataLayer
14 from .query import QueryStruct, Token, TokenType, TokenRange, BreakType
15 from .token_assignment import TokenAssignment
16 from . import db_search_fields as dbf
17 from . import db_searches as dbs
18 from . import db_search_lookups as lookups
19
20
21 def wrap_near_search(categories: List[Tuple[str, str]],
22                      search: dbs.AbstractSearch) -> dbs.NearSearch:
23     """ Create a new search that wraps the given search in a search
24         for near places of the given category.
25     """
26     return dbs.NearSearch(penalty=search.penalty,
27                           categories=dbf.WeightedCategories(categories,
28                                                             [0.0] * len(categories)),
29                           search=search)
30
31
32 def build_poi_search(category: List[Tuple[str, str]],
33                      countries: Optional[List[str]]) -> dbs.PoiSearch:
34     """ Create a new search for places by the given category, possibly
35         constraint to the given countries.
36     """
37     if countries:
38         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
39     else:
40         ccs = dbf.WeightedStrings([], [])
41
42     class _PoiData(dbf.SearchData):
43         penalty = 0.0
44         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
45         countries = ccs
46
47     return dbs.PoiSearch(_PoiData())
48
49
50 class SearchBuilder:
51     """ Build the abstract search queries from token assignments.
52     """
53
54     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
55         self.query = query
56         self.details = details
57
58     @property
59     def configured_for_country(self) -> bool:
60         """ Return true if the search details are configured to
61             allow countries in the result.
62         """
63         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
64             and self.details.layer_enabled(DataLayer.ADDRESS)
65
66     @property
67     def configured_for_postcode(self) -> bool:
68         """ Return true if the search details are configured to
69             allow postcodes in the result.
70         """
71         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
72             and self.details.layer_enabled(DataLayer.ADDRESS)
73
74     @property
75     def configured_for_housenumbers(self) -> bool:
76         """ Return true if the search details are configured to
77             allow addresses in the result.
78         """
79         return self.details.max_rank >= 30 \
80             and self.details.layer_enabled(DataLayer.ADDRESS)
81
82     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
83         """ Yield all possible abstract searches for the given token assignment.
84         """
85         sdata = self.get_search_data(assignment)
86         if sdata is None:
87             return
88
89         near_items = self.get_near_items(assignment)
90         if near_items is not None and not near_items:
91             return  # impossible combination of near items and category parameter
92
93         if assignment.name is None:
94             if near_items and not sdata.postcodes:
95                 sdata.qualifiers = near_items
96                 near_items = None
97                 builder = self.build_poi_search(sdata)
98             elif assignment.housenumber:
99                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
100                                                    TokenType.HOUSENUMBER)
101                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
102             else:
103                 builder = self.build_special_search(sdata, assignment.address,
104                                                     bool(near_items))
105         else:
106             builder = self.build_name_search(sdata, assignment.name, assignment.address,
107                                              bool(near_items))
108
109         if near_items:
110             penalty = min(near_items.penalties)
111             near_items.penalties = [p - penalty for p in near_items.penalties]
112             for search in builder:
113                 search_penalty = search.penalty
114                 search.penalty = 0.0
115                 yield dbs.NearSearch(penalty + assignment.penalty + search_penalty,
116                                      near_items, search)
117         else:
118             for search in builder:
119                 search.penalty += assignment.penalty
120                 yield search
121
122     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
123         """ Build abstract search query for a simple category search.
124             This kind of search requires an additional geographic constraint.
125         """
126         if not sdata.housenumbers \
127            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
128             yield dbs.PoiSearch(sdata)
129
130     def build_special_search(self, sdata: dbf.SearchData,
131                              address: List[TokenRange],
132                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
133         """ Build abstract search queries for searches that do not involve
134             a named place.
135         """
136         if sdata.qualifiers:
137             # No special searches over qualifiers supported.
138             return
139
140         if sdata.countries and not address and not sdata.postcodes \
141            and self.configured_for_country:
142             yield dbs.CountrySearch(sdata)
143
144         if sdata.postcodes and (is_category or self.configured_for_postcode):
145             penalty = 0.0 if sdata.countries else 0.1
146             if address:
147                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
148                                                  [t.token for r in address
149                                                   for t in self.query.get_partials_list(r)],
150                                                  lookups.Restrict)]
151                 penalty += 0.2
152             yield dbs.PostcodeSearch(penalty, sdata)
153
154     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
155                                  address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
156         """ Build a simple address search for special entries where the
157             housenumber is the main name token.
158         """
159         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], lookups.LookupAny)]
160         expected_count = sum(t.count for t in hnrs)
161
162         partials = {t.token: t.addr_count for trange in address
163                     for t in self.query.get_partials_list(trange)}
164
165         if not partials:
166             # can happen when none of the partials is indexed
167             return
168
169         if expected_count < 8000:
170             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
171                                                  list(partials), lookups.Restrict))
172         elif len(partials) != 1 or list(partials.values())[0] < 10000:
173             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
174                                                  list(partials), lookups.LookupAll))
175         else:
176             addr_fulls = [t.token for t
177                           in self.query.get_tokens(address[0], TokenType.WORD)]
178             if len(addr_fulls) > 5:
179                 return
180             sdata.lookups.append(
181                 dbf.FieldLookup('nameaddress_vector', addr_fulls, lookups.LookupAny))
182
183         sdata.housenumbers = dbf.WeightedStrings([], [])
184         yield dbs.PlaceSearch(0.05, sdata, expected_count)
185
186     def build_name_search(self, sdata: dbf.SearchData,
187                           name: TokenRange, address: List[TokenRange],
188                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
189         """ Build abstract search queries for simple name or address searches.
190         """
191         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
192             ranking = self.get_name_ranking(name)
193             name_penalty = ranking.normalize_penalty()
194             if ranking.rankings:
195                 sdata.rankings.append(ranking)
196             for penalty, count, lookup in self.yield_lookups(name, address):
197                 sdata.lookups = lookup
198                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
199
200     def yield_lookups(self, name: TokenRange, address: List[TokenRange]
201                       ) -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
202         """ Yield all variants how the given name and address should best
203             be searched for. This takes into account how frequent the terms
204             are and tries to find a lookup that optimizes index use.
205         """
206         penalty = 0.0  # extra penalty
207         name_partials = {t.token: t for t in self.query.get_partials_list(name)}
208
209         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
210         addr_tokens = list({t.token for t in addr_partials})
211
212         exp_count = min(t.count for t in name_partials.values()) / (2**(len(name_partials) - 1))
213
214         if (len(name_partials) > 3 or exp_count < 8000):
215             yield penalty, exp_count, dbf.lookup_by_names(list(name_partials.keys()), addr_tokens)
216             return
217
218         addr_count = min(t.addr_count for t in addr_partials) if addr_partials else 30000
219         # Partial term to frequent. Try looking up by rare full names first.
220         name_fulls = self.query.get_tokens(name, TokenType.WORD)
221         if name_fulls:
222             fulls_count = sum(t.count for t in name_fulls)
223
224             if fulls_count < 50000 or addr_count < 30000:
225                 yield penalty, fulls_count / (2**len(addr_tokens)), \
226                     self.get_full_name_ranking(name_fulls, addr_partials,
227                                                fulls_count > 30000 / max(1, len(addr_tokens)))
228
229         # To catch remaining results, lookup by name and address
230         # We only do this if there is a reasonable number of results expected.
231         exp_count = exp_count / (2**len(addr_tokens)) if addr_tokens else exp_count
232         if exp_count < 10000 and addr_count < 20000:
233             penalty += 0.35 * max(1 if name_fulls else 0.1,
234                                   5 - len(name_partials) - len(addr_tokens))
235             yield penalty, exp_count, \
236                 self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
237
238     def get_name_address_ranking(self, name_tokens: List[int],
239                                  addr_partials: List[Token]) -> List[dbf.FieldLookup]:
240         """ Create a ranking expression looking up by name and address.
241         """
242         lookup = [dbf.FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
243
244         addr_restrict_tokens = []
245         addr_lookup_tokens = []
246         for t in addr_partials:
247             if t.addr_count > 20000:
248                 addr_restrict_tokens.append(t.token)
249             else:
250                 addr_lookup_tokens.append(t.token)
251
252         if addr_restrict_tokens:
253             lookup.append(dbf.FieldLookup('nameaddress_vector',
254                                           addr_restrict_tokens, lookups.Restrict))
255         if addr_lookup_tokens:
256             lookup.append(dbf.FieldLookup('nameaddress_vector',
257                                           addr_lookup_tokens, lookups.LookupAll))
258
259         return lookup
260
261     def get_full_name_ranking(self, name_fulls: List[Token], addr_partials: List[Token],
262                               use_lookup: bool) -> List[dbf.FieldLookup]:
263         """ Create a ranking expression with full name terms and
264             additional address lookup. When 'use_lookup' is true, then
265             address lookups will use the index, when the occurrences are not
266             too many.
267         """
268         # At this point drop unindexed partials from the address.
269         # This might yield wrong results, nothing we can do about that.
270         if use_lookup:
271             addr_restrict_tokens = []
272             addr_lookup_tokens = []
273             for t in addr_partials:
274                 if t.addr_count > 20000:
275                     addr_restrict_tokens.append(t.token)
276                 else:
277                     addr_lookup_tokens.append(t.token)
278         else:
279             addr_restrict_tokens = [t.token for t in addr_partials]
280             addr_lookup_tokens = []
281
282         return dbf.lookup_by_any_name([t.token for t in name_fulls],
283                                       addr_restrict_tokens, addr_lookup_tokens)
284
285     def get_name_ranking(self, trange: TokenRange,
286                          db_field: str = 'name_vector') -> dbf.FieldRanking:
287         """ Create a ranking expression for a name term in the given range.
288         """
289         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
290         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
291         ranks.sort(key=lambda r: r.penalty)
292         # Fallback, sum of penalty for partials
293         name_partials = self.query.get_partials_list(trange)
294         default = sum(t.penalty for t in name_partials) + 0.2
295         return dbf.FieldRanking(db_field, default, ranks)
296
297     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
298         """ Create a list of ranking expressions for an address term
299             for the given ranges.
300         """
301         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
302         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
303         ranks: List[dbf.RankedTokens] = []
304
305         while todo:
306             neglen, pos, rank = heapq.heappop(todo)
307             for tlist in self.query.nodes[pos].starting:
308                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
309                     if tlist.end < trange.end:
310                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
311                         if tlist.ttype == TokenType.PARTIAL:
312                             penalty = rank.penalty + chgpenalty \
313                                       + max(t.penalty for t in tlist.tokens)
314                             heapq.heappush(todo, (neglen - 1, tlist.end,
315                                                   dbf.RankedTokens(penalty, rank.tokens)))
316                         else:
317                             for t in tlist.tokens:
318                                 heapq.heappush(todo, (neglen - 1, tlist.end,
319                                                       rank.with_token(t, chgpenalty)))
320                     elif tlist.end == trange.end:
321                         if tlist.ttype == TokenType.PARTIAL:
322                             ranks.append(dbf.RankedTokens(rank.penalty
323                                                           + max(t.penalty for t in tlist.tokens),
324                                                           rank.tokens))
325                         else:
326                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
327                         if len(ranks) >= 10:
328                             # Too many variants, bail out and only add
329                             # Worst-case Fallback: sum of penalty of partials
330                             name_partials = self.query.get_partials_list(trange)
331                             default = sum(t.penalty for t in name_partials) + 0.2
332                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
333                             # Bail out of outer loop
334                             todo.clear()
335                             break
336
337         ranks.sort(key=lambda r: len(r.tokens))
338         default = ranks[0].penalty + 0.3
339         del ranks[0]
340         ranks.sort(key=lambda r: r.penalty)
341
342         return dbf.FieldRanking('nameaddress_vector', default, ranks)
343
344     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
345         """ Collect the tokens for the non-name search fields in the
346             assignment.
347         """
348         sdata = dbf.SearchData()
349         sdata.penalty = assignment.penalty
350         if assignment.country:
351             tokens = self.get_country_tokens(assignment.country)
352             if not tokens:
353                 return None
354             sdata.set_strings('countries', tokens)
355         elif self.details.countries:
356             sdata.countries = dbf.WeightedStrings(self.details.countries,
357                                                   [0.0] * len(self.details.countries))
358         if assignment.housenumber:
359             sdata.set_strings('housenumbers',
360                               self.query.get_tokens(assignment.housenumber,
361                                                     TokenType.HOUSENUMBER))
362         if assignment.postcode:
363             sdata.set_strings('postcodes',
364                               self.query.get_tokens(assignment.postcode,
365                                                     TokenType.POSTCODE))
366         if assignment.qualifier:
367             tokens = self.get_qualifier_tokens(assignment.qualifier)
368             if not tokens:
369                 return None
370             sdata.set_qualifiers(tokens)
371         elif self.details.categories:
372             sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
373                                                       [0.0] * len(self.details.categories))
374
375         if assignment.address:
376             if not assignment.name and assignment.housenumber:
377                 # housenumber search: the first item needs to be handled like
378                 # a name in ranking or penalties are not comparable with
379                 # normal searches.
380                 sdata.set_ranking([self.get_name_ranking(assignment.address[0],
381                                                          db_field='nameaddress_vector')]
382                                   + [self.get_addr_ranking(r) for r in assignment.address[1:]])
383             else:
384                 sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
385         else:
386             sdata.rankings = []
387
388         return sdata
389
390     def get_country_tokens(self, trange: TokenRange) -> List[Token]:
391         """ Return the list of country tokens for the given range,
392             optionally filtered by the country list from the details
393             parameters.
394         """
395         tokens = self.query.get_tokens(trange, TokenType.COUNTRY)
396         if self.details.countries:
397             tokens = [t for t in tokens if t.lookup_word in self.details.countries]
398
399         return tokens
400
401     def get_qualifier_tokens(self, trange: TokenRange) -> List[Token]:
402         """ Return the list of qualifier tokens for the given range,
403             optionally filtered by the qualifier list from the details
404             parameters.
405         """
406         tokens = self.query.get_tokens(trange, TokenType.QUALIFIER)
407         if self.details.categories:
408             tokens = [t for t in tokens if t.get_category() in self.details.categories]
409
410         return tokens
411
412     def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
413         """ Collect tokens for near items search or use the categories
414             requested per parameter.
415             Returns None if no category search is requested.
416         """
417         if assignment.near_item:
418             tokens: Dict[Tuple[str, str], float] = {}
419             for t in self.query.get_tokens(assignment.near_item, TokenType.NEAR_ITEM):
420                 cat = t.get_category()
421                 # The category of a near search will be that of near_item.
422                 # Thus, if search is restricted to a category parameter,
423                 # the two sets must intersect.
424                 if (not self.details.categories or cat in self.details.categories)\
425                    and t.penalty < tokens.get(cat, 1000.0):
426                     tokens[cat] = t.penalty
427             return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
428
429         return None
430
431
432 PENALTY_WORDCHANGE = {
433     BreakType.START: 0.0,
434     BreakType.END: 0.0,
435     BreakType.PHRASE: 0.0,
436     BreakType.WORD: 0.1,
437     BreakType.PART: 0.2,
438     BreakType.TOKEN: 0.4
439 }