]> git.openstreetmap.org Git - nominatim.git/blob - test/python/api/search/test_search_poi.py
make NominatimAPI[Async] a context manager
[nominatim.git] / test / python / api / search / test_search_poi.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Tests for running the POI searcher.
9 """
10 import pytest
11
12 import nominatim_api as napi
13 from nominatim_api.types import SearchDetails
14 from nominatim_api.search.db_searches import PoiSearch
15 from nominatim_api.search.db_search_fields import WeightedStrings, WeightedCategories
16
17
18 def run_search(apiobj, frontend, global_penalty, poitypes, poi_penalties=None,
19                ccodes=[], details=SearchDetails()):
20     if poi_penalties is None:
21         poi_penalties = [0.0] * len(poitypes)
22
23     class MySearchData:
24         penalty = global_penalty
25         qualifiers = WeightedCategories(poitypes, poi_penalties)
26         countries = WeightedStrings(ccodes, [0.0] * len(ccodes))
27
28     search = PoiSearch(MySearchData())
29
30     api = frontend(apiobj, options=['search'])
31
32     async def run():
33         async with api._async_api.begin() as conn:
34             return await search.lookup(conn, details)
35
36     return api._loop.run_until_complete(run())
37
38
39 @pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2),
40                                        ('5.0, 4.59933', 1)])
41 def test_simple_near_search_in_placex(apiobj, frontend, coord, pid):
42     apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
43                       centroid=(5.0, 4.6))
44     apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
45                       centroid=(34.3, 56.1))
46
47     details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.001})
48
49     results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
50
51     assert [r.place_id for r in results] == [pid]
52
53
54 @pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2),
55                                        ('34.3, 56.4', 2),
56                                        ('5.0, 4.59933', 1)])
57 def test_simple_near_search_in_classtype(apiobj, frontend, coord, pid):
58     apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
59                       centroid=(5.0, 4.6))
60     apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
61                       centroid=(34.3, 56.1))
62     apiobj.add_class_type_table('highway', 'bus_stop')
63
64     details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.5})
65
66     results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
67
68     assert [r.place_id for r in results] == [pid]
69
70
71 class TestPoiSearchWithRestrictions:
72
73     @pytest.fixture(autouse=True, params=["placex", "classtype"])
74     def fill_database(self, apiobj, request):
75         apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
76                           country_code='au',
77                           centroid=(34.3, 56.10003))
78         apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
79                           country_code='nz',
80                           centroid=(34.3, 56.1))
81         if request.param == 'classtype':
82             apiobj.add_class_type_table('highway', 'bus_stop')
83             self.args = {'near': '34.3, 56.4', 'near_radius': 0.5}
84         else:
85             self.args = {'near': '34.3, 56.100021', 'near_radius': 0.001}
86
87
88     def test_unrestricted(self, apiobj, frontend):
89         results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5],
90                              details=SearchDetails.from_kwargs(self.args))
91
92         assert [r.place_id for r in results] == [1, 2]
93
94
95     def test_restict_country(self, apiobj, frontend):
96         results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5],
97                              ccodes=['de', 'nz'],
98                              details=SearchDetails.from_kwargs(self.args))
99
100         assert [r.place_id for r in results] == [2]
101
102
103     def test_restrict_by_viewbox(self, apiobj, frontend):
104         args = {'bounded_viewbox': True, 'viewbox': '34.299,56.0,34.3001,56.10001'}
105         args.update(self.args)
106         results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5],
107                              ccodes=['de', 'nz'],
108                              details=SearchDetails.from_kwargs(args))
109
110         assert [r.place_id for r in results] == [2]