X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/6e89310a9285f1ad15d8002bf68f578eada367a0..1f0796778754d8df0dfab9dd01302e26a397f064:/src/nominatim_api/server/starlette/server.py?ds=sidebyside diff --git a/src/nominatim_api/server/starlette/server.py b/src/nominatim_api/server/starlette/server.py index dd35cd6e..48f0207a 100644 --- a/src/nominatim_api/server/starlette/server.py +++ b/src/nominatim_api/server/starlette/server.py @@ -21,37 +21,35 @@ from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.middleware.cors import CORSMiddleware -from nominatim_core.config import Configuration +from ...config import Configuration from ...core import NominatimAPIAsync from ... import v1 as api_impl +from ...result_formatting import FormatDispatcher, load_format_dispatcher +from ..asgi_adaptor import ASGIAdaptor, EndpointFunc from ... import logging as loglib -class ParamWrapper(api_impl.ASGIAdaptor): + +class ParamWrapper(ASGIAdaptor): """ Adaptor class for server glue to Starlette framework. """ def __init__(self, request: Request) -> None: self.request = request - def get(self, name: str, default: Optional[str] = None) -> Optional[str]: return self.request.query_params.get(name, default=default) - def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]: return self.request.headers.get(name, default) - def error(self, msg: str, status: int = 400) -> HTTPException: return HTTPException(status, detail=msg, headers={'content-type': self.content_type}) - def create_response(self, status: int, output: str, num_results: int) -> Response: self.request.state.num_results = num_results return Response(output, status_code=status, media_type=self.content_type) - def base_uri(self) -> str: scheme = self.request.url.scheme host = self.request.url.hostname @@ -64,12 +62,14 @@ class ParamWrapper(api_impl.ASGIAdaptor): return f"{scheme}://{host}{root}" - def config(self) -> Configuration: return cast(Configuration, self.request.app.state.API.config) + def formatting(self) -> FormatDispatcher: + return cast(FormatDispatcher, self.request.app.state.API.formatter) + -def _wrap_endpoint(func: api_impl.EndpointFunc)\ +def _wrap_endpoint(func: EndpointFunc)\ -> Callable[[Request], Coroutine[Any, Any, Response]]: async def _callback(request: Request) -> Response: return cast(Response, await func(request.app.state.API, ParamWrapper(request))) @@ -83,7 +83,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware): def __init__(self, app: Starlette, file_name: str = ''): super().__init__(app) - self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732 + self.fd = open(file_name, 'a', buffering=1, encoding='utf8') async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: @@ -112,7 +112,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware): return response -async def timeout_error(request: Request, #pylint: disable=unused-argument +async def timeout_error(request: Request, _: Exception) -> Response: """ Error handler for query timeouts. """ @@ -164,6 +164,7 @@ def get_application(project_dir: Path, on_shutdown=[_shutdown]) app.state.API = NominatimAPIAsync(project_dir, environ) + app.state.formatter = load_format_dispatcher('v1', project_dir) return app