]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/db_search_builder.py
docs: rework library getting started
[nominatim.git] / src / nominatim_api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Conversion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator, Dict
11 import heapq
12
13 from ..types import SearchDetails, DataLayer
14 from .query import QueryStruct, Token, TokenType, TokenRange, BreakType
15 from .token_assignment import TokenAssignment
16 from . import db_search_fields as dbf
17 from . import db_searches as dbs
18 from . import db_search_lookups as lookups
19
20
21 def wrap_near_search(categories: List[Tuple[str, str]],
22                      search: dbs.AbstractSearch) -> dbs.NearSearch:
23     """ Create a new search that wraps the given search in a search
24         for near places of the given category.
25     """
26     return dbs.NearSearch(penalty=search.penalty,
27                           categories=dbf.WeightedCategories(categories,
28                                                             [0.0] * len(categories)),
29                           search=search)
30
31
32 def build_poi_search(category: List[Tuple[str, str]],
33                      countries: Optional[List[str]]) -> dbs.PoiSearch:
34     """ Create a new search for places by the given category, possibly
35         constraint to the given countries.
36     """
37     if countries:
38         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
39     else:
40         ccs = dbf.WeightedStrings([], [])
41
42     class _PoiData(dbf.SearchData):
43         penalty = 0.0
44         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
45         countries=ccs
46
47     return dbs.PoiSearch(_PoiData())
48
49
50 class SearchBuilder:
51     """ Build the abstract search queries from token assignments.
52     """
53
54     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
55         self.query = query
56         self.details = details
57
58
59     @property
60     def configured_for_country(self) -> bool:
61         """ Return true if the search details are configured to
62             allow countries in the result.
63         """
64         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
65                and self.details.layer_enabled(DataLayer.ADDRESS)
66
67
68     @property
69     def configured_for_postcode(self) -> bool:
70         """ Return true if the search details are configured to
71             allow postcodes in the result.
72         """
73         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
74                and self.details.layer_enabled(DataLayer.ADDRESS)
75
76
77     @property
78     def configured_for_housenumbers(self) -> bool:
79         """ Return true if the search details are configured to
80             allow addresses in the result.
81         """
82         return self.details.max_rank >= 30 \
83                and self.details.layer_enabled(DataLayer.ADDRESS)
84
85
86     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
87         """ Yield all possible abstract searches for the given token assignment.
88         """
89         sdata = self.get_search_data(assignment)
90         if sdata is None:
91             return
92
93         near_items = self.get_near_items(assignment)
94         if near_items is not None and not near_items:
95             return # impossible compbination of near items and category parameter
96
97         if assignment.name is None:
98             if near_items and not sdata.postcodes:
99                 sdata.qualifiers = near_items
100                 near_items = None
101                 builder = self.build_poi_search(sdata)
102             elif assignment.housenumber:
103                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
104                                                    TokenType.HOUSENUMBER)
105                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
106             else:
107                 builder = self.build_special_search(sdata, assignment.address,
108                                                     bool(near_items))
109         else:
110             builder = self.build_name_search(sdata, assignment.name, assignment.address,
111                                              bool(near_items))
112
113         if near_items:
114             penalty = min(near_items.penalties)
115             near_items.penalties = [p - penalty for p in near_items.penalties]
116             for search in builder:
117                 search_penalty = search.penalty
118                 search.penalty = 0.0
119                 yield dbs.NearSearch(penalty + assignment.penalty + search_penalty,
120                                      near_items, search)
121         else:
122             for search in builder:
123                 search.penalty += assignment.penalty
124                 yield search
125
126
127     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
128         """ Build abstract search query for a simple category search.
129             This kind of search requires an additional geographic constraint.
130         """
131         if not sdata.housenumbers \
132            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
133             yield dbs.PoiSearch(sdata)
134
135
136     def build_special_search(self, sdata: dbf.SearchData,
137                              address: List[TokenRange],
138                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
139         """ Build abstract search queries for searches that do not involve
140             a named place.
141         """
142         if sdata.qualifiers:
143             # No special searches over qualifiers supported.
144             return
145
146         if sdata.countries and not address and not sdata.postcodes \
147            and self.configured_for_country:
148             yield dbs.CountrySearch(sdata)
149
150         if sdata.postcodes and (is_category or self.configured_for_postcode):
151             penalty = 0.0 if sdata.countries else 0.1
152             if address:
153                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
154                                                  [t.token for r in address
155                                                   for t in self.query.get_partials_list(r)],
156                                                  lookups.Restrict)]
157                 penalty += 0.2
158             yield dbs.PostcodeSearch(penalty, sdata)
159
160
161     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
162                                  address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
163         """ Build a simple address search for special entries where the
164             housenumber is the main name token.
165         """
166         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], lookups.LookupAny)]
167         expected_count = sum(t.count for t in hnrs)
168
169         partials = {t.token: t.addr_count for trange in address
170                        for t in self.query.get_partials_list(trange)
171                        if t.is_indexed}
172
173         if not partials:
174             # can happen when none of the partials is indexed
175             return
176
177         if expected_count < 8000:
178             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
179                                                  list(partials), lookups.Restrict))
180         elif len(partials) != 1 or list(partials.values())[0] < 10000:
181             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
182                                                  list(partials), lookups.LookupAll))
183         else:
184             addr_fulls = [t.token for t
185                           in self.query.get_tokens(address[0], TokenType.WORD)]
186             if len(addr_fulls) > 5:
187                 return
188             sdata.lookups.append(
189                 dbf.FieldLookup('nameaddress_vector', addr_fulls, lookups.LookupAny))
190
191         sdata.housenumbers = dbf.WeightedStrings([], [])
192         yield dbs.PlaceSearch(0.05, sdata, expected_count)
193
194
195     def build_name_search(self, sdata: dbf.SearchData,
196                           name: TokenRange, address: List[TokenRange],
197                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
198         """ Build abstract search queries for simple name or address searches.
199         """
200         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
201             ranking = self.get_name_ranking(name)
202             name_penalty = ranking.normalize_penalty()
203             if ranking.rankings:
204                 sdata.rankings.append(ranking)
205             for penalty, count, lookup in self.yield_lookups(name, address):
206                 sdata.lookups = lookup
207                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
208
209
210     def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
211                           -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
212         """ Yield all variants how the given name and address should best
213             be searched for. This takes into account how frequent the terms
214             are and tries to find a lookup that optimizes index use.
215         """
216         penalty = 0.0 # extra penalty
217         name_partials = {t.token: t for t in self.query.get_partials_list(name)}
218
219         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
220         addr_tokens = list({t.token for t in addr_partials})
221
222         partials_indexed = all(t.is_indexed for t in name_partials.values()) \
223                            and all(t.is_indexed for t in addr_partials)
224         exp_count = min(t.count for t in name_partials.values()) / (2**(len(name_partials) - 1))
225
226         if (len(name_partials) > 3 or exp_count < 8000) and partials_indexed:
227             yield penalty, exp_count, dbf.lookup_by_names(list(name_partials.keys()), addr_tokens)
228             return
229
230         addr_count = min(t.addr_count for t in addr_partials) if addr_partials else 30000
231         # Partial term to frequent. Try looking up by rare full names first.
232         name_fulls = self.query.get_tokens(name, TokenType.WORD)
233         if name_fulls:
234             fulls_count = sum(t.count for t in name_fulls)
235             if partials_indexed:
236                 penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
237
238             if fulls_count < 50000 or addr_count < 30000:
239                 yield penalty,fulls_count / (2**len(addr_tokens)), \
240                     self.get_full_name_ranking(name_fulls, addr_partials,
241                                                fulls_count > 30000 / max(1, len(addr_tokens)))
242
243         # To catch remaining results, lookup by name and address
244         # We only do this if there is a reasonable number of results expected.
245         exp_count = exp_count / (2**len(addr_tokens)) if addr_tokens else exp_count
246         if exp_count < 10000 and addr_count < 20000\
247            and all(t.is_indexed for t in name_partials.values()):
248             penalty += 0.35 * max(1 if name_fulls else 0.1,
249                                   5 - len(name_partials) - len(addr_tokens))
250             yield penalty, exp_count,\
251                   self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
252
253
254     def get_name_address_ranking(self, name_tokens: List[int],
255                                  addr_partials: List[Token]) -> List[dbf.FieldLookup]:
256         """ Create a ranking expression looking up by name and address.
257         """
258         lookup = [dbf.FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
259
260         addr_restrict_tokens = []
261         addr_lookup_tokens = []
262         for t in addr_partials:
263             if t.is_indexed:
264                 if t.addr_count > 20000:
265                     addr_restrict_tokens.append(t.token)
266                 else:
267                     addr_lookup_tokens.append(t.token)
268
269         if addr_restrict_tokens:
270             lookup.append(dbf.FieldLookup('nameaddress_vector',
271                                           addr_restrict_tokens, lookups.Restrict))
272         if addr_lookup_tokens:
273             lookup.append(dbf.FieldLookup('nameaddress_vector',
274                                           addr_lookup_tokens, lookups.LookupAll))
275
276         return lookup
277
278
279     def get_full_name_ranking(self, name_fulls: List[Token], addr_partials: List[Token],
280                               use_lookup: bool) -> List[dbf.FieldLookup]:
281         """ Create a ranking expression with full name terms and
282             additional address lookup. When 'use_lookup' is true, then
283             address lookups will use the index, when the occurrences are not
284             too many.
285         """
286         # At this point drop unindexed partials from the address.
287         # This might yield wrong results, nothing we can do about that.
288         if use_lookup:
289             addr_restrict_tokens = []
290             addr_lookup_tokens = []
291             for t in addr_partials:
292                 if t.is_indexed:
293                     if t.addr_count > 20000:
294                         addr_restrict_tokens.append(t.token)
295                     else:
296                         addr_lookup_tokens.append(t.token)
297         else:
298             addr_restrict_tokens = [t.token for t in addr_partials if t.is_indexed]
299             addr_lookup_tokens = []
300
301         return dbf.lookup_by_any_name([t.token for t in name_fulls],
302                                       addr_restrict_tokens, addr_lookup_tokens)
303
304
305     def get_name_ranking(self, trange: TokenRange,
306                          db_field: str = 'name_vector') -> dbf.FieldRanking:
307         """ Create a ranking expression for a name term in the given range.
308         """
309         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
310         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
311         ranks.sort(key=lambda r: r.penalty)
312         # Fallback, sum of penalty for partials
313         name_partials = self.query.get_partials_list(trange)
314         default = sum(t.penalty for t in name_partials) + 0.2
315         return dbf.FieldRanking(db_field, default, ranks)
316
317
318     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
319         """ Create a list of ranking expressions for an address term
320             for the given ranges.
321         """
322         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
323         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
324         ranks: List[dbf.RankedTokens] = []
325
326         while todo: # pylint: disable=too-many-nested-blocks
327             neglen, pos, rank = heapq.heappop(todo)
328             for tlist in self.query.nodes[pos].starting:
329                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
330                     if tlist.end < trange.end:
331                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
332                         if tlist.ttype == TokenType.PARTIAL:
333                             penalty = rank.penalty + chgpenalty \
334                                       + max(t.penalty for t in tlist.tokens)
335                             heapq.heappush(todo, (neglen - 1, tlist.end,
336                                                   dbf.RankedTokens(penalty, rank.tokens)))
337                         else:
338                             for t in tlist.tokens:
339                                 heapq.heappush(todo, (neglen - 1, tlist.end,
340                                                       rank.with_token(t, chgpenalty)))
341                     elif tlist.end == trange.end:
342                         if tlist.ttype == TokenType.PARTIAL:
343                             ranks.append(dbf.RankedTokens(rank.penalty
344                                                           + max(t.penalty for t in tlist.tokens),
345                                                           rank.tokens))
346                         else:
347                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
348                         if len(ranks) >= 10:
349                             # Too many variants, bail out and only add
350                             # Worst-case Fallback: sum of penalty of partials
351                             name_partials = self.query.get_partials_list(trange)
352                             default = sum(t.penalty for t in name_partials) + 0.2
353                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
354                             # Bail out of outer loop
355                             todo.clear()
356                             break
357
358         ranks.sort(key=lambda r: len(r.tokens))
359         default = ranks[0].penalty + 0.3
360         del ranks[0]
361         ranks.sort(key=lambda r: r.penalty)
362
363         return dbf.FieldRanking('nameaddress_vector', default, ranks)
364
365
366     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
367         """ Collect the tokens for the non-name search fields in the
368             assignment.
369         """
370         sdata = dbf.SearchData()
371         sdata.penalty = assignment.penalty
372         if assignment.country:
373             tokens = self.get_country_tokens(assignment.country)
374             if not tokens:
375                 return None
376             sdata.set_strings('countries', tokens)
377         elif self.details.countries:
378             sdata.countries = dbf.WeightedStrings(self.details.countries,
379                                                   [0.0] * len(self.details.countries))
380         if assignment.housenumber:
381             sdata.set_strings('housenumbers',
382                               self.query.get_tokens(assignment.housenumber,
383                                                     TokenType.HOUSENUMBER))
384         if assignment.postcode:
385             sdata.set_strings('postcodes',
386                               self.query.get_tokens(assignment.postcode,
387                                                     TokenType.POSTCODE))
388         if assignment.qualifier:
389             tokens = self.get_qualifier_tokens(assignment.qualifier)
390             if not tokens:
391                 return None
392             sdata.set_qualifiers(tokens)
393         elif self.details.categories:
394             sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
395                                                       [0.0] * len(self.details.categories))
396
397         if assignment.address:
398             if not assignment.name and assignment.housenumber:
399                 # housenumber search: the first item needs to be handled like
400                 # a name in ranking or penalties are not comparable with
401                 # normal searches.
402                 sdata.set_ranking([self.get_name_ranking(assignment.address[0],
403                                                          db_field='nameaddress_vector')]
404                                   + [self.get_addr_ranking(r) for r in assignment.address[1:]])
405             else:
406                 sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
407         else:
408             sdata.rankings = []
409
410         return sdata
411
412
413     def get_country_tokens(self, trange: TokenRange) -> List[Token]:
414         """ Return the list of country tokens for the given range,
415             optionally filtered by the country list from the details
416             parameters.
417         """
418         tokens = self.query.get_tokens(trange, TokenType.COUNTRY)
419         if self.details.countries:
420             tokens = [t for t in tokens if t.lookup_word in self.details.countries]
421
422         return tokens
423
424
425     def get_qualifier_tokens(self, trange: TokenRange) -> List[Token]:
426         """ Return the list of qualifier tokens for the given range,
427             optionally filtered by the qualifier list from the details
428             parameters.
429         """
430         tokens = self.query.get_tokens(trange, TokenType.QUALIFIER)
431         if self.details.categories:
432             tokens = [t for t in tokens if t.get_category() in self.details.categories]
433
434         return tokens
435
436
437     def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
438         """ Collect tokens for near items search or use the categories
439             requested per parameter.
440             Returns None if no category search is requested.
441         """
442         if assignment.near_item:
443             tokens: Dict[Tuple[str, str], float] = {}
444             for t in self.query.get_tokens(assignment.near_item, TokenType.NEAR_ITEM):
445                 cat = t.get_category()
446                 # The category of a near search will be that of near_item.
447                 # Thus, if search is restricted to a category parameter,
448                 # the two sets must intersect.
449                 if (not self.details.categories or cat in self.details.categories)\
450                    and t.penalty < tokens.get(cat, 1000.0):
451                     tokens[cat] = t.penalty
452             return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
453
454         return None
455
456
457 PENALTY_WORDCHANGE = {
458     BreakType.START: 0.0,
459     BreakType.END: 0.0,
460     BreakType.PHRASE: 0.0,
461     BreakType.WORD: 0.1,
462     BreakType.PART: 0.2,
463     BreakType.TOKEN: 0.4
464 }