]> git.openstreetmap.org Git - nominatim.git/commitdiff
add API functions for search functions
authorSarah Hoffmann <lonvia@denofr.de>
Wed, 24 May 2023 15:43:28 +0000 (17:43 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Wed, 24 May 2023 16:05:43 +0000 (18:05 +0200)
Search is now split into three functions: for free-text search,
for structured search and for search by category. Note that the
free-text search does not have as many hidden features like
coordinate search. Use the search parameters for that.

nominatim/api/core.py
nominatim/api/logging.py
nominatim/api/lookup.py
nominatim/api/results.py
nominatim/api/reverse.py
nominatim/api/search/__init__.py
nominatim/api/search/db_search_builder.py
nominatim/api/search/geocoder.py [new file with mode: 0644]
nominatim/api/search/query_analyzer_factory.py
test/python/api/test_api_search.py [new file with mode: 0644]

index f1a656da483eda6fecf3b5f2038b3c1806b38142..a9fc12439364415e8a8039f2f45b1eccbea9756e 100644 (file)
@@ -7,7 +7,7 @@
 """
 Implementation of classes for API access via libraries.
 """
 """
 Implementation of classes for API access via libraries.
 """
-from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence
+from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, Tuple
 import asyncio
 import contextlib
 from pathlib import Path
 import asyncio
 import contextlib
 from pathlib import Path
@@ -15,7 +15,7 @@ from pathlib import Path
 import sqlalchemy as sa
 import sqlalchemy.ext.asyncio as sa_asyncio
 
 import sqlalchemy as sa
 import sqlalchemy.ext.asyncio as sa_asyncio
 
-
+from nominatim.errors import UsageError
 from nominatim.db.sqlalchemy_schema import SearchTables
 from nominatim.db.async_core_library import PGCORE_LIB, PGCORE_ERROR
 from nominatim.config import Configuration
 from nominatim.db.sqlalchemy_schema import SearchTables
 from nominatim.db.async_core_library import PGCORE_LIB, PGCORE_ERROR
 from nominatim.config import Configuration
@@ -23,6 +23,7 @@ from nominatim.api.connection import SearchConnection
 from nominatim.api.status import get_status, StatusResult
 from nominatim.api.lookup import get_detailed_place, get_simple_place
 from nominatim.api.reverse import ReverseGeocoder
 from nominatim.api.status import get_status, StatusResult
 from nominatim.api.lookup import get_detailed_place, get_simple_place
 from nominatim.api.reverse import ReverseGeocoder
+from nominatim.api.search import ForwardGeocoder, Phrase, PhraseType, make_query_analyzer
 import nominatim.api.types as ntyp
 from nominatim.api.results import DetailedResult, ReverseResult, SearchResults
 
 import nominatim.api.types as ntyp
 from nominatim.api.results import DetailedResult, ReverseResult, SearchResults
 
@@ -133,9 +134,11 @@ class NominatimAPIAsync:
 
             Returns None if there is no entry under the given ID.
         """
 
             Returns None if there is no entry under the given ID.
         """
+        details = ntyp.LookupDetails.from_kwargs(params)
         async with self.begin() as conn:
         async with self.begin() as conn:
-            return await get_detailed_place(conn, place,
-                                            ntyp.LookupDetails.from_kwargs(params))
+            if details.keywords:
+                await make_query_analyzer(conn)
+            return await get_detailed_place(conn, place, details)
 
 
     async def lookup(self, places: Sequence[ntyp.PlaceRef], **params: Any) -> SearchResults:
 
 
     async def lookup(self, places: Sequence[ntyp.PlaceRef], **params: Any) -> SearchResults:
@@ -145,6 +148,8 @@ class NominatimAPIAsync:
         """
         details = ntyp.LookupDetails.from_kwargs(params)
         async with self.begin() as conn:
         """
         details = ntyp.LookupDetails.from_kwargs(params)
         async with self.begin() as conn:
+            if details.keywords:
+                await make_query_analyzer(conn)
             return SearchResults(filter(None,
                                         [await get_simple_place(conn, p, details) for p in places]))
 
             return SearchResults(filter(None,
                                         [await get_simple_place(conn, p, details) for p in places]))
 
@@ -160,11 +165,107 @@ class NominatimAPIAsync:
             # There are no results to be expected outside valid coordinates.
             return None
 
             # There are no results to be expected outside valid coordinates.
             return None
 
+        details = ntyp.ReverseDetails.from_kwargs(params)
         async with self.begin() as conn:
         async with self.begin() as conn:
-            geocoder = ReverseGeocoder(conn, ntyp.ReverseDetails.from_kwargs(params))
+            if details.keywords:
+                await make_query_analyzer(conn)
+            geocoder = ReverseGeocoder(conn, details)
             return await geocoder.lookup(coord)
 
 
             return await geocoder.lookup(coord)
 
 
+    async def search(self, query: str, **params: Any) -> SearchResults:
+        """ Find a place by free-text search. Also known as forward geocoding.
+        """
+        query = query.strip()
+        if not query:
+            raise UsageError('Nothing to search for.')
+
+        async with self.begin() as conn:
+            geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params))
+            phrases = [Phrase(PhraseType.NONE, p.strip()) for p in query.split(',')]
+            return await geocoder.lookup(phrases)
+
+
+    # pylint: disable=too-many-arguments,too-many-branches
+    async def search_address(self, amenity: Optional[str] = None,
+                             street: Optional[str] = None,
+                             city: Optional[str] = None,
+                             county: Optional[str] = None,
+                             state: Optional[str] = None,
+                             country: Optional[str] = None,
+                             postalcode: Optional[str] = None,
+                             **params: Any) -> SearchResults:
+        """ Find an address using structured search.
+        """
+        async with self.begin() as conn:
+            details = ntyp.SearchDetails.from_kwargs(params)
+
+            phrases: List[Phrase] = []
+
+            if amenity:
+                phrases.append(Phrase(PhraseType.AMENITY, amenity))
+            if street:
+                phrases.append(Phrase(PhraseType.STREET, street))
+            if city:
+                phrases.append(Phrase(PhraseType.CITY, city))
+            if county:
+                phrases.append(Phrase(PhraseType.COUNTY, county))
+            if state:
+                phrases.append(Phrase(PhraseType.STATE, state))
+            if postalcode:
+                phrases.append(Phrase(PhraseType.POSTCODE, postalcode))
+            if country:
+                phrases.append(Phrase(PhraseType.COUNTRY, country))
+
+            if not phrases:
+                raise UsageError('Nothing to search for.')
+
+            if amenity or street:
+                details.restrict_min_max_rank(26, 30)
+            elif city:
+                details.restrict_min_max_rank(13, 25)
+            elif county:
+                details.restrict_min_max_rank(10, 12)
+            elif state:
+                details.restrict_min_max_rank(5, 9)
+            elif postalcode:
+                details.restrict_min_max_rank(5, 11)
+            else:
+                details.restrict_min_max_rank(4, 4)
+
+            if 'layers' not in params:
+                details.layers = ntyp.DataLayer.ADDRESS
+                if amenity:
+                    details.layers |= ntyp.DataLayer.POI
+
+            geocoder = ForwardGeocoder(conn, details)
+            return await geocoder.lookup(phrases)
+
+
+    async def search_category(self, categories: List[Tuple[str, str]],
+                              near_query: Optional[str] = None,
+                              **params: Any) -> SearchResults:
+        """ Find an object of a certain category near another place.
+            The near place may either be given as an unstructured search
+            query in itself or as coordinates.
+        """
+        if not categories:
+            return SearchResults()
+
+        details = ntyp.SearchDetails.from_kwargs(params)
+        async with self.begin() as conn:
+            if near_query:
+                phrases = [Phrase(PhraseType.NONE, p) for p in near_query.split(',')]
+            else:
+                phrases = []
+                if details.keywords:
+                    await make_query_analyzer(conn)
+
+            geocoder = ForwardGeocoder(conn, details)
+            return await geocoder.lookup_pois(categories, phrases)
+
+
+
 class NominatimAPI:
     """ API loader, synchronous version.
     """
 class NominatimAPI:
     """ API loader, synchronous version.
     """
