1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Custom type for an array of integers.
10 from typing import Any, List, cast, Optional
12 import sqlalchemy as sa
13 from sqlalchemy.ext.compiler import compiles
14 from sqlalchemy.dialects.postgresql import ARRAY
16 from ...typing import SaDialect, SaColumn
20 class IntList(sa.types.TypeDecorator[Any]):
21 """ A list of integers saved as a text of comma-separated numbers.
23 impl = sa.types.Unicode
26 def process_bind_param(self, value: Optional[Any], dialect: 'sa.Dialect') -> Optional[str]:
30 assert isinstance(value, list)
31 return ','.join(map(str, value))
33 def process_result_value(self, value: Optional[Any],
34 dialect: SaDialect) -> Optional[List[int]]:
35 return [int(v) for v in value.split(',')] if value is not None else None
37 def copy(self, **kw: Any) -> 'IntList':
38 return IntList(self.impl.length)
41 class IntArray(sa.types.TypeDecorator[Any]):
42 """ Dialect-independent list of integers.
47 def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
48 if dialect.name == 'postgresql':
49 return ARRAY(sa.Integer()) #pylint: disable=invalid-name
54 class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
56 def __add__(self, other: SaColumn) -> 'sa.ColumnOperators':
57 """ Concate the array with the given array. If one of the
58 operants is null, the value of the other will be returned.
60 return ArrayCat(self.expr, other)
63 def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
64 """ Return true if the array contains all the value of the argument
67 return ArrayContains(self.expr, other)
71 class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
72 """ Aggregate function to collect elements in an array.
75 identifier = 'ArrayAgg'
80 @compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc]
81 def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str:
82 return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)
86 class ArrayContains(sa.sql.expression.FunctionElement[Any]):
87 """ Function to check if an array is fully contained in another.
89 name = 'ArrayContains'
93 @compiles(ArrayContains) # type: ignore[no-untyped-call, misc]
94 def generic_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
95 arg1, arg2 = list(element.clauses)
96 return "(%s @> %s)" % (compiler.process(arg1, **kw),
97 compiler.process(arg2, **kw))
100 @compiles(ArrayContains, 'sqlite') # type: ignore[no-untyped-call, misc]
101 def sqlite_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
102 return "array_contains(%s)" % compiler.process(element.clauses, **kw)
106 class ArrayCat(sa.sql.expression.FunctionElement[Any]):
107 """ Function to check if an array is fully contained in another.
110 identifier = 'ArrayCat'
114 @compiles(ArrayCat) # type: ignore[no-untyped-call, misc]
115 def generic_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
116 return "array_cat(%s)" % compiler.process(element.clauses, **kw)
119 @compiles(ArrayCat, 'sqlite') # type: ignore[no-untyped-call, misc]
120 def sqlite_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
121 arg1, arg2 = list(element.clauses)
122 return "(%s || ',' || %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))