]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/core.py
make NominatimAPI[Async] a context manager
[nominatim.git] / src / nominatim_api / core.py
index 632c97a7a6af93d387a28bc070f22d41f67d3f33..ac5798625cc4900c8de1227892ce67da0716a0bd 100644 (file)
@@ -7,7 +7,7 @@
 """
 Implementation of classes for API access via libraries.
 """
-from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, Tuple
+from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, Tuple, cast
 import asyncio
 import sys
 import contextlib
@@ -38,6 +38,8 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
         This class shares most of the functions with its synchronous
         version. There are some additional functions or parameters,
         which are documented below.
+
+        This class should usually be used as a context manager in 'with' context.
     """
     def __init__(self, project_dir: Path,
                  environ: Optional[Mapping[str, str]] = None,
@@ -107,16 +109,16 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
                     raise UsageError(f"SQlite database '{params.get('dbname')}' does not exist.")
             else:
                 dsn = self.config.get_database_params()
-                query = {k: v for k, v in dsn.items()
+                query = {k: str(v) for k, v in dsn.items()
                          if k not in ('user', 'password', 'dbname', 'host', 'port')}
 
                 dburl = sa.engine.URL.create(
                            f'postgresql+{PGCORE_LIB}',
-                           database=dsn.get('dbname'),
-                           username=dsn.get('user'),
-                           password=dsn.get('password'),
-                           host=dsn.get('host'),
-                           port=int(dsn['port']) if 'port' in dsn else None,
+                           database=cast(str, dsn.get('dbname')),
+                           username=cast(str, dsn.get('user')),
+                           password=cast(str, dsn.get('password')),
+                           host=cast(str, dsn.get('host')),
+                           port=int(cast(str, dsn['port'])) if 'port' in dsn else None,
                            query=query)
 
             engine = sa_asyncio.create_async_engine(dburl, **extra_args)
@@ -166,6 +168,14 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
             await self._engine.dispose()
 
 
+    async def __aenter__(self) -> 'NominatimAPIAsync':
+        return self
+
+
+    async def __aexit__(self, *_: Any) -> None:
+        await self.close()
+
+
     @contextlib.asynccontextmanager
     async def begin(self) -> AsyncIterator[SearchConnection]:
         """ Create a new connection with automatic transaction handling.
@@ -351,6 +361,8 @@ class NominatimAPI:
     """ This class provides a thin synchronous wrapper around the asynchronous
         Nominatim functions. It creates its own event loop and runs each
         synchronous function call to completion using that loop.
+
+        This class should usually be used as a context manager in 'with' context.
     """
 
     def __init__(self, project_dir: Path,
@@ -376,13 +388,22 @@ class NominatimAPI:
             This function also closes the asynchronous worker loop making
             the NominatimAPI object unusable.
         """
-        self._loop.run_until_complete(self._async_api.close())
-        self._loop.close()
+        if not self._loop.is_closed():
+            self._loop.run_until_complete(self._async_api.close())
+            self._loop.close()
+
+
+    def __enter__(self) -> 'NominatimAPI':
+        return self
+
+
+    def __exit__(self, *_: Any) -> None:
+        self.close()
 
 
     @property
     def config(self) -> Configuration:
-        """ Provide read-only access to the [configuration](#Configuration)
+        """ Provide read-only access to the [configuration](Configuration.md)
             used by the API.
         """
         return self._async_api.config