]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/search/geocoder.py
Merge pull request #3375 from matkoniecz/patch-1
[nominatim.git] / nominatim / api / search / geocoder.py
index d341b6cd3532f52c8e69f0a3fb93b8cb13f37b6b..775606aab755d6f68180e696828115c285a99f11 100644 (file)
@@ -7,12 +7,15 @@
 """
 Public interface to the search code.
 """
-from typing import List, Any, Optional, Iterator, Tuple
+from typing import List, Any, Optional, Iterator, Tuple, Dict
 import itertools
+import re
+import datetime as dt
+import difflib
 
 from nominatim.api.connection import SearchConnection
 from nominatim.api.types import SearchDetails
-from nominatim.api.results import SearchResults, add_result_details
+from nominatim.api.results import SearchResult, SearchResults, add_result_details
 from nominatim.api.search.token_assignment import yield_token_assignments
 from nominatim.api.search.db_search_builder import SearchBuilder, build_poi_search, wrap_near_search
 from nominatim.api.search.db_searches import AbstractSearch
@@ -24,9 +27,11 @@ class ForwardGeocoder:
     """ Main class responsible for place search.
     """
 
-    def __init__(self, conn: SearchConnection, params: SearchDetails) -> None:
+    def __init__(self, conn: SearchConnection,
+                 params: SearchDetails, timeout: Optional[int]) -> None:
         self.conn = conn
         self.params = params
+        self.timeout = dt.timedelta(seconds=timeout or 1000000)
         self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
 
 
@@ -59,7 +64,7 @@ class ForwardGeocoder:
                     log().table_dump('Searches for assignment',
                                      _dump_searches(searches, query, num_searches))
                 num_searches = len(searches)
-            searches.sort(key=lambda s: s.penalty)
+            searches.sort(key=lambda s: (s.penalty, s.SEARCH_PRIO))
 
         return query, searches
 
@@ -70,39 +75,99 @@ class ForwardGeocoder:
             is found.
         """
         log().section('Execute database searches')
-        results = SearchResults()
+        results: Dict[Any, SearchResult] = {}
+
+        end_time = dt.datetime.now() + self.timeout
 
-        num_results = 0
-        min_ranking = 1000.0
+        min_ranking = searches[0].penalty + 2.0
         prev_penalty = 0.0
         for i, search in enumerate(searches):
             if search.penalty > prev_penalty and (search.penalty > min_ranking or i > 20):
                 break
             log().table_dump(f"{i + 1}. Search", _dump_searches([search], query))
-            for result in await search.lookup(self.conn, self.params):
-                results.append(result)
-                min_ranking = min(min_ranking, result.ranking + 0.5, search.penalty + 0.3)
-            log().result_dump('Results', ((r.accuracy, r) for r in results[num_results:]))
-            num_results = len(results)
+            log().var_dump('Params', self.params)
+            lookup_results = await search.lookup(self.conn, self.params)
+            for result in lookup_results:
+                rhash = (result.source_table, result.place_id,
+                         result.housenumber, result.country_code)
+                prevresult = results.get(rhash)
+                if prevresult:
+                    prevresult.accuracy = min(prevresult.accuracy, result.accuracy)
+                else:
+                    results[rhash] = result
+                min_ranking = min(min_ranking, result.accuracy * 1.2, 2.0)
+            log().result_dump('Results', ((r.accuracy, r) for r in lookup_results))
             prev_penalty = search.penalty
+            if dt.datetime.now() >= end_time:
+                break
 
-        if results:
-            min_ranking = min(r.ranking for r in results)
-            results = SearchResults(r for r in results if r.ranking < min_ranking + 0.5)
+        return SearchResults(results.values())
 
+
+    def pre_filter_results(self, results: SearchResults) -> SearchResults:
+        """ Remove results that are significantly worse than the
+            best match.
+        """
         if results:
-            min_rank = min(r.rank_search for r in results)
+            max_ranking = min(r.ranking for r in results) + 0.5
+            results = SearchResults(r for r in results if r.ranking < max_ranking)
+
+        return results
 
+
+    def sort_and_cut_results(self, results: SearchResults) -> SearchResults:
+        """ Remove badly matching results, sort by ranking and
+            limit to the configured number of results.
+        """
+        if results:
+            results.sort(key=lambda r: r.ranking)
+            min_rank = results[0].rank_search
+            min_ranking = results[0].ranking
             results = SearchResults(r for r in results
-                                    if r.ranking + 0.05 * (r.rank_search - min_rank)
+                                    if r.ranking + 0.03 * (r.rank_search - min_rank)
                                        < min_ranking + 0.5)
 
-            results.sort(key=lambda r: r.accuracy - r.calculated_importance())
             results = SearchResults(results[:self.limit])
 
         return results
 
 
+    def rerank_by_query(self, query: QueryStruct, results: SearchResults) -> None:
+        """ Adjust the accuracy of the localized result according to how well
+            they match the original query.
+        """
+        assert self.query_analyzer is not None
+        qwords = [word for phrase in query.source
+                       for word in re.split('[, ]+', phrase.text) if word]
+        if not qwords:
+            return
+
+        for result in results:
+            # Negative importance indicates ordering by distance, which is
+            # more important than word matching.
+            if not result.display_name\
+               or (result.importance is not None and result.importance < 0):
+                continue
+            distance = 0.0
+            norm = self.query_analyzer.normalize_text(' '.join((result.display_name,
+                                                                result.country_code or '')))
+            words = set((w for w in norm.split(' ') if w))
+            if not words:
+                continue
+            for qword in qwords:
+                wdist = max(difflib.SequenceMatcher(a=qword, b=w).quick_ratio() for w in words)
+                if wdist < 0.5:
+                    distance += len(qword)
+                else:
+                    distance += (1.0 - wdist) * len(qword)
+            # Compensate for the fact that country names do not get a
+            # match penalty yet by the tokenizer.
+            # Temporary hack that needs to be removed!
+            if result.rank_address == 4:
+                distance *= 2
+            result.accuracy += distance * 0.4 / sum(len(w) for w in qwords)
+
+
     async def lookup_pois(self, categories: List[Tuple[str, str]],
                           phrases: List[Phrase]) -> SearchResults:
         """ Look up places by category. If phrase is given, a place search
