]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/search/db_search_fields.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / api / search / db_search_fields.py
index 9fcc2c4e521e9aa3ba55207edcf438af79a26949..2b2e3e56761e6ff15012297c144c52e30c4d74b1 100644 (file)
@@ -7,13 +7,13 @@
 """
 Data structures for more complex fields in abstract search descriptions.
 """
 """
 Data structures for more complex fields in abstract search descriptions.
 """
-from typing import List, Tuple, cast
+from typing import List, Tuple, Iterator, cast
 import dataclasses
 
 import sqlalchemy as sa
 from sqlalchemy.dialects.postgresql import ARRAY
 
 import dataclasses
 
 import sqlalchemy as sa
 from sqlalchemy.dialects.postgresql import ARRAY
 
-from nominatim.typing import SaFromClause, SaColumn
+from nominatim.typing import SaFromClause, SaColumn, SaExpression
 from nominatim.api.search.query import Token
 
 @dataclasses.dataclass
 from nominatim.api.search.query import Token
 
 @dataclasses.dataclass
@@ -27,6 +27,21 @@ class WeightedStrings:
         return bool(self.values)
 
 
         return bool(self.values)
 
 
+    def __iter__(self) -> Iterator[Tuple[str, float]]:
+        return iter(zip(self.values, self.penalties))
+
+
+    def get_penalty(self, value: str, default: float = 1000.0) -> float:
+        """ Get the penalty for the given value. Returns the given default
+            if the value does not exist.
+        """
+        try:
+            return self.penalties[self.values.index(value)]
+        except ValueError:
+            pass
+        return default
+
+
 @dataclasses.dataclass
 class WeightedCategories:
     """ A list of class/type tuples together with a penalty.
 @dataclasses.dataclass
 class WeightedCategories:
     """ A list of class/type tuples together with a penalty.
@@ -38,6 +53,36 @@ class WeightedCategories:
         return bool(self.values)
 
 
         return bool(self.values)
 
 
+    def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
+        return iter(zip(self.values, self.penalties))
+
+
+    def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
+        """ Get the penalty for the given value. Returns the given default
+            if the value does not exist.
+        """
+        try:
+            return self.penalties[self.values.index(value)]
+        except ValueError:
+            pass
+        return default
+
+
+    def sql_restrict(self, table: SaFromClause) -> SaExpression:
+        """ Return an SQLAlcheny expression that restricts the
+            class and type columns of the given table to the values
+            in the list.
+            Must not be used with an empty list.
+        """
+        assert self.values
+        if len(self.values) == 1:
+            return sa.and_(table.c.class_ == self.values[0][0],
+                           table.c.type == self.values[0][1])
+
+        return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
+                        for c, t in self.values))
+
+
 @dataclasses.dataclass(order=True)
 class RankedTokens:
     """ List of tokens together with the penalty of using it.
 @dataclasses.dataclass(order=True)
 class RankedTokens:
     """ List of tokens together with the penalty of using it.
@@ -84,10 +129,11 @@ class FieldRanking:
         """
         assert self.rankings
 
         """
         assert self.rankings
 
-        col = table.c[self.column]
-
-        return sa.case(*((col.contains(r.tokens),r.penalty) for r in self.rankings),
-                       else_=self.default)
+        return sa.func.weigh_search(table.c[self.column],
+                                    [f"{{{','.join((str(s) for s in r.tokens))}}}"
+                                     for r in self.rankings],
+                                    [r.penalty for r in self.rankings],
+                                    self.default)
 
 
 @dataclasses.dataclass
 
 
 @dataclasses.dataclass
@@ -165,3 +211,34 @@ class SearchData:
                 self.rankings.append(ranking)
             else:
                 self.penalty += ranking.default
                 self.rankings.append(ranking)
             else:
                 self.penalty += ranking.default
+
+
+def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
+    """ Create a lookup list where name tokens are looked up via index
+        and potential address tokens are used to restrict the search further.
+    """
+    lookup = [FieldLookup('name_vector', name_tokens, 'lookup_all')]
+    if addr_tokens:
+        lookup.append(FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
+
+    return lookup
+
+
+def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
+    """ Create a lookup list where name tokens are looked up via index
+        and only one of the name tokens must be present.
+        Potential address tokens are used to restrict the search further.
+    """
+    lookup = [FieldLookup('name_vector', name_tokens, 'lookup_any')]
+    if addr_tokens:
+        lookup.append(FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
+
+    return lookup
+
+
+def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
+    """ Create a lookup list where address tokens are looked up via index
+        and the name tokens are only used to restrict the search further.
+    """
+    return [FieldLookup('name_vector', name_tokens, 'restrict'),
+            FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]