]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/server/starlette/server.py
Merge pull request #3590 from lonvia/lookup-per-osm-type
[nominatim.git] / src / nominatim_api / server / starlette / server.py
index 48f0207ac7b2f21e07a16ca83ad627d984b74200..e6c97693dd8ab03f71bdbfa07129b83c044d7342 100644 (file)
@@ -7,10 +7,12 @@
 """
 Server implementation using the starlette webserver framework.
 """
 """
 Server implementation using the starlette webserver framework.
 """
-from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
+from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \
+                   Awaitable, AsyncIterator
 from pathlib import Path
 import datetime as dt
 import asyncio
 from pathlib import Path
 import datetime as dt
 import asyncio
+import contextlib
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 
 from starlette.applications import Starlette
 from starlette.routing import Route
@@ -66,7 +68,7 @@ class ParamWrapper(ASGIAdaptor):
         return cast(Configuration, self.request.app.state.API.config)
 
     def formatting(self) -> FormatDispatcher:
         return cast(Configuration, self.request.app.state.API.config)
 
     def formatting(self) -> FormatDispatcher:
-        return cast(FormatDispatcher, self.request.app.state.API.formatter)
+        return cast(FormatDispatcher, self.request.app.state.formatter)
 
 
 def _wrap_endpoint(func: EndpointFunc)\
 
 
 def _wrap_endpoint(func: EndpointFunc)\
@@ -132,14 +134,6 @@ def get_application(project_dir: Path,
     """
     config = Configuration(project_dir, environ)
 
     """
     config = Configuration(project_dir, environ)
 
-    routes = []
-    legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
-    for name, func in api_impl.ROUTES:
-        endpoint = _wrap_endpoint(func)
-        routes.append(Route(f"/{name}", endpoint=endpoint))
-        if legacy_urls:
-            routes.append(Route(f"/{name}.php", endpoint=endpoint))
-
     middleware = []
     if config.get_bool('CORS_NOACCESSCONTROL'):
         middleware.append(Middleware(CORSMiddleware,
     middleware = []
     if config.get_bool('CORS_NOACCESSCONTROL'):
         middleware.append(Middleware(CORSMiddleware,
@@ -156,14 +150,26 @@ def get_application(project_dir: Path,
         asyncio.TimeoutError: timeout_error
     }
 
         asyncio.TimeoutError: timeout_error
     }
 
-    async def _shutdown() -> None:
+    @contextlib.asynccontextmanager
+    async def lifespan(app: Starlette) -> AsyncIterator[Any]:
+        app.state.API = NominatimAPIAsync(project_dir, environ)
+        config = app.state.API.config
+
+        legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
+        for name, func in await api_impl.get_routes(app.state.API):
+            endpoint = _wrap_endpoint(func)
+            app.routes.append(Route(f"/{name}", endpoint=endpoint))
+            if legacy_urls:
+                app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
+
+        yield
+
         await app.state.API.close()
 
         await app.state.API.close()
 
-    app = Starlette(debug=debug, routes=routes, middleware=middleware,
+    app = Starlette(debug=debug, middleware=middleware,
                     exception_handlers=exceptions,
                     exception_handlers=exceptions,
-                    on_shutdown=[_shutdown])
+                    lifespan=lifespan)
 
 
-    app.state.API = NominatimAPIAsync(project_dir, environ)
     app.state.formatter = load_format_dispatcher('v1', project_dir)
 
     return app
     app.state.formatter = load_format_dispatcher('v1', project_dir)
 
     return app