]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/sqlalchemy_types.py
make code work with Spatialite 4.3
[nominatim.git] / nominatim / db / sqlalchemy_types.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Custom types for SQLAlchemy.
9 """
10 from __future__ import annotations
11 from typing import Callable, Any, cast
12 import sys
13
14 import sqlalchemy as sa
15 from sqlalchemy.ext.compiler import compiles
16 from sqlalchemy import types
17
18 from nominatim.typing import SaColumn, SaBind
19
20 #pylint: disable=all
21
22 class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
23     """ Function to compute the spherical distance in meters.
24     """
25     type = sa.Float()
26     name = 'Geometry_DistanceSpheroid'
27     inherit_cache = True
28
29
30 @compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc]
31 def _default_distance_spheroid(element: SaColumn,
32                                compiler: 'sa.Compiled', **kw: Any) -> str:
33     return "ST_DistanceSpheroid(%s,"\
34            " 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\
35              % compiler.process(element.clauses, **kw)
36
37
38 @compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc]
39 def _spatialite_distance_spheroid(element: SaColumn,
40                                   compiler: 'sa.Compiled', **kw: Any) -> str:
41     return "COALESCE(Distance(%s, true), 0.0)" % compiler.process(element.clauses, **kw)
42
43
44 class Geometry_IsLineLike(sa.sql.expression.FunctionElement[bool]):
45     """ Check if the geometry is a line or multiline.
46     """
47     type = sa.Boolean()
48     name = 'Geometry_IsLineLike'
49     inherit_cache = True
50
51
52 @compiles(Geometry_IsLineLike) # type: ignore[no-untyped-call, misc]
53 def _default_is_line_like(element: SaColumn,
54                           compiler: 'sa.Compiled', **kw: Any) -> str:
55     return "ST_GeometryType(%s) IN ('ST_LineString', 'ST_MultiLineString')" % \
56                compiler.process(element.clauses, **kw)
57
58
59 @compiles(Geometry_IsLineLike, 'sqlite') # type: ignore[no-untyped-call, misc]
60 def _sqlite_is_line_like(element: SaColumn,
61                          compiler: 'sa.Compiled', **kw: Any) -> str:
62     return "ST_GeometryType(%s) IN ('LINESTRING', 'MULTILINESTRING')" % \
63                compiler.process(element.clauses, **kw)
64
65
66 class Geometry_IsAreaLike(sa.sql.expression.FunctionElement[bool]):
67     """ Check if the geometry is a polygon or multipolygon.
68     """
69     type = sa.Boolean()
70     name = 'Geometry_IsLineLike'
71     inherit_cache = True
72
73
74 @compiles(Geometry_IsAreaLike) # type: ignore[no-untyped-call, misc]
75 def _default_is_area_like(element: SaColumn,
76                           compiler: 'sa.Compiled', **kw: Any) -> str:
77     return "ST_GeometryType(%s) IN ('ST_Polygon', 'ST_MultiPolygon')" % \
78                compiler.process(element.clauses, **kw)
79
80
81 @compiles(Geometry_IsAreaLike, 'sqlite') # type: ignore[no-untyped-call, misc]
82 def _sqlite_is_area_like(element: SaColumn,
83                          compiler: 'sa.Compiled', **kw: Any) -> str:
84     return "ST_GeometryType(%s) IN ('POLYGON', 'MULTIPOLYGON')" % \
85                compiler.process(element.clauses, **kw)
86
87
88 class Geometry_IntersectsBbox(sa.sql.expression.FunctionElement[bool]):
89     """ Check if the bounding boxes of the given geometries intersect.
90     """
91     type = sa.Boolean()
92     name = 'Geometry_IntersectsBbox'
93     inherit_cache = True
94
95
96 @compiles(Geometry_IntersectsBbox) # type: ignore[no-untyped-call, misc]
97 def _default_intersects(element: SaColumn,
98                         compiler: 'sa.Compiled', **kw: Any) -> str:
99     arg1, arg2 = list(element.clauses)
100     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
101
102
103 @compiles(Geometry_IntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
104 def _sqlite_intersects(element: SaColumn,
105                        compiler: 'sa.Compiled', **kw: Any) -> str:
106     return "MbrIntersects(%s)" % compiler.process(element.clauses, **kw)
107
108
109 class Geometry_ColumnIntersectsBbox(sa.sql.expression.FunctionElement[bool]):
110     """ Check if the bounding box of the geometry intersects with the
111         given table column, using the spatial index for the column.
112
113         The index must exist or the query may return nothing.
114     """
115     type = sa.Boolean()
116     name = 'Geometry_ColumnIntersectsBbox'
117     inherit_cache = True
118
119
120 @compiles(Geometry_ColumnIntersectsBbox) # type: ignore[no-untyped-call, misc]
121 def default_intersects_column(element: SaColumn,
122                               compiler: 'sa.Compiled', **kw: Any) -> str:
123     arg1, arg2 = list(element.clauses)
124     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
125
126
127 @compiles(Geometry_ColumnIntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
128 def spatialite_intersects_column(element: SaColumn,
129                                  compiler: 'sa.Compiled', **kw: Any) -> str:
130     arg1, arg2 = list(element.clauses)
131     return "MbrIntersects(%s, %s) and "\
132            "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
133                         "WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
134                         "AND search_frame = %s)" %(
135               compiler.process(arg1, **kw),
136               compiler.process(arg2, **kw),
137               arg1.table.name, arg1.table.name, arg1.name,
138               compiler.process(arg2, **kw))
139
140
141 class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[bool]):
142     """ Check if the geometry is within the distance of the
143         given table column, using the spatial index for the column.
144
145         The index must exist or the query may return nothing.
146     """
147     type = sa.Boolean()
148     name = 'Geometry_ColumnDWithin'
149     inherit_cache = True
150
151
152 @compiles(Geometry_ColumnDWithin) # type: ignore[no-untyped-call, misc]
153 def default_dwithin_column(element: SaColumn,
154                            compiler: 'sa.Compiled', **kw: Any) -> str:
155     return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw)
156
157 @compiles(Geometry_ColumnDWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
158 def spatialite_dwithin_column(element: SaColumn,
159                               compiler: 'sa.Compiled', **kw: Any) -> str:
160     geom1, geom2, dist = list(element.clauses)
161     return "ST_Distance(%s, %s) < %s and "\
162            "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
163                         "WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
164                         "AND search_frame = ST_Expand(%s, %s))" %(
165               compiler.process(geom1, **kw),
166               compiler.process(geom2, **kw),
167               compiler.process(dist, **kw),
168               geom1.table.name, geom1.table.name, geom1.name,
169               compiler.process(geom2, **kw),
170               compiler.process(dist, **kw))
171
172
173
174 class Geometry(types.UserDefinedType): # type: ignore[type-arg]
175     """ Simplified type decorator for PostGIS geometry. This type
176         only supports geometries in 4326 projection.
177     """
178     cache_ok = True
179
180     def __init__(self, subtype: str = 'Geometry'):
181         self.subtype = subtype
182
183
184     def get_col_spec(self) -> str:
185         return f'GEOMETRY({self.subtype}, 4326)'
186
187
188     def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
189         def process(value: Any) -> str:
190             if isinstance(value, str):
191                 return value
192
193             return cast(str, value.to_wkt())
194         return process
195
196
197     def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
198         def process(value: Any) -> str:
199             assert isinstance(value, str)
200             return value
201         return process
202
203
204     def column_expression(self, col: SaColumn) -> SaColumn:
205         return sa.func.ST_AsEWKB(col)
206
207
208     def bind_expression(self, bindvalue: SaBind) -> SaColumn:
209         return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
210
211
212     class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
213
214         def intersects(self, other: SaColumn) -> 'sa.Operators':
215             if isinstance(self.expr, sa.Column):
216                 return Geometry_ColumnIntersectsBbox(self.expr, other)
217
218             return Geometry_IntersectsBbox(self.expr, other)
219
220
221         def is_line_like(self) -> SaColumn:
222             return Geometry_IsLineLike(self)
223
224
225         def is_area(self) -> SaColumn:
226             return Geometry_IsAreaLike(self)
227
228
229         def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn:
230             if isinstance(self.expr, sa.Column):
231                 return Geometry_ColumnDWithin(self.expr, other, distance)
232
233             return sa.func.ST_DWithin(self.expr, other, distance)
234
235
236         def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn:
237             return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self),
238                                       other, distance)
239
240
241         def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators':
242             return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self), other)
243
244
245         def ST_Distance(self, other: SaColumn) -> SaColumn:
246             return sa.func.ST_Distance(self, other, type_=sa.Float)
247
248
249         def ST_Contains(self, other: SaColumn) -> SaColumn:
250             return sa.func.ST_Contains(self, other, type_=sa.Boolean)
251
252
253         def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
254             return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
255
256
257         def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
258             return sa.func.coalesce(sa.func.ST_ClosestPoint(self, other, type_=Geometry),
259                                     other)
260
261
262         def ST_Buffer(self, other: SaColumn) -> SaColumn:
263             return sa.func.ST_Buffer(self, other, type_=Geometry)
264
265
266         def ST_Expand(self, other: SaColumn) -> SaColumn:
267             return sa.func.ST_Expand(self, other, type_=Geometry)
268
269
270         def ST_Collect(self) -> SaColumn:
271             return sa.func.ST_Collect(self, type_=Geometry)
272
273
274         def ST_Centroid(self) -> SaColumn:
275             return sa.func.ST_Centroid(self, type_=Geometry)
276
277
278         def ST_LineInterpolatePoint(self, other: SaColumn) -> SaColumn:
279             return sa.func.ST_LineInterpolatePoint(self, other, type_=Geometry)
280
281
282         def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
283             return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
284
285
286         def distance_spheroid(self, other: SaColumn) -> SaColumn:
287             return Geometry_DistanceSpheroid(self, other)
288
289
290 @compiles(Geometry, 'sqlite') # type: ignore[no-untyped-call]
291 def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def]
292     return 'GEOMETRY'
293
294
295 SQLITE_FUNCTION_ALIAS = (
296     ('ST_AsEWKB', sa.Text, 'AsEWKB'),
297     ('ST_GeomFromEWKT', Geometry, 'GeomFromEWKT'),
298     ('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'),
299     ('ST_AsKML', sa.Text, 'AsKML'),
300     ('ST_AsSVG', sa.Text, 'AsSVG'),
301     ('ST_LineLocatePoint', sa.Float, 'ST_Line_Locate_Point'),
302     ('ST_LineInterpolatePoint', sa.Float, 'ST_Line_Interpolate_Point'),
303 )
304
305 def _add_function_alias(func: str, ftype: type, alias: str) -> None:
306     _FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
307         "type": ftype(),
308         "name": func,
309         "identifier": func,
310         "inherit_cache": True})
311
312     func_templ = f"{alias}(%s)"
313
314     def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any:
315         return func_templ % compiler.process(element.clauses, **kw)
316
317     compiles(_FuncDef, 'sqlite')(_sqlite_impl) # type: ignore[no-untyped-call]
318
319 for alias in SQLITE_FUNCTION_ALIAS:
320     _add_function_alias(*alias)
321
322
323 class ST_DWithin(sa.sql.functions.GenericFunction[bool]):
324     type = sa.Boolean()
325     name = 'ST_DWithin'
326     inherit_cache = True
327
328
329 @compiles(ST_DWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
330 def default_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
331     geom1, geom2, dist = list(element.clauses)
332     return "(MbrIntersects(%s, ST_Expand(%s, %s)) = 1 AND ST_Distance(%s, %s) <= %s)" % (
333         compiler.process(geom1, **kw), compiler.process(geom2, **kw),
334         compiler.process(dist, **kw),
335         compiler.process(geom1, **kw), compiler.process(geom2, **kw),
336         compiler.process(dist, **kw))