]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/lookup.py
Update lookup.py - Correct spelling for "simultaneously"
[nominatim.git] / src / nominatim_api / lookup.py
index e451edbee0a4af36e53bee3687082d1a6f6e803c..6ab3cf192e929c0362f8b7e41e050d85e1da7838 100644 (file)
@@ -5,14 +5,15 @@
 # Copyright (C) 2024 by the Nominatim developer community.
 # For a full list of authors see the git log.
 """
-Implementation of place lookup by ID.
+Implementation of place lookup by ID (doing many places at once).
 """
-from typing import Optional, Callable, Tuple, Type
+from typing import Optional, Callable, Type, Iterable, Tuple, Union
+from dataclasses import dataclass
 import datetime as dt
 
 import sqlalchemy as sa
 
-from nominatim_core.typing import SaColumn, SaRow, SaSelect
+from .typing import SaColumn, SaRow, SaSelect
 from .connection import SearchConnection
 from .logging import log
 from . import types as ntyp
@@ -20,135 +21,128 @@ from . import results as nres
 
 RowFunc = Callable[[Optional[SaRow], Type[nres.BaseResultT]], Optional[nres.BaseResultT]]
 
-GeomFunc = Callable[[SaSelect, SaColumn], SaSelect]
+GEOMETRY_TYPE_MAP = {
+    'POINT': 'ST_Point',
+    'MULTIPOINT': 'ST_MultiPoint',
+    'LINESTRING': 'ST_LineString',
+    'MULTILINESTRING': 'ST_MultiLineString',
+    'POLYGON': 'ST_Polygon',
+    'MULTIPOLYGON': 'ST_MultiPolygon',
+    'GEOMETRYCOLLECTION': 'ST_GeometryCollection'
+}
 
 
-async def find_in_placex(conn: SearchConnection, place: ntyp.PlaceRef,
-                         add_geometries: GeomFunc) -> Optional[SaRow]:
-    """ Search for the given place in the placex table and return the
-        base information.
+@dataclass
+class LookupTuple:
+    """ Data class saving the SQL result for a single lookup.
     """
-    log().section("Find in placex table")
-    t = conn.t.placex
-    sql = sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
-                    t.c.class_, t.c.type, t.c.admin_level,
-                    t.c.address, t.c.extratags,
-                    t.c.housenumber, t.c.postcode, t.c.country_code,
-                    t.c.importance, t.c.wikipedia, t.c.indexed_date,
-                    t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
-                    t.c.linked_place_id,
-                    t.c.geometry.ST_Expand(0).label('bbox'),
-                    t.c.centroid)
+    pid: ntyp.PlaceRef
+    result: Optional[nres.SearchResult] = None
 
-    if isinstance(place, ntyp.PlaceID):
-        sql = sql.where(t.c.place_id == place.place_id)
-    elif isinstance(place, ntyp.OsmID):
-        sql = sql.where(t.c.osm_type == place.osm_type)\
-                 .where(t.c.osm_id == place.osm_id)
-        if place.osm_class:
-            sql = sql.where(t.c.class_ == place.osm_class)
-        else:
-            sql = sql.order_by(t.c.class_)
-        sql = sql.limit(1)
-    else:
-        return None
 
-    return (await conn.execute(add_geometries(sql, t.c.geometry))).one_or_none()
+class LookupCollector:
+    """ Result collector for the simple lookup.
 
-
-async def find_in_osmline(conn: SearchConnection, place: ntyp.PlaceRef,
-                          add_geometries: GeomFunc) -> Optional[SaRow]:
-    """ Search for the given place in the osmline table and return the
-        base information.
+        Allows for lookup of multiple places simultaneously.
     """
-    log().section("Find in interpolation table")
-    t = conn.t.osmline
-    sql = sa.select(t.c.place_id, t.c.osm_id, t.c.parent_place_id,
-                    t.c.indexed_date, t.c.startnumber, t.c.endnumber,
-                    t.c.step, t.c.address, t.c.postcode, t.c.country_code,
-                    t.c.linegeo.ST_Centroid().label('centroid'))
 
