]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/db_search_builder.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / src / nominatim_api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Conversion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator, Dict
11 import heapq
12
13 from ..types import SearchDetails, DataLayer
14 from .query import QueryStruct, Token, TokenType, TokenRange, BreakType
15 from .token_assignment import TokenAssignment
16 from . import db_search_fields as dbf
17 from . import db_searches as dbs
18 from . import db_search_lookups as lookups
19
20
21 def wrap_near_search(categories: List[Tuple[str, str]],
22                      search: dbs.AbstractSearch) -> dbs.NearSearch:
23     """ Create a new search that wraps the given search in a search
24         for near places of the given category.
25     """
26     return dbs.NearSearch(penalty=search.penalty,
27                           categories=dbf.WeightedCategories(categories,
28                                                             [0.0] * len(categories)),
29                           search=search)
30
31
32 def build_poi_search(category: List[Tuple[str, str]],
33                      countries: Optional[List[str]]) -> dbs.PoiSearch:
34     """ Create a new search for places by the given category, possibly
35         constraint to the given countries.
36     """
37     if countries:
38         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
39     else:
40         ccs = dbf.WeightedStrings([], [])
41
42     class _PoiData(dbf.SearchData):
43         penalty = 0.0
44         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
45         countries = ccs
46
47     return dbs.PoiSearch(_PoiData())
48
49
50 class SearchBuilder:
51     """ Build the abstract search queries from token assignments.
52     """
53
54     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
55         self.query = query
56         self.details = details
57
58     @property
59     def configured_for_country(self) -> bool:
60         """ Return true if the search details are configured to
61             allow countries in the result.
62         """
63         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
64             and self.details.layer_enabled(DataLayer.ADDRESS)
65
66     @property
67     def configured_for_postcode(self) -> bool:
68         """ Return true if the search details are configured to
69             allow postcodes in the result.
70         """
71         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
72             and self.details.layer_enabled(DataLayer.ADDRESS)
73
74     @property
75     def configured_for_housenumbers(self) -> bool:
76         """ Return true if the search details are configured to
77             allow addresses in the result.
78         """
79         return self.details.max_rank >= 30 \
80             and self.details.layer_enabled(DataLayer.ADDRESS)
81
82     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
83         """ Yield all possible abstract searches for the given token assignment.
84         """
85         sdata = self.get_search_data(assignment)
86         if sdata is None:
87             return
88
89         near_items = self.get_near_items(assignment)
90         if near_items is not None and not near_items:
91             return  # impossible combination of near items and category parameter
92
93         if assignment.name is None:
94             if near_items and not sdata.postcodes:
95                 sdata.qualifiers = near_items
96                 near_items = None
97                 builder = self.build_poi_search(sdata)
98             elif assignment.housenumber:
99                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
100                                                    TokenType.HOUSENUMBER)
101                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
102             else:
103                 builder = self.build_special_search(sdata, assignment.address,
104                                                     bool(near_items))
105         else:
106             builder = self.build_name_search(sdata, assignment.name, assignment.address,
107                                              bool(near_items))
108
109         if near_items:
110             penalty = min(near_items.penalties)
111             near_items.penalties = [p - penalty for p in near_items.penalties]
112             for search in builder:
113                 search_penalty = search.penalty
114                 search.penalty = 0.0
115                 yield dbs.NearSearch(penalty + assignment.penalty + search_penalty,
116                                      near_items, search)
117         else:
118             for search in builder:
119                 search.penalty += assignment.penalty
120                 yield search
121
122     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
123         """ Build abstract search query for a simple category search.
124             This kind of search requires an additional geographic constraint.
125         """
126         if not sdata.housenumbers \
127            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
128             yield dbs.PoiSearch(sdata)
129
130     def build_special_search(self, sdata: dbf.SearchData,
131                              address: List[TokenRange],
132                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
133         """ Build abstract search queries for searches that do not involve
134             a named place.
135         """
136         if sdata.qualifiers:
137             # No special searches over qualifiers supported.
138             return
139
140         if sdata.countries and not address and not sdata.postcodes \
141            and self.configured_for_country:
142             yield dbs.CountrySearch(sdata)
143
144         if sdata.postcodes and (is_category or self.configured_for_postcode):
145             penalty = 0.0 if sdata.countries else 0.1
146             if address:
147                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
148                                                  [t.token for r in address
149                                                   for t in self.query.get_partials_list(r)],
150                                                  lookups.Restrict)]
151                 penalty += 0.2
152             yield dbs.PostcodeSearch(penalty, sdata)
153
154     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
155                                  address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
156         """ Build a simple address search for special entries where the
157             housenumber is the main name token.
158         """
159         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], lookups.LookupAny)]
160         expected_count = sum(t.count for t in hnrs)
161
162         partials = {t.token: t.addr_count for trange in address
163                     for t in self.query.get_partials_list(trange)}
164
165         if not partials:
166             # can happen when none of the partials is indexed
167             return
168
169         if expected_count < 8000:
170             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
171                                                  list(partials), lookups.Restrict))
172         elif len(partials) != 1 or list(partials.values())[0] < 10000:
173             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
174                                                  list(partials), lookups.LookupAll))
175         else:
176             addr_fulls = [t.token for t
177                           in self.query.get_tokens(address[0], TokenType.WORD)]
178             if len(addr_fulls) > 5:
179                 return
180             sdata.lookups.append(
181                 dbf.FieldLookup('nameaddress_vector', addr_fulls, lookups.LookupAny))
182
183         sdata.housenumbers = dbf.WeightedStrings([], [])
184         yield dbs.PlaceSearch(0.05, sdata, expected_count)
185
186     def build_name_search(self, sdata: dbf.SearchData,
187                           name: TokenRange, address: List[TokenRange],
188                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
189         """ Build abstract search queries for simple name or address searches.
190         """
191         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
192             ranking = self.get_name_ranking(name)
193             name_penalty = ranking.normalize_penalty()
194             if ranking.rankings:
195                 sdata.rankings.append(ranking)
196             for penalty, count, lookup in self.yield_lookups(name, address):
197                 sdata.lookups = lookup
198                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
199
200     def yield_lookups(self, name: TokenRange, address: List[TokenRange]
201                       ) -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
202         """ Yield all variants how the given name and address should best
203             be searched for. This takes into account how frequent the terms
204             are and tries to find a lookup that optimizes index use.
205         """
206         penalty = 0.0  # extra penalty
207         name_partials = {t.token: t for t in self.query.get_partials_list(name)}
208
209         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
210         addr_tokens = list({t.token for t in addr_partials})
211
212         exp_count = min(t.count for t in name_partials.values()) / (2**(len(name_partials) - 1))
213
214         if (len(name_partials) > 3 or exp_count < 8000):
215             yield penalty, exp_count, dbf.lookup_by_names(list(name_partials.keys()), addr_tokens)
216             return
217
218         addr_count = min(t.addr_count for t in addr_partials) if addr_partials else 50000
219         # Partial term to frequent. Try looking up by rare full names first.
220         name_fulls = self.query.get_tokens(name, TokenType.WORD)
221         if name_fulls:
222             fulls_count = sum(t.count for t in name_fulls)
223
224             if fulls_count < 80000 or addr_count < 50000:
225                 yield penalty, fulls_count / (2**len(addr_tokens)), \
226                     self.get_full_name_ranking(name_fulls, addr_partials,
227                                                fulls_count > 30000 / max(1, len(addr_tokens)))
228
229         # To catch remaining results, lookup by name and address
230         # We only do this if there is a reasonable number of results expected.
231         exp_count = exp_count / (2**len(addr_tokens)) if addr_tokens else exp_count
232         if exp_count < 10000 and addr_count < 20000:
233             penalty += 0.35 * max(1 if name_fulls else 0.1,
234                                   5 - len(name_partials) - len(addr_tokens))
235             yield penalty, exp_count, \
236                 self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
237
238     def get_name_address_ranking(self, name_tokens: List[int],
239                                  addr_partials: List[Token]) -> List[dbf.FieldLookup]:
240         """ Create a ranking expression looking up by name and address.
241         """
242         lookup = [dbf.FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
243
244         addr_restrict_tokens = []
245         addr_lookup_tokens = []
246         for t in addr_partials:
247             if t.addr_count > 20000:
248                 addr_restrict_tokens.append(t.token)
249             else:
250                 addr_lookup_tokens.append(t.token)
251
252         if addr_restrict_tokens:
253             lookup.append(dbf.FieldLookup('nameaddress_vector',
254                                           addr_restrict_tokens, lookups.Restrict))
255         if addr_lookup_tokens:
256             lookup.append(dbf.FieldLookup('nameaddress_vector',
257                                           addr_lookup_tokens, lookups.LookupAll))
258
259         return lookup
260
261     def get_full_name_ranking(self, name_fulls: List[Token], addr_partials: List[Token],
262                               use_lookup: bool) -> List[dbf.FieldLookup]:
263         """ Create a ranking expression with full name terms and
264             additional address lookup. When 'use_lookup' is true, then
265             address lookups will use the index, when the occurrences are not
266             too many.
267         """
268         # At this point drop unindexed partials from the address.
269         # This might yield wrong results, nothing we can do about that.
270         if use_lookup:
271             addr_restrict_tokens = []
272             addr_lookup_tokens = [t.token for t in addr_partials]
273         else:
274             addr_restrict_tokens = [t.token for t in addr_partials]
275             addr_lookup_tokens = []
276
277         return dbf.lookup_by_any_name([t.token for t in name_fulls],
278                                       addr_restrict_tokens, addr_lookup_tokens)
279
280     def get_name_ranking(self, trange: TokenRange,
281                          db_field: str = 'name_vector') -> dbf.FieldRanking:
282         """ Create a ranking expression for a name term in the given range.
283         """
284         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
285         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
286         ranks.sort(key=lambda r: r.penalty)
287         # Fallback, sum of penalty for partials
288         name_partials = self.query.get_partials_list(trange)
289         default = sum(t.penalty for t in name_partials) + 0.2
290         return dbf.FieldRanking(db_field, default, ranks)
291
292     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
293         """ Create a list of ranking expressions for an address term
294             for the given ranges.
295         """
296         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
297         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
298         ranks: List[dbf.RankedTokens] = []
299
300         while todo:
301             neglen, pos, rank = heapq.heappop(todo)
302             for tlist in self.query.nodes[pos].starting:
303                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
304                     if tlist.end < trange.end:
305                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
306                         if tlist.ttype == TokenType.PARTIAL:
307                             penalty = rank.penalty + chgpenalty \
308                                       + max(t.penalty for t in tlist.tokens)
309                             heapq.heappush(todo, (neglen - 1, tlist.end,
310                                                   dbf.RankedTokens(penalty, rank.tokens)))
311                         else:
312                             for t in tlist.tokens:
313                                 heapq.heappush(todo, (neglen - 1, tlist.end,
314                                                       rank.with_token(t, chgpenalty)))
315                     elif tlist.end == trange.end:
316                         if tlist.ttype == TokenType.PARTIAL:
317                             ranks.append(dbf.RankedTokens(rank.penalty
318                                                           + max(t.penalty for t in tlist.tokens),
319                                                           rank.tokens))
320                         else:
321                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
322                         if len(ranks) >= 10:
323                             # Too many variants, bail out and only add
324                             # Worst-case Fallback: sum of penalty of partials
325                             name_partials = self.query.get_partials_list(trange)
326                             default = sum(t.penalty for t in name_partials) + 0.2
327                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
328                             # Bail out of outer loop
329                             todo.clear()
330                             break
331
332         ranks.sort(key=lambda r: len(r.tokens))
333         default = ranks[0].penalty + 0.3
334         del ranks[0]
335         ranks.sort(key=lambda r: r.penalty)
336
337         return dbf.FieldRanking('nameaddress_vector', default, ranks)
338
339     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
340         """ Collect the tokens for the non-name search fields in the
341             assignment.
342         """
343         sdata = dbf.SearchData()
344         sdata.penalty = assignment.penalty
345         if assignment.country:
346             tokens = self.get_country_tokens(assignment.country)
347             if not tokens:
348                 return None
349             sdata.set_strings('countries', tokens)
350         elif self.details.countries:
351             sdata.countries = dbf.WeightedStrings(self.details.countries,
352                                                   [0.0] * len(self.details.countries))
353         if assignment.housenumber:
354             sdata.set_strings('housenumbers',
355                               self.query.get_tokens(assignment.housenumber,
356                                                     TokenType.HOUSENUMBER))
357         if assignment.postcode:
358             sdata.set_strings('postcodes',
359                               self.query.get_tokens(assignment.postcode,
360                                                     TokenType.POSTCODE))
361         if assignment.qualifier:
362             tokens = self.get_qualifier_tokens(assignment.qualifier)
363             if not tokens:
364                 return None
365             sdata.set_qualifiers(tokens)
366         elif self.details.categories:
367             sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
368                                                       [0.0] * len(self.details.categories))
369
370         if assignment.address:
371             if not assignment.name and assignment.housenumber:
372                 # housenumber search: the first item needs to be handled like
373                 # a name in ranking or penalties are not comparable with
374                 # normal searches.
375                 sdata.set_ranking([self.get_name_ranking(assignment.address[0],
376                                                          db_field='nameaddress_vector')]
377                                   + [self.get_addr_ranking(r) for r in assignment.address[1:]])
378             else:
379                 sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
380         else:
381             sdata.rankings = []
382
383         return sdata
384
385     def get_country_tokens(self, trange: TokenRange) -> List[Token]:
386         """ Return the list of country tokens for the given range,
387             optionally filtered by the country list from the details
388             parameters.
389         """
390         tokens = self.query.get_tokens(trange, TokenType.COUNTRY)
391         if self.details.countries:
392             tokens = [t for t in tokens if t.lookup_word in self.details.countries]
393
394         return tokens
395
396     def get_qualifier_tokens(self, trange: TokenRange) -> List[Token]:
397         """ Return the list of qualifier tokens for the given range,
398             optionally filtered by the qualifier list from the details
399             parameters.
400         """
401         tokens = self.query.get_tokens(trange, TokenType.QUALIFIER)
402         if self.details.categories:
403             tokens = [t for t in tokens if t.get_category() in self.details.categories]
404
405         return tokens
406
407     def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
408         """ Collect tokens for near items search or use the categories
409             requested per parameter.
410             Returns None if no category search is requested.
411         """
412         if assignment.near_item:
413             tokens: Dict[Tuple[str, str], float] = {}
414             for t in self.query.get_tokens(assignment.near_item, TokenType.NEAR_ITEM):
415                 cat = t.get_category()
416                 # The category of a near search will be that of near_item.
417                 # Thus, if search is restricted to a category parameter,
418                 # the two sets must intersect.
419                 if (not self.details.categories or cat in self.details.categories)\
420                    and t.penalty < tokens.get(cat, 1000.0):
421                     tokens[cat] = t.penalty
422             return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
423
424         return None
425
426
427 PENALTY_WORDCHANGE = {
428     BreakType.START: 0.0,
429     BreakType.END: 0.0,
430     BreakType.PHRASE: 0.0,
431     BreakType.WORD: 0.1,
432     BreakType.PART: 0.2,
433     BreakType.TOKEN: 0.4
434 }