]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/sqlalchemy_types.py
Merge pull request #3128 from lonvia/fix-classtype-lookup
[nominatim.git] / nominatim / db / sqlalchemy_types.py
index c54d339e6d903b202ec1ad7e551ee881681c8856..7d3789aa0d3b44854583aeb09686d8edaa1c421e 100644 (file)
@@ -7,14 +7,17 @@
 """
 Custom types for SQLAlchemy.
 """
-from typing import Callable, Any
+from typing import Callable, Any, cast
+import sys
 
 import sqlalchemy as sa
-import sqlalchemy.types as types
+from sqlalchemy import types
 
-from nominatim.typing import SaColumn
+from nominatim.typing import SaColumn, SaBind
 
-class Geometry(types.UserDefinedType[Any]):
+#pylint: disable=all
+
+class Geometry(types.UserDefinedType): # type: ignore[type-arg]
     """ Simplified type decorator for PostGIS geometry. This type
         only supports geometries in 4326 projection.
     """
@@ -31,9 +34,9 @@ class Geometry(types.UserDefinedType[Any]):
     def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
         def process(value: Any) -> str:
             if isinstance(value, str):
-                return 'SRID=4326;' + value
+                return value
 
-            return 'SRID=4326;' + value.to_wkt()
+            return cast(str, value.to_wkt())
         return process
 
 
@@ -44,13 +47,13 @@ class Geometry(types.UserDefinedType[Any]):
         return process
 
 
-    def bind_expression(self, bindvalue: 'sa.BindParameter[Any]') -> SaColumn:
-        return sa.func.ST_GeomFromText(bindvalue, type_=self)
+    def bind_expression(self, bindvalue: SaBind) -> SaColumn:
+        return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
 
 
-    class comparator_factory(types.UserDefinedType.Comparator):
+    class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
 
-        def intersects(self, other: SaColumn) -> SaColumn:
+        def intersects(self, other: SaColumn) -> 'sa.Operators':
             return self.op('&&')(other)
 
         def is_line_like(self) -> SaColumn:
@@ -71,7 +74,11 @@ class Geometry(types.UserDefinedType[Any]):
 
 
         def ST_Contains(self, other: SaColumn) -> SaColumn:
-            return sa.func.ST_Contains(self, other, type_=sa.Float)
+            return sa.func.ST_Contains(self, other, type_=sa.Boolean)
+
+
+        def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
+            return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
 
 
         def ST_ClosestPoint(self, other: SaColumn) -> SaColumn: