]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/sql/sqlalchemy_types/int_array.py
e53f8bfdd7439ee027950b46b86ca7b13fb51c62
[nominatim.git] / src / nominatim_api / sql / sqlalchemy_types / int_array.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Custom type for an array of integers.
9 """
10 from typing import Any, List, cast, Optional
11
12 import sqlalchemy as sa
13 from sqlalchemy.ext.compiler import compiles
14 from sqlalchemy.dialects.postgresql import ARRAY
15
16 from ...typing import SaDialect, SaColumn
17
18 # pylint: disable=all
19
20 class IntList(sa.types.TypeDecorator[Any]):
21     """ A list of integers saved as a text of comma-separated numbers.
22     """
23     impl = sa.types.Unicode
24     cache_ok = True
25
26     def process_bind_param(self, value: Optional[Any], dialect: 'sa.Dialect') -> Optional[str]:
27         if value is None:
28             return None
29
30         assert isinstance(value, list)
31         return ','.join(map(str, value))
32
33     def process_result_value(self, value: Optional[Any],
34                              dialect: SaDialect) -> Optional[List[int]]:
35         return [int(v) for v in value.split(',')] if value is not None else None
36
37     def copy(self, **kw: Any) -> 'IntList':
38         return IntList(self.impl.length)
39
40
41 class IntArray(sa.types.TypeDecorator[Any]):
42     """ Dialect-independent list of integers.
43     """
44     impl = IntList
45     cache_ok = True
46
47     def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
48         if dialect.name == 'postgresql':
49             return ARRAY(sa.Integer()) #pylint: disable=invalid-name
50
51         return IntList()
52
53
54     class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
55
56         def __add__(self, other: SaColumn) -> 'sa.ColumnOperators':
57             """ Concate the array with the given array. If one of the
58                 operants is null, the value of the other will be returned.
59             """
60             return ArrayCat(self.expr, other)
61
62
63         def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
64             """ Return true if the array contains all the value of the argument
65                 array.
66             """
67             return ArrayContains(self.expr, other)
68
69
70
71 class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
72     """ Aggregate function to collect elements in an array.
73     """
74     type = IntArray()
75     identifier = 'ArrayAgg'
76     name = 'array_agg'
77     inherit_cache = True
78
79
80 @compiles(ArrayAgg, 'sqlite')
81 def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str:
82     return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)
83
84
85
86 class ArrayContains(sa.sql.expression.FunctionElement[Any]):
87     """ Function to check if an array is fully contained in another.
88     """
89     name = 'ArrayContains'
90     inherit_cache = True
91
92
93 @compiles(ArrayContains)
94 def generic_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
95     arg1, arg2 = list(element.clauses)
96     return "(%s @> %s)" % (compiler.process(arg1, **kw),
97                            compiler.process(arg2, **kw))
98
99
100 @compiles(ArrayContains, 'sqlite')
101 def sqlite_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
102     return "array_contains(%s)" % compiler.process(element.clauses, **kw)
103
104
105
106 class ArrayCat(sa.sql.expression.FunctionElement[Any]):
107     """ Function to check if an array is fully contained in another.
108     """
109     type = IntArray()
110     identifier = 'ArrayCat'
111     inherit_cache = True
112
113
114 @compiles(ArrayCat)
115 def generic_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
116     return "array_cat(%s)" % compiler.process(element.clauses, **kw)
117
118
119 @compiles(ArrayCat, 'sqlite')
120 def sqlite_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
121     arg1, arg2 = list(element.clauses)
122     return "(%s || ',' || %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
123