X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/bd2c64876f7ddc99da14ea78a652f797e17134f4..bdded69ab636ab8e3bb46322ebbe4d2f9ec41614:/nominatim/api/search/db_search_fields.py?ds=sidebyside diff --git a/nominatim/api/search/db_search_fields.py b/nominatim/api/search/db_search_fields.py index 325e08df..7f775277 100644 --- a/nominatim/api/search/db_search_fields.py +++ b/nominatim/api/search/db_search_fields.py @@ -7,14 +7,16 @@ """ Data structures for more complex fields in abstract search descriptions. """ -from typing import List, Tuple, Iterator, cast +from typing import List, Tuple, Iterator, Dict, Type import dataclasses import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import ARRAY from nominatim.typing import SaFromClause, SaColumn, SaExpression from nominatim.api.search.query import Token +import nominatim.api.search.db_search_lookups as lookups +from nominatim.utils.json_writer import JsonWriter + @dataclasses.dataclass class WeightedStrings: @@ -92,7 +94,7 @@ class RankedTokens: def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens': """ Create a new RankedTokens list with the given token appended. - The tokens penalty as well as the given transision penalty + The tokens penalty as well as the given transition penalty are added to the overall penalty. """ return RankedTokens(self.penalty + t.penalty + transition_penalty, @@ -129,10 +131,17 @@ class FieldRanking: """ assert self.rankings - col = table.c[self.column] + rout = JsonWriter().start_array() + for rank in self.rankings: + rout.start_array().value(rank.penalty).next() + rout.start_array() + for token in rank.tokens: + rout.value(token).next() + rout.end_array() + rout.end_array().next() + rout.end_array() - return sa.case(*((col.contains(r.tokens),r.penalty) for r in self.rankings), - else_=self.default) + return sa.func.weigh_search(table.c[self.column], rout(), self.default) @dataclasses.dataclass @@ -145,19 +154,12 @@ class FieldLookup: """ column: str tokens: List[int] - lookup_type: str + lookup_type: Type[lookups.LookupType] def sql_condition(self, table: SaFromClause) -> SaColumn: """ Create an SQL expression for the given match condition. """ - col = table.c[self.column] - if self.lookup_type == 'lookup_all': - return col.contains(self.tokens) - if self.lookup_type == 'lookup_any': - return cast(SaColumn, col.overlap(self.tokens)) - - return sa.func.array_cat(col, sa.text('ARRAY[]::integer[]'), - type_=ARRAY(sa.Integer())).contains(self.tokens) + return self.lookup_type(table, self.column, self.tokens) class SearchData: @@ -194,10 +196,16 @@ class SearchData: """ Set the qulaifier field from the given tokens. """ if tokens: - min_penalty = min(t.penalty for t in tokens) + categories: Dict[Tuple[str, str], float] = {} + min_penalty = 1000.0 + for t in tokens: + min_penalty = min(min_penalty, t.penalty) + cat = t.get_category() + if t.penalty < categories.get(cat, 1000.0): + categories[cat] = t.penalty self.penalty += min_penalty - self.qualifiers = WeightedCategories([t.get_category() for t in tokens], - [t.penalty - min_penalty for t in tokens]) + self.qualifiers = WeightedCategories(list(categories.keys()), + list(categories.values())) def set_ranking(self, rankings: List[FieldRanking]) -> None: @@ -210,3 +218,37 @@ class SearchData: self.rankings.append(ranking) else: self.penalty += ranking.default + + +def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]: + """ Create a lookup list where name tokens are looked up via index + and potential address tokens are used to restrict the search further. + """ + lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAll)] + if addr_tokens: + lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookups.Restrict)) + + return lookup + + +def lookup_by_any_name(name_tokens: List[int], addr_restrict_tokens: List[int], + addr_lookup_tokens: List[int]) -> List[FieldLookup]: + """ Create a lookup list where name tokens are looked up via index + and only one of the name tokens must be present. + Potential address tokens are used to restrict the search further. + """ + lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAny)] + if addr_restrict_tokens: + lookup.append(FieldLookup('nameaddress_vector', addr_restrict_tokens, lookups.Restrict)) + if addr_lookup_tokens: + lookup.append(FieldLookup('nameaddress_vector', addr_lookup_tokens, lookups.LookupAll)) + + return lookup + + +def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]: + """ Create a lookup list where address tokens are looked up via index + and the name tokens are only used to restrict the search further. + """ + return [FieldLookup('name_vector', name_tokens, lookups.Restrict), + FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)]