X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/4e0602919cfdad274a6d04c9798f2d61f1b03cf3..20d0fb35ce9d4d7c006a0e77dcf25edc2e8509b3:/src/nominatim_api/server/starlette/server.py diff --git a/src/nominatim_api/server/starlette/server.py b/src/nominatim_api/server/starlette/server.py index 60a81321..e6c97693 100644 --- a/src/nominatim_api/server/starlette/server.py +++ b/src/nominatim_api/server/starlette/server.py @@ -7,10 +7,12 @@ """ Server implementation using the starlette webserver framework. """ -from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable +from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \ + Awaitable, AsyncIterator from pathlib import Path import datetime as dt import asyncio +import contextlib from starlette.applications import Starlette from starlette.routing import Route @@ -24,9 +26,11 @@ from starlette.middleware.cors import CORSMiddleware from ...config import Configuration from ...core import NominatimAPIAsync from ... import v1 as api_impl +from ...result_formatting import FormatDispatcher, load_format_dispatcher from ..asgi_adaptor import ASGIAdaptor, EndpointFunc from ... import logging as loglib + class ParamWrapper(ASGIAdaptor): """ Adaptor class for server glue to Starlette framework. """ @@ -34,25 +38,20 @@ class ParamWrapper(ASGIAdaptor): def __init__(self, request: Request) -> None: self.request = request - def get(self, name: str, default: Optional[str] = None) -> Optional[str]: return self.request.query_params.get(name, default=default) - def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]: return self.request.headers.get(name, default) - def error(self, msg: str, status: int = 400) -> HTTPException: return HTTPException(status, detail=msg, headers={'content-type': self.content_type}) - def create_response(self, status: int, output: str, num_results: int) -> Response: self.request.state.num_results = num_results return Response(output, status_code=status, media_type=self.content_type) - def base_uri(self) -> str: scheme = self.request.url.scheme host = self.request.url.hostname @@ -65,10 +64,12 @@ class ParamWrapper(ASGIAdaptor): return f"{scheme}://{host}{root}" - def config(self) -> Configuration: return cast(Configuration, self.request.app.state.API.config) + def formatting(self) -> FormatDispatcher: + return cast(FormatDispatcher, self.request.app.state.formatter) + def _wrap_endpoint(func: EndpointFunc)\ -> Callable[[Request], Coroutine[Any, Any, Response]]: @@ -84,7 +85,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware): def __init__(self, app: Starlette, file_name: str = ''): super().__init__(app) - self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732 + self.fd = open(file_name, 'a', buffering=1, encoding='utf8') async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: @@ -113,7 +114,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware): return response -async def timeout_error(request: Request, #pylint: disable=unused-argument +async def timeout_error(request: Request, _: Exception) -> Response: """ Error handler for query timeouts. """ @@ -133,14 +134,6 @@ def get_application(project_dir: Path, """ config = Configuration(project_dir, environ) - routes = [] - legacy_urls = config.get_bool('SERVE_LEGACY_URLS') - for name, func in api_impl.ROUTES: - endpoint = _wrap_endpoint(func) - routes.append(Route(f"/{name}", endpoint=endpoint)) - if legacy_urls: - routes.append(Route(f"/{name}.php", endpoint=endpoint)) - middleware = [] if config.get_bool('CORS_NOACCESSCONTROL'): middleware.append(Middleware(CORSMiddleware, @@ -157,14 +150,27 @@ def get_application(project_dir: Path, asyncio.TimeoutError: timeout_error } - async def _shutdown() -> None: + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[Any]: + app.state.API = NominatimAPIAsync(project_dir, environ) + config = app.state.API.config + + legacy_urls = config.get_bool('SERVE_LEGACY_URLS') + for name, func in await api_impl.get_routes(app.state.API): + endpoint = _wrap_endpoint(func) + app.routes.append(Route(f"/{name}", endpoint=endpoint)) + if legacy_urls: + app.routes.append(Route(f"/{name}.php", endpoint=endpoint)) + + yield + await app.state.API.close() - app = Starlette(debug=debug, routes=routes, middleware=middleware, + app = Starlette(debug=debug, middleware=middleware, exception_handlers=exceptions, - on_shutdown=[_shutdown]) + lifespan=lifespan) - app.state.API = NominatimAPIAsync(project_dir, environ) + app.state.formatter = load_format_dispatcher('v1', project_dir) return app