-    if isinstance(place, ntyp.PlaceID):
-        sql = sql.where(t.c.place_id == place.place_id)
-    elif isinstance(place, ntyp.OsmID) and place.osm_type == 'W':
-        # There may be multiple interpolations for a single way.
-        # If 'class' contains a number, return the one that belongs to that number.
-        sql = sql.where(t.c.osm_id == place.osm_id).limit(1)
-        if place.osm_class and place.osm_class.isdigit():
-            sql = sql.order_by(sa.func.greatest(0,
-                                                int(place.osm_class) - t.c.endnumber,
-                                                t.c.startnumber - int(place.osm_class)))
-    else:
-        return None
-
-    return (await conn.execute(add_geometries(sql, t.c.linegeo))).one_or_none()
-
-
-async def find_in_tiger(conn: SearchConnection, place: ntyp.PlaceRef,
-                        add_geometries: GeomFunc) -> Optional[SaRow]:
-    """ Search for the given place in the table of Tiger addresses and return
-        the base information. Only lookup by place ID is supported.
-    """
-    if not isinstance(place, ntyp.PlaceID):
-        return None
-
-    log().section("Find in TIGER table")
-    t = conn.t.tiger
-    parent = conn.t.placex
-    sql = sa.select(t.c.place_id, t.c.parent_place_id,
-                    parent.c.osm_type, parent.c.osm_id,
-                    t.c.startnumber, t.c.endnumber, t.c.step,
-                    t.c.postcode,
-                    t.c.linegeo.ST_Centroid().label('centroid'))\
-            .where(t.c.place_id == place.place_id)\
-            .join(parent, t.c.parent_place_id == parent.c.place_id, isouter=True)
-
-    return (await conn.execute(add_geometries(sql, t.c.linegeo))).one_or_none()
-
-
-async def find_in_postcode(conn: SearchConnection, place: ntyp.PlaceRef,
-                           add_geometries: GeomFunc) -> Optional[SaRow]:
-    """ Search for the given place in the postcode table and return the
-        base information. Only lookup by place ID is supported.
-    """
-    if not isinstance(place, ntyp.PlaceID):
-        return None
+    def __init__(self, places: Iterable[ntyp.PlaceRef],
+                 details: ntyp.LookupDetails) -> None:
+        self.details = details
+        self.lookups = [LookupTuple(p) for p in places]
 
-    log().section("Find in postcode table")
-    t = conn.t.postcode
-    sql = sa.select(t.c.place_id, t.c.parent_place_id,
-                    t.c.rank_search, t.c.rank_address,
-                    t.c.indexed_date, t.c.postcode, t.c.country_code,
-                    t.c.geometry.label('centroid')) \
-            .where(t.c.place_id == place.place_id)
+    def get_results(self) -> nres.SearchResults:
+        """ Return the list of results available.
+        """
+        return nres.SearchResults(p.result for p in self.lookups if p.result is not None)
+
+    async def add_rows_from_sql(self, conn: SearchConnection, sql: SaSelect,
+                                col: SaColumn, row_func: RowFunc[nres.SearchResult]) -> bool:
+        if self.details.geometry_output:
+            if self.details.geometry_simplification > 0.0:
+                col = sa.func.ST_SimplifyPreserveTopology(
+                    col, self.details.geometry_simplification)
+
+            if self.details.geometry_output & ntyp.GeometryFormat.GEOJSON:
+                sql = sql.add_columns(sa.func.ST_AsGeoJSON(col, 7).label('geometry_geojson'))
+            if self.details.geometry_output & ntyp.GeometryFormat.TEXT:
+                sql = sql.add_columns(sa.func.ST_AsText(col).label('geometry_text'))
+            if self.details.geometry_output & ntyp.GeometryFormat.KML:
+                sql = sql.add_columns(sa.func.ST_AsKML(col, 7).label('geometry_kml'))
+            if self.details.geometry_output & ntyp.GeometryFormat.SVG:
+                sql = sql.add_columns(sa.func.ST_AsSVG(col, 0, 7).label('geometry_svg'))
+
+        for row in await conn.execute(sql):
+            result = row_func(row, nres.SearchResult)
+            assert result is not None
+            if hasattr(row, 'bbox'):
+                result.bbox = ntyp.Bbox.from_wkb(row.bbox)
+
+            if self.lookups[row._idx].result is None:
+                self.lookups[row._idx].result = result
 
