1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Data structures for more complex fields in abstract search descriptions.
10 from typing import List, Tuple, Iterator, Dict, Type
13 import sqlalchemy as sa
15 from ..typing import SaFromClause, SaColumn, SaExpression
16 from ..utils.json_writer import JsonWriter
17 from .query import Token
18 from . import db_search_lookups as lookups
21 @dataclasses.dataclass
22 class WeightedStrings:
23 """ A list of strings together with a penalty.
26 penalties: List[float]
28 def __bool__(self) -> bool:
29 return bool(self.values)
31 def __iter__(self) -> Iterator[Tuple[str, float]]:
32 return iter(zip(self.values, self.penalties))
34 def get_penalty(self, value: str, default: float = 1000.0) -> float:
35 """ Get the penalty for the given value. Returns the given default
36 if the value does not exist.
39 return self.penalties[self.values.index(value)]
45 @dataclasses.dataclass
46 class WeightedCategories:
47 """ A list of class/type tuples together with a penalty.
49 values: List[Tuple[str, str]]
50 penalties: List[float]
52 def __bool__(self) -> bool:
53 return bool(self.values)
55 def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
56 return iter(zip(self.values, self.penalties))
58 def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
59 """ Get the penalty for the given value. Returns the given default
60 if the value does not exist.
63 return self.penalties[self.values.index(value)]
68 def sql_restrict(self, table: SaFromClause) -> SaExpression:
69 """ Return an SQLAlcheny expression that restricts the
70 class and type columns of the given table to the values
72 Must not be used with an empty list.
75 if len(self.values) == 1:
76 return sa.and_(table.c.class_ == self.values[0][0],
77 table.c.type == self.values[0][1])
79 return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
80 for c, t in self.values))
83 @dataclasses.dataclass(order=True)
85 """ List of tokens together with the penalty of using it.
90 def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
91 """ Create a new RankedTokens list with the given token appended.
92 The tokens penalty as well as the given transition penalty
93 are added to the overall penalty.
95 return RankedTokens(self.penalty + t.penalty + transition_penalty,
96 self.tokens + [t.token])
99 @dataclasses.dataclass
101 """ A list of rankings to be applied sequentially until one matches.
102 The matched ranking determines the penalty. If none matches a
103 default penalty is applied.
107 rankings: List[RankedTokens]
109 def normalize_penalty(self) -> float:
110 """ Reduce the default and ranking penalties, such that the minimum
111 penalty is 0. Return the penalty that was subtracted.
114 min_penalty = min(self.default, min(r.penalty for r in self.rankings))
116 min_penalty = self.default
117 if min_penalty > 0.0:
118 self.default -= min_penalty
119 for ranking in self.rankings:
120 ranking.penalty -= min_penalty
123 def sql_penalty(self, table: SaFromClause) -> SaColumn:
124 """ Create an SQL expression for the rankings.
128 rout = JsonWriter().start_array()
129 for rank in self.rankings:
130 rout.start_array().value(rank.penalty).next()
132 for token in rank.tokens:
133 rout.value(token).next()
135 rout.end_array().next()
138 return sa.func.weigh_search(table.c[self.column], rout(), self.default)
141 @dataclasses.dataclass
143 """ A list of tokens to be searched for. The column names the database
144 column to search in and the lookup_type the operator that is applied.
145 'lookup_all' requires all tokens to match. 'lookup_any' requires
146 one of the tokens to match. 'restrict' requires to match all tokens
147 but avoids the use of indexes.
151 lookup_type: Type[lookups.LookupType]
153 def sql_condition(self, table: SaFromClause) -> SaColumn:
154 """ Create an SQL expression for the given match condition.
156 return self.lookup_type(table, self.column, self.tokens)
160 """ Search fields derived from query and token assignment
161 to be used with the SQL queries.
165 lookups: List[FieldLookup] = []
166 rankings: List[FieldRanking]
168 housenumbers: WeightedStrings = WeightedStrings([], [])
169 postcodes: WeightedStrings = WeightedStrings([], [])
170 countries: WeightedStrings = WeightedStrings([], [])
172 qualifiers: WeightedCategories = WeightedCategories([], [])
174 def set_strings(self, field: str, tokens: List[Token]) -> None:
175 """ Set on of the WeightedStrings properties from the given
176 token list. Adapt the global penalty, so that the
177 minimum penalty is 0.
180 min_penalty = min(t.penalty for t in tokens)
181 self.penalty += min_penalty
182 wstrs = WeightedStrings([t.lookup_word for t in tokens],
183 [t.penalty - min_penalty for t in tokens])
185 setattr(self, field, wstrs)
187 def set_qualifiers(self, tokens: List[Token]) -> None:
188 """ Set the qulaifier field from the given tokens.
191 categories: Dict[Tuple[str, str], float] = {}
194 min_penalty = min(min_penalty, t.penalty)
195 cat = t.get_category()
196 if t.penalty < categories.get(cat, 1000.0):
197 categories[cat] = t.penalty
198 self.penalty += min_penalty
199 self.qualifiers = WeightedCategories(list(categories.keys()),
200 list(categories.values()))
202 def set_ranking(self, rankings: List[FieldRanking]) -> None:
203 """ Set the list of rankings and normalize the ranking.
206 for ranking in rankings:
208 self.penalty += ranking.normalize_penalty()
209 self.rankings.append(ranking)
211 self.penalty += ranking.default
214 def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
215 """ Create a lookup list where name tokens are looked up via index
216 and potential address tokens are used to restrict the search further.
218 lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
220 lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookups.Restrict))
225 def lookup_by_any_name(name_tokens: List[int], addr_restrict_tokens: List[int],
226 addr_lookup_tokens: List[int]) -> List[FieldLookup]:
227 """ Create a lookup list where name tokens are looked up via index
228 and only one of the name tokens must be present.
229 Potential address tokens are used to restrict the search further.
231 lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAny)]
232 if addr_restrict_tokens:
233 lookup.append(FieldLookup('nameaddress_vector', addr_restrict_tokens, lookups.Restrict))
234 if addr_lookup_tokens:
235 lookup.append(FieldLookup('nameaddress_vector', addr_lookup_tokens, lookups.LookupAll))
240 def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
241 """ Create a lookup list where address tokens are looked up via index
242 and the name tokens are only used to restrict the search further.
244 return [FieldLookup('name_vector', name_tokens, lookups.Restrict),
245 FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)]