1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Public interface to the search code.
10 from typing import List, Any, Optional, Iterator, Tuple, Dict
16 from ..connection import SearchConnection
17 from ..types import SearchDetails
18 from ..results import SearchResult, SearchResults, add_result_details
19 from ..logging import log
20 from .token_assignment import yield_token_assignments
21 from .db_search_builder import SearchBuilder, build_poi_search, wrap_near_search
22 from .db_searches import AbstractSearch
23 from .query_analyzer_factory import make_query_analyzer, AbstractQueryAnalyzer
24 from .query import Phrase, QueryStruct
26 class ForwardGeocoder:
27 """ Main class responsible for place search.
30 def __init__(self, conn: SearchConnection,
31 params: SearchDetails, timeout: Optional[int]) -> None:
34 self.timeout = dt.timedelta(seconds=timeout or 1000000)
35 self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
39 def limit(self) -> int:
40 """ Return the configured maximum number of search results.
42 return self.params.max_results
45 async def build_searches(self,
46 phrases: List[Phrase]) -> Tuple[QueryStruct, List[AbstractSearch]]:
47 """ Analyse the query and return the tokenized query and list of
48 possible searches over it.
50 if self.query_analyzer is None:
51 self.query_analyzer = await make_query_analyzer(self.conn)
53 query = await self.query_analyzer.analyze_query(phrases)
55 searches: List[AbstractSearch] = []
56 if query.num_token_slots() > 0:
57 # 2. Compute all possible search interpretations
58 log().section('Compute abstract searches')
59 search_builder = SearchBuilder(query, self.params)
61 for assignment in yield_token_assignments(query):
62 searches.extend(search_builder.build(assignment))
63 if num_searches < len(searches):
64 log().table_dump('Searches for assignment',
65 _dump_searches(searches, query, num_searches))
66 num_searches = len(searches)
67 searches.sort(key=lambda s: (s.penalty, s.SEARCH_PRIO))
69 return query, searches
72 async def execute_searches(self, query: QueryStruct,
73 searches: List[AbstractSearch]) -> SearchResults:
74 """ Run the abstract searches against the database until a result
77 log().section('Execute database searches')
78 results: Dict[Any, SearchResult] = {}
80 end_time = dt.datetime.now() + self.timeout
82 min_ranking = searches[0].penalty + 2.0
84 for i, search in enumerate(searches):
85 if search.penalty > prev_penalty and (search.penalty > min_ranking or i > 20):
87 log().table_dump(f"{i + 1}. Search", _dump_searches([search], query))
88 log().var_dump('Params', self.params)
89 lookup_results = await search.lookup(self.conn, self.params)
90 for result in lookup_results:
91 rhash = (result.source_table, result.place_id,
92 result.housenumber, result.country_code)
93 prevresult = results.get(rhash)
95 prevresult.accuracy = min(prevresult.accuracy, result.accuracy)
97 results[rhash] = result
98 min_ranking = min(min_ranking, result.accuracy * 1.2, 2.0)
99 log().result_dump('Results', ((r.accuracy, r) for r in lookup_results))
100 prev_penalty = search.penalty
101 if dt.datetime.now() >= end_time:
104 return SearchResults(results.values())
107 def pre_filter_results(self, results: SearchResults) -> SearchResults:
108 """ Remove results that are significantly worse than the
112 max_ranking = min(r.ranking for r in results) + 0.5
113 results = SearchResults(r for r in results if r.ranking < max_ranking)
118 def sort_and_cut_results(self, results: SearchResults) -> SearchResults:
119 """ Remove badly matching results, sort by ranking and
120 limit to the configured number of results.
123 results.sort(key=lambda r: r.ranking)
124 min_rank = results[0].rank_search
125 min_ranking = results[0].ranking
126 results = SearchResults(r for r in results
127 if r.ranking + 0.03 * (r.rank_search - min_rank)
130 results = SearchResults(results[:self.limit])
135 def rerank_by_query(self, query: QueryStruct, results: SearchResults) -> None:
136 """ Adjust the accuracy of the localized result according to how well
137 they match the original query.
139 assert self.query_analyzer is not None
140 qwords = [word for phrase in query.source
141 for word in re.split('[, ]+', phrase.text) if word]
145 for result in results:
146 # Negative importance indicates ordering by distance, which is
147 # more important than word matching.
148 if not result.display_name\
149 or (result.importance is not None and result.importance < 0):
152 norm = self.query_analyzer.normalize_text(' '.join((result.display_name,
153 result.country_code or '')))
154 words = set((w for w in norm.split(' ') if w))
158 wdist = max(difflib.SequenceMatcher(a=qword, b=w).quick_ratio() for w in words)
160 distance += len(qword)
162 distance += (1.0 - wdist) * len(qword)
163 # Compensate for the fact that country names do not get a
164 # match penalty yet by the tokenizer.
165 # Temporary hack that needs to be removed!
166 if result.rank_address == 4:
168 result.accuracy += distance * 0.4 / sum(len(w) for w in qwords)
171 async def lookup_pois(self, categories: List[Tuple[str, str]],
172 phrases: List[Phrase]) -> SearchResults:
173 """ Look up places by category. If phrase is given, a place search
174 over the phrase will be executed first and places close to the
177 log().function('forward_lookup_pois', categories=categories, params=self.params)
180 query, searches = await self.build_searches(phrases)
183 searches = [wrap_near_search(categories, s) for s in searches[:50]]
184 results = await self.execute_searches(query, searches)
185 results = self.pre_filter_results(results)
186 await add_result_details(self.conn, results, self.params)
187 log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
188 results = self.sort_and_cut_results(results)
190 results = SearchResults()
192 search = build_poi_search(categories, self.params.countries)
193 results = await search.lookup(self.conn, self.params)
194 await add_result_details(self.conn, results, self.params)
196 log().result_dump('Final Results', ((r.accuracy, r) for r in results))
201 async def lookup(self, phrases: List[Phrase]) -> SearchResults:
202 """ Look up a single free-text query.
204 log().function('forward_lookup', phrases=phrases, params=self.params)
205 results = SearchResults()
207 if self.params.is_impossible():
210 query, searches = await self.build_searches(phrases)
213 # Execute SQL until an appropriate result is found.
214 results = await self.execute_searches(query, searches[:50])
215 results = self.pre_filter_results(results)
216 await add_result_details(self.conn, results, self.params)
217 log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
218 self.rerank_by_query(query, results)
219 log().result_dump('Results after reranking', ((r.accuracy, r) for r in results))
220 results = self.sort_and_cut_results(results)
221 log().result_dump('Final Results', ((r.accuracy, r) for r in results))
226 # pylint: disable=invalid-name,too-many-locals
227 def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
228 start: int = 0) -> Iterator[Optional[List[Any]]]:
229 yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries',
230 'Qualifier', 'Catgeory', 'Rankings']
232 def tk(tl: List[int]) -> str:
233 tstr = [f"{query.find_lookup_word_by_id(t)}({t})" for t in tl]
235 return f"[{','.join(tstr)}]"
237 def fmt_ranking(f: Any) -> str:
240 ranks = ','.join((f"{tk(r.tokens)}^{r.penalty:.3g}" for r in f.rankings))
242 ranks = ranks[:100] + '...'
243 return f"{f.column}({ranks},def={f.default:.3g})"
245 def fmt_lookup(l: Any) -> str:
249 return f"{l.lookup_type}({l.column}{tk(l.tokens)})"
252 def fmt_cstr(c: Any) -> str:
256 return f'{c[0]}^{c[1]}'
258 for search in searches[start:]:
259 fields = ('lookups', 'rankings', 'countries', 'housenumbers',
260 'postcodes', 'qualifiers')
261 if hasattr(search, 'search'):
262 iters = itertools.zip_longest([f"{search.penalty:.3g}"],
263 *(getattr(search.search, attr, []) for attr in fields),
264 getattr(search, 'categories', []),
267 iters = itertools.zip_longest([f"{search.penalty:.3g}"],
268 *(getattr(search, attr, []) for attr in fields),
271 for penalty, lookup, rank, cc, hnr, pc, qual, cat in iters:
272 yield [penalty, fmt_lookup(lookup), fmt_cstr(hnr),
273 fmt_cstr(pc), fmt_cstr(cc), fmt_cstr(qual), fmt_cstr(cat), fmt_ranking(rank)]