-    return (await conn.execute(add_geometries(sql, t.c.geometry))).one_or_none()
+        return all(p.result is not None for p in self.lookups)
 
+    def enumerate_free_place_ids(self) -> Iterable[Tuple[int, ntyp.PlaceID]]:
+        return ((i, p.pid) for i, p in enumerate(self.lookups)
+                if p.result is None and isinstance(p.pid, ntyp.PlaceID))
 
-async def find_in_all_tables(conn: SearchConnection, place: ntyp.PlaceRef,
-                             add_geometries: GeomFunc
-                            ) -> Tuple[Optional[SaRow], RowFunc[nres.BaseResultT]]:
-    """ Search for the given place in all data tables
-        and return the base information.
+    def enumerate_free_osm_ids(self) -> Iterable[Tuple[int, ntyp.OsmID]]:
+        return ((i, p.pid) for i, p in enumerate(self.lookups)
+                if p.result is None and isinstance(p.pid, ntyp.OsmID))
+
+
+class DetailedCollector:
+    """ Result collector for detailed lookup.
+
+        Only one place at the time may be looked up.
     """
-    row = await find_in_placex(conn, place, add_geometries)
-    log().var_dump('Result (placex)', row)
-    if row is not None:
-        return row, nres.create_from_placex_row
 
-    row = await find_in_osmline(conn, place, add_geometries)
-    log().var_dump('Result (osmline)', row)
-    if row is not None:
-        return row, nres.create_from_osmline_row
+    def __init__(self, place: ntyp.PlaceRef, with_geometry: bool) -> None:
+        self.with_geometry = with_geometry
+        self.place = place
+        self.result: Optional[nres.DetailedResult] = None
+
+    async def add_rows_from_sql(self, conn: SearchConnection, sql: SaSelect,
+                                col: SaColumn, row_func: RowFunc[nres.DetailedResult]) -> bool:
+        if self.with_geometry:
+            sql = sql.add_columns(
+                sa.func.ST_AsGeoJSON(
+                    sa.case((sa.func.ST_NPoints(col) > 5000,
+                             sa.func.ST_SimplifyPreserveTopology(col, 0.0001)),
+                            else_=col), 7).label('geometry_geojson'))
+        else:
+            sql = sql.add_columns(sa.func.ST_GeometryType(col).label('geometry_type'))
+
+        for row in await conn.execute(sql):
+            self.result = row_func(row, nres.DetailedResult)
+            assert self.result is not None
+            # add missing details
+            if 'type' in self.result.geometry:
+                self.result.geometry['type'] = \
+                    GEOMETRY_TYPE_MAP.get(self.result.geometry['type'],
+                                          self.result.geometry['type'])
+            indexed_date = getattr(row, 'indexed_date', None)
+            if indexed_date is not None:
+                self.result.indexed_date = indexed_date.replace(tzinfo=dt.timezone.utc)
+
+            return True
 
-    row = await find_in_postcode(conn, place, add_geometries)
-    log().var_dump('Result (postcode)', row)
-    if row is not None:
-        return row, nres.create_from_postcode_row
+        # Nothing found.
+        return False
 
-    row = await find_in_tiger(conn, place, add_geometries)
-    log().var_dump('Result (tiger)', row)
-    return row, nres.create_from_tiger_row
+    def enumerate_free_place_ids(self) -> Iterable[Tuple[int, ntyp.PlaceID]]:
+        if self.result is None and isinstance(self.place, ntyp.PlaceID):
+            return [(0, self.place)]
+        return []
+
+    def enumerate_free_osm_ids(self) -> Iterable[Tuple[int, ntyp.OsmID]]:
+        if self.result is None and isinstance(self.place, ntyp.OsmID):
+            return [(0, self.place)]
+        return []
+
+
+Collector = Union[LookupCollector, DetailedCollector]
 
 
 async def get_detailed_place(conn: SearchConnection, place: ntyp.PlaceRef,
@@ -160,91 +154,180 @@ async def get_detailed_place(conn: SearchConnection, place: ntyp.PlaceRef,
     if details.geometry_output and details.geometry_output != ntyp.GeometryFormat.GEOJSON:
         raise ValueError("lookup only supports geojosn polygon output.")
 
-    if details.geometry_output & ntyp.GeometryFormat.GEOJSON:
-        def _add_geometry(sql: SaSelect, column: SaColumn) -> SaSelect:
-            return sql.add_columns(sa.func.ST_AsGeoJSON(
-                                    sa.case((sa.func.ST_NPoints(column) > 5000,
-                                             sa.func.ST_SimplifyPreserveTopology(column, 0.0001)),
-                                            else_=column), 7).label('geometry_geojson'))
-    else:
-        def _add_geometry(sql: SaSelect, column: SaColumn) -> SaSelect:
-            return sql.add_columns(sa.func.ST_GeometryType(column).label('geometry_type'))
+    collector = DetailedCollector(place,
+                                  bool(details.geometry_output & ntyp.GeometryFormat.GEOJSON))
 
-    row_func: RowFunc[nres.DetailedResult]
-    row, row_func = await find_in_all_tables(conn, place, _add_geometry)
+    for func in (find_in_placex, find_in_osmline, find_in_postcode, find_in_tiger):
+        if await func(conn, collector):
+            break
 
-    if row is None:
-        return None
+    if collector.result is not None:
+        await nres.add_result_details(conn, [collector.result], details)
 
-    result = row_func(row, nres.DetailedResult)
-    assert result is not None
+    return collector.result
 
-    # add missing details
-    assert result is not None
-    if 'type' in result.geometry:
-        result.geometry['type'] = GEOMETRY_TYPE_MAP.get(result.geometry['type'],
-                                                        result.geometry['type'])
-    indexed_date = getattr(row, 'indexed_date', None)
-    if indexed_date is not None:
-        result.indexed_date = indexed_date.replace(tzinfo=dt.timezone.utc)
 
-    await nres.add_result_details(conn, [result], details)
+async def get_places(conn: SearchConnection, places: Iterable[ntyp.PlaceRef],
+                     details: ntyp.LookupDetails) -> nres.SearchResults:
+    """ Retrieve a list of places as simple search results from the
+        database.
+    """
+    log().function('get_places', places=places, details=details)
+
+    collector = LookupCollector(places, details)
 
-    return result
+    for func in (find_in_placex, find_in_osmline, find_in_postcode, find_in_tiger):
+        if await func(conn, collector):
+            break
 
+    results = collector.get_results()
+    await nres.add_result_details(conn, results, details)
 
-async def get_simple_place(conn: SearchConnection, place: ntyp.PlaceRef,
-                           details: ntyp.LookupDetails) -> Optional[nres.SearchResult]:
-    """ Retrieve a place as a simple search result from the database.
+    return results
+
+
+async def find_in_placex(conn: SearchConnection, collector: Collector) -> bool:
+    """ Search for the given places in the main placex table.
     """
