]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/sqlalchemy_types.py
make lookup call work with sqlite
[nominatim.git] / nominatim / db / sqlalchemy_types.py
index 5131dad3fd5bb7e4b09621187a75e094c0417908..9d1e48fae31e3c1763a28abe43ce13a37e1dd03c 100644 (file)
@@ -7,14 +7,43 @@
 """
 Custom types for SQLAlchemy.
 """
-from typing import Callable, Any
+from typing import Callable, Any, cast
+import sys
 
 import sqlalchemy as sa
-import sqlalchemy.types as types
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy import types
 
-from nominatim.typing import SaColumn
+from nominatim.typing import SaColumn, SaBind
 
-class Geometry(types.UserDefinedType[Any]):
+#pylint: disable=all
+
+SQLITE_FUNCTION_ALIAS = (
+    ('ST_AsEWKB', sa.Text, 'AsEWKB'),
+    ('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'),
+    ('ST_AsKML', sa.Text, 'AsKML'),
+    ('ST_AsSVG', sa.Text, 'AsSVG'),
+)
+
+def _add_function_alias(func: str, ftype: type, alias: str) -> None:
+    _FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
+        "type": ftype,
+        "name": func,
+        "identifier": func,
+        "inherit_cache": True})
+
+    func_templ = f"{alias}(%s)"
+
+    def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any:
+        return func_templ % compiler.process(element.clauses, **kw)
+
+    compiles(_FuncDef, 'sqlite')(_sqlite_impl) # type: ignore[no-untyped-call]
+
+for alias in SQLITE_FUNCTION_ALIAS:
+    _add_function_alias(*alias)
+
+
+class Geometry(types.UserDefinedType): # type: ignore[type-arg]
     """ Simplified type decorator for PostGIS geometry. This type
         only supports geometries in 4326 projection.
     """
@@ -28,27 +57,33 @@ class Geometry(types.UserDefinedType[Any]):
         return f'GEOMETRY({self.subtype}, 4326)'
 
 
-    def bind_processor(self, dialect: sa.Dialect) -> Callable[[Any], str]:
+    def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
         def process(value: Any) -> str:
-            assert isinstance(value, str)
-            return value
+            if isinstance(value, str):
+                return value
+
+            return cast(str, value.to_wkt())
         return process
 
 
-    def result_processor(self, dialect: sa.Dialect, coltype: object) -> Callable[[Any], str]:
+    def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
         def process(value: Any) -> str:
             assert isinstance(value, str)
             return value
         return process
 
 
-    def bind_expression(self, bindvalue: sa.BindParameter[Any]) -> SaColumn:
-        return sa.func.ST_GeomFromText(bindvalue, type_=self)
+    def column_expression(self, col: SaColumn) -> SaColumn:
+        return sa.func.ST_AsEWKB(col)
+
 
+    def bind_expression(self, bindvalue: SaBind) -> SaColumn:
+        return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
 
-    class comparator_factory(types.UserDefinedType.Comparator):
 
-        def intersects(self, other: SaColumn) -> SaColumn:
+    class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
+
+        def intersects(self, other: SaColumn) -> 'sa.Operators':
             return self.op('&&')(other)
 
         def is_line_like(self) -> SaColumn:
@@ -61,7 +96,16 @@ class Geometry(types.UserDefinedType[Any]):
 
 
         def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn:
-            return sa.func.ST_DWithin(self, other, distance, type_=sa.Float)
+            return sa.func.ST_DWithin(self, other, distance, type_=sa.Boolean)
+
+
+        def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn:
+            return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self),
+                                      other, distance, type_=sa.Boolean)
+
+
+        def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators':
+            return sa.func.coalesce(sa.null(), self).op('&&')(other)
 
 
         def ST_Distance(self, other: SaColumn) -> SaColumn:
@@ -69,7 +113,11 @@ class Geometry(types.UserDefinedType[Any]):
 
 
         def ST_Contains(self, other: SaColumn) -> SaColumn:
-            return sa.func.ST_Contains(self, other, type_=sa.Float)
+            return sa.func.ST_Contains(self, other, type_=sa.Boolean)
+
+
+        def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
+            return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
 
 
         def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
@@ -84,6 +132,10 @@ class Geometry(types.UserDefinedType[Any]):
             return sa.func.ST_Expand(self, other, type_=Geometry)
 
 
+        def ST_Collect(self) -> SaColumn:
+            return sa.func.ST_Collect(self, type_=Geometry)
+
+
         def ST_Centroid(self) -> SaColumn:
             return sa.func.ST_Centroid(self, type_=Geometry)
 
@@ -94,3 +146,8 @@ class Geometry(types.UserDefinedType[Any]):
 
         def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
             return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
+
+
+@compiles(Geometry, 'sqlite') # type: ignore[no-untyped-call]
+def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def]
+    return 'GEOMETRY'