]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/search/db_search_fields.py
Merge pull request #3408 from lonvia/update-postcode-parents
[nominatim.git] / nominatim / api / search / db_search_fields.py
index 13f1c56eb09493261a10ece0e728c9e7cf0f252d..7f775277e6960a7dc37d76a7181b5ad4e4273084 100644 (file)
@@ -7,14 +7,16 @@
 """
 Data structures for more complex fields in abstract search descriptions.
 """
-from typing import List, Tuple, Iterator, cast
+from typing import List, Tuple, Iterator, Dict, Type
 import dataclasses
 
 import sqlalchemy as sa
-from sqlalchemy.dialects.postgresql import ARRAY
 
 from nominatim.typing import SaFromClause, SaColumn, SaExpression
 from nominatim.api.search.query import Token
+import nominatim.api.search.db_search_lookups as lookups
+from nominatim.utils.json_writer import JsonWriter
+
 
 @dataclasses.dataclass
 class WeightedStrings:
@@ -92,7 +94,7 @@ class RankedTokens:
 
     def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
         """ Create a new RankedTokens list with the given token appended.
-            The tokens penalty as well as the given transision penalty
+            The tokens penalty as well as the given transition penalty
             are added to the overall penalty.
         """
         return RankedTokens(self.penalty + t.penalty + transition_penalty,
@@ -129,11 +131,17 @@ class FieldRanking:
         """
         assert self.rankings
 
-        return sa.func.weigh_search(table.c[self.column],
-                                    [f"{{{','.join((str(s) for s in r.tokens))}}}"
-                                     for r in self.rankings],
-                                    [r.penalty for r in self.rankings],
-                                    self.default)
+        rout = JsonWriter().start_array()
+        for rank in self.rankings:
+            rout.start_array().value(rank.penalty).next()
+            rout.start_array()
+            for token in rank.tokens:
+                rout.value(token).next()
+            rout.end_array()
+            rout.end_array().next()
+        rout.end_array()
+
+        return sa.func.weigh_search(table.c[self.column], rout(), self.default)
 
 
 @dataclasses.dataclass
@@ -146,19 +154,12 @@ class FieldLookup:
     """
     column: str
     tokens: List[int]
-    lookup_type: str
+    lookup_type: Type[lookups.LookupType]
 
     def sql_condition(self, table: SaFromClause) -> SaColumn:
         """ Create an SQL expression for the given match condition.
         """
-        col = table.c[self.column]
-        if self.lookup_type == 'lookup_all':
-            return col.contains(self.tokens)
-        if self.lookup_type == 'lookup_any':
-            return cast(SaColumn, col.overlap(self.tokens))
-
-        return sa.func.array_cat(col, sa.text('ARRAY[]::integer[]'),
-                                 type_=ARRAY(sa.Integer())).contains(self.tokens)
+        return self.lookup_type(table, self.column, self.tokens)
 
 
 class SearchData:
@@ -195,10 +196,16 @@ class SearchData:
         """ Set the qulaifier field from the given tokens.
         """
         if tokens:
-            min_penalty = min(t.penalty for t in tokens)
+            categories: Dict[Tuple[str, str], float] = {}
+            min_penalty = 1000.0
+            for t in tokens:
+                min_penalty = min(min_penalty, t.penalty)
+                cat = t.get_category()
+                if t.penalty < categories.get(cat, 1000.0):
+                    categories[cat] = t.penalty
             self.penalty += min_penalty
-            self.qualifiers = WeightedCategories([t.get_category() for t in tokens],
-                                                 [t.penalty - min_penalty for t in tokens])
+            self.qualifiers = WeightedCategories(list(categories.keys()),
+                                                 list(categories.values()))
 
 
     def set_ranking(self, rankings: List[FieldRanking]) -> None:
@@ -211,3 +218,37 @@ class SearchData:
                 self.rankings.append(ranking)
             else:
                 self.penalty += ranking.default
+
+
+def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
+    """ Create a lookup list where name tokens are looked up via index
+        and potential address tokens are used to restrict the search further.
+    """
+    lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
+    if addr_tokens:
+        lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookups.Restrict))
+
+    return lookup
+
+
+def lookup_by_any_name(name_tokens: List[int], addr_restrict_tokens: List[int],
+                       addr_lookup_tokens: List[int]) -> List[FieldLookup]:
+    """ Create a lookup list where name tokens are looked up via index
+        and only one of the name tokens must be present.
+        Potential address tokens are used to restrict the search further.
+    """
+    lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAny)]
+    if addr_restrict_tokens:
+        lookup.append(FieldLookup('nameaddress_vector', addr_restrict_tokens, lookups.Restrict))
+    if addr_lookup_tokens:
+        lookup.append(FieldLookup('nameaddress_vector', addr_lookup_tokens, lookups.LookupAll))
+
+    return lookup
+
+
+def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
+    """ Create a lookup list where address tokens are looked up via index
+        and the name tokens are only used to restrict the search further.
+    """
+    return [FieldLookup('name_vector', name_tokens, lookups.Restrict),
+            FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)]