]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/sql/sqlalchemy_types/key_value.py
15e1f6c50fd4236c804fca07805597156eb33e59
[nominatim.git] / src / nominatim_api / sql / sqlalchemy_types / key_value.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 A custom type that implements a simple key-value store of strings.
9 """
10 from typing import Any
11
12 import sqlalchemy as sa
13 from sqlalchemy.ext.compiler import compiles
14 from sqlalchemy.dialects.postgresql import HSTORE
15 from sqlalchemy.dialects.sqlite import JSON as sqlite_json
16
17 from ...typing import SaDialect, SaColumn
18
19 # pylint: disable=all
20
21 class KeyValueStore(sa.types.TypeDecorator[Any]):
22     """ Dialect-independent type of a simple key-value store of strings.
23     """
24     impl = HSTORE
25     cache_ok = True
26
27     def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
28         if dialect.name == 'postgresql':
29             return HSTORE() # type: ignore[no-untyped-call]
30
31         return sqlite_json(none_as_null=True)
32
33
34     class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
35
36         def merge(self, other: SaColumn) -> 'sa.Operators':
37             """ Merge the values from the given KeyValueStore into this
38                 one, overwriting values where necessary. When the argument
39                 is null, nothing happens.
40             """
41             return KeyValueConcat(self.expr, other)
42
43
44 class KeyValueConcat(sa.sql.expression.FunctionElement[Any]):
45     """ Return the merged key-value store from the input parameters.
46     """
47     type = KeyValueStore()
48     name = 'JsonConcat'
49     inherit_cache = True
50
51 @compiles(KeyValueConcat) # type: ignore[no-untyped-call, misc]
52 def default_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
53     arg1, arg2 = list(element.clauses)
54     return "(%s || coalesce(%s, ''::hstore))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
55
56 @compiles(KeyValueConcat, 'sqlite') # type: ignore[no-untyped-call, misc]
57 def sqlite_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
58     arg1, arg2 = list(element.clauses)
59     return "json_patch(%s, coalesce(%s, '{}'))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
60
61
62