]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/sqlalchemy_types/int_array.py
335d5541972bebad04b7ab79367357bdba49bb29
[nominatim.git] / nominatim / db / sqlalchemy_types / int_array.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Custom type for an array of integers.
9 """
10 from typing import Any, List, cast, Optional
11
12 import sqlalchemy as sa
13 from sqlalchemy.dialects.postgresql import ARRAY
14
15 from nominatim.typing import SaDialect, SaColumn
16
17 # pylint: disable=all
18
19 class IntList(sa.types.TypeDecorator[Any]):
20     """ A list of integers saved as a text of comma-separated numbers.
21     """
22     impl = sa.types.Unicode
23     cache_ok = True
24
25     def process_bind_param(self, value: Optional[Any], dialect: 'sa.Dialect') -> Optional[str]:
26         if value is None:
27             return None
28
29         assert isinstance(value, list)
30         return ','.join(map(str, value))
31
32     def process_result_value(self, value: Optional[Any],
33                              dialect: SaDialect) -> Optional[List[int]]:
34         return [int(v) for v in value.split(',')] if value is not None else None
35
36     def copy(self, **kw: Any) -> 'IntList':
37         return IntList(self.impl.length)
38
39
40 class IntArray(sa.types.TypeDecorator[Any]):
41     """ Dialect-independent list of integers.
42     """
43     impl = IntList
44     cache_ok = True
45
46     def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
47         if dialect.name == 'postgresql':
48             return ARRAY(sa.Integer()) #pylint: disable=invalid-name
49
50         return IntList()
51
52
53     class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
54
55         def __add__(self, other: SaColumn) -> 'sa.ColumnOperators':
56             """ Concate the array with the given array. If one of the
57                 operants is null, the value of the other will be returned.
58             """
59             return sa.func.array_cat(self, other, type_=IntArray)
60
61
62         def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
63             """ Return true if the array contains all the value of the argument
64                 array.
65             """
66             return cast('sa.ColumnOperators', self.op('@>', is_comparison=True)(other))
67
68
69         def overlaps(self, other: SaColumn) -> 'sa.Operators':
70             """ Return true if at least one value of the argument is contained
71                 in the array.
72             """
73             return self.op('&&', is_comparison=True)(other)