X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/42631b85c7445b3acae9d907330fd2544f3366a1..c29ffc38e6cef4bb99fd40060be8243ea70e5939:/nominatim/db/sqlalchemy_types.py diff --git a/nominatim/db/sqlalchemy_types.py b/nominatim/db/sqlalchemy_types.py index c54d339e..7d3789aa 100644 --- a/nominatim/db/sqlalchemy_types.py +++ b/nominatim/db/sqlalchemy_types.py @@ -7,14 +7,17 @@ """ Custom types for SQLAlchemy. """ -from typing import Callable, Any +from typing import Callable, Any, cast +import sys import sqlalchemy as sa -import sqlalchemy.types as types +from sqlalchemy import types -from nominatim.typing import SaColumn +from nominatim.typing import SaColumn, SaBind -class Geometry(types.UserDefinedType[Any]): +#pylint: disable=all + +class Geometry(types.UserDefinedType): # type: ignore[type-arg] """ Simplified type decorator for PostGIS geometry. This type only supports geometries in 4326 projection. """ @@ -31,9 +34,9 @@ class Geometry(types.UserDefinedType[Any]): def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]: def process(value: Any) -> str: if isinstance(value, str): - return 'SRID=4326;' + value + return value - return 'SRID=4326;' + value.to_wkt() + return cast(str, value.to_wkt()) return process @@ -44,13 +47,13 @@ class Geometry(types.UserDefinedType[Any]): return process - def bind_expression(self, bindvalue: 'sa.BindParameter[Any]') -> SaColumn: - return sa.func.ST_GeomFromText(bindvalue, type_=self) + def bind_expression(self, bindvalue: SaBind) -> SaColumn: + return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self) - class comparator_factory(types.UserDefinedType.Comparator): + class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg] - def intersects(self, other: SaColumn) -> SaColumn: + def intersects(self, other: SaColumn) -> 'sa.Operators': return self.op('&&')(other) def is_line_like(self) -> SaColumn: @@ -71,7 +74,11 @@ class Geometry(types.UserDefinedType[Any]): def ST_Contains(self, other: SaColumn) -> SaColumn: - return sa.func.ST_Contains(self, other, type_=sa.Float) + return sa.func.ST_Contains(self, other, type_=sa.Boolean) + + + def ST_CoveredBy(self, other: SaColumn) -> SaColumn: + return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean) def ST_ClosestPoint(self, other: SaColumn) -> SaColumn: