X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/6e89310a9285f1ad15d8002bf68f578eada367a0..9f417d5f64aecf72d4a3bd22d6e9d6c8fac76a51:/src/nominatim_api/server/starlette/server.py diff --git a/src/nominatim_api/server/starlette/server.py b/src/nominatim_api/server/starlette/server.py index dd35cd6e..e6c97693 100644 --- a/src/nominatim_api/server/starlette/server.py +++ b/src/nominatim_api/server/starlette/server.py @@ -7,10 +7,12 @@ """ Server implementation using the starlette webserver framework. """ -from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable +from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \ + Awaitable, AsyncIterator from pathlib import Path import datetime as dt import asyncio +import contextlib from starlette.applications import Starlette from starlette.routing import Route @@ -21,37 +23,35 @@ from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.middleware.cors import CORSMiddleware -from nominatim_core.config import Configuration +from ...config import Configuration from ...core import NominatimAPIAsync from ... import v1 as api_impl +from ...result_formatting import FormatDispatcher, load_format_dispatcher +from ..asgi_adaptor import ASGIAdaptor, EndpointFunc from ... import logging as loglib -class ParamWrapper(api_impl.ASGIAdaptor): + +class ParamWrapper(ASGIAdaptor): """ Adaptor class for server glue to Starlette framework. """ def __init__(self, request: Request) -> None: self.request = request - def get(self, name: str, default: Optional[str] = None) -> Optional[str]: return self.request.query_params.get(name, default=default) - def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]: return self.request.headers.get(name, default) - def error(self, msg: str, status: int = 400) -> HTTPException: return HTTPException(status, detail=msg, headers={'content-type': self.content_type}) - def create_response(self, status: int, output: str, num_results: int) -> Response: self.request.state.num_results = num_results return Response(output, status_code=status, media_type=self.content_type) - def base_uri(self) -> str: scheme = self.request.url.scheme host = self.request.url.hostname @@ -64,12 +64,14 @@ class ParamWrapper(api_impl.ASGIAdaptor): return f"{scheme}://{host}{root}" - def config(self) -> Configuration: return cast(Configuration, self.request.app.state.API.config) + def formatting(self) -> FormatDispatcher: + return cast(FormatDispatcher, self.request.app.state.formatter) + -def _wrap_endpoint(func: api_impl.EndpointFunc)\ +def _wrap_endpoint(func: EndpointFunc)\ -> Callable[[Request], Coroutine[Any, Any, Response]]: async def _callback(request: Request) -> Response: return cast(Response, await func(request.app.state.API, ParamWrapper(request))) @@ -83,7 +85,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware): def __init__(self, app: Starlette, file_name: str = ''): super().__init__(app) - self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732 + self.fd = open(file_name, 'a', buffering=1, encoding='utf8') async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: @@ -112,7 +114,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware): return response -async def timeout_error(request: Request, #pylint: disable=unused-argument +async def timeout_error(request: Request, _: Exception) -> Response: """ Error handler for query timeouts. """ @@ -132,14 +134,6 @@ def get_application(project_dir: Path, """ config = Configuration(project_dir, environ) - routes = [] - legacy_urls = config.get_bool('SERVE_LEGACY_URLS') - for name, func in api_impl.ROUTES: - endpoint = _wrap_endpoint(func) - routes.append(Route(f"/{name}", endpoint=endpoint)) - if legacy_urls: - routes.append(Route(f"/{name}.php", endpoint=endpoint)) - middleware = [] if config.get_bool('CORS_NOACCESSCONTROL'): middleware.append(Middleware(CORSMiddleware, @@ -156,14 +150,27 @@ def get_application(project_dir: Path, asyncio.TimeoutError: timeout_error } - async def _shutdown() -> None: + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[Any]: + app.state.API = NominatimAPIAsync(project_dir, environ) + config = app.state.API.config + + legacy_urls = config.get_bool('SERVE_LEGACY_URLS') + for name, func in await api_impl.get_routes(app.state.API): + endpoint = _wrap_endpoint(func) + app.routes.append(Route(f"/{name}", endpoint=endpoint)) + if legacy_urls: + app.routes.append(Route(f"/{name}.php", endpoint=endpoint)) + + yield + await app.state.API.close() - app = Starlette(debug=debug, routes=routes, middleware=middleware, + app = Starlette(debug=debug, middleware=middleware, exception_handlers=exceptions, - on_shutdown=[_shutdown]) + lifespan=lifespan) - app.state.API = NominatimAPIAsync(project_dir, environ) + app.state.formatter = load_format_dispatcher('v1', project_dir) return app