]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/server/starlette/server.py
Merge pull request #3293 from lonvia/rematch-against-country-code
[nominatim.git] / nominatim / server / starlette / server.py
index 2bcc8df51c37b0bb8ad00316551903fc6ad728ed..33ab22c7bcee7c2b1994ab56496acec886f38d30 100644 (file)
@@ -7,14 +7,14 @@
 """
 Server implementation using the starlette webserver framework.
 """
-from typing import Any, Optional, Mapping, Callable, cast, Coroutine
+from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
 from pathlib import Path
 import datetime as dt
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
-from starlette.responses import Response
+from starlette.responses import Response, PlainTextResponse
 from starlette.requests import Request
 from starlette.middleware import Middleware
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
@@ -50,6 +50,19 @@ class ParamWrapper(api_impl.ASGIAdaptor):
         return Response(output, status_code=status, media_type=self.content_type)
 
 
+    def base_uri(self) -> str:
+        scheme = self.request.url.scheme
+        host = self.request.url.hostname
+        port = self.request.url.port
+        root = self.request.scope['root_path']
+        if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
+            port = None
+        if port is not None:
+            return f"{scheme}://{host}:{port}{root}"
+
+        return f"{scheme}://{host}{root}"
+
+
     def config(self) -> Configuration:
         return cast(Configuration, self.request.app.state.API.config)
 
@@ -80,7 +93,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
 
         finish = dt.datetime.now(tz=dt.timezone.utc)
 
-        for endpoint in ('reverse', 'search', 'lookup'):
+        for endpoint in ('reverse', 'search', 'lookup', 'details'):
             if request.url.path.startswith('/' + endpoint):
                 qtype = endpoint
                 break
@@ -97,6 +110,13 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
         return response
 
 
+async def timeout_error(request: Request, #pylint: disable=unused-argument
+                        _: Exception) -> Response:
+    """ Error handler for query timeouts.
+    """
+    return PlainTextResponse("Query took too long to process.", status_code=503)
+
+
 def get_application(project_dir: Path,
                     environ: Optional[Mapping[str, str]] = None,
                     debug: bool = True) -> Starlette:
@@ -114,16 +134,24 @@ def get_application(project_dir: Path,
 
     middleware = []
     if config.get_bool('CORS_NOACCESSCONTROL'):
-        middleware.append(Middleware(CORSMiddleware, allow_origins=['*']))
+        middleware.append(Middleware(CORSMiddleware,
+                                     allow_origins=['*'],
+                                     allow_methods=['GET', 'OPTIONS'],
+                                     max_age=86400))
 
     log_file = config.LOG_FILE
     if log_file:
         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
 
+    exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
+        TimeoutError: timeout_error
+    }
+
     async def _shutdown() -> None:
         await app.state.API.close()
 
     app = Starlette(debug=debug, routes=routes, middleware=middleware,
+                    exception_handlers=exceptions,
                     on_shutdown=[_shutdown])
 
     app.state.API = NominatimAPIAsync(project_dir, environ)