"""
Data structures for more complex fields in abstract search descriptions.
"""
-from typing import List, Tuple, cast
+from typing import List, Tuple, Iterator, Dict, Type
import dataclasses
import sqlalchemy as sa
-from sqlalchemy.dialects.postgresql import ARRAY
-from nominatim.typing import SaFromClause, SaColumn
+from nominatim.typing import SaFromClause, SaColumn, SaExpression
from nominatim.api.search.query import Token
+import nominatim.api.search.db_search_lookups as lookups
+from nominatim.utils.json_writer import JsonWriter
+
@dataclasses.dataclass
class WeightedStrings:
return bool(self.values)
+ def __iter__(self) -> Iterator[Tuple[str, float]]:
+ return iter(zip(self.values, self.penalties))
+
+
+ def get_penalty(self, value: str, default: float = 1000.0) -> float:
+ """ Get the penalty for the given value. Returns the given default
+ if the value does not exist.
+ """
+ try:
+ return self.penalties[self.values.index(value)]
+ except ValueError:
+ pass
+ return default
+
+
@dataclasses.dataclass
class WeightedCategories:
""" A list of class/type tuples together with a penalty.
return bool(self.values)
+ def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
+ return iter(zip(self.values, self.penalties))
+
+
+ def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
+ """ Get the penalty for the given value. Returns the given default
+ if the value does not exist.
+ """
+ try:
+ return self.penalties[self.values.index(value)]
+ except ValueError:
+ pass
+ return default
+
+
+ def sql_restrict(self, table: SaFromClause) -> SaExpression:
+ """ Return an SQLAlcheny expression that restricts the
+ class and type columns of the given table to the values
+ in the list.
+ Must not be used with an empty list.
+ """
+ assert self.values
+ if len(self.values) == 1:
+ return sa.and_(table.c.class_ == self.values[0][0],
+ table.c.type == self.values[0][1])
+
+ return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
+ for c, t in self.values))
+
+
@dataclasses.dataclass(order=True)
class RankedTokens:
""" List of tokens together with the penalty of using it.
def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
""" Create a new RankedTokens list with the given token appended.
- The tokens penalty as well as the given transision penalty
+ The tokens penalty as well as the given transition penalty
are added to the overall penalty.
"""
return RankedTokens(self.penalty + t.penalty + transition_penalty,
"""
assert self.rankings
- col = table.c[self.column]
+ rout = JsonWriter().start_array()
+ for rank in self.rankings:
+ rout.start_array().value(rank.penalty).next()
+ rout.start_array()
+ for token in rank.tokens:
+ rout.value(token).next()
+ rout.end_array()
+ rout.end_array().next()
+ rout.end_array()
- return sa.case(*((col.contains(r.tokens),r.penalty) for r in self.rankings),
- else_=self.default)
+ return sa.func.weigh_search(table.c[self.column], rout(), self.default)
@dataclasses.dataclass
"""
column: str
tokens: List[int]
- lookup_type: str
+ lookup_type: Type[lookups.LookupType]
def sql_condition(self, table: SaFromClause) -> SaColumn:
""" Create an SQL expression for the given match condition.
"""
- col = table.c[self.column]
- if self.lookup_type == 'lookup_all':
- return col.contains(self.tokens)
- if self.lookup_type == 'lookup_any':
- return cast(SaColumn, col.overlap(self.tokens))
-
- return sa.func.array_cat(col, sa.text('ARRAY[]::integer[]'),
- type_=ARRAY(sa.Integer())).contains(self.tokens)
+ return self.lookup_type(table, self.column, self.tokens)
class SearchData:
""" Set the qulaifier field from the given tokens.
"""
if tokens:
- min_penalty = min(t.penalty for t in tokens)
+ categories: Dict[Tuple[str, str], float] = {}
+ min_penalty = 1000.0
+ for t in tokens:
+ min_penalty = min(min_penalty, t.penalty)
+ cat = t.get_category()
+ if t.penalty < categories.get(cat, 1000.0):
+ categories[cat] = t.penalty
self.penalty += min_penalty
- self.qualifiers = WeightedCategories([t.get_category() for t in tokens],
- [t.penalty - min_penalty for t in tokens])
+ self.qualifiers = WeightedCategories(list(categories.keys()),
+ list(categories.values()))
def set_ranking(self, rankings: List[FieldRanking]) -> None:
self.rankings.append(ranking)
else:
self.penalty += ranking.default
+
+
+def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
+ """ Create a lookup list where name tokens are looked up via index
+ and potential address tokens are used to restrict the search further.
+ """
+ lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
+ if addr_tokens:
+ lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookups.Restrict))
+
+ return lookup
+
+
+def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int],
+ use_index_for_addr: bool) -> List[FieldLookup]:
+ """ Create a lookup list where name tokens are looked up via index
+ and only one of the name tokens must be present.
+ Potential address tokens are used to restrict the search further.
+ """
+ lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAny)]
+ if addr_tokens:
+ lookup.append(FieldLookup('nameaddress_vector', addr_tokens,
+ lookups.LookupAll if use_index_for_addr else lookups.Restrict))
+
+ return lookup
+
+
+def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
+ """ Create a lookup list where address tokens are looked up via index
+ and the name tokens are only used to restrict the search further.
+ """
+ return [FieldLookup('name_vector', name_tokens, lookups.Restrict),
+ FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)]