"""
Data structures for more complex fields in abstract search descriptions.
"""
-from typing import List, Tuple, cast
+from typing import List, Tuple, Iterator, cast
import dataclasses
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import ARRAY
-from nominatim.typing import SaFromClause, SaColumn
+from nominatim.typing import SaFromClause, SaColumn, SaExpression
from nominatim.api.search.query import Token
@dataclasses.dataclass
return bool(self.values)
+ def __iter__(self) -> Iterator[Tuple[str, float]]:
+ return iter(zip(self.values, self.penalties))
+
+
+ def get_penalty(self, value: str, default: float = 1000.0) -> float:
+ """ Get the penalty for the given value. Returns the given default
+ if the value does not exist.
+ """
+ try:
+ return self.penalties[self.values.index(value)]
+ except ValueError:
+ pass
+ return default
+
+
@dataclasses.dataclass
class WeightedCategories:
""" A list of class/type tuples together with a penalty.
return bool(self.values)
+ def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
+ return iter(zip(self.values, self.penalties))
+
+
+ def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
+ """ Get the penalty for the given value. Returns the given default
+ if the value does not exist.
+ """
+ try:
+ return self.penalties[self.values.index(value)]
+ except ValueError:
+ pass
+ return default
+
+
+ def sql_restrict(self, table: SaFromClause) -> SaExpression:
+ """ Return an SQLAlcheny expression that restricts the
+ class and type columns of the given table to the values
+ in the list.
+ Must not be used with an empty list.
+ """
+ assert self.values
+ if len(self.values) == 1:
+ return sa.and_(table.c.class_ == self.values[0][0],
+ table.c.type == self.values[0][1])
+
+ return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
+ for c, t in self.values))
+
+
@dataclasses.dataclass(order=True)
class RankedTokens:
""" List of tokens together with the penalty of using it.
"""
assert self.rankings
- col = table.c[self.column]
-
- return sa.case(*((col.contains(r.tokens),r.penalty) for r in self.rankings),
- else_=self.default)
+ return sa.func.weigh_search(table.c[self.column],
+ [f"{{{','.join((str(s) for s in r.tokens))}}}"
+ for r in self.rankings],
+ [r.penalty for r in self.rankings],
+ self.default)
@dataclasses.dataclass