]> git.openstreetmap.org Git - nominatim.git/commitdiff
implement actual database searches
authorSarah Hoffmann <lonvia@denofr.de>
Wed, 24 May 2023 11:52:31 +0000 (13:52 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Wed, 24 May 2023 11:52:31 +0000 (13:52 +0200)
13 files changed:
.pylintrc
nominatim/api/connection.py
nominatim/api/results.py
nominatim/api/search/db_search_fields.py
nominatim/api/search/db_searches.py
nominatim/api/types.py
nominatim/typing.py
test/python/api/conftest.py
test/python/api/search/test_search_country.py [new file with mode: 0644]
test/python/api/search/test_search_near.py [new file with mode: 0644]
test/python/api/search/test_search_places.py [new file with mode: 0644]
test/python/api/search/test_search_poi.py [new file with mode: 0644]
test/python/api/search/test_search_postcode.py [new file with mode: 0644]

index f2d3491f27d60b6ac61bd44865844f3f78a38aa0..c1384c00958a681767ea117f602f24f796591222 100644 (file)
--- a/.pylintrc
+++ b/.pylintrc
@@ -15,4 +15,4 @@ ignored-classes=NominatimArgs,closing
 #   typed Python is enabled. See also https://github.com/PyCQA/pylint/issues/5273
 disable=too-few-public-methods,duplicate-code,too-many-ancestors,bad-option-value,no-self-use,not-context-manager,use-dict-literal,chained-comparison,attribute-defined-outside-init
 
-good-names=i,j,x,y,m,t,fd,db,cc,x1,x2,y1,y2,pt,k,v
+good-names=i,j,x,y,m,t,fd,db,cc,x1,x2,y1,y2,pt,k,v,nr
index efa4490e0cbf6d263398a4cf21d2f883b419bad9..e157d06208a48403de26c7b87189c087a6dcc99c 100644 (file)
@@ -7,11 +7,13 @@
 """
 Extended SQLAlchemy connection class that also includes access to the schema.
 """
-from typing import Any, Mapping, Sequence, Union, Dict, cast
+from typing import cast, Any, Mapping, Sequence, Union, Dict, Optional, Set
 
 import sqlalchemy as sa
+from geoalchemy2 import Geometry
 from sqlalchemy.ext.asyncio import AsyncConnection
 
+from nominatim.typing import SaFromClause
 from nominatim.db.sqlalchemy_schema import SearchTables
 from nominatim.api.logging import log
 
@@ -28,6 +30,7 @@ class SearchConnection:
         self.connection = conn
         self.t = tables # pylint: disable=invalid-name
         self._property_cache = properties
+        self._classtables: Optional[Set[str]] = None
 
 
     async def scalar(self, sql: sa.sql.base.Executable,
@@ -87,3 +90,26 @@ class SearchConnection:
             raise ValueError(f"DB setting '{name}' not found in database.")
 
         return self._property_cache['DB:server_version']
+
+
+    async def get_class_table(self, cls: str, typ: str) -> Optional[SaFromClause]:
+        """ Lookup up if there is a classtype table for the given category
+            and return a SQLAlchemy table for it, if it exists.
+        """
+        if self._classtables is None:
+            res = await self.execute(sa.text("""SELECT tablename FROM pg_tables
+                                                WHERE tablename LIKE 'place_classtype_%'
+                                             """))
+            self._classtables = {r[0] for r in res}
+
+        tablename = f"place_classtype_{cls}_{typ}"
+
+        if tablename not in self._classtables:
+            return None
+
+        if tablename in self.t.meta.tables:
+            return self.t.meta.tables[tablename]
+
+        return sa.Table(tablename, self.t.meta,
+                        sa.Column('place_id', sa.BigInteger),
+                        sa.Column('centroid', Geometry(srid=4326, spatial_index=False)))
index 56243e8d7672121c7cf5471f46c19f09aa12a799..1c313398e0d94b88ec68e32e16f8de6da06d86d7 100644 (file)
@@ -179,6 +179,15 @@ class SearchResult(BaseResult):
     """ A search result for forward geocoding.
     """
     bbox: Optional[Bbox] = None
+    accuracy: float = 0.0
+
+
+    @property
+    def ranking(self) -> float:
+        """ Return the ranking, a combined measure of accuracy and importance.
+        """
+        return (self.accuracy if self.accuracy is not None else 1) \
+               - self.calculated_importance()
 
 
 class SearchResults(List[SearchResult]):
@@ -306,6 +315,23 @@ def create_from_postcode_row(row: Optional[SaRow],
                       geometry=_filter_geometries(row))
 
 
+def create_from_country_row(row: Optional[SaRow],
+                        class_type: Type[BaseResultT]) -> Optional[BaseResultT]:
+    """ Construct a new result and add the data from the result row
+        from the fallback country tables. 'class_type' defines
+        the type of result to return. Returns None if the row is None.
+    """
+    if row is None:
+        return None
+
+    return class_type(source_table=SourceTable.COUNTRY,
+                      category=('place', 'country'),
+                      centroid=Point.from_wkb(row.centroid.data),
+                      names=row.name,
+                      rank_address=4, rank_search=4,
+                      country_code=row.country_code)
+
+
 async def add_result_details(conn: SearchConnection, result: BaseResult,
                              details: LookupDetails) -> None:
     """ Retrieve more details from the database according to the
index 9fcc2c4e521e9aa3ba55207edcf438af79a26949..325e08df3559f9afd362242c69ec043c4ff617b0 100644 (file)
@@ -7,13 +7,13 @@
 """
 Data structures for more complex fields in abstract search descriptions.
 """
-from typing import List, Tuple, cast
+from typing import List, Tuple, Iterator, cast
 import dataclasses
 
 import sqlalchemy as sa
 from sqlalchemy.dialects.postgresql import ARRAY
 
-from nominatim.typing import SaFromClause, SaColumn
+from nominatim.typing import SaFromClause, SaColumn, SaExpression
 from nominatim.api.search.query import Token
 
 @dataclasses.dataclass
@@ -27,6 +27,21 @@ class WeightedStrings:
         return bool(self.values)
 
 
+    def __iter__(self) -> Iterator[Tuple[str, float]]:
+        return iter(zip(self.values, self.penalties))
+
+
+    def get_penalty(self, value: str, default: float = 1000.0) -> float:
+        """ Get the penalty for the given value. Returns the given default
+            if the value does not exist.
+        """
+        try:
+            return self.penalties[self.values.index(value)]
+        except ValueError:
+            pass
+        return default
+
+
 @dataclasses.dataclass
 class WeightedCategories:
     """ A list of class/type tuples together with a penalty.
@@ -38,6 +53,36 @@ class WeightedCategories:
         return bool(self.values)
 
 
+    def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
+        return iter(zip(self.values, self.penalties))
+
+
+    def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
+        """ Get the penalty for the given value. Returns the given default
+            if the value does not exist.
+        """
+        try:
+            return self.penalties[self.values.index(value)]
+        except ValueError:
+            pass
+        return default
+
+
+    def sql_restrict(self, table: SaFromClause) -> SaExpression:
+        """ Return an SQLAlcheny expression that restricts the
+            class and type columns of the given table to the values
+            in the list.
+            Must not be used with an empty list.
+        """
+        assert self.values
+        if len(self.values) == 1:
+            return sa.and_(table.c.class_ == self.values[0][0],
+                           table.c.type == self.values[0][1])
+
+        return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
+                        for c, t in self.values))
+
+
 @dataclasses.dataclass(order=True)
 class RankedTokens:
     """ List of tokens together with the penalty of using it.
index f0d75ad1f301fa5ceee5127f6035445f9c8ecdf3..9a94b4f658bb3ec49d6388f0400a571716875b57 100644 (file)
 """
 Implementation of the acutal database accesses for forward search.
 """
+from typing import List, Tuple, AsyncIterator
 import abc
 
+import sqlalchemy as sa
+from sqlalchemy.dialects.postgresql import ARRAY, array_agg
+
+from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \
+                             SaExpression, SaSelect, SaRow
 from nominatim.api.connection import SearchConnection
-from nominatim.api.types import SearchDetails
+from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox
 import nominatim.api.results as nres
 from nominatim.api.search.db_search_fields import SearchData, WeightedCategories
 
+#pylint: disable=singleton-comparison
+#pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements
+
+def _select_placex(t: SaFromClause) -> SaSelect:
+    return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
+                     t.c.class_, t.c.type,
+                     t.c.address, t.c.extratags,
+                     t.c.housenumber, t.c.postcode, t.c.country_code,
+                     t.c.importance, t.c.wikipedia,
+                     t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
+                     t.c.centroid,
+                     t.c.geometry.ST_Expand(0).label('bbox'))
+
+
+def _add_geometry_columns(sql: SaSelect, col: SaColumn, details: SearchDetails) -> SaSelect:
+    if not details.geometry_output:
+        return sql
+
+    out = []
+
+    if details.geometry_simplification > 0.0:
+        col = col.ST_SimplifyPreserveTopology(details.geometry_simplification)
+
+    if details.geometry_output & GeometryFormat.GEOJSON:
+        out.append(col.ST_AsGeoJSON().label('geometry_geojson'))
+    if details.geometry_output & GeometryFormat.TEXT:
+        out.append(col.ST_AsText().label('geometry_text'))
+    if details.geometry_output & GeometryFormat.KML:
+        out.append(col.ST_AsKML().label('geometry_kml'))
+    if details.geometry_output & GeometryFormat.SVG:
+        out.append(col.ST_AsSVG().label('geometry_svg'))
+
+    return sql.add_columns(*out)
+
+
+def _make_interpolation_subquery(table: SaFromClause, inner: SaFromClause,
+                                 numerals: List[int], details: SearchDetails) -> SaScalarSelect:
+    all_ids = array_agg(table.c.place_id) # type: ignore[no-untyped-call]
+    sql = sa.select(all_ids).where(table.c.parent_place_id == inner.c.place_id)
+
+    if len(numerals) == 1:
+        sql = sql.where(sa.between(numerals[0], table.c.startnumber, table.c.endnumber))\
+                 .where((numerals[0] - table.c.startnumber) % table.c.step == 0)
+    else:
+        sql = sql.where(sa.or_(
+                *(sa.and_(sa.between(n, table.c.startnumber, table.c.endnumber),
+                          (n - table.c.startnumber) % table.c.step == 0)
+                  for n in numerals)))
+
+    if details.excluded:
+        sql = sql.where(table.c.place_id.not_in(details.excluded))
+
+    return sql.scalar_subquery()
+
+
+def _filter_by_layer(table: SaFromClause, layers: DataLayer) -> SaColumn:
+    orexpr: List[SaExpression] = []
+    if layers & DataLayer.ADDRESS and layers & DataLayer.POI:
+        orexpr.append(table.c.rank_address.between(1, 30))
+    elif layers & DataLayer.ADDRESS:
+        orexpr.append(table.c.rank_address.between(1, 29))
+        orexpr.append(sa.and_(table.c.rank_address == 30,
+                              sa.or_(table.c.housenumber != None,
+                                     table.c.address.has_key('housename'))))
+    elif layers & DataLayer.POI:
+        orexpr.append(sa.and_(table.c.rank_address == 30,
+                              table.c.class_.not_in(('place', 'building'))))
+
+    if layers & DataLayer.MANMADE:
+        exclude = []
+        if not layers & DataLayer.RAILWAY:
+            exclude.append('railway')
+        if not layers & DataLayer.NATURAL:
+            exclude.extend(('natural', 'water', 'waterway'))
+        orexpr.append(sa.and_(table.c.class_.not_in(tuple(exclude)),
+                              table.c.rank_address == 0))
+    else:
+        include = []
+        if layers & DataLayer.RAILWAY:
+            include.append('railway')
+        if layers & DataLayer.NATURAL:
+            include.extend(('natural', 'water', 'waterway'))
+        orexpr.append(sa.and_(table.c.class_.in_(tuple(include)),
+                              table.c.rank_address == 0))
+
+    if len(orexpr) == 1:
+        return orexpr[0]
+
+    return sa.or_(*orexpr)
+
+
+def _interpolated_position(table: SaFromClause, nr: SaColumn) -> SaColumn:
+    pos = sa.cast(nr - table.c.startnumber, sa.Float) / (table.c.endnumber - table.c.startnumber)
+    return sa.case(
+            (table.c.endnumber == table.c.startnumber, table.c.linegeo.ST_Centroid()),
+            else_=table.c.linegeo.ST_LineInterpolatePoint(pos)).label('centroid')
+
+
+async def _get_placex_housenumbers(conn: SearchConnection,
+                                   place_ids: List[int],
+                                   details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
+    t = conn.t.placex
+    sql = _select_placex(t).where(t.c.place_id.in_(place_ids))
+
+    sql = _add_geometry_columns(sql, t.c.geometry, details)
+
+    for row in await conn.execute(sql):
+        result = nres.create_from_placex_row(row, nres.SearchResult)
+        assert result
+        result.bbox = Bbox.from_wkb(row.bbox.data)
+        yield result
+
+
+async def _get_osmline(conn: SearchConnection, place_ids: List[int],
+                       numerals: List[int],
+                       details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
+    t = conn.t.osmline
+    values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\
+               .data([(n,) for n in numerals])
+    sql = sa.select(t.c.place_id, t.c.osm_id,
+                    t.c.parent_place_id, t.c.address,
+                    values.c.nr.label('housenumber'),
+                    _interpolated_position(t, values.c.nr),
+                    t.c.postcode, t.c.country_code)\
+            .where(t.c.place_id.in_(place_ids))\
+            .join(values, values.c.nr.between(t.c.startnumber, t.c.endnumber))
+
+    if details.geometry_output:
+        sub = sql.subquery()
+        sql = _add_geometry_columns(sa.select(sub), sub.c.centroid, details)
+
+    for row in await conn.execute(sql):
+        result = nres.create_from_osmline_row(row, nres.SearchResult)
+        assert result
+        yield result
+
+
+async def _get_tiger(conn: SearchConnection, place_ids: List[int],
+                     numerals: List[int], osm_id: int,
+                     details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
+    t = conn.t.tiger
+    values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\
+               .data([(n,) for n in numerals])
+    sql = sa.select(t.c.place_id, t.c.parent_place_id,
+                    sa.literal('W').label('osm_type'),
+                    sa.literal(osm_id).label('osm_id'),
+                    values.c.nr.label('housenumber'),
+                    _interpolated_position(t, values.c.nr),
+                    t.c.postcode)\
+            .where(t.c.place_id.in_(place_ids))\
+            .join(values, values.c.nr.between(t.c.startnumber, t.c.endnumber))
+
+    if details.geometry_output:
+        sub = sql.subquery()
+        sql = _add_geometry_columns(sa.select(sub), sub.c.centroid, details)
+
+    for row in await conn.execute(sql):
+        result = nres.create_from_tiger_row(row, nres.SearchResult)
+        assert result
+        yield result
+
+
 class AbstractSearch(abc.ABC):
     """ Encapuslation of a single lookup in the database.
     """
@@ -42,7 +210,79 @@ class NearSearch(AbstractSearch):
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
         """
