]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_fields.py
hide type differences between Postgres and Sqlite in custom types
[nominatim.git] / nominatim / api / search / db_search_fields.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Data structures for more complex fields in abstract search descriptions.
9 """
10 from typing import List, Tuple, Iterator, cast, Dict
11 import dataclasses
12
13 import sqlalchemy as sa
14
15 from nominatim.typing import SaFromClause, SaColumn, SaExpression
16 from nominatim.api.search.query import Token
17
18 @dataclasses.dataclass
19 class WeightedStrings:
20     """ A list of strings together with a penalty.
21     """
22     values: List[str]
23     penalties: List[float]
24
25     def __bool__(self) -> bool:
26         return bool(self.values)
27
28
29     def __iter__(self) -> Iterator[Tuple[str, float]]:
30         return iter(zip(self.values, self.penalties))
31
32
33     def get_penalty(self, value: str, default: float = 1000.0) -> float:
34         """ Get the penalty for the given value. Returns the given default
35             if the value does not exist.
36         """
37         try:
38             return self.penalties[self.values.index(value)]
39         except ValueError:
40             pass
41         return default
42
43
44 @dataclasses.dataclass
45 class WeightedCategories:
46     """ A list of class/type tuples together with a penalty.
47     """
48     values: List[Tuple[str, str]]
49     penalties: List[float]
50
51     def __bool__(self) -> bool:
52         return bool(self.values)
53
54
55     def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
56         return iter(zip(self.values, self.penalties))
57
58
59     def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
60         """ Get the penalty for the given value. Returns the given default
61             if the value does not exist.
62         """
63         try:
64             return self.penalties[self.values.index(value)]
65         except ValueError:
66             pass
67         return default
68
69
70     def sql_restrict(self, table: SaFromClause) -> SaExpression:
71         """ Return an SQLAlcheny expression that restricts the
72             class and type columns of the given table to the values
73             in the list.
74             Must not be used with an empty list.
75         """
76         assert self.values
77         if len(self.values) == 1:
78             return sa.and_(table.c.class_ == self.values[0][0],
79                            table.c.type == self.values[0][1])
80
81         return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
82                         for c, t in self.values))
83
84
85 @dataclasses.dataclass(order=True)
86 class RankedTokens:
87     """ List of tokens together with the penalty of using it.
88     """
89     penalty: float
90     tokens: List[int]
91
92     def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
93         """ Create a new RankedTokens list with the given token appended.
94             The tokens penalty as well as the given transision penalty
95             are added to the overall penalty.
96         """
97         return RankedTokens(self.penalty + t.penalty + transition_penalty,
98                             self.tokens + [t.token])
99
100
101 @dataclasses.dataclass
102 class FieldRanking:
103     """ A list of rankings to be applied sequentially until one matches.
104         The matched ranking determines the penalty. If none matches a
105         default penalty is applied.
106     """
107     column: str
108     default: float
109     rankings: List[RankedTokens]
110
111     def normalize_penalty(self) -> float:
112         """ Reduce the default and ranking penalties, such that the minimum
113             penalty is 0. Return the penalty that was subtracted.
114         """
115         if self.rankings:
116             min_penalty = min(self.default, min(r.penalty for r in self.rankings))
117         else:
118             min_penalty = self.default
119         if min_penalty > 0.0:
120             self.default -= min_penalty
121             for ranking in self.rankings:
122                 ranking.penalty -= min_penalty
123         return min_penalty
124
125
126     def sql_penalty(self, table: SaFromClause) -> SaColumn:
127         """ Create an SQL expression for the rankings.
128         """
129         assert self.rankings
130
131         return sa.func.weigh_search(table.c[self.column],
132                                     [f"{{{','.join((str(s) for s in r.tokens))}}}"
133                                      for r in self.rankings],
134                                     [r.penalty for r in self.rankings],
135                                     self.default)
136
137
138 @dataclasses.dataclass
139 class FieldLookup:
140     """ A list of tokens to be searched for. The column names the database
141         column to search in and the lookup_type the operator that is applied.
142         'lookup_all' requires all tokens to match. 'lookup_any' requires
143         one of the tokens to match. 'restrict' requires to match all tokens
144         but avoids the use of indexes.
145     """
146     column: str
147     tokens: List[int]
148     lookup_type: str
149
150     def sql_condition(self, table: SaFromClause) -> SaColumn:
151         """ Create an SQL expression for the given match condition.
152         """
153         col = table.c[self.column]
154         if self.lookup_type == 'lookup_all':
155             return col.contains(self.tokens)
156         if self.lookup_type == 'lookup_any':
157             return cast(SaColumn, col.overlaps(self.tokens))
158
159         return sa.func.coalesce(sa.null(), col).contains(self.tokens) # pylint: disable=not-callable
160
161
162 class SearchData:
163     """ Search fields derived from query and token assignment
164         to be used with the SQL queries.
165     """
166     penalty: float
167
168     lookups: List[FieldLookup] = []
169     rankings: List[FieldRanking]
170
171     housenumbers: WeightedStrings = WeightedStrings([], [])
172     postcodes: WeightedStrings = WeightedStrings([], [])
173     countries: WeightedStrings = WeightedStrings([], [])
174
175     qualifiers: WeightedCategories = WeightedCategories([], [])
176
177
178     def set_strings(self, field: str, tokens: List[Token]) -> None:
179         """ Set on of the WeightedStrings properties from the given
180             token list. Adapt the global penalty, so that the
181             minimum penalty is 0.
182         """
183         if tokens:
184             min_penalty = min(t.penalty for t in tokens)
185             self.penalty += min_penalty
186             wstrs = WeightedStrings([t.lookup_word for t in tokens],
187                                     [t.penalty - min_penalty for t in tokens])
188
189             setattr(self, field, wstrs)
190
191
192     def set_qualifiers(self, tokens: List[Token]) -> None:
193         """ Set the qulaifier field from the given tokens.
194         """
195         if tokens:
196             categories: Dict[Tuple[str, str], float] = {}
197             min_penalty = 1000.0
198             for t in tokens:
199                 if t.penalty < min_penalty:
200                     min_penalty = t.penalty
201                 cat = t.get_category()
202                 if t.penalty < categories.get(cat, 1000.0):
203                     categories[cat] = t.penalty
204             self.penalty += min_penalty
205             self.qualifiers = WeightedCategories(list(categories.keys()),
206                                                  list(categories.values()))
207
208
209     def set_ranking(self, rankings: List[FieldRanking]) -> None:
210         """ Set the list of rankings and normalize the ranking.
211         """
212         self.rankings = []
213         for ranking in rankings:
214             if ranking.rankings:
215                 self.penalty += ranking.normalize_penalty()
216                 self.rankings.append(ranking)
217             else:
218                 self.penalty += ranking.default
219
220
221 def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
222     """ Create a lookup list where name tokens are looked up via index
223         and potential address tokens are used to restrict the search further.
224     """
225     lookup = [FieldLookup('name_vector', name_tokens, 'lookup_all')]
226     if addr_tokens:
227         lookup.append(FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
228
229     return lookup
230
231
232 def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int],
233                        lookup_type: str) -> List[FieldLookup]:
234     """ Create a lookup list where name tokens are looked up via index
235         and only one of the name tokens must be present.
236         Potential address tokens are used to restrict the search further.
237     """
238     lookup = [FieldLookup('name_vector', name_tokens, 'lookup_any')]
239     if addr_tokens:
240         lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookup_type))
241
242     return lookup
243
244
245 def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
246     """ Create a lookup list where address tokens are looked up via index
247         and the name tokens are only used to restrict the search further.
248     """
249     return [FieldLookup('name_vector', name_tokens, 'restrict'),
250             FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]