-    log().function('get_simple_place', place=place, details=details)
+    log().section("Find in placex table")
+    t = conn.t.placex
+    sql = sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
+                    t.c.class_, t.c.type, t.c.admin_level,
+                    t.c.address, t.c.extratags,
+                    t.c.housenumber, t.c.postcode, t.c.country_code,
+                    t.c.importance, t.c.wikipedia, t.c.indexed_date,
+                    t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
+                    t.c.linked_place_id,
+                    t.c.geometry.ST_Expand(0).label('bbox'),
+                    t.c.centroid)
 
-    def _add_geometry(sql: SaSelect, col: SaColumn) -> SaSelect:
-        if not details.geometry_output:
-            return sql
+    osm_ids = [{'i': i, 'ot': p.osm_type, 'oi': p.osm_id, 'oc': p.osm_class or ''}
+               for i, p in collector.enumerate_free_osm_ids()]
 
-        out = []
+    if osm_ids:
+        oid_tab = sa.func.JsonArrayEach(sa.type_coerce(osm_ids, sa.JSON))\
+                    .table_valued(sa.column('value', type_=sa.JSON))
+        psql = sql.add_columns(oid_tab.c.value['i'].as_integer().label('_idx'))\
+                  .where(t.c.osm_type == oid_tab.c.value['ot'].as_string())\
+                  .where(t.c.osm_id == oid_tab.c.value['oi'].as_string().cast(sa.BigInteger))\
+                  .where(sa.or_(oid_tab.c.value['oc'].as_string() == '',
+                                oid_tab.c.value['oc'].as_string() == t.c.class_))\
+                  .order_by(t.c.class_)
 
