]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_fields.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / api / search / db_search_fields.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Data structures for more complex fields in abstract search descriptions.
9 """
10 from typing import List, Tuple, Iterator, cast
11 import dataclasses
12
13 import sqlalchemy as sa
14 from sqlalchemy.dialects.postgresql import ARRAY
15
16 from nominatim.typing import SaFromClause, SaColumn, SaExpression
17 from nominatim.api.search.query import Token
18
19 @dataclasses.dataclass
20 class WeightedStrings:
21     """ A list of strings together with a penalty.
22     """
23     values: List[str]
24     penalties: List[float]
25
26     def __bool__(self) -> bool:
27         return bool(self.values)
28
29
30     def __iter__(self) -> Iterator[Tuple[str, float]]:
31         return iter(zip(self.values, self.penalties))
32
33
34     def get_penalty(self, value: str, default: float = 1000.0) -> float:
35         """ Get the penalty for the given value. Returns the given default
36             if the value does not exist.
37         """
38         try:
39             return self.penalties[self.values.index(value)]
40         except ValueError:
41             pass
42         return default
43
44
45 @dataclasses.dataclass
46 class WeightedCategories:
47     """ A list of class/type tuples together with a penalty.
48     """
49     values: List[Tuple[str, str]]
50     penalties: List[float]
51
52     def __bool__(self) -> bool:
53         return bool(self.values)
54
55
56     def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
57         return iter(zip(self.values, self.penalties))
58
59
60     def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
61         """ Get the penalty for the given value. Returns the given default
62             if the value does not exist.
63         """
64         try:
65             return self.penalties[self.values.index(value)]
66         except ValueError:
67             pass
68         return default
69
70
71     def sql_restrict(self, table: SaFromClause) -> SaExpression:
72         """ Return an SQLAlcheny expression that restricts the
73             class and type columns of the given table to the values
74             in the list.
75             Must not be used with an empty list.
76         """
77         assert self.values
78         if len(self.values) == 1:
79             return sa.and_(table.c.class_ == self.values[0][0],
80                            table.c.type == self.values[0][1])
81
82         return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
83                         for c, t in self.values))
84
85
86 @dataclasses.dataclass(order=True)
87 class RankedTokens:
88     """ List of tokens together with the penalty of using it.
89     """
90     penalty: float
91     tokens: List[int]
92
93     def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
94         """ Create a new RankedTokens list with the given token appended.
95             The tokens penalty as well as the given transision penalty
96             are added to the overall penalty.
97         """
98         return RankedTokens(self.penalty + t.penalty + transition_penalty,
99                             self.tokens + [t.token])
100
101
102 @dataclasses.dataclass
103 class FieldRanking:
104     """ A list of rankings to be applied sequentially until one matches.
105         The matched ranking determines the penalty. If none matches a
106         default penalty is applied.
107     """
108     column: str
109     default: float
110     rankings: List[RankedTokens]
111
112     def normalize_penalty(self) -> float:
113         """ Reduce the default and ranking penalties, such that the minimum
114             penalty is 0. Return the penalty that was subtracted.
115         """
116         if self.rankings:
117             min_penalty = min(self.default, min(r.penalty for r in self.rankings))
118         else:
119             min_penalty = self.default
120         if min_penalty > 0.0:
121             self.default -= min_penalty
122             for ranking in self.rankings:
123                 ranking.penalty -= min_penalty
124         return min_penalty
125
126
127     def sql_penalty(self, table: SaFromClause) -> SaColumn:
128         """ Create an SQL expression for the rankings.
129         """
130         assert self.rankings
131
132         col = table.c[self.column]
133
134         return sa.case(*((col.contains(r.tokens),r.penalty) for r in self.rankings),
135                        else_=self.default)
136
137
138 @dataclasses.dataclass
139 class FieldLookup:
140     """ A list of tokens to be searched for. The column names the database
141         column to search in and the lookup_type the operator that is applied.
142         'lookup_all' requires all tokens to match. 'lookup_any' requires
143         one of the tokens to match. 'restrict' requires to match all tokens
144         but avoids the use of indexes.
145     """
146     column: str
147     tokens: List[int]
148     lookup_type: str
149
150     def sql_condition(self, table: SaFromClause) -> SaColumn:
151         """ Create an SQL expression for the given match condition.
152         """
153         col = table.c[self.column]
154         if self.lookup_type == 'lookup_all':
155             return col.contains(self.tokens)
156         if self.lookup_type == 'lookup_any':
157             return cast(SaColumn, col.overlap(self.tokens))
158
159         return sa.func.array_cat(col, sa.text('ARRAY[]::integer[]'),
160                                  type_=ARRAY(sa.Integer())).contains(self.tokens)
161
162
163 class SearchData:
164     """ Search fields derived from query and token assignment
165         to be used with the SQL queries.
166     """
167     penalty: float
168
169     lookups: List[FieldLookup] = []
170     rankings: List[FieldRanking]
171
172     housenumbers: WeightedStrings = WeightedStrings([], [])
173     postcodes: WeightedStrings = WeightedStrings([], [])
174     countries: WeightedStrings = WeightedStrings([], [])
175
176     qualifiers: WeightedCategories = WeightedCategories([], [])
177
178
179     def set_strings(self, field: str, tokens: List[Token]) -> None:
180         """ Set on of the WeightedStrings properties from the given
181             token list. Adapt the global penalty, so that the
182             minimum penalty is 0.
183         """
184         if tokens:
185             min_penalty = min(t.penalty for t in tokens)
186             self.penalty += min_penalty
187             wstrs = WeightedStrings([t.lookup_word for t in tokens],
188                                     [t.penalty - min_penalty for t in tokens])
189
190             setattr(self, field, wstrs)
191
192
193     def set_qualifiers(self, tokens: List[Token]) -> None:
194         """ Set the qulaifier field from the given tokens.
195         """
196         if tokens:
197             min_penalty = min(t.penalty for t in tokens)
198             self.penalty += min_penalty
199             self.qualifiers = WeightedCategories([t.get_category() for t in tokens],
200                                                  [t.penalty - min_penalty for t in tokens])
201
202
203     def set_ranking(self, rankings: List[FieldRanking]) -> None:
204         """ Set the list of rankings and normalize the ranking.
205         """
206         self.rankings = []
207         for ranking in rankings:
208             if ranking.rankings:
209                 self.penalty += ranking.normalize_penalty()
210                 self.rankings.append(ranking)
211             else:
212                 self.penalty += ranking.default