-        return nres.SearchResults([])
+        results = nres.SearchResults()
+        base = await self.search.lookup(conn, details)
+
+        if not base:
+            return results
+
+        base.sort(key=lambda r: (r.accuracy, r.rank_search))
+        max_accuracy = base[0].accuracy + 0.5
+        base = nres.SearchResults(r for r in base if r.source_table == nres.SourceTable.PLACEX
+                                                     and r.accuracy <= max_accuracy
+                                                     and r.bbox and r.bbox.area < 20)
+
+        if base:
+            baseids = [b.place_id for b in base[:5] if b.place_id]
+
+            for category, penalty in self.categories:
+                await self.lookup_category(results, conn, baseids, category, penalty, details)
+                if len(results) >= details.max_results:
+                    break
+
+        return results
+
+
+    async def lookup_category(self, results: nres.SearchResults,
+                              conn: SearchConnection, ids: List[int],
+                              category: Tuple[str, str], penalty: float,
+                              details: SearchDetails) -> None:
+        """ Find places of the given category near the list of
+            place ids and add the results to 'results'.
+        """
+        table = await conn.get_class_table(*category)
+
+        t = conn.t.placex.alias('p')
+        tgeom = conn.t.placex.alias('pgeom')
+
+        sql = _select_placex(t).where(tgeom.c.place_id.in_(ids))\
+                               .where(t.c.class_ == category[0])\
+                               .where(t.c.type == category[1])
+
+        if table is None:
+            # No classtype table available, do a simplified lookup in placex.
+            sql = sql.join(tgeom, t.c.geometry.ST_DWithin(tgeom.c.centroid, 0.01))\
+                     .order_by(tgeom.c.centroid.ST_Distance(t.c.centroid))
+        else:
+            # Use classtype table. We can afford to use a larger
+            # radius for the lookup.
+            sql = sql.join(table, t.c.place_id == table.c.place_id)\
+                     .join(tgeom,
+                           sa.case((sa.and_(tgeom.c.rank_address < 9,
+                                            tgeom.c.geometry.ST_GeometryType().in_(
+                                                ('ST_Polygon', 'ST_MultiPolygon'))),
+                                    tgeom.c.geometry.ST_Contains(table.c.centroid)),
+                                   else_ = tgeom.c.centroid.ST_DWithin(table.c.centroid, 0.05)))\
+                     .order_by(tgeom.c.centroid.ST_Distance(table.c.centroid))
+
+        if details.countries:
+            sql = sql.where(t.c.country_code.in_(details.countries))
+        if details.min_rank > 0:
+            sql = sql.where(t.c.rank_address >= details.min_rank)
+        if details.max_rank < 30:
+            sql = sql.where(t.c.rank_address <= details.max_rank)
+        if details.excluded:
+            sql = sql.where(t.c.place_id.not_in(details.excluded))
+        if details.layers is not None:
+            sql = sql.where(_filter_by_layer(t, details.layers))
+
+        for row in await conn.execute(sql.limit(details.max_results)):
+            result = nres.create_from_placex_row(row, nres.SearchResult)
+            assert result
+            result.accuracy = self.penalty + penalty
+            result.bbox = Bbox.from_wkb(row.bbox.data)
+            results.append(result)
+
 
 
 class PoiSearch(AbstractSearch):
