]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_fields.py
adapt typing for newer version of mypy
[nominatim.git] / nominatim / api / search / db_search_fields.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Data structures for more complex fields in abstract search descriptions.
9 """
10 from typing import List, Tuple, Iterator, cast
11 import dataclasses
12
13 import sqlalchemy as sa
14 from sqlalchemy.dialects.postgresql import ARRAY
15
16 from nominatim.typing import SaFromClause, SaColumn, SaExpression
17 from nominatim.api.search.query import Token
18
19 @dataclasses.dataclass
20 class WeightedStrings:
21     """ A list of strings together with a penalty.
22     """
23     values: List[str]
24     penalties: List[float]
25
26     def __bool__(self) -> bool:
27         return bool(self.values)
28
29
30     def __iter__(self) -> Iterator[Tuple[str, float]]:
31         return iter(zip(self.values, self.penalties))
32
33
34     def get_penalty(self, value: str, default: float = 1000.0) -> float:
35         """ Get the penalty for the given value. Returns the given default
36             if the value does not exist.
37         """
38         try:
39             return self.penalties[self.values.index(value)]
40         except ValueError:
41             pass
42         return default
43
44
45 @dataclasses.dataclass
46 class WeightedCategories:
47     """ A list of class/type tuples together with a penalty.
48     """
49     values: List[Tuple[str, str]]
50     penalties: List[float]
51
52     def __bool__(self) -> bool:
53         return bool(self.values)
54
55
56     def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
57         return iter(zip(self.values, self.penalties))
58
59
60     def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
61         """ Get the penalty for the given value. Returns the given default
62             if the value does not exist.
63         """
64         try:
65             return self.penalties[self.values.index(value)]
66         except ValueError:
67             pass
68         return default
69
70
71     def sql_restrict(self, table: SaFromClause) -> SaExpression:
72         """ Return an SQLAlcheny expression that restricts the
73             class and type columns of the given table to the values
74             in the list.
75             Must not be used with an empty list.
76         """
77         assert self.values
78         if len(self.values) == 1:
79             return sa.and_(table.c.class_ == self.values[0][0],
80                            table.c.type == self.values[0][1])
81
82         return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
83                         for c, t in self.values))
84
85
86 @dataclasses.dataclass(order=True)
87 class RankedTokens:
88     """ List of tokens together with the penalty of using it.
89     """
90     penalty: float
91     tokens: List[int]
92
93     def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
94         """ Create a new RankedTokens list with the given token appended.
95             The tokens penalty as well as the given transision penalty
96             are added to the overall penalty.
97         """
98         return RankedTokens(self.penalty + t.penalty + transition_penalty,
99                             self.tokens + [t.token])
100
101
102 @dataclasses.dataclass
103 class FieldRanking:
104     """ A list of rankings to be applied sequentially until one matches.
105         The matched ranking determines the penalty. If none matches a
106         default penalty is applied.
107     """
108     column: str
109     default: float
110     rankings: List[RankedTokens]
111
112     def normalize_penalty(self) -> float:
113         """ Reduce the default and ranking penalties, such that the minimum
114             penalty is 0. Return the penalty that was subtracted.
115         """
116         if self.rankings:
117             min_penalty = min(self.default, min(r.penalty for r in self.rankings))
118         else:
119             min_penalty = self.default
120         if min_penalty > 0.0:
121             self.default -= min_penalty
122             for ranking in self.rankings:
123                 ranking.penalty -= min_penalty
124         return min_penalty
125
126
127     def sql_penalty(self, table: SaFromClause) -> SaColumn:
128         """ Create an SQL expression for the rankings.
129         """
130         assert self.rankings
131
132         return sa.func.weigh_search(table.c[self.column],
133                                     [f"{{{','.join((str(s) for s in r.tokens))}}}"
134                                      for r in self.rankings],
135                                     [r.penalty for r in self.rankings],
136                                     self.default)
137
138
139 @dataclasses.dataclass
140 class FieldLookup:
141     """ A list of tokens to be searched for. The column names the database
142         column to search in and the lookup_type the operator that is applied.
143         'lookup_all' requires all tokens to match. 'lookup_any' requires
144         one of the tokens to match. 'restrict' requires to match all tokens
145         but avoids the use of indexes.
146     """
147     column: str
148     tokens: List[int]
149     lookup_type: str
150
151     def sql_condition(self, table: SaFromClause) -> SaColumn:
152         """ Create an SQL expression for the given match condition.
153         """
154         col = table.c[self.column]
155         if self.lookup_type == 'lookup_all':
156             return col.contains(self.tokens)
157         if self.lookup_type == 'lookup_any':
158             return cast(SaColumn, col.overlap(self.tokens))
159
160         return sa.func.array_cat(col, sa.text('ARRAY[]::integer[]'),
161                                  type_=ARRAY(sa.Integer())).contains(self.tokens)
162
163
164 class SearchData:
165     """ Search fields derived from query and token assignment
166         to be used with the SQL queries.
167     """
168     penalty: float
169
170     lookups: List[FieldLookup] = []
171     rankings: List[FieldRanking]
172
173     housenumbers: WeightedStrings = WeightedStrings([], [])
174     postcodes: WeightedStrings = WeightedStrings([], [])
175     countries: WeightedStrings = WeightedStrings([], [])
176
177     qualifiers: WeightedCategories = WeightedCategories([], [])
178
179
180     def set_strings(self, field: str, tokens: List[Token]) -> None:
181         """ Set on of the WeightedStrings properties from the given
182             token list. Adapt the global penalty, so that the
183             minimum penalty is 0.
184         """
185         if tokens:
186             min_penalty = min(t.penalty for t in tokens)
187             self.penalty += min_penalty
188             wstrs = WeightedStrings([t.lookup_word for t in tokens],
189                                     [t.penalty - min_penalty for t in tokens])
190
191             setattr(self, field, wstrs)
192
193
194     def set_qualifiers(self, tokens: List[Token]) -> None:
195         """ Set the qulaifier field from the given tokens.
196         """
197         if tokens:
198             min_penalty = min(t.penalty for t in tokens)
199             self.penalty += min_penalty
200             self.qualifiers = WeightedCategories([t.get_category() for t in tokens],
201                                                  [t.penalty - min_penalty for t in tokens])
202
203
204     def set_ranking(self, rankings: List[FieldRanking]) -> None:
205         """ Set the list of rankings and normalize the ranking.
206         """
207         self.rankings = []
208         for ranking in rankings:
209             if ranking.rankings:
210                 self.penalty += ranking.normalize_penalty()
211                 self.rankings.append(ranking)
212             else:
213                 self.penalty += ranking.default
214
215
216 def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
217     """ Create a lookup list where name tokens are looked up via index
218         and potential address tokens are used to restrict the search further.
219     """
220     lookup = [FieldLookup('name_vector', name_tokens, 'lookup_all')]
221     if addr_tokens:
222         lookup.append(FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
223
224     return lookup
225
226
227 def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int],
228                        lookup_type: str) -> List[FieldLookup]:
229     """ Create a lookup list where name tokens are looked up via index
230         and only one of the name tokens must be present.
231         Potential address tokens are used to restrict the search further.
232     """
233     lookup = [FieldLookup('name_vector', name_tokens, 'lookup_any')]
234     if addr_tokens:
235         lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookup_type))
236
237     return lookup
238
239
240 def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
241     """ Create a lookup list where address tokens are looked up via index
242         and the name tokens are only used to restrict the search further.
243     """
244     return [FieldLookup('name_vector', name_tokens, 'restrict'),
245             FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]