]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/sqlalchemy_types.py
adapt typing to newest version of SQLAlchemy
[nominatim.git] / nominatim / db / sqlalchemy_types.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Custom types for SQLAlchemy.
9 """
10 from __future__ import annotations
11 from typing import Callable, Any, cast
12 import sys
13
14 import sqlalchemy as sa
15 from sqlalchemy.ext.compiler import compiles
16 from sqlalchemy import types
17
18 from nominatim.typing import SaColumn, SaBind
19
20 #pylint: disable=all
21
22 class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
23     """ Function to compute the spherical distance in meters.
24     """
25     type = sa.Float()
26     name = 'Geometry_DistanceSpheroid'
27     inherit_cache = True
28
29
30 @compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc]
31 def _default_distance_spheroid(element: SaColumn,
32                                compiler: 'sa.Compiled', **kw: Any) -> str:
33     return "ST_DistanceSpheroid(%s,"\
34            " 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\
35              % compiler.process(element.clauses, **kw)
36
37
38 @compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc]
39 def _spatialite_distance_spheroid(element: SaColumn,
40                                   compiler: 'sa.Compiled', **kw: Any) -> str:
41     return "COALESCE(Distance(%s, true), 0.0)" % compiler.process(element.clauses, **kw)
42
43
44 class Geometry_IsLineLike(sa.sql.expression.FunctionElement[Any]):
45     """ Check if the geometry is a line or multiline.
46     """
47     name = 'Geometry_IsLineLike'
48     inherit_cache = True
49
50
51 @compiles(Geometry_IsLineLike) # type: ignore[no-untyped-call, misc]
52 def _default_is_line_like(element: SaColumn,
53                           compiler: 'sa.Compiled', **kw: Any) -> str:
54     return "ST_GeometryType(%s) IN ('ST_LineString', 'ST_MultiLineString')" % \
55                compiler.process(element.clauses, **kw)
56
57
58 @compiles(Geometry_IsLineLike, 'sqlite') # type: ignore[no-untyped-call, misc]
59 def _sqlite_is_line_like(element: SaColumn,
60                          compiler: 'sa.Compiled', **kw: Any) -> str:
61     return "ST_GeometryType(%s) IN ('LINESTRING', 'MULTILINESTRING')" % \
62                compiler.process(element.clauses, **kw)
63
64
65 class Geometry_IsAreaLike(sa.sql.expression.FunctionElement[Any]):
66     """ Check if the geometry is a polygon or multipolygon.
67     """
68     name = 'Geometry_IsLineLike'
69     inherit_cache = True
70
71
72 @compiles(Geometry_IsAreaLike) # type: ignore[no-untyped-call, misc]
73 def _default_is_area_like(element: SaColumn,
74                           compiler: 'sa.Compiled', **kw: Any) -> str:
75     return "ST_GeometryType(%s) IN ('ST_Polygon', 'ST_MultiPolygon')" % \
76                compiler.process(element.clauses, **kw)
77
78
79 @compiles(Geometry_IsAreaLike, 'sqlite') # type: ignore[no-untyped-call, misc]
80 def _sqlite_is_area_like(element: SaColumn,
81                          compiler: 'sa.Compiled', **kw: Any) -> str:
82     return "ST_GeometryType(%s) IN ('POLYGON', 'MULTIPOLYGON')" % \
83                compiler.process(element.clauses, **kw)
84
85
86 class Geometry_IntersectsBbox(sa.sql.expression.FunctionElement[Any]):
87     """ Check if the bounding boxes of the given geometries intersect.
88     """
89     name = 'Geometry_IntersectsBbox'
90     inherit_cache = True
91
92
93 @compiles(Geometry_IntersectsBbox) # type: ignore[no-untyped-call, misc]
94 def _default_intersects(element: SaColumn,
95                         compiler: 'sa.Compiled', **kw: Any) -> str:
96     arg1, arg2 = list(element.clauses)
97     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
98
99
100 @compiles(Geometry_IntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
101 def _sqlite_intersects(element: SaColumn,
102                        compiler: 'sa.Compiled', **kw: Any) -> str:
103     return "MbrIntersects(%s) = 1" % compiler.process(element.clauses, **kw)
104
105
106 class Geometry_ColumnIntersectsBbox(sa.sql.expression.FunctionElement[Any]):
107     """ Check if the bounding box of the geometry intersects with the
108         given table column, using the spatial index for the column.
109
110         The index must exist or the query may return nothing.
111     """
112     name = 'Geometry_ColumnIntersectsBbox'
113     inherit_cache = True
114
115
116 @compiles(Geometry_ColumnIntersectsBbox) # type: ignore[no-untyped-call, misc]
117 def default_intersects_column(element: SaColumn,
118                               compiler: 'sa.Compiled', **kw: Any) -> str:
119     arg1, arg2 = list(element.clauses)
120     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
121
122
123 @compiles(Geometry_ColumnIntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
124 def spatialite_intersects_column(element: SaColumn,
125                                  compiler: 'sa.Compiled', **kw: Any) -> str:
126     arg1, arg2 = list(element.clauses)
127     return "MbrIntersects(%s, %s) = 1 and "\
128            "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
129                         "WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
130                         "AND search_frame = %s)" %(
131               compiler.process(arg1, **kw),
132               compiler.process(arg2, **kw),
133               arg1.table.name, arg1.table.name, arg1.name,
134               compiler.process(arg2, **kw))
135
136
137 class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[Any]):
138     """ Check if the geometry is within the distance of the
139         given table column, using the spatial index for the column.
140
141         The index must exist or the query may return nothing.
142     """
143     name = 'Geometry_ColumnDWithin'
144     inherit_cache = True
145
146
147 @compiles(Geometry_ColumnDWithin) # type: ignore[no-untyped-call, misc]
148 def default_dwithin_column(element: SaColumn,
149                            compiler: 'sa.Compiled', **kw: Any) -> str:
150     return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw)
151
152 @compiles(Geometry_ColumnDWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
153 def spatialite_dwithin_column(element: SaColumn,
154                               compiler: 'sa.Compiled', **kw: Any) -> str:
155     geom1, geom2, dist = list(element.clauses)
156     return "ST_Distance(%s, %s) < %s and "\
157            "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
158                         "WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
159                         "AND search_frame = ST_Expand(%s, %s))" %(
160               compiler.process(geom1, **kw),
161               compiler.process(geom2, **kw),
162               compiler.process(dist, **kw),
163               geom1.table.name, geom1.table.name, geom1.name,
164               compiler.process(geom2, **kw),
165               compiler.process(dist, **kw))
166
167
168
169 class Geometry(types.UserDefinedType): # type: ignore[type-arg]
170     """ Simplified type decorator for PostGIS geometry. This type
171         only supports geometries in 4326 projection.
172     """
173     cache_ok = True
174
175     def __init__(self, subtype: str = 'Geometry'):
176         self.subtype = subtype
177
178
179     def get_col_spec(self) -> str:
180         return f'GEOMETRY({self.subtype}, 4326)'
181
182
183     def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
184         def process(value: Any) -> str:
185             if isinstance(value, str):
186                 return value
187
188             return cast(str, value.to_wkt())
189         return process
190
191
192     def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
193         def process(value: Any) -> str:
194             assert isinstance(value, str)
195             return value
196         return process
197
198
199     def column_expression(self, col: SaColumn) -> SaColumn:
200         return sa.func.ST_AsEWKB(col)
201
202
203     def bind_expression(self, bindvalue: SaBind) -> SaColumn:
204         return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
205
206
207     class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
208
209         def intersects(self, other: SaColumn) -> 'sa.Operators':
210             if isinstance(self.expr, sa.Column):
211                 return Geometry_ColumnIntersectsBbox(self.expr, other)
212
213             return Geometry_IntersectsBbox(self.expr, other)
214
215
216         def is_line_like(self) -> SaColumn:
217             return Geometry_IsLineLike(self)
218
219
220         def is_area(self) -> SaColumn:
221             return Geometry_IsAreaLike(self)
222
223
224         def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn:
225             if isinstance(self.expr, sa.Column):
226                 return Geometry_ColumnDWithin(self.expr, other, distance)
227
228             return sa.func.ST_DWithin(self.expr, other, distance)
229
230
231         def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn:
232             return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self),
233                                       other, distance)
234
235
236         def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators':
237             return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self), other)
238
239
240         def ST_Distance(self, other: SaColumn) -> SaColumn:
241             return sa.func.ST_Distance(self, other, type_=sa.Float)
242
243
244         def ST_Contains(self, other: SaColumn) -> SaColumn:
245             return sa.func.ST_Contains(self, other, type_=sa.Boolean)
246
247
248         def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
249             return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
250
251
252         def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
253             return sa.func.coalesce(sa.func.ST_ClosestPoint(self, other, type_=Geometry),
254                                     other)
255
256
257         def ST_Buffer(self, other: SaColumn) -> SaColumn:
258             return sa.func.ST_Buffer(self, other, type_=Geometry)
259
260
261         def ST_Expand(self, other: SaColumn) -> SaColumn:
262             return sa.func.ST_Expand(self, other, type_=Geometry)
263
264
265         def ST_Collect(self) -> SaColumn:
266             return sa.func.ST_Collect(self, type_=Geometry)
267
268
269         def ST_Centroid(self) -> SaColumn:
270             return sa.func.ST_Centroid(self, type_=Geometry)
271
272
273         def ST_LineInterpolatePoint(self, other: SaColumn) -> SaColumn:
274             return sa.func.ST_LineInterpolatePoint(self, other, type_=Geometry)
275
276
277         def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
278             return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
279
280
281         def distance_spheroid(self, other: SaColumn) -> SaColumn:
282             return Geometry_DistanceSpheroid(self, other)
283
284
285 @compiles(Geometry, 'sqlite') # type: ignore[no-untyped-call]
286 def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def]
287     return 'GEOMETRY'
288
289
290 SQLITE_FUNCTION_ALIAS = (
291     ('ST_AsEWKB', sa.Text, 'AsEWKB'),
292     ('ST_GeomFromEWKT', Geometry, 'GeomFromEWKT'),
293     ('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'),
294     ('ST_AsKML', sa.Text, 'AsKML'),
295     ('ST_AsSVG', sa.Text, 'AsSVG'),
296     ('ST_LineLocatePoint', sa.Float, 'ST_Line_Locate_Point'),
297     ('ST_LineInterpolatePoint', sa.Float, 'ST_Line_Interpolate_Point'),
298 )
299
300 def _add_function_alias(func: str, ftype: type, alias: str) -> None:
301     _FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
302         "type": ftype(),
303         "name": func,
304         "identifier": func,
305         "inherit_cache": True})
306
307     func_templ = f"{alias}(%s)"
308
309     def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any:
310         return func_templ % compiler.process(element.clauses, **kw)
311
312     compiles(_FuncDef, 'sqlite')(_sqlite_impl) # type: ignore[no-untyped-call]
313
314 for alias in SQLITE_FUNCTION_ALIAS:
315     _add_function_alias(*alias)
316
317
318 class ST_DWithin(sa.sql.functions.GenericFunction[Any]):
319     name = 'ST_DWithin'
320     inherit_cache = True
321
322
323 @compiles(ST_DWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
324 def default_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
325     geom1, geom2, dist = list(element.clauses)
326     return "(MbrIntersects(%s, ST_Expand(%s, %s)) = 1 AND ST_Distance(%s, %s) <= %s)" % (
327         compiler.process(geom1, **kw), compiler.process(geom2, **kw),
328         compiler.process(dist, **kw),
329         compiler.process(geom1, **kw), compiler.process(geom2, **kw),
330         compiler.process(dist, **kw))