]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/search/db_search_builder.py
rename use of category as POI search to near_item
[nominatim.git] / nominatim / api / search / db_search_builder.py
index 905b5c621cfe8cd99e28ec0bf13a2a5c9bff392a..a0018480d2dee7cf10525e051a8e97e6b3641475 100644 (file)
@@ -7,7 +7,7 @@
 """
 Convertion from token assignment to an abstract DB search.
 """
-from typing import Optional, List, Tuple, Iterator
+from typing import Optional, List, Tuple, Iterator, Dict
 import heapq
 
 from nominatim.api.types import SearchDetails, DataLayer
@@ -89,12 +89,12 @@ class SearchBuilder:
         if sdata is None:
             return
 
-        categories = self.get_search_categories(assignment)
+        near_items = self.get_near_items(assignment)
 
         if assignment.name is None:
-            if categories and not sdata.postcodes:
-                sdata.qualifiers = categories
-                categories = None
+            if near_items and not sdata.postcodes:
+                sdata.qualifiers = near_items
+                near_items = None
                 builder = self.build_poi_search(sdata)
             elif assignment.housenumber:
                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
@@ -102,16 +102,16 @@ class SearchBuilder:
                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
             else:
                 builder = self.build_special_search(sdata, assignment.address,
-                                                    bool(categories))
+                                                    bool(near_items))
         else:
             builder = self.build_name_search(sdata, assignment.name, assignment.address,
-                                             bool(categories))
+                                             bool(near_items))
 
-        if categories:
-            penalty = min(categories.penalties)
-            categories.penalties = [p - penalty for p in categories.penalties]
+        if near_items:
+            penalty = min(near_items.penalties)
+            near_items.penalties = [p - penalty for p in near_items.penalties]
             for search in builder:
-                yield dbs.NearSearch(penalty + assignment.penalty, categories, search)
+                yield dbs.NearSearch(penalty + assignment.penalty, near_items, search)
         else:
             for search in builder:
                 search.penalty += assignment.penalty
@@ -321,8 +321,15 @@ class SearchBuilder:
                               self.query.get_tokens(assignment.postcode,
                                                     TokenType.POSTCODE))
         if assignment.qualifier:
-            sdata.set_qualifiers(self.query.get_tokens(assignment.qualifier,
-                                                       TokenType.QUALIFIER))
+            tokens = self.query.get_tokens(assignment.qualifier, TokenType.QUALIFIER)
+            if self.details.categories:
+                tokens = [t for t in tokens if t.get_category() in self.details.categories]
+                if not tokens:
+                    return None
+            sdata.set_qualifiers(tokens)
+        elif self.details.categories:
+            sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
+                                                      [0.0] * len(self.details.categories))
 
         if assignment.address:
             sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
@@ -332,23 +339,19 @@ class SearchBuilder:
         return sdata
 
 
-    def get_search_categories(self,
-                              assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
-        """ Collect tokens for category search or use the categories
+    def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
+        """ Collect tokens for near items search or use the categories
             requested per parameter.
             Returns None if no category search is requested.
         """
-        if assignment.category:
-            tokens = [t for t in self.query.get_tokens(assignment.category,
-                                                       TokenType.CATEGORY)
-                      if not self.details.categories
-                         or t.get_category() in self.details.categories]
-            return dbf.WeightedCategories([t.get_category() for t in tokens],
-                                          [t.penalty for t in tokens])
-
-        if self.details.categories:
-            return dbf.WeightedCategories(self.details.categories,
-                                          [0.0] * len(self.details.categories))
+        if assignment.near_item:
+            tokens: Dict[Tuple[str, str], float] = {}
+            for t in self.query.get_tokens(assignment.near_item, TokenType.NEAR_ITEM):
+                cat = t.get_category()
+                if (not self.details.categories or cat in self.details.categories)\
+                   and t.penalty < tokens.get(cat, 1000.0):
+                    tokens[cat] = t.penalty
+            return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
 
         return None