]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/server/starlette/server.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / src / nominatim_api / server / starlette / server.py
index 5f5cf055072a4250c4adf619e86570f55901019b..e6c97693dd8ab03f71bdbfa07129b83c044d7342 100644 (file)
@@ -7,10 +7,12 @@
 """
 Server implementation using the starlette webserver framework.
 """
-from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
+from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \
+                   Awaitable, AsyncIterator
 from pathlib import Path
 import datetime as dt
 import asyncio
+import contextlib
 
 from starlette.applications import Starlette
 from starlette.routing import Route
@@ -24,34 +26,32 @@ from starlette.middleware.cors import CORSMiddleware
 from ...config import Configuration
 from ...core import NominatimAPIAsync
 from ... import v1 as api_impl
+from ...result_formatting import FormatDispatcher, load_format_dispatcher
+from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
 from ... import logging as loglib
 
-class ParamWrapper(api_impl.ASGIAdaptor):
+
+class ParamWrapper(ASGIAdaptor):
     """ Adaptor class for server glue to Starlette framework.
     """
 
     def __init__(self, request: Request) -> None:
         self.request = request
 
-
     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
         return self.request.query_params.get(name, default=default)
 
-
     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
         return self.request.headers.get(name, default)
 
-
     def error(self, msg: str, status: int = 400) -> HTTPException:
         return HTTPException(status, detail=msg,
                              headers={'content-type': self.content_type})
 
-
     def create_response(self, status: int, output: str, num_results: int) -> Response:
         self.request.state.num_results = num_results
         return Response(output, status_code=status, media_type=self.content_type)
 
-
     def base_uri(self) -> str:
         scheme = self.request.url.scheme
         host = self.request.url.hostname
@@ -64,12 +64,14 @@ class ParamWrapper(api_impl.ASGIAdaptor):
 
         return f"{scheme}://{host}{root}"
 
-
     def config(self) -> Configuration:
         return cast(Configuration, self.request.app.state.API.config)
 
+    def formatting(self) -> FormatDispatcher:
+        return cast(FormatDispatcher, self.request.app.state.formatter)
+
 
-def _wrap_endpoint(func: api_impl.EndpointFunc)\
+def _wrap_endpoint(func: EndpointFunc)\
         -> Callable[[Request], Coroutine[Any, Any, Response]]:
     async def _callback(request: Request) -> Response:
         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
@@ -83,7 +85,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
 
     def __init__(self, app: Starlette, file_name: str = ''):
         super().__init__(app)
-        self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
+        self.fd = open(file_name, 'a', buffering=1, encoding='utf8')
 
     async def dispatch(self, request: Request,
                        call_next: RequestResponseEndpoint) -> Response:
@@ -112,7 +114,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
         return response
 
 
-async def timeout_error(request: Request, #pylint: disable=unused-argument
+async def timeout_error(request: Request,
                         _: Exception) -> Response:
     """ Error handler for query timeouts.
     """
@@ -132,14 +134,6 @@ def get_application(project_dir: Path,
     """
     config = Configuration(project_dir, environ)
 
-    routes = []
-    legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
-    for name, func in api_impl.ROUTES:
-        endpoint = _wrap_endpoint(func)
-        routes.append(Route(f"/{name}", endpoint=endpoint))
-        if legacy_urls:
-            routes.append(Route(f"/{name}.php", endpoint=endpoint))
-
     middleware = []
     if config.get_bool('CORS_NOACCESSCONTROL'):
         middleware.append(Middleware(CORSMiddleware,
@@ -156,14 +150,27 @@ def get_application(project_dir: Path,
         asyncio.TimeoutError: timeout_error
     }
 
-    async def _shutdown() -> None:
+    @contextlib.asynccontextmanager
+    async def lifespan(app: Starlette) -> AsyncIterator[Any]:
+        app.state.API = NominatimAPIAsync(project_dir, environ)
+        config = app.state.API.config
+
+        legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
+        for name, func in await api_impl.get_routes(app.state.API):
+            endpoint = _wrap_endpoint(func)
+            app.routes.append(Route(f"/{name}", endpoint=endpoint))
+            if legacy_urls:
+                app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
+
+        yield
+
         await app.state.API.close()
 
-    app = Starlette(debug=debug, routes=routes, middleware=middleware,
+    app = Starlette(debug=debug, middleware=middleware,
                     exception_handlers=exceptions,
-                    on_shutdown=[_shutdown])
+                    lifespan=lifespan)
 
-    app.state.API = NominatimAPIAsync(project_dir, environ)
+    app.state.formatter = load_format_dispatcher('v1', project_dir)
 
     return app