]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/sqlalchemy_types/geometry.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / db / sqlalchemy_types / geometry.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Custom types for SQLAlchemy.
9 """
10 from __future__ import annotations
11 from typing import Callable, Any, cast
12 import sys
13
14 import sqlalchemy as sa
15 from sqlalchemy.ext.compiler import compiles
16 from sqlalchemy import types
17
18 from nominatim.typing import SaColumn, SaBind
19
20 #pylint: disable=all
21
22 class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
23     """ Function to compute the spherical distance in meters.
24     """
25     type = sa.Float()
26     name = 'Geometry_DistanceSpheroid'
27     inherit_cache = True
28
29
30 @compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc]
31 def _default_distance_spheroid(element: Geometry_DistanceSpheroid,
32                                compiler: 'sa.Compiled', **kw: Any) -> str:
33     return "ST_DistanceSpheroid(%s,"\
34            " 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\
35              % compiler.process(element.clauses, **kw)
36
37
38 @compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc]
39 def _spatialite_distance_spheroid(element: Geometry_DistanceSpheroid,
40                                   compiler: 'sa.Compiled', **kw: Any) -> str:
41     return "COALESCE(Distance(%s, true), 0.0)" % compiler.process(element.clauses, **kw)
42
43
44 class Geometry_IsLineLike(sa.sql.expression.FunctionElement[Any]):
45     """ Check if the geometry is a line or multiline.
46     """
47     name = 'Geometry_IsLineLike'
48     inherit_cache = True
49
50
51 @compiles(Geometry_IsLineLike) # type: ignore[no-untyped-call, misc]
52 def _default_is_line_like(element: Geometry_IsLineLike,
53                           compiler: 'sa.Compiled', **kw: Any) -> str:
54     return "ST_GeometryType(%s) IN ('ST_LineString', 'ST_MultiLineString')" % \
55                compiler.process(element.clauses, **kw)
56
57
58 @compiles(Geometry_IsLineLike, 'sqlite') # type: ignore[no-untyped-call, misc]
59 def _sqlite_is_line_like(element: Geometry_IsLineLike,
60                          compiler: 'sa.Compiled', **kw: Any) -> str:
61     return "ST_GeometryType(%s) IN ('LINESTRING', 'MULTILINESTRING')" % \
62                compiler.process(element.clauses, **kw)
63
64
65 class Geometry_IsAreaLike(sa.sql.expression.FunctionElement[Any]):
66     """ Check if the geometry is a polygon or multipolygon.
67     """
68     name = 'Geometry_IsLineLike'
69     inherit_cache = True
70
71
72 @compiles(Geometry_IsAreaLike) # type: ignore[no-untyped-call, misc]
73 def _default_is_area_like(element: Geometry_IsAreaLike,
74                           compiler: 'sa.Compiled', **kw: Any) -> str:
75     return "ST_GeometryType(%s) IN ('ST_Polygon', 'ST_MultiPolygon')" % \
76                compiler.process(element.clauses, **kw)
77
78
79 @compiles(Geometry_IsAreaLike, 'sqlite') # type: ignore[no-untyped-call, misc]
80 def _sqlite_is_area_like(element: Geometry_IsAreaLike,
81                          compiler: 'sa.Compiled', **kw: Any) -> str:
82     return "ST_GeometryType(%s) IN ('POLYGON', 'MULTIPOLYGON')" % \
83                compiler.process(element.clauses, **kw)
84
85
86 class Geometry_IntersectsBbox(sa.sql.expression.FunctionElement[Any]):
87     """ Check if the bounding boxes of the given geometries intersect.
88     """
89     name = 'Geometry_IntersectsBbox'
90     inherit_cache = True
91
92
93 @compiles(Geometry_IntersectsBbox) # type: ignore[no-untyped-call, misc]
94 def _default_intersects(element: Geometry_IntersectsBbox,
95                         compiler: 'sa.Compiled', **kw: Any) -> str:
96     arg1, arg2 = list(element.clauses)
97     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
98
99
100 @compiles(Geometry_IntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
101 def _sqlite_intersects(element: Geometry_IntersectsBbox,
102                        compiler: 'sa.Compiled', **kw: Any) -> str:
103     return "MbrIntersects(%s) = 1" % compiler.process(element.clauses, **kw)
104
105
106 class Geometry_ColumnIntersectsBbox(sa.sql.expression.FunctionElement[Any]):
107     """ Check if the bounding box of the geometry intersects with the
108         given table column, using the spatial index for the column.
109
110         The index must exist or the query may return nothing.
111     """
112     name = 'Geometry_ColumnIntersectsBbox'
113     inherit_cache = True
114
115
116 @compiles(Geometry_ColumnIntersectsBbox) # type: ignore[no-untyped-call, misc]
117 def default_intersects_column(element: Geometry_ColumnIntersectsBbox,
118                               compiler: 'sa.Compiled', **kw: Any) -> str:
119     arg1, arg2 = list(element.clauses)
120     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
121
122
123 @compiles(Geometry_ColumnIntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
124 def spatialite_intersects_column(element: Geometry_ColumnIntersectsBbox,
125                                  compiler: 'sa.Compiled', **kw: Any) -> str:
126     arg1, arg2 = list(element.clauses)
127     return "MbrIntersects(%s, %s) = 1 and "\
128            "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
129                         "WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
130                         "AND search_frame = %s)" %(
131               compiler.process(arg1, **kw),
132               compiler.process(arg2, **kw),
133               arg1.table.name, arg1.table.name, arg1.name,
134               compiler.process(arg2, **kw))
135
136
137 class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[Any]):
138     """ Check if the geometry is within the distance of the
139         given table column, using the spatial index for the column.
140
141         The index must exist or the query may return nothing.
142     """
143     name = 'Geometry_ColumnDWithin'
144     inherit_cache = True
145
146
147 @compiles(Geometry_ColumnDWithin) # type: ignore[no-untyped-call, misc]
148 def default_dwithin_column(element: Geometry_ColumnDWithin,
149                            compiler: 'sa.Compiled', **kw: Any) -> str:
150     return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw)
151
152 @compiles(Geometry_ColumnDWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
153 def spatialite_dwithin_column(element: Geometry_ColumnDWithin,
154                               compiler: 'sa.Compiled', **kw: Any) -> str:
155     geom1, geom2, dist = list(element.clauses)
156     return "ST_Distance(%s, %s) < %s and "\
157            "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
158                         "WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
159                         "AND search_frame = ST_Expand(%s, %s))" %(
160               compiler.process(geom1, **kw),
161               compiler.process(geom2, **kw),
162               compiler.process(dist, **kw),
163               geom1.table.name, geom1.table.name, geom1.name,
164               compiler.process(geom2, **kw),
165               compiler.process(dist, **kw))
166
167
168 class Geometry(types.UserDefinedType): # type: ignore[type-arg]
169     """ Simplified type decorator for PostGIS geometry. This type
170         only supports geometries in 4326 projection.
171     """
172     cache_ok = True
173
174     def __init__(self, subtype: str = 'Geometry'):
175         self.subtype = subtype
176
177
178     def get_col_spec(self) -> str:
179         return f'GEOMETRY({self.subtype}, 4326)'
180
181
182     def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
183         def process(value: Any) -> str:
184             if isinstance(value, str):
185                 return value
186
187             return cast(str, value.to_wkt())
188         return process
189
190
191     def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
192         def process(value: Any) -> str:
193             assert isinstance(value, str)
194             return value
195         return process
196
197
198     def column_expression(self, col: SaColumn) -> SaColumn:
199         return sa.func.ST_AsEWKB(col)
200
201
202     def bind_expression(self, bindvalue: SaBind) -> SaColumn:
203         return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
204
205
206     class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
207
208         def intersects(self, other: SaColumn, use_index: bool = True) -> 'sa.Operators':
209             if not use_index:
210                 return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self.expr), other)
211
212             if isinstance(self.expr, sa.Column):
213                 return Geometry_ColumnIntersectsBbox(self.expr, other)
214
215             return Geometry_IntersectsBbox(self.expr, other)
216
217
218         def is_line_like(self) -> SaColumn:
219             return Geometry_IsLineLike(self)
220
221
222         def is_area(self) -> SaColumn:
223             return Geometry_IsAreaLike(self)
224
225
226         def within_distance(self, other: SaColumn, distance: SaColumn) -> SaColumn:
227             if isinstance(self.expr, sa.Column):
228                 return Geometry_ColumnDWithin(self.expr, other, distance)
229
230             return self.ST_Distance(other) < distance
231
232
233         def ST_Distance(self, other: SaColumn) -> SaColumn:
234             return sa.func.ST_Distance(self, other, type_=sa.Float)
235
236
237         def ST_Contains(self, other: SaColumn) -> SaColumn:
238             return sa.func.ST_Contains(self, other, type_=sa.Boolean)
239
240
241         def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
242             return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
243
244
245         def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
246             return sa.func.coalesce(sa.func.ST_ClosestPoint(self, other, type_=Geometry),
247                                     other)
248
249
250         def ST_Buffer(self, other: SaColumn) -> SaColumn:
251             return sa.func.ST_Buffer(self, other, type_=Geometry)
252
253
254         def ST_Expand(self, other: SaColumn) -> SaColumn:
255             return sa.func.ST_Expand(self, other, type_=Geometry)
256
257
258         def ST_Collect(self) -> SaColumn:
259             return sa.func.ST_Collect(self, type_=Geometry)
260
261
262         def ST_Centroid(self) -> SaColumn:
263             return sa.func.ST_Centroid(self, type_=Geometry)
264
265
266         def ST_LineInterpolatePoint(self, other: SaColumn) -> SaColumn:
267             return sa.func.ST_LineInterpolatePoint(self, other, type_=Geometry)
268
269
270         def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
271             return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
272
273
274         def distance_spheroid(self, other: SaColumn) -> SaColumn:
275             return Geometry_DistanceSpheroid(self, other)
276
277
278 @compiles(Geometry, 'sqlite') # type: ignore[no-untyped-call]
279 def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def]
280     return 'GEOMETRY'
281
282
283 SQLITE_FUNCTION_ALIAS = (
284     ('ST_AsEWKB', sa.Text, 'AsEWKB'),
285     ('ST_GeomFromEWKT', Geometry, 'GeomFromEWKT'),
286     ('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'),
287     ('ST_AsKML', sa.Text, 'AsKML'),
288     ('ST_AsSVG', sa.Text, 'AsSVG'),
289     ('ST_LineLocatePoint', sa.Float, 'ST_Line_Locate_Point'),
290     ('ST_LineInterpolatePoint', sa.Float, 'ST_Line_Interpolate_Point'),
291 )
292
293 def _add_function_alias(func: str, ftype: type, alias: str) -> None:
294     _FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
295         "type": ftype(),
296         "name": func,
297         "identifier": func,
298         "inherit_cache": True})
299
300     func_templ = f"{alias}(%s)"
301
302     def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any:
303         return func_templ % compiler.process(element.clauses, **kw)
304
305     compiles(_FuncDef, 'sqlite')(_sqlite_impl) # type: ignore[no-untyped-call]
306
307 for alias in SQLITE_FUNCTION_ALIAS:
308     _add_function_alias(*alias)