From: Sarah Hoffmann Date: Mon, 19 Aug 2024 09:31:38 +0000 (+0200) Subject: make NominatimAPI[Async] a context manager X-Git-Tag: deploy~1^2~2^2~2 X-Git-Url: https://git.openstreetmap.org./nominatim.git/commitdiff_plain/c2594aca4075fe7b7a70d45a068523ca3d372ab2?ds=sidebyside make NominatimAPI[Async] a context manager If close() isn't properly called, it can lead to odd error messages about uncaught exceptions. --- diff --git a/src/nominatim_api/core.py b/src/nominatim_api/core.py index 6c4c37d7..ac579862 100644 --- a/src/nominatim_api/core.py +++ b/src/nominatim_api/core.py @@ -38,6 +38,8 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes This class shares most of the functions with its synchronous version. There are some additional functions or parameters, which are documented below. + + This class should usually be used as a context manager in 'with' context. """ def __init__(self, project_dir: Path, environ: Optional[Mapping[str, str]] = None, @@ -166,6 +168,14 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes await self._engine.dispose() + async def __aenter__(self) -> 'NominatimAPIAsync': + return self + + + async def __aexit__(self, *_: Any) -> None: + await self.close() + + @contextlib.asynccontextmanager async def begin(self) -> AsyncIterator[SearchConnection]: """ Create a new connection with automatic transaction handling. @@ -351,6 +361,8 @@ class NominatimAPI: """ This class provides a thin synchronous wrapper around the asynchronous Nominatim functions. It creates its own event loop and runs each synchronous function call to completion using that loop. + + This class should usually be used as a context manager in 'with' context. """ def __init__(self, project_dir: Path, @@ -376,8 +388,17 @@ class NominatimAPI: This function also closes the asynchronous worker loop making the NominatimAPI object unusable. """ - self._loop.run_until_complete(self._async_api.close()) - self._loop.close() + if not self._loop.is_closed(): + self._loop.run_until_complete(self._async_api.close()) + self._loop.close() + + + def __enter__(self) -> 'NominatimAPI': + return self + + + def __exit__(self, *_: Any) -> None: + self.close() @property diff --git a/test/python/api/conftest.py b/test/python/api/conftest.py index a902e264..0c770980 100644 --- a/test/python/api/conftest.py +++ b/test/python/api/conftest.py @@ -9,6 +9,7 @@ Helper fixtures for API call tests. """ from pathlib import Path import pytest +import pytest_asyncio import time import datetime as dt @@ -244,3 +245,9 @@ def frontend(request, event_loop, tmp_path): for api in testapis: api.close() + + +@pytest_asyncio.fixture +async def api(temp_db): + async with napi.NominatimAPIAsync(Path('/invalid')) as api: + yield api diff --git a/test/python/api/search/test_icu_query_analyzer.py b/test/python/api/search/test_icu_query_analyzer.py index 8e5480fc..7f88879c 100644 --- a/test/python/api/search/test_icu_query_analyzer.py +++ b/test/python/api/search/test_icu_query_analyzer.py @@ -40,10 +40,9 @@ async def conn(table_factory): table_factory('word', definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') - api = NominatimAPIAsync(Path('/invalid'), {}) - async with api.begin() as conn: - yield conn - await api.close() + async with NominatimAPIAsync(Path('/invalid'), {}) as api: + async with api.begin() as conn: + yield conn @pytest.mark.asyncio diff --git a/test/python/api/search/test_legacy_query_analyzer.py b/test/python/api/search/test_legacy_query_analyzer.py index 92de8706..0e967c10 100644 --- a/test/python/api/search/test_legacy_query_analyzer.py +++ b/test/python/api/search/test_legacy_query_analyzer.py @@ -74,10 +74,9 @@ async def conn(table_factory, temp_db_cursor): temp_db_cursor.execute("""CREATE OR REPLACE FUNCTION make_standard_name(name TEXT) RETURNS TEXT AS $$ SELECT lower(name); $$ LANGUAGE SQL;""") - api = NominatimAPIAsync(Path('/invalid'), {}) - async with api.begin() as conn: - yield conn - await api.close() + async with NominatimAPIAsync(Path('/invalid'), {}) as api: + async with api.begin() as conn: + yield conn @pytest.mark.asyncio diff --git a/test/python/api/search/test_query_analyzer_factory.py b/test/python/api/search/test_query_analyzer_factory.py index 9545a88f..42220b55 100644 --- a/test/python/api/search/test_query_analyzer_factory.py +++ b/test/python/api/search/test_query_analyzer_factory.py @@ -11,41 +11,35 @@ from pathlib import Path import pytest -from nominatim_api import NominatimAPIAsync from nominatim_api.search.query_analyzer_factory import make_query_analyzer from nominatim_api.search.icu_tokenizer import ICUQueryAnalyzer @pytest.mark.asyncio -async def test_import_icu_tokenizer(table_factory): +async def test_import_icu_tokenizer(table_factory, api): table_factory('nominatim_properties', definition='property TEXT, value TEXT', content=(('tokenizer', 'icu'), ('tokenizer_import_normalisation', ':: lower();'), ('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '"))) - api = NominatimAPIAsync(Path('/invalid'), {}) async with api.begin() as conn: ana = await make_query_analyzer(conn) assert isinstance(ana, ICUQueryAnalyzer) - await api.close() @pytest.mark.asyncio -async def test_import_missing_property(table_factory): - api = NominatimAPIAsync(Path('/invalid'), {}) +async def test_import_missing_property(table_factory, api): table_factory('nominatim_properties', definition='property TEXT, value TEXT') async with api.begin() as conn: with pytest.raises(ValueError, match='Property.*not found'): await make_query_analyzer(conn) - await api.close() @pytest.mark.asyncio -async def test_import_missing_module(table_factory): - api = NominatimAPIAsync(Path('/invalid'), {}) +async def test_import_missing_module(table_factory, api): table_factory('nominatim_properties', definition='property TEXT, value TEXT', content=(('tokenizer', 'missing'),)) @@ -53,5 +47,3 @@ async def test_import_missing_module(table_factory): async with api.begin() as conn: with pytest.raises(RuntimeError, match='Tokenizer not found'): await make_query_analyzer(conn) - await api.close() - diff --git a/test/python/api/test_api_connection.py b/test/python/api/test_api_connection.py index 3c4fc61b..f62b0d9e 100644 --- a/test/python/api/test_api_connection.py +++ b/test/python/api/test_api_connection.py @@ -9,45 +9,34 @@ Tests for enhanced connection class for API functions. """ from pathlib import Path import pytest -import pytest_asyncio import sqlalchemy as sa -from nominatim_api import NominatimAPIAsync - -@pytest_asyncio.fixture -async def apiobj(temp_db): - """ Create an asynchronous SQLAlchemy engine for the test DB. - """ - api = NominatimAPIAsync(Path('/invalid'), {}) - yield api - await api.close() - @pytest.mark.asyncio -async def test_run_scalar(apiobj, table_factory): +async def test_run_scalar(api, table_factory): table_factory('foo', definition='that TEXT', content=(('a', ),)) - async with apiobj.begin() as conn: + async with api.begin() as conn: assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a' @pytest.mark.asyncio -async def test_run_execute(apiobj, table_factory): +async def test_run_execute(api, table_factory): table_factory('foo', definition='that TEXT', content=(('a', ),)) - async with apiobj.begin() as conn: + async with api.begin() as conn: result = await conn.execute(sa.text('SELECT * FROM foo')) assert result.fetchone()[0] == 'a' @pytest.mark.asyncio -async def test_get_property_existing_cached(apiobj, table_factory): +async def test_get_property_existing_cached(api, table_factory): table_factory('nominatim_properties', definition='property TEXT, value TEXT', content=(('dbv', '96723'), )) - async with apiobj.begin() as conn: + async with api.begin() as conn: assert await conn.get_property('dbv') == '96723' await conn.execute(sa.text('TRUNCATE nominatim_properties')) @@ -56,12 +45,12 @@ async def test_get_property_existing_cached(apiobj, table_factory): @pytest.mark.asyncio -async def test_get_property_existing_uncached(apiobj, table_factory): +async def test_get_property_existing_uncached(api, table_factory): table_factory('nominatim_properties', definition='property TEXT, value TEXT', content=(('dbv', '96723'), )) - async with apiobj.begin() as conn: + async with api.begin() as conn: assert await conn.get_property('dbv') == '96723' await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'")) @@ -71,23 +60,23 @@ async def test_get_property_existing_uncached(apiobj, table_factory): @pytest.mark.asyncio @pytest.mark.parametrize('param', ['foo', 'DB:server_version']) -async def test_get_property_missing(apiobj, table_factory, param): +async def test_get_property_missing(api, table_factory, param): table_factory('nominatim_properties', definition='property TEXT, value TEXT') - async with apiobj.begin() as conn: + async with api.begin() as conn: with pytest.raises(ValueError): await conn.get_property(param) @pytest.mark.asyncio -async def test_get_db_property_existing(apiobj): - async with apiobj.begin() as conn: +async def test_get_db_property_existing(api): + async with api.begin() as conn: assert await conn.get_db_property('server_version') > 0 @pytest.mark.asyncio -async def test_get_db_property_existing(apiobj): - async with apiobj.begin() as conn: +async def test_get_db_property_existing(api): + async with api.begin() as conn: with pytest.raises(ValueError): await conn.get_db_property('dfkgjd.rijg') diff --git a/test/python/api/test_api_deletable_v1.py b/test/python/api/test_api_deletable_v1.py index 649dd8fc..9e113886 100644 --- a/test/python/api/test_api_deletable_v1.py +++ b/test/python/api/test_api_deletable_v1.py @@ -11,19 +11,10 @@ import json from pathlib import Path import pytest -import pytest_asyncio from fake_adaptor import FakeAdaptor, FakeError, FakeResponse import nominatim_api.v1.server_glue as glue -import nominatim_api as napi - -@pytest_asyncio.fixture -async def api(): - api = napi.NominatimAPIAsync(Path('/invalid')) - yield api - await api.close() - class TestDeletableEndPoint: @@ -61,4 +52,3 @@ class TestDeletableEndPoint: {'place_id': 3, 'country_code': 'cd', 'name': None, 'osm_id': 781, 'osm_type': 'R', 'class': 'landcover', 'type': 'grass'}] - diff --git a/test/python/api/test_api_polygons_v1.py b/test/python/api/test_api_polygons_v1.py index 558be813..ac2b4cb9 100644 --- a/test/python/api/test_api_polygons_v1.py +++ b/test/python/api/test_api_polygons_v1.py @@ -12,19 +12,10 @@ import datetime as dt from pathlib import Path import pytest -import pytest_asyncio from fake_adaptor import FakeAdaptor, FakeError, FakeResponse import nominatim_api.v1.server_glue as glue -import nominatim_api as napi - -@pytest_asyncio.fixture -async def api(): - api = napi.NominatimAPIAsync(Path('/invalid')) - yield api - await api.close() - class TestPolygonsEndPoint: