X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/4bb4db0668a37979725678a1690f6163d5cae03f..c29ffc38e6cef4bb99fd40060be8243ea70e5939:/nominatim/db/sqlalchemy_types.py diff --git a/nominatim/db/sqlalchemy_types.py b/nominatim/db/sqlalchemy_types.py index 88cae29f..7d3789aa 100644 --- a/nominatim/db/sqlalchemy_types.py +++ b/nominatim/db/sqlalchemy_types.py @@ -7,14 +7,17 @@ """ Custom types for SQLAlchemy. """ -from typing import Callable, Any +from typing import Callable, Any, cast +import sys import sqlalchemy as sa -import sqlalchemy.types as types +from sqlalchemy import types -from nominatim.typing import SaColumn +from nominatim.typing import SaColumn, SaBind -class Geometry(types.UserDefinedType[Any]): +#pylint: disable=all + +class Geometry(types.UserDefinedType): # type: ignore[type-arg] """ Simplified type decorator for PostGIS geometry. This type only supports geometries in 4326 projection. """ @@ -28,25 +31,30 @@ class Geometry(types.UserDefinedType[Any]): return f'GEOMETRY({self.subtype}, 4326)' - def bind_processor(self, dialect: sa.Dialect) -> Callable[[Any], str]: + def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]: def process(value: Any) -> str: - assert isinstance(value, str) - return value + if isinstance(value, str): + return value + + return cast(str, value.to_wkt()) return process - def result_processor(self, dialect: sa.Dialect, coltype: object) -> Callable[[Any], str]: + def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]: def process(value: Any) -> str: assert isinstance(value, str) return value return process - def bind_expression(self, bindvalue: sa.BindParameter[Any]) -> SaColumn: - return sa.func.ST_GeomFromText(bindvalue, type_=self) + def bind_expression(self, bindvalue: SaBind) -> SaColumn: + return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self) + + class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg] - class comparator_factory(types.UserDefinedType.Comparator): + def intersects(self, other: SaColumn) -> 'sa.Operators': + return self.op('&&')(other) def is_line_like(self) -> SaColumn: return sa.func.ST_GeometryType(self, type_=sa.String).in_(('ST_LineString', @@ -66,7 +74,11 @@ class Geometry(types.UserDefinedType[Any]): def ST_Contains(self, other: SaColumn) -> SaColumn: - return sa.func.ST_Contains(self, other, type_=sa.Float) + return sa.func.ST_Contains(self, other, type_=sa.Boolean) + + + def ST_CoveredBy(self, other: SaColumn) -> SaColumn: + return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean) def ST_ClosestPoint(self, other: SaColumn) -> SaColumn: @@ -81,6 +93,10 @@ class Geometry(types.UserDefinedType[Any]): return sa.func.ST_Expand(self, other, type_=Geometry) + def ST_Collect(self) -> SaColumn: + return sa.func.ST_Collect(self, type_=Geometry) + + def ST_Centroid(self) -> SaColumn: return sa.func.ST_Centroid(self, type_=Geometry)