]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/search/db_searches.py
fix style issue found by flake8
[nominatim.git] / src / nominatim_api / search / db_searches.py
index f5c431460d6a34677d37c645e32c047c60c753ac..3a4c826fd871c0a60725d62188066e0bbd9ae7d9 100644 (file)
@@ -12,22 +12,20 @@ import abc
 
 import sqlalchemy as sa
 
 
 import sqlalchemy as sa
 
-from nominatim_core.typing import SaFromClause, SaScalarSelect, SaColumn, \
-                                  SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind
-from nominatim_core.db.sqlalchemy_types import Geometry, IntArray
+from ..typing import SaFromClause, SaScalarSelect, SaColumn, \
+                     SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind
+from ..sql.sqlalchemy_types import Geometry, IntArray
 from ..connection import SearchConnection
 from ..types import SearchDetails, DataLayer, GeometryFormat, Bbox
 from .. import results as nres
 from .db_search_fields import SearchData, WeightedCategories
 
 from ..connection import SearchConnection
 from ..types import SearchDetails, DataLayer, GeometryFormat, Bbox
 from .. import results as nres
 from .db_search_fields import SearchData, WeightedCategories
 
-#pylint: disable=singleton-comparison,not-callable
-#pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements
 
 def no_index(expr: SaColumn) -> SaColumn:
     """ Wrap the given expression, so that the query planner will
         refrain from using the expression for index lookup.
     """
 
 def no_index(expr: SaColumn) -> SaColumn:
     """ Wrap the given expression, so that the query planner will
         refrain from using the expression for index lookup.
     """
-    return sa.func.coalesce(sa.null(), expr) # pylint: disable=not-callable
+    return sa.func.coalesce(sa.null(), expr)
 
 
 def _details_to_bind_params(details: SearchDetails) -> Dict[str, Any]:
 
 
 def _details_to_bind_params(details: SearchDetails) -> Dict[str, Any]:
@@ -68,7 +66,7 @@ def filter_by_area(sql: SaSelect, t: SaFromClause,
     if details.viewbox is not None and details.bounded_viewbox:
         sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM,
                                                 use_index=not avoid_index and
     if details.viewbox is not None and details.bounded_viewbox:
         sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM,
                                                 use_index=not avoid_index and
-                                                          details.viewbox.area < 0.2))
+                                                details.viewbox.area < 0.2))
 
     return sql
 
 
     return sql
 
@@ -190,7 +188,7 @@ def _int_list_to_subquery(inp: List[int]) -> 'sa.Subquery':
         as rows in the column 'nr'.
     """
     vtab = sa.func.JsonArrayEach(sa.type_coerce(inp, sa.JSON))\
         as rows in the column 'nr'.
     """
     vtab = sa.func.JsonArrayEach(sa.type_coerce(inp, sa.JSON))\
-               .table_valued(sa.column('value', type_=sa.JSON))
+             .table_valued(sa.column('value', type_=sa.JSON))
     return sa.select(sa.cast(sa.cast(vtab.c.value, sa.Text), sa.Integer).label('nr')).subquery()
 
 
     return sa.select(sa.cast(sa.cast(vtab.c.value, sa.Text), sa.Integer).label('nr')).subquery()
 
 
@@ -266,7 +264,6 @@ class NearSearch(AbstractSearch):
         self.search = search
         self.categories = categories
 
         self.search = search
         self.categories = categories
 