-        if details.geometry_simplification > 0.0:
-            col = sa.func.ST_SimplifyPreserveTopology(col, details.geometry_simplification)
+        if await collector.add_rows_from_sql(conn, psql, t.c.geometry,
+                                             nres.create_from_placex_row):
+            return True
 
-        if details.geometry_output & ntyp.GeometryFormat.GEOJSON:
-            out.append(sa.func.ST_AsGeoJSON(col, 7).label('geometry_geojson'))
-        if details.geometry_output & ntyp.GeometryFormat.TEXT:
-            out.append(sa.func.ST_AsText(col).label('geometry_text'))
-        if details.geometry_output & ntyp.GeometryFormat.KML:
-            out.append(sa.func.ST_AsKML(col, 7).label('geometry_kml'))
-        if details.geometry_output & ntyp.GeometryFormat.SVG:
-            out.append(sa.func.ST_AsSVG(col, 0, 7).label('geometry_svg'))
+    place_ids = [{'i': i, 'id': p.place_id}
+                 for i, p in collector.enumerate_free_place_ids()]
 
-        return sql.add_columns(*out)
+    if place_ids:
+        pid_tab = sa.func.JsonArrayEach(sa.type_coerce(place_ids, sa.JSON))\
+                    .table_valued(sa.column('value', type_=sa.JSON))
+        psql = sql.add_columns(pid_tab.c.value['i'].as_integer().label('_idx'))\
+                  .where(t.c.place_id == pid_tab.c.value['id'].as_string().cast(sa.BigInteger))
 
+        return await collector.add_rows_from_sql(conn, psql, t.c.geometry,
+                                                 nres.create_from_placex_row)
 
-    row_func: RowFunc[nres.SearchResult]
-    row, row_func = await find_in_all_tables(conn, place, _add_geometry)
+    return False
 
-    if row is None:
-        return None
 
-    result = row_func(row, nres.SearchResult)
-    assert result is not None
+async def find_in_osmline(conn: SearchConnection, collector: Collector) -> bool:
+    """ Search for the given places in the table for address interpolations.
 
-    # add missing details
-    assert result is not None
-    if hasattr(row, 'bbox'):
-        result.bbox = ntyp.Bbox.from_wkb(row.bbox)
+        Return true when all places have been resolved.
+    """
+    log().section("Find in interpolation table")
+    t = conn.t.osmline
+    sql = sa.select(t.c.place_id, t.c.osm_id, t.c.parent_place_id,
+                    t.c.indexed_date, t.c.startnumber, t.c.endnumber,
+                    t.c.step, t.c.address, t.c.postcode, t.c.country_code,
+                    t.c.linegeo.ST_Centroid().label('centroid'))
 
-    await nres.add_result_details(conn, [result], details)
+    osm_ids = [{'i': i, 'oi': p.osm_id, 'oc': p.class_as_housenumber()}
+               for i, p in collector.enumerate_free_osm_ids() if p.osm_type == 'W']
 
-    return result
+    if osm_ids:
+        oid_tab = sa.func.JsonArrayEach(sa.type_coerce(osm_ids, sa.JSON))\
+                    .table_valued(sa.column('value', type_=sa.JSON))
+        psql = sql.add_columns(oid_tab.c.value['i'].as_integer().label('_idx'))\
+                  .where(t.c.osm_id == oid_tab.c.value['oi'].as_string().cast(sa.BigInteger))\
+                  .order_by(sa.func.greatest(0,
+                                             oid_tab.c.value['oc'].as_integer() - t.c.endnumber,
+                                             t.c.startnumber - oid_tab.c.value['oc'].as_integer()))
 
