]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_builder.py
fix various failing BDD tests
[nominatim.git] / nominatim / api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Convertion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator
11 import heapq
12
13 from nominatim.api.types import SearchDetails, DataLayer
14 from nominatim.api.search.query import QueryStruct, Token, TokenType, TokenRange, BreakType
15 from nominatim.api.search.token_assignment import TokenAssignment
16 import nominatim.api.search.db_search_fields as dbf
17 import nominatim.api.search.db_searches as dbs
18 from nominatim.api.logging import log
19
20
21 def wrap_near_search(categories: List[Tuple[str, str]],
22                      search: dbs.AbstractSearch) -> dbs.NearSearch:
23     """ Create a new search that wraps the given search in a search
24         for near places of the given category.
25     """
26     return dbs.NearSearch(penalty=search.penalty,
27                           categories=dbf.WeightedCategories(categories,
28                                                             [0.0] * len(categories)),
29                           search=search)
30
31
32 def build_poi_search(category: List[Tuple[str, str]],
33                      countries: Optional[List[str]]) -> dbs.PoiSearch:
34     """ Create a new search for places by the given category, possibly
35         constraint to the given countries.
36     """
37     if countries:
38         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
39     else:
40         ccs = dbf.WeightedStrings([], [])
41
42     class _PoiData(dbf.SearchData):
43         penalty = 0.0
44         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
45         countries=ccs
46
47     return dbs.PoiSearch(_PoiData())
48
49
50 class SearchBuilder:
51     """ Build the abstract search queries from token assignments.
52     """
53
54     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
55         self.query = query
56         self.details = details
57
58
59     @property
60     def configured_for_country(self) -> bool:
61         """ Return true if the search details are configured to
62             allow countries in the result.
63         """
64         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
65                and self.details.layer_enabled(DataLayer.ADDRESS)
66
67
68     @property
69     def configured_for_postcode(self) -> bool:
70         """ Return true if the search details are configured to
71             allow postcodes in the result.
72         """
73         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
74                and self.details.layer_enabled(DataLayer.ADDRESS)
75
76
77     @property
78     def configured_for_housenumbers(self) -> bool:
79         """ Return true if the search details are configured to
80             allow addresses in the result.
81         """
82         return self.details.max_rank >= 30 \
83                and self.details.layer_enabled(DataLayer.ADDRESS)
84
85
86     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
87         """ Yield all possible abstract searches for the given token assignment.
88         """
89         sdata = self.get_search_data(assignment)
90         if sdata is None:
91             return
92
93         categories = self.get_search_categories(assignment)
94
95         if assignment.name is None:
96             if categories and not sdata.postcodes:
97                 sdata.qualifiers = categories
98                 categories = None
99                 builder = self.build_poi_search(sdata)
100             elif assignment.housenumber:
101                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
102                                                    TokenType.HOUSENUMBER)
103                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
104             else:
105                 builder = self.build_special_search(sdata, assignment.address,
106                                                     bool(categories))
107         else:
108             builder = self.build_name_search(sdata, assignment.name, assignment.address,
109                                              bool(categories))
110
111         if categories:
112             penalty = min(categories.penalties)
113             categories.penalties = [p - penalty for p in categories.penalties]
114             for search in builder:
115                 yield dbs.NearSearch(penalty, categories, search)
116         else:
117             yield from builder
118
119
120     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
121         """ Build abstract search query for a simple category search.
122             This kind of search requires an additional geographic constraint.
123         """
124         if not sdata.housenumbers \
125            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
126             yield dbs.PoiSearch(sdata)
127
128
129     def build_special_search(self, sdata: dbf.SearchData,
130                              address: List[TokenRange],
131                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
132         """ Build abstract search queries for searches that do not involve
133             a named place.
134         """
135         if sdata.qualifiers:
136             # No special searches over qualifiers supported.
137             return
138
139         if sdata.countries and not address and not sdata.postcodes \
140            and self.configured_for_country:
141             yield dbs.CountrySearch(sdata)
142
143         if sdata.postcodes and (is_category or self.configured_for_postcode):
144             if address:
145                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
146                                                  [t.token for r in address
147                                                   for t in self.query.get_partials_list(r)],
148                                                  'restrict')]
149             yield dbs.PostcodeSearch(0.4, sdata)
150
151
152     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
153                                  address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
154         """ Build a simple address search for special entries where the
155             housenumber is the main name token.
156         """
157         partial_tokens: List[int] = []
158         for trange in address:
159             partial_tokens.extend(t.token for t in self.query.get_partials_list(trange))
160
161         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], 'lookup_any'),
162                          dbf.FieldLookup('nameaddress_vector', partial_tokens, 'lookup_all')
163                         ]
164         yield dbs.PlaceSearch(0.05, sdata, sum(t.count for t in hnrs))
165
166
167     def build_name_search(self, sdata: dbf.SearchData,
168                           name: TokenRange, address: List[TokenRange],
169                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
170         """ Build abstract search queries for simple name or address searches.
171         """
172         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
173             ranking = self.get_name_ranking(name)
174             name_penalty = ranking.normalize_penalty()
175             if ranking.rankings:
176                 sdata.rankings.append(ranking)
177             for penalty, count, lookup in self.yield_lookups(name, address):
178                 sdata.lookups = lookup
179                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
180
181
182     def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
183                           -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
184         """ Yield all variants how the given name and address should best
185             be searched for. This takes into account how frequent the terms
186             are and tries to find a lookup that optimizes index use.
187         """
188         penalty = 0.0 # extra penalty currently unused
189
190         name_partials = self.query.get_partials_list(name)
191         exp_name_count = min(t.count for t in name_partials)
192         addr_partials = []
193         for trange in address:
194             addr_partials.extend(self.query.get_partials_list(trange))
195         addr_tokens = [t.token for t in addr_partials]
196         partials_indexed = all(t.is_indexed for t in name_partials) \
197                            and all(t.is_indexed for t in addr_partials)
198
199         if (len(name_partials) > 3 or exp_name_count < 1000) and partials_indexed:
200             # Lookup by name partials, use address partials to restrict results.
201             lookup = [dbf.FieldLookup('name_vector',
202                                   [t.token for t in name_partials], 'lookup_all')]
203             if addr_tokens:
204                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
205             yield penalty, exp_name_count, lookup
206             return
207
208         exp_addr_count = min(t.count for t in addr_partials) if addr_partials else exp_name_count
209         if exp_addr_count < 1000 and partials_indexed:
210             # Lookup by address partials and restrict results through name terms.
211             yield penalty, exp_addr_count,\
212                   [dbf.FieldLookup('name_vector', [t.token for t in name_partials], 'restrict'),
213                    dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]
214             return
215
216         # Partial term to frequent. Try looking up by rare full names first.
217         name_fulls = self.query.get_tokens(name, TokenType.WORD)
218         rare_names = list(filter(lambda t: t.count < 1000, name_fulls))
219         # At this point drop unindexed partials from the address.
220         # This might yield wrong results, nothing we can do about that.
221         if not partials_indexed:
222             addr_tokens = [t.token for t in addr_partials if t.is_indexed]
223             log().var_dump('before', penalty)
224             penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
225             log().var_dump('after', penalty)
226         if rare_names:
227             # Any of the full names applies with all of the partials from the address
228             lookup = [dbf.FieldLookup('name_vector', [t.token for t in rare_names], 'lookup_any')]
229             if addr_tokens:
230                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
231             yield penalty, sum(t.count for t in rare_names), lookup
232
233         # To catch remaining results, lookup by name and address
234         if all(t.is_indexed for t in name_partials):
235             lookup = [dbf.FieldLookup('name_vector',
236                                       [t.token for t in name_partials], 'lookup_all')]
237         else:
238             # we don't have the partials, try with the non-rare names
239             non_rare_names = [t.token for t in name_fulls if t.count >= 1000]
240             if not non_rare_names:
241                 return
242             lookup = [dbf.FieldLookup('name_vector', non_rare_names, 'lookup_any')]
243         if addr_tokens:
244             lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all'))
245         yield penalty + 0.1 * max(0, 5 - len(name_partials) - len(addr_tokens)),\
246               min(exp_name_count, exp_addr_count), lookup
247
248
249     def get_name_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
250         """ Create a ranking expression for a name term in the given range.
251         """
252         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
253         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
254         ranks.sort(key=lambda r: r.penalty)
255         # Fallback, sum of penalty for partials
256         name_partials = self.query.get_partials_list(trange)
257         default = sum(t.penalty for t in name_partials) + 0.2
258         return dbf.FieldRanking('name_vector', default, ranks)
259
260
261     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
262         """ Create a list of ranking expressions for an address term
263             for the given ranges.
264         """
265         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
266         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
267         ranks: List[dbf.RankedTokens] = []
268
269         while todo: # pylint: disable=too-many-nested-blocks
270             neglen, pos, rank = heapq.heappop(todo)
271             for tlist in self.query.nodes[pos].starting:
272                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
273                     if tlist.end < trange.end:
274                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
275                         if tlist.ttype == TokenType.PARTIAL:
276                             penalty = rank.penalty + chgpenalty \
277                                       + max(t.penalty for t in tlist.tokens)
278                             heapq.heappush(todo, (neglen - 1, tlist.end,
279                                                   dbf.RankedTokens(penalty, rank.tokens)))
280                         else:
281                             for t in tlist.tokens:
282                                 heapq.heappush(todo, (neglen - 1, tlist.end,
283                                                       rank.with_token(t, chgpenalty)))
284                     elif tlist.end == trange.end:
285                         if tlist.ttype == TokenType.PARTIAL:
286                             ranks.append(dbf.RankedTokens(rank.penalty
287                                                           + max(t.penalty for t in tlist.tokens),
288                                                           rank.tokens))
289                         else:
290                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
291                         if len(ranks) >= 10:
292                             # Too many variants, bail out and only add
293                             # Worst-case Fallback: sum of penalty of partials
294                             name_partials = self.query.get_partials_list(trange)
295                             default = sum(t.penalty for t in name_partials) + 0.2
296                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
297                             # Bail out of outer loop
298                             todo.clear()
299                             break
300
301         ranks.sort(key=lambda r: len(r.tokens))
302         default = ranks[0].penalty + 0.3
303         del ranks[0]
304         ranks.sort(key=lambda r: r.penalty)
305
306         return dbf.FieldRanking('nameaddress_vector', default, ranks)
307
308
309     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
310         """ Collect the tokens for the non-name search fields in the
311             assignment.
312         """
313         sdata = dbf.SearchData()
314         sdata.penalty = assignment.penalty
315         if assignment.country:
316             tokens = self.query.get_tokens(assignment.country, TokenType.COUNTRY)
317             if self.details.countries:
318                 tokens = [t for t in tokens if t.lookup_word in self.details.countries]
319                 if not tokens:
320                     return None
321             sdata.set_strings('countries', tokens)
322         elif self.details.countries:
323             sdata.countries = dbf.WeightedStrings(self.details.countries,
324                                                   [0.0] * len(self.details.countries))
325         if assignment.housenumber:
326             sdata.set_strings('housenumbers',
327                               self.query.get_tokens(assignment.housenumber,
328                                                     TokenType.HOUSENUMBER))
329         if assignment.postcode:
330             sdata.set_strings('postcodes',
331                               self.query.get_tokens(assignment.postcode,
332                                                     TokenType.POSTCODE))
333         if assignment.qualifier:
334             sdata.set_qualifiers(self.query.get_tokens(assignment.qualifier,
335                                                        TokenType.QUALIFIER))
336
337         if assignment.address:
338             sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
339         else:
340             sdata.rankings = []
341
342         return sdata
343
344
345     def get_search_categories(self,
346                               assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
347         """ Collect tokens for category search or use the categories
348             requested per parameter.
349             Returns None if no category search is requested.
350         """
351         if assignment.category:
352             tokens = [t for t in self.query.get_tokens(assignment.category,
353                                                        TokenType.CATEGORY)
354                       if not self.details.categories
355                          or t.get_category() in self.details.categories]
356             return dbf.WeightedCategories([t.get_category() for t in tokens],
357                                           [t.penalty for t in tokens])
358
359         if self.details.categories:
360             return dbf.WeightedCategories(self.details.categories,
361                                           [0.0] * len(self.details.categories))
362
363         return None
364
365
366 PENALTY_WORDCHANGE = {
367     BreakType.START: 0.0,
368     BreakType.END: 0.0,
369     BreakType.PHRASE: 0.0,
370     BreakType.WORD: 0.1,
371     BreakType.PART: 0.2,
372     BreakType.TOKEN: 0.4
373 }