"""
Custom types for SQLAlchemy.
"""
-from typing import Callable, Any
+from typing import Callable, Any, cast
+import sys
import sqlalchemy as sa
-import sqlalchemy.types as types
+from sqlalchemy import types
-from nominatim.typing import SaColumn
+from nominatim.typing import SaColumn, SaBind
-class Geometry(types.UserDefinedType[Any]):
+#pylint: disable=all
+
+class Geometry(types.UserDefinedType): # type: ignore[type-arg]
""" Simplified type decorator for PostGIS geometry. This type
only supports geometries in 4326 projection.
"""
return f'GEOMETRY({self.subtype}, 4326)'
- def bind_processor(self, dialect: sa.Dialect) -> Callable[[Any], str]:
+ def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
def process(value: Any) -> str:
- assert isinstance(value, str)
- return value
+ if isinstance(value, str):
+ return value
+
+ return cast(str, value.to_wkt())
return process
- def result_processor(self, dialect: sa.Dialect, coltype: object) -> Callable[[Any], str]:
+ def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
def process(value: Any) -> str:
assert isinstance(value, str)
return value
return process
- def bind_expression(self, bindvalue: sa.BindParameter[Any]) -> SaColumn:
- return sa.func.ST_GeomFromText(bindvalue, type_=self)
+ def bind_expression(self, bindvalue: SaBind) -> SaColumn:
+ return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
- class comparator_factory(types.UserDefinedType.Comparator):
+ class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
+
+ def intersects(self, other: SaColumn) -> 'sa.Operators':
+ return self.op('&&')(other)
def is_line_like(self) -> SaColumn:
return sa.func.ST_GeometryType(self, type_=sa.String).in_(('ST_LineString',
def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn:
- return sa.func.ST_DWithin(self, other, distance, type_=sa.Float)
+ return sa.func.ST_DWithin(self, other, distance, type_=sa.Boolean)
+
+
+ def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn:
+ return sa.func._ST_DWithin(self, other, distance, type_=sa.Boolean)
+
+
+ def ST_Intersects_no_index(self, other: SaColumn) -> SaColumn:
+ return sa.func._ST_Intersects(self, other, type_=sa.Boolean)
def ST_Distance(self, other: SaColumn) -> SaColumn:
def ST_Contains(self, other: SaColumn) -> SaColumn:
- return sa.func.ST_Contains(self, other, type_=sa.Float)
+ return sa.func.ST_Contains(self, other, type_=sa.Boolean)
+
+
+ def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
+ return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
return sa.func.ST_Expand(self, other, type_=Geometry)
+ def ST_Collect(self) -> SaColumn:
+ return sa.func.ST_Collect(self, type_=Geometry)
+
+
def ST_Centroid(self) -> SaColumn:
return sa.func.ST_Centroid(self, type_=Geometry)