-
     async def lookup(self, conn: SearchConnection,
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
     async def lookup(self, conn: SearchConnection,
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
@@ -288,11 +285,12 @@ class NearSearch(AbstractSearch):
         else:
             min_rank = 26
             max_rank = 30
         else:
             min_rank = 26
             max_rank = 30
-        base = nres.SearchResults(r for r in base if r.source_table == nres.SourceTable.PLACEX
-                                                     and r.accuracy <= max_accuracy
-                                                     and r.bbox and r.bbox.area < 20
-                                                     and r.rank_address >= min_rank
-                                                     and r.rank_address <= max_rank)
+        base = nres.SearchResults(r for r in base
+                                  if (r.source_table == nres.SourceTable.PLACEX
+                                      and r.accuracy <= max_accuracy
+                                      and r.bbox and r.bbox.area < 20
+                                      and r.rank_address >= min_rank
+                                      and r.rank_address <= max_rank))
 
         if base:
             baseids = [b.place_id for b in base[:5] if b.place_id]
 
         if base:
             baseids = [b.place_id for b in base[:5] if b.place_id]
@@ -304,7 +302,6 @@ class NearSearch(AbstractSearch):
 
         return results
 
 
         return results
 
-
     async def lookup_category(self, results: nres.SearchResults,
                               conn: SearchConnection, ids: List[int],
                               category: Tuple[str, str], penalty: float,
     async def lookup_category(self, results: nres.SearchResults,
                               conn: SearchConnection, ids: List[int],
                               category: Tuple[str, str], penalty: float,
@@ -334,9 +331,9 @@ class NearSearch(AbstractSearch):
                     .join(tgeom,
                           table.c.centroid.ST_CoveredBy(
                               sa.case((sa.and_(tgeom.c.rank_address > 9,
                     .join(tgeom,
                           table.c.centroid.ST_CoveredBy(
                               sa.case((sa.and_(tgeom.c.rank_address > 9,
-                                                tgeom.c.geometry.is_area()),
+                                               tgeom.c.geometry.is_area()),
                                        tgeom.c.geometry),
                                        tgeom.c.geometry),
-                                      else_ = tgeom.c.centroid.ST_Expand(0.05))))
+                                      else_=tgeom.c.centroid.ST_Expand(0.05))))
 
         inner = sql.where(tgeom.c.place_id.in_(ids))\
                    .group_by(table.c.place_id).subquery()
 
         inner = sql.where(tgeom.c.place_id.in_(ids))\
                    .group_by(table.c.place_id).subquery()
@@ -363,7 +360,6 @@ class NearSearch(AbstractSearch):
             results.append(result)
 
 
             results.append(result)
 
 
-
 class PoiSearch(AbstractSearch):
     """ Category search in a geographic area.
     """
 class PoiSearch(AbstractSearch):
     """ Category search in a geographic area.
     """
@@ -372,7 +368,6 @@ class PoiSearch(AbstractSearch):
         self.qualifiers = sdata.qualifiers
         self.countries = sdata.countries
 
         self.qualifiers = sdata.qualifiers
         self.countries = sdata.countries
 
-
     async def lookup(self, conn: SearchConnection,
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
     async def lookup(self, conn: SearchConnection,
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
@@ -387,7 +382,7 @@ class PoiSearch(AbstractSearch):
             def _base_query() -> SaSelect:
                 return _select_placex(t) \
                            .add_columns((-t.c.centroid.ST_Distance(NEAR_PARAM))
             def _base_query() -> SaSelect:
                 return _select_placex(t) \
                            .add_columns((-t.c.centroid.ST_Distance(NEAR_PARAM))
-                                         .label('importance'))\
+                                        .label('importance'))\
                            .where(t.c.linked_place_id == None) \
                            .where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
                            .order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \
                            .where(t.c.linked_place_id == None) \
                            .where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
                            .order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \
@@ -396,9 +391,9 @@ class PoiSearch(AbstractSearch):
             classtype = self.qualifiers.values
             if len(classtype) == 1:
                 cclass, ctype = classtype[0]
             classtype = self.qualifiers.values
             if len(classtype) == 1:
                 cclass, ctype = classtype[0]
-                sql: SaLambdaSelect = sa.lambda_stmt(lambda: _base_query()
-                                                 .where(t.c.class_ == cclass)
-                                                 .where(t.c.type == ctype))
+                sql: SaLambdaSelect = sa.lambda_stmt(
+                    lambda: _base_query().where(t.c.class_ == cclass)
+                                         .where(t.c.type == ctype))
             else:
                 sql = _base_query().where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ)
                                                    for cls, typ in classtype)))
             else:
                 sql = _base_query().where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ)
                                                    for cls, typ in classtype)))