@@ -217,3 +318,38 @@ class NominatimAPI:
             no place matches the given criteria.
         """
         return self._loop.run_until_complete(self._async_api.reverse(coord, **params))
             no place matches the given criteria.
         """
         return self._loop.run_until_complete(self._async_api.reverse(coord, **params))
+
+
+    def search(self, query: str, **params: Any) -> SearchResults:
+        """ Find a place by free-text search. Also known as forward geocoding.
+        """
+        return self._loop.run_until_complete(
+                   self._async_api.search(query, **params))
+
+
+    # pylint: disable=too-many-arguments
+    def search_address(self, amenity: Optional[str] = None,
+                       street: Optional[str] = None,
+                       city: Optional[str] = None,
+                       county: Optional[str] = None,
+                       state: Optional[str] = None,
+                       country: Optional[str] = None,
+                       postalcode: Optional[str] = None,
+                       **params: Any) -> SearchResults:
+        """ Find an address using structured search.
+        """
+        return self._loop.run_until_complete(
+                   self._async_api.search_address(amenity, street, city, county,
+                                                  state, country, postalcode, **params))
+
+
+    def search_category(self, categories: List[Tuple[str, str]],
+                        near_query: Optional[str] = None,
+                        **params: Any) -> SearchResults:
+        """ Find an object of a certain category near another place.
+            The near place may either be given as an unstructured search
+            query in itself or as a geographic area through the
+            viewbox or near parameters.
+        """
+        return self._loop.run_until_complete(
+                   self._async_api.search_category(categories, near_query, **params))
index fdff73beb078ad14f0e6ad6104ece434e93e387a..351da9a1d6ebe81272df4b45a09e569a9dce65da 100644 (file)
@@ -7,7 +7,7 @@
 """
 Functions for specialised logging with HTML output.
 """
 """
 Functions for specialised logging with HTML output.
 """
