]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/search/db_searches.py
split up get_assignment functon in more readable parts
[nominatim.git] / nominatim / api / search / db_searches.py
index fc3f9e0945e7d6ef39d66fc1871bc5d708d1b544..5c1d98c9a60b611b38d96b0c6cb636132f0b4715 100644 (file)
@@ -7,14 +7,14 @@
 """
 Implementation of the acutal database accesses for forward search.
 """
 """
 Implementation of the acutal database accesses for forward search.
 """
-from typing import List, Tuple, AsyncIterator, Dict, Any
+from typing import List, Tuple, AsyncIterator, Dict, Any, Callable
 import abc
 
 import sqlalchemy as sa
 from sqlalchemy.dialects.postgresql import ARRAY, array_agg
 
 from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \
 import abc
 
 import sqlalchemy as sa
 from sqlalchemy.dialects.postgresql import ARRAY, array_agg
 
 from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \
-                             SaExpression, SaSelect, SaRow
+                             SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind
 from nominatim.api.connection import SearchConnection
 from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox
 import nominatim.api.results as nres
 from nominatim.api.connection import SearchConnection
 from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox
 import nominatim.api.results as nres
@@ -39,15 +39,20 @@ def _details_to_bind_params(details: SearchDetails) -> Dict[str, Any]:
             'countries': details.countries}
 
 
             'countries': details.countries}
 
 
-LIMIT_PARAM = sa.bindparam('limit')
-MIN_RANK_PARAM = sa.bindparam('min_rank')
-MAX_RANK_PARAM = sa.bindparam('max_rank')
-VIEWBOX_PARAM = sa.bindparam('viewbox', type_=Geometry)
-VIEWBOX2_PARAM = sa.bindparam('viewbox2', type_=Geometry)
-NEAR_PARAM = sa.bindparam('near', type_=Geometry)
-NEAR_RADIUS_PARAM = sa.bindparam('near_radius')
-EXCLUDED_PARAM = sa.bindparam('excluded')
-COUNTRIES_PARAM = sa.bindparam('countries')
+LIMIT_PARAM: SaBind = sa.bindparam('limit')
+MIN_RANK_PARAM: SaBind = sa.bindparam('min_rank')
+MAX_RANK_PARAM: SaBind = sa.bindparam('max_rank')
+VIEWBOX_PARAM: SaBind = sa.bindparam('viewbox', type_=Geometry)
+VIEWBOX2_PARAM: SaBind = sa.bindparam('viewbox2', type_=Geometry)
+NEAR_PARAM: SaBind = sa.bindparam('near', type_=Geometry)
+NEAR_RADIUS_PARAM: SaBind = sa.bindparam('near_radius')
+COUNTRIES_PARAM: SaBind = sa.bindparam('countries')
+
+def _within_near(t: SaFromClause) -> Callable[[], SaExpression]:
+    return lambda: t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM)
+
+def _exclude_places(t: SaFromClause) -> Callable[[], SaExpression]:
+    return lambda: t.c.place_id.not_in(sa.bindparam('excluded'))
 
 def _select_placex(t: SaFromClause) -> SaSelect:
     return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
 
 def _select_placex(t: SaFromClause) -> SaSelect:
     return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
@@ -60,10 +65,7 @@ def _select_placex(t: SaFromClause) -> SaSelect:
                      t.c.geometry.ST_Expand(0).label('bbox'))
 
 
                      t.c.geometry.ST_Expand(0).label('bbox'))
 
 
-def _add_geometry_columns(sql: SaSelect, col: SaColumn, details: SearchDetails) -> SaSelect:
-    if not details.geometry_output:
-        return sql
-
+def _add_geometry_columns(sql: SaLambdaSelect, col: SaColumn, details: SearchDetails) -> SaSelect:
     out = []
 
     if details.geometry_simplification > 0.0:
     out = []
 
     if details.geometry_simplification > 0.0:
@@ -96,7 +98,7 @@ def _make_interpolation_subquery(table: SaFromClause, inner: SaFromClause,
                   for n in numerals)))
 
     if details.excluded:
                   for n in numerals)))
 
     if details.excluded:
-        sql = sql.where(table.c.place_id.not_in(EXCLUDED_PARAM))
+        sql = sql.where(_exclude_places(table))
 
     return sql.scalar_subquery()
 
 
     return sql.scalar_subquery()
 
@@ -150,7 +152,8 @@ async def _get_placex_housenumbers(conn: SearchConnection,
     t = conn.t.placex
     sql = _select_placex(t).where(t.c.place_id.in_(place_ids))
 
     t = conn.t.placex
     sql = _select_placex(t).where(t.c.place_id.in_(place_ids))
 
-    sql = _add_geometry_columns(sql, t.c.geometry, details)
+    if details.geometry_output:
+        sql = _add_geometry_columns(sql, t.c.geometry, details)
 
     for row in await conn.execute(sql):
         result = nres.create_from_placex_row(row, nres.SearchResult)
 
     for row in await conn.execute(sql):
         result = nres.create_from_placex_row(row, nres.SearchResult)
@@ -294,7 +297,7 @@ class NearSearch(AbstractSearch):
         if details.countries:
             sql = sql.where(t.c.country_code.in_(COUNTRIES_PARAM))
         if details.excluded:
         if details.countries:
             sql = sql.where(t.c.country_code.in_(COUNTRIES_PARAM))
         if details.excluded:
-            sql = sql.where(t.c.place_id.not_in(EXCLUDED_PARAM))
+            sql = sql.where(_exclude_places(t))
         if details.layers is not None:
             sql = sql.where(_filter_by_layer(t, details.layers))
 
         if details.layers is not None:
             sql = sql.where(_filter_by_layer(t, details.layers))
 
@@ -328,10 +331,22 @@ class PoiSearch(AbstractSearch):
 
         if details.near and details.near_radius is not None and details.near_radius < 0.2:
             # simply search in placex table
 
         if details.near and details.near_radius is not None and details.near_radius < 0.2:
             # simply search in placex table
-            sql = _select_placex(t) \
-                      .where(t.c.linked_place_id == None) \
-                      .where(t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
-                      .order_by(t.c.centroid.ST_Distance(NEAR_PARAM))
+            def _base_query() -> SaSelect:
+                return _select_placex(t) \
+                           .where(t.c.linked_place_id == None) \
+                           .where(t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
+                           .order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \
+                           .limit(LIMIT_PARAM)
+
+            classtype = self.categories.values
+            if len(classtype) == 1:
+                cclass, ctype = classtype[0]
+                sql: SaLambdaSelect = sa.lambda_stmt(lambda: _base_query()
+                                                 .where(t.c.class_ == cclass)
+                                                 .where(t.c.type == ctype))
+            else:
+                sql = _base_query().where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ)
+                                                   for cls, typ in classtype)))
 
             if self.countries:
                 sql = sql.where(t.c.country_code.in_(self.countries.values))
 
             if self.countries:
                 sql = sql.where(t.c.country_code.in_(self.countries.values))
@@ -339,15 +354,6 @@ class PoiSearch(AbstractSearch):
             if details.viewbox is not None and details.bounded_viewbox:
                 sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM))
 
             if details.viewbox is not None and details.bounded_viewbox:
                 sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM))
 
-            classtype = self.categories.values
-            if len(classtype) == 1:
-                sql = sql.where(t.c.class_ == classtype[0][0]) \
-                         .where(t.c.type == classtype[0][1])
-            else:
-                sql = sql.where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ)
-                                         for cls, typ in classtype)))
-
-            sql = sql.limit(LIMIT_PARAM)
             rows.extend(await conn.execute(sql, bind_params))
         else:
             # use the class type tables
             rows.extend(await conn.execute(sql, bind_params))
         else:
             # use the class type tables