@@ -455,7 +450,6 @@ class CountrySearch(AbstractSearch):
         super().__init__(sdata.penalty)
         self.countries = sdata.countries
 
         super().__init__(sdata.penalty)
         self.countries = sdata.countries
 
-
     async def lookup(self, conn: SearchConnection,
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
     async def lookup(self, conn: SearchConnection,
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
@@ -464,9 +458,9 @@ class CountrySearch(AbstractSearch):
 
         ccodes = self.countries.values
         sql = _select_placex(t)\
 
         ccodes = self.countries.values
         sql = _select_placex(t)\
-                .add_columns(t.c.importance)\
-                .where(t.c.country_code.in_(ccodes))\
-                .where(t.c.rank_address == 4)
+            .add_columns(t.c.importance)\
+            .where(t.c.country_code.in_(ccodes))\
+            .where(t.c.rank_address == 4)
 
         if details.geometry_output:
             sql = _add_geometry_columns(sql, t.c.geometry, details)
 
         if details.geometry_output:
             sql = _add_geometry_columns(sql, t.c.geometry, details)
@@ -493,7 +487,6 @@ class CountrySearch(AbstractSearch):
 
         return results
 
 
         return results
 
-
     async def lookup_in_country_table(self, conn: SearchConnection,
                                       details: SearchDetails) -> nres.SearchResults:
         """ Look up the country in the fallback country tables.
     async def lookup_in_country_table(self, conn: SearchConnection,
                                       details: SearchDetails) -> nres.SearchResults:
         """ Look up the country in the fallback country tables.
@@ -509,7 +502,7 @@ class CountrySearch(AbstractSearch):
 
         sql = sa.select(tgrid.c.country_code,
                         tgrid.c.geometry.ST_Centroid().ST_Collect().ST_Centroid()
 
         sql = sa.select(tgrid.c.country_code,
                         tgrid.c.geometry.ST_Centroid().ST_Collect().ST_Centroid()
-                              .label('centroid'),
+                             .label('centroid'),
                         tgrid.c.geometry.ST_Collect().ST_Expand(0).label('bbox'))\
                 .where(tgrid.c.country_code.in_(self.countries.values))\
                 .group_by(tgrid.c.country_code)
                         tgrid.c.geometry.ST_Collect().ST_Expand(0).label('bbox'))\
                 .where(tgrid.c.country_code.in_(self.countries.values))\
                 .group_by(tgrid.c.country_code)
@@ -537,7 +530,6 @@ class CountrySearch(AbstractSearch):
         return results
 
 
         return results
 
 
-
 class PostcodeSearch(AbstractSearch):
     """ Search for a postcode.
     """
 class PostcodeSearch(AbstractSearch):
     """ Search for a postcode.
     """
@@ -548,7 +540,6 @@ class PostcodeSearch(AbstractSearch):
         self.lookups = sdata.lookups
         self.rankings = sdata.rankings
 
         self.lookups = sdata.lookups
         self.rankings = sdata.rankings
 
-
     async def lookup(self, conn: SearchConnection,
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
     async def lookup(self, conn: SearchConnection,
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
@@ -588,14 +579,13 @@ class PostcodeSearch(AbstractSearch):
             tsearch = conn.t.search_name
             sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
                      .where((tsearch.c.name_vector + tsearch.c.nameaddress_vector)
             tsearch = conn.t.search_name
             sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
                      .where((tsearch.c.name_vector + tsearch.c.nameaddress_vector)
-                                     .contains(sa.type_coerce(self.lookups[0].tokens,
-                                                              IntArray)))
+                            .contains(sa.type_coerce(self.lookups[0].tokens,
+                                                     IntArray)))
 
         for ranking in self.rankings:
             penalty += ranking.sql_penalty(conn.t.search_name)
         penalty += sa.case(*((t.c.postcode == v, p) for v, p in self.postcodes),
 
         for ranking in self.rankings:
             penalty += ranking.sql_penalty(conn.t.search_name)
         penalty += sa.case(*((t.c.postcode == v, p) for v, p in self.postcodes),
-                       else_=1.0)
-
+                           else_=1.0)
 
         sql = sql.add_columns(penalty.label('accuracy'))
         sql = sql.order_by('accuracy').limit(LIMIT_PARAM)
 
         sql = sql.add_columns(penalty.label('accuracy'))
         sql = sql.order_by('accuracy').limit(LIMIT_PARAM)
@@ -603,19 +593,22 @@ class PostcodeSearch(AbstractSearch):
         results = nres.SearchResults()
         for row in await conn.execute(sql, _details_to_bind_params(details)):
             p = conn.t.placex
         results = nres.SearchResults()
         for row in await conn.execute(sql, _details_to_bind_params(details)):
             p = conn.t.placex
-            placex_sql = _select_placex(p).add_columns(p.c.importance)\
-                             .where(sa.text("""class = 'boundary'
-                                               AND type = 'postal_code'
-                                               AND osm_type = 'R'"""))\
-                             .where(p.c.country_code == row.country_code)\
-                             .where(p.c.postcode == row.postcode)\
-                             .limit(1)
+            placex_sql = _select_placex(p)\
+                .add_columns(p.c.importance)\
+                .where(sa.text("""class = 'boundary'
+                                  AND type = 'postal_code'
+                                  AND osm_type = 'R'"""))\
+                .where(p.c.country_code == row.country_code)\
+                .where(p.c.postcode == row.postcode)\
+                .limit(1)
 
             if details.geometry_output:
                 placex_sql = _add_geometry_columns(placex_sql, p.c.geometry, details)
 
             for prow in await conn.execute(placex_sql, _details_to_bind_params(details)):
                 result = nres.create_from_placex_row(prow, nres.SearchResult)
 
             if details.geometry_output:
                 placex_sql = _add_geometry_columns(placex_sql, p.c.geometry, details)
 
             for prow in await conn.execute(placex_sql, _details_to_bind_params(details)):
                 result = nres.create_from_placex_row(prow, nres.SearchResult)
+                if result is not None:
+                    result.bbox = Bbox.from_wkb(prow.bbox)
                 break
             else:
                 result = nres.create_from_postcode_row(row, nres.SearchResult)
                 break
             else:
                 result = nres.create_from_postcode_row(row, nres.SearchResult)
@@ -628,7 +621,6 @@ class PostcodeSearch(AbstractSearch):
         return results
 
 
         return results
 
 
-
 class PlaceSearch(AbstractSearch):
     """ Generic search for an address or named place.
     """
 class PlaceSearch(AbstractSearch):
     """ Generic search for an address or named place.
     """
@@ -644,7 +636,6 @@ class PlaceSearch(AbstractSearch):
         self.rankings = sdata.rankings
         self.expected_count = expected_count
 
         self.rankings = sdata.rankings
         self.expected_count = expected_count
 
-
     def _inner_search_name_cte(self, conn: SearchConnection,
                                details: SearchDetails) -> 'sa.CTE':
         """ Create a subquery that preselects the rows in the search_name
     def _inner_search_name_cte(self, conn: SearchConnection,
                                details: SearchDetails) -> 'sa.CTE':
         """ Create a subquery that preselects the rows in the search_name
@@ -697,7 +688,7 @@ class PlaceSearch(AbstractSearch):
                                                              NEAR_RADIUS_PARAM))
             else:
                 sql = sql.where(t.c.centroid
                                                              NEAR_RADIUS_PARAM))
             else:
                 sql = sql.where(t.c.centroid
-                                   .ST_Distance(NEAR_PARAM) <  NEAR_RADIUS_PARAM)
+                                 .ST_Distance(NEAR_PARAM) < NEAR_RADIUS_PARAM)
 
         if self.housenumbers:
             sql = sql.where(t.c.address_rank.between(16, 30))
 
         if self.housenumbers:
             sql = sql.where(t.c.address_rank.between(16, 30))