@@ -117,13 +182,17 @@ class ForwardGeocoder:
             if query:
                 searches = [wrap_near_search(categories, s) for s in searches[:50]]
                 results = await self.execute_searches(query, searches)
+                results = self.pre_filter_results(results)
+                await add_result_details(self.conn, results, self.params)
+                log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
+                results = self.sort_and_cut_results(results)
             else:
                 results = SearchResults()
         else:
             search = build_poi_search(categories, self.params.countries)
             results = await search.lookup(self.conn, self.params)
+            await add_result_details(self.conn, results, self.params)
 
-        await add_result_details(self.conn, results, self.params)
         log().result_dump('Final Results', ((r.accuracy, r) for r in results))
 
         return results
@@ -143,7 +212,12 @@ class ForwardGeocoder:
         if searches:
             # Execute SQL until an appropriate result is found.
             results = await self.execute_searches(query, searches[:50])
+            results = self.pre_filter_results(results)
             await add_result_details(self.conn, results, self.params)
+            log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
+            self.rerank_by_query(query, results)
+            log().result_dump('Results after reranking', ((r.accuracy, r) for r in results))
+            results = self.sort_and_cut_results(results)
             log().result_dump('Final Results', ((r.accuracy, r) for r in results))
 
         return results
@@ -152,7 +226,8 @@ class ForwardGeocoder:
 # pylint: disable=invalid-name,too-many-locals
 def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
                    start: int = 0) -> Iterator[Optional[List[Any]]]:
-    yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries', 'Qualifier', 'Rankings']
+    yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries',
+           'Qualifier', 'Catgeory', 'Rankings']
 
     def tk(tl: List[int]) -> str:
         tstr = [f"{query.find_lookup_word_by_id(t)}({t})" for t in tl]
@@ -182,11 +257,18 @@ def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
 
     for search in searches[start:]:
         fields = ('lookups', 'rankings', 'countries', 'housenumbers',
-                  'postcodes', 'qualifier')
-        iters = itertools.zip_longest([f"{search.penalty:.3g}"],
-                                      *(getattr(search, attr, []) for attr in fields),
-                                      fillvalue= '')
-        for penalty, lookup, rank, cc, hnr, pc, qual in iters:
+                  'postcodes', 'qualifiers')
+        if hasattr(search, 'search'):
+            iters = itertools.zip_longest([f"{search.penalty:.3g}"],
+                                          *(getattr(search.search, attr, []) for attr in fields),
+                                          getattr(search, 'categories', []),
+                                          fillvalue='')
+        else:
+            iters = itertools.zip_longest([f"{search.penalty:.3g}"],
+                                          *(getattr(search, attr, []) for attr in fields),
+                                          [],
+                                          fillvalue='')
+        for penalty, lookup, rank, cc, hnr, pc, qual, cat in iters:
             yield [penalty, fmt_lookup(lookup), fmt_cstr(hnr),
-                   fmt_cstr(pc), fmt_cstr(cc), fmt_cstr(qual), fmt_ranking(rank)]
+                   fmt_cstr(pc), fmt_cstr(cc), fmt_cstr(qual), fmt_cstr(cat), fmt_ranking(rank)]
         yield None