@@ -58,7 +298,65 @@ class PoiSearch(AbstractSearch):
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
         """
-        return nres.SearchResults([])
+        t = conn.t.placex
+
+        rows: List[SaRow] = []
+
+        if details.near and details.near_radius is not None and details.near_radius < 0.2:
+            # simply search in placex table
+            sql = _select_placex(t) \
+                      .where(t.c.linked_place_id == None) \
+                      .where(t.c.geometry.ST_DWithin(details.near.sql_value(),
+                                                     details.near_radius)) \
+                      .order_by(t.c.centroid.ST_Distance(details.near.sql_value()))
+
+            if self.countries:
+                sql = sql.where(t.c.country_code.in_(self.countries.values))
+
+            if details.viewbox is not None and details.bounded_viewbox:
+                sql = sql.where(t.c.geometry.intersects(details.viewbox.sql_value()))
+
+            classtype = self.categories.values
+            if len(classtype) == 1:
+                sql = sql.where(t.c.class_ == classtype[0][0]) \
+                         .where(t.c.type == classtype[0][1])
+            else:
+                sql = sql.where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ)
+                                         for cls, typ in classtype)))
+
+            rows.extend(await conn.execute(sql.limit(details.max_results)))
+        else:
+            # use the class type tables
+            for category in self.categories.values:
+                table = await conn.get_class_table(*category)
+                if table is not None:
+                    sql = _select_placex(t)\
+                               .join(table, t.c.place_id == table.c.place_id)\
+                               .where(t.c.class_ == category[0])\
+                               .where(t.c.type == category[1])
+
+                    if details.viewbox is not None and details.bounded_viewbox:
+                        sql = sql.where(table.c.centroid.intersects(details.viewbox.sql_value()))
+
+                    if details.near:
+                        sql = sql.order_by(table.c.centroid.ST_Distance(details.near.sql_value()))\
+                                 .where(table.c.centroid.ST_DWithin(details.near.sql_value(),
+                                                                    details.near_radius or 0.5))
+
+                    if self.countries:
+                        sql = sql.where(t.c.country_code.in_(self.countries.values))
+
+                    rows.extend(await conn.execute(sql.limit(details.max_results)))
+
+        results = nres.SearchResults()
+        for row in rows:
+            result = nres.create_from_placex_row(row, nres.SearchResult)
+            assert result
+            result.accuracy = self.penalty + self.categories.get_penalty((row.class_, row.type))
+            result.bbox = Bbox.from_wkb(row.bbox.data)
+            results.append(result)
+
+        return results
 
 
 class CountrySearch(AbstractSearch):
@@ -73,7 +371,72 @@ class CountrySearch(AbstractSearch):
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
         """
-        return nres.SearchResults([])
+        t = conn.t.placex
+
+        sql = _select_placex(t)\
+                .where(t.c.country_code.in_(self.countries.values))\
+                .where(t.c.rank_address == 4)
+
+        sql = _add_geometry_columns(sql, t.c.geometry, details)
+
+        if details.excluded:
+            sql = sql.where(t.c.place_id.not_in(details.excluded))
+
+        if details.viewbox is not None and details.bounded_viewbox:
+            sql = sql.where(t.c.geometry.intersects(details.viewbox.sql_value()))
+
+        if details.near is not None and details.near_radius is not None:
+            sql = sql.where(t.c.geometry.ST_DWithin(details.near.sql_value(),
+                                                    details.near_radius))
+
+        results = nres.SearchResults()
+        for row in await conn.execute(sql):
+            result = nres.create_from_placex_row(row, nres.SearchResult)
+            assert result
+            result.accuracy = self.penalty + self.countries.get_penalty(row.country_code, 5.0)
+            results.append(result)
+
+        return results or await self.lookup_in_country_table(conn, details)
+
+
+    async def lookup_in_country_table(self, conn: SearchConnection,
+                                      details: SearchDetails) -> nres.SearchResults:
+        """ Look up the country in the fallback country tables.
+        """
+        t = conn.t.country_name
+        tgrid = conn.t.country_grid
+
+        sql = sa.select(tgrid.c.country_code,
+                        tgrid.c.geometry.ST_Centroid().ST_Collect().ST_Centroid()
+                              .label('centroid'))\
+                .where(tgrid.c.country_code.in_(self.countries.values))\
+                .group_by(tgrid.c.country_code)
+
+        if details.viewbox is not None and details.bounded_viewbox:
+            sql = sql.where(tgrid.c.geometry.intersects(details.viewbox.sql_value()))
+        if details.near is not None and details.near_radius is not None:
+            sql = sql.where(tgrid.c.geometry.ST_DWithin(details.near.sql_value(),
+                                                        details.near_radius))
+
+        sub = sql.subquery('grid')
+
+        sql = sa.select(t.c.country_code,
+                        (t.c.name
+                         + sa.func.coalesce(t.c.derived_name,
+                                            sa.cast('', type_=conn.t.types.Composite))
+                        ).label('name'),
+                        sub.c.centroid)\
+                .join(sub, t.c.country_code == sub.c.country_code)
+
+        results = nres.SearchResults()
+        for row in await conn.execute(sql):
+            result = nres.create_from_country_row(row, nres.SearchResult)
+            assert result
+            result.accuracy = self.penalty + self.countries.get_penalty(row.country_code, 5.0)
+            results.append(result)
+
+        return results
+
 
 
 class PostcodeSearch(AbstractSearch):