-from typing import Any, Iterator, Optional, List, cast
+from typing import Any, Iterator, Optional, List, Tuple, cast
 from contextvars import ContextVar
 import textwrap
 import io
 from contextvars import ContextVar
 import textwrap
 import io
@@ -24,6 +24,13 @@ except ModuleNotFoundError:
     CODE_HIGHLIGHT = False
 
 
     CODE_HIGHLIGHT = False
 
 
+def _debug_name(res: Any) -> str:
+    if res.names:
+        return cast(str, res.names.get('name', next(iter(res.names.values()))))
+
+    return f"Hnr {res.housenumber}" if res.housenumber is not None else '[NONE]'
+
+
 class BaseLogger:
     """ Interface for logging function.
 
 class BaseLogger:
     """ Interface for logging function.
 
@@ -61,6 +68,11 @@ class BaseLogger:
         """
 
 
         """
 
 
+    def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None:
+        """ Print a list of search results generated by the generator function.
+        """
+
+
     def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
         """ Print the SQL for the given statement.
         """
     def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
         """ Print the SQL for the given statement.
         """
@@ -128,6 +140,38 @@ class HTMLLogger(BaseLogger):
         self._write('</tbody></table>')
 
 
         self._write('</tbody></table>')
 
 
+    def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None:
+        """ Print a list of search results generated by the generator function.
+        """
+        def format_osm(osm_object: Optional[Tuple[str, int]]) -> str:
+            if not osm_object:
+                return '-'
+
+            t, i = osm_object
+            if t == 'N':
+                fullt = 'node'
+            elif t == 'W':
+                fullt = 'way'
+            elif t == 'R':
+                fullt = 'relation'
+            else:
+                return f'{t}{i}'
+
+            return f'<a href="https://www.openstreetmap.org/{fullt}/{i}">{t}{i}</a>'
+
+        self._write(f'<h5>{heading}</h5><p><dl>')
+        total = 0
+        for rank, res in results:
+            self._write(f'<dt>[{rank:.3f}]</dt>  <dd>{res.source_table.name}(')
+            self._write(f"{_debug_name(res)}, type=({','.join(res.category)}), ")
+            self._write(f"rank={res.rank_address}, ")
+            self._write(f"osm={format_osm(res.osm_object)}, ")
+            self._write(f'cc={res.country_code}, ')
+            self._write(f'importance={res.importance or -1:.5f})</dd>')
+            total += 1
+        self._write(f'</dl><b>TOTAL:</b> {total}</p>')
+
+
     def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
         sqlstr = self.format_sql(conn, statement)
         if CODE_HIGHLIGHT:
     def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
         sqlstr = self.format_sql(conn, statement)
         if CODE_HIGHLIGHT:
@@ -206,6 +250,20 @@ class TextLogger(BaseLogger):
             self._write('-'*tablewidth + '\n')
 
 
             self._write('-'*tablewidth + '\n')
 
 
+    def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None:
+        self._write(f'{heading}:\n')
+        total = 0
+        for rank, res in results:
+            self._write(f'[{rank:.3f}]  {res.source_table.name}(')
+            self._write(f"{_debug_name(res)}, type=({','.join(res.category)}), ")
+            self._write(f"rank={res.rank_address}, ")
+            self._write(f"osm={''.join(map(str, res.osm_object or []))}, ")
+            self._write(f'cc={res.country_code}, ')
+            self._write(f'importance={res.importance or -1:.5f})\n')
+            total += 1
+        self._write(f'TOTAL: {total}\n\n')
+
+
     def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
         sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement), width=78))
         self._write(f"| {sqlstr}\n\n")
     def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
         sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement), width=78))
         self._write(f"| {sqlstr}\n\n")
