]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/server/starlette/server.py
update osm2pgsql to 1.9.1
[nominatim.git] / nominatim / server / starlette / server.py
index 41ad899c9908769b970455faab3ffd159e8c688f..5567ac9c9b9e9986bafdfd4fdb2855c66e7e220b 100644 (file)
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
-# Copyright (C) 2022 by the Nominatim developer community.
+# Copyright (C) 2023 by the Nominatim developer community.
 # For a full list of authors see the git log.
 """
 Server implementation using the starlette webserver framework.
 """
 # For a full list of authors see the git log.
 """
 Server implementation using the starlette webserver framework.
 """
-from typing import Any, Type, Optional, Mapping
+from typing import Any, Optional, Mapping, Callable, cast, Coroutine
 from pathlib import Path
 from pathlib import Path
+import datetime as dt
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
 from starlette.responses import Response
 from starlette.requests import Request
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
 from starlette.responses import Response
 from starlette.requests import Request
+from starlette.middleware import Middleware
+from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
+from starlette.middleware.cors import CORSMiddleware
 
 from nominatim.api import NominatimAPIAsync
 
 from nominatim.api import NominatimAPIAsync
-from nominatim.apicmd.status import StatusResult
-import nominatim.result_formatter.v1 as formatting
+import nominatim.api.v1 as api_impl
+from nominatim.config import Configuration
 
 
-CONTENT_TYPE = {
-  'text': 'text/plain; charset=utf-8',
-  'xml': 'text/xml; charset=utf-8'
-}
+class ParamWrapper(api_impl.ASGIAdaptor):
+    """ Adaptor class for server glue to Starlette framework.
+    """
 
 
-FORMATTERS = {
-    StatusResult: formatting.create(StatusResult)
-}
+    def __init__(self, request: Request) -> None:
+        self.request = request
 
 
 
 
-def parse_format(request: Request, rtype: Type[Any], default: str) -> None:
-    """ Get and check the 'format' parameter and prepare the formatter.
-        `rtype` describes the expected return type and `default` the
-        format value to assume when no parameter is present.
-    """
-    fmt = request.query_params.get('format', default=default)
-    fmtter = FORMATTERS[rtype]
+    def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
+        return self.request.query_params.get(name, default=default)
 
 
-    if not fmtter.supports_format(fmt):
-        raise HTTPException(400, detail="Parameter 'format' must be one of: " +
-                                        ', '.join(fmtter.list_formats()))
 
 
-    request.state.format = fmt
-    request.state.formatter = fmtter
+    def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
+        return self.request.headers.get(name, default)
 
 
 
 
-def format_response(request: Request, result: Any) -> Response:
-    """ Render response into a string according to the formatter
-        set in `parse_format()`.
-    """
-    fmt = request.state.format
-    return Response(request.state.formatter.format(result, fmt),
-                    media_type=CONTENT_TYPE.get(fmt, 'application/json'))
+    def error(self, msg: str, status: int = 400) -> HTTPException:
+        return HTTPException(status, detail=msg,
+                             headers={'content-type': self.content_type})
+
+
+    def create_response(self, status: int, output: str, num_results: int) -> Response:
+        self.request.state.num_results = num_results
+        return Response(output, status_code=status, media_type=self.content_type)
+
+
+    def base_uri(self) -> str:
+        scheme = self.request.url.scheme
+        host = self.request.url.hostname
+        port = self.request.url.port
+        root = self.request.scope['root_path']
+        if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
+            port = None
+        if port is not None:
+            return f"{scheme}://{host}:{port}{root}"
+
+        return f"{scheme}://{host}{root}"
+
+
+    def config(self) -> Configuration:
+        return cast(Configuration, self.request.app.state.API.config)
+
+
+def _wrap_endpoint(func: api_impl.EndpointFunc)\
+        -> Callable[[Request], Coroutine[Any, Any, Response]]:
+    async def _callback(request: Request) -> Response:
+        return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
+
+    return _callback
 
 
 
 
-async def on_status(request: Request) -> Response:
-    """ Implementation of status endpoint.
+class FileLoggingMiddleware(BaseHTTPMiddleware):
+    """ Middleware to log selected requests into a file.
     """
     """
-    parse_format(request, StatusResult, 'text')
-    result = await request.app.state.API.status()
-    response = format_response(request, result)
 
 
-    if request.state.format == 'text' and result.status:
-        response.status_code = 500
+    def __init__(self, app: Starlette, file_name: str = ''):
+        super().__init__(app)
+        self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
 
 
-    return response
+    async def dispatch(self, request: Request,
+                       call_next: RequestResponseEndpoint) -> Response:
+        start = dt.datetime.now(tz=dt.timezone.utc)
+        response = await call_next(request)
 
 
+        if response.status_code != 200:
+            return response
+
+        finish = dt.datetime.now(tz=dt.timezone.utc)
+
+        for endpoint in ('reverse', 'search', 'lookup'):
+            if request.url.path.startswith('/' + endpoint):
+                qtype = endpoint
+                break
+        else:
+            return response
+
+        duration = (finish - start).total_seconds()
+        params = request.scope['query_string'].decode('utf8')
+
+        self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
+                      f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
+                      f'{qtype} "{params}"\n')
+
+        return response
 
 
-V1_ROUTES = [
-    Route('/status', endpoint=on_status)
-]
 
 def get_application(project_dir: Path,
 
 def get_application(project_dir: Path,
-                    environ: Optional[Mapping[str, str]] = None) -> Starlette:
+                    environ: Optional[Mapping[str, str]] = None,
+                    debug: bool = True) -> Starlette:
     """ Create a Nominatim falcon ASGI application.
     """
     """ Create a Nominatim falcon ASGI application.
     """
-    app = Starlette(debug=True, routes=V1_ROUTES)
+    config = Configuration(project_dir, environ)
+
+    routes = []
+    legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
+    for name, func in api_impl.ROUTES:
+        endpoint = _wrap_endpoint(func)
+        routes.append(Route(f"/{name}", endpoint=endpoint))
+        if legacy_urls:
+            routes.append(Route(f"/{name}.php", endpoint=endpoint))
+
+    middleware = []
+    if config.get_bool('CORS_NOACCESSCONTROL'):
+        middleware.append(Middleware(CORSMiddleware,
+                                     allow_origins=['*'],
+                                     allow_methods=['GET', 'OPTIONS'],
+                                     max_age=86400))
+
+    log_file = config.LOG_FILE
+    if log_file:
+        middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
+
+    async def _shutdown() -> None:
+        await app.state.API.close()
+
+    app = Starlette(debug=debug, routes=routes, middleware=middleware,
+                    on_shutdown=[_shutdown])
 
     app.state.API = NominatimAPIAsync(project_dir, environ)
 
     return app
 
     app.state.API = NominatimAPIAsync(project_dir, environ)
 
     return app
+
+
+def run_wsgi() -> Starlette:
+    """ Entry point for uvicorn.
+    """
+    return get_application(Path('.'), debug=False)