@@ -91,7 +454,66 @@ class PostcodeSearch(AbstractSearch):
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
         """
-        return nres.SearchResults([])
+        t = conn.t.postcode
+
+        sql = sa.select(t.c.place_id, t.c.parent_place_id,
+                        t.c.rank_search, t.c.rank_address,
+                        t.c.postcode, t.c.country_code,
+                        t.c.geometry.label('centroid'))\
+                .where(t.c.postcode.in_(self.postcodes.values))
+
+        sql = _add_geometry_columns(sql, t.c.geometry, details)
+
+        penalty: SaExpression = sa.literal(self.penalty)
+
+        if details.viewbox is not None:
+            if details.bounded_viewbox:
+                sql = sql.where(t.c.geometry.intersects(details.viewbox.sql_value()))
+            else:
+                penalty += sa.case((t.c.geometry.intersects(details.viewbox.sql_value()), 0.0),
+                                   (t.c.geometry.intersects(details.viewbox_x2.sql_value()), 1.0),
+                                   else_=2.0)
+
+        if details.near is not None:
+            if details.near_radius is not None:
+                sql = sql.where(t.c.geometry.ST_DWithin(details.near.sql_value(),
+                                                        details.near_radius))
+            sql = sql.order_by(t.c.geometry.ST_Distance(details.near.sql_value()))
+
+        if self.countries:
+            sql = sql.where(t.c.country_code.in_(self.countries.values))
+
+        if details.excluded:
+            sql = sql.where(t.c.place_id.not_in(details.excluded))
+
+        if self.lookups:
+            assert len(self.lookups) == 1
+            assert self.lookups[0].lookup_type == 'restrict'
+            tsearch = conn.t.search_name
+            sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
+                     .where(sa.func.array_cat(tsearch.c.name_vector,
+                                              tsearch.c.nameaddress_vector,
+                                              type_=ARRAY(sa.Integer))
+                                    .contains(self.lookups[0].tokens))
+
+        for ranking in self.rankings:
+            penalty += ranking.sql_penalty(conn.t.search_name)
+        penalty += sa.case(*((t.c.postcode == v, p) for v, p in self.postcodes),
+                       else_=1.0)
+
+
+        sql = sql.add_columns(penalty.label('accuracy'))
+        sql = sql.order_by('accuracy')
+
+        results = nres.SearchResults()
+        for row in await conn.execute(sql.limit(details.max_results)):
+            result = nres.create_from_postcode_row(row, nres.SearchResult)
+            assert result
+            result.accuracy = row.accuracy
+            results.append(result)
+
+        return results
+
 
 
 class PlaceSearch(AbstractSearch):
@@ -112,4 +534,168 @@ class PlaceSearch(AbstractSearch):
                      details: SearchDetails) -> nres.SearchResults:
         """ Find results for the search in the database.
         """
-        return nres.SearchResults([])
+        t = conn.t.placex.alias('p')
+        tsearch = conn.t.search_name.alias('s')
+        limit = details.max_results
+
+        sql = sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
+                        t.c.class_, t.c.type,
+                        t.c.address, t.c.extratags,
+                        t.c.housenumber, t.c.postcode, t.c.country_code,
+                        t.c.wikipedia,
+                        t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
+                        t.c.centroid,
+                        t.c.geometry.ST_Expand(0).label('bbox'))\
+                .where(t.c.place_id == tsearch.c.place_id)
+
+
+        sql = _add_geometry_columns(sql, t.c.geometry, details)
+
+        penalty: SaExpression = sa.literal(self.penalty)
+        for ranking in self.rankings:
+            penalty += ranking.sql_penalty(tsearch)
+
+        for lookup in self.lookups:
+            sql = sql.where(lookup.sql_condition(tsearch))
+
+        if self.countries:
+            sql = sql.where(tsearch.c.country_code.in_(self.countries.values))
+
+        if self.postcodes:
+            tpc = conn.t.postcode
+            if self.expected_count > 1000:
+                # Many results expected. Restrict by postcode.
+                sql = sql.where(sa.select(tpc.c.postcode)
+                                  .where(tpc.c.postcode.in_(self.postcodes.values))
+                                  .where(tsearch.c.centroid.ST_DWithin(tpc.c.geometry, 0.12))
+                                  .exists())
+
+            # Less results, only have a preference for close postcodes
+            pc_near = sa.select(sa.func.min(tpc.c.geometry.ST_Distance(tsearch.c.centroid)))\
+                      .where(tpc.c.postcode.in_(self.postcodes.values))\
+                      .scalar_subquery()
+            penalty += sa.case((t.c.postcode.in_(self.postcodes.values), 0.0),
+                               else_=sa.func.coalesce(pc_near, 2.0))
+
+        if details.viewbox is not None:
+            if details.bounded_viewbox:
+                sql = sql.where(tsearch.c.centroid.intersects(details.viewbox.sql_value()))
+            else:
+                penalty += sa.case((t.c.geometry.intersects(details.viewbox.sql_value()), 0.0),
+                                   (t.c.geometry.intersects(details.viewbox_x2.sql_value()), 1.0),
+                                   else_=2.0)
+
+        if details.near is not None:
+            if details.near_radius is not None:
+                sql = sql.where(tsearch.c.centroid.ST_DWithin(details.near.sql_value(),
+                                                         details.near_radius))
+            sql = sql.add_columns(-tsearch.c.centroid.ST_Distance(details.near.sql_value())
+                                      .label('importance'))
+            sql = sql.order_by(sa.desc(sa.text('importance')))
+        else:
+            sql = sql.order_by(penalty - sa.case((tsearch.c.importance > 0, tsearch.c.importance),
+                                  else_=0.75001-(sa.cast(tsearch.c.search_rank, sa.Float())/40)))
+            sql = sql.add_columns(t.c.importance)
+
+
+        sql = sql.add_columns(penalty.label('accuracy'))\
+                 .order_by(sa.text('accuracy'))
+
+        if self.housenumbers:
+            hnr_regexp = f"\\m({'|'.join(self.housenumbers.values)})\\M"
+            sql = sql.where(tsearch.c.address_rank.between(16, 30))\
+                     .where(sa.or_(tsearch.c.address_rank < 30,
+                                  t.c.housenumber.regexp_match(hnr_regexp, flags='i')))
+
+            # Cross check for housenumbers, need to do that on a rather large
+            # set. Worst case there are 40.000 main streets in OSM.
+            inner = sql.limit(10000).subquery()
+
+            # Housenumbers from placex
+            thnr = conn.t.placex.alias('hnr')
+            pid_list = array_agg(thnr.c.place_id) # type: ignore[no-untyped-call]
+            place_sql = sa.select(pid_list)\
+                          .where(thnr.c.parent_place_id == inner.c.place_id)\
+                          .where(thnr.c.housenumber.regexp_match(hnr_regexp, flags='i'))\
+                          .where(thnr.c.linked_place_id == None)\
+                          .where(thnr.c.indexed_status == 0)
+
+            if details.excluded:
+                place_sql = place_sql.where(thnr.c.place_id.not_in(details.excluded))
+            if self.qualifiers:
+                place_sql = place_sql.where(self.qualifiers.sql_restrict(thnr))
+
+            numerals = [int(n) for n in self.housenumbers.values if n.isdigit()]
+            interpol_sql: SaExpression
+            tiger_sql: SaExpression
+            if numerals and \
+               (not self.qualifiers or ('place', 'house') in self.qualifiers.values):
+                # Housenumbers from interpolations
+                interpol_sql = _make_interpolation_subquery(conn.t.osmline, inner,
+                                                            numerals, details)
+                # Housenumbers from Tiger
+                tiger_sql = sa.case((inner.c.country_code == 'us',
+                                     _make_interpolation_subquery(conn.t.tiger, inner,
+                                                                  numerals, details)
+                                    ), else_=None)
+            else:
+                interpol_sql = sa.literal(None)
+                tiger_sql = sa.literal(None)
+
+            unsort = sa.select(inner, place_sql.scalar_subquery().label('placex_hnr'),
+                               interpol_sql.label('interpol_hnr'),
+                               tiger_sql.label('tiger_hnr')).subquery('unsort')
+            sql = sa.select(unsort)\
+                    .order_by(sa.case((unsort.c.placex_hnr != None, 1),
+                                      (unsort.c.interpol_hnr != None, 2),
+                                      (unsort.c.tiger_hnr != None, 3),
+                                      else_=4),
+                              unsort.c.accuracy)
+        else:
+            sql = sql.where(t.c.linked_place_id == None)\
+                     .where(t.c.indexed_status == 0)
+            if self.qualifiers:
+                sql = sql.where(self.qualifiers.sql_restrict(t))
+            if details.excluded:
+                sql = sql.where(tsearch.c.place_id.not_in(details.excluded))
+            if details.min_rank > 0:
+                sql = sql.where(sa.or_(tsearch.c.address_rank >= details.min_rank,
+                                       tsearch.c.search_rank >= details.min_rank))
+            if details.max_rank < 30:
+                sql = sql.where(sa.or_(tsearch.c.address_rank <= details.max_rank,
+                                       tsearch.c.search_rank <= details.max_rank))
+            if details.layers is not None:
+                sql = sql.where(_filter_by_layer(t, details.layers))
+
+
+        results = nres.SearchResults()
+        for row in await conn.execute(sql.limit(limit)):
+            result = nres.create_from_placex_row(row, nres.SearchResult)
+            assert result
+            result.bbox = Bbox.from_wkb(row.bbox.data)
+            result.accuracy = row.accuracy
+            if not details.excluded or not result.place_id in details.excluded:
+                results.append(result)
+
+            if self.housenumbers and row.rank_address < 30:
+                if row.placex_hnr:
+                    subs = _get_placex_housenumbers(conn, row.placex_hnr, details)
+                elif row.interpol_hnr:
+                    subs = _get_osmline(conn, row.interpol_hnr, numerals, details)
+                elif row.tiger_hnr:
+                    subs = _get_tiger(conn, row.tiger_hnr, numerals, row.osm_id, details)
+                else:
+                    subs = None
+
+                if subs is not None:
+                    async for sub in subs:
+                        assert sub.housenumber
+                        sub.accuracy = result.accuracy
+                        if not any(nr in self.housenumbers.values
+                                   for nr in sub.housenumber.split(';')):
+                            sub.accuracy += 0.6
+                        results.append(sub)
+
+                result.accuracy += 1.0 # penalty for missing housenumber
+
+        return results
index ff7457ec01ce3bd52df97176966d2c8ce3ae4306..9042e707db920ac46b2139156b3e53fcb5b6006e 100644 (file)
@@ -15,6 +15,9 @@ import enum
 import math
 from struct import unpack
 
+from geoalchemy2 import WKTElement
+import geoalchemy2.functions
+
 from nominatim.errors import UsageError
 
 # pylint: disable=no-member,too-many-boolean-expressions,too-many-instance-attributes
@@ -119,6 +122,12 @@ class Point(NamedTuple):
         return Point(x, y)
 
 
+    def sql_value(self) -> WKTElement:
+        """ Create an SQL expression for the point.
+        """
+        return WKTElement(f'POINT({self.x} {self.y})', srid=4326)
+
+
 
 AnyPoint = Union[Point, Tuple[float, float]]
 
@@ -163,12 +172,26 @@ class Bbox:
         return self.coords[2]
 
 
+    @property
+    def area(self) -> float:
+        """ Return the area of the box in WGS84.
+        """
+        return (self.coords[2] - self.coords[0]) * (self.coords[3] - self.coords[1])
+
+
+    def sql_value(self) -> Any:
+        """ Create an SQL expression for the box.
+        """
+        return geoalchemy2.functions.ST_MakeEnvelope(*self.coords, 4326)
+
+
     def contains(self, pt: Point) -> bool:
         """ Check if the point is inside or on the boundary of the box.
         """
         return self.coords[0] <= pt[0] and self.coords[1] <= pt[1]\
                and self.coords[2] >= pt[0] and self.coords[3] >= pt[1]
 
+
     @staticmethod
     def from_wkb(wkb: Optional[bytes]) -> 'Optional[Bbox]':
         """ Create a Bbox from a bounding box polygon as returned by
