]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/geocoder.py
fix style issue found by flake8
[nominatim.git] / src / nominatim_api / search / geocoder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Public interface to the search code.
9 """
10 from typing import List, Any, Optional, Iterator, Tuple, Dict
11 import itertools
12 import re
13 import datetime as dt
14 import difflib
15
16 from ..connection import SearchConnection
17 from ..types import SearchDetails
18 from ..results import SearchResult, SearchResults, add_result_details
19 from ..logging import log
20 from .token_assignment import yield_token_assignments
21 from .db_search_builder import SearchBuilder, build_poi_search, wrap_near_search
22 from .db_searches import AbstractSearch
23 from .query_analyzer_factory import make_query_analyzer, AbstractQueryAnalyzer
24 from .query import Phrase, QueryStruct
25
26
27 class ForwardGeocoder:
28     """ Main class responsible for place search.
29     """
30
31     def __init__(self, conn: SearchConnection,
32                  params: SearchDetails, timeout: Optional[int]) -> None:
33         self.conn = conn
34         self.params = params
35         self.timeout = dt.timedelta(seconds=timeout or 1000000)
36         self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
37
38     @property
39     def limit(self) -> int:
40         """ Return the configured maximum number of search results.
41         """
42         return self.params.max_results
43
44     async def build_searches(self,
45                              phrases: List[Phrase]) -> Tuple[QueryStruct, List[AbstractSearch]]:
46         """ Analyse the query and return the tokenized query and list of
47             possible searches over it.
48         """
49         if self.query_analyzer is None:
50             self.query_analyzer = await make_query_analyzer(self.conn)
51
52         query = await self.query_analyzer.analyze_query(phrases)
53
54         searches: List[AbstractSearch] = []
55         if query.num_token_slots() > 0:
56             # 2. Compute all possible search interpretations
57             log().section('Compute abstract searches')
58             search_builder = SearchBuilder(query, self.params)
59             num_searches = 0
60             for assignment in yield_token_assignments(query):
61                 searches.extend(search_builder.build(assignment))
62                 if num_searches < len(searches):
63                     log().table_dump('Searches for assignment',
64                                      _dump_searches(searches, query, num_searches))
65                 num_searches = len(searches)
66             searches.sort(key=lambda s: (s.penalty, s.SEARCH_PRIO))
67
68         return query, searches
69
70     async def execute_searches(self, query: QueryStruct,
71                                searches: List[AbstractSearch]) -> SearchResults:
72         """ Run the abstract searches against the database until a result
73             is found.
74         """
75         log().section('Execute database searches')
76         results: Dict[Any, SearchResult] = {}
77
78         end_time = dt.datetime.now() + self.timeout
79
80         min_ranking = searches[0].penalty + 2.0
81         prev_penalty = 0.0
82         for i, search in enumerate(searches):
83             if search.penalty > prev_penalty and (search.penalty > min_ranking or i > 20):
84                 break
85             log().table_dump(f"{i + 1}. Search", _dump_searches([search], query))
86             log().var_dump('Params', self.params)
87             lookup_results = await search.lookup(self.conn, self.params)
88             for result in lookup_results:
89                 rhash = (result.source_table, result.place_id,
90                          result.housenumber, result.country_code)
91                 prevresult = results.get(rhash)
92                 if prevresult:
93                     prevresult.accuracy = min(prevresult.accuracy, result.accuracy)
94                 else:
95                     results[rhash] = result
96                 min_ranking = min(min_ranking, result.accuracy * 1.2, 2.0)
97             log().result_dump('Results', ((r.accuracy, r) for r in lookup_results))
98             prev_penalty = search.penalty
99             if dt.datetime.now() >= end_time:
100                 break
101
102         return SearchResults(results.values())
103
104     def pre_filter_results(self, results: SearchResults) -> SearchResults:
105         """ Remove results that are significantly worse than the
106             best match.
107         """
108         if results:
109             max_ranking = min(r.ranking for r in results) + 0.5
110             results = SearchResults(r for r in results if r.ranking < max_ranking)
111
112         return results
113
114     def sort_and_cut_results(self, results: SearchResults) -> SearchResults:
115         """ Remove badly matching results, sort by ranking and
116             limit to the configured number of results.
117         """
118         if results:
119             results.sort(key=lambda r: r.ranking)
120             min_rank = results[0].rank_search
121             min_ranking = results[0].ranking
122             results = SearchResults(r for r in results
123                                     if (r.ranking + 0.03 * (r.rank_search - min_rank)
124                                         < min_ranking + 0.5))
125
126             results = SearchResults(results[:self.limit])
127
128         return results
129
130     def rerank_by_query(self, query: QueryStruct, results: SearchResults) -> None:
131         """ Adjust the accuracy of the localized result according to how well
132             they match the original query.
133         """
134         assert self.query_analyzer is not None
135         qwords = [word for phrase in query.source
136                   for word in re.split('[, ]+', phrase.text) if word]
137         if not qwords:
138             return
139
140         for result in results:
141             # Negative importance indicates ordering by distance, which is
142             # more important than word matching.
143             if not result.display_name\
144                or (result.importance is not None and result.importance < 0):
145                 continue
146             distance = 0.0
147             norm = self.query_analyzer.normalize_text(' '.join((result.display_name,
148                                                                 result.country_code or '')))
149             words = set((w for w in norm.split(' ') if w))
150             if not words:
151                 continue
152             for qword in qwords:
153                 wdist = max(difflib.SequenceMatcher(a=qword, b=w).quick_ratio() for w in words)
154                 if wdist < 0.5:
155                     distance += len(qword)
156                 else:
157                     distance += (1.0 - wdist) * len(qword)
158             # Compensate for the fact that country names do not get a
159             # match penalty yet by the tokenizer.
160             # Temporary hack that needs to be removed!
161             if result.rank_address == 4:
162                 distance *= 2
163             result.accuracy += distance * 0.4 / sum(len(w) for w in qwords)
164
165     async def lookup_pois(self, categories: List[Tuple[str, str]],
166                           phrases: List[Phrase]) -> SearchResults:
167         """ Look up places by category. If phrase is given, a place search
168             over the phrase will be executed first and places close to the
169             results returned.
170         """
171         log().function('forward_lookup_pois', categories=categories, params=self.params)
172
173         if phrases:
174             query, searches = await self.build_searches(phrases)
175
176             if query:
177                 searches = [wrap_near_search(categories, s) for s in searches[:50]]
178                 results = await self.execute_searches(query, searches)
179                 results = self.pre_filter_results(results)
180                 await add_result_details(self.conn, results, self.params)
181                 log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
182                 results = self.sort_and_cut_results(results)
183             else:
184                 results = SearchResults()
185         else:
186             search = build_poi_search(categories, self.params.countries)
187             results = await search.lookup(self.conn, self.params)
188             await add_result_details(self.conn, results, self.params)
189
190         log().result_dump('Final Results', ((r.accuracy, r) for r in results))
191
192         return results
193
194     async def lookup(self, phrases: List[Phrase]) -> SearchResults:
195         """ Look up a single free-text query.
196         """
197         log().function('forward_lookup', phrases=phrases, params=self.params)
198         results = SearchResults()
199
200         if self.params.is_impossible():
201             return results
202
203         query, searches = await self.build_searches(phrases)
204
205         if searches:
206             # Execute SQL until an appropriate result is found.
207             results = await self.execute_searches(query, searches[:50])
208             results = self.pre_filter_results(results)
209             await add_result_details(self.conn, results, self.params)
210             log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
211             self.rerank_by_query(query, results)
212             log().result_dump('Results after reranking', ((r.accuracy, r) for r in results))
213             results = self.sort_and_cut_results(results)
214             log().result_dump('Final Results', ((r.accuracy, r) for r in results))
215
216         return results
217
218
219 def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
220                    start: int = 0) -> Iterator[Optional[List[Any]]]:
221     yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries',
222            'Qualifier', 'Catgeory', 'Rankings']
223
224     def tk(tl: List[int]) -> str:
225         tstr = [f"{query.find_lookup_word_by_id(t)}({t})" for t in tl]
226
227         return f"[{','.join(tstr)}]"
228
229     def fmt_ranking(f: Any) -> str:
230         if not f:
231             return ''
232         ranks = ','.join((f"{tk(r.tokens)}^{r.penalty:.3g}" for r in f.rankings))
233         if len(ranks) > 100:
234             ranks = ranks[:100] + '...'
235         return f"{f.column}({ranks},def={f.default:.3g})"
236
237     def fmt_lookup(lk: Any) -> str:
238         if not lk:
239             return ''
240
241         return f"{lk.lookup_type}({lk.column}{tk(lk.tokens)})"
242
243     def fmt_cstr(c: Any) -> str:
244         if not c:
245             return ''
246
247         return f'{c[0]}^{c[1]}'
248
249     for search in searches[start:]:
250         fields = ('lookups', 'rankings', 'countries', 'housenumbers',
251                   'postcodes', 'qualifiers')
252         if hasattr(search, 'search'):
253             iters = itertools.zip_longest([f"{search.penalty:.3g}"],
254                                           *(getattr(search.search, attr, []) for attr in fields),
255                                           getattr(search, 'categories', []),
256                                           fillvalue='')
257         else:
258             iters = itertools.zip_longest([f"{search.penalty:.3g}"],
259                                           *(getattr(search, attr, []) for attr in fields),
260                                           [],
261                                           fillvalue='')
262         for penalty, lookup, rank, cc, hnr, pc, qual, cat in iters:
263             yield [penalty, fmt_lookup(lookup), fmt_cstr(hnr),
264                    fmt_cstr(pc), fmt_cstr(cc), fmt_cstr(qual), fmt_cstr(cat), fmt_ranking(rank)]
265         yield None