]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_builder.py
Merge pull request #3375 from matkoniecz/patch-1
[nominatim.git] / nominatim / api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Conversion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator, Dict
11 import heapq
12
13 from nominatim.api.types import SearchDetails, DataLayer
14 from nominatim.api.search.query import QueryStruct, Token, TokenType, TokenRange, BreakType
15 from nominatim.api.search.token_assignment import TokenAssignment
16 import nominatim.api.search.db_search_fields as dbf
17 import nominatim.api.search.db_searches as dbs
18 import nominatim.api.search.db_search_lookups as lookups
19
20
21 def wrap_near_search(categories: List[Tuple[str, str]],
22                      search: dbs.AbstractSearch) -> dbs.NearSearch:
23     """ Create a new search that wraps the given search in a search
24         for near places of the given category.
25     """
26     return dbs.NearSearch(penalty=search.penalty,
27                           categories=dbf.WeightedCategories(categories,
28                                                             [0.0] * len(categories)),
29                           search=search)
30
31
32 def build_poi_search(category: List[Tuple[str, str]],
33                      countries: Optional[List[str]]) -> dbs.PoiSearch:
34     """ Create a new search for places by the given category, possibly
35         constraint to the given countries.
36     """
37     if countries:
38         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
39     else:
40         ccs = dbf.WeightedStrings([], [])
41
42     class _PoiData(dbf.SearchData):
43         penalty = 0.0
44         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
45         countries=ccs
46
47     return dbs.PoiSearch(_PoiData())
48
49
50 class SearchBuilder:
51     """ Build the abstract search queries from token assignments.
52     """
53
54     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
55         self.query = query
56         self.details = details
57
58
59     @property
60     def configured_for_country(self) -> bool:
61         """ Return true if the search details are configured to
62             allow countries in the result.
63         """
64         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
65                and self.details.layer_enabled(DataLayer.ADDRESS)
66
67
68     @property
69     def configured_for_postcode(self) -> bool:
70         """ Return true if the search details are configured to
71             allow postcodes in the result.
72         """
73         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
74                and self.details.layer_enabled(DataLayer.ADDRESS)
75
76
77     @property
78     def configured_for_housenumbers(self) -> bool:
79         """ Return true if the search details are configured to
80             allow addresses in the result.
81         """
82         return self.details.max_rank >= 30 \
83                and self.details.layer_enabled(DataLayer.ADDRESS)
84
85
86     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
87         """ Yield all possible abstract searches for the given token assignment.
88         """
89         sdata = self.get_search_data(assignment)
90         if sdata is None:
91             return
92
93         near_items = self.get_near_items(assignment)
94         if near_items is not None and not near_items:
95             return # impossible compbination of near items and category parameter
96
97         if assignment.name is None:
98             if near_items and not sdata.postcodes:
99                 sdata.qualifiers = near_items
100                 near_items = None
101                 builder = self.build_poi_search(sdata)
102             elif assignment.housenumber:
103                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
104                                                    TokenType.HOUSENUMBER)
105                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
106             else:
107                 builder = self.build_special_search(sdata, assignment.address,
108                                                     bool(near_items))
109         else:
110             builder = self.build_name_search(sdata, assignment.name, assignment.address,
111                                              bool(near_items))
112
113         if near_items:
114             penalty = min(near_items.penalties)
115             near_items.penalties = [p - penalty for p in near_items.penalties]
116             for search in builder:
117                 search_penalty = search.penalty
118                 search.penalty = 0.0
119                 yield dbs.NearSearch(penalty + assignment.penalty + search_penalty,
120                                      near_items, search)
121         else:
122             for search in builder:
123                 search.penalty += assignment.penalty
124                 yield search
125
126
127     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
128         """ Build abstract search query for a simple category search.
129             This kind of search requires an additional geographic constraint.
130         """
131         if not sdata.housenumbers \
132            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
133             yield dbs.PoiSearch(sdata)
134
135
136     def build_special_search(self, sdata: dbf.SearchData,
137                              address: List[TokenRange],
138                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
139         """ Build abstract search queries for searches that do not involve
140             a named place.
141         """
142         if sdata.qualifiers:
143             # No special searches over qualifiers supported.
144             return
145
146         if sdata.countries and not address and not sdata.postcodes \
147            and self.configured_for_country:
148             yield dbs.CountrySearch(sdata)
149
150         if sdata.postcodes and (is_category or self.configured_for_postcode):
151             penalty = 0.0 if sdata.countries else 0.1
152             if address:
153                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
154                                                  [t.token for r in address
155                                                   for t in self.query.get_partials_list(r)],
156                                                  lookups.Restrict)]
157                 penalty += 0.2
158             yield dbs.PostcodeSearch(penalty, sdata)
159
160
161     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
162                                  address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
163         """ Build a simple address search for special entries where the
164             housenumber is the main name token.
165         """
166         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], lookups.LookupAny)]
167         expected_count = sum(t.count for t in hnrs)
168
169         partials = {t.token: t.addr_count for trange in address
170                        for t in self.query.get_partials_list(trange)}
171
172         if expected_count < 8000:
173             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
174                                                  list(partials), lookups.Restrict))
175         elif len(partials) != 1 or list(partials.values())[0] < 10000:
176             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
177                                                  list(partials), lookups.LookupAll))
178         else:
179             addr_fulls = [t.token for t
180                           in self.query.get_tokens(address[0], TokenType.WORD)]
181             if len(addr_fulls) > 5:
182                 return
183             sdata.lookups.append(
184                 dbf.FieldLookup('nameaddress_vector', addr_fulls, lookups.LookupAny))
185
186         sdata.housenumbers = dbf.WeightedStrings([], [])
187         yield dbs.PlaceSearch(0.05, sdata, expected_count)
188
189
190     def build_name_search(self, sdata: dbf.SearchData,
191                           name: TokenRange, address: List[TokenRange],
192                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
193         """ Build abstract search queries for simple name or address searches.
194         """
195         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
196             ranking = self.get_name_ranking(name)
197             name_penalty = ranking.normalize_penalty()
198             if ranking.rankings:
199                 sdata.rankings.append(ranking)
200             for penalty, count, lookup in self.yield_lookups(name, address):
201                 sdata.lookups = lookup
202                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
203
204
205     def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
206                           -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
207         """ Yield all variants how the given name and address should best
208             be searched for. This takes into account how frequent the terms
209             are and tries to find a lookup that optimizes index use.
210         """
211         penalty = 0.0 # extra penalty
212         name_partials = {t.token: t for t in self.query.get_partials_list(name)}
213
214         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
215         addr_tokens = list({t.token for t in addr_partials})
216
217         partials_indexed = all(t.is_indexed for t in name_partials.values()) \
218                            and all(t.is_indexed for t in addr_partials)
219         exp_count = min(t.count for t in name_partials.values()) / (2**(len(name_partials) - 1))
220
221         if (len(name_partials) > 3 or exp_count < 8000) and partials_indexed:
222             yield penalty, exp_count, dbf.lookup_by_names(list(name_partials.keys()), addr_tokens)
223             return
224
225         addr_count = min(t.addr_count for t in addr_partials) if addr_partials else 30000
226         # Partial term to frequent. Try looking up by rare full names first.
227         name_fulls = self.query.get_tokens(name, TokenType.WORD)
228         if name_fulls:
229             fulls_count = sum(t.count for t in name_fulls)
230             if len(name_partials) == 1:
231                 penalty += min(0.5, max(0, (exp_count - 50 * fulls_count) / (2000 * fulls_count)))
232             if partials_indexed:
233                 penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
234
235             if fulls_count < 50000 or addr_count < 30000:
236                 yield penalty,fulls_count / (2**len(addr_tokens)), \
237                     self.get_full_name_ranking(name_fulls, addr_partials,
238                                                fulls_count > 30000 / max(1, len(addr_tokens)))
239
240         # To catch remaining results, lookup by name and address
241         # We only do this if there is a reasonable number of results expected.
242         exp_count = exp_count / (2**len(addr_tokens)) if addr_tokens else exp_count
243         if exp_count < 10000 and addr_count < 20000\
244            and all(t.is_indexed for t in name_partials.values()):
245             penalty += 0.35 * max(1 if name_fulls else 0.1,
246                                   5 - len(name_partials) - len(addr_tokens))
247             yield penalty, exp_count,\
248                   self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
249
250
251     def get_name_address_ranking(self, name_tokens: List[int],
252                                  addr_partials: List[Token]) -> List[dbf.FieldLookup]:
253         """ Create a ranking expression looking up by name and address.
254         """
255         lookup = [dbf.FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
256
257         addr_restrict_tokens = []
258         addr_lookup_tokens = []
259         for t in addr_partials:
260             if t.is_indexed:
261                 if t.addr_count > 20000:
262                     addr_restrict_tokens.append(t.token)
263                 else:
264                     addr_lookup_tokens.append(t.token)
265
266         if addr_restrict_tokens:
267             lookup.append(dbf.FieldLookup('nameaddress_vector',
268                                           addr_restrict_tokens, lookups.Restrict))
269         if addr_lookup_tokens:
270             lookup.append(dbf.FieldLookup('nameaddress_vector',
271                                           addr_lookup_tokens, lookups.LookupAll))
272
273         return lookup
274
275
276     def get_full_name_ranking(self, name_fulls: List[Token], addr_partials: List[Token],
277                               use_lookup: bool) -> List[dbf.FieldLookup]:
278         """ Create a ranking expression with full name terms and
279             additional address lookup. When 'use_lookup' is true, then
280             address lookups will use the index, when the occurences are not
281             too many.
282         """
283         # At this point drop unindexed partials from the address.
284         # This might yield wrong results, nothing we can do about that.
285         if use_lookup:
286             addr_restrict_tokens = []
287             addr_lookup_tokens = []
288             for t in addr_partials:
289                 if t.is_indexed:
290                     if t.addr_count > 20000:
291                         addr_restrict_tokens.append(t.token)
292                     else:
293                         addr_lookup_tokens.append(t.token)
294         else:
295             addr_restrict_tokens = [t.token for t in addr_partials if t.is_indexed]
296             addr_lookup_tokens = []
297
298         return dbf.lookup_by_any_name([t.token for t in name_fulls],
299                                       addr_restrict_tokens, addr_lookup_tokens)
300
301
302     def get_name_ranking(self, trange: TokenRange,
303                          db_field: str = 'name_vector') -> dbf.FieldRanking:
304         """ Create a ranking expression for a name term in the given range.
305         """
306         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
307         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
308         ranks.sort(key=lambda r: r.penalty)
309         # Fallback, sum of penalty for partials
310         name_partials = self.query.get_partials_list(trange)
311         default = sum(t.penalty for t in name_partials) + 0.2
312         return dbf.FieldRanking(db_field, default, ranks)
313
314
315     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
316         """ Create a list of ranking expressions for an address term
317             for the given ranges.
318         """
319         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
320         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
321         ranks: List[dbf.RankedTokens] = []
322
323         while todo: # pylint: disable=too-many-nested-blocks
324             neglen, pos, rank = heapq.heappop(todo)
325             for tlist in self.query.nodes[pos].starting:
326                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
327                     if tlist.end < trange.end:
328                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
329                         if tlist.ttype == TokenType.PARTIAL:
330                             penalty = rank.penalty + chgpenalty \
331                                       + max(t.penalty for t in tlist.tokens)
332                             heapq.heappush(todo, (neglen - 1, tlist.end,
333                                                   dbf.RankedTokens(penalty, rank.tokens)))
334                         else:
335                             for t in tlist.tokens:
336                                 heapq.heappush(todo, (neglen - 1, tlist.end,
337                                                       rank.with_token(t, chgpenalty)))
338                     elif tlist.end == trange.end:
339                         if tlist.ttype == TokenType.PARTIAL:
340                             ranks.append(dbf.RankedTokens(rank.penalty
341                                                           + max(t.penalty for t in tlist.tokens),
342                                                           rank.tokens))
343                         else:
344                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
345                         if len(ranks) >= 10:
346                             # Too many variants, bail out and only add
347                             # Worst-case Fallback: sum of penalty of partials
348                             name_partials = self.query.get_partials_list(trange)
349                             default = sum(t.penalty for t in name_partials) + 0.2
350                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
351                             # Bail out of outer loop
352                             todo.clear()
353                             break
354
355         ranks.sort(key=lambda r: len(r.tokens))
356         default = ranks[0].penalty + 0.3
357         del ranks[0]
358         ranks.sort(key=lambda r: r.penalty)
359
360         return dbf.FieldRanking('nameaddress_vector', default, ranks)
361
362
363     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
364         """ Collect the tokens for the non-name search fields in the
365             assignment.
366         """
367         sdata = dbf.SearchData()
368         sdata.penalty = assignment.penalty
369         if assignment.country:
370             tokens = self.get_country_tokens(assignment.country)
371             if not tokens:
372                 return None
373             sdata.set_strings('countries', tokens)
374         elif self.details.countries:
375             sdata.countries = dbf.WeightedStrings(self.details.countries,
376                                                   [0.0] * len(self.details.countries))
377         if assignment.housenumber:
378             sdata.set_strings('housenumbers',
379                               self.query.get_tokens(assignment.housenumber,
380                                                     TokenType.HOUSENUMBER))
381         if assignment.postcode:
382             sdata.set_strings('postcodes',
383                               self.query.get_tokens(assignment.postcode,
384                                                     TokenType.POSTCODE))
385         if assignment.qualifier:
386             tokens = self.get_qualifier_tokens(assignment.qualifier)
387             if not tokens:
388                 return None
389             sdata.set_qualifiers(tokens)
390         elif self.details.categories:
391             sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
392                                                       [0.0] * len(self.details.categories))
393
394         if assignment.address:
395             if not assignment.name and assignment.housenumber:
396                 # housenumber search: the first item needs to be handled like
397                 # a name in ranking or penalties are not comparable with
398                 # normal searches.
399                 sdata.set_ranking([self.get_name_ranking(assignment.address[0],
400                                                          db_field='nameaddress_vector')]
401                                   + [self.get_addr_ranking(r) for r in assignment.address[1:]])
402             else:
403                 sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
404         else:
405             sdata.rankings = []
406
407         return sdata
408
409
410     def get_country_tokens(self, trange: TokenRange) -> List[Token]:
411         """ Return the list of country tokens for the given range,
412             optionally filtered by the country list from the details
413             parameters.
414         """
415         tokens = self.query.get_tokens(trange, TokenType.COUNTRY)
416         if self.details.countries:
417             tokens = [t for t in tokens if t.lookup_word in self.details.countries]
418
419         return tokens
420
421
422     def get_qualifier_tokens(self, trange: TokenRange) -> List[Token]:
423         """ Return the list of qualifier tokens for the given range,
424             optionally filtered by the qualifier list from the details
425             parameters.
426         """
427         tokens = self.query.get_tokens(trange, TokenType.QUALIFIER)
428         if self.details.categories:
429             tokens = [t for t in tokens if t.get_category() in self.details.categories]
430
431         return tokens
432
433
434     def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
435         """ Collect tokens for near items search or use the categories
436             requested per parameter.
437             Returns None if no category search is requested.
438         """
439         if assignment.near_item:
440             tokens: Dict[Tuple[str, str], float] = {}
441             for t in self.query.get_tokens(assignment.near_item, TokenType.NEAR_ITEM):
442                 cat = t.get_category()
443                 # The category of a near search will be that of near_item.
444                 # Thus, if search is restricted to a category parameter,
445                 # the two sets must intersect.
446                 if (not self.details.categories or cat in self.details.categories)\
447                    and t.penalty < tokens.get(cat, 1000.0):
448                     tokens[cat] = t.penalty
449             return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
450
451         return None
452
453
454 PENALTY_WORDCHANGE = {
455     BreakType.START: 0.0,
456     BreakType.END: 0.0,
457     BreakType.PHRASE: 0.0,
458     BreakType.WORD: 0.1,
459     BreakType.PART: 0.2,
460     BreakType.TOKEN: 0.4
461 }