]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_fields.py
Merge pull request #3262 from lonvia/fix-category-search
[nominatim.git] / nominatim / api / search / db_search_fields.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Data structures for more complex fields in abstract search descriptions.
9 """
10 from typing import List, Tuple, Iterator, cast, Dict
11 import dataclasses
12
13 import sqlalchemy as sa
14 from sqlalchemy.dialects.postgresql import ARRAY
15
16 from nominatim.typing import SaFromClause, SaColumn, SaExpression
17 from nominatim.api.search.query import Token
18
19 @dataclasses.dataclass
20 class WeightedStrings:
21     """ A list of strings together with a penalty.
22     """
23     values: List[str]
24     penalties: List[float]
25
26     def __bool__(self) -> bool:
27         return bool(self.values)
28
29
30     def __iter__(self) -> Iterator[Tuple[str, float]]:
31         return iter(zip(self.values, self.penalties))
32
33
34     def get_penalty(self, value: str, default: float = 1000.0) -> float:
35         """ Get the penalty for the given value. Returns the given default
36             if the value does not exist.
37         """
38         try:
39             return self.penalties[self.values.index(value)]
40         except ValueError:
41             pass
42         return default
43
44
45 @dataclasses.dataclass
46 class WeightedCategories:
47     """ A list of class/type tuples together with a penalty.
48     """
49     values: List[Tuple[str, str]]
50     penalties: List[float]
51
52     def __bool__(self) -> bool:
53         return bool(self.values)
54
55
56     def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
57         return iter(zip(self.values, self.penalties))
58
59
60     def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
61         """ Get the penalty for the given value. Returns the given default
62             if the value does not exist.
63         """
64         try:
65             return self.penalties[self.values.index(value)]
66         except ValueError:
67             pass
68         return default
69
70
71     def sql_restrict(self, table: SaFromClause) -> SaExpression:
72         """ Return an SQLAlcheny expression that restricts the
73             class and type columns of the given table to the values
74             in the list.
75             Must not be used with an empty list.
76         """
77         assert self.values
78         if len(self.values) == 1:
79             return sa.and_(table.c.class_ == self.values[0][0],
80                            table.c.type == self.values[0][1])
81
82         return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
83                         for c, t in self.values))
84
85
86 @dataclasses.dataclass(order=True)
87 class RankedTokens:
88     """ List of tokens together with the penalty of using it.
89     """
90     penalty: float
91     tokens: List[int]
92
93     def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
94         """ Create a new RankedTokens list with the given token appended.
95             The tokens penalty as well as the given transision penalty
96             are added to the overall penalty.
97         """
98         return RankedTokens(self.penalty + t.penalty + transition_penalty,
99                             self.tokens + [t.token])
100
101
102 @dataclasses.dataclass
103 class FieldRanking:
104     """ A list of rankings to be applied sequentially until one matches.
105         The matched ranking determines the penalty. If none matches a
106         default penalty is applied.
107     """
108     column: str
109     default: float
110     rankings: List[RankedTokens]
111
112     def normalize_penalty(self) -> float:
113         """ Reduce the default and ranking penalties, such that the minimum
114             penalty is 0. Return the penalty that was subtracted.
115         """
116         if self.rankings:
117             min_penalty = min(self.default, min(r.penalty for r in self.rankings))
118         else:
119             min_penalty = self.default
120         if min_penalty > 0.0:
121             self.default -= min_penalty
122             for ranking in self.rankings:
123                 ranking.penalty -= min_penalty
124         return min_penalty
125
126
127     def sql_penalty(self, table: SaFromClause) -> SaColumn:
128         """ Create an SQL expression for the rankings.
129         """
130         assert self.rankings
131
132         return sa.func.weigh_search(table.c[self.column],
133                                     [f"{{{','.join((str(s) for s in r.tokens))}}}"
134                                      for r in self.rankings],
135                                     [r.penalty for r in self.rankings],
136                                     self.default)
137
138
139 @dataclasses.dataclass
140 class FieldLookup:
141     """ A list of tokens to be searched for. The column names the database
142         column to search in and the lookup_type the operator that is applied.
143         'lookup_all' requires all tokens to match. 'lookup_any' requires
144         one of the tokens to match. 'restrict' requires to match all tokens
145         but avoids the use of indexes.
146     """
147     column: str
148     tokens: List[int]
149     lookup_type: str
150
151     def sql_condition(self, table: SaFromClause) -> SaColumn:
152         """ Create an SQL expression for the given match condition.
153         """
154         col = table.c[self.column]
155         if self.lookup_type == 'lookup_all':
156             return col.contains(self.tokens)
157         if self.lookup_type == 'lookup_any':
158             return cast(SaColumn, col.overlap(self.tokens))
159
160         return sa.func.array_cat(col, sa.text('ARRAY[]::integer[]'),
161                                  type_=ARRAY(sa.Integer())).contains(self.tokens)
162
163
164 class SearchData:
165     """ Search fields derived from query and token assignment
166         to be used with the SQL queries.
167     """
168     penalty: float
169
170     lookups: List[FieldLookup] = []
171     rankings: List[FieldRanking]
172
173     housenumbers: WeightedStrings = WeightedStrings([], [])
174     postcodes: WeightedStrings = WeightedStrings([], [])
175     countries: WeightedStrings = WeightedStrings([], [])
176
177     qualifiers: WeightedCategories = WeightedCategories([], [])
178
179
180     def set_strings(self, field: str, tokens: List[Token]) -> None:
181         """ Set on of the WeightedStrings properties from the given
182             token list. Adapt the global penalty, so that the
183             minimum penalty is 0.
184         """
185         if tokens:
186             min_penalty = min(t.penalty for t in tokens)
187             self.penalty += min_penalty
188             wstrs = WeightedStrings([t.lookup_word for t in tokens],
189                                     [t.penalty - min_penalty for t in tokens])
190
191             setattr(self, field, wstrs)
192
193
194     def set_qualifiers(self, tokens: List[Token]) -> None:
195         """ Set the qulaifier field from the given tokens.
196         """
197         if tokens:
198             categories: Dict[Tuple[str, str], float] = {}
199             min_penalty = 1000.0
200             for t in tokens:
201                 if t.penalty < min_penalty:
202                     min_penalty = t.penalty
203                 cat = t.get_category()
204                 if t.penalty < categories.get(cat, 1000.0):
205                     categories[cat] = t.penalty
206             self.penalty += min_penalty
207             self.qualifiers = WeightedCategories(list(categories.keys()),
208                                                  list(categories.values()))
209
210
211     def set_ranking(self, rankings: List[FieldRanking]) -> None:
212         """ Set the list of rankings and normalize the ranking.
213         """
214         self.rankings = []
215         for ranking in rankings:
216             if ranking.rankings:
217                 self.penalty += ranking.normalize_penalty()
218                 self.rankings.append(ranking)
219             else:
220                 self.penalty += ranking.default
221
222
223 def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
224     """ Create a lookup list where name tokens are looked up via index
225         and potential address tokens are used to restrict the search further.
226     """
227     lookup = [FieldLookup('name_vector', name_tokens, 'lookup_all')]
228     if addr_tokens:
229         lookup.append(FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
230
231     return lookup
232
233
234 def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int],
235                        lookup_type: str) -> List[FieldLookup]:
236     """ Create a lookup list where name tokens are looked up via index
237         and only one of the name tokens must be present.
238         Potential address tokens are used to restrict the search further.
239     """
240     lookup = [FieldLookup('name_vector', name_tokens, 'lookup_any')]
241     if addr_tokens:
242         lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookup_type))
243
244     return lookup
245
246
247 def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
248     """ Create a lookup list where address tokens are looked up via index
249         and the name tokens are only used to restrict the search further.
250     """
251     return [FieldLookup('name_vector', name_tokens, 'restrict'),
252             FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]