X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/c42273a4db2d7b4fe05a0be9210901d35e038887..7d911f9ffbdf63b2b2a45c3a3ee7063d006a5779:/nominatim/api/search/db_search_fields.py?ds=sidebyside diff --git a/nominatim/api/search/db_search_fields.py b/nominatim/api/search/db_search_fields.py index 9fcc2c4e..325e08df 100644 --- a/nominatim/api/search/db_search_fields.py +++ b/nominatim/api/search/db_search_fields.py @@ -7,13 +7,13 @@ """ Data structures for more complex fields in abstract search descriptions. """ -from typing import List, Tuple, cast +from typing import List, Tuple, Iterator, cast import dataclasses import sqlalchemy as sa from sqlalchemy.dialects.postgresql import ARRAY -from nominatim.typing import SaFromClause, SaColumn +from nominatim.typing import SaFromClause, SaColumn, SaExpression from nominatim.api.search.query import Token @dataclasses.dataclass @@ -27,6 +27,21 @@ class WeightedStrings: return bool(self.values) + def __iter__(self) -> Iterator[Tuple[str, float]]: + return iter(zip(self.values, self.penalties)) + + + def get_penalty(self, value: str, default: float = 1000.0) -> float: + """ Get the penalty for the given value. Returns the given default + if the value does not exist. + """ + try: + return self.penalties[self.values.index(value)] + except ValueError: + pass + return default + + @dataclasses.dataclass class WeightedCategories: """ A list of class/type tuples together with a penalty. @@ -38,6 +53,36 @@ class WeightedCategories: return bool(self.values) + def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]: + return iter(zip(self.values, self.penalties)) + + + def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float: + """ Get the penalty for the given value. Returns the given default + if the value does not exist. + """ + try: + return self.penalties[self.values.index(value)] + except ValueError: + pass + return default + + + def sql_restrict(self, table: SaFromClause) -> SaExpression: + """ Return an SQLAlcheny expression that restricts the + class and type columns of the given table to the values + in the list. + Must not be used with an empty list. + """ + assert self.values + if len(self.values) == 1: + return sa.and_(table.c.class_ == self.values[0][0], + table.c.type == self.values[0][1]) + + return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t) + for c, t in self.values)) + + @dataclasses.dataclass(order=True) class RankedTokens: """ List of tokens together with the penalty of using it.