index 823527025d59c65baa729d620cd95ca10ab50d72..0e1fd9cec6303ed188adf5d8e54223e732e3c165 100644 (file)
@@ -189,13 +189,13 @@ async def get_detailed_place(conn: SearchConnection, place: ntyp.PlaceRef,
     if indexed_date is not None:
         result.indexed_date = indexed_date.replace(tzinfo=dt.timezone.utc)
 
     if indexed_date is not None:
         result.indexed_date = indexed_date.replace(tzinfo=dt.timezone.utc)
 
-    await nres.add_result_details(conn, result, details)
+    await nres.add_result_details(conn, [result], details)
 
     return result
 
 
 async def get_simple_place(conn: SearchConnection, place: ntyp.PlaceRef,
 
     return result
 
 
 async def get_simple_place(conn: SearchConnection, place: ntyp.PlaceRef,
-                             details: ntyp.LookupDetails) -> Optional[nres.SearchResult]:
+                           details: ntyp.LookupDetails) -> Optional[nres.SearchResult]:
     """ Retrieve a place as a simple search result from the database.
     """
     log().function('get_simple_place', place=place, details=details)
     """ Retrieve a place as a simple search result from the database.
     """
     log().function('get_simple_place', place=place, details=details)
@@ -234,6 +234,6 @@ async def get_simple_place(conn: SearchConnection, place: ntyp.PlaceRef,
     assert result is not None
     result.bbox = getattr(row, 'bbox', None)
 
     assert result is not None
     result.bbox = getattr(row, 'bbox', None)
 
-    await nres.add_result_details(conn, result, details)
+    await nres.add_result_details(conn, [result], details)
 
     return result
 
     return result
index 1c313398e0d94b88ec68e32e16f8de6da06d86d7..5981cb3ecc8f753a55a3186abbe9617ed8d1eabb 100644 (file)
@@ -11,7 +11,7 @@ Data classes are part of the public API while the functions are for
 internal use only. That's why they are implemented as free-standing functions
 instead of member functions.
 """
 internal use only. That's why they are implemented as free-standing functions
 instead of member functions.
 """
-from typing import Optional, Tuple, Dict, Sequence, TypeVar, Type, List
+from typing import Optional, Tuple, Dict, Sequence, TypeVar, Type, List, Any
 import enum
 import dataclasses
 import datetime as dt
 import enum
 import dataclasses
 import datetime as dt
@@ -23,7 +23,6 @@ from nominatim.api.types import Point, Bbox, LookupDetails
 from nominatim.api.connection import SearchConnection
 from nominatim.api.logging import log
 from nominatim.api.localization import Locales
 from nominatim.api.connection import SearchConnection
 from nominatim.api.logging import log
 from nominatim.api.localization import Locales
-from nominatim.api.search.query_analyzer_factory import make_query_analyzer
 
 # This file defines complex result data classes.
 # pylint: disable=too-many-instance-attributes
 
 # This file defines complex result data classes.
 # pylint: disable=too-many-instance-attributes
@@ -147,6 +146,7 @@ class BaseResult:
         """
         return self.importance or (0.7500001 - (self.rank_search/40.0))
 
         """
         return self.importance or (0.7500001 - (self.rank_search/40.0))
 
+
 BaseResultT = TypeVar('BaseResultT', bound=BaseResult)
 
 @dataclasses.dataclass
 BaseResultT = TypeVar('BaseResultT', bound=BaseResult)
 
 @dataclasses.dataclass
@@ -332,24 +332,28 @@ def create_from_country_row(row: Optional[SaRow],
                       country_code=row.country_code)
 
 
                       country_code=row.country_code)
 
 
-async def add_result_details(conn: SearchConnection, result: BaseResult,
+async def add_result_details(conn: SearchConnection, results: List[BaseResultT],
                              details: LookupDetails) -> None:
     """ Retrieve more details from the database according to the
         parameters specified in 'details'.
     """
                              details: LookupDetails) -> None:
     """ Retrieve more details from the database according to the
         parameters specified in 'details'.
     """
-    log().section('Query details for result')
-    if details.address_details:
-        log().comment('Query address details')
-        await complete_address_details(conn, result)
-    if details.linked_places:
-        log().comment('Query linked places')
-        await complete_linked_places(conn, result)
-    if details.parented_places:
-        log().comment('Query parent places')
-        await complete_parented_places(conn, result)
-    if details.keywords:
-        log().comment('Query keywords')
-        await complete_keywords(conn, result)
+    if results:
+        log().section('Query details for result')
+        if details.address_details:
+            log().comment('Query address details')
+            await complete_address_details(conn, results)
+        if details.linked_places:
+            log().comment('Query linked places')
+            for result in results:
+                await complete_linked_places(conn, result)
+        if details.parented_places:
+            log().comment('Query parent places')
+            for result in results:
+                await complete_parented_places(conn, result)
+        if details.keywords:
+            log().comment('Query keywords')
+            for result in results:
+                await complete_keywords(conn, result)
 
 
 def _result_row_to_address_row(row: SaRow) -> AddressLine:
 
 
 def _result_row_to_address_row(row: SaRow) -> AddressLine:
@@ -377,35 +381,60 @@ def _result_row_to_address_row(row: SaRow) -> AddressLine:
                        distance=row.distance)
 
 
                        distance=row.distance)
 
 
-async def complete_address_details(conn: SearchConnection, result: BaseResult) -> None:
+async def complete_address_details(conn: SearchConnection, results: List[BaseResultT]) -> None:
     """ Retrieve information about places that make up the address of the result.
     """
     """ Retrieve information about places that make up the address of the result.
     """