@@ -418,7 +441,7 @@ class SearchDetails(LookupDetails):
         if self.viewbox is not None:
             xext = (self.viewbox.maxlon - self.viewbox.minlon)/2
             yext = (self.viewbox.maxlat - self.viewbox.minlat)/2
-            self.viewbox_x2 = Bbox(self.viewbox.minlon - xext, self.viewbox.maxlon - yext,
+            self.viewbox_x2 = Bbox(self.viewbox.minlon - xext, self.viewbox.minlat - yext,
                                    self.viewbox.maxlon + xext, self.viewbox.maxlat + yext)
 
 
index bc4c5534777e537a3d0ffe1ab4a0c6d2b16f2456..d988fe04a3e3d56fb4af801fa4f3cb076c685a4e 100644 (file)
@@ -63,8 +63,10 @@ else:
     TypeAlias = str
 
 SaSelect: TypeAlias = 'sa.Select[Any]'
+SaScalarSelect: TypeAlias = 'sa.ScalarSelect[Any]'
 SaRow: TypeAlias = 'sa.Row[Any]'
 SaColumn: TypeAlias = 'sa.ColumnElement[Any]'
+SaExpression: TypeAlias = 'sa.ColumnElement[bool]'
 SaLabel: TypeAlias = 'sa.Label[Any]'
 SaFromClause: TypeAlias = 'sa.FromClause'
 SaSelectable: TypeAlias = 'sa.Selectable'
index d8a6dfa0ae93dade6097bbcb69482151008de1c5..cfe14e1eddc85797cabaf08672e2ccdf1d63c1c5 100644 (file)
@@ -12,6 +12,8 @@ import pytest
 import time
 import datetime as dt
 
+import sqlalchemy as sa
+
 import nominatim.api as napi
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 import nominatim.api.logging as loglib
@@ -129,6 +131,34 @@ class APITester:
                        'geometry': 'SRID=4326;' + geometry})
 
 
+    def add_country_name(self, country_code, names, partition=0):
+        self.add_data('country_name',
+                      {'country_code': country_code,
+                       'name': names,
+                       'partition': partition})
+
+
+    def add_search_name(self, place_id, **kw):
+        centroid = kw.get('centroid', (23.0, 34.0))
+        self.add_data('search_name',
+                      {'place_id': place_id,
+                       'importance': kw.get('importance', 0.00001),
+                       'search_rank': kw.get('search_rank', 30),
+                       'address_rank': kw.get('address_rank', 30),
+                       'name_vector': kw.get('names', []),
+                       'nameaddress_vector': kw.get('address', []),
+                       'country_code': kw.get('country_code', 'xx'),
+                       'centroid': 'SRID=4326;POINT(%f %f)' % centroid})
+
+
+    def add_class_type_table(self, cls, typ):
+        self.async_to_sync(
+            self.exec_async(sa.text(f"""CREATE TABLE place_classtype_{cls}_{typ}
+                                         AS (SELECT place_id, centroid FROM placex
+                                             WHERE class = '{cls}' AND type = '{typ}')
+                                     """)))
+
+
     async def exec_async(self, sql, *args, **kwargs):
         async with self.api._async_api.begin() as conn:
             return await conn.execute(sql, *args, **kwargs)
