]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/server/starlette/server.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / server / starlette / server.py
index f81b122f274e17ddf0e0565139b09e003663a05e..2bcc8df51c37b0bb8ad00316551903fc6ad728ed 100644 (file)
@@ -9,6 +9,7 @@ Server implementation using the starlette webserver framework.
 """
 from typing import Any, Optional, Mapping, Callable, cast, Coroutine
 from pathlib import Path
 """
 from typing import Any, Optional, Mapping, Callable, cast, Coroutine
 from pathlib import Path
+import datetime as dt
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 
 from starlette.applications import Starlette
 from starlette.routing import Route
@@ -16,6 +17,7 @@ from starlette.exceptions import HTTPException
 from starlette.responses import Response
 from starlette.requests import Request
 from starlette.middleware import Middleware
 from starlette.responses import Response
 from starlette.requests import Request
 from starlette.middleware import Middleware
+from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 from starlette.middleware.cors import CORSMiddleware
 
 from nominatim.api import NominatimAPIAsync
 from starlette.middleware.cors import CORSMiddleware
 
 from nominatim.api import NominatimAPIAsync
@@ -43,7 +45,8 @@ class ParamWrapper(api_impl.ASGIAdaptor):
                              headers={'content-type': self.content_type})
 
 
                              headers={'content-type': self.content_type})
 
 
-    def create_response(self, status: int, output: str) -> Response:
+    def create_response(self, status: int, output: str, num_results: int) -> Response:
+        self.request.state.num_results = num_results
         return Response(output, status_code=status, media_type=self.content_type)
 
 
         return Response(output, status_code=status, media_type=self.content_type)
 
 
@@ -59,6 +62,41 @@ def _wrap_endpoint(func: api_impl.EndpointFunc)\
     return _callback
 
 
     return _callback
 
 
+class FileLoggingMiddleware(BaseHTTPMiddleware):
+    """ Middleware to log selected requests into a file.
+    """
+
+    def __init__(self, app: Starlette, file_name: str = ''):
+        super().__init__(app)
+        self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
+
+    async def dispatch(self, request: Request,
+                       call_next: RequestResponseEndpoint) -> Response:
+        start = dt.datetime.now(tz=dt.timezone.utc)
+        response = await call_next(request)
+
+        if response.status_code != 200:
+            return response
+
+        finish = dt.datetime.now(tz=dt.timezone.utc)
+
+        for endpoint in ('reverse', 'search', 'lookup'):
+            if request.url.path.startswith('/' + endpoint):
+                qtype = endpoint
+                break
+        else:
+            return response
+
+        duration = (finish - start).total_seconds()
+        params = request.scope['query_string'].decode('utf8')
+
+        self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
+                      f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
+                      f'{qtype} "{params}"\n')
+
+        return response
+
+
 def get_application(project_dir: Path,
                     environ: Optional[Mapping[str, str]] = None,
                     debug: bool = True) -> Starlette:
 def get_application(project_dir: Path,
                     environ: Optional[Mapping[str, str]] = None,
                     debug: bool = True) -> Starlette:
@@ -78,6 +116,10 @@ def get_application(project_dir: Path,
     if config.get_bool('CORS_NOACCESSCONTROL'):
         middleware.append(Middleware(CORSMiddleware, allow_origins=['*']))
 
     if config.get_bool('CORS_NOACCESSCONTROL'):
         middleware.append(Middleware(CORSMiddleware, allow_origins=['*']))
 
+    log_file = config.LOG_FILE
+    if log_file:
+        middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
+
     async def _shutdown() -> None:
         await app.state.API.close()
 
     async def _shutdown() -> None:
         await app.state.API.close()