From: Sarah Hoffmann Parameters: {text}Debug output for {func}()
\n")
for name, value in kwargs.items():
self._write(f'
{heading}
")
def comment(self, text: str) -> None:
+ self._timestamp()
self._write(f"{heading}
{self._python_var(var)}')
+ def table_dump(self, heading: str, rows: Iterator[Optional[List[Any]]]) -> None:
+ self._timestamp()
+ head = next(rows)
+ assert head
+ self._write(f'
')
+
+
+ def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None:
+ """ Print a list of search results generated by the generator function.
+ """
+ self._timestamp()
+ def format_osm(osm_object: Optional[Tuple[str, int]]) -> str:
+ if not osm_object:
+ return '-'
+
+ t, i = osm_object
+ if t == 'N':
+ fullt = 'node'
+ elif t == 'W':
+ fullt = 'way'
+ elif t == 'R':
+ fullt = 'relation'
+ else:
+ return f'{t}{i}'
+
+ return f'{t}{i}'
+
+ self._write(f'{heading} ')
+ for cell in head:
+ self._write(f' ')
+ for row in rows:
+ if row is not None:
+ self._write('{cell} ')
+ self._write('')
+ for cell in row:
+ self._write(f' ')
+ self._write('{cell} ')
+ self._write('{heading}
')
+ total = 0
+ for rank, res in results:
+ self._write(f'
TOTAL: {total}
{fmt}
{str(var)}
'
@@ -155,9 +235,47 @@ class TextLogger(BaseLogger):
def var_dump(self, heading: str, var: Any) -> None:
+ if callable(var):
+ var = var()
+
self._write(f'{heading}:\n {self._python_var(var)}\n\n')
+ def table_dump(self, heading: str, rows: Iterator[Optional[List[Any]]]) -> None:
+ self._write(f'{heading}:\n')
+ data = [list(map(self._python_var, row)) if row else None for row in rows]
+ assert data[0] is not None
+ num_cols = len(data[0])
+
+ maxlens = [max(len(d[i]) for d in data if d) for i in range(num_cols)]
+ tablewidth = sum(maxlens) + 3 * num_cols + 1
+ row_format = '| ' +' | '.join(f'{{:<{l}}}' for l in maxlens) + ' |\n'
+ self._write('-'*tablewidth + '\n')
+ self._write(row_format.format(*data[0]))
+ self._write('-'*tablewidth + '\n')
+ for row in data[1:]:
+ if row:
+ self._write(row_format.format(*row))
+ else:
+ self._write('-'*tablewidth + '\n')
+ if data[-1]:
+ self._write('-'*tablewidth + '\n')
+
+
+ def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None:
+ self._write(f'{heading}:\n')
+ total = 0
+ for rank, res in results:
+ self._write(f'[{rank:.3f}] {res.source_table.name}(')
+ self._write(f"{_debug_name(res)}, type=({','.join(res.category)}), ")
+ self._write(f"rank={res.rank_address}, ")
+ self._write(f"osm={''.join(map(str, res.osm_object or []))}, ")
+ self._write(f'cc={res.country_code}, ')
+ self._write(f'importance={res.importance or -1:.5f})\n')
+ total += 1
+ self._write(f'TOTAL: {total}\n\n')
+
+
def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement), width=78))
self._write(f"| {sqlstr}\n\n")
@@ -242,6 +360,26 @@ HTML_HEADER: str = """
padding: 3pt;
border: solid lightgrey 0.1pt
}
+
+ table, th, tbody {
+ border: thin solid;
+ border-collapse: collapse;
+ }
+ td {
+ border-right: thin solid;
+ padding-left: 3pt;
+ padding-right: 3pt;
+ }
+
+ .timestamp {
+ font-size: 0.8em;
+ color: darkblue;
+ width: calc(100% - 5pt);
+ text-align: right;
+ position: absolute;
+ left: 0;
+ margin-top: -5px;
+ }
diff --git a/nominatim/api/lookup.py b/nominatim/api/lookup.py
index 82352702..0e1fd9ce 100644
--- a/nominatim/api/lookup.py
+++ b/nominatim/api/lookup.py
@@ -189,13 +189,13 @@ async def get_detailed_place(conn: SearchConnection, place: ntyp.PlaceRef,
if indexed_date is not None:
result.indexed_date = indexed_date.replace(tzinfo=dt.timezone.utc)
- await nres.add_result_details(conn, result, details)
+ await nres.add_result_details(conn, [result], details)
return result
async def get_simple_place(conn: SearchConnection, place: ntyp.PlaceRef,
- details: ntyp.LookupDetails) -> Optional[nres.SearchResult]:
+ details: ntyp.LookupDetails) -> Optional[nres.SearchResult]:
""" Retrieve a place as a simple search result from the database.
"""
log().function('get_simple_place', place=place, details=details)
@@ -234,6 +234,6 @@ async def get_simple_place(conn: SearchConnection, place: ntyp.PlaceRef,
assert result is not None
result.bbox = getattr(row, 'bbox', None)
- await nres.add_result_details(conn, result, details)
+ await nres.add_result_details(conn, [result], details)
return result
diff --git a/nominatim/api/results.py b/nominatim/api/results.py
index 98b13380..c661b508 100644
--- a/nominatim/api/results.py
+++ b/nominatim/api/results.py
@@ -11,7 +11,7 @@ Data classes are part of the public API while the functions are for
internal use only. That's why they are implemented as free-standing functions
instead of member functions.
"""
-from typing import Optional, Tuple, Dict, Sequence, TypeVar, Type, List
+from typing import Optional, Tuple, Dict, Sequence, TypeVar, Type, List, Any
import enum
import dataclasses
import datetime as dt
@@ -27,6 +27,24 @@ from nominatim.api.localization import Locales
# This file defines complex result data classes.
# pylint: disable=too-many-instance-attributes
+def _mingle_name_tags(names: Optional[Dict[str, str]]) -> Optional[Dict[str, str]]:
+ """ Mix-in names from linked places, so that they show up
+ as standard names where necessary.
+ """
+ if not names:
+ return None
+
+ out = {}
+ for k, v in names.items():
+ if k.startswith('_place_'):
+ outkey = k[7:]
+ out[k if outkey in names else outkey] = v
+ else:
+ out[k] = v
+
+ return out
+
+
class SourceTable(enum.Enum):
""" Enumeration of kinds of results.
"""
@@ -103,6 +121,9 @@ class BaseResult:
place_id : Optional[int] = None
osm_object: Optional[Tuple[str, int]] = None
+ locale_name: Optional[str] = None
+ display_name: Optional[str] = None
+
names: Optional[Dict[str, str]] = None
address: Optional[Dict[str, str]] = None
extratags: Optional[Dict[str, str]] = None
@@ -146,6 +167,19 @@ class BaseResult:
"""
return self.importance or (0.7500001 - (self.rank_search/40.0))
+
+ def localize(self, locales: Locales) -> None:
+ """ Fill the locale_name and the display_name field for the
+ place and, if available, its address information.
+ """
+ self.locale_name = locales.display_name(self.names)
+ if self.address_rows:
+ self.display_name = ', '.join(self.address_rows.localize(locales))
+ else:
+ self.display_name = self.locale_name
+
+
+
BaseResultT = TypeVar('BaseResultT', bound=BaseResult)
@dataclasses.dataclass
@@ -178,6 +212,15 @@ class SearchResult(BaseResult):
""" A search result for forward geocoding.
"""
bbox: Optional[Bbox] = None
+ accuracy: float = 0.0
+
+
+ @property
+ def ranking(self) -> float:
+ """ Return the ranking, a combined measure of accuracy and importance.
+ """
+ return (self.accuracy if self.accuracy is not None else 1) \
+ - self.calculated_importance()
class SearchResults(List[SearchResult]):
@@ -185,6 +228,12 @@ class SearchResults(List[SearchResult]):
May be empty when no result was found.
"""
+ def localize(self, locales: Locales) -> None:
+ """ Apply the given locales to all results.
+ """
+ for result in self:
+ result.localize(locales)
+
def _filter_geometries(row: SaRow) -> Dict[str, str]:
return {k[9:]: v for k, v in row._mapping.items() # pylint: disable=W0212
@@ -204,7 +253,7 @@ def create_from_placex_row(row: Optional[SaRow],
place_id=row.place_id,
osm_object=(row.osm_type, row.osm_id),
category=(row.class_, row.type),
- names=row.name,
+ names=_mingle_name_tags(row.name),
address=row.address,
extratags=row.extratags,
housenumber=row.housenumber,
@@ -305,24 +354,45 @@ def create_from_postcode_row(row: Optional[SaRow],
geometry=_filter_geometries(row))
-async def add_result_details(conn: SearchConnection, result: BaseResult,
+def create_from_country_row(row: Optional[SaRow],
+ class_type: Type[BaseResultT]) -> Optional[BaseResultT]:
+ """ Construct a new result and add the data from the result row
+ from the fallback country tables. 'class_type' defines
+ the type of result to return. Returns None if the row is None.
+ """
+ if row is None:
+ return None
+
+ return class_type(source_table=SourceTable.COUNTRY,
+ category=('place', 'country'),
+ centroid=Point.from_wkb(row.centroid.data),
+ names=row.name,
+ rank_address=4, rank_search=4,
+ country_code=row.country_code)
+
+
+async def add_result_details(conn: SearchConnection, results: List[BaseResultT],
details: LookupDetails) -> None:
""" Retrieve more details from the database according to the
parameters specified in 'details'.
"""
- log().section('Query details for result')
- if details.address_details:
- log().comment('Query address details')
- await complete_address_details(conn, result)
- if details.linked_places:
- log().comment('Query linked places')
- await complete_linked_places(conn, result)
- if details.parented_places:
- log().comment('Query parent places')
- await complete_parented_places(conn, result)
- if details.keywords:
- log().comment('Query keywords')
- await complete_keywords(conn, result)
+ if results:
+ log().section('Query details for result')
+ if details.address_details:
+ log().comment('Query address details')
+ await complete_address_details(conn, results)
+ if details.linked_places:
+ log().comment('Query linked places')
+ for result in results:
+ await complete_linked_places(conn, result)
+ if details.parented_places:
+ log().comment('Query parent places')
+ for result in results:
+ await complete_parented_places(conn, result)
+ if details.keywords:
+ log().comment('Query keywords')
+ for result in results:
+ await complete_keywords(conn, result)
def _result_row_to_address_row(row: SaRow) -> AddressLine:
@@ -332,10 +402,8 @@ def _result_row_to_address_row(row: SaRow) -> AddressLine:
if hasattr(row, 'place_type') and row.place_type:
extratags['place'] = row.place_type
- names = row.name
+ names = _mingle_name_tags(row.name) or {}
if getattr(row, 'housenumber', None) is not None:
- if names is None:
- names = {}
names['housenumber'] = row.housenumber
return AddressLine(place_id=row.place_id,
@@ -350,35 +418,60 @@ def _result_row_to_address_row(row: SaRow) -> AddressLine:
distance=row.distance)
-async def complete_address_details(conn: SearchConnection, result: BaseResult) -> None:
+async def complete_address_details(conn: SearchConnection, results: List[BaseResultT]) -> None:
""" Retrieve information about places that make up the address of the result.
"""
- housenumber = -1
- if result.source_table in (SourceTable.TIGER, SourceTable.OSMLINE):
- if result.housenumber is not None:
- housenumber = int(result.housenumber)
- elif result.extratags is not None and 'startnumber' in result.extratags:
- # details requests do not come with a specific house number
- housenumber = int(result.extratags['startnumber'])
-
- sfn = sa.func.get_addressdata(result.place_id, housenumber)\
- .table_valued( # type: ignore[no-untyped-call]
- sa.column('place_id', type_=sa.Integer),
- 'osm_type',
- sa.column('osm_id', type_=sa.BigInteger),
- sa.column('name', type_=conn.t.types.Composite),
- 'class', 'type', 'place_type',
- sa.column('admin_level', type_=sa.Integer),
- sa.column('fromarea', type_=sa.Boolean),
- sa.column('isaddress', type_=sa.Boolean),
- sa.column('rank_address', type_=sa.SmallInteger),
- sa.column('distance', type_=sa.Float))
- sql = sa.select(sfn).order_by(sa.column('rank_address').desc(),
- sa.column('isaddress').desc())
-
- result.address_rows = AddressLines()
+ def get_hnr(result: BaseResult) -> Tuple[int, int]:
+ housenumber = -1
+ if result.source_table in (SourceTable.TIGER, SourceTable.OSMLINE):
+ if result.housenumber is not None:
+ housenumber = int(result.housenumber)
+ elif result.extratags is not None and 'startnumber' in result.extratags:
+ # details requests do not come with a specific house number
+ housenumber = int(result.extratags['startnumber'])
+ assert result.place_id
+ return result.place_id, housenumber
+
+ data: List[Tuple[Any, ...]] = [get_hnr(r) for r in results if r.place_id]
+
+ if not data:
+ return
+
+ values = sa.values(sa.column('place_id', type_=sa.Integer),
+ sa.column('housenumber', type_=sa.Integer),
+ name='places',
+ literal_binds=True).data(data)
+
+ sfn = sa.func.get_addressdata(values.c.place_id, values.c.housenumber)\
+ .table_valued( # type: ignore[no-untyped-call]
+ sa.column('place_id', type_=sa.Integer),
+ 'osm_type',
+ sa.column('osm_id', type_=sa.BigInteger),
+ sa.column('name', type_=conn.t.types.Composite),
+ 'class', 'type', 'place_type',
+ sa.column('admin_level', type_=sa.Integer),
+ sa.column('fromarea', type_=sa.Boolean),
+ sa.column('isaddress', type_=sa.Boolean),
+ sa.column('rank_address', type_=sa.SmallInteger),
+ sa.column('distance', type_=sa.Float),
+ joins_implicitly=True)
+
+ sql = sa.select(values.c.place_id.label('result_place_id'), sfn)\
+ .order_by(values.c.place_id,
+ sa.column('rank_address').desc(),
+ sa.column('isaddress').desc())
+
+ current_result = None
for row in await conn.execute(sql):
- result.address_rows.append(_result_row_to_address_row(row))
+ if current_result is None or row.result_place_id != current_result.place_id:
+ for result in results:
+ if result.place_id == row.result_place_id:
+ current_result = result
+ break
+ else:
+ assert False
+ current_result.address_rows = AddressLines()
+ current_result.address_rows.append(_result_row_to_address_row(row))
# pylint: disable=consider-using-f-string
@@ -413,6 +506,9 @@ async def complete_linked_places(conn: SearchConnection, result: BaseResult) ->
async def complete_keywords(conn: SearchConnection, result: BaseResult) -> None:
""" Retrieve information about the search terms used for this place.
+
+ Requires that the query analyzer was initialised to get access to
+ the word table.
"""
t = conn.t.search_name
sql = sa.select(t.c.name_vector, t.c.nameaddress_vector)\
@@ -420,10 +516,11 @@ async def complete_keywords(conn: SearchConnection, result: BaseResult) -> None:
result.name_keywords = []
result.address_keywords = []
- for name_tokens, address_tokens in await conn.execute(sql):
- t = conn.t.word
- sel = sa.select(t.c.word_id, t.c.word_token, t.c.word)
+ t = conn.t.meta.tables['word']
+ sel = sa.select(t.c.word_id, t.c.word_token, t.c.word)
+
+ for name_tokens, address_tokens in await conn.execute(sql):
for row in await conn.execute(sel.where(t.c.word_id == sa.any_(name_tokens))):
result.name_keywords.append(WordInfo(*row))
diff --git a/nominatim/api/reverse.py b/nominatim/api/reverse.py
index d6976c06..10c97cad 100644
--- a/nominatim/api/reverse.py
+++ b/nominatim/api/reverse.py
@@ -548,6 +548,6 @@ class ReverseGeocoder:
result.distance = row.distance
if hasattr(row, 'bbox'):
result.bbox = Bbox.from_wkb(row.bbox.data)
- await nres.add_result_details(self.conn, result, self.params)
+ await nres.add_result_details(self.conn, [result], self.params)
return result
diff --git a/nominatim/api/search/__init__.py b/nominatim/api/search/__init__.py
new file mode 100644
index 00000000..f60cbe1e
--- /dev/null
+++ b/nominatim/api/search/__init__.py
@@ -0,0 +1,15 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Module for forward search.
+"""
+# pylint: disable=useless-import-alias
+
+from .geocoder import (ForwardGeocoder as ForwardGeocoder)
+from .query import (Phrase as Phrase,
+ PhraseType as PhraseType)
+from .query_analyzer_factory import (make_query_analyzer as make_query_analyzer)
diff --git a/nominatim/api/search/db_search_builder.py b/nominatim/api/search/db_search_builder.py
new file mode 100644
index 00000000..9ff8c03c
--- /dev/null
+++ b/nominatim/api/search/db_search_builder.py
@@ -0,0 +1,375 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Convertion from token assignment to an abstract DB search.
+"""
+from typing import Optional, List, Tuple, Iterator
+import heapq
+
+from nominatim.api.types import SearchDetails, DataLayer
+from nominatim.api.search.query import QueryStruct, Token, TokenType, TokenRange, BreakType
+from nominatim.api.search.token_assignment import TokenAssignment
+import nominatim.api.search.db_search_fields as dbf
+import nominatim.api.search.db_searches as dbs
+from nominatim.api.logging import log
+
+
+def wrap_near_search(categories: List[Tuple[str, str]],
+ search: dbs.AbstractSearch) -> dbs.NearSearch:
+ """ Create a new search that wraps the given search in a search
+ for near places of the given category.
+ """
+ return dbs.NearSearch(penalty=search.penalty,
+ categories=dbf.WeightedCategories(categories,
+ [0.0] * len(categories)),
+ search=search)
+
+
+def build_poi_search(category: List[Tuple[str, str]],
+ countries: Optional[List[str]]) -> dbs.PoiSearch:
+ """ Create a new search for places by the given category, possibly
+ constraint to the given countries.
+ """
+ if countries:
+ ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
+ else:
+ ccs = dbf.WeightedStrings([], [])
+
+ class _PoiData(dbf.SearchData):
+ penalty = 0.0
+ qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
+ countries=ccs
+
+ return dbs.PoiSearch(_PoiData())
+
+
+class SearchBuilder:
+ """ Build the abstract search queries from token assignments.
+ """
+
+ def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
+ self.query = query
+ self.details = details
+
+
+ @property
+ def configured_for_country(self) -> bool:
+ """ Return true if the search details are configured to
+ allow countries in the result.
+ """
+ return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
+ and self.details.layer_enabled(DataLayer.ADDRESS)
+
+
+ @property
+ def configured_for_postcode(self) -> bool:
+ """ Return true if the search details are configured to
+ allow postcodes in the result.
+ """
+ return self.details.min_rank <= 5 and self.details.max_rank >= 11\
+ and self.details.layer_enabled(DataLayer.ADDRESS)
+
+
+ @property
+ def configured_for_housenumbers(self) -> bool:
+ """ Return true if the search details are configured to
+ allow addresses in the result.
+ """
+ return self.details.max_rank >= 30 \
+ and self.details.layer_enabled(DataLayer.ADDRESS)
+
+
+ def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
+ """ Yield all possible abstract searches for the given token assignment.
+ """
+ sdata = self.get_search_data(assignment)
+ if sdata is None:
+ return
+
+ categories = self.get_search_categories(assignment)
+
+ if assignment.name is None:
+ if categories and not sdata.postcodes:
+ sdata.qualifiers = categories
+ categories = None
+ builder = self.build_poi_search(sdata)
+ elif assignment.housenumber:
+ hnr_tokens = self.query.get_tokens(assignment.housenumber,
+ TokenType.HOUSENUMBER)
+ builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
+ else:
+ builder = self.build_special_search(sdata, assignment.address,
+ bool(categories))
+ else:
+ builder = self.build_name_search(sdata, assignment.name, assignment.address,
+ bool(categories))
+
+ if categories:
+ penalty = min(categories.penalties)
+ categories.penalties = [p - penalty for p in categories.penalties]
+ for search in builder:
+ yield dbs.NearSearch(penalty, categories, search)
+ else:
+ yield from builder
+
+
+ def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
+ """ Build abstract search query for a simple category search.
+ This kind of search requires an additional geographic constraint.
+ """
+ if not sdata.housenumbers \
+ and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
+ yield dbs.PoiSearch(sdata)
+
+
+ def build_special_search(self, sdata: dbf.SearchData,
+ address: List[TokenRange],
+ is_category: bool) -> Iterator[dbs.AbstractSearch]:
+ """ Build abstract search queries for searches that do not involve
+ a named place.
+ """
+ if sdata.qualifiers:
+ # No special searches over qualifiers supported.
+ return
+
+ if sdata.countries and not address and not sdata.postcodes \
+ and self.configured_for_country:
+ yield dbs.CountrySearch(sdata)
+
+ if sdata.postcodes and (is_category or self.configured_for_postcode):
+ penalty = 0.0 if sdata.countries else 0.1
+ if address:
+ sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
+ [t.token for r in address
+ for t in self.query.get_partials_list(r)],
+ 'restrict')]
+ penalty += 0.2
+ yield dbs.PostcodeSearch(penalty, sdata)
+
+
+ def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
+ address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
+ """ Build a simple address search for special entries where the
+ housenumber is the main name token.
+ """
+ partial_tokens: List[int] = []
+ for trange in address:
+ partial_tokens.extend(t.token for t in self.query.get_partials_list(trange))
+
+ sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], 'lookup_any'),
+ dbf.FieldLookup('nameaddress_vector', partial_tokens, 'lookup_all')
+ ]
+ yield dbs.PlaceSearch(0.05, sdata, sum(t.count for t in hnrs))
+
+
+ def build_name_search(self, sdata: dbf.SearchData,
+ name: TokenRange, address: List[TokenRange],
+ is_category: bool) -> Iterator[dbs.AbstractSearch]:
+ """ Build abstract search queries for simple name or address searches.
+ """
+ if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
+ ranking = self.get_name_ranking(name)
+ name_penalty = ranking.normalize_penalty()
+ if ranking.rankings:
+ sdata.rankings.append(ranking)
+ for penalty, count, lookup in self.yield_lookups(name, address):
+ sdata.lookups = lookup
+ yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
+
+
+ def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
+ -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
+ """ Yield all variants how the given name and address should best
+ be searched for. This takes into account how frequent the terms
+ are and tries to find a lookup that optimizes index use.
+ """
+ penalty = 0.0 # extra penalty currently unused
+
+ name_partials = self.query.get_partials_list(name)
+ exp_name_count = min(t.count for t in name_partials)
+ addr_partials = []
+ for trange in address:
+ addr_partials.extend(self.query.get_partials_list(trange))
+ addr_tokens = [t.token for t in addr_partials]
+ partials_indexed = all(t.is_indexed for t in name_partials) \
+ and all(t.is_indexed for t in addr_partials)
+
+ if (len(name_partials) > 3 or exp_name_count < 1000) and partials_indexed:
+ # Lookup by name partials, use address partials to restrict results.
+ lookup = [dbf.FieldLookup('name_vector',
+ [t.token for t in name_partials], 'lookup_all')]
+ if addr_tokens:
+ lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
+ yield penalty, exp_name_count, lookup
+ return
+
+ exp_addr_count = min(t.count for t in addr_partials) if addr_partials else exp_name_count
+ if exp_addr_count < 1000 and partials_indexed:
+ # Lookup by address partials and restrict results through name terms.
+ yield penalty, exp_addr_count,\
+ [dbf.FieldLookup('name_vector', [t.token for t in name_partials], 'restrict'),
+ dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]
+ return
+
+ # Partial term to frequent. Try looking up by rare full names first.
+ name_fulls = self.query.get_tokens(name, TokenType.WORD)
+ rare_names = list(filter(lambda t: t.count < 1000, name_fulls))
+ # At this point drop unindexed partials from the address.
+ # This might yield wrong results, nothing we can do about that.
+ if not partials_indexed:
+ addr_tokens = [t.token for t in addr_partials if t.is_indexed]
+ log().var_dump('before', penalty)
+ penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
+ log().var_dump('after', penalty)
+ if rare_names:
+ # Any of the full names applies with all of the partials from the address
+ lookup = [dbf.FieldLookup('name_vector', [t.token for t in rare_names], 'lookup_any')]
+ if addr_tokens:
+ lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
+ yield penalty, sum(t.count for t in rare_names), lookup
+
+ # To catch remaining results, lookup by name and address
+ if all(t.is_indexed for t in name_partials):
+ lookup = [dbf.FieldLookup('name_vector',
+ [t.token for t in name_partials], 'lookup_all')]
+ else:
+ # we don't have the partials, try with the non-rare names
+ non_rare_names = [t.token for t in name_fulls if t.count >= 1000]
+ if not non_rare_names:
+ return
+ lookup = [dbf.FieldLookup('name_vector', non_rare_names, 'lookup_any')]
+ if addr_tokens:
+ lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all'))
+ yield penalty + 0.1 * max(0, 5 - len(name_partials) - len(addr_tokens)),\
+ min(exp_name_count, exp_addr_count), lookup
+
+
+ def get_name_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
+ """ Create a ranking expression for a name term in the given range.
+ """
+ name_fulls = self.query.get_tokens(trange, TokenType.WORD)
+ ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
+ ranks.sort(key=lambda r: r.penalty)
+ # Fallback, sum of penalty for partials
+ name_partials = self.query.get_partials_list(trange)
+ default = sum(t.penalty for t in name_partials) + 0.2
+ return dbf.FieldRanking('name_vector', default, ranks)
+
+
+ def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
+ """ Create a list of ranking expressions for an address term
+ for the given ranges.
+ """
+ todo: List[Tuple[int, int, dbf.RankedTokens]] = []
+ heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
+ ranks: List[dbf.RankedTokens] = []
+
+ while todo: # pylint: disable=too-many-nested-blocks
+ neglen, pos, rank = heapq.heappop(todo)
+ for tlist in self.query.nodes[pos].starting:
+ if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
+ if tlist.end < trange.end:
+ chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
+ if tlist.ttype == TokenType.PARTIAL:
+ penalty = rank.penalty + chgpenalty \
+ + max(t.penalty for t in tlist.tokens)
+ heapq.heappush(todo, (neglen - 1, tlist.end,
+ dbf.RankedTokens(penalty, rank.tokens)))
+ else:
+ for t in tlist.tokens:
+ heapq.heappush(todo, (neglen - 1, tlist.end,
+ rank.with_token(t, chgpenalty)))
+ elif tlist.end == trange.end:
+ if tlist.ttype == TokenType.PARTIAL:
+ ranks.append(dbf.RankedTokens(rank.penalty
+ + max(t.penalty for t in tlist.tokens),
+ rank.tokens))
+ else:
+ ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
+ if len(ranks) >= 10:
+ # Too many variants, bail out and only add
+ # Worst-case Fallback: sum of penalty of partials
+ name_partials = self.query.get_partials_list(trange)
+ default = sum(t.penalty for t in name_partials) + 0.2
+ ranks.append(dbf.RankedTokens(rank.penalty + default, []))
+ # Bail out of outer loop
+ todo.clear()
+ break
+
+ ranks.sort(key=lambda r: len(r.tokens))
+ default = ranks[0].penalty + 0.3
+ del ranks[0]
+ ranks.sort(key=lambda r: r.penalty)
+
+ return dbf.FieldRanking('nameaddress_vector', default, ranks)
+
+
+ def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
+ """ Collect the tokens for the non-name search fields in the
+ assignment.
+ """
+ sdata = dbf.SearchData()
+ sdata.penalty = assignment.penalty
+ if assignment.country:
+ tokens = self.query.get_tokens(assignment.country, TokenType.COUNTRY)
+ if self.details.countries:
+ tokens = [t for t in tokens if t.lookup_word in self.details.countries]
+ if not tokens:
+ return None
+ sdata.set_strings('countries', tokens)
+ elif self.details.countries:
+ sdata.countries = dbf.WeightedStrings(self.details.countries,
+ [0.0] * len(self.details.countries))
+ if assignment.housenumber:
+ sdata.set_strings('housenumbers',
+ self.query.get_tokens(assignment.housenumber,
+ TokenType.HOUSENUMBER))
+ if assignment.postcode:
+ sdata.set_strings('postcodes',
+ self.query.get_tokens(assignment.postcode,
+ TokenType.POSTCODE))
+ if assignment.qualifier:
+ sdata.set_qualifiers(self.query.get_tokens(assignment.qualifier,
+ TokenType.QUALIFIER))
+
+ if assignment.address:
+ sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
+ else:
+ sdata.rankings = []
+
+ return sdata
+
+
+ def get_search_categories(self,
+ assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
+ """ Collect tokens for category search or use the categories
+ requested per parameter.
+ Returns None if no category search is requested.
+ """
+ if assignment.category:
+ tokens = [t for t in self.query.get_tokens(assignment.category,
+ TokenType.CATEGORY)
+ if not self.details.categories
+ or t.get_category() in self.details.categories]
+ return dbf.WeightedCategories([t.get_category() for t in tokens],
+ [t.penalty for t in tokens])
+
+ if self.details.categories:
+ return dbf.WeightedCategories(self.details.categories,
+ [0.0] * len(self.details.categories))
+
+ return None
+
+
+PENALTY_WORDCHANGE = {
+ BreakType.START: 0.0,
+ BreakType.END: 0.0,
+ BreakType.PHRASE: 0.0,
+ BreakType.WORD: 0.1,
+ BreakType.PART: 0.2,
+ BreakType.TOKEN: 0.4
+}
diff --git a/nominatim/api/search/db_search_fields.py b/nominatim/api/search/db_search_fields.py
new file mode 100644
index 00000000..325e08df
--- /dev/null
+++ b/nominatim/api/search/db_search_fields.py
@@ -0,0 +1,212 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Data structures for more complex fields in abstract search descriptions.
+"""
+from typing import List, Tuple, Iterator, cast
+import dataclasses
+
+import sqlalchemy as sa
+from sqlalchemy.dialects.postgresql import ARRAY
+
+from nominatim.typing import SaFromClause, SaColumn, SaExpression
+from nominatim.api.search.query import Token
+
+@dataclasses.dataclass
+class WeightedStrings:
+ """ A list of strings together with a penalty.
+ """
+ values: List[str]
+ penalties: List[float]
+
+ def __bool__(self) -> bool:
+ return bool(self.values)
+
+
+ def __iter__(self) -> Iterator[Tuple[str, float]]:
+ return iter(zip(self.values, self.penalties))
+
+
+ def get_penalty(self, value: str, default: float = 1000.0) -> float:
+ """ Get the penalty for the given value. Returns the given default
+ if the value does not exist.
+ """
+ try:
+ return self.penalties[self.values.index(value)]
+ except ValueError:
+ pass
+ return default
+
+
+@dataclasses.dataclass
+class WeightedCategories:
+ """ A list of class/type tuples together with a penalty.
+ """
+ values: List[Tuple[str, str]]
+ penalties: List[float]
+
+ def __bool__(self) -> bool:
+ return bool(self.values)
+
+
+ def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
+ return iter(zip(self.values, self.penalties))
+
+
+ def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
+ """ Get the penalty for the given value. Returns the given default
+ if the value does not exist.
+ """
+ try:
+ return self.penalties[self.values.index(value)]
+ except ValueError:
+ pass
+ return default
+
+
+ def sql_restrict(self, table: SaFromClause) -> SaExpression:
+ """ Return an SQLAlcheny expression that restricts the
+ class and type columns of the given table to the values
+ in the list.
+ Must not be used with an empty list.
+ """
+ assert self.values
+ if len(self.values) == 1:
+ return sa.and_(table.c.class_ == self.values[0][0],
+ table.c.type == self.values[0][1])
+
+ return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
+ for c, t in self.values))
+
+
+@dataclasses.dataclass(order=True)
+class RankedTokens:
+ """ List of tokens together with the penalty of using it.
+ """
+ penalty: float
+ tokens: List[int]
+
+ def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
+ """ Create a new RankedTokens list with the given token appended.
+ The tokens penalty as well as the given transision penalty
+ are added to the overall penalty.
+ """
+ return RankedTokens(self.penalty + t.penalty + transition_penalty,
+ self.tokens + [t.token])
+
+
+@dataclasses.dataclass
+class FieldRanking:
+ """ A list of rankings to be applied sequentially until one matches.
+ The matched ranking determines the penalty. If none matches a
+ default penalty is applied.
+ """
+ column: str
+ default: float
+ rankings: List[RankedTokens]
+
+ def normalize_penalty(self) -> float:
+ """ Reduce the default and ranking penalties, such that the minimum
+ penalty is 0. Return the penalty that was subtracted.
+ """
+ if self.rankings:
+ min_penalty = min(self.default, min(r.penalty for r in self.rankings))
+ else:
+ min_penalty = self.default
+ if min_penalty > 0.0:
+ self.default -= min_penalty
+ for ranking in self.rankings:
+ ranking.penalty -= min_penalty
+ return min_penalty
+
+
+ def sql_penalty(self, table: SaFromClause) -> SaColumn:
+ """ Create an SQL expression for the rankings.
+ """
+ assert self.rankings
+
+ col = table.c[self.column]
+
+ return sa.case(*((col.contains(r.tokens),r.penalty) for r in self.rankings),
+ else_=self.default)
+
+
+@dataclasses.dataclass
+class FieldLookup:
+ """ A list of tokens to be searched for. The column names the database
+ column to search in and the lookup_type the operator that is applied.
+ 'lookup_all' requires all tokens to match. 'lookup_any' requires
+ one of the tokens to match. 'restrict' requires to match all tokens
+ but avoids the use of indexes.
+ """
+ column: str
+ tokens: List[int]
+ lookup_type: str
+
+ def sql_condition(self, table: SaFromClause) -> SaColumn:
+ """ Create an SQL expression for the given match condition.
+ """
+ col = table.c[self.column]
+ if self.lookup_type == 'lookup_all':
+ return col.contains(self.tokens)
+ if self.lookup_type == 'lookup_any':
+ return cast(SaColumn, col.overlap(self.tokens))
+
+ return sa.func.array_cat(col, sa.text('ARRAY[]::integer[]'),
+ type_=ARRAY(sa.Integer())).contains(self.tokens)
+
+
+class SearchData:
+ """ Search fields derived from query and token assignment
+ to be used with the SQL queries.
+ """
+ penalty: float
+
+ lookups: List[FieldLookup] = []
+ rankings: List[FieldRanking]
+
+ housenumbers: WeightedStrings = WeightedStrings([], [])
+ postcodes: WeightedStrings = WeightedStrings([], [])
+ countries: WeightedStrings = WeightedStrings([], [])
+
+ qualifiers: WeightedCategories = WeightedCategories([], [])
+
+
+ def set_strings(self, field: str, tokens: List[Token]) -> None:
+ """ Set on of the WeightedStrings properties from the given
+ token list. Adapt the global penalty, so that the
+ minimum penalty is 0.
+ """
+ if tokens:
+ min_penalty = min(t.penalty for t in tokens)
+ self.penalty += min_penalty
+ wstrs = WeightedStrings([t.lookup_word for t in tokens],
+ [t.penalty - min_penalty for t in tokens])
+
+ setattr(self, field, wstrs)
+
+
+ def set_qualifiers(self, tokens: List[Token]) -> None:
+ """ Set the qulaifier field from the given tokens.
+ """
+ if tokens:
+ min_penalty = min(t.penalty for t in tokens)
+ self.penalty += min_penalty
+ self.qualifiers = WeightedCategories([t.get_category() for t in tokens],
+ [t.penalty - min_penalty for t in tokens])
+
+
+ def set_ranking(self, rankings: List[FieldRanking]) -> None:
+ """ Set the list of rankings and normalize the ranking.
+ """
+ self.rankings = []
+ for ranking in rankings:
+ if ranking.rankings:
+ self.penalty += ranking.normalize_penalty()
+ self.rankings.append(ranking)
+ else:
+ self.penalty += ranking.default
diff --git a/nominatim/api/search/db_searches.py b/nominatim/api/search/db_searches.py
new file mode 100644
index 00000000..76ff368f
--- /dev/null
+++ b/nominatim/api/search/db_searches.py
@@ -0,0 +1,709 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Implementation of the acutal database accesses for forward search.
+"""
+from typing import List, Tuple, AsyncIterator
+import abc
+
+import sqlalchemy as sa
+from sqlalchemy.dialects.postgresql import ARRAY, array_agg
+
+from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \
+ SaExpression, SaSelect, SaRow
+from nominatim.api.connection import SearchConnection
+from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox
+import nominatim.api.results as nres
+from nominatim.api.search.db_search_fields import SearchData, WeightedCategories
+
+#pylint: disable=singleton-comparison,not-callable
+#pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements
+
+def _select_placex(t: SaFromClause) -> SaSelect:
+ return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
+ t.c.class_, t.c.type,
+ t.c.address, t.c.extratags,
+ t.c.housenumber, t.c.postcode, t.c.country_code,
+ t.c.importance, t.c.wikipedia,
+ t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
+ t.c.centroid,
+ t.c.geometry.ST_Expand(0).label('bbox'))
+
+
+def _add_geometry_columns(sql: SaSelect, col: SaColumn, details: SearchDetails) -> SaSelect:
+ if not details.geometry_output:
+ return sql
+
+ out = []
+
+ if details.geometry_simplification > 0.0:
+ col = col.ST_SimplifyPreserveTopology(details.geometry_simplification)
+
+ if details.geometry_output & GeometryFormat.GEOJSON:
+ out.append(col.ST_AsGeoJSON().label('geometry_geojson'))
+ if details.geometry_output & GeometryFormat.TEXT:
+ out.append(col.ST_AsText().label('geometry_text'))
+ if details.geometry_output & GeometryFormat.KML:
+ out.append(col.ST_AsKML().label('geometry_kml'))
+ if details.geometry_output & GeometryFormat.SVG:
+ out.append(col.ST_AsSVG().label('geometry_svg'))
+
+ return sql.add_columns(*out)
+
+
+def _make_interpolation_subquery(table: SaFromClause, inner: SaFromClause,
+ numerals: List[int], details: SearchDetails) -> SaScalarSelect:
+ all_ids = array_agg(table.c.place_id) # type: ignore[no-untyped-call]
+ sql = sa.select(all_ids).where(table.c.parent_place_id == inner.c.place_id)
+
+ if len(numerals) == 1:
+ sql = sql.where(sa.between(numerals[0], table.c.startnumber, table.c.endnumber))\
+ .where((numerals[0] - table.c.startnumber) % table.c.step == 0)
+ else:
+ sql = sql.where(sa.or_(
+ *(sa.and_(sa.between(n, table.c.startnumber, table.c.endnumber),
+ (n - table.c.startnumber) % table.c.step == 0)
+ for n in numerals)))
+
+ if details.excluded:
+ sql = sql.where(table.c.place_id.not_in(details.excluded))
+
+ return sql.scalar_subquery()
+
+
+def _filter_by_layer(table: SaFromClause, layers: DataLayer) -> SaColumn:
+ orexpr: List[SaExpression] = []
+ if layers & DataLayer.ADDRESS and layers & DataLayer.POI:
+ orexpr.append(table.c.rank_address.between(1, 30))
+ elif layers & DataLayer.ADDRESS:
+ orexpr.append(table.c.rank_address.between(1, 29))
+ orexpr.append(sa.and_(table.c.rank_address == 30,
+ sa.or_(table.c.housenumber != None,
+ table.c.address.has_key('housename'))))
+ elif layers & DataLayer.POI:
+ orexpr.append(sa.and_(table.c.rank_address == 30,
+ table.c.class_.not_in(('place', 'building'))))
+
+ if layers & DataLayer.MANMADE:
+ exclude = []
+ if not layers & DataLayer.RAILWAY:
+ exclude.append('railway')
+ if not layers & DataLayer.NATURAL:
+ exclude.extend(('natural', 'water', 'waterway'))
+ orexpr.append(sa.and_(table.c.class_.not_in(tuple(exclude)),
+ table.c.rank_address == 0))
+ else:
+ include = []
+ if layers & DataLayer.RAILWAY:
+ include.append('railway')
+ if layers & DataLayer.NATURAL:
+ include.extend(('natural', 'water', 'waterway'))
+ orexpr.append(sa.and_(table.c.class_.in_(tuple(include)),
+ table.c.rank_address == 0))
+
+ if len(orexpr) == 1:
+ return orexpr[0]
+
+ return sa.or_(*orexpr)
+
+
+def _interpolated_position(table: SaFromClause, nr: SaColumn) -> SaColumn:
+ pos = sa.cast(nr - table.c.startnumber, sa.Float) / (table.c.endnumber - table.c.startnumber)
+ return sa.case(
+ (table.c.endnumber == table.c.startnumber, table.c.linegeo.ST_Centroid()),
+ else_=table.c.linegeo.ST_LineInterpolatePoint(pos)).label('centroid')
+
+
+async def _get_placex_housenumbers(conn: SearchConnection,
+ place_ids: List[int],
+ details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
+ t = conn.t.placex
+ sql = _select_placex(t).where(t.c.place_id.in_(place_ids))
+
+ sql = _add_geometry_columns(sql, t.c.geometry, details)
+
+ for row in await conn.execute(sql):
+ result = nres.create_from_placex_row(row, nres.SearchResult)
+ assert result
+ result.bbox = Bbox.from_wkb(row.bbox.data)
+ yield result
+
+
+async def _get_osmline(conn: SearchConnection, place_ids: List[int],
+ numerals: List[int],
+ details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
+ t = conn.t.osmline
+ values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\
+ .data([(n,) for n in numerals])
+ sql = sa.select(t.c.place_id, t.c.osm_id,
+ t.c.parent_place_id, t.c.address,
+ values.c.nr.label('housenumber'),
+ _interpolated_position(t, values.c.nr),
+ t.c.postcode, t.c.country_code)\
+ .where(t.c.place_id.in_(place_ids))\
+ .join(values, values.c.nr.between(t.c.startnumber, t.c.endnumber))
+
+ if details.geometry_output:
+ sub = sql.subquery()
+ sql = _add_geometry_columns(sa.select(sub), sub.c.centroid, details)
+
+ for row in await conn.execute(sql):
+ result = nres.create_from_osmline_row(row, nres.SearchResult)
+ assert result
+ yield result
+
+
+async def _get_tiger(conn: SearchConnection, place_ids: List[int],
+ numerals: List[int], osm_id: int,
+ details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
+ t = conn.t.tiger
+ values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\
+ .data([(n,) for n in numerals])
+ sql = sa.select(t.c.place_id, t.c.parent_place_id,
+ sa.literal('W').label('osm_type'),
+ sa.literal(osm_id).label('osm_id'),
+ values.c.nr.label('housenumber'),
+ _interpolated_position(t, values.c.nr),
+ t.c.postcode)\
+ .where(t.c.place_id.in_(place_ids))\
+ .join(values, values.c.nr.between(t.c.startnumber, t.c.endnumber))
+
+ if details.geometry_output:
+ sub = sql.subquery()
+ sql = _add_geometry_columns(sa.select(sub), sub.c.centroid, details)
+
+ for row in await conn.execute(sql):
+ result = nres.create_from_tiger_row(row, nres.SearchResult)
+ assert result
+ yield result
+
+
+class AbstractSearch(abc.ABC):
+ """ Encapuslation of a single lookup in the database.
+ """
+
+ def __init__(self, penalty: float) -> None:
+ self.penalty = penalty
+
+ @abc.abstractmethod
+ async def lookup(self, conn: SearchConnection,
+ details: SearchDetails) -> nres.SearchResults:
+ """ Find results for the search in the database.
+ """
+
+
+class NearSearch(AbstractSearch):
+ """ Category search of a place type near the result of another search.
+ """
+ def __init__(self, penalty: float, categories: WeightedCategories,
+ search: AbstractSearch) -> None:
+ super().__init__(penalty)
+ self.search = search
+ self.categories = categories
+
+
+ async def lookup(self, conn: SearchConnection,
+ details: SearchDetails) -> nres.SearchResults:
+ """ Find results for the search in the database.
+ """
+ results = nres.SearchResults()
+ base = await self.search.lookup(conn, details)
+
+ if not base:
+ return results
+
+ base.sort(key=lambda r: (r.accuracy, r.rank_search))
+ max_accuracy = base[0].accuracy + 0.5
+ base = nres.SearchResults(r for r in base if r.source_table == nres.SourceTable.PLACEX
+ and r.accuracy <= max_accuracy
+ and r.bbox and r.bbox.area < 20)
+
+ if base:
+ baseids = [b.place_id for b in base[:5] if b.place_id]
+
+ for category, penalty in self.categories:
+ await self.lookup_category(results, conn, baseids, category, penalty, details)
+ if len(results) >= details.max_results:
+ break
+
+ return results
+
+
+ async def lookup_category(self, results: nres.SearchResults,
+ conn: SearchConnection, ids: List[int],
+ category: Tuple[str, str], penalty: float,
+ details: SearchDetails) -> None:
+ """ Find places of the given category near the list of
+ place ids and add the results to 'results'.
+ """
+ table = await conn.get_class_table(*category)
+
+ t = conn.t.placex.alias('p')
+ tgeom = conn.t.placex.alias('pgeom')
+
+ sql = _select_placex(t).where(tgeom.c.place_id.in_(ids))\
+ .where(t.c.class_ == category[0])\
+ .where(t.c.type == category[1])
+
+ if table is None:
+ # No classtype table available, do a simplified lookup in placex.
+ sql = sql.join(tgeom, t.c.geometry.ST_DWithin(tgeom.c.centroid, 0.01))\
+ .order_by(tgeom.c.centroid.ST_Distance(t.c.centroid))
+ else:
+ # Use classtype table. We can afford to use a larger
+ # radius for the lookup.
+ sql = sql.join(table, t.c.place_id == table.c.place_id)\
+ .join(tgeom,
+ sa.case((sa.and_(tgeom.c.rank_address < 9,
+ tgeom.c.geometry.ST_GeometryType().in_(
+ ('ST_Polygon', 'ST_MultiPolygon'))),
+ tgeom.c.geometry.ST_Contains(table.c.centroid)),
+ else_ = tgeom.c.centroid.ST_DWithin(table.c.centroid, 0.05)))\
+ .order_by(tgeom.c.centroid.ST_Distance(table.c.centroid))
+
+ if details.countries:
+ sql = sql.where(t.c.country_code.in_(details.countries))
+ if details.min_rank > 0:
+ sql = sql.where(t.c.rank_address >= details.min_rank)
+ if details.max_rank < 30:
+ sql = sql.where(t.c.rank_address <= details.max_rank)
+ if details.excluded:
+ sql = sql.where(t.c.place_id.not_in(details.excluded))
+ if details.layers is not None:
+ sql = sql.where(_filter_by_layer(t, details.layers))
+
+ for row in await conn.execute(sql.limit(details.max_results)):
+ result = nres.create_from_placex_row(row, nres.SearchResult)
+ assert result
+ result.accuracy = self.penalty + penalty
+ result.bbox = Bbox.from_wkb(row.bbox.data)
+ results.append(result)
+
+
+
+class PoiSearch(AbstractSearch):
+ """ Category search in a geographic area.
+ """
+ def __init__(self, sdata: SearchData) -> None:
+ super().__init__(sdata.penalty)
+ self.categories = sdata.qualifiers
+ self.countries = sdata.countries
+
+
+ async def lookup(self, conn: SearchConnection,
+ details: SearchDetails) -> nres.SearchResults:
+ """ Find results for the search in the database.
+ """
+ t = conn.t.placex
+
+ rows: List[SaRow] = []
+
+ if details.near and details.near_radius is not None and details.near_radius < 0.2:
+ # simply search in placex table
+ sql = _select_placex(t) \
+ .where(t.c.linked_place_id == None) \
+ .where(t.c.geometry.ST_DWithin(details.near.sql_value(),
+ details.near_radius)) \
+ .order_by(t.c.centroid.ST_Distance(details.near.sql_value()))
+
+ if self.countries:
+ sql = sql.where(t.c.country_code.in_(self.countries.values))
+
+ if details.viewbox is not None and details.bounded_viewbox:
+ sql = sql.where(t.c.geometry.intersects(details.viewbox.sql_value()))
+
+ classtype = self.categories.values
+ if len(classtype) == 1:
+ sql = sql.where(t.c.class_ == classtype[0][0]) \
+ .where(t.c.type == classtype[0][1])
+ else:
+ sql = sql.where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ)
+ for cls, typ in classtype)))
+
+ rows.extend(await conn.execute(sql.limit(details.max_results)))
+ else:
+ # use the class type tables
+ for category in self.categories.values:
+ table = await conn.get_class_table(*category)
+ if table is not None:
+ sql = _select_placex(t)\
+ .join(table, t.c.place_id == table.c.place_id)\
+ .where(t.c.class_ == category[0])\
+ .where(t.c.type == category[1])
+
+ if details.viewbox is not None and details.bounded_viewbox:
+ sql = sql.where(table.c.centroid.intersects(details.viewbox.sql_value()))
+
+ if details.near:
+ sql = sql.order_by(table.c.centroid.ST_Distance(details.near.sql_value()))\
+ .where(table.c.centroid.ST_DWithin(details.near.sql_value(),
+ details.near_radius or 0.5))
+
+ if self.countries:
+ sql = sql.where(t.c.country_code.in_(self.countries.values))
+
+ rows.extend(await conn.execute(sql.limit(details.max_results)))
+
+ results = nres.SearchResults()
+ for row in rows:
+ result = nres.create_from_placex_row(row, nres.SearchResult)
+ assert result
+ result.accuracy = self.penalty + self.categories.get_penalty((row.class_, row.type))
+ result.bbox = Bbox.from_wkb(row.bbox.data)
+ results.append(result)
+
+ return results
+
+
+class CountrySearch(AbstractSearch):
+ """ Search for a country name or country code.
+ """
+ def __init__(self, sdata: SearchData) -> None:
+ super().__init__(sdata.penalty)
+ self.countries = sdata.countries
+
+
+ async def lookup(self, conn: SearchConnection,
+ details: SearchDetails) -> nres.SearchResults:
+ """ Find results for the search in the database.
+ """
+ t = conn.t.placex
+
+ sql = _select_placex(t)\
+ .where(t.c.country_code.in_(self.countries.values))\
+ .where(t.c.rank_address == 4)
+
+ sql = _add_geometry_columns(sql, t.c.geometry, details)
+
+ if details.excluded:
+ sql = sql.where(t.c.place_id.not_in(details.excluded))
+
+ if details.viewbox is not None and details.bounded_viewbox:
+ sql = sql.where(t.c.geometry.intersects(details.viewbox.sql_value()))
+
+ if details.near is not None and details.near_radius is not None:
+ sql = sql.where(t.c.geometry.ST_DWithin(details.near.sql_value(),
+ details.near_radius))
+
+ results = nres.SearchResults()
+ for row in await conn.execute(sql):
+ result = nres.create_from_placex_row(row, nres.SearchResult)
+ assert result
+ result.accuracy = self.penalty + self.countries.get_penalty(row.country_code, 5.0)
+ results.append(result)
+
+ return results or await self.lookup_in_country_table(conn, details)
+
+
+ async def lookup_in_country_table(self, conn: SearchConnection,
+ details: SearchDetails) -> nres.SearchResults:
+ """ Look up the country in the fallback country tables.
+ """
+ # Avoid the fallback search when this is a more search. Country results
+ # usually are in the first batch of results and it is not possible
+ # to exclude these fallbacks.
+ if details.excluded:
+ return nres.SearchResults()
+
+ t = conn.t.country_name
+ tgrid = conn.t.country_grid
+
+ sql = sa.select(tgrid.c.country_code,
+ tgrid.c.geometry.ST_Centroid().ST_Collect().ST_Centroid()
+ .label('centroid'))\
+ .where(tgrid.c.country_code.in_(self.countries.values))\
+ .group_by(tgrid.c.country_code)
+
+ if details.viewbox is not None and details.bounded_viewbox:
+ sql = sql.where(tgrid.c.geometry.intersects(details.viewbox.sql_value()))
+ if details.near is not None and details.near_radius is not None:
+ sql = sql.where(tgrid.c.geometry.ST_DWithin(details.near.sql_value(),
+ details.near_radius))
+
+ sub = sql.subquery('grid')
+
+ sql = sa.select(t.c.country_code,
+ (t.c.name
+ + sa.func.coalesce(t.c.derived_name,
+ sa.cast('', type_=conn.t.types.Composite))
+ ).label('name'),
+ sub.c.centroid)\
+ .join(sub, t.c.country_code == sub.c.country_code)
+
+ results = nres.SearchResults()
+ for row in await conn.execute(sql):
+ result = nres.create_from_country_row(row, nres.SearchResult)
+ assert result
+ result.accuracy = self.penalty + self.countries.get_penalty(row.country_code, 5.0)
+ results.append(result)
+
+ return results
+
+
+
+class PostcodeSearch(AbstractSearch):
+ """ Search for a postcode.
+ """
+ def __init__(self, extra_penalty: float, sdata: SearchData) -> None:
+ super().__init__(sdata.penalty + extra_penalty)
+ self.countries = sdata.countries
+ self.postcodes = sdata.postcodes
+ self.lookups = sdata.lookups
+ self.rankings = sdata.rankings
+
+
+ async def lookup(self, conn: SearchConnection,
+ details: SearchDetails) -> nres.SearchResults:
+ """ Find results for the search in the database.
+ """
+ t = conn.t.postcode
+
+ sql = sa.select(t.c.place_id, t.c.parent_place_id,
+ t.c.rank_search, t.c.rank_address,
+ t.c.postcode, t.c.country_code,
+ t.c.geometry.label('centroid'))\
+ .where(t.c.postcode.in_(self.postcodes.values))
+
+ sql = _add_geometry_columns(sql, t.c.geometry, details)
+
+ penalty: SaExpression = sa.literal(self.penalty)
+
+ if details.viewbox is not None:
+ if details.bounded_viewbox:
+ sql = sql.where(t.c.geometry.intersects(details.viewbox.sql_value()))
+ else:
+ penalty += sa.case((t.c.geometry.intersects(details.viewbox.sql_value()), 0.0),
+ (t.c.geometry.intersects(details.viewbox_x2.sql_value()), 1.0),
+ else_=2.0)
+
+ if details.near is not None:
+ if details.near_radius is not None:
+ sql = sql.where(t.c.geometry.ST_DWithin(details.near.sql_value(),
+ details.near_radius))
+ sql = sql.order_by(t.c.geometry.ST_Distance(details.near.sql_value()))
+
+ if self.countries:
+ sql = sql.where(t.c.country_code.in_(self.countries.values))
+
+ if details.excluded:
+ sql = sql.where(t.c.place_id.not_in(details.excluded))
+
+ if self.lookups:
+ assert len(self.lookups) == 1
+ assert self.lookups[0].lookup_type == 'restrict'
+ tsearch = conn.t.search_name
+ sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
+ .where(sa.func.array_cat(tsearch.c.name_vector,
+ tsearch.c.nameaddress_vector,
+ type_=ARRAY(sa.Integer))
+ .contains(self.lookups[0].tokens))
+
+ for ranking in self.rankings:
+ penalty += ranking.sql_penalty(conn.t.search_name)
+ penalty += sa.case(*((t.c.postcode == v, p) for v, p in self.postcodes),
+ else_=1.0)
+
+
+ sql = sql.add_columns(penalty.label('accuracy'))
+ sql = sql.order_by('accuracy')
+
+ results = nres.SearchResults()
+ for row in await conn.execute(sql.limit(details.max_results)):
+ result = nres.create_from_postcode_row(row, nres.SearchResult)
+ assert result
+ result.accuracy = row.accuracy
+ results.append(result)
+
+ return results
+
+
+
+class PlaceSearch(AbstractSearch):
+ """ Generic search for an address or named place.
+ """
+ def __init__(self, extra_penalty: float, sdata: SearchData, expected_count: int) -> None:
+ super().__init__(sdata.penalty + extra_penalty)
+ self.countries = sdata.countries
+ self.postcodes = sdata.postcodes
+ self.housenumbers = sdata.housenumbers
+ self.qualifiers = sdata.qualifiers
+ self.lookups = sdata.lookups
+ self.rankings = sdata.rankings
+ self.expected_count = expected_count
+
+
+ async def lookup(self, conn: SearchConnection,
+ details: SearchDetails) -> nres.SearchResults:
+ """ Find results for the search in the database.
+ """
+ t = conn.t.placex.alias('p')
+ tsearch = conn.t.search_name.alias('s')
+ limit = details.max_results
+
+ sql = sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
+ t.c.class_, t.c.type,
+ t.c.address, t.c.extratags,
+ t.c.housenumber, t.c.postcode, t.c.country_code,
+ t.c.wikipedia,
+ t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
+ t.c.centroid,
+ t.c.geometry.ST_Expand(0).label('bbox'))\
+ .where(t.c.place_id == tsearch.c.place_id)
+
+
+ sql = _add_geometry_columns(sql, t.c.geometry, details)
+
+ penalty: SaExpression = sa.literal(self.penalty)
+ for ranking in self.rankings:
+ penalty += ranking.sql_penalty(tsearch)
+
+ for lookup in self.lookups:
+ sql = sql.where(lookup.sql_condition(tsearch))
+
+ if self.countries:
+ sql = sql.where(tsearch.c.country_code.in_(self.countries.values))
+
+ if self.postcodes:
+ # if a postcode is given, don't search for state or country level objects
+ sql = sql.where(tsearch.c.address_rank > 9)
+ tpc = conn.t.postcode
+ if self.expected_count > 1000:
+ # Many results expected. Restrict by postcode.
+ sql = sql.where(sa.select(tpc.c.postcode)
+ .where(tpc.c.postcode.in_(self.postcodes.values))
+ .where(tsearch.c.centroid.ST_DWithin(tpc.c.geometry, 0.12))
+ .exists())
+
+ # Less results, only have a preference for close postcodes
+ pc_near = sa.select(sa.func.min(tpc.c.geometry.ST_Distance(tsearch.c.centroid)))\
+ .where(tpc.c.postcode.in_(self.postcodes.values))\
+ .scalar_subquery()
+ penalty += sa.case((t.c.postcode.in_(self.postcodes.values), 0.0),
+ else_=sa.func.coalesce(pc_near, 2.0))
+
+ if details.viewbox is not None:
+ if details.bounded_viewbox:
+ sql = sql.where(tsearch.c.centroid.intersects(details.viewbox.sql_value()))
+ else:
+ penalty += sa.case((t.c.geometry.intersects(details.viewbox.sql_value()), 0.0),
+ (t.c.geometry.intersects(details.viewbox_x2.sql_value()), 1.0),
+ else_=2.0)
+
+ if details.near is not None:
+ if details.near_radius is not None:
+ sql = sql.where(tsearch.c.centroid.ST_DWithin(details.near.sql_value(),
+ details.near_radius))
+ sql = sql.add_columns(-tsearch.c.centroid.ST_Distance(details.near.sql_value())
+ .label('importance'))
+ sql = sql.order_by(sa.desc(sa.text('importance')))
+ else:
+ sql = sql.order_by(penalty - sa.case((tsearch.c.importance > 0, tsearch.c.importance),
+ else_=0.75001-(sa.cast(tsearch.c.search_rank, sa.Float())/40)))
+ sql = sql.add_columns(t.c.importance)
+
+
+ sql = sql.add_columns(penalty.label('accuracy'))\
+ .order_by(sa.text('accuracy'))
+
+ if self.housenumbers:
+ hnr_regexp = f"\\m({'|'.join(self.housenumbers.values)})\\M"
+ sql = sql.where(tsearch.c.address_rank.between(16, 30))\
+ .where(sa.or_(tsearch.c.address_rank < 30,
+ t.c.housenumber.regexp_match(hnr_regexp, flags='i')))
+
+ # Cross check for housenumbers, need to do that on a rather large
+ # set. Worst case there are 40.000 main streets in OSM.
+ inner = sql.limit(10000).subquery()
+
+ # Housenumbers from placex
+ thnr = conn.t.placex.alias('hnr')
+ pid_list = array_agg(thnr.c.place_id) # type: ignore[no-untyped-call]
+ place_sql = sa.select(pid_list)\
+ .where(thnr.c.parent_place_id == inner.c.place_id)\
+ .where(thnr.c.housenumber.regexp_match(hnr_regexp, flags='i'))\
+ .where(thnr.c.linked_place_id == None)\
+ .where(thnr.c.indexed_status == 0)
+
+ if details.excluded:
+ place_sql = place_sql.where(thnr.c.place_id.not_in(details.excluded))
+ if self.qualifiers:
+ place_sql = place_sql.where(self.qualifiers.sql_restrict(thnr))
+
+ numerals = [int(n) for n in self.housenumbers.values if n.isdigit()]
+ interpol_sql: SaExpression
+ tiger_sql: SaExpression
+ if numerals and \
+ (not self.qualifiers or ('place', 'house') in self.qualifiers.values):
+ # Housenumbers from interpolations
+ interpol_sql = _make_interpolation_subquery(conn.t.osmline, inner,
+ numerals, details)
+ # Housenumbers from Tiger
+ tiger_sql = sa.case((inner.c.country_code == 'us',
+ _make_interpolation_subquery(conn.t.tiger, inner,
+ numerals, details)
+ ), else_=None)
+ else:
+ interpol_sql = sa.literal_column('NULL')
+ tiger_sql = sa.literal_column('NULL')
+
+ unsort = sa.select(inner, place_sql.scalar_subquery().label('placex_hnr'),
+ interpol_sql.label('interpol_hnr'),
+ tiger_sql.label('tiger_hnr')).subquery('unsort')
+ sql = sa.select(unsort)\
+ .order_by(sa.case((unsort.c.placex_hnr != None, 1),
+ (unsort.c.interpol_hnr != None, 2),
+ (unsort.c.tiger_hnr != None, 3),
+ else_=4),
+ unsort.c.accuracy)
+ else:
+ sql = sql.where(t.c.linked_place_id == None)\
+ .where(t.c.indexed_status == 0)
+ if self.qualifiers:
+ sql = sql.where(self.qualifiers.sql_restrict(t))
+ if details.excluded:
+ sql = sql.where(tsearch.c.place_id.not_in(details.excluded))
+ if details.min_rank > 0:
+ sql = sql.where(sa.or_(tsearch.c.address_rank >= details.min_rank,
+ tsearch.c.search_rank >= details.min_rank))
+ if details.max_rank < 30:
+ sql = sql.where(sa.or_(tsearch.c.address_rank <= details.max_rank,
+ tsearch.c.search_rank <= details.max_rank))
+ if details.layers is not None:
+ sql = sql.where(_filter_by_layer(t, details.layers))
+
+
+ results = nres.SearchResults()
+ for row in await conn.execute(sql.limit(limit)):
+ result = nres.create_from_placex_row(row, nres.SearchResult)
+ assert result
+ result.bbox = Bbox.from_wkb(row.bbox.data)
+ result.accuracy = row.accuracy
+ if not details.excluded or not result.place_id in details.excluded:
+ results.append(result)
+
+ if self.housenumbers and row.rank_address < 30:
+ if row.placex_hnr:
+ subs = _get_placex_housenumbers(conn, row.placex_hnr, details)
+ elif row.interpol_hnr:
+ subs = _get_osmline(conn, row.interpol_hnr, numerals, details)
+ elif row.tiger_hnr:
+ subs = _get_tiger(conn, row.tiger_hnr, numerals, row.osm_id, details)
+ else:
+ subs = None
+
+ if subs is not None:
+ async for sub in subs:
+ assert sub.housenumber
+ sub.accuracy = result.accuracy
+ if not any(nr in self.housenumbers.values
+ for nr in sub.housenumber.split(';')):
+ sub.accuracy += 0.6
+ results.append(sub)
+
+ result.accuracy += 1.0 # penalty for missing housenumber
+
+ return results
diff --git a/nominatim/api/search/geocoder.py b/nominatim/api/search/geocoder.py
new file mode 100644
index 00000000..0ef649d9
--- /dev/null
+++ b/nominatim/api/search/geocoder.py
@@ -0,0 +1,191 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Public interface to the search code.
+"""
+from typing import List, Any, Optional, Iterator, Tuple
+import itertools
+
+from nominatim.api.connection import SearchConnection
+from nominatim.api.types import SearchDetails
+from nominatim.api.results import SearchResults, add_result_details
+from nominatim.api.search.token_assignment import yield_token_assignments
+from nominatim.api.search.db_search_builder import SearchBuilder, build_poi_search, wrap_near_search
+from nominatim.api.search.db_searches import AbstractSearch
+from nominatim.api.search.query_analyzer_factory import make_query_analyzer, AbstractQueryAnalyzer
+from nominatim.api.search.query import Phrase, QueryStruct
+from nominatim.api.logging import log
+
+class ForwardGeocoder:
+ """ Main class responsible for place search.
+ """
+
+ def __init__(self, conn: SearchConnection, params: SearchDetails) -> None:
+ self.conn = conn
+ self.params = params
+ self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
+
+
+ @property
+ def limit(self) -> int:
+ """ Return the configured maximum number of search results.
+ """
+ return self.params.max_results
+
+
+ async def build_searches(self,
+ phrases: List[Phrase]) -> Tuple[QueryStruct, List[AbstractSearch]]:
+ """ Analyse the query and return the tokenized query and list of
+ possible searches over it.
+ """
+ if self.query_analyzer is None:
+ self.query_analyzer = await make_query_analyzer(self.conn)
+
+ query = await self.query_analyzer.analyze_query(phrases)
+
+ searches: List[AbstractSearch] = []
+ if query.num_token_slots() > 0:
+ # 2. Compute all possible search interpretations
+ log().section('Compute abstract searches')
+ search_builder = SearchBuilder(query, self.params)
+ num_searches = 0
+ for assignment in yield_token_assignments(query):
+ searches.extend(search_builder.build(assignment))
+ log().table_dump('Searches for assignment',
+ _dump_searches(searches, query, num_searches))
+ num_searches = len(searches)
+ searches.sort(key=lambda s: s.penalty)
+
+ return query, searches
+
+
+ async def execute_searches(self, query: QueryStruct,
+ searches: List[AbstractSearch]) -> SearchResults:
+ """ Run the abstract searches against the database until a result
+ is found.
+ """
+ log().section('Execute database searches')
+ results = SearchResults()
+
+ num_results = 0
+ min_ranking = 1000.0
+ prev_penalty = 0.0
+ for i, search in enumerate(searches):
+ if search.penalty > prev_penalty and (search.penalty > min_ranking or i > 20):
+ break
+ log().table_dump(f"{i + 1}. Search", _dump_searches([search], query))
+ for result in await search.lookup(self.conn, self.params):
+ results.append(result)
+ min_ranking = min(min_ranking, result.ranking + 0.5, search.penalty + 0.3)
+ log().result_dump('Results', ((r.accuracy, r) for r in results[num_results:]))
+ num_results = len(results)
+ prev_penalty = search.penalty
+
+ if results:
+ min_ranking = min(r.ranking for r in results)
+ results = SearchResults(r for r in results if r.ranking < min_ranking + 0.5)
+
+ if results:
+ min_rank = min(r.rank_search for r in results)
+
+ results = SearchResults(r for r in results
+ if r.ranking + 0.05 * (r.rank_search - min_rank)
+ < min_ranking + 0.5)
+
+ results.sort(key=lambda r: r.accuracy - r.calculated_importance())
+ results = SearchResults(results[:self.limit])
+
+ return results
+
+
+ async def lookup_pois(self, categories: List[Tuple[str, str]],
+ phrases: List[Phrase]) -> SearchResults:
+ """ Look up places by category. If phrase is given, a place search
+ over the phrase will be executed first and places close to the
+ results returned.
+ """
+ log().function('forward_lookup_pois', categories=categories, params=self.params)
+
+ if phrases:
+ query, searches = await self.build_searches(phrases)
+
+ if query:
+ searches = [wrap_near_search(categories, s) for s in searches[:50]]
+ results = await self.execute_searches(query, searches)
+ else:
+ results = SearchResults()
+ else:
+ search = build_poi_search(categories, self.params.countries)
+ results = await search.lookup(self.conn, self.params)
+
+ await add_result_details(self.conn, results, self.params)
+ log().result_dump('Final Results', ((r.accuracy, r) for r in results))
+
+ return results
+
+
+ async def lookup(self, phrases: List[Phrase]) -> SearchResults:
+ """ Look up a single free-text query.
+ """
+ log().function('forward_lookup', phrases=phrases, params=self.params)
+ results = SearchResults()
+
+ if self.params.is_impossible():
+ return results
+
+ query, searches = await self.build_searches(phrases)
+
+ if searches:
+ # Execute SQL until an appropriate result is found.
+ results = await self.execute_searches(query, searches[:50])
+ await add_result_details(self.conn, results, self.params)
+ log().result_dump('Final Results', ((r.accuracy, r) for r in results))
+
+ return results
+
+
+# pylint: disable=invalid-name,too-many-locals
+def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
+ start: int = 0) -> Iterator[Optional[List[Any]]]:
+ yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries', 'Qualifier', 'Rankings']
+
+ def tk(tl: List[int]) -> str:
+ tstr = [f"{query.find_lookup_word_by_id(t)}({t})" for t in tl]
+
+ return f"[{','.join(tstr)}]"
+
+ def fmt_ranking(f: Any) -> str:
+ if not f:
+ return ''
+ ranks = ','.join((f"{tk(r.tokens)}^{r.penalty:.3g}" for r in f.rankings))
+ if len(ranks) > 100:
+ ranks = ranks[:100] + '...'
+ return f"{f.column}({ranks},def={f.default:.3g})"
+
+ def fmt_lookup(l: Any) -> str:
+ if not l:
+ return ''
+
+ return f"{l.lookup_type}({l.column}{tk(l.tokens)})"
+
+
+ def fmt_cstr(c: Any) -> str:
+ if not c:
+ return ''
+
+ return f'{c[0]}^{c[1]}'
+
+ for search in searches[start:]:
+ fields = ('lookups', 'rankings', 'countries', 'housenumbers',
+ 'postcodes', 'qualifier')
+ iters = itertools.zip_longest([f"{search.penalty:.3g}"],
+ *(getattr(search, attr, []) for attr in fields),
+ fillvalue= '')
+ for penalty, lookup, rank, cc, hnr, pc, qual in iters:
+ yield [penalty, fmt_lookup(lookup), fmt_cstr(hnr),
+ fmt_cstr(pc), fmt_cstr(cc), fmt_cstr(qual), fmt_ranking(rank)]
+ yield None
diff --git a/nominatim/api/search/icu_tokenizer.py b/nominatim/api/search/icu_tokenizer.py
new file mode 100644
index 00000000..17e67905
--- /dev/null
+++ b/nominatim/api/search/icu_tokenizer.py
@@ -0,0 +1,291 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Implementation of query analysis for the ICU tokenizer.
+"""
+from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
+from copy import copy
+from collections import defaultdict
+import dataclasses
+import difflib
+
+from icu import Transliterator
+
+import sqlalchemy as sa
+
+from nominatim.typing import SaRow
+from nominatim.api.connection import SearchConnection
+from nominatim.api.logging import log
+from nominatim.api.search import query as qmod
+from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer
+
+
+DB_TO_TOKEN_TYPE = {
+ 'W': qmod.TokenType.WORD,
+ 'w': qmod.TokenType.PARTIAL,
+ 'H': qmod.TokenType.HOUSENUMBER,
+ 'P': qmod.TokenType.POSTCODE,
+ 'C': qmod.TokenType.COUNTRY
+}
+
+
+class QueryPart(NamedTuple):
+ """ Normalized and transliterated form of a single term in the query.
+ When the term came out of a split during the transliteration,
+ the normalized string is the full word before transliteration.
+ The word number keeps track of the word before transliteration
+ and can be used to identify partial transliterated terms.
+ """
+ token: str
+ normalized: str
+ word_number: int
+
+
+QueryParts = List[QueryPart]
+WordDict = Dict[str, List[qmod.TokenRange]]
+
+def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
+ """ Return all combinations of words in the terms list after the
+ given position.
+ """
+ total = len(terms)
+ for first in range(start, total):
+ word = terms[first].token
+ yield word, qmod.TokenRange(first, first + 1)
+ for last in range(first + 1, min(first + 20, total)):
+ word = ' '.join((word, terms[last].token))
+ yield word, qmod.TokenRange(first, last + 1)
+
+
+@dataclasses.dataclass
+class ICUToken(qmod.Token):
+ """ Specialised token for ICU tokenizer.
+ """
+ word_token: str
+ info: Optional[Dict[str, Any]]
+
+ def get_category(self) -> Tuple[str, str]:
+ assert self.info
+ return self.info.get('class', ''), self.info.get('type', '')
+
+
+ def rematch(self, norm: str) -> None:
+ """ Check how well the token matches the given normalized string
+ and add a penalty, if necessary.
+ """
+ if not self.lookup_word:
+ return
+
+ seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
+ distance = 0
+ for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
+ if tag == 'delete' and (afrom == 0 or ato == len(self.lookup_word)):
+ distance += 1
+ elif tag == 'replace':
+ distance += max((ato-afrom), (bto-bfrom))
+ elif tag != 'equal':
+ distance += abs((ato-afrom) - (bto-bfrom))
+ self.penalty += (distance/len(self.lookup_word))
+
+
+ @staticmethod
+ def from_db_row(row: SaRow) -> 'ICUToken':
+ """ Create a ICUToken from the row of the word table.
+ """
+ count = 1 if row.info is None else row.info.get('count', 1)
+
+ penalty = 0.0
+ if row.type == 'w':
+ penalty = 0.3
+ elif row.type == 'H':
+ penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
+ if all(not c.isdigit() for c in row.word_token):
+ penalty += 0.2 * (len(row.word_token) - 1)
+
+ if row.info is None:
+ lookup_word = row.word
+ else:
+ lookup_word = row.info.get('lookup', row.word)
+ if lookup_word:
+ lookup_word = lookup_word.split('@', 1)[0]
+ else:
+ lookup_word = row.word_token
+
+ return ICUToken(penalty=penalty, token=row.word_id, count=count,
+ lookup_word=lookup_word, is_indexed=True,
+ word_token=row.word_token, info=row.info)
+
+
+
+class ICUQueryAnalyzer(AbstractQueryAnalyzer):
+ """ Converter for query strings into a tokenized query
+ using the tokens created by a ICU tokenizer.
+ """
+
+ def __init__(self, conn: SearchConnection) -> None:
+ self.conn = conn
+
+
+ async def setup(self) -> None:
+ """ Set up static data structures needed for the analysis.
+ """
+ rules = await self.conn.get_property('tokenizer_import_normalisation')
+ self.normalizer = Transliterator.createFromRules("normalization", rules)
+ rules = await self.conn.get_property('tokenizer_import_transliteration')
+ self.transliterator = Transliterator.createFromRules("transliteration", rules)
+
+ if 'word' not in self.conn.t.meta.tables:
+ sa.Table('word', self.conn.t.meta,
+ sa.Column('word_id', sa.Integer),
+ sa.Column('word_token', sa.Text, nullable=False),
+ sa.Column('type', sa.Text, nullable=False),
+ sa.Column('word', sa.Text),
+ sa.Column('info', self.conn.t.types.Json))
+
+
+ async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
+ """ Analyze the given list of phrases and return the
+ tokenized query.
+ """
+ log().section('Analyze query (using ICU tokenizer)')
+ normalized = list(filter(lambda p: p.text,
+ (qmod.Phrase(p.ptype, self.normalizer.transliterate(p.text))
+ for p in phrases)))
+ query = qmod.QueryStruct(normalized)
+ log().var_dump('Normalized query', query.source)
+ if not query.source:
+ return query
+
+ parts, words = self.split_query(query)
+ log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
+
+ for row in await self.lookup_in_db(list(words.keys())):
+ for trange in words[row.word_token]:
+ token = ICUToken.from_db_row(row)
+ if row.type == 'S':
+ if row.info['op'] in ('in', 'near'):
+ if trange.start == 0:
+ query.add_token(trange, qmod.TokenType.CATEGORY, token)
+ else:
+ query.add_token(trange, qmod.TokenType.QUALIFIER, token)
+ if trange.start == 0 or trange.end == query.num_token_slots():
+ token = copy(token)
+ token.penalty += 0.1 * (query.num_token_slots())
+ query.add_token(trange, qmod.TokenType.CATEGORY, token)
+ else:
+ query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
+
+ self.add_extra_tokens(query, parts)
+ self.rerank_tokens(query, parts)
+
+ log().table_dump('Word tokens', _dump_word_tokens(query))
+
+ return query
+
+
+ def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
+ """ Transliterate the phrases and split them into tokens.
+
+ Returns the list of transliterated tokens together with their
+ normalized form and a dictionary of words for lookup together
+ with their position.
+ """
+ parts: QueryParts = []
+ phrase_start = 0
+ words = defaultdict(list)
+ wordnr = 0
+ for phrase in query.source:
+ query.nodes[-1].ptype = phrase.ptype
+ for word in phrase.text.split(' '):
+ trans = self.transliterator.transliterate(word)
+ if trans:
+ for term in trans.split(' '):
+ if term:
+ parts.append(QueryPart(term, word, wordnr))
+ query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
+ query.nodes[-1].btype = qmod.BreakType.WORD
+ wordnr += 1
+ query.nodes[-1].btype = qmod.BreakType.PHRASE
+
+ for word, wrange in yield_words(parts, phrase_start):
+ words[word].append(wrange)
+
+ phrase_start = len(parts)
+ query.nodes[-1].btype = qmod.BreakType.END
+
+ return parts, words
+
+
+ async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
+ """ Return the token information from the database for the
+ given word tokens.
+ """
+ t = self.conn.t.meta.tables['word']
+ return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
+
+
+ def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
+ """ Add tokens to query that are not saved in the database.
+ """
+ for part, node, i in zip(parts, query.nodes, range(1000)):
+ if len(part.token) <= 4 and part[0].isdigit()\
+ and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
+ query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
+ ICUToken(0.5, 0, 1, part.token, True, part.token, None))
+
+
+ def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
+ """ Add penalties to tokens that depend on presence of other token.
+ """
+ for i, node, tlist in query.iter_token_lists():
+ if tlist.ttype == qmod.TokenType.POSTCODE:
+ for repl in node.starting:
+ if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
+ and (repl.ttype != qmod.TokenType.HOUSENUMBER
+ or len(tlist.tokens[0].lookup_word) > 4):
+ repl.add_penalty(0.39)
+ elif tlist.ttype == qmod.TokenType.HOUSENUMBER:
+ if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
+ for repl in node.starting:
+ if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER \
+ and (repl.ttype != qmod.TokenType.HOUSENUMBER
+ or len(tlist.tokens[0].lookup_word) <= 3):
+ repl.add_penalty(0.5 - tlist.tokens[0].penalty)
+ elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
+ norm = parts[i].normalized
+ for j in range(i + 1, tlist.end):
+ if parts[j - 1].word_number != parts[j].word_number:
+ norm += ' ' + parts[j].normalized
+ for token in tlist.tokens:
+ cast(ICUToken, token).rematch(norm)
+
+
+def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
+ out = query.nodes[0].btype.value
+ for node, part in zip(query.nodes[1:], parts):
+ out += part.token + node.btype.value
+ return out
+
+
+def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
+ yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
+ for node in query.nodes:
+ for tlist in node.starting:
+ for token in tlist.tokens:
+ t = cast(ICUToken, token)
+ yield [tlist.ttype.name, t.token, t.word_token or '',
+ t.lookup_word or '', t.penalty, t.count, t.info]
+
+
+async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
+ """ Create and set up a new query analyzer for a database based
+ on the ICU tokenizer.
+ """
+ out = ICUQueryAnalyzer(conn)
+ await out.setup()
+
+ return out
diff --git a/nominatim/api/search/legacy_tokenizer.py b/nominatim/api/search/legacy_tokenizer.py
new file mode 100644
index 00000000..96975704
--- /dev/null
+++ b/nominatim/api/search/legacy_tokenizer.py
@@ -0,0 +1,263 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Implementation of query analysis for the legacy tokenizer.
+"""
+from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
+from copy import copy
+from collections import defaultdict
+import dataclasses
+
+import sqlalchemy as sa
+
+from nominatim.typing import SaRow
+from nominatim.api.connection import SearchConnection
+from nominatim.api.logging import log
+from nominatim.api.search import query as qmod
+from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer
+
+def yield_words(terms: List[str], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
+ """ Return all combinations of words in the terms list after the
+ given position.
+ """
+ total = len(terms)
+ for first in range(start, total):
+ word = terms[first]
+ yield word, qmod.TokenRange(first, first + 1)
+ for last in range(first + 1, min(first + 20, total)):
+ word = ' '.join((word, terms[last]))
+ yield word, qmod.TokenRange(first, last + 1)
+
+
+@dataclasses.dataclass
+class LegacyToken(qmod.Token):
+ """ Specialised token for legacy tokenizer.
+ """
+ word_token: str
+ category: Optional[Tuple[str, str]]
+ country: Optional[str]
+ operator: Optional[str]
+
+ @property
+ def info(self) -> Dict[str, Any]:
+ """ Dictionary of additional propoerties of the token.
+ Should only be used for debugging purposes.
+ """
+ return {'category': self.category,
+ 'country': self.country,
+ 'operator': self.operator}
+
+
+ def get_category(self) -> Tuple[str, str]:
+ assert self.category
+ return self.category
+
+
+class LegacyQueryAnalyzer(AbstractQueryAnalyzer):
+ """ Converter for query strings into a tokenized query
+ using the tokens created by a legacy tokenizer.
+ """
+
+ def __init__(self, conn: SearchConnection) -> None:
+ self.conn = conn
+
+ async def setup(self) -> None:
+ """ Set up static data structures needed for the analysis.
+ """
+ self.max_word_freq = int(await self.conn.get_property('tokenizer_maxwordfreq'))
+ if 'word' not in self.conn.t.meta.tables:
+ sa.Table('word', self.conn.t.meta,
+ sa.Column('word_id', sa.Integer),
+ sa.Column('word_token', sa.Text, nullable=False),
+ sa.Column('word', sa.Text),
+ sa.Column('class', sa.Text),
+ sa.Column('type', sa.Text),
+ sa.Column('country_code', sa.Text),
+ sa.Column('search_name_count', sa.Integer),
+ sa.Column('operator', sa.Text))
+
+
+ async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
+ """ Analyze the given list of phrases and return the
+ tokenized query.
+ """
+ log().section('Analyze query (using Legacy tokenizer)')
+
+ normalized = []
+ if phrases:
+ for row in await self.conn.execute(sa.select(*(sa.func.make_standard_name(p.text)
+ for p in phrases))):
+ normalized = [qmod.Phrase(p.ptype, r) for r, p in zip(row, phrases) if r]
+ break
+
+ query = qmod.QueryStruct(normalized)
+ log().var_dump('Normalized query', query.source)
+ if not query.source:
+ return query
+
+ parts, words = self.split_query(query)
+ lookup_words = list(words.keys())
+ log().var_dump('Split query', parts)
+ log().var_dump('Extracted words', lookup_words)
+
+ for row in await self.lookup_in_db(lookup_words):
+ for trange in words[row.word_token.strip()]:
+ token, ttype = self.make_token(row)
+ if ttype == qmod.TokenType.CATEGORY:
+ if trange.start == 0:
+ query.add_token(trange, qmod.TokenType.CATEGORY, token)
+ elif ttype == qmod.TokenType.QUALIFIER:
+ query.add_token(trange, qmod.TokenType.QUALIFIER, token)
+ if trange.start == 0 or trange.end == query.num_token_slots():
+ token = copy(token)
+ token.penalty += 0.1 * (query.num_token_slots())
+ query.add_token(trange, qmod.TokenType.CATEGORY, token)
+ elif ttype != qmod.TokenType.PARTIAL or trange.start + 1 == trange.end:
+ query.add_token(trange, ttype, token)
+
+ self.add_extra_tokens(query, parts)
+ self.rerank_tokens(query)
+
+ log().table_dump('Word tokens', _dump_word_tokens(query))
+
+ return query
+
+
+ def split_query(self, query: qmod.QueryStruct) -> Tuple[List[str],
+ Dict[str, List[qmod.TokenRange]]]:
+ """ Transliterate the phrases and split them into tokens.
+
+ Returns a list of transliterated tokens and a dictionary
+ of words for lookup together with their position.
+ """
+ parts: List[str] = []
+ phrase_start = 0
+ words = defaultdict(list)
+ for phrase in query.source:
+ query.nodes[-1].ptype = phrase.ptype
+ for trans in phrase.text.split(' '):
+ if trans:
+ for term in trans.split(' '):
+ if term:
+ parts.append(trans)
+ query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
+ query.nodes[-1].btype = qmod.BreakType.WORD
+ query.nodes[-1].btype = qmod.BreakType.PHRASE
+ for word, wrange in yield_words(parts, phrase_start):
+ words[word].append(wrange)
+ phrase_start = len(parts)
+ query.nodes[-1].btype = qmod.BreakType.END
+
+ return parts, words
+
+
+ async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
+ """ Return the token information from the database for the
+ given word tokens.
+ """
+ t = self.conn.t.meta.tables['word']
+
+ sql = t.select().where(t.c.word_token.in_(words + [' ' + w for w in words]))
+
+ return await self.conn.execute(sql)
+
+
+ def make_token(self, row: SaRow) -> Tuple[LegacyToken, qmod.TokenType]:
+ """ Create a LegacyToken from the row of the word table.
+ Also determines the type of token.
+ """
+ penalty = 0.0
+ is_indexed = True
+
+ rowclass = getattr(row, 'class')
+
+ if row.country_code is not None:
+ ttype = qmod.TokenType.COUNTRY
+ lookup_word = row.country_code
+ elif rowclass is not None:
+ if rowclass == 'place' and row.type == 'house':
+ ttype = qmod.TokenType.HOUSENUMBER
+ lookup_word = row.word_token[1:]
+ elif rowclass == 'place' and row.type == 'postcode':
+ ttype = qmod.TokenType.POSTCODE
+ lookup_word = row.word_token[1:]
+ else:
+ ttype = qmod.TokenType.CATEGORY if row.operator in ('in', 'near')\
+ else qmod.TokenType.QUALIFIER
+ lookup_word = row.word
+ elif row.word_token.startswith(' '):
+ ttype = qmod.TokenType.WORD
+ lookup_word = row.word or row.word_token[1:]
+ else:
+ ttype = qmod.TokenType.PARTIAL
+ lookup_word = row.word_token
+ penalty = 0.21
+ if row.search_name_count > self.max_word_freq:
+ is_indexed = False
+
+ return LegacyToken(penalty=penalty, token=row.word_id,
+ count=row.search_name_count or 1,
+ lookup_word=lookup_word,
+ word_token=row.word_token.strip(),
+ category=(rowclass, row.type) if rowclass is not None else None,
+ country=row.country_code,
+ operator=row.operator,
+ is_indexed=is_indexed),\
+ ttype
+
+
+ def add_extra_tokens(self, query: qmod.QueryStruct, parts: List[str]) -> None:
+ """ Add tokens to query that are not saved in the database.
+ """
+ for part, node, i in zip(parts, query.nodes, range(1000)):
+ if len(part) <= 4 and part.isdigit()\
+ and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
+ query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
+ LegacyToken(penalty=0.5, token=0, count=1,
+ lookup_word=part, word_token=part,
+ category=None, country=None,
+ operator=None, is_indexed=True))
+
+
+ def rerank_tokens(self, query: qmod.QueryStruct) -> None:
+ """ Add penalties to tokens that depend on presence of other token.
+ """
+ for _, node, tlist in query.iter_token_lists():
+ if tlist.ttype == qmod.TokenType.POSTCODE:
+ for repl in node.starting:
+ if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
+ and (repl.ttype != qmod.TokenType.HOUSENUMBER
+ or len(tlist.tokens[0].lookup_word) > 4):
+ repl.add_penalty(0.39)
+ elif tlist.ttype == qmod.TokenType.HOUSENUMBER:
+ if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
+ for repl in node.starting:
+ if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER \
+ and (repl.ttype != qmod.TokenType.HOUSENUMBER
+ or len(tlist.tokens[0].lookup_word) <= 3):
+ repl.add_penalty(0.5 - tlist.tokens[0].penalty)
+
+
+
+def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
+ yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
+ for node in query.nodes:
+ for tlist in node.starting:
+ for token in tlist.tokens:
+ t = cast(LegacyToken, token)
+ yield [tlist.ttype.name, t.token, t.word_token or '',
+ t.lookup_word or '', t.penalty, t.count, t.info]
+
+
+async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
+ """ Create and set up a new query analyzer for a database based
+ on the ICU tokenizer.
+ """
+ out = LegacyQueryAnalyzer(conn)
+ await out.setup()
+
+ return out
diff --git a/nominatim/api/search/query.py b/nominatim/api/search/query.py
new file mode 100644
index 00000000..f2b18f87
--- /dev/null
+++ b/nominatim/api/search/query.py
@@ -0,0 +1,276 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Datastructures for a tokenized query.
+"""
+from typing import List, Tuple, Optional, NamedTuple, Iterator
+from abc import ABC, abstractmethod
+import dataclasses
+import enum
+
+class BreakType(enum.Enum):
+ """ Type of break between tokens.
+ """
+ START = '<'
+ """ Begin of the query. """
+ END = '>'
+ """ End of the query. """
+ PHRASE = ','
+ """ Break between two phrases. """
+ WORD = ' '
+ """ Break between words. """
+ PART = '-'
+ """ Break inside a word, for example a hyphen or apostrophe. """
+ TOKEN = '`'
+ """ Break created as a result of tokenization.
+ This may happen in languages without spaces between words.
+ """
+
+
+class TokenType(enum.Enum):
+ """ Type of token.
+ """
+ WORD = enum.auto()
+ """ Full name of a place. """
+ PARTIAL = enum.auto()
+ """ Word term without breaks, does not necessarily represent a full name. """
+ HOUSENUMBER = enum.auto()
+ """ Housenumber term. """
+ POSTCODE = enum.auto()
+ """ Postal code term. """
+ COUNTRY = enum.auto()
+ """ Country name or reference. """
+ QUALIFIER = enum.auto()
+ """ Special term used together with name (e.g. _Hotel_ Bellevue). """
+ CATEGORY = enum.auto()
+ """ Special term used as searchable object(e.g. supermarket in ...). """
+
+
+class PhraseType(enum.Enum):
+ """ Designation of a phrase.
+ """
+ NONE = 0
+ """ No specific designation (i.e. source is free-form query). """
+ AMENITY = enum.auto()
+ """ Contains name or type of a POI. """
+ STREET = enum.auto()
+ """ Contains a street name optionally with a housenumber. """
+ CITY = enum.auto()
+ """ Contains the postal city. """
+ COUNTY = enum.auto()
+ """ Contains the equivalent of a county. """
+ STATE = enum.auto()
+ """ Contains a state or province. """
+ POSTCODE = enum.auto()
+ """ Contains a postal code. """
+ COUNTRY = enum.auto()
+ """ Contains the country name or code. """
+
+ def compatible_with(self, ttype: TokenType) -> bool:
+ """ Check if the given token type can be used with the phrase type.
+ """
+ if self == PhraseType.NONE:
+ return True
+ if self == PhraseType.AMENITY:
+ return ttype in (TokenType.WORD, TokenType.PARTIAL,
+ TokenType.QUALIFIER, TokenType.CATEGORY)
+ if self == PhraseType.STREET:
+ return ttype in (TokenType.WORD, TokenType.PARTIAL, TokenType.HOUSENUMBER)
+ if self == PhraseType.POSTCODE:
+ return ttype == TokenType.POSTCODE
+ if self == PhraseType.COUNTRY:
+ return ttype == TokenType.COUNTRY
+
+ return ttype in (TokenType.WORD, TokenType.PARTIAL)
+
+
+@dataclasses.dataclass
+class Token(ABC):
+ """ Base type for tokens.
+ Specific query analyzers must implement the concrete token class.
+ """
+
+ penalty: float
+ token: int
+ count: int
+ lookup_word: str
+ is_indexed: bool
+
+
+ @abstractmethod
+ def get_category(self) -> Tuple[str, str]:
+ """ Return the category restriction for qualifier terms and
+ category objects.
+ """
+
+
+class TokenRange(NamedTuple):
+ """ Indexes of query nodes over which a token spans.
+ """
+ start: int
+ end: int
+
+ def replace_start(self, new_start: int) -> 'TokenRange':
+ """ Return a new token range with the new start.
+ """
+ return TokenRange(new_start, self.end)
+
+
+ def replace_end(self, new_end: int) -> 'TokenRange':
+ """ Return a new token range with the new end.
+ """
+ return TokenRange(self.start, new_end)
+
+
+ def split(self, index: int) -> Tuple['TokenRange', 'TokenRange']:
+ """ Split the span into two spans at the given index.
+ The index must be within the span.
+ """
+ return self.replace_end(index), self.replace_start(index)
+
+
+@dataclasses.dataclass
+class TokenList:
+ """ List of all tokens of a given type going from one breakpoint to another.
+ """
+ end: int
+ ttype: TokenType
+ tokens: List[Token]
+
+
+ def add_penalty(self, penalty: float) -> None:
+ """ Add the given penalty to all tokens in the list.
+ """
+ for token in self.tokens:
+ token.penalty += penalty
+
+
+@dataclasses.dataclass
+class QueryNode:
+ """ A node of the querry representing a break between terms.
+ """
+ btype: BreakType
+ ptype: PhraseType
+ starting: List[TokenList] = dataclasses.field(default_factory=list)
+
+ def has_tokens(self, end: int, *ttypes: TokenType) -> bool:
+ """ Check if there are tokens of the given types ending at the
+ given node.
+ """
+ return any(tl.end == end and tl.ttype in ttypes for tl in self.starting)
+
+
+ def get_tokens(self, end: int, ttype: TokenType) -> Optional[List[Token]]:
+ """ Get the list of tokens of the given type starting at this node
+ and ending at the node 'end'. Returns 'None' if no such
+ tokens exist.
+ """
+ for tlist in self.starting:
+ if tlist.end == end and tlist.ttype == ttype:
+ return tlist.tokens
+ return None
+
+
+@dataclasses.dataclass
+class Phrase:
+ """ A normalized query part. Phrases may be typed which means that
+ they then represent a specific part of the address.
+ """
+ ptype: PhraseType
+ text: str
+
+
+class QueryStruct:
+ """ A tokenized search query together with the normalized source
+ from which the tokens have been parsed.
+
+ The query contains a list of nodes that represent the breaks
+ between words. Tokens span between nodes, which don't necessarily
+ need to be direct neighbours. Thus the query is represented as a
+ directed acyclic graph.
+
+ When created, a query contains a single node: the start of the
+ query. Further nodes can be added by appending to 'nodes'.
+ """
+
+ def __init__(self, source: List[Phrase]) -> None:
+ self.source = source
+ self.nodes: List[QueryNode] = \
+ [QueryNode(BreakType.START, source[0].ptype if source else PhraseType.NONE)]
+
+
+ def num_token_slots(self) -> int:
+ """ Return the length of the query in vertice steps.
+ """
+ return len(self.nodes) - 1
+
+
+ def add_node(self, btype: BreakType, ptype: PhraseType) -> None:
+ """ Append a new break node with the given break type.
+ The phrase type denotes the type for any tokens starting
+ at the node.
+ """
+ self.nodes.append(QueryNode(btype, ptype))
+
+
+ def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
+ """ Add a token to the query. 'start' and 'end' are the indexes of the
+ nodes from which to which the token spans. The indexes must exist
+ and are expected to be in the same phrase.
+ 'ttype' denotes the type of the token and 'token' the token to
+ be inserted.
+
+ If the token type is not compatible with the phrase it should
+ be added to, then the token is silently dropped.
+ """
+ snode = self.nodes[trange.start]
+ if snode.ptype.compatible_with(ttype):
+ tlist = snode.get_tokens(trange.end, ttype)
+ if tlist is None:
+ snode.starting.append(TokenList(trange.end, ttype, [token]))
+ else:
+ tlist.append(token)
+
+
+ def get_tokens(self, trange: TokenRange, ttype: TokenType) -> List[Token]:
+ """ Get the list of tokens of a given type, spanning the given
+ nodes. The nodes must exist. If no tokens exist, an
+ empty list is returned.
+ """
+ return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
+
+
+ def get_partials_list(self, trange: TokenRange) -> List[Token]:
+ """ Create a list of partial tokens between the given nodes.
+ The list is composed of the first token of type PARTIAL
+ going to the subsequent node. Such PARTIAL tokens are
+ assumed to exist.
+ """
+ return [next(iter(self.get_tokens(TokenRange(i, i+1), TokenType.PARTIAL)))
+ for i in range(trange.start, trange.end)]
+
+
+ def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]:
+ """ Iterator over all token lists in the query.
+ """
+ for i, node in enumerate(self.nodes):
+ for tlist in node.starting:
+ yield i, node, tlist
+
+
+ def find_lookup_word_by_id(self, token: int) -> str:
+ """ Find the first token with the given token ID and return
+ its lookup word. Returns 'None' if no such token exists.
+ The function is very slow and must only be used for
+ debugging.
+ """
+ for node in self.nodes:
+ for tlist in node.starting:
+ for t in tlist.tokens:
+ if t.token == token:
+ return f"[{tlist.ttype.name[0]}]{t.lookup_word}"
+ return 'None'
diff --git a/nominatim/api/search/query_analyzer_factory.py b/nominatim/api/search/query_analyzer_factory.py
new file mode 100644
index 00000000..35649d0f
--- /dev/null
+++ b/nominatim/api/search/query_analyzer_factory.py
@@ -0,0 +1,45 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Factory for creating a query analyzer for the configured tokenizer.
+"""
+from typing import List, cast, TYPE_CHECKING
+from abc import ABC, abstractmethod
+from pathlib import Path
+import importlib
+
+from nominatim.api.logging import log
+from nominatim.api.connection import SearchConnection
+
+if TYPE_CHECKING:
+ from nominatim.api.search.query import Phrase, QueryStruct
+
+class AbstractQueryAnalyzer(ABC):
+ """ Class for analysing incomming queries.
+
+ Query analyzers are tied to the tokenizer used on import.
+ """
+
+ @abstractmethod
+ async def analyze_query(self, phrases: List['Phrase']) -> 'QueryStruct':
+ """ Analyze the given phrases and return the tokenized query.
+ """
+
+
+async def make_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
+ """ Create a query analyzer for the tokenizer used by the database.
+ """
+ name = await conn.get_property('tokenizer')
+
+ src_file = Path(__file__).parent / f'{name}_tokenizer.py'
+ if not src_file.is_file():
+ log().comment(f"No tokenizer named '{name}' available. Database not set up properly.")
+ raise RuntimeError('Tokenizer not found')
+
+ module = importlib.import_module(f'nominatim.api.search.{name}_tokenizer')
+
+ return cast(AbstractQueryAnalyzer, await module.create_query_analyzer(conn))
diff --git a/nominatim/api/search/token_assignment.py b/nominatim/api/search/token_assignment.py
new file mode 100644
index 00000000..11da2359
--- /dev/null
+++ b/nominatim/api/search/token_assignment.py
@@ -0,0 +1,357 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Create query interpretations where each vertice in the query is assigned
+a specific function (expressed as a token type).
+"""
+from typing import Optional, List, Iterator
+import dataclasses
+
+import nominatim.api.search.query as qmod
+from nominatim.api.logging import log
+
+# pylint: disable=too-many-return-statements,too-many-branches
+
+@dataclasses.dataclass
+class TypedRange:
+ """ A token range for a specific type of tokens.
+ """
+ ttype: qmod.TokenType
+ trange: qmod.TokenRange
+
+
+PENALTY_TOKENCHANGE = {
+ qmod.BreakType.START: 0.0,
+ qmod.BreakType.END: 0.0,
+ qmod.BreakType.PHRASE: 0.0,
+ qmod.BreakType.WORD: 0.1,
+ qmod.BreakType.PART: 0.2,
+ qmod.BreakType.TOKEN: 0.4
+}
+
+TypedRangeSeq = List[TypedRange]
+
+@dataclasses.dataclass
+class TokenAssignment: # pylint: disable=too-many-instance-attributes
+ """ Representation of a possible assignment of token types
+ to the tokens in a tokenized query.
+ """
+ penalty: float = 0.0
+ name: Optional[qmod.TokenRange] = None
+ address: List[qmod.TokenRange] = dataclasses.field(default_factory=list)
+ housenumber: Optional[qmod.TokenRange] = None
+ postcode: Optional[qmod.TokenRange] = None
+ country: Optional[qmod.TokenRange] = None
+ category: Optional[qmod.TokenRange] = None
+ qualifier: Optional[qmod.TokenRange] = None
+
+
+ @staticmethod
+ def from_ranges(ranges: TypedRangeSeq) -> 'TokenAssignment':
+ """ Create a new token assignment from a sequence of typed spans.
+ """
+ out = TokenAssignment()
+ for token in ranges:
+ if token.ttype == qmod.TokenType.PARTIAL:
+ out.address.append(token.trange)
+ elif token.ttype == qmod.TokenType.HOUSENUMBER:
+ out.housenumber = token.trange
+ elif token.ttype == qmod.TokenType.POSTCODE:
+ out.postcode = token.trange
+ elif token.ttype == qmod.TokenType.COUNTRY:
+ out.country = token.trange
+ elif token.ttype == qmod.TokenType.CATEGORY:
+ out.category = token.trange
+ elif token.ttype == qmod.TokenType.QUALIFIER:
+ out.qualifier = token.trange
+ return out
+
+
+class _TokenSequence:
+ """ Working state used to put together the token assignements.
+
+ Represents an intermediate state while traversing the tokenized
+ query.
+ """
+ def __init__(self, seq: TypedRangeSeq,
+ direction: int = 0, penalty: float = 0.0) -> None:
+ self.seq = seq
+ self.direction = direction
+ self.penalty = penalty
+
+
+ def __str__(self) -> str:
+ seq = ''.join(f'[{r.trange.start} - {r.trange.end}: {r.ttype.name}]' for r in self.seq)
+ return f'{seq} (dir: {self.direction}, penalty: {self.penalty})'
+
+
+ @property
+ def end_pos(self) -> int:
+ """ Return the index of the global end of the current sequence.
+ """
+ return self.seq[-1].trange.end if self.seq else 0
+
+
+ def has_types(self, *ttypes: qmod.TokenType) -> bool:
+ """ Check if the current sequence contains any typed ranges of
+ the given types.
+ """
+ return any(s.ttype in ttypes for s in self.seq)
+
+
+ def is_final(self) -> bool:
+ """ Return true when the sequence cannot be extended by any
+ form of token anymore.
+ """
+ # Country and category must be the final term for left-to-right
+ return len(self.seq) > 1 and \
+ self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.CATEGORY)
+
+
+ def appendable(self, ttype: qmod.TokenType) -> Optional[int]:
+ """ Check if the give token type is appendable to the existing sequence.
+
+ Returns None if the token type is not appendable, otherwise the
+ new direction of the sequence after adding such a type. The
+ token is not added.
+ """
+ if ttype == qmod.TokenType.WORD:
+ return None
+
+ if not self.seq:
+ # Append unconditionally to the empty list
+ if ttype == qmod.TokenType.COUNTRY:
+ return -1
+ if ttype in (qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
+ return 1
+ return self.direction
+
+ # Name tokens are always acceptable and don't change direction
+ if ttype == qmod.TokenType.PARTIAL:
+ return self.direction
+
+ # Other tokens may only appear once
+ if self.has_types(ttype):
+ return None
+
+ if ttype == qmod.TokenType.HOUSENUMBER:
+ if self.direction == 1:
+ if len(self.seq) == 1 and self.seq[0].ttype == qmod.TokenType.QUALIFIER:
+ return None
+ if len(self.seq) > 2 \
+ or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
+ return None # direction left-to-right: housenumber must come before anything
+ elif self.direction == -1 \
+ or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
+ return -1 # force direction right-to-left if after other terms
+
+ return self.direction
+
+ if ttype == qmod.TokenType.POSTCODE:
+ if self.direction == -1:
+ if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
+ return None
+ return -1
+ if self.direction == 1:
+ return None if self.has_types(qmod.TokenType.COUNTRY) else 1
+ if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
+ return 1
+ return self.direction
+
+ if ttype == qmod.TokenType.COUNTRY:
+ return None if self.direction == -1 else 1
+
+ if ttype == qmod.TokenType.CATEGORY:
+ return self.direction
+
+ if ttype == qmod.TokenType.QUALIFIER:
+ if self.direction == 1:
+ if (len(self.seq) == 1
+ and self.seq[0].ttype in (qmod.TokenType.PARTIAL, qmod.TokenType.CATEGORY)) \
+ or (len(self.seq) == 2
+ and self.seq[0].ttype == qmod.TokenType.CATEGORY
+ and self.seq[1].ttype == qmod.TokenType.PARTIAL):
+ return 1
+ return None
+ if self.direction == -1:
+ return -1
+
+ tempseq = self.seq[1:] if self.seq[0].ttype == qmod.TokenType.CATEGORY else self.seq
+ if len(tempseq) == 0:
+ return 1
+ if len(tempseq) == 1 and self.seq[0].ttype == qmod.TokenType.HOUSENUMBER:
+ return None
+ if len(tempseq) > 1 or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
+ return -1
+ return 0
+
+ return None
+
+
+ def advance(self, ttype: qmod.TokenType, end_pos: int,
+ btype: qmod.BreakType) -> Optional['_TokenSequence']:
+ """ Return a new token sequence state with the given token type
+ extended.
+ """
+ newdir = self.appendable(ttype)
+ if newdir is None:
+ return None
+
+ if not self.seq:
+ newseq = [TypedRange(ttype, qmod.TokenRange(0, end_pos))]
+ new_penalty = 0.0
+ else:
+ last = self.seq[-1]
+ if btype != qmod.BreakType.PHRASE and last.ttype == ttype:
+ # extend the existing range
+ newseq = self.seq[:-1] + [TypedRange(ttype, last.trange.replace_end(end_pos))]
+ new_penalty = 0.0
+ else:
+ # start a new range
+ newseq = list(self.seq) + [TypedRange(ttype,
+ qmod.TokenRange(last.trange.end, end_pos))]
+ new_penalty = PENALTY_TOKENCHANGE[btype]
+
+ return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
+
+
+ def _adapt_penalty_from_priors(self, priors: int, new_dir: int) -> bool:
+ if priors == 2:
+ self.penalty += 1.0
+ elif priors > 2:
+ if self.direction == 0:
+ self.direction = new_dir
+ else:
+ return False
+
+ return True
+
+
+ def recheck_sequence(self) -> bool:
+ """ Check that the sequence is a fully valid token assignment
+ and addapt direction and penalties further if necessary.
+
+ This function catches some impossible assignments that need
+ forward context and can therefore not be exluded when building
+ the assignment.
+ """
+ # housenumbers may not be further than 2 words from the beginning.
+ # If there are two words in front, give it a penalty.
+ hnrpos = next((i for i, tr in enumerate(self.seq)
+ if tr.ttype == qmod.TokenType.HOUSENUMBER),
+ None)
+ if hnrpos is not None:
+ if self.direction != -1:
+ priors = sum(1 for t in self.seq[:hnrpos] if t.ttype == qmod.TokenType.PARTIAL)
+ if not self._adapt_penalty_from_priors(priors, -1):
+ return False
+ if self.direction != 1:
+ priors = sum(1 for t in self.seq[hnrpos+1:] if t.ttype == qmod.TokenType.PARTIAL)
+ if not self._adapt_penalty_from_priors(priors, 1):
+ return False
+
+ return True
+
+
+ def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
+ """ Yield possible assignments for the current sequence.
+
+ This function splits up general name assignments into name
+ and address and yields all possible variants of that.
+ """
+ base = TokenAssignment.from_ranges(self.seq)
+
+ # Postcode search (postcode-only search is covered in next case)
+ if base.postcode is not None and base.address:
+ if (base.postcode.start == 0 and self.direction != -1)\
+ or (base.postcode.end == query.num_token_slots() and self.direction != 1):
+ log().comment('postcode search')
+ # ,.*?)\s+)??' + r + r'(?:\s+(?P.*))?') for r in ( + r"(?P [NS])\s*" + _deg('lat') + r"[\s,]+" + r"(?P [EW])\s*" + _deg('lon'), + _deg('lat') + r"\s*(?P [NS])[\s,]+" + _deg('lon') + r"\s*(?P [EW])", + r"(?P [NS])\s*" + _deg_min('lat') + r"[\s,]+" + r"(?P [EW])\s*" + _deg_min('lon'), + _deg_min('lat') + r"\s*(?P [NS])[\s,]+" + _deg_min('lon') + r"\s*(?P [EW])", + r"(?P [NS])\s*" + _deg_min_sec('lat') + r"[\s,]+" + r"(?P [EW])\s*" + _deg_min_sec('lon'), + _deg_min_sec('lat') + r"\s*(?P [NS])[\s,]+" + _deg_min_sec('lon') + r"\s*(?P [EW])", + r"\[?(?P [+-]?\d+\.\d+)[\s,]+(?P [+-]?\d+\.\d+)\]?" +)] + +def extract_coords_from_query(query: str) -> Tuple[str, Optional[float], Optional[float]]: + """ Look for something that is formated like a coordinate at the + beginning or end of the query. If found, extract the coordinate and + return the remaining query (or the empty string if the query + consisted of nothing but a coordinate). + + Only the first match will be returned. + """ + for regex in COORD_REGEX: + match = regex.fullmatch(query) + if match is None: + continue + groups = match.groupdict() + if not groups['pre'] or not groups['post']: + x = float(groups['lon_deg']) \ + + float(groups.get('lon_min', 0.0)) / 60.0 \ + + float(groups.get('lon_sec', 0.0)) / 3600.0 + if groups.get('ew') == 'W': + x = -x + y = float(groups['lat_deg']) \ + + float(groups.get('lat_min', 0.0)) / 60.0 \ + + float(groups.get('lat_sec', 0.0)) / 3600.0 + if groups.get('ns') == 'S': + y = -y + return groups['pre'] or groups['post'] or '', x, y + + return query, None, None + + +CATEGORY_REGEX = re.compile(r'(?P .*?)\[(?P[a-zA-Z_]+)=(?P [a-zA-Z_]+)\](?P .*)') + +def extract_category_from_query(query: str) -> Tuple[str, Optional[str], Optional[str]]: + """ Extract a hidden category specification of the form '[key=value]' from + the query. If found, extract key and value and + return the remaining query (or the empty string if the query + consisted of nothing but a category). + + Only the first match will be returned. + """ + match = CATEGORY_REGEX.search(query) + if match is not None: + return (match.group('pre').strip() + ' ' + match.group('post').strip()).strip(), \ + match.group('cls'), match.group('typ') + + return query, None, None diff --git a/nominatim/api/v1/server_glue.py b/nominatim/api/v1/server_glue.py index ccf8f7d1..865e1331 100644 --- a/nominatim/api/v1/server_glue.py +++ b/nominatim/api/v1/server_glue.py @@ -11,8 +11,11 @@ Combine with the scaffolding provided for the various Python ASGI frameworks. from typing import Optional, Any, Type, Callable, NoReturn, Dict, cast from functools import reduce import abc +import dataclasses import math +from urllib.parse import urlencode +from nominatim.errors import UsageError from nominatim.config import Configuration import nominatim.api as napi import nominatim.api.logging as loglib @@ -182,7 +185,7 @@ class ASGIAdaptor(abc.ABC): """ Return the accepted languages. """ return self.get('accept-language')\ - or self.get_header('http_accept_language')\ + or self.get_header('accept-language')\ or self.config().DEFAULT_LANGUAGE @@ -305,6 +308,8 @@ async def details_endpoint(api: napi.NominatimAPIAsync, params: ASGIAdaptor) -> if result is None: params.raise_error('No place with that OSM ID found.', status=404) + result.localize(locales) + output = formatting.format_result(result, fmt, {'locales': locales, 'group_hierarchy': params.get_bool('group_hierarchy', False), @@ -319,7 +324,6 @@ async def reverse_endpoint(api: napi.NominatimAPIAsync, params: ASGIAdaptor) -> fmt = params.parse_format(napi.ReverseResults, 'xml') debug = params.setup_debugging() coord = napi.Point(params.get_float('lon'), params.get_float('lat')) - locales = napi.Locales.from_accept_languages(params.get_accepted_languages()) details = params.parse_geometry_details(fmt) details['max_rank'] = helpers.zoom_to_rank(params.get_int('zoom', 18)) @@ -330,11 +334,23 @@ async def reverse_endpoint(api: napi.NominatimAPIAsync, params: ASGIAdaptor) -> if debug: return params.build_response(loglib.get_and_disable()) - fmt_options = {'locales': locales, + if fmt == 'xml': + queryparts = {'lat': str(coord.lat), 'lon': str(coord.lon), 'format': 'xml'} + zoom = params.get('zoom', None) + if zoom: + queryparts['zoom'] = zoom + query = urlencode(queryparts) + else: + query = '' + + fmt_options = {'query': query, 'extratags': params.get_bool('extratags', False), 'namedetails': params.get_bool('namedetails', False), 'addressdetails': params.get_bool('addressdetails', True)} + if result: + result.localize(napi.Locales.from_accept_languages(params.get_accepted_languages())) + output = formatting.format_result(napi.ReverseResults([result] if result else []), fmt, fmt_options) @@ -346,7 +362,6 @@ async def lookup_endpoint(api: napi.NominatimAPIAsync, params: ASGIAdaptor) -> A """ fmt = params.parse_format(napi.SearchResults, 'xml') debug = params.setup_debugging() - locales = napi.Locales.from_accept_languages(params.get_accepted_languages()) details = params.parse_geometry_details(fmt) places = [] @@ -355,6 +370,9 @@ async def lookup_endpoint(api: napi.NominatimAPIAsync, params: ASGIAdaptor) -> A if len(oid) > 1 and oid[0] in 'RNWrnw' and oid[1:].isdigit(): places.append(napi.OsmID(oid[0], int(oid[1:]))) + if len(places) > params.config().get_int('LOOKUP_MAX_COUNT'): + params.raise_error('Too many object IDs.') + if places: results = await api.lookup(places, **details) else: @@ -363,20 +381,125 @@ async def lookup_endpoint(api: napi.NominatimAPIAsync, params: ASGIAdaptor) -> A if debug: return params.build_response(loglib.get_and_disable()) - fmt_options = {'locales': locales, - 'extratags': params.get_bool('extratags', False), + fmt_options = {'extratags': params.get_bool('extratags', False), 'namedetails': params.get_bool('namedetails', False), 'addressdetails': params.get_bool('addressdetails', True)} + results.localize(napi.Locales.from_accept_languages(params.get_accepted_languages())) + + output = formatting.format_result(results, fmt, fmt_options) + + return params.build_response(output) + + +async def _unstructured_search(query: str, api: napi.NominatimAPIAsync, + details: Dict[str, Any]) -> napi.SearchResults: + if not query: + return napi.SearchResults() + + # Extract special format for coordinates from query. + query, x, y = helpers.extract_coords_from_query(query) + if x is not None: + assert y is not None + details['near'] = napi.Point(x, y) + details['near_radius'] = 0.1 + + # If no query is left, revert to reverse search. + if x is not None and not query: + result = await api.reverse(details['near'], **details) + if not result: + return napi.SearchResults() + + return napi.SearchResults( + [napi.SearchResult(**{f.name: getattr(result, f.name) + for f in dataclasses.fields(napi.SearchResult) + if hasattr(result, f.name)})]) + + query, cls, typ = helpers.extract_category_from_query(query) + if cls is not None: + assert typ is not None + return await api.search_category([(cls, typ)], near_query=query, **details) + + return await api.search(query, **details) + + +async def search_endpoint(api: napi.NominatimAPIAsync, params: ASGIAdaptor) -> Any: + """ Server glue for /search endpoint. See API docs for details. + """ + fmt = params.parse_format(napi.SearchResults, 'jsonv2') + debug = params.setup_debugging() + details = params.parse_geometry_details(fmt) + + details['countries'] = params.get('countrycodes', None) + details['excluded'] = params.get('exclude_place_ids', None) + details['viewbox'] = params.get('viewbox', None) or params.get('viewboxlbrt', None) + details['bounded_viewbox'] = params.get_bool('bounded', False) + details['dedupe'] = params.get_bool('dedupe', True) + + max_results = max(1, min(50, params.get_int('limit', 10))) + details['max_results'] = max_results + min(10, max_results) \ + if details['dedupe'] else max_results + + details['min_rank'], details['max_rank'] = \ + helpers.feature_type_to_rank(params.get('featureType', '')) + if params.get('featureType', None) is not None: + details['layers'] = napi.DataLayer.ADDRESS + + query = params.get('q', None) + queryparts = {} + try: + if query is not None: + queryparts['q'] = query + results = await _unstructured_search(query, api, details) + else: + for key in ('amenity', 'street', 'city', 'county', 'state', 'postalcode', 'country'): + details[key] = params.get(key, None) + if details[key]: + queryparts[key] = details[key] + query = ', '.join(queryparts.values()) + + results = await api.search_address(**details) + except UsageError as err: + params.raise_error(str(err)) + + results.localize(napi.Locales.from_accept_languages(params.get_accepted_languages())) + + if details['dedupe'] and len(results) > 1: + results = helpers.deduplicate_results(results, max_results) + + if debug: + return params.build_response(loglib.get_and_disable()) + + if fmt == 'xml': + helpers.extend_query_parts(queryparts, details, + params.get('featureType', ''), + params.get_bool('namedetails', False), + params.get_bool('extratags', False), + (str(r.place_id) for r in results if r.place_id)) + queryparts['format'] = fmt + + moreurl = urlencode(queryparts) + else: + moreurl = '' + + fmt_options = {'query': query, 'more_url': moreurl, + 'exclude_place_ids': queryparts.get('exclude_place_ids'), + 'viewbox': queryparts.get('viewbox'), + 'extratags': params.get_bool('extratags', False), + 'namedetails': params.get_bool('namedetails', False), + 'addressdetails': params.get_bool('addressdetails', False)} + output = formatting.format_result(results, fmt, fmt_options) return params.build_response(output) + EndpointFunc = Callable[[napi.NominatimAPIAsync, ASGIAdaptor], Any] ROUTES = [ ('status', status_endpoint), ('details', details_endpoint), ('reverse', reverse_endpoint), - ('lookup', lookup_endpoint) + ('lookup', lookup_endpoint), + ('search', search_endpoint) ] diff --git a/nominatim/cli.py b/nominatim/cli.py index d34ef118..6a89a8de 100644 --- a/nominatim/cli.py +++ b/nominatim/cli.py @@ -251,7 +251,7 @@ class AdminServe: return 0 -def get_set_parser(**kwargs: Any) -> CommandlineParser: +def get_set_parser() -> CommandlineParser: """\ Initializes the parser and adds various subcommands for nominatim cli. @@ -273,14 +273,11 @@ def get_set_parser(**kwargs: Any) -> CommandlineParser: parser.add_subcommand('export', QueryExport()) parser.add_subcommand('serve', AdminServe()) - if kwargs.get('phpcgi_path'): - parser.add_subcommand('search', clicmd.APISearch()) - parser.add_subcommand('reverse', clicmd.APIReverse()) - parser.add_subcommand('lookup', clicmd.APILookup()) - parser.add_subcommand('details', clicmd.APIDetails()) - parser.add_subcommand('status', clicmd.APIStatus()) - else: - parser.parser.epilog = 'php-cgi not found. Query commands not available.' + parser.add_subcommand('search', clicmd.APISearch()) + parser.add_subcommand('reverse', clicmd.APIReverse()) + parser.add_subcommand('lookup', clicmd.APILookup()) + parser.add_subcommand('details', clicmd.APIDetails()) + parser.add_subcommand('status', clicmd.APIStatus()) return parser @@ -290,6 +287,4 @@ def nominatim(**kwargs: Any) -> int: Command-line tools for importing, updating, administrating and querying the Nominatim database. """ - parser = get_set_parser(**kwargs) - - return parser.run(**kwargs) + return get_set_parser().run(**kwargs) diff --git a/nominatim/clicmd/api.py b/nominatim/clicmd/api.py index fef6bdf6..f2f1826b 100644 --- a/nominatim/clicmd/api.py +++ b/nominatim/clicmd/api.py @@ -7,7 +7,7 @@ """ Subcommand definitions for API calls from the command line. """ -from typing import Mapping, Dict +from typing import Mapping, Dict, Any import argparse import logging import json @@ -18,7 +18,7 @@ from nominatim.errors import UsageError from nominatim.clicmd.args import NominatimArgs import nominatim.api as napi import nominatim.api.v1 as api_output -from nominatim.api.v1.helpers import zoom_to_rank +from nominatim.api.v1.helpers import zoom_to_rank, deduplicate_results import nominatim.api.logging as loglib # Do not repeat documentation of subcommand classes. @@ -27,6 +27,7 @@ import nominatim.api.logging as loglib LOG = logging.getLogger() STRUCTURED_QUERY = ( + ('amenity', 'name and/or type of POI'), ('street', 'housenumber and street'), ('city', 'city, town or village'), ('county', 'county'), @@ -97,7 +98,7 @@ class APISearch: help='Limit search results to one or more countries') group.add_argument('--exclude_place_ids', metavar='ID,..', help='List of search object to be excluded') - group.add_argument('--limit', type=int, + group.add_argument('--limit', type=int, default=10, help='Limit the number of returned results') group.add_argument('--viewbox', metavar='X1,Y1,X2,Y2', help='Preferred area to find search results') @@ -110,30 +111,58 @@ class APISearch: def run(self, args: NominatimArgs) -> int: - params: Dict[str, object] + if args.format == 'debug': + loglib.set_log_output('text') + + api = napi.NominatimAPI(args.project_dir) + + params: Dict[str, Any] = {'max_results': args.limit + min(args.limit, 10), + 'address_details': True, # needed for display name + 'geometry_output': args.get_geometry_output(), + 'geometry_simplification': args.polygon_threshold, + 'countries': args.countrycodes, + 'excluded': args.exclude_place_ids, + 'viewbox': args.viewbox, + 'bounded_viewbox': args.bounded + } + if args.query: - params = dict(q=args.query) + results = api.search(args.query, **params) else: - params = {k: getattr(args, k) for k, _ in STRUCTURED_QUERY if getattr(args, k)} - - for param, _ in EXTRADATA_PARAMS: - if getattr(args, param): - params[param] = '1' - for param in ('format', 'countrycodes', 'exclude_place_ids', 'limit', 'viewbox'): - if getattr(args, param): - params[param] = getattr(args, param) - if args.lang: - params['accept-language'] = args.lang - if args.polygon_output: - params['polygon_' + args.polygon_output] = '1' - if args.polygon_threshold: - params['polygon_threshold'] = args.polygon_threshold - if args.bounded: - params['bounded'] = '1' - if not args.dedupe: - params['dedupe'] = '0' - - return _run_api('search', args, params) + results = api.search_address(amenity=args.amenity, + street=args.street, + city=args.city, + county=args.county, + state=args.state, + postalcode=args.postalcode, + country=args.country, + **params) + + for result in results: + result.localize(args.get_locales(api.config.DEFAULT_LANGUAGE)) + + if args.dedupe and len(results) > 1: + results = deduplicate_results(results, args.limit) + + if args.format == 'debug': + print(loglib.get_and_disable()) + return 0 + + output = api_output.format_result( + results, + args.format, + {'extratags': args.extratags, + 'namedetails': args.namedetails, + 'addressdetails': args.addressdetails}) + if args.format != 'xml': + # reformat the result, so it is pretty-printed + json.dump(json.loads(output), sys.stdout, indent=4, ensure_ascii=False) + else: + sys.stdout.write(output) + sys.stdout.write('\n') + + return 0 + class APIReverse: """\ @@ -179,11 +208,11 @@ class APIReverse: return 0 if result: + result.localize(args.get_locales(api.config.DEFAULT_LANGUAGE)) output = api_output.format_result( napi.ReverseResults([result]), args.format, - {'locales': args.get_locales(api.config.DEFAULT_LANGUAGE), - 'extratags': args.extratags, + {'extratags': args.extratags, 'namedetails': args.namedetails, 'addressdetails': args.addressdetails}) if args.format != 'xml': @@ -236,11 +265,13 @@ class APILookup: geometry_output=args.get_geometry_output(), geometry_simplification=args.polygon_threshold or 0.0) + for result in results: + result.localize(args.get_locales(api.config.DEFAULT_LANGUAGE)) + output = api_output.format_result( results, args.format, - {'locales': args.get_locales(api.config.DEFAULT_LANGUAGE), - 'extratags': args.extratags, + {'extratags': args.extratags, 'namedetails': args.namedetails, 'addressdetails': args.addressdetails}) if args.format != 'xml': @@ -320,10 +351,13 @@ class APIDetails: if result: + locales = args.get_locales(api.config.DEFAULT_LANGUAGE) + result.localize(locales) + output = api_output.format_result( result, 'json', - {'locales': args.get_locales(api.config.DEFAULT_LANGUAGE), + {'locales': locales, 'group_hierarchy': args.group_hierarchy}) # reformat the result, so it is pretty-printed json.dump(json.loads(output), sys.stdout, indent=4, ensure_ascii=False) diff --git a/nominatim/clicmd/args.py b/nominatim/clicmd/args.py index bf3109ac..10316165 100644 --- a/nominatim/clicmd/args.py +++ b/nominatim/clicmd/args.py @@ -147,6 +147,7 @@ class NominatimArgs: # Arguments to 'search' query: Optional[str] + amenity: Optional[str] street: Optional[str] city: Optional[str] county: Optional[str] @@ -155,7 +156,7 @@ class NominatimArgs: postalcode: Optional[str] countrycodes: Optional[str] exclude_place_ids: Optional[str] - limit: Optional[int] + limit: int viewbox: Optional[str] bounded: bool dedupe: bool diff --git a/nominatim/clicmd/replication.py b/nominatim/clicmd/replication.py index ad201663..b7956506 100644 --- a/nominatim/clicmd/replication.py +++ b/nominatim/clicmd/replication.py @@ -147,10 +147,13 @@ class UpdateReplication: tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config) indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, args.threads or 1) + dsn = args.config.get_libpq_dsn() + while True: - with connect(args.config.get_libpq_dsn()) as conn: - start = dt.datetime.now(dt.timezone.utc) - state = replication.update(conn, params, socket_timeout=args.socket_timeout) + start = dt.datetime.now(dt.timezone.utc) + state = replication.update(dsn, params, socket_timeout=args.socket_timeout) + + with connect(dsn) as conn: if state is not replication.UpdateState.NO_CHANGES: status.log_status(conn, start, 'import') batchdate, _, _ = status.get_status(conn) @@ -160,7 +163,7 @@ class UpdateReplication: index_start = dt.datetime.now(dt.timezone.utc) indexer.index_full(analyse=False) - with connect(args.config.get_libpq_dsn()) as conn: + with connect(dsn) as conn: status.set_indexed(conn, True) status.log_status(conn, index_start, 'index') conn.commit() diff --git a/nominatim/db/sql_preprocessor.py b/nominatim/db/sql_preprocessor.py index 31b4a8c0..2e11f571 100644 --- a/nominatim/db/sql_preprocessor.py +++ b/nominatim/db/sql_preprocessor.py @@ -57,9 +57,11 @@ def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]: """ pg_version = conn.server_version_tuple() postgis_version = conn.postgis_version_tuple() + pg11plus = pg_version >= (11, 0, 0) + ps3 = postgis_version >= (3, 0) return { - 'has_index_non_key_column': pg_version >= (11, 0, 0), - 'spgist_geom' : 'SPGIST' if postgis_version >= (3, 0) else 'GIST' + 'has_index_non_key_column': pg11plus, + 'spgist_geom' : 'SPGIST' if pg11plus and ps3 else 'GIST' } class SQLPreprocessor: diff --git a/nominatim/db/sqlalchemy_schema.py b/nominatim/db/sqlalchemy_schema.py index 26bbefcf..550f1f12 100644 --- a/nominatim/db/sqlalchemy_schema.py +++ b/nominatim/db/sqlalchemy_schema.py @@ -113,13 +113,6 @@ class SearchTables: sa.Column('postcode', sa.Text), sa.Column('country_code', sa.String(2))) - self.word = sa.Table('word', meta, - sa.Column('word_id', sa.Integer), - sa.Column('word_token', sa.Text, nullable=False), - sa.Column('type', sa.Text, nullable=False), - sa.Column('word', sa.Text), - sa.Column('info', self.types.Json)) - self.country_name = sa.Table('country_name', meta, sa.Column('country_code', sa.String(2)), sa.Column('name', self.types.Composite), diff --git a/nominatim/server/starlette/server.py b/nominatim/server/starlette/server.py index a2a70ebf..f81b122f 100644 --- a/nominatim/server/starlette/server.py +++ b/nominatim/server/starlette/server.py @@ -78,7 +78,11 @@ def get_application(project_dir: Path, if config.get_bool('CORS_NOACCESSCONTROL'): middleware.append(Middleware(CORSMiddleware, allow_origins=['*'])) - app = Starlette(debug=debug, routes=routes, middleware=middleware) + async def _shutdown() -> None: + await app.state.API.close() + + app = Starlette(debug=debug, routes=routes, middleware=middleware, + on_shutdown=[_shutdown]) app.state.API = NominatimAPIAsync(project_dir, environ) diff --git a/nominatim/tools/refresh.py b/nominatim/tools/refresh.py index 5dd98c0e..c6df9982 100644 --- a/nominatim/tools/refresh.py +++ b/nominatim/tools/refresh.py @@ -217,7 +217,7 @@ def setup_website(basedir: Path, config: Configuration, conn: Connection) -> Non basedir.mkdir() assert config.project_dir is not None - template = dedent(f"""\ + basedata = dedent(f"""\ Non for php_name, conf_name, var_type in PHP_CONST_DEFS: varout = _quote_php_variable(var_type, config, conf_name) - template += f"@define('CONST_{php_name}', {varout});\n" + basedata += f"@define('CONST_{php_name}', {varout});\n" - template += f"\nrequire_once('{config.lib_dir.php}/website/{{}}');\n" + template = "\nrequire_once(CONST_LibDir.'/website/{}');\n" search_name_table_exists = bool(conn and conn.table_exists('search_name')) for script in WEBSITE_SCRIPTS: if not search_name_table_exists and script == 'search.php': - (basedir / script).write_text(template.format('reverse-only-search.php'), 'utf-8') + out = template.format('reverse-only-search.php') else: - (basedir / script).write_text(template.format(script), 'utf-8') + out = template.format(script) + + (basedir / script).write_text(basedata + out, 'utf-8') def invalidate_osm_object(osm_type: str, osm_id: int, conn: Connection, diff --git a/nominatim/tools/replication.py b/nominatim/tools/replication.py index 846b9c34..edd63e49 100644 --- a/nominatim/tools/replication.py +++ b/nominatim/tools/replication.py @@ -18,7 +18,7 @@ import urllib.request as urlrequest import requests from nominatim.db import status -from nominatim.db.connection import Connection +from nominatim.db.connection import Connection, connect from nominatim.tools.exec_utils import run_osm2pgsql from nominatim.errors import UsageError @@ -92,12 +92,14 @@ class UpdateState(Enum): NO_CHANGES = 3 -def update(conn: Connection, options: MutableMapping[str, Any], +def update(dsn: str, options: MutableMapping[str, Any], socket_timeout: int = 60) -> UpdateState: """ Update database from the next batch of data. Returns the state of updates according to `UpdateState`. """ - startdate, startseq, indexed = status.get_status(conn) + with connect(dsn) as conn: + startdate, startseq, indexed = status.get_status(conn) + conn.commit() if startseq is None: LOG.error("Replication not set up. " @@ -130,12 +132,14 @@ def update(conn: Connection, options: MutableMapping[str, Any], if endseq is None: return UpdateState.NO_CHANGES - run_osm2pgsql_updates(conn, options) + with connect(dsn) as conn: + run_osm2pgsql_updates(conn, options) - # Write the current status to the file - endstate = repl.get_state_info(endseq) - status.set_status(conn, endstate.timestamp if endstate else None, - seq=endseq, indexed=False) + # Write the current status to the file + endstate = repl.get_state_info(endseq) + status.set_status(conn, endstate.timestamp if endstate else None, + seq=endseq, indexed=False) + conn.commit() return UpdateState.UP_TO_DATE diff --git a/nominatim/typing.py b/nominatim/typing.py index bc4c5534..d988fe04 100644 --- a/nominatim/typing.py +++ b/nominatim/typing.py @@ -63,8 +63,10 @@ else: TypeAlias = str SaSelect: TypeAlias = 'sa.Select[Any]' +SaScalarSelect: TypeAlias = 'sa.ScalarSelect[Any]' SaRow: TypeAlias = 'sa.Row[Any]' SaColumn: TypeAlias = 'sa.ColumnElement[Any]' +SaExpression: TypeAlias = 'sa.ColumnElement[bool]' SaLabel: TypeAlias = 'sa.Label[Any]' SaFromClause: TypeAlias = 'sa.FromClause' SaSelectable: TypeAlias = 'sa.Selectable' diff --git a/test/bdd/api/search/params.feature b/test/bdd/api/search/params.feature index 053dbbcd..d5512f5b 100644 --- a/test/bdd/api/search/params.feature +++ b/test/bdd/api/search/params.feature @@ -187,10 +187,6 @@ Feature: Search queries Then a HTTP 400 is returned Scenario: Restrict to feature type country - When sending xml search query "fürstentum" - Then results contain - | ID | class | - | 1 | building | When sending xml search query "fürstentum" | featureType | | country | @@ -200,7 +196,7 @@ Feature: Search queries Scenario: Restrict to feature type state When sending xml search query "Wangerberg" - Then more than 1 result is returned + Then at least 1 result is returned When sending xml search query "Wangerberg" | featureType | | state | @@ -208,9 +204,7 @@ Feature: Search queries Scenario: Restrict to feature type city When sending xml search query "vaduz" - Then results contain - | ID | place_rank | - | 1 | 30 | + Then at least 1 result is returned When sending xml search query "vaduz" | featureType | | city | @@ -358,6 +352,7 @@ Feature: Search queries | svg | | geokml | + @v1-api-php-only Scenario: Search along a route When sending json search query "rathaus" with address Then result addresses contain diff --git a/test/bdd/api/search/queries.feature b/test/bdd/api/search/queries.feature index d378d3f8..f0474460 100644 --- a/test/bdd/api/search/queries.feature +++ b/test/bdd/api/search/queries.feature @@ -97,6 +97,7 @@ Feature: Search queries | class | type | | club | scout | + @v1-api-php-only Scenario: With multiple amenity search only the first is used When sending json search query "[club=scout] [church] vaduz" Then results contain @@ -119,6 +120,7 @@ Feature: Search queries | class | type | | leisure | firepit | + @v1-api-php-only Scenario: Arbitrary key/value search near given coordinate and named place When sending json search query "[leisure=firepit] ebenholz 47° 9â² 26â³ N 9° 36â² 45â³ E" Then results contain @@ -184,7 +186,6 @@ Feature: Search queries Then result addresses contain | ID | house_number | | 0 | 11 | - | 1 | 11 a | Scenario Outline: Coordinate searches with white spaces When sending json search query "" diff --git a/test/bdd/api/search/simple.feature b/test/bdd/api/search/simple.feature index b9323c5a..11cd4801 100644 --- a/test/bdd/api/search/simple.feature +++ b/test/bdd/api/search/simple.feature @@ -146,9 +146,6 @@ Feature: Simple Tests | foo | foo | | FOO | FOO | | __world | __world | - | $me | \$me | - | m1[4] | m1\[4\] | - | d_r[$d] | d_r\[\$d\] | Scenario Outline: Wrapping of illegal jsonp search requests When sending json search query "Tokyo" diff --git a/test/bdd/db/query/normalization.feature b/test/bdd/db/query/normalization.feature index e5a7a592..5e94cd3e 100644 --- a/test/bdd/db/query/normalization.feature +++ b/test/bdd/db/query/normalization.feature @@ -209,8 +209,8 @@ Feature: Import and search of names When importing And sending search query "Main St " Then results contain - | osm | display_name | - | N1 | , Main St | + | ID | osm | display_name | + | 0 | N1 | , Main St | Examples: | nr-list | nr | diff --git a/test/bdd/steps/steps_api_queries.py b/test/bdd/steps/steps_api_queries.py index 550cf531..55bb2084 100644 --- a/test/bdd/steps/steps_api_queries.py +++ b/test/bdd/steps/steps_api_queries.py @@ -265,7 +265,10 @@ def check_page_error(context, fmt): @then(u'result header contains') def check_header_attr(context): + context.execute_steps("Then a HTTP 200 is returned") for line in context.table: + assert line['attr'] in context.response.header, \ + f"Field '{line['attr']}' missing in header. Full header:\n{context.response.header}" value = context.response.header[line['attr']] assert re.fullmatch(line['value'], value) is not None, \ f"Attribute '{line['attr']}': expected: '{line['value']}', got '{value}'" diff --git a/test/python/api/conftest.py b/test/python/api/conftest.py index d8a6dfa0..cfe14e1e 100644 --- a/test/python/api/conftest.py +++ b/test/python/api/conftest.py @@ -12,6 +12,8 @@ import pytest import time import datetime as dt +import sqlalchemy as sa + import nominatim.api as napi from nominatim.db.sql_preprocessor import SQLPreprocessor import nominatim.api.logging as loglib @@ -129,6 +131,34 @@ class APITester: 'geometry': 'SRID=4326;' + geometry}) + def add_country_name(self, country_code, names, partition=0): + self.add_data('country_name', + {'country_code': country_code, + 'name': names, + 'partition': partition}) + + + def add_search_name(self, place_id, **kw): + centroid = kw.get('centroid', (23.0, 34.0)) + self.add_data('search_name', + {'place_id': place_id, + 'importance': kw.get('importance', 0.00001), + 'search_rank': kw.get('search_rank', 30), + 'address_rank': kw.get('address_rank', 30), + 'name_vector': kw.get('names', []), + 'nameaddress_vector': kw.get('address', []), + 'country_code': kw.get('country_code', 'xx'), + 'centroid': 'SRID=4326;POINT(%f %f)' % centroid}) + + + def add_class_type_table(self, cls, typ): + self.async_to_sync( + self.exec_async(sa.text(f"""CREATE TABLE place_classtype_{cls}_{typ} + AS (SELECT place_id, centroid FROM placex + WHERE class = '{cls}' AND type = '{typ}') + """))) + + async def exec_async(self, sql, *args, **kwargs): async with self.api._async_api.begin() as conn: return await conn.execute(sql, *args, **kwargs) diff --git a/test/python/api/search/test_api_search_query.py b/test/python/api/search/test_api_search_query.py new file mode 100644 index 00000000..f8c9c2dc --- /dev/null +++ b/test/python/api/search/test_api_search_query.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for tokenized query data structures. +""" +import pytest + +from nominatim.api.search import query + +class MyToken(query.Token): + + def get_category(self): + return 'this', 'that' + + +def mktoken(tid: int): + return MyToken(3.0, tid, 1, 'foo', True) + + +@pytest.mark.parametrize('ptype,ttype', [('NONE', 'WORD'), + ('AMENITY', 'QUALIFIER'), + ('STREET', 'PARTIAL'), + ('CITY', 'WORD'), + ('COUNTRY', 'COUNTRY'), + ('POSTCODE', 'POSTCODE')]) +def test_phrase_compatible(ptype, ttype): + assert query.PhraseType[ptype].compatible_with(query.TokenType[ttype]) + + +@pytest.mark.parametrize('ptype', ['COUNTRY', 'POSTCODE']) +def test_phrase_incompatible(ptype): + assert not query.PhraseType[ptype].compatible_with(query.TokenType.PARTIAL) + + +def test_query_node_empty(): + qn = query.QueryNode(query.BreakType.PHRASE, query.PhraseType.NONE) + + assert not qn.has_tokens(3, query.TokenType.PARTIAL) + assert qn.get_tokens(3, query.TokenType.WORD) is None + + +def test_query_node_with_content(): + qn = query.QueryNode(query.BreakType.PHRASE, query.PhraseType.NONE) + qn.starting.append(query.TokenList(2, query.TokenType.PARTIAL, [mktoken(100), mktoken(101)])) + qn.starting.append(query.TokenList(2, query.TokenType.WORD, [mktoken(1000)])) + + assert not qn.has_tokens(3, query.TokenType.PARTIAL) + assert not qn.has_tokens(2, query.TokenType.COUNTRY) + assert qn.has_tokens(2, query.TokenType.PARTIAL) + assert qn.has_tokens(2, query.TokenType.WORD) + + assert qn.get_tokens(3, query.TokenType.PARTIAL) is None + assert qn.get_tokens(2, query.TokenType.COUNTRY) is None + assert len(qn.get_tokens(2, query.TokenType.PARTIAL)) == 2 + assert len(qn.get_tokens(2, query.TokenType.WORD)) == 1 + + +def test_query_struct_empty(): + q = query.QueryStruct([]) + + assert q.num_token_slots() == 0 + + +def test_query_struct_with_tokens(): + q = query.QueryStruct([query.Phrase(query.PhraseType.NONE, 'foo bar')]) + q.add_node(query.BreakType.WORD, query.PhraseType.NONE) + q.add_node(query.BreakType.END, query.PhraseType.NONE) + + assert q.num_token_slots() == 2 + + q.add_token(query.TokenRange(0, 1), query.TokenType.PARTIAL, mktoken(1)) + q.add_token(query.TokenRange(1, 2), query.TokenType.PARTIAL, mktoken(2)) + q.add_token(query.TokenRange(1, 2), query.TokenType.WORD, mktoken(99)) + q.add_token(query.TokenRange(1, 2), query.TokenType.WORD, mktoken(98)) + + assert q.get_tokens(query.TokenRange(0, 2), query.TokenType.WORD) == [] + assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.WORD)) == 2 + + partials = q.get_partials_list(query.TokenRange(0, 2)) + + assert len(partials) == 2 + assert [t.token for t in partials] == [1, 2] + + assert q.find_lookup_word_by_id(4) == 'None' + assert q.find_lookup_word_by_id(99) == '[W]foo' + + +def test_query_struct_incompatible_token(): + q = query.QueryStruct([query.Phrase(query.PhraseType.COUNTRY, 'foo bar')]) + q.add_node(query.BreakType.WORD, query.PhraseType.COUNTRY) + q.add_node(query.BreakType.END, query.PhraseType.NONE) + + q.add_token(query.TokenRange(0, 1), query.TokenType.PARTIAL, mktoken(1)) + q.add_token(query.TokenRange(1, 2), query.TokenType.COUNTRY, mktoken(100)) + + assert q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL) == [] + assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.COUNTRY)) == 1 diff --git a/test/python/api/search/test_db_search_builder.py b/test/python/api/search/test_db_search_builder.py new file mode 100644 index 00000000..9631850e --- /dev/null +++ b/test/python/api/search/test_db_search_builder.py @@ -0,0 +1,395 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for creating abstract searches from token assignments. +""" +import pytest + +from nominatim.api.search.query import Token, TokenRange, BreakType, PhraseType, TokenType, QueryStruct, Phrase +from nominatim.api.search.db_search_builder import SearchBuilder +from nominatim.api.search.token_assignment import TokenAssignment +from nominatim.api.types import SearchDetails +import nominatim.api.search.db_searches as dbs + +class MyToken(Token): + def get_category(self): + return 'this', 'that' + + +def make_query(*args): + q = None + + for tlist in args: + if q is None: + q = QueryStruct([Phrase(PhraseType.NONE, '')]) + else: + q.add_node(BreakType.WORD, PhraseType.NONE) + + start = len(q.nodes) - 1 + for end, ttype, tinfo in tlist: + for tid, word in tinfo: + q.add_token(TokenRange(start, end), ttype, + MyToken(0.5 if ttype == TokenType.PARTIAL else 0.0, tid, 1, word, True)) + + q.add_node(BreakType.END, PhraseType.NONE) + + return q + + +def test_country_search(): + q = make_query([(1, TokenType.COUNTRY, [(2, 'de'), (3, 'en')])]) + builder = SearchBuilder(q, SearchDetails()) + + searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1)))) + + assert len(searches) == 1 + + search = searches[0] + + assert isinstance(search, dbs.CountrySearch) + assert set(search.countries.values) == {'de', 'en'} + + +def test_country_search_with_country_restriction(): + q = make_query([(1, TokenType.COUNTRY, [(2, 'de'), (3, 'en')])]) + builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'en,fr'})) + + searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1)))) + + assert len(searches) == 1 + + search = searches[0] + + assert isinstance(search, dbs.CountrySearch) + assert set(search.countries.values) == {'en'} + + +def test_country_search_with_confllicting_country_restriction(): + q = make_query([(1, TokenType.COUNTRY, [(2, 'de'), (3, 'en')])]) + builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'fr'})) + + searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1)))) + + assert len(searches) == 0 + + +def test_postcode_search_simple(): + q = make_query([(1, TokenType.POSTCODE, [(34, '2367')])]) + builder = SearchBuilder(q, SearchDetails()) + + searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1)))) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.PostcodeSearch) + assert search.postcodes.values == ['2367'] + assert not search.countries.values + assert not search.lookups + assert not search.rankings + + +def test_postcode_with_country(): + q = make_query([(1, TokenType.POSTCODE, [(34, '2367')])], + [(2, TokenType.COUNTRY, [(1, 'xx')])]) + builder = SearchBuilder(q, SearchDetails()) + + searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1), + country=TokenRange(1, 2)))) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.PostcodeSearch) + assert search.postcodes.values == ['2367'] + assert search.countries.values == ['xx'] + assert not search.lookups + assert not search.rankings + + +def test_postcode_with_address(): + q = make_query([(1, TokenType.POSTCODE, [(34, '2367')])], + [(2, TokenType.PARTIAL, [(100, 'word')])]) + builder = SearchBuilder(q, SearchDetails()) + + searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1), + address=[TokenRange(1, 2)]))) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.PostcodeSearch) + assert search.postcodes.values == ['2367'] + assert not search.countries + assert search.lookups + assert not search.rankings + + +def test_postcode_with_address_with_full_word(): + q = make_query([(1, TokenType.POSTCODE, [(34, '2367')])], + [(2, TokenType.PARTIAL, [(100, 'word')]), + (2, TokenType.WORD, [(1, 'full')])]) + builder = SearchBuilder(q, SearchDetails()) + + searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1), + address=[TokenRange(1, 2)]))) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.PostcodeSearch) + assert search.postcodes.values == ['2367'] + assert not search.countries + assert search.lookups + assert len(search.rankings) == 1 + + +@pytest.mark.parametrize('kwargs', [{'viewbox': '0,0,1,1', 'bounded_viewbox': True}, + {'near': '10,10'}]) +def test_category_only(kwargs): + q = make_query([(1, TokenType.CATEGORY, [(2, 'foo')])]) + builder = SearchBuilder(q, SearchDetails.from_kwargs(kwargs)) + + searches = list(builder.build(TokenAssignment(category=TokenRange(0, 1)))) + + assert len(searches) == 1 + + search = searches[0] + + assert isinstance(search, dbs.PoiSearch) + assert search.categories.values == [('this', 'that')] + + +@pytest.mark.parametrize('kwargs', [{'viewbox': '0,0,1,1'}, + {}]) +def test_category_skipped(kwargs): + q = make_query([(1, TokenType.CATEGORY, [(2, 'foo')])]) + builder = SearchBuilder(q, SearchDetails.from_kwargs(kwargs)) + + searches = list(builder.build(TokenAssignment(category=TokenRange(0, 1)))) + + assert len(searches) == 0 + + +def test_name_only_search(): + q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]), + (1, TokenType.WORD, [(100, 'a')])]) + builder = SearchBuilder(q, SearchDetails()) + + searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1)))) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.PlaceSearch) + assert not search.postcodes.values + assert not search.countries.values + assert not search.housenumbers.values + assert not search.qualifiers.values + assert len(search.lookups) == 1 + assert len(search.rankings) == 1 + + +def test_name_with_qualifier(): + q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]), + (1, TokenType.WORD, [(100, 'a')])], + [(2, TokenType.QUALIFIER, [(55, 'hotel')])]) + builder = SearchBuilder(q, SearchDetails()) + + searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1), + qualifier=TokenRange(1, 2)))) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.PlaceSearch) + assert not search.postcodes.values + assert not search.countries.values + assert not search.housenumbers.values + assert search.qualifiers.values == [('this', 'that')] + assert len(search.lookups) == 1 + assert len(search.rankings) == 1 + + +def test_name_with_housenumber_search(): + q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]), + (1, TokenType.WORD, [(100, 'a')])], + [(2, TokenType.HOUSENUMBER, [(66, '66')])]) + builder = SearchBuilder(q, SearchDetails()) + + searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1), + housenumber=TokenRange(1, 2)))) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.PlaceSearch) + assert not search.postcodes.values + assert not search.countries.values + assert search.housenumbers.values == ['66'] + assert len(search.lookups) == 1 + assert len(search.rankings) == 1 + + +def test_name_and_address(): + q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]), + (1, TokenType.WORD, [(100, 'a')])], + [(2, TokenType.PARTIAL, [(2, 'b')]), + (2, TokenType.WORD, [(101, 'b')])], + [(3, TokenType.PARTIAL, [(3, 'c')]), + (3, TokenType.WORD, [(102, 'c')])] + ) + builder = SearchBuilder(q, SearchDetails()) + + searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1), + address=[TokenRange(1, 2), + TokenRange(2, 3)]))) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.PlaceSearch) + assert not search.postcodes.values + assert not search.countries.values + assert not search.housenumbers.values + assert len(search.lookups) == 2 + assert len(search.rankings) == 3 + + +def test_name_and_complex_address(): + q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]), + (1, TokenType.WORD, [(100, 'a')])], + [(2, TokenType.PARTIAL, [(2, 'b')]), + (3, TokenType.WORD, [(101, 'bc')])], + [(3, TokenType.PARTIAL, [(3, 'c')])], + [(4, TokenType.PARTIAL, [(4, 'd')]), + (4, TokenType.WORD, [(103, 'd')])] + ) + builder = SearchBuilder(q, SearchDetails()) + + searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1), + address=[TokenRange(1, 2), + TokenRange(2, 4)]))) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.PlaceSearch) + assert not search.postcodes.values + assert not search.countries.values + assert not search.housenumbers.values + assert len(search.lookups) == 2 + assert len(search.rankings) == 2 + + +def test_name_only_near_search(): + q = make_query([(1, TokenType.CATEGORY, [(88, 'g')])], + [(2, TokenType.PARTIAL, [(1, 'a')]), + (2, TokenType.WORD, [(100, 'a')])]) + builder = SearchBuilder(q, SearchDetails()) + + searches = list(builder.build(TokenAssignment(name=TokenRange(1, 2), + category=TokenRange(0, 1)))) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.NearSearch) + assert isinstance(search.search, dbs.PlaceSearch) + + +def test_name_only_search_with_category(): + q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]), + (1, TokenType.WORD, [(100, 'a')])]) + builder = SearchBuilder(q, SearchDetails.from_kwargs({'categories': [('foo', 'bar')]})) + + searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1)))) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.NearSearch) + assert isinstance(search.search, dbs.PlaceSearch) + + +def test_name_only_search_with_countries(): + q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]), + (1, TokenType.WORD, [(100, 'a')])]) + builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'de,en'})) + + searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1)))) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.PlaceSearch) + assert not search.postcodes.values + assert set(search.countries.values) == {'de', 'en'} + assert not search.housenumbers.values + + +def make_counted_searches(name_part, name_full, address_part, address_full): + q = QueryStruct([Phrase(PhraseType.NONE, '')]) + for i in range(2): + q.add_node(BreakType.WORD, PhraseType.NONE) + q.add_node(BreakType.END, PhraseType.NONE) + + q.add_token(TokenRange(0, 1), TokenType.PARTIAL, + MyToken(0.5, 1, name_part, 'name_part', True)) + q.add_token(TokenRange(0, 1), TokenType.WORD, + MyToken(0, 101, name_full, 'name_full', True)) + q.add_token(TokenRange(1, 2), TokenType.PARTIAL, + MyToken(0.5, 2, address_part, 'address_part', True)) + q.add_token(TokenRange(1, 2), TokenType.WORD, + MyToken(0, 102, address_full, 'address_full', True)) + + builder = SearchBuilder(q, SearchDetails()) + + return list(builder.build(TokenAssignment(name=TokenRange(0, 1), + address=[TokenRange(1, 2)]))) + + +def test_infrequent_partials_in_name(): + searches = make_counted_searches(1, 1, 1, 1) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.PlaceSearch) + assert len(search.lookups) == 2 + assert len(search.rankings) == 2 + + assert set((l.column, l.lookup_type) for l in search.lookups) == \ + {('name_vector', 'lookup_all'), ('nameaddress_vector', 'restrict')} + + +def test_frequent_partials_in_name_but_not_in_address(): + searches = make_counted_searches(10000, 1, 1, 1) + + assert len(searches) == 1 + search = searches[0] + + assert isinstance(search, dbs.PlaceSearch) + assert len(search.lookups) == 2 + assert len(search.rankings) == 2 + + assert set((l.column, l.lookup_type) for l in search.lookups) == \ + {('nameaddress_vector', 'lookup_all'), ('name_vector', 'restrict')} + + +def test_frequent_partials_in_name_and_address(): + searches = make_counted_searches(10000, 1, 10000, 1) + + assert len(searches) == 2 + + assert all(isinstance(s, dbs.PlaceSearch) for s in searches) + searches.sort(key=lambda s: s.penalty) + + assert set((l.column, l.lookup_type) for l in searches[0].lookups) == \ + {('name_vector', 'lookup_any'), ('nameaddress_vector', 'restrict')} + assert set((l.column, l.lookup_type) for l in searches[1].lookups) == \ + {('nameaddress_vector', 'lookup_all'), ('name_vector', 'lookup_all')} diff --git a/test/python/api/search/test_icu_query_analyzer.py b/test/python/api/search/test_icu_query_analyzer.py new file mode 100644 index 00000000..78cd2c4d --- /dev/null +++ b/test/python/api/search/test_icu_query_analyzer.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for query analyzer for ICU tokenizer. +""" +from pathlib import Path + +import pytest +import pytest_asyncio + +from nominatim.api import NominatimAPIAsync +from nominatim.api.search.query import Phrase, PhraseType, TokenType, BreakType +import nominatim.api.search.icu_tokenizer as tok +from nominatim.api.logging import set_log_output, get_and_disable + +async def add_word(conn, word_id, word_token, wtype, word, info = None): + t = conn.t.meta.tables['word'] + await conn.execute(t.insert(), {'word_id': word_id, + 'word_token': word_token, + 'type': wtype, + 'word': word, + 'info': info}) + + +def make_phrase(query): + return [Phrase(PhraseType.NONE, s) for s in query.split(',')] + +@pytest_asyncio.fixture +async def conn(table_factory): + """ Create an asynchronous SQLAlchemy engine for the test DB. + """ + table_factory('nominatim_properties', + definition='property TEXT, value TEXT', + content=(('tokenizer_import_normalisation', ':: lower();'), + ('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '"))) + table_factory('word', + definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') + + api = NominatimAPIAsync(Path('/invalid'), {}) + async with api.begin() as conn: + yield conn + await api.close() + + +@pytest.mark.asyncio +async def test_empty_phrase(conn): + ana = await tok.create_query_analyzer(conn) + + query = await ana.analyze_query([]) + + assert len(query.source) == 0 + assert query.num_token_slots() == 0 + + +@pytest.mark.asyncio +async def test_single_phrase_with_unknown_terms(conn): + ana = await tok.create_query_analyzer(conn) + + await add_word(conn, 1, 'foo', 'w', 'FOO') + + query = await ana.analyze_query(make_phrase('foo BAR')) + + assert len(query.source) == 1 + assert query.source[0].ptype == PhraseType.NONE + assert query.source[0].text == 'foo bar' + + assert query.num_token_slots() == 2 + assert len(query.nodes[0].starting) == 1 + assert not query.nodes[1].starting + + +@pytest.mark.asyncio +async def test_multiple_phrases(conn): + ana = await tok.create_query_analyzer(conn) + + await add_word(conn, 1, 'one', 'w', 'one') + await add_word(conn, 2, 'two', 'w', 'two') + await add_word(conn, 100, 'one two', 'W', 'one two') + await add_word(conn, 3, 'three', 'w', 'three') + + query = await ana.analyze_query(make_phrase('one two,three')) + + assert len(query.source) == 2 + + +@pytest.mark.asyncio +async def test_splitting_in_transliteration(conn): + ana = await tok.create_query_analyzer(conn) + + await add_word(conn, 1, 'mä', 'W', 'ma') + await add_word(conn, 2, 'fo', 'W', 'fo') + + query = await ana.analyze_query(make_phrase('mäfo')) + + assert query.num_token_slots() == 2 + assert query.nodes[0].starting + assert query.nodes[1].starting + assert query.nodes[1].btype == BreakType.TOKEN + + +@pytest.mark.asyncio +@pytest.mark.parametrize('term,order', [('23456', ['POSTCODE', 'HOUSENUMBER', 'WORD', 'PARTIAL']), + ('3', ['HOUSENUMBER', 'POSTCODE', 'WORD', 'PARTIAL']) + ]) +async def test_penalty_postcodes_and_housenumbers(conn, term, order): + ana = await tok.create_query_analyzer(conn) + + await add_word(conn, 1, term, 'P', None) + await add_word(conn, 2, term, 'H', term) + await add_word(conn, 3, term, 'w', term) + await add_word(conn, 4, term, 'W', term) + + query = await ana.analyze_query(make_phrase(term)) + + assert query.num_token_slots() == 1 + + torder = [(tl.tokens[0].penalty, tl.ttype) for tl in query.nodes[0].starting] + torder.sort() + + assert [t[1] for t in torder] == [TokenType[o] for o in order] + +@pytest.mark.asyncio +async def test_category_words_only_at_beginning(conn): + ana = await tok.create_query_analyzer(conn) + + await add_word(conn, 1, 'foo', 'S', 'FOO', {'op': 'in'}) + await add_word(conn, 2, 'bar', 'w', 'BAR') + + query = await ana.analyze_query(make_phrase('foo BAR foo')) + + assert query.num_token_slots() == 3 + assert len(query.nodes[0].starting) == 1 + assert query.nodes[0].starting[0].ttype == TokenType.CATEGORY + assert not query.nodes[2].starting + + +@pytest.mark.asyncio +async def test_qualifier_words(conn): + ana = await tok.create_query_analyzer(conn) + + await add_word(conn, 1, 'foo', 'S', None, {'op': '-'}) + await add_word(conn, 2, 'bar', 'w', None) + + query = await ana.analyze_query(make_phrase('foo BAR foo BAR foo')) + + assert query.num_token_slots() == 5 + assert set(t.ttype for t in query.nodes[0].starting) == {TokenType.CATEGORY, TokenType.QUALIFIER} + assert set(t.ttype for t in query.nodes[2].starting) == {TokenType.QUALIFIER} + assert set(t.ttype for t in query.nodes[4].starting) == {TokenType.CATEGORY, TokenType.QUALIFIER} + + +@pytest.mark.asyncio +async def test_add_unknown_housenumbers(conn): + ana = await tok.create_query_analyzer(conn) + + await add_word(conn, 1, '23', 'H', '23') + + query = await ana.analyze_query(make_phrase('466 23 99834 34a')) + + assert query.num_token_slots() == 4 + assert query.nodes[0].starting[0].ttype == TokenType.HOUSENUMBER + assert len(query.nodes[0].starting[0].tokens) == 1 + assert query.nodes[0].starting[0].tokens[0].token == 0 + assert query.nodes[1].starting[0].ttype == TokenType.HOUSENUMBER + assert len(query.nodes[1].starting[0].tokens) == 1 + assert query.nodes[1].starting[0].tokens[0].token == 1 + assert not query.nodes[2].starting + assert not query.nodes[3].starting + + +@pytest.mark.asyncio +@pytest.mark.parametrize('logtype', ['text', 'html']) +async def test_log_output(conn, logtype): + + ana = await tok.create_query_analyzer(conn) + + await add_word(conn, 1, 'foo', 'w', 'FOO') + + set_log_output(logtype) + await ana.analyze_query(make_phrase('foo')) + + assert get_and_disable() diff --git a/test/python/api/search/test_legacy_query_analyzer.py b/test/python/api/search/test_legacy_query_analyzer.py new file mode 100644 index 00000000..c2115853 --- /dev/null +++ b/test/python/api/search/test_legacy_query_analyzer.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for query analyzer for legacy tokenizer. +""" +from pathlib import Path + +import pytest +import pytest_asyncio + +from nominatim.api import NominatimAPIAsync +from nominatim.api.search.query import Phrase, PhraseType, TokenType, BreakType +import nominatim.api.search.legacy_tokenizer as tok +from nominatim.api.logging import set_log_output, get_and_disable + + +async def add_word(conn, word_id, word_token, word, count): + t = conn.t.meta.tables['word'] + await conn.execute(t.insert(), {'word_id': word_id, + 'word_token': word_token, + 'search_name_count': count, + 'word': word}) + + +async def add_housenumber(conn, word_id, hnr): + t = conn.t.meta.tables['word'] + await conn.execute(t.insert(), {'word_id': word_id, + 'word_token': ' ' + hnr, + 'word': hnr, + 'class': 'place', + 'type': 'house'}) + + +async def add_postcode(conn, word_id, postcode): + t = conn.t.meta.tables['word'] + await conn.execute(t.insert(), {'word_id': word_id, + 'word_token': ' ' + postcode, + 'word': postcode, + 'class': 'place', + 'type': 'postcode'}) + + +async def add_special_term(conn, word_id, word_token, cls, typ, op): + t = conn.t.meta.tables['word'] + await conn.execute(t.insert(), {'word_id': word_id, + 'word_token': word_token, + 'word': word_token, + 'class': cls, + 'type': typ, + 'operator': op}) + + +def make_phrase(query): + return [Phrase(PhraseType.NONE, s) for s in query.split(',')] + + +@pytest_asyncio.fixture +async def conn(table_factory, temp_db_cursor): + """ Create an asynchronous SQLAlchemy engine for the test DB. + """ + table_factory('nominatim_properties', + definition='property TEXT, value TEXT', + content=(('tokenizer_maxwordfreq', '10000'), )) + table_factory('word', + definition="""word_id INT, word_token TEXT, word TEXT, + class TEXT, type TEXT, country_code TEXT, + search_name_count INT, operator TEXT + """) + + temp_db_cursor.execute("""CREATE OR REPLACE FUNCTION make_standard_name(name TEXT) + RETURNS TEXT AS $$ SELECT lower(name); $$ LANGUAGE SQL;""") + + api = NominatimAPIAsync(Path('/invalid'), {}) + async with api.begin() as conn: + yield conn + await api.close() + + +@pytest.mark.asyncio +async def test_empty_phrase(conn): + ana = await tok.create_query_analyzer(conn) + + query = await ana.analyze_query([]) + + assert len(query.source) == 0 + assert query.num_token_slots() == 0 + + +@pytest.mark.asyncio +async def test_single_phrase_with_unknown_terms(conn): + ana = await tok.create_query_analyzer(conn) + + await add_word(conn, 1, 'foo', 'FOO', 3) + + query = await ana.analyze_query(make_phrase('foo BAR')) + + assert len(query.source) == 1 + assert query.source[0].ptype == PhraseType.NONE + assert query.source[0].text == 'foo bar' + + assert query.num_token_slots() == 2 + assert len(query.nodes[0].starting) == 1 + assert not query.nodes[1].starting + + +@pytest.mark.asyncio +async def test_multiple_phrases(conn): + ana = await tok.create_query_analyzer(conn) + + await add_word(conn, 1, 'one', 'one', 13) + await add_word(conn, 2, 'two', 'two', 45) + await add_word(conn, 100, 'one two', 'one two', 3) + await add_word(conn, 3, 'three', 'three', 4584) + + query = await ana.analyze_query(make_phrase('one two,three')) + + assert len(query.source) == 2 + + +@pytest.mark.asyncio +async def test_housenumber_token(conn): + ana = await tok.create_query_analyzer(conn) + + await add_housenumber(conn, 556, '45 a') + + query = await ana.analyze_query(make_phrase('45 A')) + + assert query.num_token_slots() == 2 + assert len(query.nodes[0].starting) == 2 + + query.nodes[0].starting.sort(key=lambda tl: tl.end) + + hn1 = query.nodes[0].starting[0] + assert hn1.ttype == TokenType.HOUSENUMBER + assert hn1.end == 1 + assert hn1.tokens[0].token == 0 + + hn2 = query.nodes[0].starting[1] + assert hn2.ttype == TokenType.HOUSENUMBER + assert hn2.end == 2 + assert hn2.tokens[0].token == 556 + + +@pytest.mark.asyncio +async def test_postcode_token(conn): + ana = await tok.create_query_analyzer(conn) + + await add_postcode(conn, 34, '45ax') + + query = await ana.analyze_query(make_phrase('45AX')) + + assert query.num_token_slots() == 1 + assert [tl.ttype for tl in query.nodes[0].starting] == [TokenType.POSTCODE] + + +@pytest.mark.asyncio +async def test_partial_tokens(conn): + ana = await tok.create_query_analyzer(conn) + + await add_word(conn, 1, ' foo', 'foo', 99) + await add_word(conn, 1, 'foo', 'FOO', 99) + await add_word(conn, 1, 'bar', 'FOO', 990000) + + query = await ana.analyze_query(make_phrase('foo bar')) + + assert query.num_token_slots() == 2 + + first = query.nodes[0].starting + first.sort(key=lambda tl: tl.tokens[0].penalty) + assert [tl.ttype for tl in first] == [TokenType.WORD, TokenType.PARTIAL] + assert all(tl.tokens[0].lookup_word == 'foo' for tl in first) + + second = query.nodes[1].starting + assert [tl.ttype for tl in second] == [TokenType.PARTIAL] + assert not second[0].tokens[0].is_indexed + + +@pytest.mark.asyncio +@pytest.mark.parametrize('term,order', [('23456', ['POSTCODE', 'HOUSENUMBER', 'WORD', 'PARTIAL']), + ('3', ['HOUSENUMBER', 'POSTCODE', 'WORD', 'PARTIAL']) + ]) +async def test_penalty_postcodes_and_housenumbers(conn, term, order): + ana = await tok.create_query_analyzer(conn) + + await add_postcode(conn, 1, term) + await add_housenumber(conn, 2, term) + await add_word(conn, 3, term, term, 5) + await add_word(conn, 4, ' ' + term, term, 1) + + query = await ana.analyze_query(make_phrase(term)) + + assert query.num_token_slots() == 1 + + torder = [(tl.tokens[0].penalty, tl.ttype) for tl in query.nodes[0].starting] + print(query.nodes[0].starting) + torder.sort() + + assert [t[1] for t in torder] == [TokenType[o] for o in order] + + +@pytest.mark.asyncio +async def test_category_words_only_at_beginning(conn): + ana = await tok.create_query_analyzer(conn) + + await add_special_term(conn, 1, 'foo', 'amenity', 'restaurant', 'in') + await add_word(conn, 2, ' bar', 'BAR', 1) + + query = await ana.analyze_query(make_phrase('foo BAR foo')) + + assert query.num_token_slots() == 3 + assert len(query.nodes[0].starting) == 1 + assert query.nodes[0].starting[0].ttype == TokenType.CATEGORY + assert not query.nodes[2].starting + + +@pytest.mark.asyncio +async def test_qualifier_words(conn): + ana = await tok.create_query_analyzer(conn) + + await add_special_term(conn, 1, 'foo', 'amenity', 'restaurant', '-') + await add_word(conn, 2, ' bar', 'w', None) + + query = await ana.analyze_query(make_phrase('foo BAR foo BAR foo')) + + assert query.num_token_slots() == 5 + assert set(t.ttype for t in query.nodes[0].starting) == {TokenType.CATEGORY, TokenType.QUALIFIER} + assert set(t.ttype for t in query.nodes[2].starting) == {TokenType.QUALIFIER} + assert set(t.ttype for t in query.nodes[4].starting) == {TokenType.CATEGORY, TokenType.QUALIFIER} + + +@pytest.mark.asyncio +@pytest.mark.parametrize('logtype', ['text', 'html']) +async def test_log_output(conn, logtype): + ana = await tok.create_query_analyzer(conn) + + await add_word(conn, 1, 'foo', 'FOO', 99) + + set_log_output(logtype) + await ana.analyze_query(make_phrase('foo')) + + assert get_and_disable() diff --git a/test/python/api/search/test_query_analyzer_factory.py b/test/python/api/search/test_query_analyzer_factory.py new file mode 100644 index 00000000..2d113e3e --- /dev/null +++ b/test/python/api/search/test_query_analyzer_factory.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for query analyzer creation. +""" +from pathlib import Path + +import pytest + +from nominatim.api import NominatimAPIAsync +from nominatim.api.search.query_analyzer_factory import make_query_analyzer +from nominatim.api.search.icu_tokenizer import ICUQueryAnalyzer + +@pytest.mark.asyncio +async def test_import_icu_tokenizer(table_factory): + table_factory('nominatim_properties', + definition='property TEXT, value TEXT', + content=(('tokenizer', 'icu'), + ('tokenizer_import_normalisation', ':: lower();'), + ('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '"))) + + api = NominatimAPIAsync(Path('/invalid'), {}) + async with api.begin() as conn: + ana = await make_query_analyzer(conn) + + assert isinstance(ana, ICUQueryAnalyzer) + await api.close() + + +@pytest.mark.asyncio +async def test_import_missing_property(table_factory): + api = NominatimAPIAsync(Path('/invalid'), {}) + table_factory('nominatim_properties', + definition='property TEXT, value TEXT') + + async with api.begin() as conn: + with pytest.raises(ValueError, match='Property.*not found'): + await make_query_analyzer(conn) + await api.close() + + +@pytest.mark.asyncio +async def test_import_missing_module(table_factory): + api = NominatimAPIAsync(Path('/invalid'), {}) + table_factory('nominatim_properties', + definition='property TEXT, value TEXT', + content=(('tokenizer', 'missing'),)) + + async with api.begin() as conn: + with pytest.raises(RuntimeError, match='Tokenizer not found'): + await make_query_analyzer(conn) + await api.close() + diff --git a/test/python/api/search/test_search_country.py b/test/python/api/search/test_search_country.py new file mode 100644 index 00000000..bb0abc39 --- /dev/null +++ b/test/python/api/search/test_search_country.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for running the country searcher. +""" +import pytest + +import nominatim.api as napi +from nominatim.api.types import SearchDetails +from nominatim.api.search.db_searches import CountrySearch +from nominatim.api.search.db_search_fields import WeightedStrings + + +def run_search(apiobj, global_penalty, ccodes, + country_penalties=None, details=SearchDetails()): + if country_penalties is None: + country_penalties = [0.0] * len(ccodes) + + class MySearchData: + penalty = global_penalty + countries = WeightedStrings(ccodes, country_penalties) + + search = CountrySearch(MySearchData()) + + async def run(): + async with apiobj.api._async_api.begin() as conn: + return await search.lookup(conn, details) + + return apiobj.async_to_sync(run()) + + +def test_find_from_placex(apiobj): + apiobj.add_placex(place_id=55, class_='boundary', type='administrative', + rank_search=4, rank_address=4, + name={'name': 'Lolaland'}, + country_code='yw', + centroid=(10, 10), + geometry='POLYGON((9.5 9.5, 9.5 10.5, 10.5 10.5, 10.5 9.5, 9.5 9.5))') + + results = run_search(apiobj, 0.5, ['de', 'yw'], [0.0, 0.3]) + + assert len(results) == 1 + assert results[0].place_id == 55 + assert results[0].accuracy == 0.8 + +def test_find_from_fallback_countries(apiobj): + apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))') + apiobj.add_country_name('ro', {'name': 'România'}) + + results = run_search(apiobj, 0.0, ['ro']) + + assert len(results) == 1 + assert results[0].names == {'name': 'România'} + + +def test_find_none(apiobj): + assert len(run_search(apiobj, 0.0, ['xx'])) == 0 diff --git a/test/python/api/search/test_search_near.py b/test/python/api/search/test_search_near.py new file mode 100644 index 00000000..cfbdadb2 --- /dev/null +++ b/test/python/api/search/test_search_near.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for running the near searcher. +""" +import pytest + +import nominatim.api as napi +from nominatim.api.types import SearchDetails +from nominatim.api.search.db_searches import NearSearch, PlaceSearch +from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories,\ + FieldLookup, FieldRanking, RankedTokens + + +def run_search(apiobj, global_penalty, cat, cat_penalty=None, + details=SearchDetails()): + + class PlaceSearchData: + penalty = 0.0 + postcodes = WeightedStrings([], []) + countries = WeightedStrings([], []) + housenumbers = WeightedStrings([], []) + qualifiers = WeightedStrings([], []) + lookups = [FieldLookup('name_vector', [56], 'lookup_all')] + rankings = [] + + place_search = PlaceSearch(0.0, PlaceSearchData(), 2) + + if cat_penalty is None: + cat_penalty = [0.0] * len(cat) + + near_search = NearSearch(0.1, WeightedCategories(cat, cat_penalty), place_search) + + async def run(): + async with apiobj.api._async_api.begin() as conn: + return await near_search.lookup(conn, details) + + results = apiobj.async_to_sync(run()) + results.sort(key=lambda r: r.accuracy) + + return results + + +def test_no_results_inner_query(apiobj): + assert not run_search(apiobj, 0.4, [('this', 'that')]) + + +class TestNearSearch: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_placex(place_id=100, country_code='us', + centroid=(5.6, 4.3)) + apiobj.add_search_name(100, names=[56], country_code='us', + centroid=(5.6, 4.3)) + apiobj.add_placex(place_id=101, country_code='mx', + centroid=(-10.3, 56.9)) + apiobj.add_search_name(101, names=[56], country_code='mx', + centroid=(-10.3, 56.9)) + + + def test_near_in_placex(self, apiobj): + apiobj.add_placex(place_id=22, class_='amenity', type='bank', + centroid=(5.6001, 4.2994)) + apiobj.add_placex(place_id=23, class_='amenity', type='bench', + centroid=(5.6001, 4.2994)) + + results = run_search(apiobj, 0.1, [('amenity', 'bank')]) + + assert [r.place_id for r in results] == [22] + + + def test_multiple_types_near_in_placex(self, apiobj): + apiobj.add_placex(place_id=22, class_='amenity', type='bank', + importance=0.002, + centroid=(5.6001, 4.2994)) + apiobj.add_placex(place_id=23, class_='amenity', type='bench', + importance=0.001, + centroid=(5.6001, 4.2994)) + + results = run_search(apiobj, 0.1, [('amenity', 'bank'), + ('amenity', 'bench')]) + + assert [r.place_id for r in results] == [22, 23] + + + def test_near_in_classtype(self, apiobj): + apiobj.add_placex(place_id=22, class_='amenity', type='bank', + centroid=(5.6, 4.34)) + apiobj.add_placex(place_id=23, class_='amenity', type='bench', + centroid=(5.6, 4.34)) + apiobj.add_class_type_table('amenity', 'bank') + apiobj.add_class_type_table('amenity', 'bench') + + results = run_search(apiobj, 0.1, [('amenity', 'bank')]) + + assert [r.place_id for r in results] == [22] + diff --git a/test/python/api/search/test_search_places.py b/test/python/api/search/test_search_places.py new file mode 100644 index 00000000..df369b81 --- /dev/null +++ b/test/python/api/search/test_search_places.py @@ -0,0 +1,385 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for running the generic place searcher. +""" +import pytest + +import nominatim.api as napi +from nominatim.api.types import SearchDetails +from nominatim.api.search.db_searches import PlaceSearch +from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories,\ + FieldLookup, FieldRanking, RankedTokens + +def run_search(apiobj, global_penalty, lookup, ranking, count=2, + hnrs=[], pcs=[], ccodes=[], quals=[], + details=SearchDetails()): + class MySearchData: + penalty = global_penalty + postcodes = WeightedStrings(pcs, [0.0] * len(pcs)) + countries = WeightedStrings(ccodes, [0.0] * len(ccodes)) + housenumbers = WeightedStrings(hnrs, [0.0] * len(hnrs)) + qualifiers = WeightedCategories(quals, [0.0] * len(quals)) + lookups = lookup + rankings = ranking + + search = PlaceSearch(0.0, MySearchData(), count) + + async def run(): + async with apiobj.api._async_api.begin() as conn: + return await search.lookup(conn, details) + + results = apiobj.async_to_sync(run()) + results.sort(key=lambda r: r.accuracy) + + return results + + +class TestNameOnlySearches: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_placex(place_id=100, country_code='us', + centroid=(5.6, 4.3)) + apiobj.add_search_name(100, names=[1,2,10,11], country_code='us', + centroid=(5.6, 4.3)) + apiobj.add_placex(place_id=101, country_code='mx', + centroid=(-10.3, 56.9)) + apiobj.add_search_name(101, names=[1,2,20,21], country_code='mx', + centroid=(-10.3, 56.9)) + + + @pytest.mark.parametrize('lookup_type', ['lookup_all', 'restrict']) + @pytest.mark.parametrize('rank,res', [([10], [100, 101]), + ([20], [101, 100])]) + def test_lookup_all_match(self, apiobj, lookup_type, rank, res): + lookup = FieldLookup('name_vector', [1,2], lookup_type) + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)]) + + results = run_search(apiobj, 0.1, [lookup], [ranking]) + + assert [r.place_id for r in results] == res + + + @pytest.mark.parametrize('lookup_type', ['lookup_all', 'restrict']) + def test_lookup_all_partial_match(self, apiobj, lookup_type): + lookup = FieldLookup('name_vector', [1,20], lookup_type) + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking]) + + assert len(results) == 1 + assert results[0].place_id == 101 + + @pytest.mark.parametrize('rank,res', [([10], [100, 101]), + ([20], [101, 100])]) + def test_lookup_any_match(self, apiobj, rank, res): + lookup = FieldLookup('name_vector', [11,21], 'lookup_any') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)]) + + results = run_search(apiobj, 0.1, [lookup], [ranking]) + + assert [r.place_id for r in results] == res + + + def test_lookup_any_partial_match(self, apiobj): + lookup = FieldLookup('name_vector', [20], 'lookup_all') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking]) + + assert len(results) == 1 + assert results[0].place_id == 101 + + + @pytest.mark.parametrize('cc,res', [('us', 100), ('mx', 101)]) + def test_lookup_restrict_country(self, apiobj, cc, res): + lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], ccodes=[cc]) + + assert [r.place_id for r in results] == [res] + + + def test_lookup_restrict_placeid(self, apiobj): + lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], + details=SearchDetails(excluded=[101])) + + assert [r.place_id for r in results] == [100] + + + @pytest.mark.parametrize('geom', [napi.GeometryFormat.GEOJSON, + napi.GeometryFormat.KML, + napi.GeometryFormat.SVG, + napi.GeometryFormat.TEXT]) + def test_return_geometries(self, apiobj, geom): + lookup = FieldLookup('name_vector', [20], 'lookup_all') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], + details=SearchDetails(geometry_output=geom)) + + assert geom.name.lower() in results[0].geometry + + + @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.7,4.0,6.0,5.0']) + def test_prefer_viewbox(self, apiobj, viewbox): + lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking]) + assert [r.place_id for r in results] == [101, 100] + + results = run_search(apiobj, 0.1, [lookup], [ranking], + details=SearchDetails.from_kwargs({'viewbox': viewbox})) + assert [r.place_id for r in results] == [100, 101] + + + def test_force_viewbox(self, apiobj): + lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + + details=SearchDetails.from_kwargs({'viewbox': '5.0,4.0,6.0,5.0', + 'bounded_viewbox': True}) + + results = run_search(apiobj, 0.1, [lookup], [], details=details) + assert [r.place_id for r in results] == [100] + + + def test_prefer_near(self, apiobj): + lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking]) + assert [r.place_id for r in results] == [101, 100] + + results = run_search(apiobj, 0.1, [lookup], [ranking], + details=SearchDetails.from_kwargs({'near': '5.6,4.3'})) + results.sort(key=lambda r: -r.importance) + assert [r.place_id for r in results] == [100, 101] + + + def test_force_near(self, apiobj): + lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + + details=SearchDetails.from_kwargs({'near': '5.6,4.3', + 'near_radius': 0.11}) + + results = run_search(apiobj, 0.1, [lookup], [], details=details) + + assert [r.place_id for r in results] == [100] + + +class TestStreetWithHousenumber: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_placex(place_id=1, class_='place', type='house', + parent_place_id=1000, + housenumber='20 a', country_code='es') + apiobj.add_placex(place_id=2, class_='place', type='house', + parent_place_id=1000, + housenumber='21;22', country_code='es') + apiobj.add_placex(place_id=1000, class_='highway', type='residential', + rank_search=26, rank_address=26, + country_code='es') + apiobj.add_search_name(1000, names=[1,2,10,11], + search_rank=26, address_rank=26, + country_code='es') + apiobj.add_placex(place_id=91, class_='place', type='house', + parent_place_id=2000, + housenumber='20', country_code='pt') + apiobj.add_placex(place_id=92, class_='place', type='house', + parent_place_id=2000, + housenumber='22', country_code='pt') + apiobj.add_placex(place_id=93, class_='place', type='house', + parent_place_id=2000, + housenumber='24', country_code='pt') + apiobj.add_placex(place_id=2000, class_='highway', type='residential', + rank_search=26, rank_address=26, + country_code='pt') + apiobj.add_search_name(2000, names=[1,2,20,21], + search_rank=26, address_rank=26, + country_code='pt') + + + @pytest.mark.parametrize('hnr,res', [('20', [91, 1]), ('20 a', [1]), + ('21', [2]), ('22', [2, 92]), + ('24', [93]), ('25', [])]) + def test_lookup_by_single_housenumber(self, apiobj, hnr, res): + lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=[hnr]) + + assert [r.place_id for r in results] == res + [1000, 2000] + + + @pytest.mark.parametrize('cc,res', [('es', [2, 1000]), ('pt', [92, 2000])]) + def test_lookup_with_country_restriction(self, apiobj, cc, res): + lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], + ccodes=[cc]) + + assert [r.place_id for r in results] == res + + + def test_lookup_exclude_housenumber_placeid(self, apiobj): + lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], + details=SearchDetails(excluded=[92])) + + assert [r.place_id for r in results] == [2, 1000, 2000] + + + def test_lookup_exclude_street_placeid(self, apiobj): + lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], + details=SearchDetails(excluded=[1000])) + + assert [r.place_id for r in results] == [2, 92, 2000] + + + @pytest.mark.parametrize('geom', [napi.GeometryFormat.GEOJSON, + napi.GeometryFormat.KML, + napi.GeometryFormat.SVG, + napi.GeometryFormat.TEXT]) + def test_return_geometries(self, apiobj, geom): + lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + + results = run_search(apiobj, 0.1, [lookup], [], hnrs=['20', '21', '22'], + details=SearchDetails(geometry_output=geom)) + + assert results + assert all(geom.name.lower() in r.geometry for r in results) + + +class TestInterpolations: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_placex(place_id=990, class_='highway', type='service', + rank_search=27, rank_address=27, + centroid=(10.0, 10.0), + geometry='LINESTRING(9.995 10, 10.005 10)') + apiobj.add_search_name(990, names=[111], + search_rank=27, address_rank=27) + apiobj.add_placex(place_id=991, class_='place', type='house', + parent_place_id=990, + rank_search=30, rank_address=30, + housenumber='23', + centroid=(10.0, 10.00002)) + apiobj.add_osmline(place_id=992, + parent_place_id=990, + startnumber=21, endnumber=29, step=2, + centroid=(10.0, 10.00001), + geometry='LINESTRING(9.995 10.00001, 10.005 10.00001)') + + + @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])]) + def test_lookup_housenumber(self, apiobj, hnr, res): + lookup = FieldLookup('name_vector', [111], 'lookup_all') + + results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr]) + + assert [r.place_id for r in results] == res + [990] + + +class TestTiger: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_placex(place_id=990, class_='highway', type='service', + rank_search=27, rank_address=27, + country_code='us', + centroid=(10.0, 10.0), + geometry='LINESTRING(9.995 10, 10.005 10)') + apiobj.add_search_name(990, names=[111], country_code='us', + search_rank=27, address_rank=27) + apiobj.add_placex(place_id=991, class_='place', type='house', + parent_place_id=990, + rank_search=30, rank_address=30, + housenumber='23', + country_code='us', + centroid=(10.0, 10.00002)) + apiobj.add_tiger(place_id=992, + parent_place_id=990, + startnumber=21, endnumber=29, step=2, + centroid=(10.0, 10.00001), + geometry='LINESTRING(9.995 10.00001, 10.005 10.00001)') + + + @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])]) + def test_lookup_housenumber(self, apiobj, hnr, res): + lookup = FieldLookup('name_vector', [111], 'lookup_all') + + results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr]) + + assert [r.place_id for r in results] == res + [990] + + +class TestLayersRank30: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_placex(place_id=223, class_='place', type='house', + housenumber='1', + rank_address=30, + rank_search=30) + apiobj.add_search_name(223, names=[34], + importance=0.0009, + address_rank=30, search_rank=30) + apiobj.add_placex(place_id=224, class_='amenity', type='toilet', + rank_address=30, + rank_search=30) + apiobj.add_search_name(224, names=[34], + importance=0.0008, + address_rank=30, search_rank=30) + apiobj.add_placex(place_id=225, class_='man_made', type='tower', + rank_address=0, + rank_search=30) + apiobj.add_search_name(225, names=[34], + importance=0.0007, + address_rank=0, search_rank=30) + apiobj.add_placex(place_id=226, class_='railway', type='station', + rank_address=0, + rank_search=30) + apiobj.add_search_name(226, names=[34], + importance=0.0006, + address_rank=0, search_rank=30) + apiobj.add_placex(place_id=227, class_='natural', type='cave', + rank_address=0, + rank_search=30) + apiobj.add_search_name(227, names=[34], + importance=0.0005, + address_rank=0, search_rank=30) + + + @pytest.mark.parametrize('layer,res', [(napi.DataLayer.ADDRESS, [223]), + (napi.DataLayer.POI, [224]), + (napi.DataLayer.ADDRESS | napi.DataLayer.POI, [223, 224]), + (napi.DataLayer.MANMADE, [225]), + (napi.DataLayer.RAILWAY, [226]), + (napi.DataLayer.NATURAL, [227]), + (napi.DataLayer.MANMADE | napi.DataLayer.NATURAL, [225, 227]), + (napi.DataLayer.MANMADE | napi.DataLayer.RAILWAY, [225, 226])]) + def test_layers_rank30(self, apiobj, layer, res): + lookup = FieldLookup('name_vector', [34], 'lookup_any') + + results = run_search(apiobj, 0.1, [lookup], [], + details=SearchDetails(layers=layer)) + + assert [r.place_id for r in results] == res diff --git a/test/python/api/search/test_search_poi.py b/test/python/api/search/test_search_poi.py new file mode 100644 index 00000000..b80c0752 --- /dev/null +++ b/test/python/api/search/test_search_poi.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for running the POI searcher. +""" +import pytest + +import nominatim.api as napi +from nominatim.api.types import SearchDetails +from nominatim.api.search.db_searches import PoiSearch +from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories + + +def run_search(apiobj, global_penalty, poitypes, poi_penalties=None, + ccodes=[], details=SearchDetails()): + if poi_penalties is None: + poi_penalties = [0.0] * len(poitypes) + + class MySearchData: + penalty = global_penalty + qualifiers = WeightedCategories(poitypes, poi_penalties) + countries = WeightedStrings(ccodes, [0.0] * len(ccodes)) + + search = PoiSearch(MySearchData()) + + async def run(): + async with apiobj.api._async_api.begin() as conn: + return await search.lookup(conn, details) + + return apiobj.async_to_sync(run()) + + +@pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2), + ('5.0, 4.59933', 1)]) +def test_simple_near_search_in_placex(apiobj, coord, pid): + apiobj.add_placex(place_id=1, class_='highway', type='bus_stop', + centroid=(5.0, 4.6)) + apiobj.add_placex(place_id=2, class_='highway', type='bus_stop', + centroid=(34.3, 56.1)) + + details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.001}) + + results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details) + + assert [r.place_id for r in results] == [pid] + + +@pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2), + ('34.3, 56.4', 2), + ('5.0, 4.59933', 1)]) +def test_simple_near_search_in_classtype(apiobj, coord, pid): + apiobj.add_placex(place_id=1, class_='highway', type='bus_stop', + centroid=(5.0, 4.6)) + apiobj.add_placex(place_id=2, class_='highway', type='bus_stop', + centroid=(34.3, 56.1)) + apiobj.add_class_type_table('highway', 'bus_stop') + + details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.5}) + + results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details) + + assert [r.place_id for r in results] == [pid] + + +class TestPoiSearchWithRestrictions: + + @pytest.fixture(autouse=True, params=["placex", "classtype"]) + def fill_database(self, apiobj, request): + apiobj.add_placex(place_id=1, class_='highway', type='bus_stop', + country_code='au', + centroid=(34.3, 56.10003)) + apiobj.add_placex(place_id=2, class_='highway', type='bus_stop', + country_code='nz', + centroid=(34.3, 56.1)) + if request.param == 'classtype': + apiobj.add_class_type_table('highway', 'bus_stop') + self.args = {'near': '34.3, 56.4', 'near_radius': 0.5} + else: + self.args = {'near': '34.3, 56.100021', 'near_radius': 0.001} + + + def test_unrestricted(self, apiobj): + results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], + details=SearchDetails.from_kwargs(self.args)) + + assert [r.place_id for r in results] == [1, 2] + + + def test_restict_country(self, apiobj): + results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], + ccodes=['de', 'nz'], + details=SearchDetails.from_kwargs(self.args)) + + assert [r.place_id for r in results] == [2] + + + def test_restrict_by_viewbox(self, apiobj): + args = {'bounded_viewbox': True, 'viewbox': '34.299,56.0,34.3001,56.10001'} + args.update(self.args) + results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], + ccodes=['de', 'nz'], + details=SearchDetails.from_kwargs(args)) + + assert [r.place_id for r in results] == [2] diff --git a/test/python/api/search/test_search_postcode.py b/test/python/api/search/test_search_postcode.py new file mode 100644 index 00000000..a43bc897 --- /dev/null +++ b/test/python/api/search/test_search_postcode.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for running the postcode searcher. +""" +import pytest + +import nominatim.api as napi +from nominatim.api.types import SearchDetails +from nominatim.api.search.db_searches import PostcodeSearch +from nominatim.api.search.db_search_fields import WeightedStrings, FieldLookup, \ + FieldRanking, RankedTokens + +def run_search(apiobj, global_penalty, pcs, pc_penalties=None, + ccodes=[], lookup=[], ranking=[], details=SearchDetails()): + if pc_penalties is None: + pc_penalties = [0.0] * len(pcs) + + class MySearchData: + penalty = global_penalty + postcodes = WeightedStrings(pcs, pc_penalties) + countries = WeightedStrings(ccodes, [0.0] * len(ccodes)) + lookups = lookup + rankings = ranking + + search = PostcodeSearch(0.0, MySearchData()) + + async def run(): + async with apiobj.api._async_api.begin() as conn: + return await search.lookup(conn, details) + + return apiobj.async_to_sync(run()) + + +def test_postcode_only_search(apiobj): + apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345') + apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345') + + results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1]) + + assert len(results) == 2 + assert [r.place_id for r in results] == [100, 101] + + +def test_postcode_with_country(apiobj): + apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345') + apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345') + + results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1], + ccodes=['de', 'pl']) + + assert len(results) == 1 + assert results[0].place_id == 101 + + +class TestPostcodeSearchWithAddress: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_postcode(place_id=100, country_code='ch', + parent_place_id=1000, postcode='12345') + apiobj.add_postcode(place_id=101, country_code='pl', + parent_place_id=2000, postcode='12345') + apiobj.add_placex(place_id=1000, class_='place', type='village', + rank_search=22, rank_address=22, + country_code='ch') + apiobj.add_search_name(1000, names=[1,2,10,11], + search_rank=22, address_rank=22, + country_code='ch') + apiobj.add_placex(place_id=2000, class_='place', type='village', + rank_search=22, rank_address=22, + country_code='pl') + apiobj.add_search_name(2000, names=[1,2,20,21], + search_rank=22, address_rank=22, + country_code='pl') + + + def test_lookup_both(self, apiobj): + lookup = FieldLookup('name_vector', [1,2], 'restrict') + ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup], ranking=[ranking]) + + assert [r.place_id for r in results] == [100, 101] + + + def test_restrict_by_name(self, apiobj): + lookup = FieldLookup('name_vector', [10], 'restrict') + + results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup]) + + assert [r.place_id for r in results] == [100] + diff --git a/test/python/api/search/test_token_assignment.py b/test/python/api/search/test_token_assignment.py new file mode 100644 index 00000000..dc123403 --- /dev/null +++ b/test/python/api/search/test_token_assignment.py @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Test for creation of token assignments from tokenized queries. +""" +import pytest + +from nominatim.api.search.query import QueryStruct, Phrase, PhraseType, BreakType, TokenType, TokenRange, Token +from nominatim.api.search.token_assignment import yield_token_assignments, TokenAssignment, PENALTY_TOKENCHANGE + +class MyToken(Token): + def get_category(self): + return 'this', 'that' + + +def make_query(*args): + q = None + dummy = MyToken(3.0, 45, 1, 'foo', True) + + for btype, ptype, tlist in args: + if q is None: + q = QueryStruct([Phrase(ptype, '')]) + else: + q.add_node(btype, ptype) + + start = len(q.nodes) - 1 + for end, ttype in tlist: + q.add_token(TokenRange(start, end), ttype, dummy) + + q.add_node(BreakType.END, PhraseType.NONE) + + return q + + +def check_assignments(actual, *expected): + todo = list(expected) + for assignment in actual: + assert assignment in todo, f"Unexpected assignment: {assignment}" + todo.remove(assignment) + + assert not todo, f"Missing assignments: {expected}" + + +def test_query_with_missing_tokens(): + q = QueryStruct([Phrase(PhraseType.NONE, '')]) + q.add_node(BreakType.END, PhraseType.NONE) + + assert list(yield_token_assignments(q)) == [] + + +def test_one_word_query(): + q = make_query((BreakType.START, PhraseType.NONE, + [(1, TokenType.PARTIAL), + (1, TokenType.WORD), + (1, TokenType.HOUSENUMBER)])) + + res = list(yield_token_assignments(q)) + assert res == [TokenAssignment(name=TokenRange(0, 1))] + + +def test_single_postcode(): + q = make_query((BreakType.START, PhraseType.NONE, + [(1, TokenType.POSTCODE)])) + + res = list(yield_token_assignments(q)) + assert res == [TokenAssignment(postcode=TokenRange(0, 1))] + + +def test_single_country_name(): + q = make_query((BreakType.START, PhraseType.NONE, + [(1, TokenType.COUNTRY)])) + + res = list(yield_token_assignments(q)) + assert res == [TokenAssignment(country=TokenRange(0, 1))] + + +def test_single_word_poi_search(): + q = make_query((BreakType.START, PhraseType.NONE, + [(1, TokenType.CATEGORY), + (1, TokenType.QUALIFIER)])) + + res = list(yield_token_assignments(q)) + assert res == [TokenAssignment(category=TokenRange(0, 1))] + + +@pytest.mark.parametrize('btype', [BreakType.WORD, BreakType.PART, BreakType.TOKEN]) +def test_multiple_simple_words(btype): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.PARTIAL)]), + (btype, PhraseType.NONE, [(2, TokenType.PARTIAL)]), + (btype, PhraseType.NONE, [(3, TokenType.PARTIAL)])) + + penalty = PENALTY_TOKENCHANGE[btype] + + check_assignments(yield_token_assignments(q), + TokenAssignment(name=TokenRange(0, 3)), + TokenAssignment(penalty=penalty, name=TokenRange(0, 2), + address=[TokenRange(2, 3)]), + TokenAssignment(penalty=penalty, name=TokenRange(0, 1), + address=[TokenRange(1, 3)]), + TokenAssignment(penalty=penalty, name=TokenRange(1, 3), + address=[TokenRange(0, 1)]), + TokenAssignment(penalty=penalty, name=TokenRange(2, 3), + address=[TokenRange(0, 2)]) + ) + + +def test_multiple_words_respect_phrase_break(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.PARTIAL)]), + (BreakType.PHRASE, PhraseType.NONE, [(2, TokenType.PARTIAL)])) + + check_assignments(yield_token_assignments(q), + TokenAssignment(name=TokenRange(0, 1), + address=[TokenRange(1, 2)]), + TokenAssignment(name=TokenRange(1, 2), + address=[TokenRange(0, 1)])) + + +def test_housenumber_and_street(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.HOUSENUMBER)]), + (BreakType.PHRASE, PhraseType.NONE, [(2, TokenType.PARTIAL)])) + + check_assignments(yield_token_assignments(q), + TokenAssignment(name=TokenRange(1, 2), + housenumber=TokenRange(0, 1)), + TokenAssignment(address=[TokenRange(1, 2)], + housenumber=TokenRange(0, 1))) + + +def test_housenumber_and_street_backwards(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.PARTIAL)]), + (BreakType.PHRASE, PhraseType.NONE, [(2, TokenType.HOUSENUMBER)])) + + check_assignments(yield_token_assignments(q), + TokenAssignment(name=TokenRange(0, 1), + housenumber=TokenRange(1, 2)), + TokenAssignment(address=[TokenRange(0, 1)], + housenumber=TokenRange(1, 2))) + + +def test_housenumber_and_postcode(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.HOUSENUMBER)]), + (BreakType.WORD, PhraseType.NONE, [(3, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(4, TokenType.POSTCODE)])) + + check_assignments(yield_token_assignments(q), + TokenAssignment(penalty=pytest.approx(0.3), + name=TokenRange(0, 1), + housenumber=TokenRange(1, 2), + address=[TokenRange(2, 3)], + postcode=TokenRange(3, 4)), + TokenAssignment(penalty=pytest.approx(0.3), + housenumber=TokenRange(1, 2), + address=[TokenRange(0, 1), TokenRange(2, 3)], + postcode=TokenRange(3, 4))) + +def test_postcode_and_housenumber(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.POSTCODE)]), + (BreakType.WORD, PhraseType.NONE, [(3, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(4, TokenType.HOUSENUMBER)])) + + check_assignments(yield_token_assignments(q), + TokenAssignment(penalty=pytest.approx(0.3), + name=TokenRange(2, 3), + housenumber=TokenRange(3, 4), + address=[TokenRange(0, 1)], + postcode=TokenRange(1, 2)), + TokenAssignment(penalty=pytest.approx(0.3), + housenumber=TokenRange(3, 4), + address=[TokenRange(0, 1), TokenRange(2, 3)], + postcode=TokenRange(1, 2))) + + +def test_country_housenumber_postcode(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.COUNTRY)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(3, TokenType.HOUSENUMBER)]), + (BreakType.WORD, PhraseType.NONE, [(4, TokenType.POSTCODE)])) + + check_assignments(yield_token_assignments(q)) + + +@pytest.mark.parametrize('ttype', [TokenType.POSTCODE, TokenType.COUNTRY, + TokenType.CATEGORY, TokenType.QUALIFIER]) +def test_housenumber_with_only_special_terms(ttype): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.HOUSENUMBER)]), + (BreakType.WORD, PhraseType.NONE, [(2, ttype)])) + + check_assignments(yield_token_assignments(q)) + + +@pytest.mark.parametrize('ttype', [TokenType.POSTCODE, TokenType.HOUSENUMBER, TokenType.COUNTRY]) +def test_multiple_special_tokens(ttype): + q = make_query((BreakType.START, PhraseType.NONE, [(1, ttype)]), + (BreakType.PHRASE, PhraseType.NONE, [(2, TokenType.PARTIAL)]), + (BreakType.PHRASE, PhraseType.NONE, [(3, ttype)])) + + check_assignments(yield_token_assignments(q)) + + +def test_housenumber_many_phrases(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.PARTIAL)]), + (BreakType.PHRASE, PhraseType.NONE, [(2, TokenType.PARTIAL)]), + (BreakType.PHRASE, PhraseType.NONE, [(3, TokenType.PARTIAL)]), + (BreakType.PHRASE, PhraseType.NONE, [(4, TokenType.HOUSENUMBER)]), + (BreakType.WORD, PhraseType.NONE, [(5, TokenType.PARTIAL)])) + + check_assignments(yield_token_assignments(q), + TokenAssignment(penalty=0.1, + name=TokenRange(4, 5), + housenumber=TokenRange(3, 4),\ + address=[TokenRange(0, 1), TokenRange(1, 2), + TokenRange(2, 3)]), + TokenAssignment(penalty=0.1, + housenumber=TokenRange(3, 4),\ + address=[TokenRange(0, 1), TokenRange(1, 2), + TokenRange(2, 3), TokenRange(4, 5)])) + + +def test_country_at_beginning(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.COUNTRY)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.PARTIAL)])) + + check_assignments(yield_token_assignments(q), + TokenAssignment(penalty=0.1, name=TokenRange(1, 2), + country=TokenRange(0, 1))) + + +def test_country_at_end(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.COUNTRY)])) + + check_assignments(yield_token_assignments(q), + TokenAssignment(penalty=0.1, name=TokenRange(0, 1), + country=TokenRange(1, 2))) + + +def test_country_in_middle(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.COUNTRY)]), + (BreakType.WORD, PhraseType.NONE, [(3, TokenType.PARTIAL)])) + + check_assignments(yield_token_assignments(q)) + + +def test_postcode_with_designation(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.POSTCODE)]), + (BreakType.PHRASE, PhraseType.NONE, [(2, TokenType.PARTIAL)])) + + check_assignments(yield_token_assignments(q), + TokenAssignment(penalty=0.1, name=TokenRange(1, 2), + postcode=TokenRange(0, 1)), + TokenAssignment(postcode=TokenRange(0, 1), + address=[TokenRange(1, 2)])) + + +def test_postcode_with_designation_backwards(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.PARTIAL)]), + (BreakType.PHRASE, PhraseType.NONE, [(2, TokenType.POSTCODE)])) + + check_assignments(yield_token_assignments(q), + TokenAssignment(name=TokenRange(0, 1), + postcode=TokenRange(1, 2)), + TokenAssignment(penalty=0.1, postcode=TokenRange(1, 2), + address=[TokenRange(0, 1)])) + + +def test_category_at_beginning(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.CATEGORY)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.PARTIAL)])) + + check_assignments(yield_token_assignments(q), + TokenAssignment(penalty=0.1, name=TokenRange(1, 2), + category=TokenRange(0, 1))) + + +def test_category_at_end(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.CATEGORY)])) + + check_assignments(yield_token_assignments(q), + TokenAssignment(penalty=0.1, name=TokenRange(0, 1), + category=TokenRange(1, 2))) + + +def test_category_in_middle(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.CATEGORY)]), + (BreakType.WORD, PhraseType.NONE, [(3, TokenType.PARTIAL)])) + + check_assignments(yield_token_assignments(q)) + + +def test_qualifier_at_beginning(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.QUALIFIER)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(3, TokenType.PARTIAL)])) + + + check_assignments(yield_token_assignments(q), + TokenAssignment(penalty=0.1, name=TokenRange(1, 3), + qualifier=TokenRange(0, 1)), + TokenAssignment(penalty=0.2, name=TokenRange(1, 2), + qualifier=TokenRange(0, 1), + address=[TokenRange(2, 3)])) + + +def test_qualifier_after_name(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(3, TokenType.QUALIFIER)]), + (BreakType.WORD, PhraseType.NONE, [(4, TokenType.PARTIAL)]), + (BreakType.WORD, PhraseType.NONE, [(5, TokenType.PARTIAL)])) + + + check_assignments(yield_token_assignments(q), + TokenAssignment(penalty=0.2, name=TokenRange(0, 2), + qualifier=TokenRange(2, 3), + address=[TokenRange(3, 5)]), + TokenAssignment(penalty=0.2, name=TokenRange(3, 5), + qualifier=TokenRange(2, 3), + address=[TokenRange(0, 2)])) + + +def test_qualifier_before_housenumber(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.QUALIFIER)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.HOUSENUMBER)]), + (BreakType.WORD, PhraseType.NONE, [(3, TokenType.PARTIAL)])) + + check_assignments(yield_token_assignments(q)) + + +def test_qualifier_after_housenumber(): + q = make_query((BreakType.START, PhraseType.NONE, [(1, TokenType.HOUSENUMBER)]), + (BreakType.WORD, PhraseType.NONE, [(2, TokenType.QUALIFIER)]), + (BreakType.WORD, PhraseType.NONE, [(3, TokenType.PARTIAL)])) + + check_assignments(yield_token_assignments(q)) diff --git a/test/python/api/test_api_search.py b/test/python/api/test_api_search.py new file mode 100644 index 00000000..aa263d24 --- /dev/null +++ b/test/python/api/test_api_search.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for search API calls. + +These tests make sure that all Python code is correct and executable. +Functional tests can be found in the BDD test suite. +""" +import json + +import pytest + +import sqlalchemy as sa + +import nominatim.api as napi +import nominatim.api.logging as loglib + +@pytest.fixture(autouse=True) +def setup_icu_tokenizer(apiobj): + """ Setup the propoerties needed for using the ICU tokenizer. + """ + apiobj.add_data('properties', + [{'property': 'tokenizer', 'value': 'icu'}, + {'property': 'tokenizer_import_normalisation', 'value': ':: lower();'}, + {'property': 'tokenizer_import_transliteration', 'value': "'1' > '/1/'; 'ä' > 'ä '"}, + ]) + + +def test_search_no_content(apiobj, table_factory): + table_factory('word', + definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') + + assert apiobj.api.search('foo') == [] + + +def test_search_simple_word(apiobj, table_factory): + table_factory('word', + definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB', + content=[(55, 'test', 'W', 'test', None), + (2, 'test', 'w', 'test', None)]) + + apiobj.add_placex(place_id=444, class_='place', type='village', + centroid=(1.3, 0.7)) + apiobj.add_search_name(444, names=[2, 55]) + + results = apiobj.api.search('TEST') + + assert [r.place_id for r in results] == [444] + + +@pytest.mark.parametrize('logtype', ['text', 'html']) +def test_search_with_debug(apiobj, table_factory, logtype): + table_factory('word', + definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB', + content=[(55, 'test', 'W', 'test', None), + (2, 'test', 'w', 'test', None)]) + + apiobj.add_placex(place_id=444, class_='place', type='village', + centroid=(1.3, 0.7)) + apiobj.add_search_name(444, names=[2, 55]) + + loglib.set_log_output(logtype) + results = apiobj.api.search('TEST') + + assert loglib.get_and_disable() + + +def test_address_no_content(apiobj, table_factory): + table_factory('word', + definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') + + assert apiobj.api.search_address(amenity='hotel', + street='Main St 34', + city='Happyville', + county='Wideland', + state='Praerie', + postalcode='55648', + country='xx') == [] + + +@pytest.mark.parametrize('atype,address,search', [('street', 26, 26), + ('city', 16, 18), + ('county', 12, 12), + ('state', 8, 8)]) +def test_address_simple_places(apiobj, table_factory, atype, address, search): + table_factory('word', + definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB', + content=[(55, 'test', 'W', 'test', None), + (2, 'test', 'w', 'test', None)]) + + apiobj.add_placex(place_id=444, + rank_address=address, rank_search=search, + centroid=(1.3, 0.7)) + apiobj.add_search_name(444, names=[2, 55], address_rank=address, search_rank=search) + + results = apiobj.api.search_address(**{atype: 'TEST'}) + + assert [r.place_id for r in results] == [444] + + +def test_address_country(apiobj, table_factory): + table_factory('word', + definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB', + content=[(None, 'ro', 'C', 'ro', None)]) + apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))') + apiobj.add_country_name('ro', {'name': 'România'}) + + assert len(apiobj.api.search_address(country='ro')) == 1 + + +def test_category_no_categories(apiobj, table_factory): + table_factory('word', + definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') + + assert apiobj.api.search_category([], near_query='Berlin') == [] + + +def test_category_no_content(apiobj, table_factory): + table_factory('word', + definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') + + assert apiobj.api.search_category([('amenity', 'restaurant')]) == [] + + +def test_category_simple_restaurant(apiobj, table_factory): + table_factory('word', + definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') + + apiobj.add_placex(place_id=444, class_='amenity', type='restaurant', + centroid=(1.3, 0.7)) + apiobj.add_search_name(444, names=[2, 55], address_rank=16, search_rank=18) + + results = apiobj.api.search_category([('amenity', 'restaurant')], + near=(1.3, 0.701), near_radius=0.015) + + assert [r.place_id for r in results] == [444] + + +def test_category_with_search_phrase(apiobj, table_factory): + table_factory('word', + definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB', + content=[(55, 'test', 'W', 'test', None), + (2, 'test', 'w', 'test', None)]) + + apiobj.add_placex(place_id=444, class_='place', type='village', + rank_address=16, rank_search=18, + centroid=(1.3, 0.7)) + apiobj.add_search_name(444, names=[2, 55], address_rank=16, search_rank=18) + apiobj.add_placex(place_id=95, class_='amenity', type='restaurant', + centroid=(1.3, 0.7003)) + + results = apiobj.api.search_category([('amenity', 'restaurant')], + near_query='TEST') + + assert [r.place_id for r in results] == [95] diff --git a/test/python/api/test_helpers_v1.py b/test/python/api/test_helpers_v1.py new file mode 100644 index 00000000..45f538de --- /dev/null +++ b/test/python/api/test_helpers_v1.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for the helper functions for v1 API. +""" +import pytest + +import nominatim.api.v1.helpers as helper + +@pytest.mark.parametrize('inp', ['', 'abc', '12 23', 'abc -78.90, 12.456 def']) +def test_extract_coords_no_coords(inp): + query, x, y = helper.extract_coords_from_query(inp) + + assert query == inp + assert x is None + assert y is None + + +def test_extract_coords_null_island(): + assert ('', 0.0, 0.0) == helper.extract_coords_from_query('0.0 -0.0') + + +def test_extract_coords_with_text_before(): + assert ('abc', 12.456, -78.90) == helper.extract_coords_from_query('abc -78.90, 12.456') + + +def test_extract_coords_with_text_after(): + assert ('abc', 12.456, -78.90) == helper.extract_coords_from_query('-78.90, 12.456 abc') + +@pytest.mark.parametrize('inp', [' [12.456,-78.90] ', ' 12.456,-78.90 ']) +def test_extract_coords_with_spaces(inp): + assert ('', -78.90, 12.456) == helper.extract_coords_from_query(inp) + +@pytest.mark.parametrize('inp', ['40 26.767 N 79 58.933 W', + '40° 26.767â² N 79° 58.933â² W', + "40° 26.767' N 79° 58.933' W", + "40° 26.767'\n" + " N 79° 58.933' W", + 'N 40 26.767, W 79 58.933', + 'N 40°26.767â², W 79°58.933â²', + ' N 40°26.767â², W 79°58.933â²', + "N 40°26.767', W 79°58.933'", + + '40 26 46 N 79 58 56 W', + '40° 26â² 46â³ N 79° 58â² 56â³ W', + '40° 26â² 46.00â³ N 79° 58â² 56.00â³ W', + '40°26â²46â³N 79°58â²56â³W', + 'N 40 26 46 W 79 58 56', + 'N 40° 26â² 46â³, W 79° 58â² 56â³', + 'N 40° 26\' 46", W 79° 58\' 56"', + 'N 40° 26\' 46", W 79° 58\' 56"', + + '40.446 -79.982', + '40.446,-79.982', + '40.446° N 79.982° W', + 'N 40.446° W 79.982°', + + '[40.446 -79.982]', + '[40.446,-79.982]', + ' 40.446 , -79.982 ', + ' 40.446 , -79.982 ', + ' 40.446 , -79.982 ', + ' 40.446, -79.982 ']) +def test_extract_coords_formats(inp): + query, x, y = helper.extract_coords_from_query(inp) + + assert query == '' + assert pytest.approx(x, abs=0.001) == -79.982 + assert pytest.approx(y, abs=0.001) == 40.446 + + query, x, y = helper.extract_coords_from_query('foo bar ' + inp) + + assert query == 'foo bar' + assert pytest.approx(x, abs=0.001) == -79.982 + assert pytest.approx(y, abs=0.001) == 40.446 + + query, x, y = helper.extract_coords_from_query(inp + ' x') + + assert query == 'x' + assert pytest.approx(x, abs=0.001) == -79.982 + assert pytest.approx(y, abs=0.001) == 40.446 + + +def test_extract_coords_formats_southeast(): + query, x, y = helper.extract_coords_from_query('S 40 26.767, E 79 58.933') + + assert query == '' + assert pytest.approx(x, abs=0.001) == 79.982 + assert pytest.approx(y, abs=0.001) == -40.446 + + +@pytest.mark.parametrize('inp', ['[shop=fish] foo bar', + 'foo [shop=fish] bar', + 'foo [shop=fish]bar', + 'foo bar [shop=fish]']) +def test_extract_category_good(inp): + query, cls, typ = helper.extract_category_from_query(inp) + + assert query == 'foo bar' + assert cls == 'shop' + assert typ == 'fish' + +def test_extract_category_only(): + assert helper.extract_category_from_query('[shop=market]') == ('', 'shop', 'market') + +@pytest.mark.parametrize('inp', ['house []', 'nothing', '[352]']) +def test_extract_category_no_match(inp): + assert helper.extract_category_from_query(inp) == (inp, None, None) diff --git a/test/python/api/test_result_formatting_v1.py b/test/python/api/test_result_formatting_v1.py index e0fcc025..0c54667e 100644 --- a/test/python/api/test_result_formatting_v1.py +++ b/test/python/api/test_result_formatting_v1.py @@ -75,11 +75,14 @@ def test_search_details_minimal(): {'category': 'place', 'type': 'thing', 'admin_level': 15, + 'names': {}, 'localname': '', 'calculated_importance': pytest.approx(0.0000001), 'rank_address': 30, 'rank_search': 30, 'isarea': False, + 'addresstags': {}, + 'extratags': {}, 'centroid': {'type': 'Point', 'coordinates': [1.0, 2.0]}, 'geometry': {'type': 'Point', 'coordinates': [1.0, 2.0]}, } @@ -108,6 +111,7 @@ def test_search_details_full(): country_code='ll', indexed_date = import_date ) + search.localize(napi.Locales()) result = api_impl.format_result(search, 'json', {}) diff --git a/test/python/api/test_result_formatting_v1_reverse.py b/test/python/api/test_result_formatting_v1_reverse.py index 6e94cf10..d9d43953 100644 --- a/test/python/api/test_result_formatting_v1_reverse.py +++ b/test/python/api/test_result_formatting_v1_reverse.py @@ -101,6 +101,7 @@ def test_format_reverse_with_address(fmt): rank_address=10, distance=0.0) ])) + reverse.localize(napi.Locales()) raw = api_impl.format_result(napi.ReverseResults([reverse]), fmt, {'addressdetails': True}) @@ -164,6 +165,8 @@ def test_format_reverse_geocodejson_special_parts(): distance=0.0) ])) + reverse.localize(napi.Locales()) + raw = api_impl.format_result(napi.ReverseResults([reverse]), 'geocodejson', {'addressdetails': True}) diff --git a/test/python/api/test_server_glue_v1.py b/test/python/api/test_server_glue_v1.py index c0ca69dd..a731e720 100644 --- a/test/python/api/test_server_glue_v1.py +++ b/test/python/api/test_server_glue_v1.py @@ -32,9 +32,9 @@ FakeResponse = namedtuple('FakeResponse', ['status', 'output', 'content_type']) class FakeAdaptor(glue.ASGIAdaptor): - def __init__(self, params={}, headers={}, config=None): - self.params = params - self.headers = headers + def __init__(self, params=None, headers=None, config=None): + self.params = params or {} + self.headers = headers or {} self._config = config or Configuration(None) @@ -123,7 +123,7 @@ def test_accepted_languages_from_param(): def test_accepted_languages_from_header(): - a = FakeAdaptor(headers={'http_accept_language': 'de'}) + a = FakeAdaptor(headers={'accept-language': 'de'}) assert a.get_accepted_languages() == 'de' @@ -135,13 +135,13 @@ def test_accepted_languages_from_default(monkeypatch): def test_accepted_languages_param_over_header(): a = FakeAdaptor(params={'accept-language': 'de'}, - headers={'http_accept_language': 'en'}) + headers={'accept-language': 'en'}) assert a.get_accepted_languages() == 'de' def test_accepted_languages_header_over_default(monkeypatch): monkeypatch.setenv('NOMINATIM_DEFAULT_LANGUAGE', 'en') - a = FakeAdaptor(headers={'http_accept_language': 'de'}) + a = FakeAdaptor(headers={'accept-language': 'de'}) assert a.get_accepted_languages() == 'de' @@ -197,14 +197,14 @@ def test_raise_error_during_debug(): loglib.log().section('Ongoing') with pytest.raises(FakeError) as excinfo: - a.raise_error('bad state') + a.raise_error('badstate') content = ET.fromstring(excinfo.value.msg) assert content.tag == 'html' assert '>Ongoing<' in excinfo.value.msg - assert 'bad state' in excinfo.value.msg + assert 'badstate' in excinfo.value.msg # ASGIAdaptor.build_response @@ -386,6 +386,63 @@ class TestDetailsEndpoint: await glue.details_endpoint(napi.NominatimAPIAsync(Path('/invalid')), a) +# reverse_endpoint() +class TestReverseEndPoint: + + @pytest.fixture(autouse=True) + def patch_reverse_func(self, monkeypatch): + self.result = napi.ReverseResult(napi.SourceTable.PLACEX, + ('place', 'thing'), + napi.Point(1.0, 2.0)) + async def _reverse(*args, **kwargs): + return self.result + + monkeypatch.setattr(napi.NominatimAPIAsync, 'reverse', _reverse) + + + @pytest.mark.asyncio + @pytest.mark.parametrize('params', [{}, {'lat': '3.4'}, {'lon': '6.7'}]) + async def test_reverse_no_params(self, params): + a = FakeAdaptor() + a.params = params + a.params['format'] = 'xml' + + with pytest.raises(FakeError, match='^400 -- (?s:.*)missing'): + await glue.reverse_endpoint(napi.NominatimAPIAsync(Path('/invalid')), a) + + + @pytest.mark.asyncio + @pytest.mark.parametrize('params', [{'lat': '45.6', 'lon': '4563'}]) + async def test_reverse_success(self, params): + a = FakeAdaptor() + a.params = params + a.params['format'] = 'json' + + res = await glue.reverse_endpoint(napi.NominatimAPIAsync(Path('/invalid')), a) + + assert res == '' + + + @pytest.mark.asyncio + async def test_reverse_success(self): + a = FakeAdaptor() + a.params['lat'] = '56.3' + a.params['lon'] = '6.8' + + assert await glue.reverse_endpoint(napi.NominatimAPIAsync(Path('/invalid')), a) + + + @pytest.mark.asyncio + async def test_reverse_from_search(self): + a = FakeAdaptor() + a.params['q'] = '34.6 2.56' + a.params['format'] = 'json' + + res = await glue.search_endpoint(napi.NominatimAPIAsync(Path('/invalid')), a) + + assert len(json.loads(res.output)) == 1 + + # lookup_endpoint() class TestLookupEndpoint: @@ -444,3 +501,111 @@ class TestLookupEndpoint: res = await glue.lookup_endpoint(napi.NominatimAPIAsync(Path('/invalid')), a) assert len(json.loads(res.output)) == 1 + + +# search_endpoint() + +class TestSearchEndPointSearch: + + @pytest.fixture(autouse=True) + def patch_lookup_func(self, monkeypatch): + self.results = [napi.SearchResult(napi.SourceTable.PLACEX, + ('place', 'thing'), + napi.Point(1.0, 2.0))] + async def _search(*args, **kwargs): + return napi.SearchResults(self.results) + + monkeypatch.setattr(napi.NominatimAPIAsync, 'search', _search) + + + @pytest.mark.asyncio + async def test_search_free_text(self): + a = FakeAdaptor() + a.params['q'] = 'something' + + res = await glue.search_endpoint(napi.NominatimAPIAsync(Path('/invalid')), a) + + assert len(json.loads(res.output)) == 1 + + + @pytest.mark.asyncio + async def test_search_free_text_xml(self): + a = FakeAdaptor() + a.params['q'] = 'something' + a.params['format'] = 'xml' + + res = await glue.search_endpoint(napi.NominatimAPIAsync(Path('/invalid')), a) + + assert res.status == 200 + assert res.output.index('something') > 0 + + + @pytest.mark.asyncio + async def test_search_free_and_structured(self): + a = FakeAdaptor() + a.params['q'] = 'something' + a.params['city'] = 'ignored' + + res = await glue.search_endpoint(napi.NominatimAPIAsync(Path('/invalid')), a) + + assert len(json.loads(res.output)) == 1 + + + @pytest.mark.asyncio + @pytest.mark.parametrize('dedupe,numres', [(True, 1), (False, 2)]) + async def test_search_dedupe(self, dedupe, numres): + self.results = self.results * 2 + a = FakeAdaptor() + a.params['q'] = 'something' + if not dedupe: + a.params['dedupe'] = '0' + + res = await glue.search_endpoint(napi.NominatimAPIAsync(Path('/invalid')), a) + + assert len(json.loads(res.output)) == numres + + +class TestSearchEndPointSearchAddress: + + @pytest.fixture(autouse=True) + def patch_lookup_func(self, monkeypatch): + self.results = [napi.SearchResult(napi.SourceTable.PLACEX, + ('place', 'thing'), + napi.Point(1.0, 2.0))] + async def _search(*args, **kwargs): + return napi.SearchResults(self.results) + + monkeypatch.setattr(napi.NominatimAPIAsync, 'search_address', _search) + + + @pytest.mark.asyncio + async def test_search_structured(self): + a = FakeAdaptor() + a.params['street'] = 'something' + + res = await glue.search_endpoint(napi.NominatimAPIAsync(Path('/invalid')), a) + + assert len(json.loads(res.output)) == 1 + + +class TestSearchEndPointSearchCategory: + + @pytest.fixture(autouse=True) + def patch_lookup_func(self, monkeypatch): + self.results = [napi.SearchResult(napi.SourceTable.PLACEX, + ('place', 'thing'), + napi.Point(1.0, 2.0))] + async def _search(*args, **kwargs): + return napi.SearchResults(self.results) + + monkeypatch.setattr(napi.NominatimAPIAsync, 'search_category', _search) + + + @pytest.mark.asyncio + async def test_search_category(self): + a = FakeAdaptor() + a.params['q'] = '[shop=fog]' + + res = await glue.search_endpoint(napi.NominatimAPIAsync(Path('/invalid')), a) + + assert len(json.loads(res.output)) == 1 diff --git a/test/python/cli/test_cmd_api.py b/test/python/cli/test_cmd_api.py index 2d7897a3..05e3c4f0 100644 --- a/test/python/cli/test_cmd_api.py +++ b/test/python/cli/test_cmd_api.py @@ -14,42 +14,6 @@ import nominatim.clicmd.api import nominatim.api as napi -@pytest.mark.parametrize("endpoint", (('search', 'reverse', 'lookup', 'details', 'status'))) -def test_no_api_without_phpcgi(endpoint): - assert nominatim.cli.nominatim(module_dir='MODULE NOT AVAILABLE', - osm2pgsql_path='OSM2PGSQL NOT AVAILABLE', - phpcgi_path=None, - cli_args=[endpoint]) == 1 - - -@pytest.mark.parametrize("params", [('search', '--query', 'new'), - ('search', '--city', 'Berlin')]) -class TestCliApiCallPhp: - - @pytest.fixture(autouse=True) - def setup_cli_call(self, params, cli_call, mock_func_factory, tmp_path): - self.mock_run_api = mock_func_factory(nominatim.clicmd.api, 'run_api_script') - - def _run(): - return cli_call(*params, '--project-dir', str(tmp_path)) - - self.run_nominatim = _run - - - def test_api_commands_simple(self, tmp_path, params): - (tmp_path / 'website').mkdir() - (tmp_path / 'website' / (params[0] + '.php')).write_text('') - - assert self.run_nominatim() == 0 - - assert self.mock_run_api.called == 1 - assert self.mock_run_api.last_args[0] == params[0] - - - def test_bad_project_dir(self): - assert self.run_nominatim() == 1 - - class TestCliStatusCall: @pytest.fixture(autouse=True) @@ -181,72 +145,26 @@ class TestCliLookupCall: assert 'namedetails' not in out[0] -class TestCliApiCommonParameters: - - @pytest.fixture(autouse=True) - def setup_website_dir(self, cli_call, project_env): - self.cli_call = cli_call - self.project_dir = project_env.project_dir - (self.project_dir / 'website').mkdir() - - - def expect_param(self, param, expected): - (self.project_dir / 'website' / ('search.php')).write_text(f"""