1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Custom types for SQLAlchemy.
10 from typing import Callable, Any, cast
13 import sqlalchemy as sa
14 from sqlalchemy.ext.compiler import compiles
15 from sqlalchemy import types
17 from nominatim.typing import SaColumn, SaBind
21 class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
22 """ Function to compute the spherical distance in meters.
25 name = 'Geometry_DistanceSpheroid'
29 @compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc]
30 def _default_distance_spheroid(element: SaColumn,
31 compiler: 'sa.Compiled', **kw: Any) -> str:
32 return "ST_DistanceSpheroid(%s,"\
33 " 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\
34 % compiler.process(element.clauses, **kw)
37 @compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc]
38 def _spatialite_distance_spheroid(element: SaColumn,
39 compiler: 'sa.Compiled', **kw: Any) -> str:
40 return "Distance(%s, true)" % compiler.process(element.clauses, **kw)
43 class Geometry(types.UserDefinedType): # type: ignore[type-arg]
44 """ Simplified type decorator for PostGIS geometry. This type
45 only supports geometries in 4326 projection.
49 def __init__(self, subtype: str = 'Geometry'):
50 self.subtype = subtype
53 def get_col_spec(self) -> str:
54 return f'GEOMETRY({self.subtype}, 4326)'
57 def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
58 def process(value: Any) -> str:
59 if isinstance(value, str):
62 return cast(str, value.to_wkt())
66 def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
67 def process(value: Any) -> str:
68 assert isinstance(value, str)
73 def column_expression(self, col: SaColumn) -> SaColumn:
74 return sa.func.ST_AsEWKB(col)
77 def bind_expression(self, bindvalue: SaBind) -> SaColumn:
78 return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
81 class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
83 def intersects(self, other: SaColumn) -> 'sa.Operators':
84 return self.op('&&')(other)
86 def is_line_like(self) -> SaColumn:
87 return sa.func.ST_GeometryType(self, type_=sa.String).in_(('ST_LineString',
88 'ST_MultiLineString'))
90 def is_area(self) -> SaColumn:
91 return sa.func.ST_GeometryType(self, type_=sa.String).in_(('ST_Polygon',
95 def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn:
96 return sa.func.ST_DWithin(self, other, distance, type_=sa.Boolean)
99 def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn:
100 return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self),
101 other, distance, type_=sa.Boolean)
104 def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators':
105 return sa.func.coalesce(sa.null(), self).op('&&')(other)
108 def ST_Distance(self, other: SaColumn) -> SaColumn:
109 return sa.func.ST_Distance(self, other, type_=sa.Float)
112 def ST_Contains(self, other: SaColumn) -> SaColumn:
113 return sa.func.ST_Contains(self, other, type_=sa.Boolean)
116 def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
117 return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
120 def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
121 return sa.func.ST_ClosestPoint(self, other, type_=Geometry)
124 def ST_Buffer(self, other: SaColumn) -> SaColumn:
125 return sa.func.ST_Buffer(self, other, type_=Geometry)
128 def ST_Expand(self, other: SaColumn) -> SaColumn:
129 return sa.func.ST_Expand(self, other, type_=Geometry)
132 def ST_Collect(self) -> SaColumn:
133 return sa.func.ST_Collect(self, type_=Geometry)
136 def ST_Centroid(self) -> SaColumn:
137 return sa.func.ST_Centroid(self, type_=Geometry)
140 def ST_LineInterpolatePoint(self, other: SaColumn) -> SaColumn:
141 return sa.func.ST_LineInterpolatePoint(self, other, type_=Geometry)
144 def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
145 return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
148 def distance_spheroid(self, other: SaColumn) -> SaColumn:
149 return Geometry_DistanceSpheroid(self, other)
152 @compiles(Geometry, 'sqlite') # type: ignore[no-untyped-call]
153 def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def]
157 SQLITE_FUNCTION_ALIAS = (
158 ('ST_AsEWKB', sa.Text, 'AsEWKB'),
159 ('ST_GeomFromEWKT', Geometry, 'GeomFromEWKT'),
160 ('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'),
161 ('ST_AsKML', sa.Text, 'AsKML'),
162 ('ST_AsSVG', sa.Text, 'AsSVG'),
165 def _add_function_alias(func: str, ftype: type, alias: str) -> None:
166 _FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
170 "inherit_cache": True})
172 func_templ = f"{alias}(%s)"
174 def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any:
175 return func_templ % compiler.process(element.clauses, **kw)
177 compiles(_FuncDef, 'sqlite')(_sqlite_impl) # type: ignore[no-untyped-call]
179 for alias in SQLITE_FUNCTION_ALIAS:
180 _add_function_alias(*alias)