]> git.openstreetmap.org Git - nominatim.git/commitdiff
hide type differences between Postgres and Sqlite in custom types
authorSarah Hoffmann <lonvia@denofr.de>
Tue, 5 Dec 2023 10:29:16 +0000 (11:29 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Thu, 7 Dec 2023 08:31:00 +0000 (09:31 +0100)
Also define a custom set of operators in preparation of differences
in implementation.

nominatim/api/core.py
nominatim/api/search/db_search_fields.py
nominatim/api/search/db_searches.py
nominatim/api/search/icu_tokenizer.py
nominatim/db/sqlalchemy_schema.py
nominatim/db/sqlalchemy_types/__init__.py [new file with mode: 0644]
nominatim/db/sqlalchemy_types/geometry.py [moved from nominatim/db/sqlalchemy_types.py with 100% similarity]
nominatim/db/sqlalchemy_types/int_array.py [new file with mode: 0644]
nominatim/db/sqlalchemy_types/json.py [new file with mode: 0644]
nominatim/db/sqlalchemy_types/key_value.py [new file with mode: 0644]
nominatim/typing.py

index 44ac91606fef90a746bb26d06b2a9fc6da0e61e4..c8045c2d1494167a6a87dc2ef149b783607aa970 100644 (file)
@@ -137,7 +137,7 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
 
             self._property_cache['DB:server_version'] = server_version
 
 
             self._property_cache['DB:server_version'] = server_version
 
-            self._tables = SearchTables(sa.MetaData(), engine.name) # pylint: disable=no-member
+            self._tables = SearchTables(sa.MetaData()) # pylint: disable=no-member
             self._engine = engine
 
 
             self._engine = engine
 
 
index 59af826086db86027f2c808dee51824fb17e72ff..52693e95fce673026d97c545bc70b37ad52a17cf 100644 (file)
@@ -11,7 +11,6 @@ from typing import List, Tuple, Iterator, cast, Dict
 import dataclasses
 
 import sqlalchemy as sa
 import dataclasses
 
 import sqlalchemy as sa
-from sqlalchemy.dialects.postgresql import ARRAY
 
 from nominatim.typing import SaFromClause, SaColumn, SaExpression
 from nominatim.api.search.query import Token
 
 from nominatim.typing import SaFromClause, SaColumn, SaExpression
 from nominatim.api.search.query import Token
@@ -155,10 +154,9 @@ class FieldLookup:
         if self.lookup_type == 'lookup_all':
             return col.contains(self.tokens)
         if self.lookup_type == 'lookup_any':
         if self.lookup_type == 'lookup_all':
             return col.contains(self.tokens)
         if self.lookup_type == 'lookup_any':
-            return cast(SaColumn, col.overlap(self.tokens))
+            return cast(SaColumn, col.overlaps(self.tokens))
 
 
-        return sa.func.array_cat(col, sa.text('ARRAY[]::integer[]'),
-                                 type_=ARRAY(sa.Integer())).contains(self.tokens)
+        return sa.func.coalesce(sa.null(), col).contains(self.tokens) # pylint: disable=not-callable
 
 
 class SearchData:
 
 
 class SearchData:
index 232f816ef89609f050ea15e79f3651410222ef86..2b4dfd3c9bdc5c78c7ccaa07bb1a18a4bef97d16 100644 (file)
@@ -11,7 +11,7 @@ from typing import List, Tuple, AsyncIterator, Dict, Any, Callable
 import abc
 
 import sqlalchemy as sa
 import abc
 
 import sqlalchemy as sa
-from sqlalchemy.dialects.postgresql import ARRAY, array_agg
+from sqlalchemy.dialects.postgresql import array_agg
 
 from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \
                              SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind
 
 from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \
                              SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind
@@ -494,10 +494,7 @@ class CountrySearch(AbstractSearch):
         sub = sql.subquery('grid')
 
         sql = sa.select(t.c.country_code,
         sub = sql.subquery('grid')
 
         sql = sa.select(t.c.country_code,
-                        (t.c.name
-                         + sa.func.coalesce(t.c.derived_name,
-                                            sa.cast('', type_=conn.t.types.Composite))
-                        ).label('name'),
+                        t.c.name.merge(t.c.derived_name).label('name'),
                         sub.c.centroid, sub.c.bbox)\
                 .join(sub, t.c.country_code == sub.c.country_code)
 
                         sub.c.centroid, sub.c.bbox)\
                 .join(sub, t.c.country_code == sub.c.country_code)
 
@@ -569,10 +566,8 @@ class PostcodeSearch(AbstractSearch):
             assert self.lookups[0].lookup_type == 'restrict'
             tsearch = conn.t.search_name
             sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
             assert self.lookups[0].lookup_type == 'restrict'
             tsearch = conn.t.search_name
             sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
-                     .where(sa.func.array_cat(tsearch.c.name_vector,
-                                              tsearch.c.nameaddress_vector,
-                                              type_=ARRAY(sa.Integer))
-                                    .contains(self.lookups[0].tokens))
+                     .where((tsearch.c.name_vector + tsearch.c.nameaddress_vector)
+                                     .contains(self.lookups[0].tokens))
 
         for ranking in self.rankings:
             penalty += ranking.sql_penalty(conn.t.search_name)
 
         for ranking in self.rankings:
             penalty += ranking.sql_penalty(conn.t.search_name)
index fceec2df522feb5105936204b099e9a8a7a2ad96..eabd329d57e08cac6d8b6cbc8c0a8c3c0fcf8fdd 100644 (file)
@@ -22,6 +22,7 @@ from nominatim.api.connection import SearchConnection
 from nominatim.api.logging import log
 from nominatim.api.search import query as qmod
 from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer
 from nominatim.api.logging import log
 from nominatim.api.search import query as qmod
 from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer
+from nominatim.db.sqlalchemy_types import Json
 
 
 DB_TO_TOKEN_TYPE = {
 
 
 DB_TO_TOKEN_TYPE = {
@@ -159,7 +160,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                      sa.Column('word_token', sa.Text, nullable=False),
                      sa.Column('type', sa.Text, nullable=False),
                      sa.Column('word', sa.Text),
                      sa.Column('word_token', sa.Text, nullable=False),
                      sa.Column('type', sa.Text, nullable=False),
                      sa.Column('word', sa.Text),
-                     sa.Column('info', self.conn.t.types.Json))
+                     sa.Column('info', Json))
 
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
 
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
index 7dd1e0ce0b046182b6224eab7b5ec16769719b96..0ec22b7e1fa322469a2ea75d38642c3b75f02aa8 100644 (file)
@@ -7,37 +7,10 @@
 """
 SQLAlchemy definitions for all tables used by the frontend.
 """
 """
 SQLAlchemy definitions for all tables used by the frontend.
 """
-from typing import Any
-
 import sqlalchemy as sa
 import sqlalchemy as sa
-from sqlalchemy.dialects.postgresql import HSTORE, ARRAY, JSONB, array
-from sqlalchemy.dialects.sqlite import JSON as sqlite_json
 
 import nominatim.db.sqlalchemy_functions #pylint: disable=unused-import
 
 import nominatim.db.sqlalchemy_functions #pylint: disable=unused-import
-from nominatim.db.sqlalchemy_types import Geometry
-
-class PostgresTypes:
-    """ Type definitions for complex types as used in Postgres variants.
-    """
-    Composite = HSTORE
-    Json = JSONB
-    IntArray = ARRAY(sa.Integer()) #pylint: disable=invalid-name
-    to_array = array
-
-
-class SqliteTypes:
-    """ Type definitions for complex types as used in Postgres variants.
-    """
-    Composite = sqlite_json
-    Json = sqlite_json
-    IntArray = sqlite_json
-
-    @staticmethod
-    def to_array(arr: Any) -> Any:
-        """ Sqlite has no special conversion for arrays.
-        """
-        return arr
-
+from nominatim.db.sqlalchemy_types import Geometry, KeyValueStore, IntArray
 
 #pylint: disable=too-many-instance-attributes
 class SearchTables:
 
 #pylint: disable=too-many-instance-attributes
 class SearchTables:
@@ -47,14 +20,7 @@ class SearchTables:
         Any data used for updates only will not be visible.
     """
 
         Any data used for updates only will not be visible.
     """
 
-    def __init__(self, meta: sa.MetaData, engine_name: str) -> None:
-        if engine_name == 'postgresql':
-            self.types: Any = PostgresTypes
-        elif engine_name == 'sqlite':
-            self.types = SqliteTypes
-        else:
-            raise ValueError("Only 'postgresql' and 'sqlite' engines are supported.")
-
+    def __init__(self, meta: sa.MetaData) -> None:
         self.meta = meta
 
         self.import_status = sa.Table('import_status', meta,
         self.meta = meta
 
         self.import_status = sa.Table('import_status', meta,
@@ -80,9 +46,9 @@ class SearchTables:
             sa.Column('class', sa.Text, nullable=False, key='class_'),
             sa.Column('type', sa.Text, nullable=False),
             sa.Column('admin_level', sa.SmallInteger),
             sa.Column('class', sa.Text, nullable=False, key='class_'),
             sa.Column('type', sa.Text, nullable=False),
             sa.Column('admin_level', sa.SmallInteger),
-            sa.Column('name', self.types.Composite),
-            sa.Column('address', self.types.Composite),
-            sa.Column('extratags', self.types.Composite),
+            sa.Column('name', KeyValueStore),
+            sa.Column('address', KeyValueStore),
+            sa.Column('extratags', KeyValueStore),
             sa.Column('geometry', Geometry, nullable=False),
             sa.Column('wikipedia', sa.Text),
             sa.Column('country_code', sa.String(2)),
             sa.Column('geometry', Geometry, nullable=False),
             sa.Column('wikipedia', sa.Text),
             sa.Column('country_code', sa.String(2)),
@@ -118,14 +84,14 @@ class SearchTables:
             sa.Column('step', sa.SmallInteger),
             sa.Column('indexed_status', sa.SmallInteger),
             sa.Column('linegeo', Geometry),
             sa.Column('step', sa.SmallInteger),
             sa.Column('indexed_status', sa.SmallInteger),
             sa.Column('linegeo', Geometry),
-            sa.Column('address', self.types.Composite),
+            sa.Column('address', KeyValueStore),
             sa.Column('postcode', sa.Text),
             sa.Column('country_code', sa.String(2)))
 
         self.country_name = sa.Table('country_name', meta,
             sa.Column('country_code', sa.String(2)),
             sa.Column('postcode', sa.Text),
             sa.Column('country_code', sa.String(2)))
 
         self.country_name = sa.Table('country_name', meta,
             sa.Column('country_code', sa.String(2)),
-            sa.Column('name', self.types.Composite),
-            sa.Column('derived_name', self.types.Composite),
+            sa.Column('name', KeyValueStore),
+            sa.Column('derived_name', KeyValueStore),
             sa.Column('partition', sa.Integer))
 
         self.country_grid = sa.Table('country_osm_grid', meta,
             sa.Column('partition', sa.Integer))
 
         self.country_grid = sa.Table('country_osm_grid', meta,
@@ -139,8 +105,8 @@ class SearchTables:
             sa.Column('importance', sa.Float),
             sa.Column('search_rank', sa.SmallInteger),
             sa.Column('address_rank', sa.SmallInteger),
             sa.Column('importance', sa.Float),
             sa.Column('search_rank', sa.SmallInteger),
             sa.Column('address_rank', sa.SmallInteger),
-            sa.Column('name_vector', self.types.IntArray),
-            sa.Column('nameaddress_vector', self.types.IntArray),
+            sa.Column('name_vector', IntArray),
+            sa.Column('nameaddress_vector', IntArray),
             sa.Column('country_code', sa.String(2)),
             sa.Column('centroid', Geometry))
 
             sa.Column('country_code', sa.String(2)),
             sa.Column('centroid', Geometry))
 
diff --git a/nominatim/db/sqlalchemy_types/__init__.py b/nominatim/db/sqlalchemy_types/__init__.py
new file mode 100644 (file)
index 0000000..dc41799
--- /dev/null
@@ -0,0 +1,17 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Module with custom types for SQLAlchemy
+"""
+
+# See also https://github.com/PyCQA/pylint/issues/6006
+# pylint: disable=useless-import-alias
+
+from .geometry import (Geometry as Geometry)
+from .int_array import (IntArray as IntArray)
+from .key_value import (KeyValueStore as KeyValueStore)
+from .json import (Json as Json)
diff --git a/nominatim/db/sqlalchemy_types/int_array.py b/nominatim/db/sqlalchemy_types/int_array.py
new file mode 100644 (file)
index 0000000..335d554
--- /dev/null
@@ -0,0 +1,73 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Custom type for an array of integers.
+"""
+from typing import Any, List, cast, Optional
+
+import sqlalchemy as sa
+from sqlalchemy.dialects.postgresql import ARRAY
+
+from nominatim.typing import SaDialect, SaColumn
+
+# pylint: disable=all
+
+class IntList(sa.types.TypeDecorator[Any]):
+    """ A list of integers saved as a text of comma-separated numbers.
+    """
+    impl = sa.types.Unicode
+    cache_ok = True
+
+    def process_bind_param(self, value: Optional[Any], dialect: 'sa.Dialect') -> Optional[str]:
+        if value is None:
+            return None
+
+        assert isinstance(value, list)
+        return ','.join(map(str, value))
+
+    def process_result_value(self, value: Optional[Any],
+                             dialect: SaDialect) -> Optional[List[int]]:
+        return [int(v) for v in value.split(',')] if value is not None else None
+
+    def copy(self, **kw: Any) -> 'IntList':
+        return IntList(self.impl.length)
+
+
+class IntArray(sa.types.TypeDecorator[Any]):
+    """ Dialect-independent list of integers.
+    """
+    impl = IntList
+    cache_ok = True
+
+    def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
+        if dialect.name == 'postgresql':
+            return ARRAY(sa.Integer()) #pylint: disable=invalid-name
+
+        return IntList()
+
+
+    class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
+
+        def __add__(self, other: SaColumn) -> 'sa.ColumnOperators':
+            """ Concate the array with the given array. If one of the
+                operants is null, the value of the other will be returned.
+            """
+            return sa.func.array_cat(self, other, type_=IntArray)
+
+
+        def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
+            """ Return true if the array contains all the value of the argument
+                array.
+            """
+            return cast('sa.ColumnOperators', self.op('@>', is_comparison=True)(other))
+
+
+        def overlaps(self, other: SaColumn) -> 'sa.Operators':
+            """ Return true if at least one value of the argument is contained
+                in the array.
+            """
+            return self.op('&&', is_comparison=True)(other)
diff --git a/nominatim/db/sqlalchemy_types/json.py b/nominatim/db/sqlalchemy_types/json.py
new file mode 100644 (file)
index 0000000..31635fd
--- /dev/null
@@ -0,0 +1,30 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Common json type for different dialects.
+"""
+from typing import Any
+
+import sqlalchemy as sa
+from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.dialects.sqlite import JSON as sqlite_json
+
+from nominatim.typing import SaDialect
+
+# pylint: disable=all
+
+class Json(sa.types.TypeDecorator[Any]):
+    """ Dialect-independent type for JSON.
+    """
+    impl = sa.types.JSON
+    cache_ok = True
+
+    def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
+        if dialect.name == 'postgresql':
+            return JSONB(none_as_null=True) # type: ignore[no-untyped-call]
+
+        return sqlite_json(none_as_null=True)
diff --git a/nominatim/db/sqlalchemy_types/key_value.py b/nominatim/db/sqlalchemy_types/key_value.py
new file mode 100644 (file)
index 0000000..4f2d824
--- /dev/null
@@ -0,0 +1,47 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+A custom type that implements a simple key-value store of strings.
+"""
+from typing import Any
+
+import sqlalchemy as sa
+from sqlalchemy.dialects.postgresql import HSTORE
+from sqlalchemy.dialects.sqlite import JSON as sqlite_json
+
+from nominatim.typing import SaDialect, SaColumn
+
+# pylint: disable=all
+
+class KeyValueStore(sa.types.TypeDecorator[Any]):
+    """ Dialect-independent type of a simple key-value store of strings.
+    """
+    impl = HSTORE
+    cache_ok = True
+
+    def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
+        if dialect.name == 'postgresql':
+            return HSTORE() # type: ignore[no-untyped-call]
+
+        return sqlite_json(none_as_null=True)
+
+
+    class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
+
+        def merge(self, other: SaColumn) -> 'sa.Operators':
+            """ Merge the values from the given KeyValueStore into this
+                one, overwriting values where necessary. When the argument
+                is null, nothing happens.
+            """
+            return self.op('||')(sa.func.coalesce(other,
+                                                  sa.type_coerce('', KeyValueStore)))
+
+
+        def has_key(self, key: SaColumn) -> 'sa.Operators':
+            """ Return true if the key is cotained in the store.
+            """
+            return self.op('?', is_comparison=True)(key)
index 7274f1d396f8159b714c80fff14fd25b3455b345..62ecd8c3e169ce7340dca7c6eb6a83a7881cd3d5 100644 (file)
@@ -72,3 +72,4 @@ SaLabel: TypeAlias = 'sa.Label[Any]'
 SaFromClause: TypeAlias = 'sa.FromClause'
 SaSelectable: TypeAlias = 'sa.Selectable'
 SaBind: TypeAlias = 'sa.BindParameter[Any]'
 SaFromClause: TypeAlias = 'sa.FromClause'
 SaSelectable: TypeAlias = 'sa.Selectable'
 SaBind: TypeAlias = 'sa.BindParameter[Any]'
+SaDialect: TypeAlias = 'sa.Dialect'