]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/sqlalchemy_types.py
make details API work with sqlite incl. unit tests
[nominatim.git] / nominatim / db / sqlalchemy_types.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Custom types for SQLAlchemy.
9 """
10 from typing import Callable, Any, cast
11 import sys
12
13 import sqlalchemy as sa
14 from sqlalchemy.ext.compiler import compiles
15 from sqlalchemy import types
16
17 from nominatim.typing import SaColumn, SaBind
18
19 #pylint: disable=all
20
21 class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
22     """ Function to compute the spherical distance in meters.
23     """
24     type = sa.Float()
25     name = 'Geometry_DistanceSpheroid'
26     inherit_cache = True
27
28
29 @compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc]
30 def _default_distance_spheroid(element: SaColumn,
31                                compiler: 'sa.Compiled', **kw: Any) -> str:
32     return "ST_DistanceSpheroid(%s,"\
33            " 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\
34              % compiler.process(element.clauses, **kw)
35
36
37 @compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc]
38 def _spatialite_distance_spheroid(element: SaColumn,
39                                   compiler: 'sa.Compiled', **kw: Any) -> str:
40     return "Distance(%s, true)" % compiler.process(element.clauses, **kw)
41
42
43 class Geometry(types.UserDefinedType): # type: ignore[type-arg]
44     """ Simplified type decorator for PostGIS geometry. This type
45         only supports geometries in 4326 projection.
46     """
47     cache_ok = True
48
49     def __init__(self, subtype: str = 'Geometry'):
50         self.subtype = subtype
51
52
53     def get_col_spec(self) -> str:
54         return f'GEOMETRY({self.subtype}, 4326)'
55
56
57     def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
58         def process(value: Any) -> str:
59             if isinstance(value, str):
60                 return value
61
62             return cast(str, value.to_wkt())
63         return process
64
65
66     def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
67         def process(value: Any) -> str:
68             assert isinstance(value, str)
69             return value
70         return process
71
72
73     def column_expression(self, col: SaColumn) -> SaColumn:
74         return sa.func.ST_AsEWKB(col)
75
76
77     def bind_expression(self, bindvalue: SaBind) -> SaColumn:
78         return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
79
80
81     class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
82
83         def intersects(self, other: SaColumn) -> 'sa.Operators':
84             return self.op('&&')(other)
85
86         def is_line_like(self) -> SaColumn:
87             return sa.func.ST_GeometryType(self, type_=sa.String).in_(('ST_LineString',
88                                                                        'ST_MultiLineString'))
89
90         def is_area(self) -> SaColumn:
91             return sa.func.ST_GeometryType(self, type_=sa.String).in_(('ST_Polygon',
92                                                                        'ST_MultiPolygon'))
93
94
95         def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn:
96             return sa.func.ST_DWithin(self, other, distance, type_=sa.Boolean)
97
98
99         def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn:
100             return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self),
101                                       other, distance, type_=sa.Boolean)
102
103
104         def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators':
105             return sa.func.coalesce(sa.null(), self).op('&&')(other)
106
107
108         def ST_Distance(self, other: SaColumn) -> SaColumn:
109             return sa.func.ST_Distance(self, other, type_=sa.Float)
110
111
112         def ST_Contains(self, other: SaColumn) -> SaColumn:
113             return sa.func.ST_Contains(self, other, type_=sa.Boolean)
114
115
116         def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
117             return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
118
119
120         def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
121             return sa.func.ST_ClosestPoint(self, other, type_=Geometry)
122
123
124         def ST_Buffer(self, other: SaColumn) -> SaColumn:
125             return sa.func.ST_Buffer(self, other, type_=Geometry)
126
127
128         def ST_Expand(self, other: SaColumn) -> SaColumn:
129             return sa.func.ST_Expand(self, other, type_=Geometry)
130
131
132         def ST_Collect(self) -> SaColumn:
133             return sa.func.ST_Collect(self, type_=Geometry)
134
135
136         def ST_Centroid(self) -> SaColumn:
137             return sa.func.ST_Centroid(self, type_=Geometry)
138
139
140         def ST_LineInterpolatePoint(self, other: SaColumn) -> SaColumn:
141             return sa.func.ST_LineInterpolatePoint(self, other, type_=Geometry)
142
143
144         def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
145             return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
146
147
148         def distance_spheroid(self, other: SaColumn) -> SaColumn:
149             return Geometry_DistanceSpheroid(self, other)
150
151
152 @compiles(Geometry, 'sqlite') # type: ignore[no-untyped-call]
153 def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def]
154     return 'GEOMETRY'
155
156
157 SQLITE_FUNCTION_ALIAS = (
158     ('ST_AsEWKB', sa.Text, 'AsEWKB'),
159     ('ST_GeomFromEWKT', Geometry, 'GeomFromEWKT'),
160     ('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'),
161     ('ST_AsKML', sa.Text, 'AsKML'),
162     ('ST_AsSVG', sa.Text, 'AsSVG'),
163 )
164
165 def _add_function_alias(func: str, ftype: type, alias: str) -> None:
166     _FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
167         "type": ftype,
168         "name": func,
169         "identifier": func,
170         "inherit_cache": True})
171
172     func_templ = f"{alias}(%s)"
173
174     def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any:
175         return func_templ % compiler.process(element.clauses, **kw)
176
177     compiles(_FuncDef, 'sqlite')(_sqlite_impl) # type: ignore[no-untyped-call]
178
179 for alias in SQLITE_FUNCTION_ALIAS:
180     _add_function_alias(*alias)
181
182
183