-    housenumber = -1
-    if result.source_table in (SourceTable.TIGER, SourceTable.OSMLINE):
-        if result.housenumber is not None:
-            housenumber = int(result.housenumber)
-        elif result.extratags is not None and 'startnumber' in result.extratags:
-            # details requests do not come with a specific house number
-            housenumber = int(result.extratags['startnumber'])
-
-    sfn = sa.func.get_addressdata(result.place_id, housenumber)\
-            .table_valued( # type: ignore[no-untyped-call]
-                sa.column('place_id', type_=sa.Integer),
-                'osm_type',
-                sa.column('osm_id', type_=sa.BigInteger),
-                sa.column('name', type_=conn.t.types.Composite),
-                'class', 'type', 'place_type',
-                sa.column('admin_level', type_=sa.Integer),
-                sa.column('fromarea', type_=sa.Boolean),
-                sa.column('isaddress', type_=sa.Boolean),
-                sa.column('rank_address', type_=sa.SmallInteger),
-                sa.column('distance', type_=sa.Float))
-    sql = sa.select(sfn).order_by(sa.column('rank_address').desc(),
-                                  sa.column('isaddress').desc())
-
-    result.address_rows = AddressLines()
+    def get_hnr(result: BaseResult) -> Tuple[int, int]:
+        housenumber = -1
+        if result.source_table in (SourceTable.TIGER, SourceTable.OSMLINE):
+            if result.housenumber is not None:
+                housenumber = int(result.housenumber)
+            elif result.extratags is not None and 'startnumber' in result.extratags:
+                # details requests do not come with a specific house number
+                housenumber = int(result.extratags['startnumber'])
+        assert result.place_id
+        return result.place_id, housenumber
+
+    data: List[Tuple[Any, ...]] = [get_hnr(r) for r in results if r.place_id]
+
+    if not data:
+        return
+
+    values = sa.values(sa.column('place_id', type_=sa.Integer),
+                       sa.column('housenumber', type_=sa.Integer),
+                       name='places',
+                       literal_binds=True).data(data)
+
+    sfn = sa.func.get_addressdata(values.c.place_id, values.c.housenumber)\
+                .table_valued( # type: ignore[no-untyped-call]
+                    sa.column('place_id', type_=sa.Integer),
+                    'osm_type',
+                    sa.column('osm_id', type_=sa.BigInteger),
+                    sa.column('name', type_=conn.t.types.Composite),
+                    'class', 'type', 'place_type',
+                    sa.column('admin_level', type_=sa.Integer),
+                    sa.column('fromarea', type_=sa.Boolean),
+                    sa.column('isaddress', type_=sa.Boolean),
+                    sa.column('rank_address', type_=sa.SmallInteger),
+                    sa.column('distance', type_=sa.Float),
+                    joins_implicitly=True)
+
+    sql = sa.select(values.c.place_id.label('result_place_id'), sfn)\
+            .order_by(values.c.place_id,
+                      sa.column('rank_address').desc(),
+                      sa.column('isaddress').desc())
+
+    current_result = None
     for row in await conn.execute(sql):
     for row in await conn.execute(sql):
-        result.address_rows.append(_result_row_to_address_row(row))
+        if current_result is None or row.result_place_id != current_result.place_id:
+            for result in results:
+                if result.place_id == row.result_place_id:
+                    current_result = result
+                    break
+            else:
+                assert False
+            current_result.address_rows = AddressLines()
+        current_result.address_rows.append(_result_row_to_address_row(row))
 
 
 # pylint: disable=consider-using-f-string
 
 
 # pylint: disable=consider-using-f-string
@@ -440,6 +469,9 @@ async def complete_linked_places(conn: SearchConnection, result: BaseResult) ->
 
 async def complete_keywords(conn: SearchConnection, result: BaseResult) -> None:
     """ Retrieve information about the search terms used for this place.
 
 async def complete_keywords(conn: SearchConnection, result: BaseResult) -> None:
     """ Retrieve information about the search terms used for this place.
+
+        Requires that the query analyzer was initialised to get access to
+        the word table.
     """
     t = conn.t.search_name
     sql = sa.select(t.c.name_vector, t.c.nameaddress_vector)\
     """
     t = conn.t.search_name
     sql = sa.select(t.c.name_vector, t.c.nameaddress_vector)\
@@ -448,7 +480,6 @@ async def complete_keywords(conn: SearchConnection, result: BaseResult) -> None:
     result.name_keywords = []
     result.address_keywords = []
 
     result.name_keywords = []
     result.address_keywords = []
 
-    await make_query_analyzer(conn)
     t = conn.t.meta.tables['word']
     sel = sa.select(t.c.word_id, t.c.word_token, t.c.word)
 
     t = conn.t.meta.tables['word']
     sel = sa.select(t.c.word_id, t.c.word_token, t.c.word)
 
index d6976c06c2f8307cae1dcf8fc3f9972672ef557f..10c97cad221702e513d175c339bf9db46e73a148 100644 (file)
@@ -548,6 +548,6 @@ class ReverseGeocoder:
             result.distance = row.distance
             if hasattr(row, 'bbox'):
                 result.bbox = Bbox.from_wkb(row.bbox.data)
             result.distance = row.distance
             if hasattr(row, 'bbox'):
                 result.bbox = Bbox.from_wkb(row.bbox.data)
