]> git.openstreetmap.org Git - nominatim.git/commitdiff
make NominatimAPI[Async] a context manager
authorSarah Hoffmann <lonvia@denofr.de>
Mon, 19 Aug 2024 09:31:38 +0000 (11:31 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 19 Aug 2024 09:31:38 +0000 (11:31 +0200)
If close() isn't properly called, it can lead to odd error messages
about uncaught exceptions.

src/nominatim_api/core.py
test/python/api/conftest.py
test/python/api/search/test_icu_query_analyzer.py
test/python/api/search/test_legacy_query_analyzer.py
test/python/api/search/test_query_analyzer_factory.py
test/python/api/test_api_connection.py
test/python/api/test_api_deletable_v1.py
test/python/api/test_api_polygons_v1.py

index 6c4c37d7e824b681e91f3de6ace74968cc440813..ac5798625cc4900c8de1227892ce67da0716a0bd 100644 (file)
@@ -38,6 +38,8 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
         This class shares most of the functions with its synchronous
         version. There are some additional functions or parameters,
         which are documented below.
         This class shares most of the functions with its synchronous
         version. There are some additional functions or parameters,
         which are documented below.
+
+        This class should usually be used as a context manager in 'with' context.
     """
     def __init__(self, project_dir: Path,
                  environ: Optional[Mapping[str, str]] = None,
     """
     def __init__(self, project_dir: Path,
                  environ: Optional[Mapping[str, str]] = None,
@@ -166,6 +168,14 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
             await self._engine.dispose()
 
 
             await self._engine.dispose()
 
 
+    async def __aenter__(self) -> 'NominatimAPIAsync':
+        return self
+
+
+    async def __aexit__(self, *_: Any) -> None:
+        await self.close()
+
+
     @contextlib.asynccontextmanager
     async def begin(self) -> AsyncIterator[SearchConnection]:
         """ Create a new connection with automatic transaction handling.
     @contextlib.asynccontextmanager
     async def begin(self) -> AsyncIterator[SearchConnection]:
         """ Create a new connection with automatic transaction handling.
@@ -351,6 +361,8 @@ class NominatimAPI:
     """ This class provides a thin synchronous wrapper around the asynchronous
         Nominatim functions. It creates its own event loop and runs each
         synchronous function call to completion using that loop.
     """ This class provides a thin synchronous wrapper around the asynchronous
         Nominatim functions. It creates its own event loop and runs each
         synchronous function call to completion using that loop.
+
+        This class should usually be used as a context manager in 'with' context.
     """
 
     def __init__(self, project_dir: Path,
     """
 
     def __init__(self, project_dir: Path,
@@ -376,8 +388,17 @@ class NominatimAPI:
             This function also closes the asynchronous worker loop making
             the NominatimAPI object unusable.
         """
             This function also closes the asynchronous worker loop making
             the NominatimAPI object unusable.
         """
-        self._loop.run_until_complete(self._async_api.close())
-        self._loop.close()
+        if not self._loop.is_closed():
+            self._loop.run_until_complete(self._async_api.close())
+            self._loop.close()
+
+
+    def __enter__(self) -> 'NominatimAPI':
+        return self
+
+
+    def __exit__(self, *_: Any) -> None:
+        self.close()
 
 
     @property
 
 
     @property
index a902e2640a7996a5cedbbd3765cb25f593e0f3a3..0c770980acdada423eb2e8879503c53b744a1b01 100644 (file)
@@ -9,6 +9,7 @@ Helper fixtures for API call tests.
 """
 from pathlib import Path
 import pytest
 """
 from pathlib import Path
 import pytest
+import pytest_asyncio
 import time
 import datetime as dt
 
 import time
 import datetime as dt
 
@@ -244,3 +245,9 @@ def frontend(request, event_loop, tmp_path):
 
     for api in testapis:
         api.close()
 
     for api in testapis:
         api.close()
+
+
+@pytest_asyncio.fixture
+async def api(temp_db):
+    async with napi.NominatimAPIAsync(Path('/invalid')) as api:
+        yield api
index 8e5480fcfb9ce49693fea8dc23e2f5bd5e43b476..7f88879c14fd7d8c0a856997432bc64b007b2d96 100644 (file)
@@ -40,10 +40,9 @@ async def conn(table_factory):
     table_factory('word',
                   definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
 
     table_factory('word',
                   definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
 
-    api = NominatimAPIAsync(Path('/invalid'), {})
-    async with api.begin() as conn:
-        yield conn
-    await api.close()
+    async with NominatimAPIAsync(Path('/invalid'), {}) as api:
+        async with api.begin() as conn:
+            yield conn
 
 
 @pytest.mark.asyncio
 
 
 @pytest.mark.asyncio
index 92de8706f149619b03a6423069751cdfe3f8b577..0e967c10fa5f8e062825fc19e491352eee087bb1 100644 (file)
@@ -74,10 +74,9 @@ async def conn(table_factory, temp_db_cursor):
     temp_db_cursor.execute("""CREATE OR REPLACE FUNCTION make_standard_name(name TEXT)
                               RETURNS TEXT AS $$ SELECT lower(name); $$ LANGUAGE SQL;""")
 
     temp_db_cursor.execute("""CREATE OR REPLACE FUNCTION make_standard_name(name TEXT)
                               RETURNS TEXT AS $$ SELECT lower(name); $$ LANGUAGE SQL;""")
 
-    api = NominatimAPIAsync(Path('/invalid'), {})
-    async with api.begin() as conn:
-        yield conn
-    await api.close()
+    async with NominatimAPIAsync(Path('/invalid'), {}) as api:
+        async with api.begin() as conn:
+            yield conn
 
 
 @pytest.mark.asyncio
 
 
 @pytest.mark.asyncio
index 9545a88ff21ee03d1456205a3699841e649f5de2..42220b55958116a5ab6f26448e6a8431539e486f 100644 (file)
@@ -11,41 +11,35 @@ from pathlib import Path
 
 import pytest
 
 
 import pytest
 
-from nominatim_api import NominatimAPIAsync
 from nominatim_api.search.query_analyzer_factory import make_query_analyzer
 from nominatim_api.search.icu_tokenizer import ICUQueryAnalyzer
 
 @pytest.mark.asyncio
 from nominatim_api.search.query_analyzer_factory import make_query_analyzer
 from nominatim_api.search.icu_tokenizer import ICUQueryAnalyzer
 
 @pytest.mark.asyncio
-async def test_import_icu_tokenizer(table_factory):
+async def test_import_icu_tokenizer(table_factory, api):
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT',
                   content=(('tokenizer', 'icu'),
                            ('tokenizer_import_normalisation', ':: lower();'),
                            ('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '")))
 
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT',
                   content=(('tokenizer', 'icu'),
                            ('tokenizer_import_normalisation', ':: lower();'),
                            ('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '")))
 
-    api = NominatimAPIAsync(Path('/invalid'), {})
     async with api.begin() as conn:
         ana = await make_query_analyzer(conn)
 
         assert isinstance(ana, ICUQueryAnalyzer)
     async with api.begin() as conn:
         ana = await make_query_analyzer(conn)
 
         assert isinstance(ana, ICUQueryAnalyzer)
-    await api.close()
 
 
 @pytest.mark.asyncio
 
 
 @pytest.mark.asyncio
-async def test_import_missing_property(table_factory):
-    api = NominatimAPIAsync(Path('/invalid'), {})
+async def test_import_missing_property(table_factory, api):
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT')
 
     async with api.begin() as conn:
         with pytest.raises(ValueError, match='Property.*not found'):
             await make_query_analyzer(conn)
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT')
 
     async with api.begin() as conn:
         with pytest.raises(ValueError, match='Property.*not found'):
             await make_query_analyzer(conn)
-    await api.close()
 
 
 @pytest.mark.asyncio
 
 
 @pytest.mark.asyncio
-async def test_import_missing_module(table_factory):
-    api = NominatimAPIAsync(Path('/invalid'), {})
+async def test_import_missing_module(table_factory, api):
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT',
                   content=(('tokenizer', 'missing'),))
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT',
                   content=(('tokenizer', 'missing'),))
@@ -53,5 +47,3 @@ async def test_import_missing_module(table_factory):
     async with api.begin() as conn:
         with pytest.raises(RuntimeError, match='Tokenizer not found'):
             await make_query_analyzer(conn)
     async with api.begin() as conn:
         with pytest.raises(RuntimeError, match='Tokenizer not found'):
             await make_query_analyzer(conn)
-    await api.close()
-
index 3c4fc61b2656938ff2c2f234cf6aeed2b20b6ed8..f62b0d9e346c8357fe9ff5690741bdada8114935 100644 (file)
@@ -9,45 +9,34 @@ Tests for enhanced connection class for API functions.
 """
 from pathlib import Path
 import pytest
 """
 from pathlib import Path
 import pytest
-import pytest_asyncio
 
 import sqlalchemy as sa
 
 
 import sqlalchemy as sa
 
-from nominatim_api import NominatimAPIAsync
-
-@pytest_asyncio.fixture
-async def apiobj(temp_db):
-    """ Create an asynchronous SQLAlchemy engine for the test DB.
-    """
-    api = NominatimAPIAsync(Path('/invalid'), {})
-    yield api
-    await api.close()
-
 
 @pytest.mark.asyncio
 
 @pytest.mark.asyncio
-async def test_run_scalar(apiobj, table_factory):
+async def test_run_scalar(api, table_factory):
     table_factory('foo', definition='that TEXT', content=(('a', ),))
 
     table_factory('foo', definition='that TEXT', content=(('a', ),))
 
-    async with apiobj.begin() as conn:
+    async with api.begin() as conn:
         assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a'
 
 
 @pytest.mark.asyncio
         assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a'
 
 
 @pytest.mark.asyncio
-async def test_run_execute(apiobj, table_factory):
+async def test_run_execute(api, table_factory):
     table_factory('foo', definition='that TEXT', content=(('a', ),))
 
     table_factory('foo', definition='that TEXT', content=(('a', ),))
 
-    async with apiobj.begin() as conn:
+    async with api.begin() as conn:
         result = await conn.execute(sa.text('SELECT * FROM foo'))
         assert result.fetchone()[0] == 'a'
 
 
 @pytest.mark.asyncio
         result = await conn.execute(sa.text('SELECT * FROM foo'))
         assert result.fetchone()[0] == 'a'
 
 
 @pytest.mark.asyncio
-async def test_get_property_existing_cached(apiobj, table_factory):
+async def test_get_property_existing_cached(api, table_factory):
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT',
                   content=(('dbv', '96723'), ))
 
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT',
                   content=(('dbv', '96723'), ))
 
-    async with apiobj.begin() as conn:
+    async with api.begin() as conn:
         assert await conn.get_property('dbv') == '96723'
 
         await conn.execute(sa.text('TRUNCATE nominatim_properties'))
         assert await conn.get_property('dbv') == '96723'
 
         await conn.execute(sa.text('TRUNCATE nominatim_properties'))
@@ -56,12 +45,12 @@ async def test_get_property_existing_cached(apiobj, table_factory):
 
 
 @pytest.mark.asyncio
 
 
 @pytest.mark.asyncio
-async def test_get_property_existing_uncached(apiobj, table_factory):
+async def test_get_property_existing_uncached(api, table_factory):
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT',
                   content=(('dbv', '96723'), ))
 
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT',
                   content=(('dbv', '96723'), ))
 
-    async with apiobj.begin() as conn:
+    async with api.begin() as conn:
         assert await conn.get_property('dbv') == '96723'
 
         await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'"))
         assert await conn.get_property('dbv') == '96723'
 
         await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'"))
@@ -71,23 +60,23 @@ async def test_get_property_existing_uncached(apiobj, table_factory):
 
 @pytest.mark.asyncio
 @pytest.mark.parametrize('param', ['foo', 'DB:server_version'])
 
 @pytest.mark.asyncio
 @pytest.mark.parametrize('param', ['foo', 'DB:server_version'])
-async def test_get_property_missing(apiobj, table_factory, param):
+async def test_get_property_missing(api, table_factory, param):
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT')
 
     table_factory('nominatim_properties',
                   definition='property TEXT, value TEXT')
 
-    async with apiobj.begin() as conn:
+    async with api.begin() as conn:
         with pytest.raises(ValueError):
             await conn.get_property(param)
 
 
 @pytest.mark.asyncio
         with pytest.raises(ValueError):
             await conn.get_property(param)
 
 
 @pytest.mark.asyncio
-async def test_get_db_property_existing(apiobj):
-    async with apiobj.begin() as conn:
+async def test_get_db_property_existing(api):
+    async with api.begin() as conn:
         assert await conn.get_db_property('server_version') > 0
 
 
 @pytest.mark.asyncio
         assert await conn.get_db_property('server_version') > 0
 
 
 @pytest.mark.asyncio
-async def test_get_db_property_existing(apiobj):
-    async with apiobj.begin() as conn:
+async def test_get_db_property_existing(api):
+    async with api.begin() as conn:
         with pytest.raises(ValueError):
             await conn.get_db_property('dfkgjd.rijg')
         with pytest.raises(ValueError):
             await conn.get_db_property('dfkgjd.rijg')
index 649dd8fc44f4619838ee7730da2e163698c0f462..9e1138869e8d632b322dce7d7306c7f9cd920f26 100644 (file)
@@ -11,19 +11,10 @@ import json
 from pathlib import Path
 
 import pytest
 from pathlib import Path
 
 import pytest
-import pytest_asyncio
 
 from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
 
 import nominatim_api.v1.server_glue as glue
 
 from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
 
 import nominatim_api.v1.server_glue as glue
-import nominatim_api as napi
-
-@pytest_asyncio.fixture
-async def api():
-    api = napi.NominatimAPIAsync(Path('/invalid'))
-    yield api
-    await api.close()
-
 
 class TestDeletableEndPoint:
 
 
 class TestDeletableEndPoint:
 
@@ -61,4 +52,3 @@ class TestDeletableEndPoint:
                            {'place_id': 3, 'country_code': 'cd', 'name': None,
                             'osm_id': 781, 'osm_type': 'R',
                             'class': 'landcover', 'type': 'grass'}]
                            {'place_id': 3, 'country_code': 'cd', 'name': None,
                             'osm_id': 781, 'osm_type': 'R',
                             'class': 'landcover', 'type': 'grass'}]
-
index 558be813e4d1b8c8183cd7ae5ae1826bb4d39149..ac2b4cb9fc09cc2e8662bb2d44a856fde6d0c33b 100644 (file)
@@ -12,19 +12,10 @@ import datetime as dt
 from pathlib import Path
 
 import pytest
 from pathlib import Path
 
 import pytest
-import pytest_asyncio
 
 from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
 
 import nominatim_api.v1.server_glue as glue
 
 from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
 
 import nominatim_api.v1.server_glue as glue
-import nominatim_api as napi
-
-@pytest_asyncio.fixture
-async def api():
-    api = napi.NominatimAPIAsync(Path('/invalid'))
-    yield api
-    await api.close()
-
 
 class TestPolygonsEndPoint:
 
 
 class TestPolygonsEndPoint: