]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/server/starlette/server.py
Correct some typos
[nominatim.git] / nominatim / server / starlette / server.py
index f89e52a151dac89cf205ce25964f4483cbd5e272..c98289915269fbefa8f56dea30f25e74a7893d3b 100644 (file)
@@ -7,14 +7,15 @@
 """
 Server implementation using the starlette webserver framework.
 """
 """
 Server implementation using the starlette webserver framework.
 """
-from typing import Any, Optional, Mapping, Callable, cast, Coroutine
+from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
 from pathlib import Path
 import datetime as dt
 from pathlib import Path
 import datetime as dt
+import asyncio
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
-from starlette.responses import Response
+from starlette.responses import Response, PlainTextResponse, HTMLResponse
 from starlette.requests import Request
 from starlette.middleware import Middleware
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 from starlette.requests import Request
 from starlette.middleware import Middleware
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
@@ -22,6 +23,7 @@ from starlette.middleware.cors import CORSMiddleware
 
 from nominatim.api import NominatimAPIAsync
 import nominatim.api.v1 as api_impl
 
 from nominatim.api import NominatimAPIAsync
 import nominatim.api.v1 as api_impl
+import nominatim.api.logging as loglib
 from nominatim.config import Configuration
 
 class ParamWrapper(api_impl.ASGIAdaptor):
 from nominatim.config import Configuration
 
 class ParamWrapper(api_impl.ASGIAdaptor):
@@ -50,6 +52,19 @@ class ParamWrapper(api_impl.ASGIAdaptor):
         return Response(output, status_code=status, media_type=self.content_type)
 
 
         return Response(output, status_code=status, media_type=self.content_type)
 
 
+    def base_uri(self) -> str:
+        scheme = self.request.url.scheme
+        host = self.request.url.hostname
+        port = self.request.url.port
+        root = self.request.scope['root_path']
+        if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
+            port = None
+        if port is not None:
+            return f"{scheme}://{host}:{port}{root}"
+
+        return f"{scheme}://{host}{root}"
+
+
     def config(self) -> Configuration:
         return cast(Configuration, self.request.app.state.API.config)
 
     def config(self) -> Configuration:
         return cast(Configuration, self.request.app.state.API.config)
 
@@ -80,7 +95,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
 
         finish = dt.datetime.now(tz=dt.timezone.utc)
 
 
         finish = dt.datetime.now(tz=dt.timezone.utc)
 
-        for endpoint in ('reverse', 'search', 'lookup'):
+        for endpoint in ('reverse', 'search', 'lookup', 'details'):
             if request.url.path.startswith('/' + endpoint):
                 qtype = endpoint
                 break
             if request.url.path.startswith('/' + endpoint):
                 qtype = endpoint
                 break
@@ -97,6 +112,19 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
         return response
 
 
         return response
 
 
+async def timeout_error(request: Request, #pylint: disable=unused-argument
+                        _: Exception) -> Response:
+    """ Error handler for query timeouts.
+    """
+    loglib.log().comment('Aborted: Query took too long to process.')
+    logdata = loglib.get_and_disable()
+
+    if logdata:
+        return HTMLResponse(logdata)
+
+    return PlainTextResponse("Query took too long to process.", status_code=503)
+
+
 def get_application(project_dir: Path,
                     environ: Optional[Mapping[str, str]] = None,
                     debug: bool = True) -> Starlette:
 def get_application(project_dir: Path,
                     environ: Optional[Mapping[str, str]] = None,
                     debug: bool = True) -> Starlette:
@@ -123,10 +151,16 @@ def get_application(project_dir: Path,
     if log_file:
         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
 
     if log_file:
         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
 
+    exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
+        TimeoutError: timeout_error,
+        asyncio.TimeoutError: timeout_error
+    }
+
     async def _shutdown() -> None:
         await app.state.API.close()
 
     app = Starlette(debug=debug, routes=routes, middleware=middleware,
     async def _shutdown() -> None:
         await app.state.API.close()
 
     app = Starlette(debug=debug, routes=routes, middleware=middleware,
+                    exception_handlers=exceptions,
                     on_shutdown=[_shutdown])
 
     app.state.API = NominatimAPIAsync(project_dir, environ)
                     on_shutdown=[_shutdown])
 
     app.state.API = NominatimAPIAsync(project_dir, environ)