]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_builder.py
move search to bind parameters
[nominatim.git] / nominatim / api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Convertion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator
11 import heapq
12
13 from nominatim.api.types import SearchDetails, DataLayer
14 from nominatim.api.search.query import QueryStruct, Token, TokenType, TokenRange, BreakType
15 from nominatim.api.search.token_assignment import TokenAssignment
16 import nominatim.api.search.db_search_fields as dbf
17 import nominatim.api.search.db_searches as dbs
18 from nominatim.api.logging import log
19
20
21 def wrap_near_search(categories: List[Tuple[str, str]],
22                      search: dbs.AbstractSearch) -> dbs.NearSearch:
23     """ Create a new search that wraps the given search in a search
24         for near places of the given category.
25     """
26     return dbs.NearSearch(penalty=search.penalty,
27                           categories=dbf.WeightedCategories(categories,
28                                                             [0.0] * len(categories)),
29                           search=search)
30
31
32 def build_poi_search(category: List[Tuple[str, str]],
33                      countries: Optional[List[str]]) -> dbs.PoiSearch:
34     """ Create a new search for places by the given category, possibly
35         constraint to the given countries.
36     """
37     if countries:
38         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
39     else:
40         ccs = dbf.WeightedStrings([], [])
41
42     class _PoiData(dbf.SearchData):
43         penalty = 0.0
44         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
45         countries=ccs
46
47     return dbs.PoiSearch(_PoiData())
48
49
50 class SearchBuilder:
51     """ Build the abstract search queries from token assignments.
52     """
53
54     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
55         self.query = query
56         self.details = details
57
58
59     @property
60     def configured_for_country(self) -> bool:
61         """ Return true if the search details are configured to
62             allow countries in the result.
63         """
64         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
65                and self.details.layer_enabled(DataLayer.ADDRESS)
66
67
68     @property
69     def configured_for_postcode(self) -> bool:
70         """ Return true if the search details are configured to
71             allow postcodes in the result.
72         """
73         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
74                and self.details.layer_enabled(DataLayer.ADDRESS)
75
76
77     @property
78     def configured_for_housenumbers(self) -> bool:
79         """ Return true if the search details are configured to
80             allow addresses in the result.
81         """
82         return self.details.max_rank >= 30 \
83                and self.details.layer_enabled(DataLayer.ADDRESS)
84
85
86     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
87         """ Yield all possible abstract searches for the given token assignment.
88         """
89         sdata = self.get_search_data(assignment)
90         if sdata is None:
91             return
92
93         categories = self.get_search_categories(assignment)
94
95         if assignment.name is None:
96             if categories and not sdata.postcodes:
97                 sdata.qualifiers = categories
98                 categories = None
99                 builder = self.build_poi_search(sdata)
100             elif assignment.housenumber:
101                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
102                                                    TokenType.HOUSENUMBER)
103                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
104             else:
105                 builder = self.build_special_search(sdata, assignment.address,
106                                                     bool(categories))
107         else:
108             builder = self.build_name_search(sdata, assignment.name, assignment.address,
109                                              bool(categories))
110
111         if categories:
112             penalty = min(categories.penalties)
113             categories.penalties = [p - penalty for p in categories.penalties]
114             for search in builder:
115                 yield dbs.NearSearch(penalty, categories, search)
116         else:
117             yield from builder
118
119
120     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
121         """ Build abstract search query for a simple category search.
122             This kind of search requires an additional geographic constraint.
123         """
124         if not sdata.housenumbers \
125            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
126             yield dbs.PoiSearch(sdata)
127
128
129     def build_special_search(self, sdata: dbf.SearchData,
130                              address: List[TokenRange],
131                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
132         """ Build abstract search queries for searches that do not involve
133             a named place.
134         """
135         if sdata.qualifiers:
136             # No special searches over qualifiers supported.
137             return
138
139         if sdata.countries and not address and not sdata.postcodes \
140            and self.configured_for_country:
141             yield dbs.CountrySearch(sdata)
142
143         if sdata.postcodes and (is_category or self.configured_for_postcode):
144             penalty = 0.0 if sdata.countries else 0.1
145             if address:
146                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
147                                                  [t.token for r in address
148                                                   for t in self.query.get_partials_list(r)],
149                                                  'restrict')]
150                 penalty += 0.2
151             yield dbs.PostcodeSearch(penalty, sdata)
152
153
154     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
155                                  address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
156         """ Build a simple address search for special entries where the
157             housenumber is the main name token.
158         """
159         partial_tokens: List[int] = []
160         for trange in address:
161             partial_tokens.extend(t.token for t in self.query.get_partials_list(trange))
162
163         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], 'lookup_any'),
164                          dbf.FieldLookup('nameaddress_vector', partial_tokens, 'lookup_all')
165                         ]
166         yield dbs.PlaceSearch(0.05, sdata, sum(t.count for t in hnrs))
167
168
169     def build_name_search(self, sdata: dbf.SearchData,
170                           name: TokenRange, address: List[TokenRange],
171                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
172         """ Build abstract search queries for simple name or address searches.
173         """
174         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
175             ranking = self.get_name_ranking(name)
176             name_penalty = ranking.normalize_penalty()
177             if ranking.rankings:
178                 sdata.rankings.append(ranking)
179             for penalty, count, lookup in self.yield_lookups(name, address):
180                 sdata.lookups = lookup
181                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
182
183
184     def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
185                           -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
186         """ Yield all variants how the given name and address should best
187             be searched for. This takes into account how frequent the terms
188             are and tries to find a lookup that optimizes index use.
189         """
190         penalty = 0.0 # extra penalty currently unused
191
192         name_partials = self.query.get_partials_list(name)
193         exp_name_count = min(t.count for t in name_partials)
194         addr_partials = []
195         for trange in address:
196             addr_partials.extend(self.query.get_partials_list(trange))
197         addr_tokens = [t.token for t in addr_partials]
198         partials_indexed = all(t.is_indexed for t in name_partials) \
199                            and all(t.is_indexed for t in addr_partials)
200
201         if (len(name_partials) > 3 or exp_name_count < 1000) and partials_indexed:
202             # Lookup by name partials, use address partials to restrict results.
203             lookup = [dbf.FieldLookup('name_vector',
204                                   [t.token for t in name_partials], 'lookup_all')]
205             if addr_tokens:
206                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
207             yield penalty, exp_name_count, lookup
208             return
209
210         exp_addr_count = min(t.count for t in addr_partials) if addr_partials else exp_name_count
211         if exp_addr_count < 1000 and partials_indexed:
212             # Lookup by address partials and restrict results through name terms.
213             yield penalty, exp_addr_count,\
214                   [dbf.FieldLookup('name_vector', [t.token for t in name_partials], 'restrict'),
215                    dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]
216             return
217
218         # Partial term to frequent. Try looking up by rare full names first.
219         name_fulls = self.query.get_tokens(name, TokenType.WORD)
220         rare_names = list(filter(lambda t: t.count < 1000, name_fulls))
221         # At this point drop unindexed partials from the address.
222         # This might yield wrong results, nothing we can do about that.
223         if not partials_indexed:
224             addr_tokens = [t.token for t in addr_partials if t.is_indexed]
225             log().var_dump('before', penalty)
226             penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
227             log().var_dump('after', penalty)
228         if rare_names:
229             # Any of the full names applies with all of the partials from the address
230             lookup = [dbf.FieldLookup('name_vector', [t.token for t in rare_names], 'lookup_any')]
231             if addr_tokens:
232                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
233             yield penalty, sum(t.count for t in rare_names), lookup
234
235         # To catch remaining results, lookup by name and address
236         if all(t.is_indexed for t in name_partials):
237             lookup = [dbf.FieldLookup('name_vector',
238                                       [t.token for t in name_partials], 'lookup_all')]
239         else:
240             # we don't have the partials, try with the non-rare names
241             non_rare_names = [t.token for t in name_fulls if t.count >= 1000]
242             if not non_rare_names:
243                 return
244             lookup = [dbf.FieldLookup('name_vector', non_rare_names, 'lookup_any')]
245         if addr_tokens:
246             lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all'))
247         yield penalty + 0.1 * max(0, 5 - len(name_partials) - len(addr_tokens)),\
248               min(exp_name_count, exp_addr_count), lookup
249
250
251     def get_name_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
252         """ Create a ranking expression for a name term in the given range.
253         """
254         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
255         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
256         ranks.sort(key=lambda r: r.penalty)
257         # Fallback, sum of penalty for partials
258         name_partials = self.query.get_partials_list(trange)
259         default = sum(t.penalty for t in name_partials) + 0.2
260         return dbf.FieldRanking('name_vector', default, ranks)
261
262
263     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
264         """ Create a list of ranking expressions for an address term
265             for the given ranges.
266         """
267         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
268         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
269         ranks: List[dbf.RankedTokens] = []
270
271         while todo: # pylint: disable=too-many-nested-blocks
272             neglen, pos, rank = heapq.heappop(todo)
273             for tlist in self.query.nodes[pos].starting:
274                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
275                     if tlist.end < trange.end:
276                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
277                         if tlist.ttype == TokenType.PARTIAL:
278                             penalty = rank.penalty + chgpenalty \
279                                       + max(t.penalty for t in tlist.tokens)
280                             heapq.heappush(todo, (neglen - 1, tlist.end,
281                                                   dbf.RankedTokens(penalty, rank.tokens)))
282                         else:
283                             for t in tlist.tokens:
284                                 heapq.heappush(todo, (neglen - 1, tlist.end,
285                                                       rank.with_token(t, chgpenalty)))
286                     elif tlist.end == trange.end:
287                         if tlist.ttype == TokenType.PARTIAL:
288                             ranks.append(dbf.RankedTokens(rank.penalty
289                                                           + max(t.penalty for t in tlist.tokens),
290                                                           rank.tokens))
291                         else:
292                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
293                         if len(ranks) >= 10:
294                             # Too many variants, bail out and only add
295                             # Worst-case Fallback: sum of penalty of partials
296                             name_partials = self.query.get_partials_list(trange)
297                             default = sum(t.penalty for t in name_partials) + 0.2
298                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
299                             # Bail out of outer loop
300                             todo.clear()
301                             break
302
303         ranks.sort(key=lambda r: len(r.tokens))
304         default = ranks[0].penalty + 0.3
305         del ranks[0]
306         ranks.sort(key=lambda r: r.penalty)
307
308         return dbf.FieldRanking('nameaddress_vector', default, ranks)
309
310
311     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
312         """ Collect the tokens for the non-name search fields in the
313             assignment.
314         """
315         sdata = dbf.SearchData()
316         sdata.penalty = assignment.penalty
317         if assignment.country:
318             tokens = self.query.get_tokens(assignment.country, TokenType.COUNTRY)
319             if self.details.countries:
320                 tokens = [t for t in tokens if t.lookup_word in self.details.countries]
321                 if not tokens:
322                     return None
323             sdata.set_strings('countries', tokens)
324         elif self.details.countries:
325             sdata.countries = dbf.WeightedStrings(self.details.countries,
326                                                   [0.0] * len(self.details.countries))
327         if assignment.housenumber:
328             sdata.set_strings('housenumbers',
329                               self.query.get_tokens(assignment.housenumber,
330                                                     TokenType.HOUSENUMBER))
331         if assignment.postcode:
332             sdata.set_strings('postcodes',
333                               self.query.get_tokens(assignment.postcode,
334                                                     TokenType.POSTCODE))
335         if assignment.qualifier:
336             sdata.set_qualifiers(self.query.get_tokens(assignment.qualifier,
337                                                        TokenType.QUALIFIER))
338
339         if assignment.address:
340             sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
341         else:
342             sdata.rankings = []
343
344         return sdata
345
346
347     def get_search_categories(self,
348                               assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
349         """ Collect tokens for category search or use the categories
350             requested per parameter.
351             Returns None if no category search is requested.
352         """
353         if assignment.category:
354             tokens = [t for t in self.query.get_tokens(assignment.category,
355                                                        TokenType.CATEGORY)
356                       if not self.details.categories
357                          or t.get_category() in self.details.categories]
358             return dbf.WeightedCategories([t.get_category() for t in tokens],
359                                           [t.penalty for t in tokens])
360
361         if self.details.categories:
362             return dbf.WeightedCategories(self.details.categories,
363                                           [0.0] * len(self.details.categories))
364
365         return None
366
367
368 PENALTY_WORDCHANGE = {
369     BreakType.START: 0.0,
370     BreakType.END: 0.0,
371     BreakType.PHRASE: 0.0,
372     BreakType.WORD: 0.1,
373     BreakType.PART: 0.2,
374     BreakType.TOKEN: 0.4
375 }