diff --git a/test/python/api/search/test_search_country.py b/test/python/api/search/test_search_country.py
new file mode 100644 (file)
index 0000000..bb0abc3
--- /dev/null
@@ -0,0 +1,61 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Tests for running the country searcher.
+"""
+import pytest
+
+import nominatim.api as napi
+from nominatim.api.types import SearchDetails
+from nominatim.api.search.db_searches import CountrySearch
+from nominatim.api.search.db_search_fields import WeightedStrings
+
+
+def run_search(apiobj, global_penalty, ccodes,
+               country_penalties=None, details=SearchDetails()):
+    if country_penalties is None:
+        country_penalties = [0.0] * len(ccodes)
+
+    class MySearchData:
+        penalty = global_penalty
+        countries = WeightedStrings(ccodes, country_penalties)
+
+    search = CountrySearch(MySearchData())
+
+    async def run():
+        async with apiobj.api._async_api.begin() as conn:
+            return await search.lookup(conn, details)
+
+    return apiobj.async_to_sync(run())
+
+
+def test_find_from_placex(apiobj):
+    apiobj.add_placex(place_id=55, class_='boundary', type='administrative',
+                      rank_search=4, rank_address=4,
+                      name={'name': 'Lolaland'},
+                      country_code='yw',
+                      centroid=(10, 10),
+                      geometry='POLYGON((9.5 9.5, 9.5 10.5, 10.5 10.5, 10.5 9.5, 9.5 9.5))')
+
+    results = run_search(apiobj, 0.5, ['de', 'yw'], [0.0, 0.3])
+
+    assert len(results) == 1
+    assert results[0].place_id == 55
+    assert results[0].accuracy == 0.8
+
+def test_find_from_fallback_countries(apiobj):
+    apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
+    apiobj.add_country_name('ro', {'name': 'România'})
+
+    results = run_search(apiobj, 0.0, ['ro'])
+
+    assert len(results) == 1
+    assert results[0].names == {'name': 'România'}
+
+
+def test_find_none(apiobj):
+    assert len(run_search(apiobj, 0.0, ['xx'])) == 0
diff --git a/test/python/api/search/test_search_near.py b/test/python/api/search/test_search_near.py
new file mode 100644 (file)
index 0000000..cfbdadb
--- /dev/null
@@ -0,0 +1,102 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Tests for running the near searcher.
+"""
+import pytest
+
+import nominatim.api as napi
+from nominatim.api.types import SearchDetails
+from nominatim.api.search.db_searches import NearSearch, PlaceSearch
+from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories,\
+                                                  FieldLookup, FieldRanking, RankedTokens
+
+
+def run_search(apiobj, global_penalty, cat, cat_penalty=None,
+               details=SearchDetails()):
+
+    class PlaceSearchData:
+        penalty = 0.0
+        postcodes = WeightedStrings([], [])
+        countries = WeightedStrings([], [])
+        housenumbers = WeightedStrings([], [])
+        qualifiers = WeightedStrings([], [])
+        lookups = [FieldLookup('name_vector', [56], 'lookup_all')]
+        rankings = []
+
+    place_search = PlaceSearch(0.0, PlaceSearchData(), 2)
+
+    if cat_penalty is None:
+        cat_penalty = [0.0] * len(cat)
+
+    near_search = NearSearch(0.1, WeightedCategories(cat, cat_penalty), place_search)
+
+    async def run():
+        async with apiobj.api._async_api.begin() as conn:
+            return await near_search.lookup(conn, details)
+
+    results = apiobj.async_to_sync(run())
+    results.sort(key=lambda r: r.accuracy)
+
+    return results
+
+
+def test_no_results_inner_query(apiobj):
+    assert not run_search(apiobj, 0.4, [('this', 'that')])
+
+
+class TestNearSearch:
+
+    @pytest.fixture(autouse=True)
+    def fill_database(self, apiobj):
+        apiobj.add_placex(place_id=100, country_code='us',
+                          centroid=(5.6, 4.3))
+        apiobj.add_search_name(100, names=[56], country_code='us',
+                               centroid=(5.6, 4.3))
+        apiobj.add_placex(place_id=101, country_code='mx',
+                          centroid=(-10.3, 56.9))
+        apiobj.add_search_name(101, names=[56], country_code='mx',
+                               centroid=(-10.3, 56.9))
+
+
+    def test_near_in_placex(self, apiobj):
+        apiobj.add_placex(place_id=22, class_='amenity', type='bank',
+                          centroid=(5.6001, 4.2994))
+        apiobj.add_placex(place_id=23, class_='amenity', type='bench',
+                          centroid=(5.6001, 4.2994))
+
+        results = run_search(apiobj, 0.1, [('amenity', 'bank')])
+
+        assert [r.place_id for r in results] == [22]
+
+
+    def test_multiple_types_near_in_placex(self, apiobj):
+        apiobj.add_placex(place_id=22, class_='amenity', type='bank',
+                          importance=0.002,
+                          centroid=(5.6001, 4.2994))
+        apiobj.add_placex(place_id=23, class_='amenity', type='bench',
+                          importance=0.001,
+                          centroid=(5.6001, 4.2994))
+
+        results = run_search(apiobj, 0.1, [('amenity', 'bank'),
+                                           ('amenity', 'bench')])
+
+        assert [r.place_id for r in results] == [22, 23]
+
+
+    def test_near_in_classtype(self, apiobj):
+        apiobj.add_placex(place_id=22, class_='amenity', type='bank',
+                          centroid=(5.6, 4.34))
+        apiobj.add_placex(place_id=23, class_='amenity', type='bench',
+                          centroid=(5.6, 4.34))
+        apiobj.add_class_type_table('amenity', 'bank')
+        apiobj.add_class_type_table('amenity', 'bench')
+
+        results = run_search(apiobj, 0.1, [('amenity', 'bank')])
+
+        assert [r.place_id for r in results] == [22]
+
diff --git a/test/python/api/search/test_search_places.py b/test/python/api/search/test_search_places.py
new file mode 100644 (file)
index 0000000..df369b8
--- /dev/null
@@ -0,0 +1,385 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Tests for running the generic place searcher.
+"""
+import pytest
+
+import nominatim.api as napi
+from nominatim.api.types import SearchDetails
+from nominatim.api.search.db_searches import PlaceSearch
+from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories,\
+                                                  FieldLookup, FieldRanking, RankedTokens
+
+def run_search(apiobj, global_penalty, lookup, ranking, count=2,
+               hnrs=[], pcs=[], ccodes=[], quals=[],
+               details=SearchDetails()):
+    class MySearchData:
+        penalty = global_penalty
+        postcodes = WeightedStrings(pcs, [0.0] * len(pcs))
+        countries = WeightedStrings(ccodes, [0.0] * len(ccodes))
+        housenumbers = WeightedStrings(hnrs, [0.0] * len(hnrs))
+        qualifiers = WeightedCategories(quals, [0.0] * len(quals))
+        lookups = lookup
+        rankings = ranking
+
+    search = PlaceSearch(0.0, MySearchData(), count)
+
+    async def run():
+        async with apiobj.api._async_api.begin() as conn:
+            return await search.lookup(conn, details)
+
+    results = apiobj.async_to_sync(run())
+    results.sort(key=lambda r: r.accuracy)
+
+    return results
+
+
+class TestNameOnlySearches:
+
+    @pytest.fixture(autouse=True)
+    def fill_database(self, apiobj):
+        apiobj.add_placex(place_id=100, country_code='us',
+                          centroid=(5.6, 4.3))
+        apiobj.add_search_name(100, names=[1,2,10,11], country_code='us',
+                               centroid=(5.6, 4.3))
+        apiobj.add_placex(place_id=101, country_code='mx',
+                          centroid=(-10.3, 56.9))
+        apiobj.add_search_name(101, names=[1,2,20,21], country_code='mx',
+                               centroid=(-10.3, 56.9))
+
+
+    @pytest.mark.parametrize('lookup_type', ['lookup_all', 'restrict'])
+    @pytest.mark.parametrize('rank,res', [([10], [100, 101]),
+                                          ([20], [101, 100])])
+    def test_lookup_all_match(self, apiobj, lookup_type, rank, res):
+        lookup = FieldLookup('name_vector', [1,2], lookup_type)
+        ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking])
+
+        assert [r.place_id for r in results] == res
+
+
+    @pytest.mark.parametrize('lookup_type', ['lookup_all', 'restrict'])
+    def test_lookup_all_partial_match(self, apiobj, lookup_type):
+        lookup = FieldLookup('name_vector', [1,20], lookup_type)
+        ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking])
+
+        assert len(results) == 1
+        assert results[0].place_id == 101
+
+    @pytest.mark.parametrize('rank,res', [([10], [100, 101]),
+                                          ([20], [101, 100])])
+    def test_lookup_any_match(self, apiobj, rank, res):
+        lookup = FieldLookup('name_vector', [11,21], 'lookup_any')
+        ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking])
+
+        assert [r.place_id for r in results] == res
+
+
+    def test_lookup_any_partial_match(self, apiobj):
+        lookup = FieldLookup('name_vector', [20], 'lookup_all')
+        ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking])
+
+        assert len(results) == 1
+        assert results[0].place_id == 101
+
+
+    @pytest.mark.parametrize('cc,res', [('us', 100), ('mx', 101)])
+    def test_lookup_restrict_country(self, apiobj, cc, res):
+        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+        ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking], ccodes=[cc])
+
+        assert [r.place_id for r in results] == [res]
+
+
+    def test_lookup_restrict_placeid(self, apiobj):
+        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+        ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking],
+                             details=SearchDetails(excluded=[101]))
+
+        assert [r.place_id for r in results] == [100]
+
+
+    @pytest.mark.parametrize('geom', [napi.GeometryFormat.GEOJSON,
+                                      napi.GeometryFormat.KML,
+                                      napi.GeometryFormat.SVG,
+                                      napi.GeometryFormat.TEXT])
+    def test_return_geometries(self, apiobj, geom):
+        lookup = FieldLookup('name_vector', [20], 'lookup_all')
+        ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking],
+                             details=SearchDetails(geometry_output=geom))
+
+        assert geom.name.lower() in results[0].geometry
+
+
+    @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.7,4.0,6.0,5.0'])
+    def test_prefer_viewbox(self, apiobj, viewbox):
+        lookup = FieldLookup('name_vector', [1, 2], 'lookup_all')
+        ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        assert [r.place_id for r in results] == [101, 100]
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking],
+                             details=SearchDetails.from_kwargs({'viewbox': viewbox}))
+        assert [r.place_id for r in results] == [100, 101]
+
+
+    def test_force_viewbox(self, apiobj):
+        lookup = FieldLookup('name_vector', [1, 2], 'lookup_all')
+
+        details=SearchDetails.from_kwargs({'viewbox': '5.0,4.0,6.0,5.0',
+                                           'bounded_viewbox': True})
+
+        results = run_search(apiobj, 0.1, [lookup], [], details=details)
+        assert [r.place_id for r in results] == [100]
+
+
+    def test_prefer_near(self, apiobj):
+        lookup = FieldLookup('name_vector', [1, 2], 'lookup_all')
+        ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        assert [r.place_id for r in results] == [101, 100]
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking],
+                             details=SearchDetails.from_kwargs({'near': '5.6,4.3'}))
+        results.sort(key=lambda r: -r.importance)
+        assert [r.place_id for r in results] == [100, 101]
+
+
+    def test_force_near(self, apiobj):
+        lookup = FieldLookup('name_vector', [1, 2], 'lookup_all')
+
+        details=SearchDetails.from_kwargs({'near': '5.6,4.3',
+                                           'near_radius': 0.11})
+
+        results = run_search(apiobj, 0.1, [lookup], [], details=details)
+
+        assert [r.place_id for r in results] == [100]
+
+
+class TestStreetWithHousenumber:
+
+    @pytest.fixture(autouse=True)
+    def fill_database(self, apiobj):
+        apiobj.add_placex(place_id=1, class_='place', type='house',
+                          parent_place_id=1000,
+                          housenumber='20 a', country_code='es')
+        apiobj.add_placex(place_id=2, class_='place', type='house',
+                          parent_place_id=1000,
+                          housenumber='21;22', country_code='es')
+        apiobj.add_placex(place_id=1000, class_='highway', type='residential',
+                          rank_search=26, rank_address=26,
+                          country_code='es')
+        apiobj.add_search_name(1000, names=[1,2,10,11],
+                               search_rank=26, address_rank=26,
+                               country_code='es')
+        apiobj.add_placex(place_id=91, class_='place', type='house',
+                          parent_place_id=2000,
+                          housenumber='20', country_code='pt')
+        apiobj.add_placex(place_id=92, class_='place', type='house',
+                          parent_place_id=2000,
+                          housenumber='22', country_code='pt')
+        apiobj.add_placex(place_id=93, class_='place', type='house',
+                          parent_place_id=2000,
+                          housenumber='24', country_code='pt')
+        apiobj.add_placex(place_id=2000, class_='highway', type='residential',
+                          rank_search=26, rank_address=26,
+                          country_code='pt')
+        apiobj.add_search_name(2000, names=[1,2,20,21],
+                               search_rank=26, address_rank=26,
+                               country_code='pt')
+
+
+    @pytest.mark.parametrize('hnr,res', [('20', [91, 1]), ('20 a', [1]),
+                                         ('21', [2]), ('22', [2, 92]),
+                                         ('24', [93]), ('25', [])])
+    def test_lookup_by_single_housenumber(self, apiobj, hnr, res):
+        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+        ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=[hnr])
+
+        assert [r.place_id for r in results] == res + [1000, 2000]
+
+
+    @pytest.mark.parametrize('cc,res', [('es', [2, 1000]), ('pt', [92, 2000])])
+    def test_lookup_with_country_restriction(self, apiobj, cc, res):
+        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+        ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+                             ccodes=[cc])
+
+        assert [r.place_id for r in results] == res
+
+
+    def test_lookup_exclude_housenumber_placeid(self, apiobj):
+        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+        ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+                             details=SearchDetails(excluded=[92]))
+
+        assert [r.place_id for r in results] == [2, 1000, 2000]
+
+
+    def test_lookup_exclude_street_placeid(self, apiobj):
+        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+        ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
+
+        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+                             details=SearchDetails(excluded=[1000]))
+
+        assert [r.place_id for r in results] == [2, 92, 2000]
+
+
+    @pytest.mark.parametrize('geom', [napi.GeometryFormat.GEOJSON,
+                                      napi.GeometryFormat.KML,
+                                      napi.GeometryFormat.SVG,
+                                      napi.GeometryFormat.TEXT])
+    def test_return_geometries(self, apiobj, geom):
+        lookup = FieldLookup('name_vector', [1, 2], 'lookup_all')
+
+        results = run_search(apiobj, 0.1, [lookup], [], hnrs=['20', '21', '22'],
+                             details=SearchDetails(geometry_output=geom))
+
+        assert results
+        assert all(geom.name.lower() in r.geometry for r in results)
+
+
+class TestInterpolations:
+
+    @pytest.fixture(autouse=True)
+    def fill_database(self, apiobj):
+        apiobj.add_placex(place_id=990, class_='highway', type='service',
+                          rank_search=27, rank_address=27,
+                          centroid=(10.0, 10.0),
+                          geometry='LINESTRING(9.995 10, 10.005 10)')
+        apiobj.add_search_name(990, names=[111],
+                               search_rank=27, address_rank=27)
+        apiobj.add_placex(place_id=991, class_='place', type='house',
+                          parent_place_id=990,
+                          rank_search=30, rank_address=30,
+                          housenumber='23',
+                          centroid=(10.0, 10.00002))
+        apiobj.add_osmline(place_id=992,
+                           parent_place_id=990,
+                           startnumber=21, endnumber=29, step=2,
+                           centroid=(10.0, 10.00001),
+                           geometry='LINESTRING(9.995 10.00001, 10.005 10.00001)')
+
+
+    @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])])
+    def test_lookup_housenumber(self, apiobj, hnr, res):
+        lookup = FieldLookup('name_vector', [111], 'lookup_all')
+
+        results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr])
+
+        assert [r.place_id for r in results] == res + [990]
+
+
+class TestTiger:
+
+    @pytest.fixture(autouse=True)
+    def fill_database(self, apiobj):
+        apiobj.add_placex(place_id=990, class_='highway', type='service',
+                          rank_search=27, rank_address=27,
+                          country_code='us',
+                          centroid=(10.0, 10.0),
+                          geometry='LINESTRING(9.995 10, 10.005 10)')
+        apiobj.add_search_name(990, names=[111], country_code='us',
+                               search_rank=27, address_rank=27)
+        apiobj.add_placex(place_id=991, class_='place', type='house',
+                          parent_place_id=990,
+                          rank_search=30, rank_address=30,
+                          housenumber='23',
+                          country_code='us',
+                          centroid=(10.0, 10.00002))
+        apiobj.add_tiger(place_id=992,
+                         parent_place_id=990,
+                         startnumber=21, endnumber=29, step=2,
+                         centroid=(10.0, 10.00001),
+                         geometry='LINESTRING(9.995 10.00001, 10.005 10.00001)')
+
+
+    @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])])
+    def test_lookup_housenumber(self, apiobj, hnr, res):
+        lookup = FieldLookup('name_vector', [111], 'lookup_all')
+
+        results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr])
+
+        assert [r.place_id for r in results] == res + [990]
+
+
+class TestLayersRank30:
+
+    @pytest.fixture(autouse=True)
+    def fill_database(self, apiobj):
+        apiobj.add_placex(place_id=223, class_='place', type='house',
+                          housenumber='1',
+                          rank_address=30,
+                          rank_search=30)
+        apiobj.add_search_name(223, names=[34],
+                               importance=0.0009,
+                               address_rank=30, search_rank=30)
+        apiobj.add_placex(place_id=224, class_='amenity', type='toilet',
+                          rank_address=30,
+                          rank_search=30)
+        apiobj.add_search_name(224, names=[34],
+                               importance=0.0008,
+                               address_rank=30, search_rank=30)
+        apiobj.add_placex(place_id=225, class_='man_made', type='tower',
+                          rank_address=0,
+                          rank_search=30)
+        apiobj.add_search_name(225, names=[34],
+                               importance=0.0007,
+                               address_rank=0, search_rank=30)
+        apiobj.add_placex(place_id=226, class_='railway', type='station',
+                          rank_address=0,
+                          rank_search=30)
+        apiobj.add_search_name(226, names=[34],
+                               importance=0.0006,
+                               address_rank=0, search_rank=30)
+        apiobj.add_placex(place_id=227, class_='natural', type='cave',
+                          rank_address=0,
+                          rank_search=30)
+        apiobj.add_search_name(227, names=[34],
+                               importance=0.0005,
+                               address_rank=0, search_rank=30)
+
+
+    @pytest.mark.parametrize('layer,res', [(napi.DataLayer.ADDRESS, [223]),
+                                           (napi.DataLayer.POI, [224]),
+                                           (napi.DataLayer.ADDRESS | napi.DataLayer.POI, [223, 224]),
+                                           (napi.DataLayer.MANMADE, [225]),
+                                           (napi.DataLayer.RAILWAY, [226]),
+                                           (napi.DataLayer.NATURAL, [227]),
+                                           (napi.DataLayer.MANMADE | napi.DataLayer.NATURAL, [225, 227]),
+                                           (napi.DataLayer.MANMADE | napi.DataLayer.RAILWAY, [225, 226])])
+    def test_layers_rank30(self, apiobj, layer, res):
+        lookup = FieldLookup('name_vector', [34], 'lookup_any')
+
+        results = run_search(apiobj, 0.1, [lookup], [],
+                             details=SearchDetails(layers=layer))
+
+        assert [r.place_id for r in results] == res
diff --git a/test/python/api/search/test_search_poi.py b/test/python/api/search/test_search_poi.py
new file mode 100644 (file)
index 0000000..b80c075
--- /dev/null
@@ -0,0 +1,108 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Tests for running the POI searcher.
+"""
+import pytest
+
+import nominatim.api as napi
+from nominatim.api.types import SearchDetails
+from nominatim.api.search.db_searches import PoiSearch
+from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories
+
+
+def run_search(apiobj, global_penalty, poitypes, poi_penalties=None,
+               ccodes=[], details=SearchDetails()):
+    if poi_penalties is None:
+        poi_penalties = [0.0] * len(poitypes)
+
+    class MySearchData:
+        penalty = global_penalty
+        qualifiers = WeightedCategories(poitypes, poi_penalties)
+        countries = WeightedStrings(ccodes, [0.0] * len(ccodes))
+
+    search = PoiSearch(MySearchData())
+
+    async def run():
+        async with apiobj.api._async_api.begin() as conn:
+            return await search.lookup(conn, details)
+
+    return apiobj.async_to_sync(run())
+
+
+@pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2),
+                                       ('5.0, 4.59933', 1)])
+def test_simple_near_search_in_placex(apiobj, coord, pid):
+    apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
+                      centroid=(5.0, 4.6))
+    apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
+                      centroid=(34.3, 56.1))
+
+    details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.001})
+
+    results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
+
+    assert [r.place_id for r in results] == [pid]
+
+
+@pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2),
+                                       ('34.3, 56.4', 2),
+                                       ('5.0, 4.59933', 1)])
+def test_simple_near_search_in_classtype(apiobj, coord, pid):
+    apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
+                      centroid=(5.0, 4.6))
+    apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
+                      centroid=(34.3, 56.1))
+    apiobj.add_class_type_table('highway', 'bus_stop')
+
+    details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.5})
+
+    results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
+
+    assert [r.place_id for r in results] == [pid]
+
+
+class TestPoiSearchWithRestrictions:
+
+    @pytest.fixture(autouse=True, params=["placex", "classtype"])
+    def fill_database(self, apiobj, request):
+        apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
+                          country_code='au',
+                          centroid=(34.3, 56.10003))
+        apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
+                          country_code='nz',
+                          centroid=(34.3, 56.1))
+        if request.param == 'classtype':
+            apiobj.add_class_type_table('highway', 'bus_stop')
+            self.args = {'near': '34.3, 56.4', 'near_radius': 0.5}
+        else:
+            self.args = {'near': '34.3, 56.100021', 'near_radius': 0.001}
+
+
+    def test_unrestricted(self, apiobj):
+        results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5],
+                             details=SearchDetails.from_kwargs(self.args))
+
+        assert [r.place_id for r in results] == [1, 2]
+
+
+    def test_restict_country(self, apiobj):
+        results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5],
+                             ccodes=['de', 'nz'],
+                             details=SearchDetails.from_kwargs(self.args))
+
+        assert [r.place_id for r in results] == [2]
+
+
+    def test_restrict_by_viewbox(self, apiobj):
+        args = {'bounded_viewbox': True, 'viewbox': '34.299,56.0,34.3001,56.10001'}
+        args.update(self.args)
+        results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5],
+                             ccodes=['de', 'nz'],
+                             details=SearchDetails.from_kwargs(args))
+
+        assert [r.place_id for r in results] == [2]
diff --git a/test/python/api/search/test_search_postcode.py b/test/python/api/search/test_search_postcode.py
new file mode 100644 (file)
index 0000000..a43bc89
--- /dev/null
@@ -0,0 +1,97 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Tests for running the postcode searcher.
+"""
+import pytest
+
+import nominatim.api as napi
+from nominatim.api.types import SearchDetails
+from nominatim.api.search.db_searches import PostcodeSearch
+from nominatim.api.search.db_search_fields import WeightedStrings, FieldLookup, \
+                                                  FieldRanking, RankedTokens
+
+def run_search(apiobj, global_penalty, pcs, pc_penalties=None,
+               ccodes=[], lookup=[], ranking=[], details=SearchDetails()):
+    if pc_penalties is None:
+        pc_penalties = [0.0] * len(pcs)
+
+    class MySearchData:
+        penalty = global_penalty
+        postcodes = WeightedStrings(pcs, pc_penalties)
+        countries = WeightedStrings(ccodes, [0.0] * len(ccodes))
+        lookups = lookup
+        rankings = ranking
+
+    search = PostcodeSearch(0.0, MySearchData())
+
+    async def run():
+        async with apiobj.api._async_api.begin() as conn:
+            return await search.lookup(conn, details)
+
+    return apiobj.async_to_sync(run())
+
+
+def test_postcode_only_search(apiobj):
+    apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345')
+    apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345')
+
+    results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1])
+
+    assert len(results) == 2
+    assert [r.place_id for r in results] == [100, 101]
+
+
+def test_postcode_with_country(apiobj):
+    apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345')
+    apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345')
+
+    results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1],
+                         ccodes=['de', 'pl'])
+
+    assert len(results) == 1
+    assert results[0].place_id == 101
+
+
+class TestPostcodeSearchWithAddress:
+
+    @pytest.fixture(autouse=True)
+    def fill_database(self, apiobj):
+        apiobj.add_postcode(place_id=100, country_code='ch',
+                            parent_place_id=1000, postcode='12345')
+        apiobj.add_postcode(place_id=101, country_code='pl',
+                            parent_place_id=2000, postcode='12345')
+        apiobj.add_placex(place_id=1000, class_='place', type='village',
+                          rank_search=22, rank_address=22,
+                          country_code='ch')
+        apiobj.add_search_name(1000, names=[1,2,10,11],
+                               search_rank=22, address_rank=22,
+                               country_code='ch')
+        apiobj.add_placex(place_id=2000, class_='place', type='village',
+                          rank_search=22, rank_address=22,
+                          country_code='pl')
+        apiobj.add_search_name(2000, names=[1,2,20,21],
+                               search_rank=22, address_rank=22,
+                               country_code='pl')
+
+
+    def test_lookup_both(self, apiobj):
+        lookup = FieldLookup('name_vector', [1,2], 'restrict')
+        ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
+
+        results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup], ranking=[ranking])
+
+        assert [r.place_id for r in results] == [100, 101]
+
+
+    def test_restrict_by_name(self, apiobj):
+        lookup = FieldLookup('name_vector', [10], 'restrict')
+
+        results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup])
+
+        assert [r.place_id for r in results] == [100]
+