]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/sqlalchemy_types.py
make code work with Spatialite 4.3
[nominatim.git] / nominatim / db / sqlalchemy_types.py
index df2bf1504e11ef6b87bebc96a679d0c20294e2ef..036b25dd9ac48e2500b4f31e8d7dcc6221bd38bf 100644 (file)
@@ -38,7 +38,7 @@ def _default_distance_spheroid(element: SaColumn,
 @compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc]
 def _spatialite_distance_spheroid(element: SaColumn,
                                   compiler: 'sa.Compiled', **kw: Any) -> str:
-    return "Distance(%s, true)" % compiler.process(element.clauses, **kw)
+    return "COALESCE(Distance(%s, true), 0.0)" % compiler.process(element.clauses, **kw)
 
 
 class Geometry_IsLineLike(sa.sql.expression.FunctionElement[bool]):
@@ -106,6 +106,71 @@ def _sqlite_intersects(element: SaColumn,
     return "MbrIntersects(%s)" % compiler.process(element.clauses, **kw)
 
 
+class Geometry_ColumnIntersectsBbox(sa.sql.expression.FunctionElement[bool]):
+    """ Check if the bounding box of the geometry intersects with the
+        given table column, using the spatial index for the column.
+
+        The index must exist or the query may return nothing.
+    """
+    type = sa.Boolean()
+    name = 'Geometry_ColumnIntersectsBbox'
+    inherit_cache = True
+
+
+@compiles(Geometry_ColumnIntersectsBbox) # type: ignore[no-untyped-call, misc]
+def default_intersects_column(element: SaColumn,
+                              compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
+
+
+@compiles(Geometry_ColumnIntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
+def spatialite_intersects_column(element: SaColumn,
+                                 compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "MbrIntersects(%s, %s) and "\
+           "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
+                        "WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
+                        "AND search_frame = %s)" %(
+              compiler.process(arg1, **kw),
+              compiler.process(arg2, **kw),
+              arg1.table.name, arg1.table.name, arg1.name,
+              compiler.process(arg2, **kw))
+
+
+class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[bool]):
+    """ Check if the geometry is within the distance of the
+        given table column, using the spatial index for the column.
+
+        The index must exist or the query may return nothing.
+    """
+    type = sa.Boolean()
+    name = 'Geometry_ColumnDWithin'
+    inherit_cache = True
+
+
+@compiles(Geometry_ColumnDWithin) # type: ignore[no-untyped-call, misc]
+def default_dwithin_column(element: SaColumn,
+                           compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw)
+
+@compiles(Geometry_ColumnDWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
+def spatialite_dwithin_column(element: SaColumn,
+                              compiler: 'sa.Compiled', **kw: Any) -> str:
+    geom1, geom2, dist = list(element.clauses)
+    return "ST_Distance(%s, %s) < %s and "\
+           "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
+                        "WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
+                        "AND search_frame = ST_Expand(%s, %s))" %(
+              compiler.process(geom1, **kw),
+              compiler.process(geom2, **kw),
+              compiler.process(dist, **kw),
+              geom1.table.name, geom1.table.name, geom1.name,
+              compiler.process(geom2, **kw),
+              compiler.process(dist, **kw))
+
+
+
 class Geometry(types.UserDefinedType): # type: ignore[type-arg]
     """ Simplified type decorator for PostGIS geometry. This type
         only supports geometries in 4326 projection.
@@ -147,7 +212,10 @@ class Geometry(types.UserDefinedType): # type: ignore[type-arg]
     class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
 
         def intersects(self, other: SaColumn) -> 'sa.Operators':
-            return Geometry_IntersectsBbox(self, other)
+            if isinstance(self.expr, sa.Column):
+                return Geometry_ColumnIntersectsBbox(self.expr, other)
+
+            return Geometry_IntersectsBbox(self.expr, other)
 
 
         def is_line_like(self) -> SaColumn:
@@ -159,16 +227,19 @@ class Geometry(types.UserDefinedType): # type: ignore[type-arg]
 
 
         def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn:
-            return sa.func.ST_DWithin(self, other, distance, type_=sa.Boolean)
+            if isinstance(self.expr, sa.Column):
+                return Geometry_ColumnDWithin(self.expr, other, distance)
+
+            return sa.func.ST_DWithin(self.expr, other, distance)
 
 
         def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn:
             return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self),
-                                      other, distance, type_=sa.Boolean)
+                                      other, distance)
 
 
         def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators':
-            return sa.func.coalesce(sa.null(), self).op('&&')(other)
+            return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self), other)
 
 
         def ST_Distance(self, other: SaColumn) -> SaColumn: