]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/geocoder.py
work round typing bug in pyosmium 4.0
[nominatim.git] / src / nominatim_api / search / geocoder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Public interface to the search code.
9 """
10 from typing import List, Any, Optional, Iterator, Tuple, Dict
11 import itertools
12 import re
13 import datetime as dt
14 import difflib
15
16 from ..connection import SearchConnection
17 from ..types import SearchDetails
18 from ..results import SearchResult, SearchResults, add_result_details
19 from ..logging import log
20 from .token_assignment import yield_token_assignments
21 from .db_search_builder import SearchBuilder, build_poi_search, wrap_near_search
22 from .db_searches import AbstractSearch
23 from .query_analyzer_factory import make_query_analyzer, AbstractQueryAnalyzer
24 from .query import Phrase, QueryStruct
25
26 class ForwardGeocoder:
27     """ Main class responsible for place search.
28     """
29
30     def __init__(self, conn: SearchConnection,
31                  params: SearchDetails, timeout: Optional[int]) -> None:
32         self.conn = conn
33         self.params = params
34         self.timeout = dt.timedelta(seconds=timeout or 1000000)
35         self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
36
37
38     @property
39     def limit(self) -> int:
40         """ Return the configured maximum number of search results.
41         """
42         return self.params.max_results
43
44
45     async def build_searches(self,
46                              phrases: List[Phrase]) -> Tuple[QueryStruct, List[AbstractSearch]]:
47         """ Analyse the query and return the tokenized query and list of
48             possible searches over it.
49         """
50         if self.query_analyzer is None:
51             self.query_analyzer = await make_query_analyzer(self.conn)
52
53         query = await self.query_analyzer.analyze_query(phrases)
54
55         searches: List[AbstractSearch] = []
56         if query.num_token_slots() > 0:
57             # 2. Compute all possible search interpretations
58             log().section('Compute abstract searches')
59             search_builder = SearchBuilder(query, self.params)
60             num_searches = 0
61             for assignment in yield_token_assignments(query):
62                 searches.extend(search_builder.build(assignment))
63                 if num_searches < len(searches):
64                     log().table_dump('Searches for assignment',
65                                      _dump_searches(searches, query, num_searches))
66                 num_searches = len(searches)
67             searches.sort(key=lambda s: (s.penalty, s.SEARCH_PRIO))
68
69         return query, searches
70
71
72     async def execute_searches(self, query: QueryStruct,
73                                searches: List[AbstractSearch]) -> SearchResults:
74         """ Run the abstract searches against the database until a result
75             is found.
76         """
77         log().section('Execute database searches')
78         results: Dict[Any, SearchResult] = {}
79
80         end_time = dt.datetime.now() + self.timeout
81
82         min_ranking = searches[0].penalty + 2.0
83         prev_penalty = 0.0
84         for i, search in enumerate(searches):
85             if search.penalty > prev_penalty and (search.penalty > min_ranking or i > 20):
86                 break
87             log().table_dump(f"{i + 1}. Search", _dump_searches([search], query))
88             log().var_dump('Params', self.params)
89             lookup_results = await search.lookup(self.conn, self.params)
90             for result in lookup_results:
91                 rhash = (result.source_table, result.place_id,
92                          result.housenumber, result.country_code)
93                 prevresult = results.get(rhash)
94                 if prevresult:
95                     prevresult.accuracy = min(prevresult.accuracy, result.accuracy)
96                 else:
97                     results[rhash] = result
98                 min_ranking = min(min_ranking, result.accuracy * 1.2, 2.0)
99             log().result_dump('Results', ((r.accuracy, r) for r in lookup_results))
100             prev_penalty = search.penalty
101             if dt.datetime.now() >= end_time:
102                 break
103
104         return SearchResults(results.values())
105
106
107     def pre_filter_results(self, results: SearchResults) -> SearchResults:
108         """ Remove results that are significantly worse than the
109             best match.
110         """
111         if results:
112             max_ranking = min(r.ranking for r in results) + 0.5
113             results = SearchResults(r for r in results if r.ranking < max_ranking)
114
115         return results
116
117
118     def sort_and_cut_results(self, results: SearchResults) -> SearchResults:
119         """ Remove badly matching results, sort by ranking and
120             limit to the configured number of results.
121         """
122         if results:
123             results.sort(key=lambda r: r.ranking)
124             min_rank = results[0].rank_search
125             min_ranking = results[0].ranking
126             results = SearchResults(r for r in results
127                                     if r.ranking + 0.03 * (r.rank_search - min_rank)
128                                        < min_ranking + 0.5)
129
130             results = SearchResults(results[:self.limit])
131
132         return results
133
134
135     def rerank_by_query(self, query: QueryStruct, results: SearchResults) -> None:
136         """ Adjust the accuracy of the localized result according to how well
137             they match the original query.
138         """
139         assert self.query_analyzer is not None
140         qwords = [word for phrase in query.source
141                        for word in re.split('[, ]+', phrase.text) if word]
142         if not qwords:
143             return
144
145         for result in results:
146             # Negative importance indicates ordering by distance, which is
147             # more important than word matching.
148             if not result.display_name\
149                or (result.importance is not None and result.importance < 0):
150                 continue
151             distance = 0.0
152             norm = self.query_analyzer.normalize_text(' '.join((result.display_name,
153                                                                 result.country_code or '')))
154             words = set((w for w in norm.split(' ') if w))
155             if not words:
156                 continue
157             for qword in qwords:
158                 wdist = max(difflib.SequenceMatcher(a=qword, b=w).quick_ratio() for w in words)
159                 if wdist < 0.5:
160                     distance += len(qword)
161                 else:
162                     distance += (1.0 - wdist) * len(qword)
163             # Compensate for the fact that country names do not get a
164             # match penalty yet by the tokenizer.
165             # Temporary hack that needs to be removed!
166             if result.rank_address == 4:
167                 distance *= 2
168             result.accuracy += distance * 0.4 / sum(len(w) for w in qwords)
169
170
171     async def lookup_pois(self, categories: List[Tuple[str, str]],
172                           phrases: List[Phrase]) -> SearchResults:
173         """ Look up places by category. If phrase is given, a place search
174             over the phrase will be executed first and places close to the
175             results returned.
176         """
177         log().function('forward_lookup_pois', categories=categories, params=self.params)
178
179         if phrases:
180             query, searches = await self.build_searches(phrases)
181
182             if query:
183                 searches = [wrap_near_search(categories, s) for s in searches[:50]]
184                 results = await self.execute_searches(query, searches)
185                 results = self.pre_filter_results(results)
186                 await add_result_details(self.conn, results, self.params)
187                 log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
188                 results = self.sort_and_cut_results(results)
189             else:
190                 results = SearchResults()
191         else:
192             search = build_poi_search(categories, self.params.countries)
193             results = await search.lookup(self.conn, self.params)
194             await add_result_details(self.conn, results, self.params)
195
196         log().result_dump('Final Results', ((r.accuracy, r) for r in results))
197
198         return results
199
200
201     async def lookup(self, phrases: List[Phrase]) -> SearchResults:
202         """ Look up a single free-text query.
203         """
204         log().function('forward_lookup', phrases=phrases, params=self.params)
205         results = SearchResults()
206
207         if self.params.is_impossible():
208             return results
209
210         query, searches = await self.build_searches(phrases)
211
212         if searches:
213             # Execute SQL until an appropriate result is found.
214             results = await self.execute_searches(query, searches[:50])
215             results = self.pre_filter_results(results)
216             await add_result_details(self.conn, results, self.params)
217             log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
218             self.rerank_by_query(query, results)
219             log().result_dump('Results after reranking', ((r.accuracy, r) for r in results))
220             results = self.sort_and_cut_results(results)
221             log().result_dump('Final Results', ((r.accuracy, r) for r in results))
222
223         return results
224
225
226 # pylint: disable=invalid-name,too-many-locals
227 def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
228                    start: int = 0) -> Iterator[Optional[List[Any]]]:
229     yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries',
230            'Qualifier', 'Catgeory', 'Rankings']
231
232     def tk(tl: List[int]) -> str:
233         tstr = [f"{query.find_lookup_word_by_id(t)}({t})" for t in tl]
234
235         return f"[{','.join(tstr)}]"
236
237     def fmt_ranking(f: Any) -> str:
238         if not f:
239             return ''
240         ranks = ','.join((f"{tk(r.tokens)}^{r.penalty:.3g}" for r in f.rankings))
241         if len(ranks) > 100:
242             ranks = ranks[:100] + '...'
243         return f"{f.column}({ranks},def={f.default:.3g})"
244
245     def fmt_lookup(l: Any) -> str:
246         if not l:
247             return ''
248
249         return f"{l.lookup_type}({l.column}{tk(l.tokens)})"
250
251
252     def fmt_cstr(c: Any) -> str:
253         if not c:
254             return ''
255
256         return f'{c[0]}^{c[1]}'
257
258     for search in searches[start:]:
259         fields = ('lookups', 'rankings', 'countries', 'housenumbers',
260                   'postcodes', 'qualifiers')
261         if hasattr(search, 'search'):
262             iters = itertools.zip_longest([f"{search.penalty:.3g}"],
263                                           *(getattr(search.search, attr, []) for attr in fields),
264                                           getattr(search, 'categories', []),
265                                           fillvalue='')
266         else:
267             iters = itertools.zip_longest([f"{search.penalty:.3g}"],
268                                           *(getattr(search, attr, []) for attr in fields),
269                                           [],
270                                           fillvalue='')
271         for penalty, lookup, rank, cc, hnr, pc, qual, cat in iters:
272             yield [penalty, fmt_lookup(lookup), fmt_cstr(hnr),
273                    fmt_cstr(pc), fmt_cstr(cc), fmt_cstr(qual), fmt_cstr(cat), fmt_ranking(rank)]
274         yield None