]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_builder.py
properly close connections when shutting down starlette
[nominatim.git] / nominatim / api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Convertion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator
11 import heapq
12
13 from nominatim.api.types import SearchDetails, DataLayer
14 from nominatim.api.search.query import QueryStruct, TokenType, TokenRange, BreakType
15 from nominatim.api.search.token_assignment import TokenAssignment
16 import nominatim.api.search.db_search_fields as dbf
17 import nominatim.api.search.db_searches as dbs
18 from nominatim.api.logging import log
19
20
21 def wrap_near_search(categories: List[Tuple[str, str]],
22                      search: dbs.AbstractSearch) -> dbs.NearSearch:
23     """ Create a new search that wraps the given search in a search
24         for near places of the given category.
25     """
26     return dbs.NearSearch(penalty=search.penalty,
27                           categories=dbf.WeightedCategories(categories,
28                                                             [0.0] * len(categories)),
29                           search=search)
30
31
32 def build_poi_search(category: List[Tuple[str, str]],
33                      countries: Optional[List[str]]) -> dbs.PoiSearch:
34     """ Create a new search for places by the given category, possibly
35         constraint to the given countries.
36     """
37     if countries:
38         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
39     else:
40         ccs = dbf.WeightedStrings([], [])
41
42     class _PoiData(dbf.SearchData):
43         penalty = 0.0
44         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
45         countries=ccs
46
47     return dbs.PoiSearch(_PoiData())
48
49
50 class SearchBuilder:
51     """ Build the abstract search queries from token assignments.
52     """
53
54     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
55         self.query = query
56         self.details = details
57
58
59     @property
60     def configured_for_country(self) -> bool:
61         """ Return true if the search details are configured to
62             allow countries in the result.
63         """
64         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
65                and self.details.layer_enabled(DataLayer.ADDRESS)
66
67
68     @property
69     def configured_for_postcode(self) -> bool:
70         """ Return true if the search details are configured to
71             allow postcodes in the result.
72         """
73         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
74                and self.details.layer_enabled(DataLayer.ADDRESS)
75
76
77     @property
78     def configured_for_housenumbers(self) -> bool:
79         """ Return true if the search details are configured to
80             allow addresses in the result.
81         """
82         return self.details.max_rank >= 30 \
83                and self.details.layer_enabled(DataLayer.ADDRESS)
84
85
86     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
87         """ Yield all possible abstract searches for the given token assignment.
88         """
89         sdata = self.get_search_data(assignment)
90         if sdata is None:
91             return
92
93         categories = self.get_search_categories(assignment)
94
95         if assignment.name is None:
96             if categories and not sdata.postcodes:
97                 sdata.qualifiers = categories
98                 categories = None
99                 builder = self.build_poi_search(sdata)
100             else:
101                 builder = self.build_special_search(sdata, assignment.address,
102                                                     bool(categories))
103         else:
104             builder = self.build_name_search(sdata, assignment.name, assignment.address,
105                                              bool(categories))
106
107         if categories:
108             penalty = min(categories.penalties)
109             categories.penalties = [p - penalty for p in categories.penalties]
110             for search in builder:
111                 yield dbs.NearSearch(penalty, categories, search)
112         else:
113             yield from builder
114
115
116     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
117         """ Build abstract search query for a simple category search.
118             This kind of search requires an additional geographic constraint.
119         """
120         if not sdata.housenumbers \
121            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
122             yield dbs.PoiSearch(sdata)
123
124
125     def build_special_search(self, sdata: dbf.SearchData,
126                              address: List[TokenRange],
127                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
128         """ Build abstract search queries for searches that do not involve
129             a named place.
130         """
131         if sdata.qualifiers or sdata.housenumbers:
132             # No special searches over housenumbers or qualifiers supported.
133             return
134
135         if sdata.countries and not address and not sdata.postcodes \
136            and self.configured_for_country:
137             yield dbs.CountrySearch(sdata)
138
139         if sdata.postcodes and (is_category or self.configured_for_postcode):
140             if address:
141                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
142                                                  [t.token for r in address
143                                                   for t in self.query.get_partials_list(r)],
144                                                  'restrict')]
145             yield dbs.PostcodeSearch(0.4, sdata)
146
147
148     def build_name_search(self, sdata: dbf.SearchData,
149                           name: TokenRange, address: List[TokenRange],
150                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
151         """ Build abstract search queries for simple name or address searches.
152         """
153         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
154             sdata.rankings.append(self.get_name_ranking(name))
155             name_penalty = sdata.rankings[-1].normalize_penalty()
156             for penalty, count, lookup in self.yield_lookups(name, address):
157                 sdata.lookups = lookup
158                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
159
160
161     def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
162                           -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
163         """ Yield all variants how the given name and address should best
164             be searched for. This takes into account how frequent the terms
165             are and tries to find a lookup that optimizes index use.
166         """
167         penalty = 0.0 # extra penalty currently unused
168
169         name_partials = self.query.get_partials_list(name)
170         exp_name_count = min(t.count for t in name_partials)
171         addr_partials = []
172         for trange in address:
173             addr_partials.extend(self.query.get_partials_list(trange))
174         addr_tokens = [t.token for t in addr_partials]
175         partials_indexed = all(t.is_indexed for t in name_partials) \
176                            and all(t.is_indexed for t in addr_partials)
177
178         if (len(name_partials) > 3 or exp_name_count < 1000) and partials_indexed:
179             # Lookup by name partials, use address partials to restrict results.
180             lookup = [dbf.FieldLookup('name_vector',
181                                   [t.token for t in name_partials], 'lookup_all')]
182             if addr_tokens:
183                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
184             yield penalty, exp_name_count, lookup
185             return
186
187         exp_addr_count = min(t.count for t in addr_partials) if addr_partials else exp_name_count
188         if exp_addr_count < 1000 and partials_indexed:
189             # Lookup by address partials and restrict results through name terms.
190             yield penalty, exp_addr_count,\
191                   [dbf.FieldLookup('name_vector', [t.token for t in name_partials], 'restrict'),
192                    dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]
193             return
194
195         # Partial term to frequent. Try looking up by rare full names first.
196         name_fulls = self.query.get_tokens(name, TokenType.WORD)
197         rare_names = list(filter(lambda t: t.count < 1000, name_fulls))
198         # At this point drop unindexed partials from the address.
199         # This might yield wrong results, nothing we can do about that.
200         if not partials_indexed:
201             addr_tokens = [t.token for t in addr_partials if t.is_indexed]
202             log().var_dump('before', penalty)
203             penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
204             log().var_dump('after', penalty)
205         if rare_names:
206             # Any of the full names applies with all of the partials from the address
207             lookup = [dbf.FieldLookup('name_vector', [t.token for t in rare_names], 'lookup_any')]
208             if addr_tokens:
209                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
210             yield penalty, sum(t.count for t in rare_names), lookup
211
212         # To catch remaining results, lookup by name and address
213         if all(t.is_indexed for t in name_partials):
214             lookup = [dbf.FieldLookup('name_vector',
215                                       [t.token for t in name_partials], 'lookup_all')]
216         else:
217             # we don't have the partials, try with the non-rare names
218             non_rare_names = [t.token for t in name_fulls if t.count >= 1000]
219             if not non_rare_names:
220                 return
221             lookup = [dbf.FieldLookup('name_vector', non_rare_names, 'lookup_any')]
222         if addr_tokens:
223             lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all'))
224         yield penalty + 0.1 * max(0, 5 - len(name_partials) - len(addr_tokens)),\
225               min(exp_name_count, exp_addr_count), lookup
226
227
228     def get_name_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
229         """ Create a ranking expression for a name term in the given range.
230         """
231         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
232         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
233         ranks.sort(key=lambda r: r.penalty)
234         # Fallback, sum of penalty for partials
235         name_partials = self.query.get_partials_list(trange)
236         default = sum(t.penalty for t in name_partials) + 0.2
237         return dbf.FieldRanking('name_vector', default, ranks)
238
239
240     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
241         """ Create a list of ranking expressions for an address term
242             for the given ranges.
243         """
244         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
245         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
246         ranks: List[dbf.RankedTokens] = []
247
248         while todo: # pylint: disable=too-many-nested-blocks
249             neglen, pos, rank = heapq.heappop(todo)
250             for tlist in self.query.nodes[pos].starting:
251                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
252                     if tlist.end < trange.end:
253                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
254                         if tlist.ttype == TokenType.PARTIAL:
255                             penalty = rank.penalty + chgpenalty \
256                                       + max(t.penalty for t in tlist.tokens)
257                             heapq.heappush(todo, (neglen - 1, tlist.end,
258                                                   dbf.RankedTokens(penalty, rank.tokens)))
259                         else:
260                             for t in tlist.tokens:
261                                 heapq.heappush(todo, (neglen - 1, tlist.end,
262                                                       rank.with_token(t, chgpenalty)))
263                     elif tlist.end == trange.end:
264                         if tlist.ttype == TokenType.PARTIAL:
265                             ranks.append(dbf.RankedTokens(rank.penalty
266                                                           + max(t.penalty for t in tlist.tokens),
267                                                           rank.tokens))
268                         else:
269                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
270                         if len(ranks) >= 10:
271                             # Too many variants, bail out and only add
272                             # Worst-case Fallback: sum of penalty of partials
273                             name_partials = self.query.get_partials_list(trange)
274                             default = sum(t.penalty for t in name_partials) + 0.2
275                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
276                             # Bail out of outer loop
277                             todo.clear()
278                             break
279
280         ranks.sort(key=lambda r: len(r.tokens))
281         default = ranks[0].penalty + 0.3
282         del ranks[0]
283         ranks.sort(key=lambda r: r.penalty)
284
285         return dbf.FieldRanking('nameaddress_vector', default, ranks)
286
287
288     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
289         """ Collect the tokens for the non-name search fields in the
290             assignment.
291         """
292         sdata = dbf.SearchData()
293         sdata.penalty = assignment.penalty
294         if assignment.country:
295             tokens = self.query.get_tokens(assignment.country, TokenType.COUNTRY)
296             if self.details.countries:
297                 tokens = [t for t in tokens if t.lookup_word in self.details.countries]
298                 if not tokens:
299                     return None
300             sdata.set_strings('countries', tokens)
301         elif self.details.countries:
302             sdata.countries = dbf.WeightedStrings(self.details.countries,
303                                                   [0.0] * len(self.details.countries))
304         if assignment.housenumber:
305             sdata.set_strings('housenumbers',
306                               self.query.get_tokens(assignment.housenumber,
307                                                     TokenType.HOUSENUMBER))
308         if assignment.postcode:
309             sdata.set_strings('postcodes',
310                               self.query.get_tokens(assignment.postcode,
311                                                     TokenType.POSTCODE))
312         if assignment.qualifier:
313             sdata.set_qualifiers(self.query.get_tokens(assignment.qualifier,
314                                                        TokenType.QUALIFIER))
315
316         if assignment.address:
317             sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
318         else:
319             sdata.rankings = []
320
321         return sdata
322
323
324     def get_search_categories(self,
325                               assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
326         """ Collect tokens for category search or use the categories
327             requested per parameter.
328             Returns None if no category search is requested.
329         """
330         if assignment.category:
331             tokens = [t for t in self.query.get_tokens(assignment.category,
332                                                        TokenType.CATEGORY)
333                       if not self.details.categories
334                          or t.get_category() in self.details.categories]
335             return dbf.WeightedCategories([t.get_category() for t in tokens],
336                                           [t.penalty for t in tokens])
337
338         if self.details.categories:
339             return dbf.WeightedCategories(self.details.categories,
340                                           [0.0] * len(self.details.categories))
341
342         return None
343
344
345 PENALTY_WORDCHANGE = {
346     BreakType.START: 0.0,
347     BreakType.END: 0.0,
348     BreakType.PHRASE: 0.0,
349     BreakType.WORD: 0.1,
350     BreakType.PART: 0.2,
351     BreakType.TOKEN: 0.4
352 }