]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/db_search_builder.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / src / nominatim_api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Conversion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator, Dict
11 import heapq
12
13 from ..types import SearchDetails, DataLayer
14 from . import query as qmod
15 from .token_assignment import TokenAssignment
16 from . import db_search_fields as dbf
17 from . import db_searches as dbs
18 from . import db_search_lookups as lookups
19
20
21 def wrap_near_search(categories: List[Tuple[str, str]],
22                      search: dbs.AbstractSearch) -> dbs.NearSearch:
23     """ Create a new search that wraps the given search in a search
24         for near places of the given category.
25     """
26     return dbs.NearSearch(penalty=search.penalty,
27                           categories=dbf.WeightedCategories(categories,
28                                                             [0.0] * len(categories)),
29                           search=search)
30
31
32 def build_poi_search(category: List[Tuple[str, str]],
33                      countries: Optional[List[str]]) -> dbs.PoiSearch:
34     """ Create a new search for places by the given category, possibly
35         constraint to the given countries.
36     """
37     if countries:
38         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
39     else:
40         ccs = dbf.WeightedStrings([], [])
41
42     class _PoiData(dbf.SearchData):
43         penalty = 0.0
44         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
45         countries = ccs
46
47     return dbs.PoiSearch(_PoiData())
48
49
50 class SearchBuilder:
51     """ Build the abstract search queries from token assignments.
52     """
53
54     def __init__(self, query: qmod.QueryStruct, details: SearchDetails) -> None:
55         self.query = query
56         self.details = details
57
58     @property
59     def configured_for_country(self) -> bool:
60         """ Return true if the search details are configured to
61             allow countries in the result.
62         """
63         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
64             and self.details.layer_enabled(DataLayer.ADDRESS)
65
66     @property
67     def configured_for_postcode(self) -> bool:
68         """ Return true if the search details are configured to
69             allow postcodes in the result.
70         """
71         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
72             and self.details.layer_enabled(DataLayer.ADDRESS)
73
74     @property
75     def configured_for_housenumbers(self) -> bool:
76         """ Return true if the search details are configured to
77             allow addresses in the result.
78         """
79         return self.details.max_rank >= 30 \
80             and self.details.layer_enabled(DataLayer.ADDRESS)
81
82     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
83         """ Yield all possible abstract searches for the given token assignment.
84         """
85         sdata = self.get_search_data(assignment)
86         if sdata is None:
87             return
88
89         near_items = self.get_near_items(assignment)
90         if near_items is not None and not near_items:
91             return  # impossible combination of near items and category parameter
92
93         if assignment.name is None:
94             if near_items and not sdata.postcodes:
95                 sdata.qualifiers = near_items
96                 near_items = None
97                 builder = self.build_poi_search(sdata)
98             elif assignment.housenumber:
99                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
100                                                    qmod.TOKEN_HOUSENUMBER)
101                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
102             else:
103                 builder = self.build_special_search(sdata, assignment.address,
104                                                     bool(near_items))
105         else:
106             builder = self.build_name_search(sdata, assignment.name, assignment.address,
107                                              bool(near_items))
108
109         if near_items:
110             penalty = min(near_items.penalties)
111             near_items.penalties = [p - penalty for p in near_items.penalties]
112             for search in builder:
113                 search_penalty = search.penalty
114                 search.penalty = 0.0
115                 yield dbs.NearSearch(penalty + assignment.penalty + search_penalty,
116                                      near_items, search)
117         else:
118             for search in builder:
119                 search.penalty += assignment.penalty
120                 yield search
121
122     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
123         """ Build abstract search query for a simple category search.
124             This kind of search requires an additional geographic constraint.
125         """
126         if not sdata.housenumbers \
127            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
128             yield dbs.PoiSearch(sdata)
129
130     def build_special_search(self, sdata: dbf.SearchData,
131                              address: List[qmod.TokenRange],
132                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
133         """ Build abstract search queries for searches that do not involve
134             a named place.
135         """
136         if sdata.qualifiers:
137             # No special searches over qualifiers supported.
138             return
139
140         if sdata.countries and not address and not sdata.postcodes \
141            and self.configured_for_country:
142             yield dbs.CountrySearch(sdata)
143
144         if sdata.postcodes and (is_category or self.configured_for_postcode):
145             penalty = 0.0 if sdata.countries else 0.1
146             if address:
147                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
148                                                  [t.token for r in address
149                                                   for t in self.query.get_partials_list(r)],
150                                                  lookups.Restrict)]
151             yield dbs.PostcodeSearch(penalty, sdata)
152
153     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[qmod.Token],
154                                  address: List[qmod.TokenRange]) -> Iterator[dbs.AbstractSearch]:
155         """ Build a simple address search for special entries where the
156             housenumber is the main name token.
157         """
158         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], lookups.LookupAny)]
159         expected_count = sum(t.count for t in hnrs)
160
161         partials = {t.token: t.addr_count for trange in address
162                     for t in self.query.get_partials_list(trange)}
163
164         if not partials:
165             # can happen when none of the partials is indexed
166             return
167
168         if expected_count < 8000:
169             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
170                                                  list(partials), lookups.Restrict))
171         elif len(partials) != 1 or list(partials.values())[0] < 10000:
172             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
173                                                  list(partials), lookups.LookupAll))
174         else:
175             addr_fulls = [t.token for t
176                           in self.query.get_tokens(address[0], qmod.TOKEN_WORD)]
177             if len(addr_fulls) > 5:
178                 return
179             sdata.lookups.append(
180                 dbf.FieldLookup('nameaddress_vector', addr_fulls, lookups.LookupAny))
181
182         sdata.housenumbers = dbf.WeightedStrings([], [])
183         yield dbs.PlaceSearch(0.05, sdata, expected_count)
184
185     def build_name_search(self, sdata: dbf.SearchData,
186                           name: qmod.TokenRange, address: List[qmod.TokenRange],
187                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
188         """ Build abstract search queries for simple name or address searches.
189         """
190         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
191             ranking = self.get_name_ranking(name)
192             name_penalty = ranking.normalize_penalty()
193             if ranking.rankings:
194                 sdata.rankings.append(ranking)
195             for penalty, count, lookup in self.yield_lookups(name, address):
196                 sdata.lookups = lookup
197                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
198
199     def yield_lookups(self, name: qmod.TokenRange, address: List[qmod.TokenRange]
200                       ) -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
201         """ Yield all variants how the given name and address should best
202             be searched for. This takes into account how frequent the terms
203             are and tries to find a lookup that optimizes index use.
204         """
205         penalty = 0.0  # extra penalty
206         name_partials = {t.token: t for t in self.query.get_partials_list(name)}
207
208         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
209         addr_tokens = list({t.token for t in addr_partials})
210
211         exp_count = min(t.count for t in name_partials.values()) / (2**(len(name_partials) - 1))
212
213         if (len(name_partials) > 3 or exp_count < 8000):
214             yield penalty, exp_count, dbf.lookup_by_names(list(name_partials.keys()), addr_tokens)
215             return
216
217         addr_count = min(t.addr_count for t in addr_partials) if addr_partials else 50000
218         # Partial term to frequent. Try looking up by rare full names first.
219         name_fulls = self.query.get_tokens(name, qmod.TOKEN_WORD)
220         if name_fulls:
221             fulls_count = sum(t.count for t in name_fulls)
222
223             if fulls_count < 80000 or addr_count < 50000:
224                 yield penalty, fulls_count / (2**len(addr_tokens)), \
225                     self.get_full_name_ranking(name_fulls, addr_partials,
226                                                fulls_count > 30000 / max(1, len(addr_tokens)))
227
228         # To catch remaining results, lookup by name and address
229         # We only do this if there is a reasonable number of results expected.
230         exp_count = exp_count / (2**len(addr_tokens)) if addr_tokens else exp_count
231         if exp_count < 10000 and addr_count < 20000:
232             penalty += 0.35 * max(1 if name_fulls else 0.1,
233                                   5 - len(name_partials) - len(addr_tokens))
234             yield penalty, exp_count, \
235                 self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
236
237     def get_name_address_ranking(self, name_tokens: List[int],
238                                  addr_partials: List[qmod.Token]) -> List[dbf.FieldLookup]:
239         """ Create a ranking expression looking up by name and address.
240         """
241         lookup = [dbf.FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
242
243         addr_restrict_tokens = []
244         addr_lookup_tokens = []
245         for t in addr_partials:
246             if t.addr_count > 20000:
247                 addr_restrict_tokens.append(t.token)
248             else:
249                 addr_lookup_tokens.append(t.token)
250
251         if addr_restrict_tokens:
252             lookup.append(dbf.FieldLookup('nameaddress_vector',
253                                           addr_restrict_tokens, lookups.Restrict))
254         if addr_lookup_tokens:
255             lookup.append(dbf.FieldLookup('nameaddress_vector',
256                                           addr_lookup_tokens, lookups.LookupAll))
257
258         return lookup
259
260     def get_full_name_ranking(self, name_fulls: List[qmod.Token], addr_partials: List[qmod.Token],
261                               use_lookup: bool) -> List[dbf.FieldLookup]:
262         """ Create a ranking expression with full name terms and
263             additional address lookup. When 'use_lookup' is true, then
264             address lookups will use the index, when the occurrences are not
265             too many.
266         """
267         # At this point drop unindexed partials from the address.
268         # This might yield wrong results, nothing we can do about that.
269         if use_lookup:
270             addr_restrict_tokens = []
271             addr_lookup_tokens = [t.token for t in addr_partials]
272         else:
273             addr_restrict_tokens = [t.token for t in addr_partials]
274             addr_lookup_tokens = []
275
276         return dbf.lookup_by_any_name([t.token for t in name_fulls],
277                                       addr_restrict_tokens, addr_lookup_tokens)
278
279     def get_name_ranking(self, trange: qmod.TokenRange,
280                          db_field: str = 'name_vector') -> dbf.FieldRanking:
281         """ Create a ranking expression for a name term in the given range.
282         """
283         name_fulls = self.query.get_tokens(trange, qmod.TOKEN_WORD)
284         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
285         ranks.sort(key=lambda r: r.penalty)
286         # Fallback, sum of penalty for partials
287         name_partials = self.query.get_partials_list(trange)
288         default = sum(t.penalty for t in name_partials) + 0.2
289         return dbf.FieldRanking(db_field, default, ranks)
290
291     def get_addr_ranking(self, trange: qmod.TokenRange) -> dbf.FieldRanking:
292         """ Create a list of ranking expressions for an address term
293             for the given ranges.
294         """
295         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
296         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
297         ranks: List[dbf.RankedTokens] = []
298
299         while todo:
300             neglen, pos, rank = heapq.heappop(todo)
301             for tlist in self.query.nodes[pos].starting:
302                 if tlist.ttype in (qmod.TOKEN_PARTIAL, qmod.TOKEN_WORD):
303                     if tlist.end < trange.end:
304                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
305                         if tlist.ttype == qmod.TOKEN_PARTIAL:
306                             penalty = rank.penalty + chgpenalty \
307                                       + max(t.penalty for t in tlist.tokens)
308                             heapq.heappush(todo, (neglen - 1, tlist.end,
309                                                   dbf.RankedTokens(penalty, rank.tokens)))
310                         else:
311                             for t in tlist.tokens:
312                                 heapq.heappush(todo, (neglen - 1, tlist.end,
313                                                       rank.with_token(t, chgpenalty)))
314                     elif tlist.end == trange.end:
315                         if tlist.ttype == qmod.TOKEN_PARTIAL:
316                             ranks.append(dbf.RankedTokens(rank.penalty
317                                                           + max(t.penalty for t in tlist.tokens),
318                                                           rank.tokens))
319                         else:
320                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
321                         if len(ranks) >= 10:
322                             # Too many variants, bail out and only add
323                             # Worst-case Fallback: sum of penalty of partials
324                             name_partials = self.query.get_partials_list(trange)
325                             default = sum(t.penalty for t in name_partials) + 0.2
326                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
327                             # Bail out of outer loop
328                             todo.clear()
329                             break
330
331         ranks.sort(key=lambda r: len(r.tokens))
332         default = ranks[0].penalty + 0.3
333         del ranks[0]
334         ranks.sort(key=lambda r: r.penalty)
335
336         return dbf.FieldRanking('nameaddress_vector', default, ranks)
337
338     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
339         """ Collect the tokens for the non-name search fields in the
340             assignment.
341         """
342         sdata = dbf.SearchData()
343         sdata.penalty = assignment.penalty
344         if assignment.country:
345             tokens = self.get_country_tokens(assignment.country)
346             if not tokens:
347                 return None
348             sdata.set_strings('countries', tokens)
349         elif self.details.countries:
350             sdata.countries = dbf.WeightedStrings(self.details.countries,
351                                                   [0.0] * len(self.details.countries))
352         if assignment.housenumber:
353             sdata.set_strings('housenumbers',
354                               self.query.get_tokens(assignment.housenumber,
355                                                     qmod.TOKEN_HOUSENUMBER))
356         if assignment.postcode:
357             sdata.set_strings('postcodes',
358                               self.query.get_tokens(assignment.postcode,
359                                                     qmod.TOKEN_POSTCODE))
360         if assignment.qualifier:
361             tokens = self.get_qualifier_tokens(assignment.qualifier)
362             if not tokens:
363                 return None
364             sdata.set_qualifiers(tokens)
365         elif self.details.categories:
366             sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
367                                                       [0.0] * len(self.details.categories))
368
369         if assignment.address:
370             if not assignment.name and assignment.housenumber:
371                 # housenumber search: the first item needs to be handled like
372                 # a name in ranking or penalties are not comparable with
373                 # normal searches.
374                 sdata.set_ranking([self.get_name_ranking(assignment.address[0],
375                                                          db_field='nameaddress_vector')]
376                                   + [self.get_addr_ranking(r) for r in assignment.address[1:]])
377             else:
378                 sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
379         else:
380             sdata.rankings = []
381
382         return sdata
383
384     def get_country_tokens(self, trange: qmod.TokenRange) -> List[qmod.Token]:
385         """ Return the list of country tokens for the given range,
386             optionally filtered by the country list from the details
387             parameters.
388         """
389         tokens = self.query.get_tokens(trange, qmod.TOKEN_COUNTRY)
390         if self.details.countries:
391             tokens = [t for t in tokens if t.lookup_word in self.details.countries]
392
393         return tokens
394
395     def get_qualifier_tokens(self, trange: qmod.TokenRange) -> List[qmod.Token]:
396         """ Return the list of qualifier tokens for the given range,
397             optionally filtered by the qualifier list from the details
398             parameters.
399         """
400         tokens = self.query.get_tokens(trange, qmod.TOKEN_QUALIFIER)
401         if self.details.categories:
402             tokens = [t for t in tokens if t.get_category() in self.details.categories]
403
404         return tokens
405
406     def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
407         """ Collect tokens for near items search or use the categories
408             requested per parameter.
409             Returns None if no category search is requested.
410         """
411         if assignment.near_item:
412             tokens: Dict[Tuple[str, str], float] = {}
413             for t in self.query.get_tokens(assignment.near_item, qmod.TOKEN_NEAR_ITEM):
414                 cat = t.get_category()
415                 # The category of a near search will be that of near_item.
416                 # Thus, if search is restricted to a category parameter,
417                 # the two sets must intersect.
418                 if (not self.details.categories or cat in self.details.categories)\
419                    and t.penalty < tokens.get(cat, 1000.0):
420                     tokens[cat] = t.penalty
421             return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
422
423         return None
424
425
426 PENALTY_WORDCHANGE = {
427     qmod.BREAK_START: 0.0,
428     qmod.BREAK_END: 0.0,
429     qmod.BREAK_PHRASE: 0.0,
430     qmod.BREAK_SOFT_PHRASE: 0.0,
431     qmod.BREAK_WORD: 0.1,
432     qmod.BREAK_PART: 0.2,
433     qmod.BREAK_TOKEN: 0.4
434 }