]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/db_search_fields.py
Revert "work round typing bug in pyosmium 4.0"
[nominatim.git] / src / nominatim_api / search / db_search_fields.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Data structures for more complex fields in abstract search descriptions.
9 """
10 from typing import List, Tuple, Iterator, Dict, Type
11 import dataclasses
12
13 import sqlalchemy as sa
14
15 from ..typing import SaFromClause, SaColumn, SaExpression
16 from ..utils.json_writer import JsonWriter
17 from .query import Token
18 from . import db_search_lookups as lookups
19
20
21 @dataclasses.dataclass
22 class WeightedStrings:
23     """ A list of strings together with a penalty.
24     """
25     values: List[str]
26     penalties: List[float]
27
28     def __bool__(self) -> bool:
29         return bool(self.values)
30
31
32     def __iter__(self) -> Iterator[Tuple[str, float]]:
33         return iter(zip(self.values, self.penalties))
34
35
36     def get_penalty(self, value: str, default: float = 1000.0) -> float:
37         """ Get the penalty for the given value. Returns the given default
38             if the value does not exist.
39         """
40         try:
41             return self.penalties[self.values.index(value)]
42         except ValueError:
43             pass
44         return default
45
46
47 @dataclasses.dataclass
48 class WeightedCategories:
49     """ A list of class/type tuples together with a penalty.
50     """
51     values: List[Tuple[str, str]]
52     penalties: List[float]
53
54     def __bool__(self) -> bool:
55         return bool(self.values)
56
57
58     def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
59         return iter(zip(self.values, self.penalties))
60
61
62     def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
63         """ Get the penalty for the given value. Returns the given default
64             if the value does not exist.
65         """
66         try:
67             return self.penalties[self.values.index(value)]
68         except ValueError:
69             pass
70         return default
71
72
73     def sql_restrict(self, table: SaFromClause) -> SaExpression:
74         """ Return an SQLAlcheny expression that restricts the
75             class and type columns of the given table to the values
76             in the list.
77             Must not be used with an empty list.
78         """
79         assert self.values
80         if len(self.values) == 1:
81             return sa.and_(table.c.class_ == self.values[0][0],
82                            table.c.type == self.values[0][1])
83
84         return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
85                         for c, t in self.values))
86
87
88 @dataclasses.dataclass(order=True)
89 class RankedTokens:
90     """ List of tokens together with the penalty of using it.
91     """
92     penalty: float
93     tokens: List[int]
94
95     def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
96         """ Create a new RankedTokens list with the given token appended.
97             The tokens penalty as well as the given transition penalty
98             are added to the overall penalty.
99         """
100         return RankedTokens(self.penalty + t.penalty + transition_penalty,
101                             self.tokens + [t.token])
102
103
104 @dataclasses.dataclass
105 class FieldRanking:
106     """ A list of rankings to be applied sequentially until one matches.
107         The matched ranking determines the penalty. If none matches a
108         default penalty is applied.
109     """
110     column: str
111     default: float
112     rankings: List[RankedTokens]
113
114     def normalize_penalty(self) -> float:
115         """ Reduce the default and ranking penalties, such that the minimum
116             penalty is 0. Return the penalty that was subtracted.
117         """
118         if self.rankings:
119             min_penalty = min(self.default, min(r.penalty for r in self.rankings))
120         else:
121             min_penalty = self.default
122         if min_penalty > 0.0:
123             self.default -= min_penalty
124             for ranking in self.rankings:
125                 ranking.penalty -= min_penalty
126         return min_penalty
127
128
129     def sql_penalty(self, table: SaFromClause) -> SaColumn:
130         """ Create an SQL expression for the rankings.
131         """
132         assert self.rankings
133
134         rout = JsonWriter().start_array()
135         for rank in self.rankings:
136             rout.start_array().value(rank.penalty).next()
137             rout.start_array()
138             for token in rank.tokens:
139                 rout.value(token).next()
140             rout.end_array()
141             rout.end_array().next()
142         rout.end_array()
143
144         return sa.func.weigh_search(table.c[self.column], rout(), self.default)
145
146
147 @dataclasses.dataclass
148 class FieldLookup:
149     """ A list of tokens to be searched for. The column names the database
150         column to search in and the lookup_type the operator that is applied.
151         'lookup_all' requires all tokens to match. 'lookup_any' requires
152         one of the tokens to match. 'restrict' requires to match all tokens
153         but avoids the use of indexes.
154     """
155     column: str
156     tokens: List[int]
157     lookup_type: Type[lookups.LookupType]
158
159     def sql_condition(self, table: SaFromClause) -> SaColumn:
160         """ Create an SQL expression for the given match condition.
161         """
162         return self.lookup_type(table, self.column, self.tokens)
163
164
165 class SearchData:
166     """ Search fields derived from query and token assignment
167         to be used with the SQL queries.
168     """
169     penalty: float
170
171     lookups: List[FieldLookup] = []
172     rankings: List[FieldRanking]
173
174     housenumbers: WeightedStrings = WeightedStrings([], [])
175     postcodes: WeightedStrings = WeightedStrings([], [])
176     countries: WeightedStrings = WeightedStrings([], [])
177
178     qualifiers: WeightedCategories = WeightedCategories([], [])
179
180
181     def set_strings(self, field: str, tokens: List[Token]) -> None:
182         """ Set on of the WeightedStrings properties from the given
183             token list. Adapt the global penalty, so that the
184             minimum penalty is 0.
185         """
186         if tokens:
187             min_penalty = min(t.penalty for t in tokens)
188             self.penalty += min_penalty
189             wstrs = WeightedStrings([t.lookup_word for t in tokens],
190                                     [t.penalty - min_penalty for t in tokens])
191
192             setattr(self, field, wstrs)
193
194
195     def set_qualifiers(self, tokens: List[Token]) -> None:
196         """ Set the qulaifier field from the given tokens.
197         """
198         if tokens:
199             categories: Dict[Tuple[str, str], float] = {}
200             min_penalty = 1000.0
201             for t in tokens:
202                 min_penalty = min(min_penalty, t.penalty)
203                 cat = t.get_category()
204                 if t.penalty < categories.get(cat, 1000.0):
205                     categories[cat] = t.penalty
206             self.penalty += min_penalty
207             self.qualifiers = WeightedCategories(list(categories.keys()),
208                                                  list(categories.values()))
209
210
211     def set_ranking(self, rankings: List[FieldRanking]) -> None:
212         """ Set the list of rankings and normalize the ranking.
213         """
214         self.rankings = []
215         for ranking in rankings:
216             if ranking.rankings:
217                 self.penalty += ranking.normalize_penalty()
218                 self.rankings.append(ranking)
219             else:
220                 self.penalty += ranking.default
221
222
223 def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
224     """ Create a lookup list where name tokens are looked up via index
225         and potential address tokens are used to restrict the search further.
226     """
227     lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
228     if addr_tokens:
229         lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookups.Restrict))
230
231     return lookup
232
233
234 def lookup_by_any_name(name_tokens: List[int], addr_restrict_tokens: List[int],
235                        addr_lookup_tokens: List[int]) -> List[FieldLookup]:
236     """ Create a lookup list where name tokens are looked up via index
237         and only one of the name tokens must be present.
238         Potential address tokens are used to restrict the search further.
239     """
240     lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAny)]
241     if addr_restrict_tokens:
242         lookup.append(FieldLookup('nameaddress_vector', addr_restrict_tokens, lookups.Restrict))
243     if addr_lookup_tokens:
244         lookup.append(FieldLookup('nameaddress_vector', addr_lookup_tokens, lookups.LookupAll))
245
246     return lookup
247
248
249 def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
250     """ Create a lookup list where address tokens are looked up via index
251         and the name tokens are only used to restrict the search further.
252     """
253     return [FieldLookup('name_vector', name_tokens, lookups.Restrict),
254             FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)]