]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/sql/sqlalchemy_types/geometry.py
90adcce850ec6c7d82c1b41c8a32065e7a3b49e7
[nominatim.git] / src / nominatim_api / sql / sqlalchemy_types / geometry.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Custom types for SQLAlchemy.
9 """
10 from __future__ import annotations
11 from typing import Callable, Any, cast
12
13 import sqlalchemy as sa
14 from sqlalchemy.ext.compiler import compiles
15 from sqlalchemy import types
16
17 from ...typing import SaColumn, SaBind
18
19
20 class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
21     """ Function to compute the spherical distance in meters.
22     """
23     type = sa.Float()
24     name = 'Geometry_DistanceSpheroid'
25     inherit_cache = True
26
27
28 @compiles(Geometry_DistanceSpheroid)
29 def _default_distance_spheroid(element: Geometry_DistanceSpheroid,
30                                compiler: 'sa.Compiled', **kw: Any) -> str:
31     return "ST_DistanceSpheroid(%s,"\
32            " 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\
33              % compiler.process(element.clauses, **kw)
34
35
36 @compiles(Geometry_DistanceSpheroid, 'sqlite')
37 def _spatialite_distance_spheroid(element: Geometry_DistanceSpheroid,
38                                   compiler: 'sa.Compiled', **kw: Any) -> str:
39     return "COALESCE(Distance(%s, true), 0.0)" % compiler.process(element.clauses, **kw)
40
41
42 class Geometry_IsLineLike(sa.sql.expression.FunctionElement[Any]):
43     """ Check if the geometry is a line or multiline.
44     """
45     name = 'Geometry_IsLineLike'
46     inherit_cache = True
47
48
49 @compiles(Geometry_IsLineLike)
50 def _default_is_line_like(element: Geometry_IsLineLike,
51                           compiler: 'sa.Compiled', **kw: Any) -> str:
52     return "ST_GeometryType(%s) IN ('ST_LineString', 'ST_MultiLineString')" % \
53                compiler.process(element.clauses, **kw)
54
55
56 @compiles(Geometry_IsLineLike, 'sqlite')
57 def _sqlite_is_line_like(element: Geometry_IsLineLike,
58                          compiler: 'sa.Compiled', **kw: Any) -> str:
59     return "ST_GeometryType(%s) IN ('LINESTRING', 'MULTILINESTRING')" % \
60                compiler.process(element.clauses, **kw)
61
62
63 class Geometry_IsAreaLike(sa.sql.expression.FunctionElement[Any]):
64     """ Check if the geometry is a polygon or multipolygon.
65     """
66     name = 'Geometry_IsLineLike'
67     inherit_cache = True
68
69
70 @compiles(Geometry_IsAreaLike)
71 def _default_is_area_like(element: Geometry_IsAreaLike,
72                           compiler: 'sa.Compiled', **kw: Any) -> str:
73     return "ST_GeometryType(%s) IN ('ST_Polygon', 'ST_MultiPolygon')" % \
74                compiler.process(element.clauses, **kw)
75
76
77 @compiles(Geometry_IsAreaLike, 'sqlite')
78 def _sqlite_is_area_like(element: Geometry_IsAreaLike,
79                          compiler: 'sa.Compiled', **kw: Any) -> str:
80     return "ST_GeometryType(%s) IN ('POLYGON', 'MULTIPOLYGON')" % \
81                compiler.process(element.clauses, **kw)
82
83
84 class Geometry_IntersectsBbox(sa.sql.expression.FunctionElement[Any]):
85     """ Check if the bounding boxes of the given geometries intersect.
86     """
87     name = 'Geometry_IntersectsBbox'
88     inherit_cache = True
89
90
91 @compiles(Geometry_IntersectsBbox)
92 def _default_intersects(element: Geometry_IntersectsBbox,
93                         compiler: 'sa.Compiled', **kw: Any) -> str:
94     arg1, arg2 = list(element.clauses)
95     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
96
97
98 @compiles(Geometry_IntersectsBbox, 'sqlite')
99 def _sqlite_intersects(element: Geometry_IntersectsBbox,
100                        compiler: 'sa.Compiled', **kw: Any) -> str:
101     return "MbrIntersects(%s) = 1" % compiler.process(element.clauses, **kw)
102
103
104 class Geometry_ColumnIntersectsBbox(sa.sql.expression.FunctionElement[Any]):
105     """ Check if the bounding box of the geometry intersects with the
106         given table column, using the spatial index for the column.
107
108         The index must exist or the query may return nothing.
109     """
110     name = 'Geometry_ColumnIntersectsBbox'
111     inherit_cache = True
112
113
114 @compiles(Geometry_ColumnIntersectsBbox)
115 def default_intersects_column(element: Geometry_ColumnIntersectsBbox,
116                               compiler: 'sa.Compiled', **kw: Any) -> str:
117     arg1, arg2 = list(element.clauses)
118     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
119
120
121 @compiles(Geometry_ColumnIntersectsBbox, 'sqlite')
122 def spatialite_intersects_column(element: Geometry_ColumnIntersectsBbox,
123                                  compiler: 'sa.Compiled', **kw: Any) -> str:
124     arg1, arg2 = list(element.clauses)
125     return "MbrIntersects(%s, %s) = 1 and "\
126            "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
127            "             WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
128            "             AND search_frame = %s)"\
129         % (compiler.process(arg1, **kw),
130            compiler.process(arg2, **kw),
131            arg1.table.name, arg1.table.name, arg1.name,
132            compiler.process(arg2, **kw))
133
134
135 class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[Any]):
136     """ Check if the geometry is within the distance of the
137         given table column, using the spatial index for the column.
138
139         The index must exist or the query may return nothing.
140     """
141     name = 'Geometry_ColumnDWithin'
142     inherit_cache = True
143
144
145 @compiles(Geometry_ColumnDWithin)
146 def default_dwithin_column(element: Geometry_ColumnDWithin,
147                            compiler: 'sa.Compiled', **kw: Any) -> str:
148     return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw)
149
150
151 @compiles(Geometry_ColumnDWithin, 'sqlite')
152 def spatialite_dwithin_column(element: Geometry_ColumnDWithin,
153                               compiler: 'sa.Compiled', **kw: Any) -> str:
154     geom1, geom2, dist = list(element.clauses)
155     return "ST_Distance(%s, %s) < %s and "\
156            "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
157            "             WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
158            "             AND search_frame = ST_Expand(%s, %s))"\
159         % (compiler.process(geom1, **kw),
160            compiler.process(geom2, **kw),
161            compiler.process(dist, **kw),
162            geom1.table.name, geom1.table.name, geom1.name,
163            compiler.process(geom2, **kw),
164            compiler.process(dist, **kw))
165
166
167 class Geometry(types.UserDefinedType):  # type: ignore[type-arg]
168     """ Simplified type decorator for PostGIS geometry. This type
169         only supports geometries in 4326 projection.
170     """
171     cache_ok = True
172
173     def __init__(self, subtype: str = 'Geometry'):
174         self.subtype = subtype
175
176     def get_col_spec(self) -> str:
177         return f'GEOMETRY({self.subtype}, 4326)'
178
179     def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
180         def process(value: Any) -> str:
181             if isinstance(value, str):
182                 return value
183
184             return cast(str, value.to_wkt())
185         return process
186
187     def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
188         def process(value: Any) -> str:
189             assert isinstance(value, str)
190             return value
191         return process
192
193     def column_expression(self, col: SaColumn) -> SaColumn:
194         return sa.func.ST_AsEWKB(col)
195
196     def bind_expression(self, bindvalue: SaBind) -> SaColumn:
197         return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
198
199     class comparator_factory(types.UserDefinedType.Comparator):  # type: ignore[type-arg]
200
201         def intersects(self, other: SaColumn, use_index: bool = True) -> 'sa.Operators':
202             if not use_index:
203                 return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self.expr), other)
204
205             if isinstance(self.expr, sa.Column):
206                 return Geometry_ColumnIntersectsBbox(self.expr, other)
207
208             return Geometry_IntersectsBbox(self.expr, other)
209
210         def is_line_like(self) -> SaColumn:
211             return Geometry_IsLineLike(self)
212
213         def is_area(self) -> SaColumn:
214             return Geometry_IsAreaLike(self)
215
216         def within_distance(self, other: SaColumn, distance: SaColumn) -> SaColumn:
217             if isinstance(self.expr, sa.Column):
218                 return Geometry_ColumnDWithin(self.expr, other, distance)
219
220             return self.ST_Distance(other) < distance
221
222         def ST_Distance(self, other: SaColumn) -> SaColumn:
223             return sa.func.ST_Distance(self, other, type_=sa.Float)
224
225         def ST_Contains(self, other: SaColumn) -> SaColumn:
226             return sa.func.ST_Contains(self, other, type_=sa.Boolean)
227
228         def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
229             return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
230
231         def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
232             return sa.func.coalesce(sa.func.ST_ClosestPoint(self, other, type_=Geometry),
233                                     other)
234
235         def ST_Buffer(self, other: SaColumn) -> SaColumn:
236             return sa.func.ST_Buffer(self, other, type_=Geometry)
237
238         def ST_Expand(self, other: SaColumn) -> SaColumn:
239             return sa.func.ST_Expand(self, other, type_=Geometry)
240
241         def ST_Collect(self) -> SaColumn:
242             return sa.func.ST_Collect(self, type_=Geometry)
243
244         def ST_Centroid(self) -> SaColumn:
245             return sa.func.ST_Centroid(self, type_=Geometry)
246
247         def ST_LineInterpolatePoint(self, other: SaColumn) -> SaColumn:
248             return sa.func.ST_LineInterpolatePoint(self, other, type_=Geometry)
249
250         def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
251             return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
252
253         def distance_spheroid(self, other: SaColumn) -> SaColumn:
254             return Geometry_DistanceSpheroid(self, other)
255
256
257 @compiles(Geometry, 'sqlite')
258 def get_col_spec(self, *args, **kwargs):  # type: ignore[no-untyped-def]
259     return 'GEOMETRY'
260
261
262 SQLITE_FUNCTION_ALIAS = (
263     ('ST_AsEWKB', sa.Text, 'AsEWKB'),
264     ('ST_GeomFromEWKT', Geometry, 'GeomFromEWKT'),
265     ('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'),
266     ('ST_AsKML', sa.Text, 'AsKML'),
267     ('ST_AsSVG', sa.Text, 'AsSVG'),
268     ('ST_LineLocatePoint', sa.Float, 'ST_Line_Locate_Point'),
269     ('ST_LineInterpolatePoint', sa.Float, 'ST_Line_Interpolate_Point'),
270 )
271
272
273 def _add_function_alias(func: str, ftype: type, alias: str) -> None:
274     _FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
275         "type": ftype(),
276         "name": func,
277         "identifier": func,
278         "inherit_cache": True})
279
280     func_templ = f"{alias}(%s)"
281
282     def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any:
283         return func_templ % compiler.process(element.clauses, **kw)
284
285     compiles(_FuncDef, 'sqlite')(_sqlite_impl)
286
287
288 for alias in SQLITE_FUNCTION_ALIAS:
289     _add_function_alias(*alias)