]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_builder.py
take token_assignment penalty into account
[nominatim.git] / nominatim / api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Convertion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator
11 import heapq
12
13 from nominatim.api.types import SearchDetails, DataLayer
14 from nominatim.api.search.query import QueryStruct, Token, TokenType, TokenRange, BreakType
15 from nominatim.api.search.token_assignment import TokenAssignment
16 import nominatim.api.search.db_search_fields as dbf
17 import nominatim.api.search.db_searches as dbs
18
19
20 def wrap_near_search(categories: List[Tuple[str, str]],
21                      search: dbs.AbstractSearch) -> dbs.NearSearch:
22     """ Create a new search that wraps the given search in a search
23         for near places of the given category.
24     """
25     return dbs.NearSearch(penalty=search.penalty,
26                           categories=dbf.WeightedCategories(categories,
27                                                             [0.0] * len(categories)),
28                           search=search)
29
30
31 def build_poi_search(category: List[Tuple[str, str]],
32                      countries: Optional[List[str]]) -> dbs.PoiSearch:
33     """ Create a new search for places by the given category, possibly
34         constraint to the given countries.
35     """
36     if countries:
37         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
38     else:
39         ccs = dbf.WeightedStrings([], [])
40
41     class _PoiData(dbf.SearchData):
42         penalty = 0.0
43         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
44         countries=ccs
45
46     return dbs.PoiSearch(_PoiData())
47
48
49 class SearchBuilder:
50     """ Build the abstract search queries from token assignments.
51     """
52
53     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
54         self.query = query
55         self.details = details
56
57
58     @property
59     def configured_for_country(self) -> bool:
60         """ Return true if the search details are configured to
61             allow countries in the result.
62         """
63         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
64                and self.details.layer_enabled(DataLayer.ADDRESS)
65
66
67     @property
68     def configured_for_postcode(self) -> bool:
69         """ Return true if the search details are configured to
70             allow postcodes in the result.
71         """
72         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
73                and self.details.layer_enabled(DataLayer.ADDRESS)
74
75
76     @property
77     def configured_for_housenumbers(self) -> bool:
78         """ Return true if the search details are configured to
79             allow addresses in the result.
80         """
81         return self.details.max_rank >= 30 \
82                and self.details.layer_enabled(DataLayer.ADDRESS)
83
84
85     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
86         """ Yield all possible abstract searches for the given token assignment.
87         """
88         sdata = self.get_search_data(assignment)
89         if sdata is None:
90             return
91
92         categories = self.get_search_categories(assignment)
93
94         if assignment.name is None:
95             if categories and not sdata.postcodes:
96                 sdata.qualifiers = categories
97                 categories = None
98                 builder = self.build_poi_search(sdata)
99             elif assignment.housenumber:
100                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
101                                                    TokenType.HOUSENUMBER)
102                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
103             else:
104                 builder = self.build_special_search(sdata, assignment.address,
105                                                     bool(categories))
106         else:
107             builder = self.build_name_search(sdata, assignment.name, assignment.address,
108                                              bool(categories))
109
110         if categories:
111             penalty = min(categories.penalties)
112             categories.penalties = [p - penalty for p in categories.penalties]
113             for search in builder:
114                 yield dbs.NearSearch(penalty + assignment.penalty, categories, search)
115         else:
116             for search in builder:
117                 search.penalty += assignment.penalty
118                 yield search
119
120
121     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
122         """ Build abstract search query for a simple category search.
123             This kind of search requires an additional geographic constraint.
124         """
125         if not sdata.housenumbers \
126            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
127             yield dbs.PoiSearch(sdata)
128
129
130     def build_special_search(self, sdata: dbf.SearchData,
131                              address: List[TokenRange],
132                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
133         """ Build abstract search queries for searches that do not involve
134             a named place.
135         """
136         if sdata.qualifiers:
137             # No special searches over qualifiers supported.
138             return
139
140         if sdata.countries and not address and not sdata.postcodes \
141            and self.configured_for_country:
142             yield dbs.CountrySearch(sdata)
143
144         if sdata.postcodes and (is_category or self.configured_for_postcode):
145             penalty = 0.0 if sdata.countries else 0.1
146             if address:
147                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
148                                                  [t.token for r in address
149                                                   for t in self.query.get_partials_list(r)],
150                                                  'restrict')]
151                 penalty += 0.2
152             yield dbs.PostcodeSearch(penalty, sdata)
153
154
155     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
156                                  address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
157         """ Build a simple address search for special entries where the
158             housenumber is the main name token.
159         """
160         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], 'lookup_any')]
161
162         partials = [t for trange in address
163                        for t in self.query.get_partials_list(trange)]
164
165         if len(partials) != 1 or partials[0].count < 10000:
166             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
167                                                  [t.token for t in partials], 'lookup_all'))
168         else:
169             sdata.lookups.append(
170                 dbf.FieldLookup('nameaddress_vector',
171                                 [t.token for t
172                                  in self.query.get_tokens(address[0], TokenType.WORD)],
173                                 'lookup_any'))
174
175         sdata.housenumbers = dbf.WeightedStrings([], [])
176         yield dbs.PlaceSearch(0.05, sdata, sum(t.count for t in hnrs))
177
178
179     def build_name_search(self, sdata: dbf.SearchData,
180                           name: TokenRange, address: List[TokenRange],
181                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
182         """ Build abstract search queries for simple name or address searches.
183         """
184         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
185             ranking = self.get_name_ranking(name)
186             name_penalty = ranking.normalize_penalty()
187             if ranking.rankings:
188                 sdata.rankings.append(ranking)
189             for penalty, count, lookup in self.yield_lookups(name, address):
190                 sdata.lookups = lookup
191                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
192
193
194     def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
195                           -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
196         """ Yield all variants how the given name and address should best
197             be searched for. This takes into account how frequent the terms
198             are and tries to find a lookup that optimizes index use.
199         """
200         penalty = 0.0 # extra penalty
201         name_partials = self.query.get_partials_list(name)
202         name_tokens = [t.token for t in name_partials]
203
204         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
205         addr_tokens = [t.token for t in addr_partials]
206
207         partials_indexed = all(t.is_indexed for t in name_partials) \
208                            and all(t.is_indexed for t in addr_partials)
209         exp_count = min(t.count for t in name_partials)
210
211         if (len(name_partials) > 3 or exp_count < 1000) and partials_indexed:
212             yield penalty, exp_count, dbf.lookup_by_names(name_tokens, addr_tokens)
213             return
214
215         exp_count = exp_count / (2**len(addr_partials)) if addr_partials else exp_count
216
217         # Partial term to frequent. Try looking up by rare full names first.
218         name_fulls = self.query.get_tokens(name, TokenType.WORD)
219         rare_names = list(filter(lambda t: t.count < 10000, name_fulls))
220         # At this point drop unindexed partials from the address.
221         # This might yield wrong results, nothing we can do about that.
222         if not partials_indexed:
223             addr_tokens = [t.token for t in addr_partials if t.is_indexed]
224             penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
225         if rare_names:
226             # Any of the full names applies with all of the partials from the address
227             yield penalty, sum(t.count for t in rare_names),\
228                   dbf.lookup_by_any_name([t.token for t in rare_names], addr_tokens)
229
230         # To catch remaining results, lookup by name and address
231         # We only do this if there is a reasonable number of results expected.
232         if exp_count < 10000:
233             if all(t.is_indexed for t in name_partials):
234                 lookup = [dbf.FieldLookup('name_vector', name_tokens, 'lookup_all')]
235             else:
236                 # we don't have the partials, try with the non-rare names
237                 non_rare_names = [t.token for t in name_fulls if t.count >= 10000]
238                 if not non_rare_names:
239                     return
240                 lookup = [dbf.FieldLookup('name_vector', non_rare_names, 'lookup_any')]
241             if addr_tokens:
242                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all'))
243             penalty += 0.1 * max(0, 5 - len(name_partials) - len(addr_tokens))
244             if len(rare_names) == len(name_fulls):
245                 # if there already was a search for all full tokens,
246                 # avoid this if anything has been found
247                 penalty += 0.25
248             yield penalty, exp_count, lookup
249
250
251     def get_name_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
252         """ Create a ranking expression for a name term in the given range.
253         """
254         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
255         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
256         ranks.sort(key=lambda r: r.penalty)
257         # Fallback, sum of penalty for partials
258         name_partials = self.query.get_partials_list(trange)
259         default = sum(t.penalty for t in name_partials) + 0.2
260         return dbf.FieldRanking('name_vector', default, ranks)
261
262
263     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
264         """ Create a list of ranking expressions for an address term
265             for the given ranges.
266         """
267         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
268         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
269         ranks: List[dbf.RankedTokens] = []
270
271         while todo: # pylint: disable=too-many-nested-blocks
272             neglen, pos, rank = heapq.heappop(todo)
273             for tlist in self.query.nodes[pos].starting:
274                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
275                     if tlist.end < trange.end:
276                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
277                         if tlist.ttype == TokenType.PARTIAL:
278                             penalty = rank.penalty + chgpenalty \
279                                       + max(t.penalty for t in tlist.tokens)
280                             heapq.heappush(todo, (neglen - 1, tlist.end,
281                                                   dbf.RankedTokens(penalty, rank.tokens)))
282                         else:
283                             for t in tlist.tokens:
284                                 heapq.heappush(todo, (neglen - 1, tlist.end,
285                                                       rank.with_token(t, chgpenalty)))
286                     elif tlist.end == trange.end:
287                         if tlist.ttype == TokenType.PARTIAL:
288                             ranks.append(dbf.RankedTokens(rank.penalty
289                                                           + max(t.penalty for t in tlist.tokens),
290                                                           rank.tokens))
291                         else:
292                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
293                         if len(ranks) >= 10:
294                             # Too many variants, bail out and only add
295                             # Worst-case Fallback: sum of penalty of partials
296                             name_partials = self.query.get_partials_list(trange)
297                             default = sum(t.penalty for t in name_partials) + 0.2
298                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
299                             # Bail out of outer loop
300                             todo.clear()
301                             break
302
303         ranks.sort(key=lambda r: len(r.tokens))
304         default = ranks[0].penalty + 0.3
305         del ranks[0]
306         ranks.sort(key=lambda r: r.penalty)
307
308         return dbf.FieldRanking('nameaddress_vector', default, ranks)
309
310
311     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
312         """ Collect the tokens for the non-name search fields in the
313             assignment.
314         """
315         sdata = dbf.SearchData()
316         sdata.penalty = assignment.penalty
317         if assignment.country:
318             tokens = self.query.get_tokens(assignment.country, TokenType.COUNTRY)
319             if self.details.countries:
320                 tokens = [t for t in tokens if t.lookup_word in self.details.countries]
321                 if not tokens:
322                     return None
323             sdata.set_strings('countries', tokens)
324         elif self.details.countries:
325             sdata.countries = dbf.WeightedStrings(self.details.countries,
326                                                   [0.0] * len(self.details.countries))
327         if assignment.housenumber:
328             sdata.set_strings('housenumbers',
329                               self.query.get_tokens(assignment.housenumber,
330                                                     TokenType.HOUSENUMBER))
331         if assignment.postcode:
332             sdata.set_strings('postcodes',
333                               self.query.get_tokens(assignment.postcode,
334                                                     TokenType.POSTCODE))
335         if assignment.qualifier:
336             sdata.set_qualifiers(self.query.get_tokens(assignment.qualifier,
337                                                        TokenType.QUALIFIER))
338
339         if assignment.address:
340             sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
341         else:
342             sdata.rankings = []
343
344         return sdata
345
346
347     def get_search_categories(self,
348                               assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
349         """ Collect tokens for category search or use the categories
350             requested per parameter.
351             Returns None if no category search is requested.
352         """
353         if assignment.category:
354             tokens = [t for t in self.query.get_tokens(assignment.category,
355                                                        TokenType.CATEGORY)
356                       if not self.details.categories
357                          or t.get_category() in self.details.categories]
358             return dbf.WeightedCategories([t.get_category() for t in tokens],
359                                           [t.penalty for t in tokens])
360
361         if self.details.categories:
362             return dbf.WeightedCategories(self.details.categories,
363                                           [0.0] * len(self.details.categories))
364
365         return None
366
367
368 PENALTY_WORDCHANGE = {
369     BreakType.START: 0.0,
370     BreakType.END: 0.0,
371     BreakType.PHRASE: 0.0,
372     BreakType.WORD: 0.1,
373     BreakType.PART: 0.2,
374     BreakType.TOKEN: 0.4
375 }