@@ -398,20 +404,22 @@ class CountrySearch(AbstractSearch):
         """
         t = conn.t.placex
 
         """
         t = conn.t.placex
 
-        sql = _select_placex(t)\
-                .where(t.c.country_code.in_(self.countries.values))\
-                .where(t.c.rank_address == 4)
+        ccodes = self.countries.values
+        sql: SaLambdaSelect = sa.lambda_stmt(lambda: _select_placex(t)\
+                .where(t.c.country_code.in_(ccodes))\
+                .where(t.c.rank_address == 4))
 
 
-        sql = _add_geometry_columns(sql, t.c.geometry, details)
+        if details.geometry_output:
+            sql = _add_geometry_columns(sql, t.c.geometry, details)
 
         if details.excluded:
 
         if details.excluded:
-            sql = sql.where(t.c.place_id.not_in(EXCLUDED_PARAM))
+            sql = sql.where(_exclude_places(t))
 
         if details.viewbox is not None and details.bounded_viewbox:
 
         if details.viewbox is not None and details.bounded_viewbox:
-            sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM))
+            sql = sql.where(lambda: t.c.geometry.intersects(VIEWBOX_PARAM))
 
         if details.near is not None and details.near_radius is not None:
 
         if details.near is not None and details.near_radius is not None:
-            sql = sql.where(t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM))
+            sql = sql.where(_within_near(t))
 
         results = nres.SearchResults()
         for row in await conn.execute(sql, _details_to_bind_params(details)):
 
         results = nres.SearchResults()
         for row in await conn.execute(sql, _details_to_bind_params(details)):
@@ -445,7 +453,7 @@ class CountrySearch(AbstractSearch):
         if details.viewbox is not None and details.bounded_viewbox:
             sql = sql.where(tgrid.c.geometry.intersects(VIEWBOX_PARAM))
         if details.near is not None and details.near_radius is not None:
         if details.viewbox is not None and details.bounded_viewbox:
             sql = sql.where(tgrid.c.geometry.intersects(VIEWBOX_PARAM))
         if details.near is not None and details.near_radius is not None:
-            sql = sql.where(tgrid.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM))
+            sql = sql.where(_within_near(tgrid))
 
         sub = sql.subquery('grid')
 
 
         sub = sql.subquery('grid')
 
@@ -484,14 +492,17 @@ class PostcodeSearch(AbstractSearch):
         """ Find results for the search in the database.
         """
         t = conn.t.postcode
         """ Find results for the search in the database.
         """
         t = conn.t.postcode
+        pcs = self.postcodes.values
 
 
-        sql = sa.select(t.c.place_id, t.c.parent_place_id,
+        sql: SaLambdaSelect = sa.lambda_stmt(lambda:
+                sa.select(t.c.place_id, t.c.parent_place_id,
                         t.c.rank_search, t.c.rank_address,
                         t.c.postcode, t.c.country_code,
                         t.c.rank_search, t.c.rank_address,
                         t.c.postcode, t.c.country_code,
-                        t.c.geometry.label('centroid'))\
-                .where(t.c.postcode.in_(self.postcodes.values))
+                        t.c.geometry.label('centroid'))
+                .where(t.c.postcode.in_(pcs)))
 
 
-        sql = _add_geometry_columns(sql, t.c.geometry, details)
+        if details.geometry_output:
+            sql = _add_geometry_columns(sql, t.c.geometry, details)
 
         penalty: SaExpression = sa.literal(self.penalty)
 
 
         penalty: SaExpression = sa.literal(self.penalty)
 
@@ -505,14 +516,14 @@ class PostcodeSearch(AbstractSearch):
 
         if details.near is not None:
             if details.near_radius is not None:
 
         if details.near is not None:
             if details.near_radius is not None:
-                sql = sql.where(t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM))
+                sql = sql.where(_within_near(t))
             sql = sql.order_by(t.c.geometry.ST_Distance(NEAR_PARAM))
 
         if self.countries:
             sql = sql.where(t.c.country_code.in_(self.countries.values))
 
         if details.excluded:
             sql = sql.order_by(t.c.geometry.ST_Distance(NEAR_PARAM))
 
         if self.countries:
             sql = sql.where(t.c.country_code.in_(self.countries.values))
 
         if details.excluded:
