]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/geocoder.py
Merge pull request #3262 from lonvia/fix-category-search
[nominatim.git] / nominatim / api / search / geocoder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Public interface to the search code.
9 """
10 from typing import List, Any, Optional, Iterator, Tuple, Dict
11 import itertools
12 import re
13 import datetime as dt
14 import difflib
15
16 from nominatim.api.connection import SearchConnection
17 from nominatim.api.types import SearchDetails
18 from nominatim.api.results import SearchResult, SearchResults, add_result_details
19 from nominatim.api.search.token_assignment import yield_token_assignments
20 from nominatim.api.search.db_search_builder import SearchBuilder, build_poi_search, wrap_near_search
21 from nominatim.api.search.db_searches import AbstractSearch
22 from nominatim.api.search.query_analyzer_factory import make_query_analyzer, AbstractQueryAnalyzer
23 from nominatim.api.search.query import Phrase, QueryStruct
24 from nominatim.api.logging import log
25
26 class ForwardGeocoder:
27     """ Main class responsible for place search.
28     """
29
30     def __init__(self, conn: SearchConnection,
31                  params: SearchDetails, timeout: Optional[int]) -> None:
32         self.conn = conn
33         self.params = params
34         self.timeout = dt.timedelta(seconds=timeout or 1000000)
35         self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
36
37
38     @property
39     def limit(self) -> int:
40         """ Return the configured maximum number of search results.
41         """
42         return self.params.max_results
43
44
45     async def build_searches(self,
46                              phrases: List[Phrase]) -> Tuple[QueryStruct, List[AbstractSearch]]:
47         """ Analyse the query and return the tokenized query and list of
48             possible searches over it.
49         """
50         if self.query_analyzer is None:
51             self.query_analyzer = await make_query_analyzer(self.conn)
52
53         query = await self.query_analyzer.analyze_query(phrases)
54
55         searches: List[AbstractSearch] = []
56         if query.num_token_slots() > 0:
57             # 2. Compute all possible search interpretations
58             log().section('Compute abstract searches')
59             search_builder = SearchBuilder(query, self.params)
60             num_searches = 0
61             for assignment in yield_token_assignments(query):
62                 searches.extend(search_builder.build(assignment))
63                 if num_searches < len(searches):
64                     log().table_dump('Searches for assignment',
65                                      _dump_searches(searches, query, num_searches))
66                 num_searches = len(searches)
67             searches.sort(key=lambda s: s.penalty)
68
69         return query, searches
70
71
72     async def execute_searches(self, query: QueryStruct,
73                                searches: List[AbstractSearch]) -> SearchResults:
74         """ Run the abstract searches against the database until a result
75             is found.
76         """
77         log().section('Execute database searches')
78         results: Dict[Any, SearchResult] = {}
79
80         end_time = dt.datetime.now() + self.timeout
81
82         min_ranking = 1000.0
83         prev_penalty = 0.0
84         for i, search in enumerate(searches):
85             if search.penalty > prev_penalty and (search.penalty > min_ranking or i > 20):
86                 break
87             log().table_dump(f"{i + 1}. Search", _dump_searches([search], query))
88             lookup_results = await search.lookup(self.conn, self.params)
89             for result in lookup_results:
90                 rhash = (result.source_table, result.place_id,
91                          result.housenumber, result.country_code)
92                 prevresult = results.get(rhash)
93                 if prevresult:
94                     prevresult.accuracy = min(prevresult.accuracy, result.accuracy)
95                 else:
96                     results[rhash] = result
97                 min_ranking = min(min_ranking, result.ranking + 0.5, search.penalty + 0.3)
98             log().result_dump('Results', ((r.accuracy, r) for r in lookup_results))
99             prev_penalty = search.penalty
100             if dt.datetime.now() >= end_time:
101                 break
102
103         return SearchResults(results.values())
104
105
106     def sort_and_cut_results(self, results: SearchResults) -> SearchResults:
107         """ Remove badly matching results, sort by ranking and
108             limit to the configured number of results.
109         """
110         if results:
111             min_ranking = min(r.ranking for r in results)
112             results = SearchResults(r for r in results if r.ranking < min_ranking + 0.5)
113             results.sort(key=lambda r: r.ranking)
114
115         if results:
116             min_rank = results[0].rank_search
117             results = SearchResults(r for r in results
118                                     if r.ranking + 0.05 * (r.rank_search - min_rank)
119                                        < min_ranking + 0.5)
120
121             results = SearchResults(results[:self.limit])
122
123         return results
124
125
126     def rerank_by_query(self, query: QueryStruct, results: SearchResults) -> None:
127         """ Adjust the accuracy of the localized result according to how well
128             they match the original query.
129         """
130         assert self.query_analyzer is not None
131         qwords = [word for phrase in query.source
132                        for word in re.split('[, ]+', phrase.text) if word]
133         if not qwords:
134             return
135
136         for result in results:
137             # Negative importance indicates ordering by distance, which is
138             # more important than word matching.
139             if not result.display_name\
140                or (result.importance is not None and result.importance < 0):
141                 continue
142             distance = 0.0
143             norm = self.query_analyzer.normalize_text(result.display_name)
144             words = set((w for w in norm.split(' ') if w))
145             if not words:
146                 continue
147             for qword in qwords:
148                 wdist = max(difflib.SequenceMatcher(a=qword, b=w).quick_ratio() for w in words)
149                 if wdist < 0.5:
150                     distance += len(qword)
151                 else:
152                     distance += (1.0 - wdist) * len(qword)
153             # Compensate for the fact that country names do not get a
154             # match penalty yet by the tokenizer.
155             # Temporary hack that needs to be removed!
156             if result.rank_address == 4:
157                 distance *= 2
158             result.accuracy += distance * 0.4 / sum(len(w) for w in qwords)
159
160
161     async def lookup_pois(self, categories: List[Tuple[str, str]],
162                           phrases: List[Phrase]) -> SearchResults:
163         """ Look up places by category. If phrase is given, a place search
164             over the phrase will be executed first and places close to the
165             results returned.
166         """
167         log().function('forward_lookup_pois', categories=categories, params=self.params)
168
169         if phrases:
170             query, searches = await self.build_searches(phrases)
171
172             if query:
173                 searches = [wrap_near_search(categories, s) for s in searches[:50]]
174                 results = await self.execute_searches(query, searches)
175                 await add_result_details(self.conn, results, self.params)
176                 log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
177                 results = self.sort_and_cut_results(results)
178             else:
179                 results = SearchResults()
180         else:
181             search = build_poi_search(categories, self.params.countries)
182             results = await search.lookup(self.conn, self.params)
183             await add_result_details(self.conn, results, self.params)
184
185         log().result_dump('Final Results', ((r.accuracy, r) for r in results))
186
187         return results
188
189
190     async def lookup(self, phrases: List[Phrase]) -> SearchResults:
191         """ Look up a single free-text query.
192         """
193         log().function('forward_lookup', phrases=phrases, params=self.params)
194         results = SearchResults()
195
196         if self.params.is_impossible():
197             return results
198
199         query, searches = await self.build_searches(phrases)
200
201         if searches:
202             # Execute SQL until an appropriate result is found.
203             results = await self.execute_searches(query, searches[:50])
204             await add_result_details(self.conn, results, self.params)
205             log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
206             self.rerank_by_query(query, results)
207             log().result_dump('Results after reranking', ((r.accuracy, r) for r in results))
208             results = self.sort_and_cut_results(results)
209             log().result_dump('Final Results', ((r.accuracy, r) for r in results))
210
211         return results
212
213
214 # pylint: disable=invalid-name,too-many-locals
215 def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
216                    start: int = 0) -> Iterator[Optional[List[Any]]]:
217     yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries',
218            'Qualifier', 'Catgeory', 'Rankings']
219
220     def tk(tl: List[int]) -> str:
221         tstr = [f"{query.find_lookup_word_by_id(t)}({t})" for t in tl]
222
223         return f"[{','.join(tstr)}]"
224
225     def fmt_ranking(f: Any) -> str:
226         if not f:
227             return ''
228         ranks = ','.join((f"{tk(r.tokens)}^{r.penalty:.3g}" for r in f.rankings))
229         if len(ranks) > 100:
230             ranks = ranks[:100] + '...'
231         return f"{f.column}({ranks},def={f.default:.3g})"
232
233     def fmt_lookup(l: Any) -> str:
234         if not l:
235             return ''
236
237         return f"{l.lookup_type}({l.column}{tk(l.tokens)})"
238
239
240     def fmt_cstr(c: Any) -> str:
241         if not c:
242             return ''
243
244         return f'{c[0]}^{c[1]}'
245
246     for search in searches[start:]:
247         fields = ('lookups', 'rankings', 'countries', 'housenumbers',
248                   'postcodes', 'qualifiers')
249         if hasattr(search, 'search'):
250             iters = itertools.zip_longest([f"{search.penalty:.3g}"],
251                                           *(getattr(search.search, attr, []) for attr in fields),
252                                           getattr(search, 'categories', []),
253                                           fillvalue='')
254         else:
255             iters = itertools.zip_longest([f"{search.penalty:.3g}"],
256                                           *(getattr(search, attr, []) for attr in fields),
257                                           [],
258                                           fillvalue='')
259         for penalty, lookup, rank, cc, hnr, pc, qual, cat in iters:
260             yield [penalty, fmt_lookup(lookup), fmt_cstr(hnr),
261                    fmt_cstr(pc), fmt_cstr(cc), fmt_cstr(qual), fmt_cstr(cat), fmt_ranking(rank)]
262         yield None