@@ -725,8 +716,8 @@ class PlaceSearch(AbstractSearch):
            and (details.near is None or details.near_radius is not None)\
            and not self.qualifiers:
             sql = sql.add_columns(sa.func.first_value(inner.c.penalty - inner.c.importance)
            and (details.near is None or details.near_radius is not None)\
            and not self.qualifiers:
             sql = sql.add_columns(sa.func.first_value(inner.c.penalty - inner.c.importance)
-                                       .over(order_by=inner.c.penalty - inner.c.importance)
-                                       .label('min_penalty'))
+                                    .over(order_by=inner.c.penalty - inner.c.importance)
+                                    .label('min_penalty'))
 
             inner = sql.subquery()
 
 
             inner = sql.subquery()
 
@@ -737,7 +728,6 @@ class PlaceSearch(AbstractSearch):
 
         return sql.cte('searches')
 
 
         return sql.cte('searches')
 
-
     async def lookup(self, conn: SearchConnection,
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
     async def lookup(self, conn: SearchConnection,
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
@@ -757,8 +747,8 @@ class PlaceSearch(AbstractSearch):
             pcs = self.postcodes.values
 
             pc_near = sa.select(sa.func.min(tpc.c.geometry.ST_Distance(t.c.centroid)))\
             pcs = self.postcodes.values
 
             pc_near = sa.select(sa.func.min(tpc.c.geometry.ST_Distance(t.c.centroid)))\
-                      .where(tpc.c.postcode.in_(pcs))\
-                      .scalar_subquery()
+                        .where(tpc.c.postcode.in_(pcs))\
+                        .scalar_subquery()
             penalty += sa.case((t.c.postcode.in_(pcs), 0.0),
                                else_=sa.func.coalesce(pc_near, cast(SaColumn, 2.0)))
 
             penalty += sa.case((t.c.postcode.in_(pcs), 0.0),
                                else_=sa.func.coalesce(pc_near, cast(SaColumn, 2.0)))
 
@@ -769,13 +759,12 @@ class PlaceSearch(AbstractSearch):
 
         if details.near is not None:
             sql = sql.add_columns((-tsearch.c.centroid.ST_Distance(NEAR_PARAM))
 
         if details.near is not None:
             sql = sql.add_columns((-tsearch.c.centroid.ST_Distance(NEAR_PARAM))
-                                      .label('importance'))
+                                  .label('importance'))
             sql = sql.order_by(sa.desc(sa.text('importance')))
         else:
             sql = sql.order_by(penalty - tsearch.c.importance)
             sql = sql.add_columns(tsearch.c.importance)
 
             sql = sql.order_by(sa.desc(sa.text('importance')))
         else:
             sql = sql.order_by(penalty - tsearch.c.importance)
             sql = sql.add_columns(tsearch.c.importance)
 
-
         sql = sql.add_columns(penalty.label('accuracy'))\
                  .order_by(sa.text('accuracy'))
 
         sql = sql.add_columns(penalty.label('accuracy'))\
                  .order_by(sa.text('accuracy'))
 
@@ -812,7 +801,7 @@ class PlaceSearch(AbstractSearch):
                 tiger_sql = sa.case((inner.c.country_code == 'us',
                                      _make_interpolation_subquery(conn.t.tiger, inner,
                                                                   numerals, details)
                 tiger_sql = sa.case((inner.c.country_code == 'us',
                                      _make_interpolation_subquery(conn.t.tiger, inner,
                                                                   numerals, details)
-                                    ), else_=None)
+                                     ), else_=None)
             else:
                 interpol_sql = sa.null()
                 tiger_sql = sa.null()
             else:
                 interpol_sql = sa.null()
                 tiger_sql = sa.null()
@@ -866,7 +855,7 @@ class PlaceSearch(AbstractSearch):
                 if (not details.excluded or result.place_id not in details.excluded)\
                    and (not self.qualifiers or result.category in self.qualifiers.values)\
                    and result.rank_address >= details.min_rank:
                 if (not details.excluded or result.place_id not in details.excluded)\
                    and (not self.qualifiers or result.category in self.qualifiers.values)\
                    and result.rank_address >= details.min_rank:
-                    result.accuracy += 1.0 # penalty for missing housenumber
+                    result.accuracy += 1.0  # penalty for missing housenumber
                     results.append(result)
             else:
                 results.append(result)
                     results.append(result)
             else:
                 results.append(result)