]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/search/db_search_lookups.py
port unit tests to new python package layout
[nominatim.git] / src / nominatim_api / search / db_search_lookups.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Implementation of lookup functions for the search_name table.
9 """
10 from typing import List, Any
11
12 import sqlalchemy as sa
13 from sqlalchemy.ext.compiler import compiles
14
15 from nominatim_core.typing import SaFromClause
16 from nominatim_core.db.sqlalchemy_types import IntArray
17
18 # pylint: disable=consider-using-f-string
19
20 LookupType = sa.sql.expression.FunctionElement[Any]
21
22 class LookupAll(LookupType):
23     """ Find all entries in search_name table that contain all of
24         a given list of tokens using an index for the search.
25     """
26     inherit_cache = True
27
28     def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
29         super().__init__(table.c.place_id, getattr(table.c, column), column,
30                          sa.type_coerce(tokens, IntArray))
31
32
33 @compiles(LookupAll) # type: ignore[no-untyped-call, misc]
34 def _default_lookup_all(element: LookupAll,
35                         compiler: 'sa.Compiled', **kw: Any) -> str:
36     _, col, _, tokens = list(element.clauses)
37     return "(%s @> %s)" % (compiler.process(col, **kw),
38                            compiler.process(tokens, **kw))
39
40
41 @compiles(LookupAll, 'sqlite') # type: ignore[no-untyped-call, misc]
42 def _sqlite_lookup_all(element: LookupAll,
43                         compiler: 'sa.Compiled', **kw: Any) -> str:
44     place, col, colname, tokens = list(element.clauses)
45     return "(%s IN (SELECT CAST(value as bigint) FROM"\
46            " (SELECT array_intersect_fuzzy(places) as p FROM"\
47            "   (SELECT places FROM reverse_search_name"\
48            "   WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\
49            "     AND column = %s"\
50            "   ORDER BY length(places)) as x) as u,"\
51            " json_each('[' || u.p || ']'))"\
52            " AND array_contains(%s, %s))"\
53              % (compiler.process(place, **kw),
54                 compiler.process(tokens, **kw),
55                 compiler.process(colname, **kw),
56                 compiler.process(col, **kw),
57                 compiler.process(tokens, **kw)
58                 )
59
60
61
62 class LookupAny(LookupType):
63     """ Find all entries that contain at least one of the given tokens.
64         Use an index for the search.
65     """
66     inherit_cache = True
67
68     def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
69         super().__init__(table.c.place_id, getattr(table.c, column), column,
70                          sa.type_coerce(tokens, IntArray))
71
72 @compiles(LookupAny) # type: ignore[no-untyped-call, misc]
73 def _default_lookup_any(element: LookupAny,
74                         compiler: 'sa.Compiled', **kw: Any) -> str:
75     _, col, _, tokens = list(element.clauses)
76     return "(%s && %s)" % (compiler.process(col, **kw),
77                            compiler.process(tokens, **kw))
78
79 @compiles(LookupAny, 'sqlite') # type: ignore[no-untyped-call, misc]
80 def _sqlite_lookup_any(element: LookupAny,
81                         compiler: 'sa.Compiled', **kw: Any) -> str:
82     place, _, colname, tokens = list(element.clauses)
83     return "%s IN (SELECT CAST(value as bigint) FROM"\
84            " (SELECT array_union(places) as p FROM reverse_search_name"\
85            "   WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\
86            "     AND column = %s) as u,"\
87            " json_each('[' || u.p || ']'))" % (compiler.process(place, **kw),
88                                                compiler.process(tokens, **kw),
89                                                compiler.process(colname, **kw))
90
91
92
93 class Restrict(LookupType):
94     """ Find all entries that contain all of the given tokens.
95         Do not use an index for the search.
96     """
97     inherit_cache = True
98
99     def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
100         super().__init__(getattr(table.c, column),
101                          sa.type_coerce(tokens, IntArray))
102
103
104 @compiles(Restrict) # type: ignore[no-untyped-call, misc]
105 def _default_restrict(element: Restrict,
106                         compiler: 'sa.Compiled', **kw: Any) -> str:
107     arg1, arg2 = list(element.clauses)
108     return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw),
109                                            compiler.process(arg2, **kw))
110
111 @compiles(Restrict, 'sqlite') # type: ignore[no-untyped-call, misc]
112 def _sqlite_restrict(element: Restrict,
113                         compiler: 'sa.Compiled', **kw: Any) -> str:
114     return "array_contains(%s)" % compiler.process(element.clauses, **kw)