-            await nres.add_result_details(self.conn, result, self.params)
+            await nres.add_result_details(self.conn, [result], self.params)
 
         return result
 
         return result
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..f60cbe1e22c629d62a7c6d5690b3ad5af79cd045 100644 (file)
@@ -0,0 +1,15 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Module for forward search.
+"""
+# pylint: disable=useless-import-alias
+
+from .geocoder import (ForwardGeocoder as ForwardGeocoder)
+from .query import (Phrase as Phrase,
+                    PhraseType as PhraseType)
+from .query_analyzer_factory import (make_query_analyzer as make_query_analyzer)
index c0c55a18a428f299aaece8e869f32babc8ed5fb5..9ea0cfedf5017101ac747d5a3df10fdad8c1f07e 100644 (file)
@@ -17,6 +17,36 @@ import nominatim.api.search.db_search_fields as dbf
 import nominatim.api.search.db_searches as dbs
 from nominatim.api.logging import log
 
 import nominatim.api.search.db_searches as dbs
 from nominatim.api.logging import log
 
+
+def wrap_near_search(categories: List[Tuple[str, str]],
+                     search: dbs.AbstractSearch) -> dbs.NearSearch:
+    """ Create a new search that wraps the given search in a search
+        for near places of the given category.
+    """
+    return dbs.NearSearch(penalty=search.penalty,
+                          categories=dbf.WeightedCategories(categories,
+                                                            [0.0] * len(categories)),
+                          search=search)
+
+
+def build_poi_search(category: List[Tuple[str, str]],
+                     countries: Optional[List[str]]) -> dbs.PoiSearch:
+    """ Create a new search for places by the given category, possibly
+        constraint to the given countries.
+    """
+    if countries:
+        ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
+    else:
+        ccs = dbf.WeightedStrings([], [])
+
+    class _PoiData(dbf.SearchData):
+        penalty = 0.0
+        qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
+        countries=ccs
+
+    return dbs.PoiSearch(_PoiData())
+
+
 class SearchBuilder:
     """ Build the abstract search queries from token assignments.
     """
 class SearchBuilder:
     """ Build the abstract search queries from token assignments.
     """
diff --git a/nominatim/api/search/geocoder.py b/nominatim/api/search/geocoder.py
new file mode 100644 (file)
index 0000000..5e90d40
--- /dev/null
@@ -0,0 +1,191 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Public interface to the search code.
+"""
+from typing import List, Any, Optional, Iterator, Tuple
+import itertools
+
+from nominatim.api.connection import SearchConnection
+from nominatim.api.types import SearchDetails
+from nominatim.api.results import SearchResults, add_result_details
+from nominatim.api.search.token_assignment import yield_token_assignments
+from nominatim.api.search.db_search_builder import SearchBuilder, build_poi_search, wrap_near_search
+from nominatim.api.search.db_searches import AbstractSearch
+from nominatim.api.search.query_analyzer_factory import make_query_analyzer, AbstractQueryAnalyzer
+from nominatim.api.search.query import Phrase, QueryStruct
+from nominatim.api.logging import log
+
+class ForwardGeocoder:
+    """ Main class responsible for place search.
+    """
+
+    def __init__(self, conn: SearchConnection, params: SearchDetails) -> None:
+        self.conn = conn
+        self.params = params
+        self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
+
+
+    @property
+    def limit(self) -> int:
+        """ Return the configured maximum number of search results.
+        """
+        return self.params.max_results
+
+
+    async def build_searches(self,
+                             phrases: List[Phrase]) -> Tuple[QueryStruct, List[AbstractSearch]]:
+        """ Analyse the query and return the tokenized query and list of
+            possible searches over it.
+        """
+        if self.query_analyzer is None:
+            self.query_analyzer = await make_query_analyzer(self.conn)
+
+        query = await self.query_analyzer.analyze_query(phrases)
+
+        searches: List[AbstractSearch] = []
+        if query.num_token_slots() > 0:
+            # 2. Compute all possible search interpretations
+            log().section('Compute abstract searches')
+            search_builder = SearchBuilder(query, self.params)
+            num_searches = 0
+            for assignment in yield_token_assignments(query):
+                searches.extend(search_builder.build(assignment))
+                log().table_dump('Searches for assignment',
+                                 _dump_searches(searches, query, num_searches))
+                num_searches = len(searches)
+            searches.sort(key=lambda s: s.penalty)
+
+        return query, searches
+
+
+    async def execute_searches(self, query: QueryStruct,
+                               searches: List[AbstractSearch]) -> SearchResults:
+        """ Run the abstract searches against the database until a result
+            is found.
+        """
+        log().section('Execute database searches')
+        results = SearchResults()
+
+        num_results = 0
+        min_ranking = 1000.0
+        prev_penalty = 0.0
+        for i, search in enumerate(searches):
+            if search.penalty > prev_penalty and (search.penalty > min_ranking or i > 20):
+                break
+            log().table_dump(f"{i + 1}. Search", _dump_searches([search], query))
+            for result in await search.lookup(self.conn, self.params):
+                results.append(result)
+                min_ranking = min(min_ranking, result.ranking + 0.5, search.penalty + 0.3)
+            log().result_dump('Results', ((r.accuracy, r) for r in results[num_results:]))
+            num_results = len(results)
+            prev_penalty = search.penalty
+
+        if results:
+            min_ranking = min(r.ranking for r in results)
+            results = SearchResults(r for r in results if r.ranking < min_ranking + 0.5)
+
+        if results:
+            min_rank = min(r.rank_search for r in results)
+
+            results = SearchResults(r for r in results
+                                    if r.ranking + 0.05 * (r.rank_search - min_rank)
+                                       < min_ranking + 0.5)
+
+            results.sort(key=lambda r: r.accuracy - r.calculated_importance())
+            results = SearchResults(results[:self.limit])
+
+        return results
+
+
+    async def lookup_pois(self, categories: List[Tuple[str, str]],
+                          phrases: List[Phrase]) -> SearchResults:
+        """ Look up places by category. If phrase is given, a place search
+            over the phrase will be executed first and places close to the
+            results returned.
+        """
+        log().function('forward_lookup_pois', categories=categories, params=self.params)
+
+        if phrases:
+            query, searches = await self.build_searches(phrases)
+
+            if query:
+                searches = [wrap_near_search(categories, s) for s in searches[:50]]
+                results = await self.execute_searches(query, searches)
+            else:
+                results = SearchResults()
+        else:
+            search = build_poi_search(categories, self.params.countries)
+            results = await search.lookup(self.conn, self.params)
+
+        await add_result_details(self.conn, results, self.params)
+        log().result_dump('Final Results', ((r.accuracy, r) for r in results))
+
+        return results
+
+
+    async def lookup(self, phrases: List[Phrase]) -> SearchResults:
+        """ Look up a single free-text query.
+        """
+        log().function('forward_lookup', phrases=phrases, params=self.params)
+        results = SearchResults()
+
+        if self.params.is_impossible():
+            return results
+
+        query, searches = await self.build_searches(phrases)
+
+        if searches:
+            # Execute SQL until an appropriate result is found.
+            results = await self.execute_searches(query, searches[:50])
+            await add_result_details(self.conn, results, self.params)
+            log().result_dump('Final Results', ((r.accuracy, r) for r in results))
+
+        return results
+
+
+# pylint: disable=invalid-name,too-many-locals
+def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
+                   start: int = 0) -> Iterator[Optional[List[Any]]]:
+    yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries', 'Qualifier', 'Rankings']
+
+    def tk(tl: List[int]) -> str:
+        tstr = [f"{query.find_lookup_word_by_id(t)}({t})" for t in tl]
+
+        return f"[{','.join(tstr)}]"
+
+    def fmt_ranking(f: Any) -> str:
+        if not f:
+            return ''
+        ranks = ','.join((f"{tk(r.tokens)}^{r.penalty:.3g}" for r in f.rankings))
+        if len(ranks) > 100:
+            ranks = ranks[:100] + '...'
+        return f"{f.column}({ranks},def={f.default:.3g})"
+
+    def fmt_lookup(l: Any) -> str:
+        if not l:
+            return ''
+
+        return f"{l.lookup_type}({l.column}{tk(l.tokens)})"
+
+
+    def fmt_cstr(c: Any) -> str:
+        if not c:
+            return ''
+
+        return f'{c[0]}^{c[1]}'
+
+    for search in searches[start:]:
+        fields = ('name_lookups', 'name_ranking', 'countries', 'housenumbers',
+                  'postcodes', 'qualifier')
+        iters = itertools.zip_longest([f"{search.penalty:.3g}"],
+                                      *(getattr(search, attr, []) for attr in fields),
+                                      fillvalue= '')
+        for penalty, lookup, rank, cc, hnr, pc, qual in iters:
+            yield [penalty, fmt_lookup(lookup), fmt_cstr(hnr),
+                   fmt_cstr(pc), fmt_cstr(cc), fmt_cstr(qual), fmt_ranking(rank)]
+        yield None
index 9804f3ce8a1129b7a3b7699425e94ba4017f4592..35649d0ffe4cb544daf5a07a0df17ebbfe159d81 100644 (file)
@@ -7,14 +7,16 @@
 """
 Factory for creating a query analyzer for the configured tokenizer.
 """
 """
 Factory for creating a query analyzer for the configured tokenizer.
 """
