]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/geocoder.py
rerank results by query
[nominatim.git] / nominatim / api / search / geocoder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Public interface to the search code.
9 """
10 from typing import List, Any, Optional, Iterator, Tuple
11 import itertools
12 import re
13 import datetime as dt
14 import difflib
15
16 from nominatim.api.connection import SearchConnection
17 from nominatim.api.types import SearchDetails
18 from nominatim.api.results import SearchResults, add_result_details
19 from nominatim.api.search.token_assignment import yield_token_assignments
20 from nominatim.api.search.db_search_builder import SearchBuilder, build_poi_search, wrap_near_search
21 from nominatim.api.search.db_searches import AbstractSearch
22 from nominatim.api.search.query_analyzer_factory import make_query_analyzer, AbstractQueryAnalyzer
23 from nominatim.api.search.query import Phrase, QueryStruct
24 from nominatim.api.logging import log
25
26 class ForwardGeocoder:
27     """ Main class responsible for place search.
28     """
29
30     def __init__(self, conn: SearchConnection,
31                  params: SearchDetails, timeout: Optional[int]) -> None:
32         self.conn = conn
33         self.params = params
34         self.timeout = dt.timedelta(seconds=timeout or 1000000)
35         self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
36
37
38     @property
39     def limit(self) -> int:
40         """ Return the configured maximum number of search results.
41         """
42         return self.params.max_results
43
44
45     async def build_searches(self,
46                              phrases: List[Phrase]) -> Tuple[QueryStruct, List[AbstractSearch]]:
47         """ Analyse the query and return the tokenized query and list of
48             possible searches over it.
49         """
50         if self.query_analyzer is None:
51             self.query_analyzer = await make_query_analyzer(self.conn)
52
53         query = await self.query_analyzer.analyze_query(phrases)
54
55         searches: List[AbstractSearch] = []
56         if query.num_token_slots() > 0:
57             # 2. Compute all possible search interpretations
58             log().section('Compute abstract searches')
59             search_builder = SearchBuilder(query, self.params)
60             num_searches = 0
61             for assignment in yield_token_assignments(query):
62                 searches.extend(search_builder.build(assignment))
63                 if num_searches < len(searches):
64                     log().table_dump('Searches for assignment',
65                                      _dump_searches(searches, query, num_searches))
66                 num_searches = len(searches)
67             searches.sort(key=lambda s: s.penalty)
68
69         return query, searches
70
71
72     async def execute_searches(self, query: QueryStruct,
73                                searches: List[AbstractSearch]) -> SearchResults:
74         """ Run the abstract searches against the database until a result
75             is found.
76         """
77         log().section('Execute database searches')
78         results = SearchResults()
79         end_time = dt.datetime.now() + self.timeout
80
81         num_results = 0
82         min_ranking = 1000.0
83         prev_penalty = 0.0
84         for i, search in enumerate(searches):
85             if search.penalty > prev_penalty and (search.penalty > min_ranking or i > 20):
86                 break
87             log().table_dump(f"{i + 1}. Search", _dump_searches([search], query))
88             for result in await search.lookup(self.conn, self.params):
89                 results.append(result)
90                 min_ranking = min(min_ranking, result.ranking + 0.5, search.penalty + 0.3)
91             log().result_dump('Results', ((r.accuracy, r) for r in results[num_results:]))
92             num_results = len(results)
93             prev_penalty = search.penalty
94             if dt.datetime.now() >= end_time:
95                 break
96
97         return results
98
99
100     def sort_and_cut_results(self, results: SearchResults) -> SearchResults:
101         """ Remove badly matching results, sort by ranking and
102             limit to the configured number of results.
103         """
104         if results:
105             min_ranking = min(r.ranking for r in results)
106             results = SearchResults(r for r in results if r.ranking < min_ranking + 0.5)
107             results.sort(key=lambda r: r.ranking)
108
109         if results:
110             min_rank = results[0].rank_search
111             results = SearchResults(r for r in results
112                                     if r.ranking + 0.05 * (r.rank_search - min_rank)
113                                        < min_ranking + 0.5)
114
115             results = SearchResults(results[:self.limit])
116
117         return results
118
119
120     def rerank_by_query(self, query: QueryStruct, results: SearchResults) -> None:
121         """ Adjust the accuracy of the localized result according to how well
122             they match the original query.
123         """
124         assert self.query_analyzer is not None
125         qwords = [word for phrase in query.source
126                        for word in re.split('[, ]+', phrase.text) if word]
127         if not qwords:
128             return
129
130         for result in results:
131             if not result.display_name:
132                 continue
133             distance = 0.0
134             norm = self.query_analyzer.normalize_text(result.display_name)
135             words = set((w for w in norm.split(' ') if w))
136             if not words:
137                 continue
138             for qword in qwords:
139                 wdist = max(difflib.SequenceMatcher(a=qword, b=w).quick_ratio() for w in words)
140                 if wdist < 0.5:
141                     distance += len(qword)
142                 else:
143                     distance += (1.0 - wdist) * len(qword)
144             result.accuracy += distance * 0.5 / sum(len(w) for w in qwords)
145
146
147     async def lookup_pois(self, categories: List[Tuple[str, str]],
148                           phrases: List[Phrase]) -> SearchResults:
149         """ Look up places by category. If phrase is given, a place search
150             over the phrase will be executed first and places close to the
151             results returned.
152         """
153         log().function('forward_lookup_pois', categories=categories, params=self.params)
154
155         if phrases:
156             query, searches = await self.build_searches(phrases)
157
158             if query:
159                 searches = [wrap_near_search(categories, s) for s in searches[:50]]
160                 results = await self.execute_searches(query, searches)
161                 await add_result_details(self.conn, results, self.params)
162                 log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
163                 results = self.sort_and_cut_results(results)
164             else:
165                 results = SearchResults()
166         else:
167             search = build_poi_search(categories, self.params.countries)
168             results = await search.lookup(self.conn, self.params)
169             await add_result_details(self.conn, results, self.params)
170
171         log().result_dump('Final Results', ((r.accuracy, r) for r in results))
172
173         return results
174
175
176     async def lookup(self, phrases: List[Phrase]) -> SearchResults:
177         """ Look up a single free-text query.
178         """
179         log().function('forward_lookup', phrases=phrases, params=self.params)
180         results = SearchResults()
181
182         if self.params.is_impossible():
183             return results
184
185         query, searches = await self.build_searches(phrases)
186
187         if searches:
188             # Execute SQL until an appropriate result is found.
189             results = await self.execute_searches(query, searches[:50])
190             await add_result_details(self.conn, results, self.params)
191             log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
192             self.rerank_by_query(query, results)
193             log().result_dump('Results after reranking', ((r.accuracy, r) for r in results))
194             results = self.sort_and_cut_results(results)
195             log().result_dump('Final Results', ((r.accuracy, r) for r in results))
196
197         return results
198
199
200 # pylint: disable=invalid-name,too-many-locals
201 def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
202                    start: int = 0) -> Iterator[Optional[List[Any]]]:
203     yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries',
204            'Qualifier', 'Catgeory', 'Rankings']
205
206     def tk(tl: List[int]) -> str:
207         tstr = [f"{query.find_lookup_word_by_id(t)}({t})" for t in tl]
208
209         return f"[{','.join(tstr)}]"
210
211     def fmt_ranking(f: Any) -> str:
212         if not f:
213             return ''
214         ranks = ','.join((f"{tk(r.tokens)}^{r.penalty:.3g}" for r in f.rankings))
215         if len(ranks) > 100:
216             ranks = ranks[:100] + '...'
217         return f"{f.column}({ranks},def={f.default:.3g})"
218
219     def fmt_lookup(l: Any) -> str:
220         if not l:
221             return ''
222
223         return f"{l.lookup_type}({l.column}{tk(l.tokens)})"
224
225
226     def fmt_cstr(c: Any) -> str:
227         if not c:
228             return ''
229
230         return f'{c[0]}^{c[1]}'
231
232     for search in searches[start:]:
233         fields = ('lookups', 'rankings', 'countries', 'housenumbers',
234                   'postcodes', 'qualifiers')
235         if hasattr(search, 'search'):
236             iters = itertools.zip_longest([f"{search.penalty:.3g}"],
237                                           *(getattr(search.search, attr, []) for attr in fields),
238                                           getattr(search, 'categories', []),
239                                           fillvalue='')
240         else:
241             iters = itertools.zip_longest([f"{search.penalty:.3g}"],
242                                           *(getattr(search, attr, []) for attr in fields),
243                                           [],
244                                           fillvalue='')
245         for penalty, lookup, rank, cc, hnr, pc, qual, cat in iters:
246             yield [penalty, fmt_lookup(lookup), fmt_cstr(hnr),
247                    fmt_cstr(pc), fmt_cstr(cc), fmt_cstr(qual), fmt_cstr(cat), fmt_ranking(rank)]
248         yield None