-            sql = sql.where(t.c.place_id.not_in(EXCLUDED_PARAM))
+            sql = sql.where(_exclude_places(t))
 
         if self.lookups:
             assert len(self.lookups) == 1
 
         if self.lookups:
             assert len(self.lookups) == 1
@@ -562,21 +573,23 @@ class PlaceSearch(AbstractSearch):
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
         """
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
         """
-        t = conn.t.placex.alias('p')
-        tsearch = conn.t.search_name.alias('s')
+        t = conn.t.placex
+        tsearch = conn.t.search_name
 
 
-        sql = sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
-                        t.c.class_, t.c.type,
-                        t.c.address, t.c.extratags,
-                        t.c.housenumber, t.c.postcode, t.c.country_code,
-                        t.c.wikipedia,
-                        t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
-                        t.c.centroid,
-                        t.c.geometry.ST_Expand(0).label('bbox'))\
-                .where(t.c.place_id == tsearch.c.place_id)
+        sql: SaLambdaSelect = sa.lambda_stmt(lambda:
+                  sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
+                            t.c.class_, t.c.type,
+                            t.c.address, t.c.extratags,
+                            t.c.housenumber, t.c.postcode, t.c.country_code,
+                            t.c.wikipedia,
+                            t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
+                            t.c.centroid,
+                            t.c.geometry.ST_Expand(0).label('bbox'))
+                   .where(t.c.place_id == tsearch.c.place_id))
 
 
 
 
-        sql = _add_geometry_columns(sql, t.c.geometry, details)
+        if details.geometry_output:
+            sql = _add_geometry_columns(sql, t.c.geometry, details)
 
         penalty: SaExpression = sa.literal(self.penalty)
         for ranking in self.rankings:
 
         penalty: SaExpression = sa.literal(self.penalty)
         for ranking in self.rankings:
@@ -592,18 +605,19 @@ class PlaceSearch(AbstractSearch):
             # if a postcode is given, don't search for state or country level objects
             sql = sql.where(tsearch.c.address_rank > 9)
             tpc = conn.t.postcode
             # if a postcode is given, don't search for state or country level objects
             sql = sql.where(tsearch.c.address_rank > 9)
             tpc = conn.t.postcode
+            pcs = self.postcodes.values
             if self.expected_count > 1000:
                 # Many results expected. Restrict by postcode.
             if self.expected_count > 1000:
                 # Many results expected. Restrict by postcode.
-                sql = sql.where(sa.select(tpc.c.postcode)
-                                  .where(tpc.c.postcode.in_(self.postcodes.values))
+                sql = sql.where(lambda: sa.select(tpc.c.postcode)
+                                  .where(tpc.c.postcode.in_(pcs))
                                   .where(tsearch.c.centroid.ST_DWithin(tpc.c.geometry, 0.12))
                                   .exists())
 
             # Less results, only have a preference for close postcodes
             pc_near = sa.select(sa.func.min(tpc.c.geometry.ST_Distance(tsearch.c.centroid)))\
                                   .where(tsearch.c.centroid.ST_DWithin(tpc.c.geometry, 0.12))
                                   .exists())
 
             # Less results, only have a preference for close postcodes
             pc_near = sa.select(sa.func.min(tpc.c.geometry.ST_Distance(tsearch.c.centroid)))\
-                      .where(tpc.c.postcode.in_(self.postcodes.values))\
+                      .where(tpc.c.postcode.in_(pcs))\
                       .scalar_subquery()
                       .scalar_subquery()
-            penalty += sa.case((t.c.postcode.in_(self.postcodes.values), 0.0),
+            penalty += sa.case((t.c.postcode.in_(pcs), 0.0),
                                else_=sa.func.coalesce(pc_near, 2.0))
 
         if details.viewbox is not None:
                                else_=sa.func.coalesce(pc_near, 2.0))
 
         if details.viewbox is not None:
