]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_fields.py
factor out SQL for filtering by location
[nominatim.git] / nominatim / api / search / db_search_fields.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Data structures for more complex fields in abstract search descriptions.
9 """
10 from typing import List, Tuple, Iterator, Dict, Type
11 import dataclasses
12
13 import sqlalchemy as sa
14
15 from nominatim.typing import SaFromClause, SaColumn, SaExpression
16 from nominatim.api.search.query import Token
17 import nominatim.api.search.db_search_lookups as lookups
18 from nominatim.utils.json_writer import JsonWriter
19
20
21 @dataclasses.dataclass
22 class WeightedStrings:
23     """ A list of strings together with a penalty.
24     """
25     values: List[str]
26     penalties: List[float]
27
28     def __bool__(self) -> bool:
29         return bool(self.values)
30
31
32     def __iter__(self) -> Iterator[Tuple[str, float]]:
33         return iter(zip(self.values, self.penalties))
34
35
36     def get_penalty(self, value: str, default: float = 1000.0) -> float:
37         """ Get the penalty for the given value. Returns the given default
38             if the value does not exist.
39         """
40         try:
41             return self.penalties[self.values.index(value)]
42         except ValueError:
43             pass
44         return default
45
46
47 @dataclasses.dataclass
48 class WeightedCategories:
49     """ A list of class/type tuples together with a penalty.
50     """
51     values: List[Tuple[str, str]]
52     penalties: List[float]
53
54     def __bool__(self) -> bool:
55         return bool(self.values)
56
57
58     def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
59         return iter(zip(self.values, self.penalties))
60
61
62     def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
63         """ Get the penalty for the given value. Returns the given default
64             if the value does not exist.
65         """
66         try:
67             return self.penalties[self.values.index(value)]
68         except ValueError:
69             pass
70         return default
71
72
73     def sql_restrict(self, table: SaFromClause) -> SaExpression:
74         """ Return an SQLAlcheny expression that restricts the
75             class and type columns of the given table to the values
76             in the list.
77             Must not be used with an empty list.
78         """
79         assert self.values
80         if len(self.values) == 1:
81             return sa.and_(table.c.class_ == self.values[0][0],
82                            table.c.type == self.values[0][1])
83
84         return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
85                         for c, t in self.values))
86
87
88 @dataclasses.dataclass(order=True)
89 class RankedTokens:
90     """ List of tokens together with the penalty of using it.
91     """
92     penalty: float
93     tokens: List[int]
94
95     def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
96         """ Create a new RankedTokens list with the given token appended.
97             The tokens penalty as well as the given transision penalty
98             are added to the overall penalty.
99         """
100         return RankedTokens(self.penalty + t.penalty + transition_penalty,
101                             self.tokens + [t.token])
102
103
104 @dataclasses.dataclass
105 class FieldRanking:
106     """ A list of rankings to be applied sequentially until one matches.
107         The matched ranking determines the penalty. If none matches a
108         default penalty is applied.
109     """
110     column: str
111     default: float
112     rankings: List[RankedTokens]
113
114     def normalize_penalty(self) -> float:
115         """ Reduce the default and ranking penalties, such that the minimum
116             penalty is 0. Return the penalty that was subtracted.
117         """
118         if self.rankings:
119             min_penalty = min(self.default, min(r.penalty for r in self.rankings))
120         else:
121             min_penalty = self.default
122         if min_penalty > 0.0:
123             self.default -= min_penalty
124             for ranking in self.rankings:
125                 ranking.penalty -= min_penalty
126         return min_penalty
127
128
129     def sql_penalty(self, table: SaFromClause) -> SaColumn:
130         """ Create an SQL expression for the rankings.
131         """
132         assert self.rankings
133
134         rout = JsonWriter().start_array()
135         for rank in self.rankings:
136             rout.start_array().value(rank.penalty).next()
137             rout.start_array()
138             for token in rank.tokens:
139                 rout.value(token).next()
140             rout.end_array()
141             rout.end_array().next()
142         rout.end_array()
143
144         return sa.func.weigh_search(table.c[self.column], rout(), self.default)
145
146
147 @dataclasses.dataclass
148 class FieldLookup:
149     """ A list of tokens to be searched for. The column names the database
150         column to search in and the lookup_type the operator that is applied.
151         'lookup_all' requires all tokens to match. 'lookup_any' requires
152         one of the tokens to match. 'restrict' requires to match all tokens
153         but avoids the use of indexes.
154     """
155     column: str
156     tokens: List[int]
157     lookup_type: Type[lookups.LookupType]
158
159     def sql_condition(self, table: SaFromClause) -> SaColumn:
160         """ Create an SQL expression for the given match condition.
161         """
162         return self.lookup_type(table, self.column, self.tokens)
163
164
165 class SearchData:
166     """ Search fields derived from query and token assignment
167         to be used with the SQL queries.
168     """
169     penalty: float
170
171     lookups: List[FieldLookup] = []
172     rankings: List[FieldRanking]
173
174     housenumbers: WeightedStrings = WeightedStrings([], [])
175     postcodes: WeightedStrings = WeightedStrings([], [])
176     countries: WeightedStrings = WeightedStrings([], [])
177
178     qualifiers: WeightedCategories = WeightedCategories([], [])
179
180
181     def set_strings(self, field: str, tokens: List[Token]) -> None:
182         """ Set on of the WeightedStrings properties from the given
183             token list. Adapt the global penalty, so that the
184             minimum penalty is 0.
185         """
186         if tokens:
187             min_penalty = min(t.penalty for t in tokens)
188             self.penalty += min_penalty
189             wstrs = WeightedStrings([t.lookup_word for t in tokens],
190                                     [t.penalty - min_penalty for t in tokens])
191
192             setattr(self, field, wstrs)
193
194
195     def set_qualifiers(self, tokens: List[Token]) -> None:
196         """ Set the qulaifier field from the given tokens.
197         """
198         if tokens:
199             categories: Dict[Tuple[str, str], float] = {}
200             min_penalty = 1000.0
201             for t in tokens:
202                 if t.penalty < min_penalty:
203                     min_penalty = t.penalty
204                 cat = t.get_category()
205                 if t.penalty < categories.get(cat, 1000.0):
206                     categories[cat] = t.penalty
207             self.penalty += min_penalty
208             self.qualifiers = WeightedCategories(list(categories.keys()),
209                                                  list(categories.values()))
210
211
212     def set_ranking(self, rankings: List[FieldRanking]) -> None:
213         """ Set the list of rankings and normalize the ranking.
214         """
215         self.rankings = []
216         for ranking in rankings:
217             if ranking.rankings:
218                 self.penalty += ranking.normalize_penalty()
219                 self.rankings.append(ranking)
220             else:
221                 self.penalty += ranking.default
222
223
224 def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
225     """ Create a lookup list where name tokens are looked up via index
226         and potential address tokens are used to restrict the search further.
227     """
228     lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
229     if addr_tokens:
230         lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookups.Restrict))
231
232     return lookup
233
234
235 def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int],
236                        use_index_for_addr: bool) -> List[FieldLookup]:
237     """ Create a lookup list where name tokens are looked up via index
238         and only one of the name tokens must be present.
239         Potential address tokens are used to restrict the search further.
240     """
241     lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAny)]
242     if addr_tokens:
243         lookup.append(FieldLookup('nameaddress_vector', addr_tokens,
244                                   lookups.LookupAll if use_index_for_addr else lookups.Restrict))
245
246     return lookup
247
248
249 def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
250     """ Create a lookup list where address tokens are looked up via index
251         and the name tokens are only used to restrict the search further.
252     """
253     return [FieldLookup('name_vector', name_tokens, lookups.Restrict),
254             FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)]