1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Data structures for more complex fields in abstract search descriptions.
10 from typing import List, Tuple, Iterator, Dict, Type
13 import sqlalchemy as sa
15 from ..typing import SaFromClause, SaColumn, SaExpression
16 from ..utils.json_writer import JsonWriter
17 from .query import Token
18 from . import db_search_lookups as lookups
21 @dataclasses.dataclass
22 class WeightedStrings:
23 """ A list of strings together with a penalty.
26 penalties: List[float]
28 def __bool__(self) -> bool:
29 return bool(self.values)
32 def __iter__(self) -> Iterator[Tuple[str, float]]:
33 return iter(zip(self.values, self.penalties))
36 def get_penalty(self, value: str, default: float = 1000.0) -> float:
37 """ Get the penalty for the given value. Returns the given default
38 if the value does not exist.
41 return self.penalties[self.values.index(value)]
47 @dataclasses.dataclass
48 class WeightedCategories:
49 """ A list of class/type tuples together with a penalty.
51 values: List[Tuple[str, str]]
52 penalties: List[float]
54 def __bool__(self) -> bool:
55 return bool(self.values)
58 def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
59 return iter(zip(self.values, self.penalties))
62 def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
63 """ Get the penalty for the given value. Returns the given default
64 if the value does not exist.
67 return self.penalties[self.values.index(value)]
73 def sql_restrict(self, table: SaFromClause) -> SaExpression:
74 """ Return an SQLAlcheny expression that restricts the
75 class and type columns of the given table to the values
77 Must not be used with an empty list.
80 if len(self.values) == 1:
81 return sa.and_(table.c.class_ == self.values[0][0],
82 table.c.type == self.values[0][1])
84 return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
85 for c, t in self.values))
88 @dataclasses.dataclass(order=True)
90 """ List of tokens together with the penalty of using it.
95 def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
96 """ Create a new RankedTokens list with the given token appended.
97 The tokens penalty as well as the given transition penalty
98 are added to the overall penalty.
100 return RankedTokens(self.penalty + t.penalty + transition_penalty,
101 self.tokens + [t.token])
104 @dataclasses.dataclass
106 """ A list of rankings to be applied sequentially until one matches.
107 The matched ranking determines the penalty. If none matches a
108 default penalty is applied.
112 rankings: List[RankedTokens]
114 def normalize_penalty(self) -> float:
115 """ Reduce the default and ranking penalties, such that the minimum
116 penalty is 0. Return the penalty that was subtracted.
119 min_penalty = min(self.default, min(r.penalty for r in self.rankings))
121 min_penalty = self.default
122 if min_penalty > 0.0:
123 self.default -= min_penalty
124 for ranking in self.rankings:
125 ranking.penalty -= min_penalty
129 def sql_penalty(self, table: SaFromClause) -> SaColumn:
130 """ Create an SQL expression for the rankings.
134 rout = JsonWriter().start_array()
135 for rank in self.rankings:
136 rout.start_array().value(rank.penalty).next()
138 for token in rank.tokens:
139 rout.value(token).next()
141 rout.end_array().next()
144 return sa.func.weigh_search(table.c[self.column], rout(), self.default)
147 @dataclasses.dataclass
149 """ A list of tokens to be searched for. The column names the database
150 column to search in and the lookup_type the operator that is applied.
151 'lookup_all' requires all tokens to match. 'lookup_any' requires
152 one of the tokens to match. 'restrict' requires to match all tokens
153 but avoids the use of indexes.
157 lookup_type: Type[lookups.LookupType]
159 def sql_condition(self, table: SaFromClause) -> SaColumn:
160 """ Create an SQL expression for the given match condition.
162 return self.lookup_type(table, self.column, self.tokens)
166 """ Search fields derived from query and token assignment
167 to be used with the SQL queries.
171 lookups: List[FieldLookup] = []
172 rankings: List[FieldRanking]
174 housenumbers: WeightedStrings = WeightedStrings([], [])
175 postcodes: WeightedStrings = WeightedStrings([], [])
176 countries: WeightedStrings = WeightedStrings([], [])
178 qualifiers: WeightedCategories = WeightedCategories([], [])
181 def set_strings(self, field: str, tokens: List[Token]) -> None:
182 """ Set on of the WeightedStrings properties from the given
183 token list. Adapt the global penalty, so that the
184 minimum penalty is 0.
187 min_penalty = min(t.penalty for t in tokens)
188 self.penalty += min_penalty
189 wstrs = WeightedStrings([t.lookup_word for t in tokens],
190 [t.penalty - min_penalty for t in tokens])
192 setattr(self, field, wstrs)
195 def set_qualifiers(self, tokens: List[Token]) -> None:
196 """ Set the qulaifier field from the given tokens.
199 categories: Dict[Tuple[str, str], float] = {}
202 min_penalty = min(min_penalty, t.penalty)
203 cat = t.get_category()
204 if t.penalty < categories.get(cat, 1000.0):
205 categories[cat] = t.penalty
206 self.penalty += min_penalty
207 self.qualifiers = WeightedCategories(list(categories.keys()),
208 list(categories.values()))
211 def set_ranking(self, rankings: List[FieldRanking]) -> None:
212 """ Set the list of rankings and normalize the ranking.
215 for ranking in rankings:
217 self.penalty += ranking.normalize_penalty()
218 self.rankings.append(ranking)
220 self.penalty += ranking.default
223 def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
224 """ Create a lookup list where name tokens are looked up via index
225 and potential address tokens are used to restrict the search further.
227 lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
229 lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookups.Restrict))
234 def lookup_by_any_name(name_tokens: List[int], addr_restrict_tokens: List[int],
235 addr_lookup_tokens: List[int]) -> List[FieldLookup]:
236 """ Create a lookup list where name tokens are looked up via index
237 and only one of the name tokens must be present.
238 Potential address tokens are used to restrict the search further.
240 lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAny)]
241 if addr_restrict_tokens:
242 lookup.append(FieldLookup('nameaddress_vector', addr_restrict_tokens, lookups.Restrict))
243 if addr_lookup_tokens:
244 lookup.append(FieldLookup('nameaddress_vector', addr_lookup_tokens, lookups.LookupAll))
249 def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
250 """ Create a lookup list where address tokens are looked up via index
251 and the name tokens are only used to restrict the search further.
253 return [FieldLookup('name_vector', name_tokens, lookups.Restrict),
254 FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)]