]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/reverse.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / api / reverse.py
index 63836b4924f18589ddec07b52bc78ff11f942e64..382951082a4d75082fdb0d2cf89fbca3045b5f14 100644 (file)
@@ -99,9 +99,11 @@ class ReverseGeocoder:
         coordinate.
     """
 
-    def __init__(self, conn: SearchConnection, params: ReverseDetails) -> None:
+    def __init__(self, conn: SearchConnection, params: ReverseDetails,
+                 restrict_to_country_areas: bool = False) -> None:
         self.conn = conn
         self.params = params
+        self.restrict_to_country_areas = restrict_to_country_areas
 
         self.bind_params: Dict[str, Any] = {'max_rank': params.max_rank}
 
@@ -477,16 +479,24 @@ class ReverseGeocoder:
         return _get_closest(address_row, other_row)
 
 
-    async def lookup_country(self) -> Optional[SaRow]:
+    async def lookup_country_codes(self) -> List[str]:
         """ Lookup the country for the current search.
         """
         log().section('Reverse lookup by country code')
         t = self.conn.t.country_grid
-        sql: SaLambdaSelect = sa.select(t.c.country_code).distinct()\
+        sql = sa.select(t.c.country_code).distinct()\
                 .where(t.c.geometry.ST_Contains(WKT_PARAM))
 
-        ccodes = tuple((r[0] for r in await self.conn.execute(sql, self.bind_params)))
+        ccodes = [cast(str, r[0]) for r in await self.conn.execute(sql, self.bind_params)]
         log().var_dump('Country codes', ccodes)
+        return ccodes
+
+
+    async def lookup_country(self, ccodes: List[str]) -> Optional[SaRow]:
+        """ Lookup the country for the current search.
+        """
+        if not ccodes:
+            ccodes = await self.lookup_country_codes()
 
         if not ccodes:
             return None
@@ -516,7 +526,7 @@ class ReverseGeocoder:
                     .order_by(sa.desc(inner.c.rank_search), inner.c.distance)\
                     .limit(1)
 
-            sql = sa.lambda_stmt(_base_query)
+            sql: SaLambdaSelect = sa.lambda_stmt(_base_query)
             if self.has_geometries():
                 sql = self._add_geometry_columns(sql, sa.literal_column('area.geometry'))
 
@@ -559,10 +569,19 @@ class ReverseGeocoder:
             row, tmp_row_func = await self.lookup_street_poi()
             if row is not None:
                 row_func = tmp_row_func
-        if row is None and self.max_rank > 4:
-            row = await self.lookup_area()
-        if row is None and self.layer_enabled(DataLayer.ADDRESS):
-            row = await self.lookup_country()
+
+        if row is None:
+            if self.restrict_to_country_areas:
+                ccodes = await self.lookup_country_codes()
+                if not ccodes:
+                    return None
+            else:
+                ccodes = []
+
+            if self.max_rank > 4:
+                row = await self.lookup_area()
+            if row is None and self.layer_enabled(DataLayer.ADDRESS):
+                row = await self.lookup_country(ccodes)
 
         result = row_func(row, nres.ReverseResult)
         if result is not None: