X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/42631b85c7445b3acae9d907330fd2544f3366a1..07e6c5cf6923b99df15d89e587dbdf9109d0836e:/nominatim/db/sqlalchemy_types.py diff --git a/nominatim/db/sqlalchemy_types.py b/nominatim/db/sqlalchemy_types.py index c54d339e..8e8cc9c8 100644 --- a/nominatim/db/sqlalchemy_types.py +++ b/nominatim/db/sqlalchemy_types.py @@ -7,14 +7,40 @@ """ Custom types for SQLAlchemy. """ -from typing import Callable, Any +from typing import Callable, Any, cast +import sys import sqlalchemy as sa -import sqlalchemy.types as types +from sqlalchemy.ext.compiler import compiles +from sqlalchemy import types -from nominatim.typing import SaColumn +from nominatim.typing import SaColumn, SaBind -class Geometry(types.UserDefinedType[Any]): +#pylint: disable=all + +class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]): + """ Function to compute the spherical distance in meters. + """ + type = sa.Float() + name = 'Geometry_DistanceSpheroid' + inherit_cache = True + + +@compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc] +def _default_distance_spheroid(element: SaColumn, + compiler: 'sa.Compiled', **kw: Any) -> str: + return "ST_DistanceSpheroid(%s,"\ + " 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\ + % compiler.process(element.clauses, **kw) + + +@compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc] +def _spatialite_distance_spheroid(element: SaColumn, + compiler: 'sa.Compiled', **kw: Any) -> str: + return "Distance(%s, true)" % compiler.process(element.clauses, **kw) + + +class Geometry(types.UserDefinedType): # type: ignore[type-arg] """ Simplified type decorator for PostGIS geometry. This type only supports geometries in 4326 projection. """ @@ -31,9 +57,9 @@ class Geometry(types.UserDefinedType[Any]): def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]: def process(value: Any) -> str: if isinstance(value, str): - return 'SRID=4326;' + value + return value - return 'SRID=4326;' + value.to_wkt() + return cast(str, value.to_wkt()) return process @@ -44,13 +70,17 @@ class Geometry(types.UserDefinedType[Any]): return process - def bind_expression(self, bindvalue: 'sa.BindParameter[Any]') -> SaColumn: - return sa.func.ST_GeomFromText(bindvalue, type_=self) + def column_expression(self, col: SaColumn) -> SaColumn: + return sa.func.ST_AsEWKB(col) - class comparator_factory(types.UserDefinedType.Comparator): + def bind_expression(self, bindvalue: SaBind) -> SaColumn: + return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self) - def intersects(self, other: SaColumn) -> SaColumn: + + class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg] + + def intersects(self, other: SaColumn) -> 'sa.Operators': return self.op('&&')(other) def is_line_like(self) -> SaColumn: @@ -63,7 +93,16 @@ class Geometry(types.UserDefinedType[Any]): def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn: - return sa.func.ST_DWithin(self, other, distance, type_=sa.Float) + return sa.func.ST_DWithin(self, other, distance, type_=sa.Boolean) + + + def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn: + return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self), + other, distance, type_=sa.Boolean) + + + def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators': + return sa.func.coalesce(sa.null(), self).op('&&')(other) def ST_Distance(self, other: SaColumn) -> SaColumn: @@ -71,7 +110,11 @@ class Geometry(types.UserDefinedType[Any]): def ST_Contains(self, other: SaColumn) -> SaColumn: - return sa.func.ST_Contains(self, other, type_=sa.Float) + return sa.func.ST_Contains(self, other, type_=sa.Boolean) + + + def ST_CoveredBy(self, other: SaColumn) -> SaColumn: + return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean) def ST_ClosestPoint(self, other: SaColumn) -> SaColumn: @@ -100,3 +143,41 @@ class Geometry(types.UserDefinedType[Any]): def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn: return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float) + + + def distance_spheroid(self, other: SaColumn) -> SaColumn: + return Geometry_DistanceSpheroid(self, other) + + +@compiles(Geometry, 'sqlite') # type: ignore[no-untyped-call] +def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def] + return 'GEOMETRY' + + +SQLITE_FUNCTION_ALIAS = ( + ('ST_AsEWKB', sa.Text, 'AsEWKB'), + ('ST_GeomFromEWKT', Geometry, 'GeomFromEWKT'), + ('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'), + ('ST_AsKML', sa.Text, 'AsKML'), + ('ST_AsSVG', sa.Text, 'AsSVG'), +) + +def _add_function_alias(func: str, ftype: type, alias: str) -> None: + _FuncDef = type(func, (sa.sql.functions.GenericFunction, ), { + "type": ftype, + "name": func, + "identifier": func, + "inherit_cache": True}) + + func_templ = f"{alias}(%s)" + + def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any: + return func_templ % compiler.process(element.clauses, **kw) + + compiles(_FuncDef, 'sqlite')(_sqlite_impl) # type: ignore[no-untyped-call] + +for alias in SQLITE_FUNCTION_ALIAS: + _add_function_alias(*alias) + + +