X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/d97ca9fcb2ae54080d6b92c7e85ff111308bc1cc..38369ca3cfe6e52bb6f7589c714a04294497520e:/nominatim/server/starlette/server.py?ds=sidebyside diff --git a/nominatim/server/starlette/server.py b/nominatim/server/starlette/server.py index 2bcc8df5..33ab22c7 100644 --- a/nominatim/server/starlette/server.py +++ b/nominatim/server/starlette/server.py @@ -7,14 +7,14 @@ """ Server implementation using the starlette webserver framework. """ -from typing import Any, Optional, Mapping, Callable, cast, Coroutine +from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable from pathlib import Path import datetime as dt from starlette.applications import Starlette from starlette.routing import Route from starlette.exceptions import HTTPException -from starlette.responses import Response +from starlette.responses import Response, PlainTextResponse from starlette.requests import Request from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint @@ -50,6 +50,19 @@ class ParamWrapper(api_impl.ASGIAdaptor): return Response(output, status_code=status, media_type=self.content_type) + def base_uri(self) -> str: + scheme = self.request.url.scheme + host = self.request.url.hostname + port = self.request.url.port + root = self.request.scope['root_path'] + if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443): + port = None + if port is not None: + return f"{scheme}://{host}:{port}{root}" + + return f"{scheme}://{host}{root}" + + def config(self) -> Configuration: return cast(Configuration, self.request.app.state.API.config) @@ -80,7 +93,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware): finish = dt.datetime.now(tz=dt.timezone.utc) - for endpoint in ('reverse', 'search', 'lookup'): + for endpoint in ('reverse', 'search', 'lookup', 'details'): if request.url.path.startswith('/' + endpoint): qtype = endpoint break @@ -97,6 +110,13 @@ class FileLoggingMiddleware(BaseHTTPMiddleware): return response +async def timeout_error(request: Request, #pylint: disable=unused-argument + _: Exception) -> Response: + """ Error handler for query timeouts. + """ + return PlainTextResponse("Query took too long to process.", status_code=503) + + def get_application(project_dir: Path, environ: Optional[Mapping[str, str]] = None, debug: bool = True) -> Starlette: @@ -114,16 +134,24 @@ def get_application(project_dir: Path, middleware = [] if config.get_bool('CORS_NOACCESSCONTROL'): - middleware.append(Middleware(CORSMiddleware, allow_origins=['*'])) + middleware.append(Middleware(CORSMiddleware, + allow_origins=['*'], + allow_methods=['GET', 'OPTIONS'], + max_age=86400)) log_file = config.LOG_FILE if log_file: middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file)) + exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = { + TimeoutError: timeout_error + } + async def _shutdown() -> None: await app.state.API.close() app = Starlette(debug=debug, routes=routes, middleware=middleware, + exception_handlers=exceptions, on_shutdown=[_shutdown]) app.state.API = NominatimAPIAsync(project_dir, environ)