+        if await collector.add_rows_from_sql(conn, psql, t.c.linegeo,
+                                             nres.create_from_osmline_row):
+            return True
 
-GEOMETRY_TYPE_MAP = {
-    'POINT': 'ST_Point',
-    'MULTIPOINT': 'ST_MultiPoint',
-    'LINESTRING': 'ST_LineString',
-    'MULTILINESTRING': 'ST_MultiLineString',
-    'POLYGON': 'ST_Polygon',
-    'MULTIPOLYGON': 'ST_MultiPolygon',
-    'GEOMETRYCOLLECTION': 'ST_GeometryCollection'
-}
+    place_ids = [{'i': i, 'id': p.place_id}
+                 for i, p in collector.enumerate_free_place_ids()]
+
+    if place_ids:
+        pid_tab = sa.func.JsonArrayEach(sa.type_coerce(place_ids, sa.JSON))\
+                    .table_valued(sa.column('value', type_=sa.JSON))
+        psql = sql.add_columns(pid_tab.c.value['i'].label('_idx'))\
+                  .where(t.c.place_id == pid_tab.c.value['id'].as_string().cast(sa.BigInteger))
+
+        return await collector.add_rows_from_sql(conn, psql, t.c.linegeo,
+                                                 nres.create_from_osmline_row)
+
+    return False
+
+
+async def find_in_postcode(conn: SearchConnection, collector: Collector) -> bool:
+    """ Search for the given places in the postcode table.
+
+        Return true when all places have been resolved.
+    """
+    log().section("Find in postcode table")
+
+    place_ids = [{'i': i, 'id': p.place_id}
+                 for i, p in collector.enumerate_free_place_ids()]
+
+    if place_ids:
+        pid_tab = sa.func.JsonArrayEach(sa.type_coerce(place_ids, sa.JSON))\
+                    .table_valued(sa.column('value', type_=sa.JSON))
+        t = conn.t.postcode
+        sql = sa.select(pid_tab.c.value['i'].as_integer().label('_idx'),
+                        t.c.place_id, t.c.parent_place_id,
+                        t.c.rank_search, t.c.rank_address,
+                        t.c.indexed_date, t.c.postcode, t.c.country_code,
+                        t.c.geometry.label('centroid'))\
+                .where(t.c.place_id == pid_tab.c.value['id'].as_string().cast(sa.BigInteger))
+
+        return await collector.add_rows_from_sql(conn, sql, t.c.geometry,
+                                                 nres.create_from_postcode_row)
+
+    return False
+
+
+async def find_in_tiger(conn: SearchConnection, collector: Collector) -> bool:
+    """ Search for the given places in the TIGER address table.
+
+        Return true when all places have been resolved.
+    """
+    log().section("Find in tiger table")
+
+    place_ids = [{'i': i, 'id': p.place_id}
+                 for i, p in collector.enumerate_free_place_ids()]
+
+    if place_ids:
+        pid_tab = sa.func.JsonArrayEach(sa.type_coerce(place_ids, sa.JSON))\
+                    .table_valued(sa.column('value', type_=sa.JSON))
+        t = conn.t.tiger
+        parent = conn.t.placex
+        sql = sa.select(pid_tab.c.value['i'].as_integer().label('_idx'),
+                        t.c.place_id, t.c.parent_place_id,
+                        parent.c.osm_type, parent.c.osm_id,
+                        t.c.startnumber, t.c.endnumber, t.c.step,
+                        t.c.postcode,
+                        t.c.linegeo.ST_Centroid().label('centroid'))\
+                .join(parent, t.c.parent_place_id == parent.c.place_id, isouter=True)\
+                .where(t.c.place_id == pid_tab.c.value['id'].as_string().cast(sa.BigInteger))
+
+        return await collector.add_rows_from_sql(conn, sql, t.c.linegeo,
+                                                 nres.create_from_tiger_row)
+
+    return False