@@ -633,7 +647,7 @@ class PlaceSearch(AbstractSearch):
             hnr_regexp = f"\\m({'|'.join(self.housenumbers.values)})\\M"
             sql = sql.where(tsearch.c.address_rank.between(16, 30))\
                      .where(sa.or_(tsearch.c.address_rank < 30,
             hnr_regexp = f"\\m({'|'.join(self.housenumbers.values)})\\M"
             sql = sql.where(tsearch.c.address_rank.between(16, 30))\
                      .where(sa.or_(tsearch.c.address_rank < 30,
-                                  t.c.housenumber.op('~*')(hnr_regexp)))
+                                   t.c.housenumber.op('~*')(hnr_regexp)))
 
             # Cross check for housenumbers, need to do that on a rather large
             # set. Worst case there are 40.000 main streets in OSM.
 
             # Cross check for housenumbers, need to do that on a rather large
             # set. Worst case there are 40.000 main streets in OSM.
@@ -649,13 +663,13 @@ class PlaceSearch(AbstractSearch):
                           .where(thnr.c.indexed_status == 0)
 
             if details.excluded:
                           .where(thnr.c.indexed_status == 0)
 
             if details.excluded:
-                place_sql = place_sql.where(thnr.c.place_id.not_in(EXCLUDED_PARAM))
+                place_sql = place_sql.where(_exclude_places(thnr))
             if self.qualifiers:
                 place_sql = place_sql.where(self.qualifiers.sql_restrict(thnr))
 
             numerals = [int(n) for n in self.housenumbers.values if n.isdigit()]
             if self.qualifiers:
                 place_sql = place_sql.where(self.qualifiers.sql_restrict(thnr))
 
             numerals = [int(n) for n in self.housenumbers.values if n.isdigit()]
-            interpol_sql: SaExpression
-            tiger_sql: SaExpression
+            interpol_sql: SaColumn
+            tiger_sql: SaColumn
             if numerals and \
                (not self.qualifiers or ('place', 'house') in self.qualifiers.values):
                 # Housenumbers from interpolations
             if numerals and \
                (not self.qualifiers or ('place', 'house') in self.qualifiers.values):
                 # Housenumbers from interpolations
@@ -667,8 +681,8 @@ class PlaceSearch(AbstractSearch):
                                                                   numerals, details)
                                     ), else_=None)
             else:
                                                                   numerals, details)
                                     ), else_=None)
             else:
-                interpol_sql = sa.literal_column('NULL')
-                tiger_sql = sa.literal_column('NULL')
+                interpol_sql = sa.null()
+                tiger_sql = sa.null()
 
             unsort = sa.select(inner, place_sql.scalar_subquery().label('placex_hnr'),
                                interpol_sql.label('interpol_hnr'),
 
             unsort = sa.select(inner, place_sql.scalar_subquery().label('placex_hnr'),
                                interpol_sql.label('interpol_hnr'),
@@ -685,7 +699,7 @@ class PlaceSearch(AbstractSearch):
             if self.qualifiers:
                 sql = sql.where(self.qualifiers.sql_restrict(t))
             if details.excluded:
             if self.qualifiers:
                 sql = sql.where(self.qualifiers.sql_restrict(t))
             if details.excluded:
-                sql = sql.where(tsearch.c.place_id.not_in(EXCLUDED_PARAM))
+                sql = sql.where(_exclude_places(tsearch))
             if details.min_rank > 0:
                 sql = sql.where(sa.or_(tsearch.c.address_rank >= MIN_RANK_PARAM,
                                        tsearch.c.search_rank >= MIN_RANK_PARAM))
             if details.min_rank > 0:
                 sql = sql.where(sa.or_(tsearch.c.address_rank >= MIN_RANK_PARAM,
                                        tsearch.c.search_rank >= MIN_RANK_PARAM))