]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/server/starlette/server.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / server / starlette / server.py
index bfc552f236b2d844855599de1c9d98d252ce5778..2bcc8df51c37b0bb8ad00316551903fc6ad728ed 100644 (file)
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
-# Copyright (C) 2022 by the Nominatim developer community.
+# Copyright (C) 2023 by the Nominatim developer community.
 # For a full list of authors see the git log.
 """
 Server implementation using the starlette webserver framework.
 """
 # For a full list of authors see the git log.
 """
 Server implementation using the starlette webserver framework.
 """
+from typing import Any, Optional, Mapping, Callable, cast, Coroutine
 from pathlib import Path
 from pathlib import Path
+import datetime as dt
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
 from starlette.responses import Response
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
 from starlette.responses import Response
+from starlette.requests import Request
+from starlette.middleware import Middleware
+from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
+from starlette.middleware.cors import CORSMiddleware
 
 from nominatim.api import NominatimAPIAsync
 
 from nominatim.api import NominatimAPIAsync
-from nominatim.apicmd.status import StatusResult
-import nominatim.result_formatter.v1 as formatting
+import nominatim.api.v1 as api_impl
+from nominatim.config import Configuration
 
 
-CONTENT_TYPE = {
-  'text': 'text/plain; charset=utf-8',
-  'xml': 'text/xml; charset=utf-8'
-}
+class ParamWrapper(api_impl.ASGIAdaptor):
+    """ Adaptor class for server glue to Starlette framework.
+    """
 
 
-FORMATTERS = {
-    StatusResult: formatting.create(StatusResult)
-}
+    def __init__(self, request: Request) -> None:
+        self.request = request
 
 
 
 
-def parse_format(request, rtype, default):
-    fmt = request.query_params.get('format', default=default)
-    fmtter = FORMATTERS[rtype]
+    def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
+        return self.request.query_params.get(name, default=default)
 
 
-    if not fmtter.supports_format(fmt):
-        raise HTTPException(400, detail="Parameter 'format' must be one of: " +
-                                        ', '.join(fmtter.list_formats()))
 
 
-    request.state.format = fmt
-    request.state.formatter = fmtter
+    def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
+        return self.request.headers.get(name, default)
 
 
 
 
-def format_response(request, result):
-    fmt = request.state.format
-    return Response(request.state.formatter.format(result, fmt),
-                    media_type=CONTENT_TYPE.get(fmt, 'application/json'))
+    def error(self, msg: str, status: int = 400) -> HTTPException:
+        return HTTPException(status, detail=msg,
+                             headers={'content-type': self.content_type})
 
 
 
 
-async def on_status(request):
-    parse_format(request, StatusResult, 'text')
-    result = await request.app.state.API.status()
-    return format_response(request, result)
+    def create_response(self, status: int, output: str, num_results: int) -> Response:
+        self.request.state.num_results = num_results
+        return Response(output, status_code=status, media_type=self.content_type)
 
 
 
 
-V1_ROUTES = [
-    Route('/status', endpoint=on_status)
-]
+    def config(self) -> Configuration:
+        return cast(Configuration, self.request.app.state.API.config)
 
 
-def get_application(project_dir: Path) -> Starlette:
-    app = Starlette(debug=True, routes=V1_ROUTES)
 
 
-    app.state.API = NominatimAPIAsync(project_dir)
+def _wrap_endpoint(func: api_impl.EndpointFunc)\
+        -> Callable[[Request], Coroutine[Any, Any, Response]]:
+    async def _callback(request: Request) -> Response:
+        return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
+
+    return _callback
+
+
+class FileLoggingMiddleware(BaseHTTPMiddleware):
+    """ Middleware to log selected requests into a file.
+    """
+
+    def __init__(self, app: Starlette, file_name: str = ''):
+        super().__init__(app)
+        self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
+
+    async def dispatch(self, request: Request,
+                       call_next: RequestResponseEndpoint) -> Response:
+        start = dt.datetime.now(tz=dt.timezone.utc)
+        response = await call_next(request)
+
+        if response.status_code != 200:
+            return response
+
+        finish = dt.datetime.now(tz=dt.timezone.utc)
+
+        for endpoint in ('reverse', 'search', 'lookup'):
+            if request.url.path.startswith('/' + endpoint):
+                qtype = endpoint
+                break
+        else:
+            return response
+
+        duration = (finish - start).total_seconds()
+        params = request.scope['query_string'].decode('utf8')
+
+        self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
+                      f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
+                      f'{qtype} "{params}"\n')
+
+        return response
+
+
+def get_application(project_dir: Path,
+                    environ: Optional[Mapping[str, str]] = None,
+                    debug: bool = True) -> Starlette:
+    """ Create a Nominatim falcon ASGI application.
+    """
+    config = Configuration(project_dir, environ)
+
+    routes = []
+    legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
+    for name, func in api_impl.ROUTES:
+        endpoint = _wrap_endpoint(func)
+        routes.append(Route(f"/{name}", endpoint=endpoint))
+        if legacy_urls:
+            routes.append(Route(f"/{name}.php", endpoint=endpoint))
+
+    middleware = []
+    if config.get_bool('CORS_NOACCESSCONTROL'):
+        middleware.append(Middleware(CORSMiddleware, allow_origins=['*']))
+
+    log_file = config.LOG_FILE
+    if log_file:
+        middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
+
+    async def _shutdown() -> None:
+        await app.state.API.close()
+
+    app = Starlette(debug=debug, routes=routes, middleware=middleware,
+                    on_shutdown=[_shutdown])
+
+    app.state.API = NominatimAPIAsync(project_dir, environ)
 
     return app
 
     return app
+
+
+def run_wsgi() -> Starlette:
+    """ Entry point for uvicorn.
+    """
+    return get_application(Path('.'), debug=False)