]> git.openstreetmap.org Git - nominatim.git/commitdiff
make NominatimAPI[Async] a context manager
authorSarah Hoffmann <lonvia@denofr.de>
Mon, 19 Aug 2024 09:31:38 +0000 (11:31 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 19 Aug 2024 09:31:38 +0000 (11:31 +0200)
If close() isn't properly called, it can lead to odd error messages
about uncaught exceptions.

src/nominatim_api/core.py
test/python/api/conftest.py
test/python/api/search/test_icu_query_analyzer.py
test/python/api/search/test_legacy_query_analyzer.py
test/python/api/search/test_query_analyzer_factory.py
test/python/api/test_api_connection.py
test/python/api/test_api_deletable_v1.py
test/python/api/test_api_polygons_v1.py

index 6c4c37d7e824b681e91f3de6ace74968cc440813..ac5798625cc4900c8de1227892ce67da0716a0bd 100644 (file)
@@ -38,6 +38,8 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
         This class shares most of the functions with its synchronous
         version. There are some additional functions or parameters,
         which are documented below.
+
+        This class should usually be used as a context manager in 'with' context.
     """
     def __init__(self, project_dir: Path,
                  environ: Optional[Mapping[str, str]] = None,
@@ -166,6 +168,14 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
             await self._engine.dispose()
 
 
+    async def __aenter__(self) -> 'NominatimAPIAsync':
+        return self
+
+
+    async def __aexit__(self, *_: Any) -> None:
+        await self.close()
+
+
     @contextlib.asynccontextmanager
     async def begin(self) -> AsyncIterator[SearchConnection]:
         """ Create a new connection with automatic transaction handling.
@@ -351,6 +361,8 @@ class NominatimAPI:
     """ This class provides a thin synchronous wrapper around the asynchronous
         Nominatim functions. It creates its own event loop and runs each
         synchronous function call to completion using that loop.
+
+        This class should usually be used as a context manager in 'with' context.
     """
 
     def __init__(self, project_dir: Path,
@@ -376,8 +388,17 @@ class NominatimAPI:
             This function also closes the asynchronous worker loop making
             the NominatimAPI object unusable.
         """
-        self._loop.run_until_complete(self._async_api.close())
-        self._loop.close()
+        if not self._loop.is_closed():
+            self._loop.run_until_complete(self._async_api.close())
+            self._loop.close()
+
+
+    def __enter__(self) -> 'NominatimAPI':
+        return self
+
+
+    def __exit__(self, *_: Any) -> None:
+        self.close()
 
 
     @property
index a902e2640a7996a5cedbbd3765cb25f593e0f3a3..0c770980acdada423eb2e8879503c53b744a1b01 100644 (file)
@@ -9,6 +9,7 @@ Helper fixtures for API call tests.
 """
 from pathlib import Path
 import pytest
+import pytest_asyncio
 import time
 import datetime as dt
 
@@ -244,3 +245,9 @@ def frontend(request, event_loop, tmp_path):
 
     for api in testapis:
         api.close()
+
+
+@pytest_asyncio.fixture
+async def api(temp_db):
+    async with napi.NominatimAPIAsync(Path('/invalid')) as api:
+        yield api
index 8e5480fcfb9ce49693fea8dc23e2f5bd5e43b476..7f88879c14fd7d8c0a856997432bc64b007b2d96 100644 (file)
@@ -40,10 +40,9 @@ async def conn(table_factory):
     table_factory('word',
                   definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
 
-    api = NominatimAPIAsync(Path('/invalid'), {})
-    async with api.begin() as conn:
-        yield conn
-    await api.close()
+    async with NominatimAPIAsync(Path('/invalid'), {}) as api:
+        async with api.begin() as conn:
+            yield conn
 
 
 @pytest.mark.asyncio
index 92de8706f149619b03a6423069751cdfe3f8b577..0e967c10fa5f8e062825fc19e491352eee087bb1 100644 (file)
@@ -74,10 +74,9 @@ async def conn(table_factory, temp_db_cursor):
     temp_db_cursor.execute("""CREATE OR REPLACE FUNCTION make_standard_name(name TEXT)
                               RETURNS TEXT AS $$ SELECT lower(name); $$ LANGUAGE SQL;""")
 
-    api = NominatimAPIAsync(Path('/invalid'), {})
-    async with api.begin() as conn:
-        yield conn
-    await api.close()
+    async with NominatimAPIAsync(Path('/invalid'), {}) as api:
+        async with api.begin() as conn:
+            yield conn
 
 
 @pytest.mark.asyncio
index 9545a88ff21ee03d1456205a3699841e649f5de2..42220b55958116a5ab6f26448e6a8431539e486f 100644 (file)
@@ -11,41 +11,35 @@ from pathlib import Path
 
 import pytest
 
-from nominatim_api import NominatimAPIAsync
 from nominatim_api.search.query_analyzer_factory import make_query_analyzer
 from nominatim_api.search.icu_tokenizer import ICUQueryAnalyzer
 
 @pytest.mark.asyncio
-async def test_import_icu_tokenizer(table_factory):
+async def test_import_icu_tokenizer(table_factory, api):
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT',
                   content=(('tokenizer', 'icu'),
                            ('tokenizer_import_normalisation', ':: lower();'),
                            ('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '")))
 
-    api = NominatimAPIAsync(Path('/invalid'), {})
     async with api.begin() as conn:
         ana = await make_query_analyzer(conn)
 
         assert isinstance(ana, ICUQueryAnalyzer)
-    await api.close()
 
 
 @pytest.mark.asyncio
-async def test_import_missing_property(table_factory):
-    api = NominatimAPIAsync(Path('/invalid'), {})
+async def test_import_missing_property(table_factory, api):
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT')
 
     async with api.begin() as conn:
         with pytest.raises(ValueError, match='Property.*not found'):
             await make_query_analyzer(conn)
-    await api.close()
 
 
 @pytest.mark.asyncio
-async def test_import_missing_module(table_factory):
-    api = NominatimAPIAsync(Path('/invalid'), {})
+async def test_import_missing_module(table_factory, api):
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT',
                   content=(('tokenizer', 'missing'),))
@@ -53,5 +47,3 @@ async def test_import_missing_module(table_factory):
     async with api.begin() as conn:
         with pytest.raises(RuntimeError, match='Tokenizer not found'):
             await make_query_analyzer(conn)
-    await api.close()
-
index 3c4fc61b2656938ff2c2f234cf6aeed2b20b6ed8..f62b0d9e346c8357fe9ff5690741bdada8114935 100644 (file)
@@ -9,45 +9,34 @@ Tests for enhanced connection class for API functions.
 """
 from pathlib import Path
 import pytest
-import pytest_asyncio
 
 import sqlalchemy as sa
 
-from nominatim_api import NominatimAPIAsync
-
-@pytest_asyncio.fixture
-async def apiobj(temp_db):
-    """ Create an asynchronous SQLAlchemy engine for the test DB.
-    """
-    api = NominatimAPIAsync(Path('/invalid'), {})
-    yield api
-    await api.close()
-
 
 @pytest.mark.asyncio
-async def test_run_scalar(apiobj, table_factory):
+async def test_run_scalar(api, table_factory):
     table_factory('foo', definition='that TEXT', content=(('a', ),))
 
-    async with apiobj.begin() as conn:
+    async with api.begin() as conn:
         assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a'
 
 
 @pytest.mark.asyncio
-async def test_run_execute(apiobj, table_factory):
+async def test_run_execute(api, table_factory):
     table_factory('foo', definition='that TEXT', content=(('a', ),))
 
-    async with apiobj.begin() as conn:
+    async with api.begin() as conn:
         result = await conn.execute(sa.text('SELECT * FROM foo'))
         assert result.fetchone()[0] == 'a'
 
 
 @pytest.mark.asyncio
-async def test_get_property_existing_cached(apiobj, table_factory):
+async def test_get_property_existing_cached(api, table_factory):
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT',
                   content=(('dbv', '96723'), ))
 
-    async with apiobj.begin() as conn:
+    async with api.begin() as conn:
         assert await conn.get_property('dbv') == '96723'
 
         await conn.execute(sa.text('TRUNCATE nominatim_properties'))
@@ -56,12 +45,12 @@ async def test_get_property_existing_cached(apiobj, table_factory):
 
 
 @pytest.mark.asyncio
-async def test_get_property_existing_uncached(apiobj, table_factory):
+async def test_get_property_existing_uncached(api, table_factory):
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT',
                   content=(('dbv', '96723'), ))
 
-    async with apiobj.begin() as conn:
+    async with api.begin() as conn:
         assert await conn.get_property('dbv') == '96723'
 
         await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'"))
@@ -71,23 +60,23 @@ async def test_get_property_existing_uncached(apiobj, table_factory):
 
 @pytest.mark.asyncio
 @pytest.mark.parametrize('param', ['foo', 'DB:server_version'])
-async def test_get_property_missing(apiobj, table_factory, param):
+async def test_get_property_missing(api, table_factory, param):
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT')
 
-    async with apiobj.begin() as conn:
+    async with api.begin() as conn:
         with pytest.raises(ValueError):
             await conn.get_property(param)
 
 
 @pytest.mark.asyncio
-async def test_get_db_property_existing(apiobj):
-    async with apiobj.begin() as conn:
+async def test_get_db_property_existing(api):
+    async with api.begin() as conn:
         assert await conn.get_db_property('server_version') > 0
 
 
 @pytest.mark.asyncio
-async def test_get_db_property_existing(apiobj):
-    async with apiobj.begin() as conn:
+async def test_get_db_property_existing(api):
+    async with api.begin() as conn:
         with pytest.raises(ValueError):
             await conn.get_db_property('dfkgjd.rijg')
index 649dd8fc44f4619838ee7730da2e163698c0f462..9e1138869e8d632b322dce7d7306c7f9cd920f26 100644 (file)
@@ -11,19 +11,10 @@ import json
 from pathlib import Path
 
 import pytest
-import pytest_asyncio
 
 from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
 
 import nominatim_api.v1.server_glue as glue
-import nominatim_api as napi
-
-@pytest_asyncio.fixture
-async def api():
-    api = napi.NominatimAPIAsync(Path('/invalid'))
-    yield api
-    await api.close()
-
 
 class TestDeletableEndPoint:
 
@@ -61,4 +52,3 @@ class TestDeletableEndPoint:
                            {'place_id': 3, 'country_code': 'cd', 'name': None,
                             'osm_id': 781, 'osm_type': 'R',
                             'class': 'landcover', 'type': 'grass'}]
-
index 558be813e4d1b8c8183cd7ae5ae1826bb4d39149..ac2b4cb9fc09cc2e8662bb2d44a856fde6d0c33b 100644 (file)
@@ -12,19 +12,10 @@ import datetime as dt
 from pathlib import Path
 
 import pytest
-import pytest_asyncio
 
 from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
 
 import nominatim_api.v1.server_glue as glue
-import nominatim_api as napi
-
-@pytest_asyncio.fixture
-async def api():
-    api = napi.NominatimAPIAsync(Path('/invalid'))
-    yield api
-    await api.close()
-
 
 class TestPolygonsEndPoint: