]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/db_search_builder.py
Merge pull request #3658 from lonvia/minor-query-parsing-optimisations
[nominatim.git] / src / nominatim_api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Conversion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator, Dict
11 import heapq
12
13 from ..types import SearchDetails, DataLayer
14 from . import query as qmod
15 from .token_assignment import TokenAssignment
16 from . import db_search_fields as dbf
17 from . import db_searches as dbs
18 from . import db_search_lookups as lookups
19
20
21 def wrap_near_search(categories: List[Tuple[str, str]],
22                      search: dbs.AbstractSearch) -> dbs.NearSearch:
23     """ Create a new search that wraps the given search in a search
24         for near places of the given category.
25     """
26     return dbs.NearSearch(penalty=search.penalty,
27                           categories=dbf.WeightedCategories(categories,
28                                                             [0.0] * len(categories)),
29                           search=search)
30
31
32 def build_poi_search(category: List[Tuple[str, str]],
33                      countries: Optional[List[str]]) -> dbs.PoiSearch:
34     """ Create a new search for places by the given category, possibly
35         constraint to the given countries.
36     """
37     if countries:
38         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
39     else:
40         ccs = dbf.WeightedStrings([], [])
41
42     class _PoiData(dbf.SearchData):
43         penalty = 0.0
44         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
45         countries = ccs
46
47     return dbs.PoiSearch(_PoiData())
48
49
50 class SearchBuilder:
51     """ Build the abstract search queries from token assignments.
52     """
53
54     def __init__(self, query: qmod.QueryStruct, details: SearchDetails) -> None:
55         self.query = query
56         self.details = details
57
58     @property
59     def configured_for_country(self) -> bool:
60         """ Return true if the search details are configured to
61             allow countries in the result.
62         """
63         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
64             and self.details.layer_enabled(DataLayer.ADDRESS)
65
66     @property
67     def configured_for_postcode(self) -> bool:
68         """ Return true if the search details are configured to
69             allow postcodes in the result.
70         """
71         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
72             and self.details.layer_enabled(DataLayer.ADDRESS)
73
74     @property
75     def configured_for_housenumbers(self) -> bool:
76         """ Return true if the search details are configured to
77             allow addresses in the result.
78         """
79         return self.details.max_rank >= 30 \
80             and self.details.layer_enabled(DataLayer.ADDRESS)
81
82     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
83         """ Yield all possible abstract searches for the given token assignment.
84         """
85         sdata = self.get_search_data(assignment)
86         if sdata is None:
87             return
88
89         near_items = self.get_near_items(assignment)
90         if near_items is not None and not near_items:
91             return  # impossible combination of near items and category parameter
92
93         if assignment.name is None:
94             if near_items and not sdata.postcodes:
95                 sdata.qualifiers = near_items
96                 near_items = None
97                 builder = self.build_poi_search(sdata)
98             elif assignment.housenumber:
99                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
100                                                    qmod.TOKEN_HOUSENUMBER)
101                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
102             else:
103                 builder = self.build_special_search(sdata, assignment.address,
104                                                     bool(near_items))
105         else:
106             builder = self.build_name_search(sdata, assignment.name, assignment.address,
107                                              bool(near_items))
108
109         if near_items:
110             penalty = min(near_items.penalties)
111             near_items.penalties = [p - penalty for p in near_items.penalties]
112             for search in builder:
113                 search_penalty = search.penalty
114                 search.penalty = 0.0
115                 yield dbs.NearSearch(penalty + assignment.penalty + search_penalty,
116                                      near_items, search)
117         else:
118             for search in builder:
119                 search.penalty += assignment.penalty
120                 yield search
121
122     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
123         """ Build abstract search query for a simple category search.
124             This kind of search requires an additional geographic constraint.
125         """
126         if not sdata.housenumbers \
127            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
128             yield dbs.PoiSearch(sdata)
129
130     def build_special_search(self, sdata: dbf.SearchData,
131                              address: List[qmod.TokenRange],
132                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
133         """ Build abstract search queries for searches that do not involve
134             a named place.
135         """
136         if sdata.qualifiers:
137             # No special searches over qualifiers supported.
138             return
139
140         if sdata.countries and not address and not sdata.postcodes \
141            and self.configured_for_country:
142             yield dbs.CountrySearch(sdata)
143
144         if sdata.postcodes and (is_category or self.configured_for_postcode):
145             penalty = 0.0 if sdata.countries else 0.1
146             if address:
147                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
148                                                  [t.token for r in address
149                                                   for t in self.query.get_partials_list(r)],
150                                                  lookups.Restrict)]
151             yield dbs.PostcodeSearch(penalty, sdata)
152
153     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[qmod.Token],
154                                  address: List[qmod.TokenRange]) -> Iterator[dbs.AbstractSearch]:
155         """ Build a simple address search for special entries where the
156             housenumber is the main name token.
157         """
158         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], lookups.LookupAny)]
159         expected_count = sum(t.count for t in hnrs)
160
161         partials = {t.token: t.addr_count for trange in address
162                     for t in self.query.get_partials_list(trange)}
163
164         if not partials:
165             # can happen when none of the partials is indexed
166             return
167
168         if expected_count < 8000:
169             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
170                                                  list(partials), lookups.Restrict))
171         elif len(partials) != 1 or list(partials.values())[0] < 10000:
172             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
173                                                  list(partials), lookups.LookupAll))
174         else:
175             addr_fulls = [t.token for t
176                           in self.query.get_tokens(address[0], qmod.TOKEN_WORD)]
177             if len(addr_fulls) > 5:
178                 return
179             sdata.lookups.append(
180                 dbf.FieldLookup('nameaddress_vector', addr_fulls, lookups.LookupAny))
181
182         sdata.housenumbers = dbf.WeightedStrings([], [])
183         yield dbs.PlaceSearch(0.05, sdata, expected_count)
184
185     def build_name_search(self, sdata: dbf.SearchData,
186                           name: qmod.TokenRange, address: List[qmod.TokenRange],
187                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
188         """ Build abstract search queries for simple name or address searches.
189         """
190         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
191             ranking = self.get_name_ranking(name)
192             name_penalty = ranking.normalize_penalty()
193             if ranking.rankings:
194                 sdata.rankings.append(ranking)
195             for penalty, count, lookup in self.yield_lookups(name, address):
196                 sdata.lookups = lookup
197                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
198
199     def yield_lookups(self, name: qmod.TokenRange, address: List[qmod.TokenRange]
200                       ) -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
201         """ Yield all variants how the given name and address should best
202             be searched for. This takes into account how frequent the terms
203             are and tries to find a lookup that optimizes index use.
204         """
205         penalty = 0.0  # extra penalty
206         name_partials = {t.token: t for t in self.query.get_partials_list(name)}
207
208         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
209         addr_tokens = list({t.token for t in addr_partials})
210
211         exp_count = min(t.count for t in name_partials.values()) / (2**(len(name_partials) - 1))
212
213         if (len(name_partials) > 3 or exp_count < 8000):
214             yield penalty, exp_count, dbf.lookup_by_names(list(name_partials.keys()), addr_tokens)
215             return
216
217         addr_count = min(t.addr_count for t in addr_partials) if addr_partials else 30000
218         # Partial term to frequent. Try looking up by rare full names first.
219         name_fulls = self.query.get_tokens(name, qmod.TOKEN_WORD)
220         if name_fulls:
221             fulls_count = sum(t.count for t in name_fulls)
222
223             if fulls_count < 50000 or addr_count < 30000:
224                 yield penalty, fulls_count / (2**len(addr_tokens)), \
225                     self.get_full_name_ranking(name_fulls, addr_partials,
226                                                fulls_count > 30000 / max(1, len(addr_tokens)))
227
228         # To catch remaining results, lookup by name and address
229         # We only do this if there is a reasonable number of results expected.
230         exp_count = exp_count / (2**len(addr_tokens)) if addr_tokens else exp_count
231         if exp_count < 10000 and addr_count < 20000:
232             penalty += 0.35 * max(1 if name_fulls else 0.1,
233                                   5 - len(name_partials) - len(addr_tokens))
234             yield penalty, exp_count, \
235                 self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
236
237     def get_name_address_ranking(self, name_tokens: List[int],
238                                  addr_partials: List[qmod.Token]) -> List[dbf.FieldLookup]:
239         """ Create a ranking expression looking up by name and address.
240         """
241         lookup = [dbf.FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
242
243         addr_restrict_tokens = []
244         addr_lookup_tokens = []
245         for t in addr_partials:
246             if t.addr_count > 20000:
247                 addr_restrict_tokens.append(t.token)
248             else:
249                 addr_lookup_tokens.append(t.token)
250
251         if addr_restrict_tokens:
252             lookup.append(dbf.FieldLookup('nameaddress_vector',
253                                           addr_restrict_tokens, lookups.Restrict))
254         if addr_lookup_tokens:
255             lookup.append(dbf.FieldLookup('nameaddress_vector',
256                                           addr_lookup_tokens, lookups.LookupAll))
257
258         return lookup
259
260     def get_full_name_ranking(self, name_fulls: List[qmod.Token], addr_partials: List[qmod.Token],
261                               use_lookup: bool) -> List[dbf.FieldLookup]:
262         """ Create a ranking expression with full name terms and
263             additional address lookup. When 'use_lookup' is true, then
264             address lookups will use the index, when the occurrences are not
265             too many.
266         """
267         # At this point drop unindexed partials from the address.
268         # This might yield wrong results, nothing we can do about that.
269         if use_lookup:
270             addr_restrict_tokens = []
271             addr_lookup_tokens = []
272             for t in addr_partials:
273                 if t.addr_count > 20000:
274                     addr_restrict_tokens.append(t.token)
275                 else:
276                     addr_lookup_tokens.append(t.token)
277         else:
278             addr_restrict_tokens = [t.token for t in addr_partials]
279             addr_lookup_tokens = []
280
281         return dbf.lookup_by_any_name([t.token for t in name_fulls],
282                                       addr_restrict_tokens, addr_lookup_tokens)
283
284     def get_name_ranking(self, trange: qmod.TokenRange,
285                          db_field: str = 'name_vector') -> dbf.FieldRanking:
286         """ Create a ranking expression for a name term in the given range.
287         """
288         name_fulls = self.query.get_tokens(trange, qmod.TOKEN_WORD)
289         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
290         ranks.sort(key=lambda r: r.penalty)
291         # Fallback, sum of penalty for partials
292         name_partials = self.query.get_partials_list(trange)
293         default = sum(t.penalty for t in name_partials) + 0.2
294         return dbf.FieldRanking(db_field, default, ranks)
295
296     def get_addr_ranking(self, trange: qmod.TokenRange) -> dbf.FieldRanking:
297         """ Create a list of ranking expressions for an address term
298             for the given ranges.
299         """
300         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
301         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
302         ranks: List[dbf.RankedTokens] = []
303
304         while todo:
305             neglen, pos, rank = heapq.heappop(todo)
306             for tlist in self.query.nodes[pos].starting:
307                 if tlist.ttype in (qmod.TOKEN_PARTIAL, qmod.TOKEN_WORD):
308                     if tlist.end < trange.end:
309                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
310                         if tlist.ttype == qmod.TOKEN_PARTIAL:
311                             penalty = rank.penalty + chgpenalty \
312                                       + max(t.penalty for t in tlist.tokens)
313                             heapq.heappush(todo, (neglen - 1, tlist.end,
314                                                   dbf.RankedTokens(penalty, rank.tokens)))
315                         else:
316                             for t in tlist.tokens:
317                                 heapq.heappush(todo, (neglen - 1, tlist.end,
318                                                       rank.with_token(t, chgpenalty)))
319                     elif tlist.end == trange.end:
320                         if tlist.ttype == qmod.TOKEN_PARTIAL:
321                             ranks.append(dbf.RankedTokens(rank.penalty
322                                                           + max(t.penalty for t in tlist.tokens),
323                                                           rank.tokens))
324                         else:
325                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
326                         if len(ranks) >= 10:
327                             # Too many variants, bail out and only add
328                             # Worst-case Fallback: sum of penalty of partials
329                             name_partials = self.query.get_partials_list(trange)
330                             default = sum(t.penalty for t in name_partials) + 0.2
331                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
332                             # Bail out of outer loop
333                             todo.clear()
334                             break
335
336         ranks.sort(key=lambda r: len(r.tokens))
337         default = ranks[0].penalty + 0.3
338         del ranks[0]
339         ranks.sort(key=lambda r: r.penalty)
340
341         return dbf.FieldRanking('nameaddress_vector', default, ranks)
342
343     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
344         """ Collect the tokens for the non-name search fields in the
345             assignment.
346         """
347         sdata = dbf.SearchData()
348         sdata.penalty = assignment.penalty
349         if assignment.country:
350             tokens = self.get_country_tokens(assignment.country)
351             if not tokens:
352                 return None
353             sdata.set_strings('countries', tokens)
354         elif self.details.countries:
355             sdata.countries = dbf.WeightedStrings(self.details.countries,
356                                                   [0.0] * len(self.details.countries))
357         if assignment.housenumber:
358             sdata.set_strings('housenumbers',
359                               self.query.get_tokens(assignment.housenumber,
360                                                     qmod.TOKEN_HOUSENUMBER))
361         if assignment.postcode:
362             sdata.set_strings('postcodes',
363                               self.query.get_tokens(assignment.postcode,
364                                                     qmod.TOKEN_POSTCODE))
365         if assignment.qualifier:
366             tokens = self.get_qualifier_tokens(assignment.qualifier)
367             if not tokens:
368                 return None
369             sdata.set_qualifiers(tokens)
370         elif self.details.categories:
371             sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
372                                                       [0.0] * len(self.details.categories))
373
374         if assignment.address:
375             if not assignment.name and assignment.housenumber:
376                 # housenumber search: the first item needs to be handled like
377                 # a name in ranking or penalties are not comparable with
378                 # normal searches.
379                 sdata.set_ranking([self.get_name_ranking(assignment.address[0],
380                                                          db_field='nameaddress_vector')]
381                                   + [self.get_addr_ranking(r) for r in assignment.address[1:]])
382             else:
383                 sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
384         else:
385             sdata.rankings = []
386
387         return sdata
388
389     def get_country_tokens(self, trange: qmod.TokenRange) -> List[qmod.Token]:
390         """ Return the list of country tokens for the given range,
391             optionally filtered by the country list from the details
392             parameters.
393         """
394         tokens = self.query.get_tokens(trange, qmod.TOKEN_COUNTRY)
395         if self.details.countries:
396             tokens = [t for t in tokens if t.lookup_word in self.details.countries]
397
398         return tokens
399
400     def get_qualifier_tokens(self, trange: qmod.TokenRange) -> List[qmod.Token]:
401         """ Return the list of qualifier tokens for the given range,
402             optionally filtered by the qualifier list from the details
403             parameters.
404         """
405         tokens = self.query.get_tokens(trange, qmod.TOKEN_QUALIFIER)
406         if self.details.categories:
407             tokens = [t for t in tokens if t.get_category() in self.details.categories]
408
409         return tokens
410
411     def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
412         """ Collect tokens for near items search or use the categories
413             requested per parameter.
414             Returns None if no category search is requested.
415         """
416         if assignment.near_item:
417             tokens: Dict[Tuple[str, str], float] = {}
418             for t in self.query.get_tokens(assignment.near_item, qmod.TOKEN_NEAR_ITEM):
419                 cat = t.get_category()
420                 # The category of a near search will be that of near_item.
421                 # Thus, if search is restricted to a category parameter,
422                 # the two sets must intersect.
423                 if (not self.details.categories or cat in self.details.categories)\
424                    and t.penalty < tokens.get(cat, 1000.0):
425                     tokens[cat] = t.penalty
426             return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
427
428         return None
429
430
431 PENALTY_WORDCHANGE = {
432     qmod.BREAK_START: 0.0,
433     qmod.BREAK_END: 0.0,
434     qmod.BREAK_PHRASE: 0.0,
435     qmod.BREAK_SOFT_PHRASE: 0.0,
436     qmod.BREAK_WORD: 0.1,
437     qmod.BREAK_PART: 0.2,
438     qmod.BREAK_TOKEN: 0.4
439 }