]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/search/db_search_fields.py
Merge pull request #3293 from lonvia/rematch-against-country-code
[nominatim.git] / nominatim / api / search / db_search_fields.py
index 52693e95fce673026d97c545bc70b37ad52a17cf..6947a565f80dad421dcd9398975284988121a254 100644 (file)
@@ -7,13 +7,16 @@
 """
 Data structures for more complex fields in abstract search descriptions.
 """
 """
 Data structures for more complex fields in abstract search descriptions.
 """
-from typing import List, Tuple, Iterator, cast, Dict
+from typing import List, Tuple, Iterator, Dict, Type
 import dataclasses
 
 import sqlalchemy as sa
 
 from nominatim.typing import SaFromClause, SaColumn, SaExpression
 from nominatim.api.search.query import Token
 import dataclasses
 
 import sqlalchemy as sa
 
 from nominatim.typing import SaFromClause, SaColumn, SaExpression
 from nominatim.api.search.query import Token
+import nominatim.api.search.db_search_lookups as lookups
+from nominatim.utils.json_writer import JsonWriter
+
 
 @dataclasses.dataclass
 class WeightedStrings:
 
 @dataclasses.dataclass
 class WeightedStrings:
@@ -128,11 +131,17 @@ class FieldRanking:
         """
         assert self.rankings
 
         """
         assert self.rankings
 
-        return sa.func.weigh_search(table.c[self.column],
-                                    [f"{{{','.join((str(s) for s in r.tokens))}}}"
-                                     for r in self.rankings],
-                                    [r.penalty for r in self.rankings],
-                                    self.default)
+        rout = JsonWriter().start_array()
+        for rank in self.rankings:
+            rout.start_array().value(rank.penalty).next()
+            rout.start_array()
+            for token in rank.tokens:
+                rout.value(token).next()
+            rout.end_array()
+            rout.end_array().next()
+        rout.end_array()
+
+        return sa.func.weigh_search(table.c[self.column], rout(), self.default)
 
 
 @dataclasses.dataclass
 
 
 @dataclasses.dataclass
@@ -145,18 +154,12 @@ class FieldLookup:
     """
     column: str
     tokens: List[int]
     """
     column: str
     tokens: List[int]
-    lookup_type: str
+    lookup_type: Type[lookups.LookupType]
 
     def sql_condition(self, table: SaFromClause) -> SaColumn:
         """ Create an SQL expression for the given match condition.
         """
 
     def sql_condition(self, table: SaFromClause) -> SaColumn:
         """ Create an SQL expression for the given match condition.
         """
-        col = table.c[self.column]
-        if self.lookup_type == 'lookup_all':
-            return col.contains(self.tokens)
-        if self.lookup_type == 'lookup_any':
-            return cast(SaColumn, col.overlaps(self.tokens))
-
-        return sa.func.coalesce(sa.null(), col).contains(self.tokens) # pylint: disable=not-callable
+        return self.lookup_type(table, self.column, self.tokens)
 
 
 class SearchData:
 
 
 class SearchData:
@@ -222,22 +225,23 @@ def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[Fiel
     """ Create a lookup list where name tokens are looked up via index
         and potential address tokens are used to restrict the search further.
     """
     """ Create a lookup list where name tokens are looked up via index
         and potential address tokens are used to restrict the search further.
     """
-    lookup = [FieldLookup('name_vector', name_tokens, 'lookup_all')]
+    lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
     if addr_tokens:
     if addr_tokens:
-        lookup.append(FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
+        lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookups.Restrict))
 
     return lookup
 
 
 def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int],
 
     return lookup
 
 
 def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int],
-                       lookup_type: str) -> List[FieldLookup]:
+                       use_index_for_addr: bool) -> List[FieldLookup]:
     """ Create a lookup list where name tokens are looked up via index
         and only one of the name tokens must be present.
         Potential address tokens are used to restrict the search further.
     """
     """ Create a lookup list where name tokens are looked up via index
         and only one of the name tokens must be present.
         Potential address tokens are used to restrict the search further.
     """
-    lookup = [FieldLookup('name_vector', name_tokens, 'lookup_any')]
+    lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAny)]
     if addr_tokens:
     if addr_tokens:
-        lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookup_type))
+        lookup.append(FieldLookup('nameaddress_vector', addr_tokens,
+                                  lookups.LookupAll if use_index_for_addr else lookups.Restrict))
 
     return lookup
 
 
     return lookup
 
@@ -246,5 +250,5 @@ def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[Field
     """ Create a lookup list where address tokens are looked up via index
         and the name tokens are only used to restrict the search further.
     """
     """ Create a lookup list where address tokens are looked up via index
         and the name tokens are only used to restrict the search further.
     """
-    return [FieldLookup('name_vector', name_tokens, 'restrict'),
-            FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]
+    return [FieldLookup('name_vector', name_tokens, lookups.Restrict),
+            FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)]