]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/sql/sqlite_functions.py
adapt BDD tests for legacy tokenizer + Python frontend
[nominatim.git] / src / nominatim_api / sql / sqlite_functions.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Custom functions for SQLite.
9 """
10 from typing import cast, Optional, Set, Any
11 import json
12
13 # pylint: disable=protected-access
14
15 def weigh_search(search_vector: Optional[str], rankings: str, default: float) -> float:
16     """ Custom weight function for search results.
17     """
18     if search_vector is not None:
19         svec = [int(x) for x in search_vector.split(',')]
20         for rank in json.loads(rankings):
21             if all(r in svec for r in rank[1]):
22                 return cast(float, rank[0])
23
24     return default
25
26
27 class ArrayIntersectFuzzy:
28     """ Compute the array of common elements of all input integer arrays.
29         Very large input parameters may be ignored to speed up
30         computation. Therefore, the result is a superset of common elements.
31
32         Input and output arrays are given as comma-separated lists.
33     """
34     def __init__(self) -> None:
35         self.first = ''
36         self.values: Optional[Set[int]] = None
37
38     def step(self, value: Optional[str]) -> None:
39         """ Add the next array to the intersection.
40         """
41         if value is not None:
42             if not self.first:
43                 self.first = value
44             elif len(value) < 10000000:
45                 if self.values is None:
46                     self.values = {int(x) for x in self.first.split(',')}
47                 self.values.intersection_update((int(x) for x in value.split(',')))
48
49     def finalize(self) -> str:
50         """ Return the final result.
51         """
52         if self.values is not None:
53             return ','.join(map(str, self.values))
54
55         return self.first
56
57
58 class ArrayUnion:
59     """ Compute the set of all elements of the input integer arrays.
60
61         Input and output arrays are given as strings of comma-separated lists.
62     """
63     def __init__(self) -> None:
64         self.values: Optional[Set[str]] = None
65
66     def step(self, value: Optional[str]) -> None:
67         """ Add the next array to the union.
68         """
69         if value is not None:
70             if self.values is None:
71                 self.values = set(value.split(','))
72             else:
73                 self.values.update(value.split(','))
74
75     def finalize(self) -> str:
76         """ Return the final result.
77         """
78         return '' if self.values is None else ','.join(self.values)
79
80
81 def array_contains(container: Optional[str], containee: Optional[str]) -> Optional[bool]:
82     """ Is the array 'containee' completely contained in array 'container'.
83     """
84     if container is None or containee is None:
85         return None
86
87     vset = container.split(',')
88     return all(v in vset for v in containee.split(','))
89
90
91 def array_pair_contains(container1: Optional[str], container2: Optional[str],
92                         containee: Optional[str]) -> Optional[bool]:
93     """ Is the array 'containee' completely contained in the union of
94         array 'container1' and array 'container2'.
95     """
96     if container1 is None or container2 is None or containee is None:
97         return None
98
99     vset = container1.split(',') + container2.split(',')
100     return all(v in vset for v in containee.split(','))
101
102
103 def install_custom_functions(conn: Any) -> None:
104     """ Install helper functions for Nominatim into the given SQLite
105         database connection.
106     """
107     conn.create_function('weigh_search', 3, weigh_search, deterministic=True)
108     conn.create_function('array_contains', 2, array_contains, deterministic=True)
109     conn.create_function('array_pair_contains', 3, array_pair_contains, deterministic=True)
110     _create_aggregate(conn, 'array_intersect_fuzzy', 1, ArrayIntersectFuzzy)
111     _create_aggregate(conn, 'array_union', 1, ArrayUnion)
112
113
114 async def _make_aggregate(aioconn: Any, *args: Any) -> None:
115     await aioconn._execute(aioconn._conn.create_aggregate, *args)
116
117
118 def _create_aggregate(conn: Any, name: str, nargs: int, aggregate: Any) -> None:
119     try:
120         conn.await_(_make_aggregate(conn._connection, name, nargs, aggregate))
121     except Exception as error: # pylint: disable=broad-exception-caught
122         conn._handle_exception(error)