]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_api/server/starlette/server.py
fix style issue found by flake8
[nominatim.git] / src / nominatim_api / server / starlette / server.py
index dd35cd6e9e55072408bd77ed23636ed1c3d90fd0..48f0207ac7b2f21e07a16ca83ad627d984b74200 100644 (file)
@@ -21,37 +21,35 @@ from starlette.middleware import Middleware
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 from starlette.middleware.cors import CORSMiddleware
 
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 from starlette.middleware.cors import CORSMiddleware
 
-from nominatim_core.config import Configuration
+from ...config import Configuration
 from ...core import NominatimAPIAsync
 from ... import v1 as api_impl
 from ...core import NominatimAPIAsync
 from ... import v1 as api_impl
+from ...result_formatting import FormatDispatcher, load_format_dispatcher
+from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
 from ... import logging as loglib
 
 from ... import logging as loglib
 
-class ParamWrapper(api_impl.ASGIAdaptor):
+
+class ParamWrapper(ASGIAdaptor):
     """ Adaptor class for server glue to Starlette framework.
     """
 
     def __init__(self, request: Request) -> None:
         self.request = request
 
     """ Adaptor class for server glue to Starlette framework.
     """
 
     def __init__(self, request: Request) -> None:
         self.request = request
 
-
     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
         return self.request.query_params.get(name, default=default)
 
     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
         return self.request.query_params.get(name, default=default)
 
-
     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
         return self.request.headers.get(name, default)
 
     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
         return self.request.headers.get(name, default)
 
-
     def error(self, msg: str, status: int = 400) -> HTTPException:
         return HTTPException(status, detail=msg,
                              headers={'content-type': self.content_type})
 
     def error(self, msg: str, status: int = 400) -> HTTPException:
         return HTTPException(status, detail=msg,
                              headers={'content-type': self.content_type})
 
-
     def create_response(self, status: int, output: str, num_results: int) -> Response:
         self.request.state.num_results = num_results
         return Response(output, status_code=status, media_type=self.content_type)
 
     def create_response(self, status: int, output: str, num_results: int) -> Response:
         self.request.state.num_results = num_results
         return Response(output, status_code=status, media_type=self.content_type)
 
-
     def base_uri(self) -> str:
         scheme = self.request.url.scheme
         host = self.request.url.hostname
     def base_uri(self) -> str:
         scheme = self.request.url.scheme
         host = self.request.url.hostname
@@ -64,12 +62,14 @@ class ParamWrapper(api_impl.ASGIAdaptor):
 
         return f"{scheme}://{host}{root}"
 
 
         return f"{scheme}://{host}{root}"
 
-
     def config(self) -> Configuration:
         return cast(Configuration, self.request.app.state.API.config)
 
     def config(self) -> Configuration:
         return cast(Configuration, self.request.app.state.API.config)
 
+    def formatting(self) -> FormatDispatcher:
+        return cast(FormatDispatcher, self.request.app.state.API.formatter)
+
 
 
-def _wrap_endpoint(func: api_impl.EndpointFunc)\
+def _wrap_endpoint(func: EndpointFunc)\
         -> Callable[[Request], Coroutine[Any, Any, Response]]:
     async def _callback(request: Request) -> Response:
         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
         -> Callable[[Request], Coroutine[Any, Any, Response]]:
     async def _callback(request: Request) -> Response:
         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
@@ -83,7 +83,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
 
     def __init__(self, app: Starlette, file_name: str = ''):
         super().__init__(app)
 
     def __init__(self, app: Starlette, file_name: str = ''):
         super().__init__(app)
-        self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
+        self.fd = open(file_name, 'a', buffering=1, encoding='utf8')
 
     async def dispatch(self, request: Request,
                        call_next: RequestResponseEndpoint) -> Response:
 
     async def dispatch(self, request: Request,
                        call_next: RequestResponseEndpoint) -> Response:
@@ -112,7 +112,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
         return response
 
 
         return response
 
 
-async def timeout_error(request: Request, #pylint: disable=unused-argument
+async def timeout_error(request: Request,
                         _: Exception) -> Response:
     """ Error handler for query timeouts.
     """
                         _: Exception) -> Response:
     """ Error handler for query timeouts.
     """
@@ -164,6 +164,7 @@ def get_application(project_dir: Path,
                     on_shutdown=[_shutdown])
 
     app.state.API = NominatimAPIAsync(project_dir, environ)
                     on_shutdown=[_shutdown])
 
     app.state.API = NominatimAPIAsync(project_dir, environ)
+    app.state.formatter = load_format_dispatcher('v1', project_dir)
 
     return app
 
 
     return app