]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/db_search_fields.py
fix style issue found by flake8
[nominatim.git] / src / nominatim_api / search / db_search_fields.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Data structures for more complex fields in abstract search descriptions.
9 """
10 from typing import List, Tuple, Iterator, Dict, Type
11 import dataclasses
12
13 import sqlalchemy as sa
14
15 from ..typing import SaFromClause, SaColumn, SaExpression
16 from ..utils.json_writer import JsonWriter
17 from .query import Token
18 from . import db_search_lookups as lookups
19
20
21 @dataclasses.dataclass
22 class WeightedStrings:
23     """ A list of strings together with a penalty.
24     """
25     values: List[str]
26     penalties: List[float]
27
28     def __bool__(self) -> bool:
29         return bool(self.values)
30
31     def __iter__(self) -> Iterator[Tuple[str, float]]:
32         return iter(zip(self.values, self.penalties))
33
34     def get_penalty(self, value: str, default: float = 1000.0) -> float:
35         """ Get the penalty for the given value. Returns the given default
36             if the value does not exist.
37         """
38         try:
39             return self.penalties[self.values.index(value)]
40         except ValueError:
41             pass
42         return default
43
44
45 @dataclasses.dataclass
46 class WeightedCategories:
47     """ A list of class/type tuples together with a penalty.
48     """
49     values: List[Tuple[str, str]]
50     penalties: List[float]
51
52     def __bool__(self) -> bool:
53         return bool(self.values)
54
55     def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
56         return iter(zip(self.values, self.penalties))
57
58     def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
59         """ Get the penalty for the given value. Returns the given default
60             if the value does not exist.
61         """
62         try:
63             return self.penalties[self.values.index(value)]
64         except ValueError:
65             pass
66         return default
67
68     def sql_restrict(self, table: SaFromClause) -> SaExpression:
69         """ Return an SQLAlcheny expression that restricts the
70             class and type columns of the given table to the values
71             in the list.
72             Must not be used with an empty list.
73         """
74         assert self.values
75         if len(self.values) == 1:
76             return sa.and_(table.c.class_ == self.values[0][0],
77                            table.c.type == self.values[0][1])
78
79         return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
80                         for c, t in self.values))
81
82
83 @dataclasses.dataclass(order=True)
84 class RankedTokens:
85     """ List of tokens together with the penalty of using it.
86     """
87     penalty: float
88     tokens: List[int]
89
90     def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
91         """ Create a new RankedTokens list with the given token appended.
92             The tokens penalty as well as the given transition penalty
93             are added to the overall penalty.
94         """
95         return RankedTokens(self.penalty + t.penalty + transition_penalty,
96                             self.tokens + [t.token])
97
98
99 @dataclasses.dataclass
100 class FieldRanking:
101     """ A list of rankings to be applied sequentially until one matches.
102         The matched ranking determines the penalty. If none matches a
103         default penalty is applied.
104     """
105     column: str
106     default: float
107     rankings: List[RankedTokens]
108
109     def normalize_penalty(self) -> float:
110         """ Reduce the default and ranking penalties, such that the minimum
111             penalty is 0. Return the penalty that was subtracted.
112         """
113         if self.rankings:
114             min_penalty = min(self.default, min(r.penalty for r in self.rankings))
115         else:
116             min_penalty = self.default
117         if min_penalty > 0.0:
118             self.default -= min_penalty
119             for ranking in self.rankings:
120                 ranking.penalty -= min_penalty
121         return min_penalty
122
123     def sql_penalty(self, table: SaFromClause) -> SaColumn:
124         """ Create an SQL expression for the rankings.
125         """
126         assert self.rankings
127
128         rout = JsonWriter().start_array()
129         for rank in self.rankings:
130             rout.start_array().value(rank.penalty).next()
131             rout.start_array()
132             for token in rank.tokens:
133                 rout.value(token).next()
134             rout.end_array()
135             rout.end_array().next()
136         rout.end_array()
137
138         return sa.func.weigh_search(table.c[self.column], rout(), self.default)
139
140
141 @dataclasses.dataclass
142 class FieldLookup:
143     """ A list of tokens to be searched for. The column names the database
144         column to search in and the lookup_type the operator that is applied.
145         'lookup_all' requires all tokens to match. 'lookup_any' requires
146         one of the tokens to match. 'restrict' requires to match all tokens
147         but avoids the use of indexes.
148     """
149     column: str
150     tokens: List[int]
151     lookup_type: Type[lookups.LookupType]
152
153     def sql_condition(self, table: SaFromClause) -> SaColumn:
154         """ Create an SQL expression for the given match condition.
155         """
156         return self.lookup_type(table, self.column, self.tokens)
157
158
159 class SearchData:
160     """ Search fields derived from query and token assignment
161         to be used with the SQL queries.
162     """
163     penalty: float
164
165     lookups: List[FieldLookup] = []
166     rankings: List[FieldRanking]
167
168     housenumbers: WeightedStrings = WeightedStrings([], [])
169     postcodes: WeightedStrings = WeightedStrings([], [])
170     countries: WeightedStrings = WeightedStrings([], [])
171
172     qualifiers: WeightedCategories = WeightedCategories([], [])
173
174     def set_strings(self, field: str, tokens: List[Token]) -> None:
175         """ Set on of the WeightedStrings properties from the given
176             token list. Adapt the global penalty, so that the
177             minimum penalty is 0.
178         """
179         if tokens:
180             min_penalty = min(t.penalty for t in tokens)
181             self.penalty += min_penalty
182             wstrs = WeightedStrings([t.lookup_word for t in tokens],
183                                     [t.penalty - min_penalty for t in tokens])
184
185             setattr(self, field, wstrs)
186
187     def set_qualifiers(self, tokens: List[Token]) -> None:
188         """ Set the qulaifier field from the given tokens.
189         """
190         if tokens:
191             categories: Dict[Tuple[str, str], float] = {}
192             min_penalty = 1000.0
193             for t in tokens:
194                 min_penalty = min(min_penalty, t.penalty)
195                 cat = t.get_category()
196                 if t.penalty < categories.get(cat, 1000.0):
197                     categories[cat] = t.penalty
198             self.penalty += min_penalty
199             self.qualifiers = WeightedCategories(list(categories.keys()),
200                                                  list(categories.values()))
201
202     def set_ranking(self, rankings: List[FieldRanking]) -> None:
203         """ Set the list of rankings and normalize the ranking.
204         """
205         self.rankings = []
206         for ranking in rankings:
207             if ranking.rankings:
208                 self.penalty += ranking.normalize_penalty()
209                 self.rankings.append(ranking)
210             else:
211                 self.penalty += ranking.default
212
213
214 def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
215     """ Create a lookup list where name tokens are looked up via index
216         and potential address tokens are used to restrict the search further.
217     """
218     lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
219     if addr_tokens:
220         lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookups.Restrict))
221
222     return lookup
223
224
225 def lookup_by_any_name(name_tokens: List[int], addr_restrict_tokens: List[int],
226                        addr_lookup_tokens: List[int]) -> List[FieldLookup]:
227     """ Create a lookup list where name tokens are looked up via index
228         and only one of the name tokens must be present.
229         Potential address tokens are used to restrict the search further.
230     """
231     lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAny)]
232     if addr_restrict_tokens:
233         lookup.append(FieldLookup('nameaddress_vector', addr_restrict_tokens, lookups.Restrict))
234     if addr_lookup_tokens:
235         lookup.append(FieldLookup('nameaddress_vector', addr_lookup_tokens, lookups.LookupAll))
236
237     return lookup
238
239
240 def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
241     """ Create a lookup list where address tokens are looked up via index
242         and the name tokens are only used to restrict the search further.
243     """
244     return [FieldLookup('name_vector', name_tokens, lookups.Restrict),
245             FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)]