]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/sqlalchemy_types.py
Added missing return types to functions
[nominatim.git] / nominatim / db / sqlalchemy_types.py
index e34dc7c188106c8f031c6e53474264409664a324..a36e8c462acfce3b4cc5e730b2eb5c008f1dfa14 100644 (file)
 """
 Custom types for SQLAlchemy.
 """
-from typing import Callable, Any
+from __future__ import annotations
+from typing import Callable, Any, cast
+import sys
 
 import sqlalchemy as sa
-import sqlalchemy.types as types
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy import types
 
-from nominatim.typing import SaColumn
+from nominatim.typing import SaColumn, SaBind
 
-class Geometry(types.UserDefinedType[Any]):
+#pylint: disable=all
+
+class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
+    """ Function to compute the spherical distance in meters.
+    """
+    type = sa.Float()
+    name = 'Geometry_DistanceSpheroid'
+    inherit_cache = True
+
+
+@compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc]
+def _default_distance_spheroid(element: SaColumn,
+                               compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "ST_DistanceSpheroid(%s,"\
+           " 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\
+             % compiler.process(element.clauses, **kw)
+
+
+@compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _spatialite_distance_spheroid(element: SaColumn,
+                                  compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "COALESCE(Distance(%s, true), 0.0)" % compiler.process(element.clauses, **kw)
+
+
+class Geometry_IsLineLike(sa.sql.expression.FunctionElement[Any]):
+    """ Check if the geometry is a line or multiline.
+    """
+    name = 'Geometry_IsLineLike'
+    inherit_cache = True
+
+
+@compiles(Geometry_IsLineLike) # type: ignore[no-untyped-call, misc]
+def _default_is_line_like(element: SaColumn,
+                          compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "ST_GeometryType(%s) IN ('ST_LineString', 'ST_MultiLineString')" % \
+               compiler.process(element.clauses, **kw)
+
+
+@compiles(Geometry_IsLineLike, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _sqlite_is_line_like(element: SaColumn,
+                         compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "ST_GeometryType(%s) IN ('LINESTRING', 'MULTILINESTRING')" % \
+               compiler.process(element.clauses, **kw)
+
+
+class Geometry_IsAreaLike(sa.sql.expression.FunctionElement[Any]):
+    """ Check if the geometry is a polygon or multipolygon.
+    """
+    name = 'Geometry_IsLineLike'
+    inherit_cache = True
+
+
+@compiles(Geometry_IsAreaLike) # type: ignore[no-untyped-call, misc]
+def _default_is_area_like(element: SaColumn,
+                          compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "ST_GeometryType(%s) IN ('ST_Polygon', 'ST_MultiPolygon')" % \
+               compiler.process(element.clauses, **kw)
+
+
+@compiles(Geometry_IsAreaLike, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _sqlite_is_area_like(element: SaColumn,
+                         compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "ST_GeometryType(%s) IN ('POLYGON', 'MULTIPOLYGON')" % \
+               compiler.process(element.clauses, **kw)
+
+
+class Geometry_IntersectsBbox(sa.sql.expression.FunctionElement[Any]):
+    """ Check if the bounding boxes of the given geometries intersect.
+    """
+    name = 'Geometry_IntersectsBbox'
+    inherit_cache = True
+
+
+@compiles(Geometry_IntersectsBbox) # type: ignore[no-untyped-call, misc]
+def _default_intersects(element: SaColumn,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
+
+
+@compiles(Geometry_IntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _sqlite_intersects(element: SaColumn,
+                       compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "MbrIntersects(%s) = 1" % compiler.process(element.clauses, **kw)
+
+
+class Geometry_ColumnIntersectsBbox(sa.sql.expression.FunctionElement[Any]):
+    """ Check if the bounding box of the geometry intersects with the
+        given table column, using the spatial index for the column.
+
+        The index must exist or the query may return nothing.
+    """
+    name = 'Geometry_ColumnIntersectsBbox'
+    inherit_cache = True
+
+
+@compiles(Geometry_ColumnIntersectsBbox) # type: ignore[no-untyped-call, misc]
+def default_intersects_column(element: SaColumn,
+                              compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
+
+
+@compiles(Geometry_ColumnIntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
+def spatialite_intersects_column(element: SaColumn,
+                                 compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "MbrIntersects(%s, %s) = 1 and "\
+           "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
+                        "WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
+                        "AND search_frame = %s)" %(
+              compiler.process(arg1, **kw),
+              compiler.process(arg2, **kw),
+              arg1.table.name, arg1.table.name, arg1.name,
+              compiler.process(arg2, **kw))
+
+
+class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[Any]):
+    """ Check if the geometry is within the distance of the
+        given table column, using the spatial index for the column.
+
+        The index must exist or the query may return nothing.
+    """
+    name = 'Geometry_ColumnDWithin'
+    inherit_cache = True
+
+
+@compiles(Geometry_ColumnDWithin) # type: ignore[no-untyped-call, misc]
+def default_dwithin_column(element: SaColumn,
+                           compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw)
+
+@compiles(Geometry_ColumnDWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
+def spatialite_dwithin_column(element: SaColumn,
+                              compiler: 'sa.Compiled', **kw: Any) -> str:
+    geom1, geom2, dist = list(element.clauses)
+    return "ST_Distance(%s, %s) < %s and "\
+           "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
+                        "WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
+                        "AND search_frame = ST_Expand(%s, %s))" %(
+              compiler.process(geom1, **kw),
+              compiler.process(geom2, **kw),
+              compiler.process(dist, **kw),
+              geom1.table.name, geom1.table.name, geom1.name,
+              compiler.process(geom2, **kw),
+              compiler.process(dist, **kw))
+
+
+
+class Geometry(types.UserDefinedType): # type: ignore[type-arg]
     """ Simplified type decorator for PostGIS geometry. This type
         only supports geometries in 4326 projection.
     """
@@ -28,42 +180,61 @@ class Geometry(types.UserDefinedType[Any]):
         return f'GEOMETRY({self.subtype}, 4326)'
 
 
-    def bind_processor(self, dialect: sa.Dialect) -> Callable[[Any], str]:
+    def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
         def process(value: Any) -> str:
             if isinstance(value, str):
-                return 'SRID=4326;' + value
+                return value
 
-            return 'SRID=4326;' + value.to_wkt()
+            return cast(str, value.to_wkt())
         return process
 
 
-    def result_processor(self, dialect: sa.Dialect, coltype: object) -> Callable[[Any], str]:
+    def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
         def process(value: Any) -> str:
             assert isinstance(value, str)
             return value
         return process
 
 
-    def bind_expression(self, bindvalue: sa.BindParameter[Any]) -> SaColumn:
-        return sa.func.ST_GeomFromText(bindvalue, type_=self)
+    def column_expression(self, col: SaColumn) -> SaColumn:
+        return sa.func.ST_AsEWKB(col)
 
 
-    class comparator_factory(types.UserDefinedType.Comparator):
+    def bind_expression(self, bindvalue: SaBind) -> SaColumn:
+        return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
+
+
+    class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
+
+        def intersects(self, other: SaColumn) -> 'sa.Operators':
+            if isinstance(self.expr, sa.Column):
+                return Geometry_ColumnIntersectsBbox(self.expr, other)
+
+            return Geometry_IntersectsBbox(self.expr, other)
 
-        def intersects(self, other: SaColumn) -> SaColumn:
-            return self.op('&&')(other)
 
         def is_line_like(self) -> SaColumn:
-            return sa.func.ST_GeometryType(self, type_=sa.String).in_(('ST_LineString',
-                                                                       'ST_MultiLineString'))
+            return Geometry_IsLineLike(self)
+
 
         def is_area(self) -> SaColumn:
-            return sa.func.ST_GeometryType(self, type_=sa.String).in_(('ST_Polygon',
-                                                                       'ST_MultiPolygon'))
+            return Geometry_IsAreaLike(self)
 
 
         def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn:
-            return sa.func.ST_DWithin(self, other, distance, type_=sa.Float)
+            if isinstance(self.expr, sa.Column):
+                return Geometry_ColumnDWithin(self.expr, other, distance)
+
+            return sa.func.ST_DWithin(self.expr, other, distance)
+
+
+        def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn:
+            return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self),
+                                      other, distance)
+
+
+        def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators':
+            return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self), other)
 
 
         def ST_Distance(self, other: SaColumn) -> SaColumn:
@@ -71,11 +242,16 @@ class Geometry(types.UserDefinedType[Any]):
 
 
         def ST_Contains(self, other: SaColumn) -> SaColumn:
-            return sa.func.ST_Contains(self, other, type_=sa.Float)
+            return sa.func.ST_Contains(self, other, type_=sa.Boolean)
+
+
+        def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
+            return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
 
 
         def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
-            return sa.func.ST_ClosestPoint(self, other, type_=Geometry)
+            return sa.func.coalesce(sa.func.ST_ClosestPoint(self, other, type_=Geometry),
+                                    other)
 
 
         def ST_Buffer(self, other: SaColumn) -> SaColumn:
@@ -100,3 +276,55 @@ class Geometry(types.UserDefinedType[Any]):
 
         def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
             return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
+
+
+        def distance_spheroid(self, other: SaColumn) -> SaColumn:
+            return Geometry_DistanceSpheroid(self, other)
+
+
+@compiles(Geometry, 'sqlite') # type: ignore[no-untyped-call]
+def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def]
+    return 'GEOMETRY'
+
+
+SQLITE_FUNCTION_ALIAS = (
+    ('ST_AsEWKB', sa.Text, 'AsEWKB'),
+    ('ST_GeomFromEWKT', Geometry, 'GeomFromEWKT'),
+    ('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'),
+    ('ST_AsKML', sa.Text, 'AsKML'),
+    ('ST_AsSVG', sa.Text, 'AsSVG'),
+    ('ST_LineLocatePoint', sa.Float, 'ST_Line_Locate_Point'),
+    ('ST_LineInterpolatePoint', sa.Float, 'ST_Line_Interpolate_Point'),
+)
+
+def _add_function_alias(func: str, ftype: type, alias: str) -> None:
+    _FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
+        "type": ftype(),
+        "name": func,
+        "identifier": func,
+        "inherit_cache": True})
+
+    func_templ = f"{alias}(%s)"
+
+    def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any:
+        return func_templ % compiler.process(element.clauses, **kw)
+
+    compiles(_FuncDef, 'sqlite')(_sqlite_impl) # type: ignore[no-untyped-call]
+
+for alias in SQLITE_FUNCTION_ALIAS:
+    _add_function_alias(*alias)
+
+
+class ST_DWithin(sa.sql.functions.GenericFunction[Any]):
+    name = 'ST_DWithin'
+    inherit_cache = True
+
+
+@compiles(ST_DWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
+def default_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
+    geom1, geom2, dist = list(element.clauses)
+    return "(MbrIntersects(%s, ST_Expand(%s, %s)) = 1 AND ST_Distance(%s, %s) <= %s)" % (
+        compiler.process(geom1, **kw), compiler.process(geom2, **kw),
+        compiler.process(dist, **kw),
+        compiler.process(geom1, **kw), compiler.process(geom2, **kw),
+        compiler.process(dist, **kw))