]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_builder.py
avoid forwarding variables via SQL
[nominatim.git] / nominatim / api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Convertion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator
11 import heapq
12
13 from nominatim.api.types import SearchDetails, DataLayer
14 from nominatim.api.search.query import QueryStruct, Token, TokenType, TokenRange, BreakType
15 from nominatim.api.search.token_assignment import TokenAssignment
16 import nominatim.api.search.db_search_fields as dbf
17 import nominatim.api.search.db_searches as dbs
18 from nominatim.api.logging import log
19
20
21 def wrap_near_search(categories: List[Tuple[str, str]],
22                      search: dbs.AbstractSearch) -> dbs.NearSearch:
23     """ Create a new search that wraps the given search in a search
24         for near places of the given category.
25     """
26     return dbs.NearSearch(penalty=search.penalty,
27                           categories=dbf.WeightedCategories(categories,
28                                                             [0.0] * len(categories)),
29                           search=search)
30
31
32 def build_poi_search(category: List[Tuple[str, str]],
33                      countries: Optional[List[str]]) -> dbs.PoiSearch:
34     """ Create a new search for places by the given category, possibly
35         constraint to the given countries.
36     """
37     if countries:
38         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
39     else:
40         ccs = dbf.WeightedStrings([], [])
41
42     class _PoiData(dbf.SearchData):
43         penalty = 0.0
44         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
45         countries=ccs
46
47     return dbs.PoiSearch(_PoiData())
48
49
50 class SearchBuilder:
51     """ Build the abstract search queries from token assignments.
52     """
53
54     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
55         self.query = query
56         self.details = details
57
58
59     @property
60     def configured_for_country(self) -> bool:
61         """ Return true if the search details are configured to
62             allow countries in the result.
63         """
64         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
65                and self.details.layer_enabled(DataLayer.ADDRESS)
66
67
68     @property
69     def configured_for_postcode(self) -> bool:
70         """ Return true if the search details are configured to
71             allow postcodes in the result.
72         """
73         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
74                and self.details.layer_enabled(DataLayer.ADDRESS)
75
76
77     @property
78     def configured_for_housenumbers(self) -> bool:
79         """ Return true if the search details are configured to
80             allow addresses in the result.
81         """
82         return self.details.max_rank >= 30 \
83                and self.details.layer_enabled(DataLayer.ADDRESS)
84
85
86     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
87         """ Yield all possible abstract searches for the given token assignment.
88         """
89         sdata = self.get_search_data(assignment)
90         if sdata is None:
91             return
92
93         categories = self.get_search_categories(assignment)
94
95         if assignment.name is None:
96             if categories and not sdata.postcodes:
97                 sdata.qualifiers = categories
98                 categories = None
99                 builder = self.build_poi_search(sdata)
100             elif assignment.housenumber:
101                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
102                                                    TokenType.HOUSENUMBER)
103                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
104             else:
105                 builder = self.build_special_search(sdata, assignment.address,
106                                                     bool(categories))
107         else:
108             builder = self.build_name_search(sdata, assignment.name, assignment.address,
109                                              bool(categories))
110
111         if categories:
112             penalty = min(categories.penalties)
113             categories.penalties = [p - penalty for p in categories.penalties]
114             for search in builder:
115                 yield dbs.NearSearch(penalty, categories, search)
116         else:
117             yield from builder
118
119
120     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
121         """ Build abstract search query for a simple category search.
122             This kind of search requires an additional geographic constraint.
123         """
124         if not sdata.housenumbers \
125            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
126             yield dbs.PoiSearch(sdata)
127
128
129     def build_special_search(self, sdata: dbf.SearchData,
130                              address: List[TokenRange],
131                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
132         """ Build abstract search queries for searches that do not involve
133             a named place.
134         """
135         if sdata.qualifiers:
136             # No special searches over qualifiers supported.
137             return
138
139         if sdata.countries and not address and not sdata.postcodes \
140            and self.configured_for_country:
141             yield dbs.CountrySearch(sdata)
142
143         if sdata.postcodes and (is_category or self.configured_for_postcode):
144             penalty = 0.0 if sdata.countries else 0.1
145             if address:
146                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
147                                                  [t.token for r in address
148                                                   for t in self.query.get_partials_list(r)],
149                                                  'restrict')]
150                 penalty += 0.2
151             yield dbs.PostcodeSearch(penalty, sdata)
152
153
154     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
155                                  address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
156         """ Build a simple address search for special entries where the
157             housenumber is the main name token.
158         """
159         partial_tokens: List[int] = []
160         for trange in address:
161             partial_tokens.extend(t.token for t in self.query.get_partials_list(trange))
162
163         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], 'lookup_any'),
164                          dbf.FieldLookup('nameaddress_vector', partial_tokens, 'lookup_all')
165                         ]
166         yield dbs.PlaceSearch(0.05, sdata, sum(t.count for t in hnrs))
167
168
169     def build_name_search(self, sdata: dbf.SearchData,
170                           name: TokenRange, address: List[TokenRange],
171                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
172         """ Build abstract search queries for simple name or address searches.
173         """
174         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
175             ranking = self.get_name_ranking(name)
176             name_penalty = ranking.normalize_penalty()
177             if ranking.rankings:
178                 sdata.rankings.append(ranking)
179             for penalty, count, lookup in self.yield_lookups(name, address):
180                 sdata.lookups = lookup
181                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
182
183
184     def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
185                           -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
186         """ Yield all variants how the given name and address should best
187             be searched for. This takes into account how frequent the terms
188             are and tries to find a lookup that optimizes index use.
189         """
190         penalty = 0.0 # extra penalty currently unused
191
192         name_partials = self.query.get_partials_list(name)
193         exp_name_count = min(t.count for t in name_partials)
194         addr_partials = []
195         for trange in address:
196             addr_partials.extend(self.query.get_partials_list(trange))
197         addr_tokens = [t.token for t in addr_partials]
198         partials_indexed = all(t.is_indexed for t in name_partials) \
199                            and all(t.is_indexed for t in addr_partials)
200
201         if (len(name_partials) > 3 or exp_name_count < 1000) and partials_indexed:
202             # Lookup by name partials, use address partials to restrict results.
203             lookup = [dbf.FieldLookup('name_vector',
204                                   [t.token for t in name_partials], 'lookup_all')]
205             if addr_tokens:
206                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
207             yield penalty, exp_name_count, lookup
208             return
209
210         exp_addr_count = min(t.count for t in addr_partials) if addr_partials else exp_name_count
211         if exp_addr_count < 1000 and partials_indexed:
212             # Lookup by address partials and restrict results through name terms.
213             # Give this a small penalty because lookups in the address index are
214             # more expensive
215             yield penalty + exp_addr_count/5000, exp_addr_count,\
216                   [dbf.FieldLookup('name_vector', [t.token for t in name_partials], 'restrict'),
217                    dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]
218             return
219
220         # Partial term to frequent. Try looking up by rare full names first.
221         name_fulls = self.query.get_tokens(name, TokenType.WORD)
222         rare_names = list(filter(lambda t: t.count < 1000, name_fulls))
223         # At this point drop unindexed partials from the address.
224         # This might yield wrong results, nothing we can do about that.
225         if not partials_indexed:
226             addr_tokens = [t.token for t in addr_partials if t.is_indexed]
227             log().var_dump('before', penalty)
228             penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
229             log().var_dump('after', penalty)
230         if rare_names:
231             # Any of the full names applies with all of the partials from the address
232             lookup = [dbf.FieldLookup('name_vector', [t.token for t in rare_names], 'lookup_any')]
233             if addr_tokens:
234                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
235             yield penalty, sum(t.count for t in rare_names), lookup
236
237         # To catch remaining results, lookup by name and address
238         # We only do this if there is a reasonable number of results expected.
239         if min(exp_name_count, exp_addr_count) < 10000:
240             if all(t.is_indexed for t in name_partials):
241                 lookup = [dbf.FieldLookup('name_vector',
242                                           [t.token for t in name_partials], 'lookup_all')]
243             else:
244                 # we don't have the partials, try with the non-rare names
245                 non_rare_names = [t.token for t in name_fulls if t.count >= 1000]
246                 if not non_rare_names:
247                     return
248                 lookup = [dbf.FieldLookup('name_vector', non_rare_names, 'lookup_any')]
249             if addr_tokens:
250                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all'))
251             yield penalty + 0.1 * max(0, 5 - len(name_partials) - len(addr_tokens)),\
252                   min(exp_name_count, exp_addr_count), lookup
253
254
255     def get_name_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
256         """ Create a ranking expression for a name term in the given range.
257         """
258         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
259         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
260         ranks.sort(key=lambda r: r.penalty)
261         # Fallback, sum of penalty for partials
262         name_partials = self.query.get_partials_list(trange)
263         default = sum(t.penalty for t in name_partials) + 0.2
264         return dbf.FieldRanking('name_vector', default, ranks)
265
266
267     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
268         """ Create a list of ranking expressions for an address term
269             for the given ranges.
270         """
271         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
272         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
273         ranks: List[dbf.RankedTokens] = []
274
275         while todo: # pylint: disable=too-many-nested-blocks
276             neglen, pos, rank = heapq.heappop(todo)
277             for tlist in self.query.nodes[pos].starting:
278                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
279                     if tlist.end < trange.end:
280                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
281                         if tlist.ttype == TokenType.PARTIAL:
282                             penalty = rank.penalty + chgpenalty \
283                                       + max(t.penalty for t in tlist.tokens)
284                             heapq.heappush(todo, (neglen - 1, tlist.end,
285                                                   dbf.RankedTokens(penalty, rank.tokens)))
286                         else:
287                             for t in tlist.tokens:
288                                 heapq.heappush(todo, (neglen - 1, tlist.end,
289                                                       rank.with_token(t, chgpenalty)))
290                     elif tlist.end == trange.end:
291                         if tlist.ttype == TokenType.PARTIAL:
292                             ranks.append(dbf.RankedTokens(rank.penalty
293                                                           + max(t.penalty for t in tlist.tokens),
294                                                           rank.tokens))
295                         else:
296                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
297                         if len(ranks) >= 10:
298                             # Too many variants, bail out and only add
299                             # Worst-case Fallback: sum of penalty of partials
300                             name_partials = self.query.get_partials_list(trange)
301                             default = sum(t.penalty for t in name_partials) + 0.2
302                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
303                             # Bail out of outer loop
304                             todo.clear()
305                             break
306
307         ranks.sort(key=lambda r: len(r.tokens))
308         default = ranks[0].penalty + 0.3
309         del ranks[0]
310         ranks.sort(key=lambda r: r.penalty)
311
312         return dbf.FieldRanking('nameaddress_vector', default, ranks)
313
314
315     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
316         """ Collect the tokens for the non-name search fields in the
317             assignment.
318         """
319         sdata = dbf.SearchData()
320         sdata.penalty = assignment.penalty
321         if assignment.country:
322             tokens = self.query.get_tokens(assignment.country, TokenType.COUNTRY)
323             if self.details.countries:
324                 tokens = [t for t in tokens if t.lookup_word in self.details.countries]
325                 if not tokens:
326                     return None
327             sdata.set_strings('countries', tokens)
328         elif self.details.countries:
329             sdata.countries = dbf.WeightedStrings(self.details.countries,
330                                                   [0.0] * len(self.details.countries))
331         if assignment.housenumber:
332             sdata.set_strings('housenumbers',
333                               self.query.get_tokens(assignment.housenumber,
334                                                     TokenType.HOUSENUMBER))
335         if assignment.postcode:
336             sdata.set_strings('postcodes',
337                               self.query.get_tokens(assignment.postcode,
338                                                     TokenType.POSTCODE))
339         if assignment.qualifier:
340             sdata.set_qualifiers(self.query.get_tokens(assignment.qualifier,
341                                                        TokenType.QUALIFIER))
342
343         if assignment.address:
344             sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
345         else:
346             sdata.rankings = []
347
348         return sdata
349
350
351     def get_search_categories(self,
352                               assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
353         """ Collect tokens for category search or use the categories
354             requested per parameter.
355             Returns None if no category search is requested.
356         """
357         if assignment.category:
358             tokens = [t for t in self.query.get_tokens(assignment.category,
359                                                        TokenType.CATEGORY)
360                       if not self.details.categories
361                          or t.get_category() in self.details.categories]
362             return dbf.WeightedCategories([t.get_category() for t in tokens],
363                                           [t.penalty for t in tokens])
364
365         if self.details.categories:
366             return dbf.WeightedCategories(self.details.categories,
367                                           [0.0] * len(self.details.categories))
368
369         return None
370
371
372 PENALTY_WORDCHANGE = {
373     BreakType.START: 0.0,
374     BreakType.END: 0.0,
375     BreakType.PHRASE: 0.0,
376     BreakType.WORD: 0.1,
377     BreakType.PART: 0.2,
378     BreakType.TOKEN: 0.4
379 }