-from typing import List, cast
+from typing import List, cast, TYPE_CHECKING
 from abc import ABC, abstractmethod
 from pathlib import Path
 import importlib
 
 from nominatim.api.logging import log
 from nominatim.api.connection import SearchConnection
 from abc import ABC, abstractmethod
 from pathlib import Path
 import importlib
 
 from nominatim.api.logging import log
 from nominatim.api.connection import SearchConnection
-from nominatim.api.search.query import Phrase, QueryStruct
+
+if TYPE_CHECKING:
+    from nominatim.api.search.query import Phrase, QueryStruct
 
 class AbstractQueryAnalyzer(ABC):
     """ Class for analysing incomming queries.
 
 class AbstractQueryAnalyzer(ABC):
     """ Class for analysing incomming queries.
@@ -23,7 +25,7 @@ class AbstractQueryAnalyzer(ABC):
     """
 
     @abstractmethod
     """
 
     @abstractmethod
-    async def analyze_query(self, phrases: List[Phrase]) -> QueryStruct:
+    async def analyze_query(self, phrases: List['Phrase']) -> 'QueryStruct':
         """ Analyze the given phrases and return the tokenized query.
         """
 
         """ Analyze the given phrases and return the tokenized query.
         """
 
diff --git a/test/python/api/test_api_search.py b/test/python/api/test_api_search.py
new file mode 100644 (file)
index 0000000..aa263d2
--- /dev/null
@@ -0,0 +1,159 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Tests for search API calls.
+
+These tests make sure that all Python code is correct and executable.
+Functional tests can be found in the BDD test suite.
+"""
+import json
+
+import pytest
+
+import sqlalchemy as sa
+
+import nominatim.api as napi
+import nominatim.api.logging as loglib
+
+@pytest.fixture(autouse=True)
+def setup_icu_tokenizer(apiobj):
+    """ Setup the propoerties needed for using the ICU tokenizer.
+    """
+    apiobj.add_data('properties',
+                    [{'property': 'tokenizer', 'value': 'icu'},
+                     {'property': 'tokenizer_import_normalisation', 'value': ':: lower();'},
+                     {'property': 'tokenizer_import_transliteration', 'value': "'1' > '/1/'; 'ä' > 'ä '"},
+                    ])
+
+
+def test_search_no_content(apiobj, table_factory):
+    table_factory('word',
+                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+
+    assert apiobj.api.search('foo') == []
+
+
+def test_search_simple_word(apiobj, table_factory):
+    table_factory('word',
+                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
+                  content=[(55, 'test', 'W', 'test', None),
+                           (2, 'test', 'w', 'test', None)])
+
+    apiobj.add_placex(place_id=444, class_='place', type='village',
+                      centroid=(1.3, 0.7))
+    apiobj.add_search_name(444, names=[2, 55])
+
+    results = apiobj.api.search('TEST')
+
+    assert [r.place_id for r in results] == [444]
+
+
+@pytest.mark.parametrize('logtype', ['text', 'html'])
+def test_search_with_debug(apiobj, table_factory, logtype):
+    table_factory('word',
+                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
+                  content=[(55, 'test', 'W', 'test', None),
+                           (2, 'test', 'w', 'test', None)])
+
+    apiobj.add_placex(place_id=444, class_='place', type='village',
+                      centroid=(1.3, 0.7))
+    apiobj.add_search_name(444, names=[2, 55])
+
+    loglib.set_log_output(logtype)
+    results = apiobj.api.search('TEST')
+
+    assert loglib.get_and_disable()
+
+
+def test_address_no_content(apiobj, table_factory):
+    table_factory('word',
+                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+
+    assert apiobj.api.search_address(amenity='hotel',
+                                     street='Main St 34',
+                                     city='Happyville',
+                                     county='Wideland',
+                                     state='Praerie',
+                                     postalcode='55648',
+                                     country='xx') == []
+
+
+@pytest.mark.parametrize('atype,address,search', [('street', 26, 26),
+                                                  ('city', 16, 18),
+                                                  ('county', 12, 12),
+                                                  ('state', 8, 8)])
+def test_address_simple_places(apiobj, table_factory, atype, address, search):
+    table_factory('word',
+                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
+                  content=[(55, 'test', 'W', 'test', None),
+                           (2, 'test', 'w', 'test', None)])
+
+    apiobj.add_placex(place_id=444,
+                      rank_address=address, rank_search=search,
+                      centroid=(1.3, 0.7))
+    apiobj.add_search_name(444, names=[2, 55], address_rank=address, search_rank=search)
+
+    results = apiobj.api.search_address(**{atype: 'TEST'})
+
+    assert [r.place_id for r in results] == [444]
+
+
+def test_address_country(apiobj, table_factory):
+    table_factory('word',
+                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
+                  content=[(None, 'ro', 'C', 'ro', None)])
+    apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
+    apiobj.add_country_name('ro', {'name': 'România'})
+
+    assert len(apiobj.api.search_address(country='ro')) == 1
+
+
+def test_category_no_categories(apiobj, table_factory):
+    table_factory('word',
+                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+
+    assert apiobj.api.search_category([], near_query='Berlin') == []
+
+
+def test_category_no_content(apiobj, table_factory):
+    table_factory('word',
+                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+
+    assert apiobj.api.search_category([('amenity', 'restaurant')]) == []
+
+
+def test_category_simple_restaurant(apiobj, table_factory):
+    table_factory('word',
+                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+
+    apiobj.add_placex(place_id=444, class_='amenity', type='restaurant',
+                      centroid=(1.3, 0.7))
+    apiobj.add_search_name(444, names=[2, 55], address_rank=16, search_rank=18)
+
+    results = apiobj.api.search_category([('amenity', 'restaurant')],
+                                         near=(1.3, 0.701), near_radius=0.015)
+
+    assert [r.place_id for r in results] == [444]
+
+
+def test_category_with_search_phrase(apiobj, table_factory):
+    table_factory('word',
+                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
+                  content=[(55, 'test', 'W', 'test', None),
+                           (2, 'test', 'w', 'test', None)])
+
+    apiobj.add_placex(place_id=444, class_='place', type='village',
+                      rank_address=16, rank_search=18,
+                      centroid=(1.3, 0.7))
+    apiobj.add_search_name(444, names=[2, 55], address_rank=16, search_rank=18)
+    apiobj.add_placex(place_id=95, class_='amenity', type='restaurant',
+                      centroid=(1.3, 0.7003))
+
+    results = apiobj.api.search_category([('amenity', 'restaurant')],
+                                         near_query='TEST')
+
+    assert [r.place_id for r in results] == [95]