1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Data structures for more complex fields in abstract search descriptions.
10 from typing import List, Tuple, Iterator, cast, Dict
13 import sqlalchemy as sa
15 from nominatim.typing import SaFromClause, SaColumn, SaExpression
16 from nominatim.api.search.query import Token
18 @dataclasses.dataclass
19 class WeightedStrings:
20 """ A list of strings together with a penalty.
23 penalties: List[float]
25 def __bool__(self) -> bool:
26 return bool(self.values)
29 def __iter__(self) -> Iterator[Tuple[str, float]]:
30 return iter(zip(self.values, self.penalties))
33 def get_penalty(self, value: str, default: float = 1000.0) -> float:
34 """ Get the penalty for the given value. Returns the given default
35 if the value does not exist.
38 return self.penalties[self.values.index(value)]
44 @dataclasses.dataclass
45 class WeightedCategories:
46 """ A list of class/type tuples together with a penalty.
48 values: List[Tuple[str, str]]
49 penalties: List[float]
51 def __bool__(self) -> bool:
52 return bool(self.values)
55 def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
56 return iter(zip(self.values, self.penalties))
59 def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
60 """ Get the penalty for the given value. Returns the given default
61 if the value does not exist.
64 return self.penalties[self.values.index(value)]
70 def sql_restrict(self, table: SaFromClause) -> SaExpression:
71 """ Return an SQLAlcheny expression that restricts the
72 class and type columns of the given table to the values
74 Must not be used with an empty list.
77 if len(self.values) == 1:
78 return sa.and_(table.c.class_ == self.values[0][0],
79 table.c.type == self.values[0][1])
81 return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
82 for c, t in self.values))
85 @dataclasses.dataclass(order=True)
87 """ List of tokens together with the penalty of using it.
92 def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
93 """ Create a new RankedTokens list with the given token appended.
94 The tokens penalty as well as the given transision penalty
95 are added to the overall penalty.
97 return RankedTokens(self.penalty + t.penalty + transition_penalty,
98 self.tokens + [t.token])
101 @dataclasses.dataclass
103 """ A list of rankings to be applied sequentially until one matches.
104 The matched ranking determines the penalty. If none matches a
105 default penalty is applied.
109 rankings: List[RankedTokens]
111 def normalize_penalty(self) -> float:
112 """ Reduce the default and ranking penalties, such that the minimum
113 penalty is 0. Return the penalty that was subtracted.
116 min_penalty = min(self.default, min(r.penalty for r in self.rankings))
118 min_penalty = self.default
119 if min_penalty > 0.0:
120 self.default -= min_penalty
121 for ranking in self.rankings:
122 ranking.penalty -= min_penalty
126 def sql_penalty(self, table: SaFromClause) -> SaColumn:
127 """ Create an SQL expression for the rankings.
131 return sa.func.weigh_search(table.c[self.column],
132 [f"{{{','.join((str(s) for s in r.tokens))}}}"
133 for r in self.rankings],
134 [r.penalty for r in self.rankings],
138 @dataclasses.dataclass
140 """ A list of tokens to be searched for. The column names the database
141 column to search in and the lookup_type the operator that is applied.
142 'lookup_all' requires all tokens to match. 'lookup_any' requires
143 one of the tokens to match. 'restrict' requires to match all tokens
144 but avoids the use of indexes.
150 def sql_condition(self, table: SaFromClause) -> SaColumn:
151 """ Create an SQL expression for the given match condition.
153 col = table.c[self.column]
154 if self.lookup_type == 'lookup_all':
155 return col.contains(self.tokens)
156 if self.lookup_type == 'lookup_any':
157 return cast(SaColumn, col.overlaps(self.tokens))
159 return sa.func.coalesce(sa.null(), col).contains(self.tokens) # pylint: disable=not-callable
163 """ Search fields derived from query and token assignment
164 to be used with the SQL queries.
168 lookups: List[FieldLookup] = []
169 rankings: List[FieldRanking]
171 housenumbers: WeightedStrings = WeightedStrings([], [])
172 postcodes: WeightedStrings = WeightedStrings([], [])
173 countries: WeightedStrings = WeightedStrings([], [])
175 qualifiers: WeightedCategories = WeightedCategories([], [])
178 def set_strings(self, field: str, tokens: List[Token]) -> None:
179 """ Set on of the WeightedStrings properties from the given
180 token list. Adapt the global penalty, so that the
181 minimum penalty is 0.
184 min_penalty = min(t.penalty for t in tokens)
185 self.penalty += min_penalty
186 wstrs = WeightedStrings([t.lookup_word for t in tokens],
187 [t.penalty - min_penalty for t in tokens])
189 setattr(self, field, wstrs)
192 def set_qualifiers(self, tokens: List[Token]) -> None:
193 """ Set the qulaifier field from the given tokens.
196 categories: Dict[Tuple[str, str], float] = {}
199 if t.penalty < min_penalty:
200 min_penalty = t.penalty
201 cat = t.get_category()
202 if t.penalty < categories.get(cat, 1000.0):
203 categories[cat] = t.penalty
204 self.penalty += min_penalty
205 self.qualifiers = WeightedCategories(list(categories.keys()),
206 list(categories.values()))
209 def set_ranking(self, rankings: List[FieldRanking]) -> None:
210 """ Set the list of rankings and normalize the ranking.
213 for ranking in rankings:
215 self.penalty += ranking.normalize_penalty()
216 self.rankings.append(ranking)
218 self.penalty += ranking.default
221 def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
222 """ Create a lookup list where name tokens are looked up via index
223 and potential address tokens are used to restrict the search further.
225 lookup = [FieldLookup('name_vector', name_tokens, 'lookup_all')]
227 lookup.append(FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
232 def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int],
233 lookup_type: str) -> List[FieldLookup]:
234 """ Create a lookup list where name tokens are looked up via index
235 and only one of the name tokens must be present.
236 Potential address tokens are used to restrict the search further.
238 lookup = [FieldLookup('name_vector', name_tokens, 'lookup_any')]
240 lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookup_type))
245 def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
246 """ Create a lookup list where address tokens are looked up via index
247 and the name tokens are only used to restrict the search further.
249 return [FieldLookup('name_vector', name_tokens, 'restrict'),
250 FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]