]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/server/starlette/server.py
update osm2pgsql to 1.9.1
[nominatim.git] / nominatim / server / starlette / server.py
index 60a78a475d8a81fa2cbde9073f853a7001fea605..5567ac9c9b9e9986bafdfd4fdb2855c66e7e220b 100644 (file)
 """
 Server implementation using the starlette webserver framework.
 """
 """
 Server implementation using the starlette webserver framework.
 """
-from typing import Any, Type, Optional, Mapping
+from typing import Any, Optional, Mapping, Callable, cast, Coroutine
 from pathlib import Path
 from pathlib import Path
+import datetime as dt
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
 from starlette.responses import Response
 from starlette.requests import Request
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
 from starlette.responses import Response
 from starlette.requests import Request
+from starlette.middleware import Middleware
+from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
+from starlette.middleware.cors import CORSMiddleware
 
 
-from nominatim.api import NominatimAPIAsync, StatusResult
+from nominatim.api import NominatimAPIAsync
 import nominatim.api.v1 as api_impl
 import nominatim.api.v1 as api_impl
+from nominatim.config import Configuration
 
 
-CONTENT_TYPE = {
-  'text': 'text/plain; charset=utf-8',
-  'xml': 'text/xml; charset=utf-8'
-}
-
-def parse_format(request: Request, rtype: Type[Any], default: str) -> None:
-    """ Get and check the 'format' parameter and prepare the formatter.
-        `rtype` describes the expected return type and `default` the
-        format value to assume when no parameter is present.
+class ParamWrapper(api_impl.ASGIAdaptor):
+    """ Adaptor class for server glue to Starlette framework.
     """
     """
-    fmt = request.query_params.get('format', default=default)
 
 
-    if not api_impl.supports_format(rtype, fmt):
-        raise HTTPException(400, detail="Parameter 'format' must be one of: " +
-                                        ', '.join(api_impl.list_formats(rtype)))
+    def __init__(self, request: Request) -> None:
+        self.request = request
 
 
-    request.state.format = fmt
 
 
+    def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
+        return self.request.query_params.get(name, default=default)
 
 
-def format_response(request: Request, result: Any) -> Response:
-    """ Render response into a string according.
-    """
-    fmt = request.state.format
-    return Response(api_impl.format_result(result, fmt),
-                    media_type=CONTENT_TYPE.get(fmt, 'application/json'))
 
 
+    def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
+        return self.request.headers.get(name, default)
+
+
+    def error(self, msg: str, status: int = 400) -> HTTPException:
+        return HTTPException(status, detail=msg,
+                             headers={'content-type': self.content_type})
+
+
+    def create_response(self, status: int, output: str, num_results: int) -> Response:
+        self.request.state.num_results = num_results
+        return Response(output, status_code=status, media_type=self.content_type)
+
+
+    def base_uri(self) -> str:
+        scheme = self.request.url.scheme
+        host = self.request.url.hostname
+        port = self.request.url.port
+        root = self.request.scope['root_path']
+        if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
+            port = None
+        if port is not None:
+            return f"{scheme}://{host}:{port}{root}"
+
+        return f"{scheme}://{host}{root}"
+
+
+    def config(self) -> Configuration:
+        return cast(Configuration, self.request.app.state.API.config)
+
+
+def _wrap_endpoint(func: api_impl.EndpointFunc)\
+        -> Callable[[Request], Coroutine[Any, Any, Response]]:
+    async def _callback(request: Request) -> Response:
+        return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
+
+    return _callback
 
 
-async def on_status(request: Request) -> Response:
-    """ Implementation of status endpoint.
+
+class FileLoggingMiddleware(BaseHTTPMiddleware):
+    """ Middleware to log selected requests into a file.
     """
     """
-    parse_format(request, StatusResult, 'text')
-    result = await request.app.state.API.status()
-    response = format_response(request, result)
 
 
-    if request.state.format == 'text' and result.status:
-        response.status_code = 500
+    def __init__(self, app: Starlette, file_name: str = ''):
+        super().__init__(app)
+        self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
+
+    async def dispatch(self, request: Request,
+                       call_next: RequestResponseEndpoint) -> Response:
+        start = dt.datetime.now(tz=dt.timezone.utc)
+        response = await call_next(request)
+
+        if response.status_code != 200:
+            return response
+
+        finish = dt.datetime.now(tz=dt.timezone.utc)
 
 
-    return response
+        for endpoint in ('reverse', 'search', 'lookup'):
+            if request.url.path.startswith('/' + endpoint):
+                qtype = endpoint
+                break
+        else:
+            return response
 
 
+        duration = (finish - start).total_seconds()
+        params = request.scope['query_string'].decode('utf8')
+
+        self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
+                      f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
+                      f'{qtype} "{params}"\n')
+
+        return response
 
 
-V1_ROUTES = [
-    Route('/status', endpoint=on_status)
-]
 
 def get_application(project_dir: Path,
 
 def get_application(project_dir: Path,
-                    environ: Optional[Mapping[str, str]] = None) -> Starlette:
+                    environ: Optional[Mapping[str, str]] = None,
+                    debug: bool = True) -> Starlette:
     """ Create a Nominatim falcon ASGI application.
     """
     """ Create a Nominatim falcon ASGI application.
     """
-    app = Starlette(debug=True, routes=V1_ROUTES)
+    config = Configuration(project_dir, environ)
+
+    routes = []
+    legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
+    for name, func in api_impl.ROUTES:
+        endpoint = _wrap_endpoint(func)
+        routes.append(Route(f"/{name}", endpoint=endpoint))
+        if legacy_urls:
+            routes.append(Route(f"/{name}.php", endpoint=endpoint))
+
+    middleware = []
+    if config.get_bool('CORS_NOACCESSCONTROL'):
+        middleware.append(Middleware(CORSMiddleware,
+                                     allow_origins=['*'],
+                                     allow_methods=['GET', 'OPTIONS'],
+                                     max_age=86400))
+
+    log_file = config.LOG_FILE
+    if log_file:
+        middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
+
+    async def _shutdown() -> None:
+        await app.state.API.close()
+
+    app = Starlette(debug=debug, routes=routes, middleware=middleware,
+                    on_shutdown=[_shutdown])
 
     app.state.API = NominatimAPIAsync(project_dir, environ)
 
     return app
 
     app.state.API = NominatimAPIAsync(project_dir, environ)
 
     return app
+
+
+def run_wsgi() -> Starlette:
+    """ Entry point for uvicorn.
+    """
+    return get_application(Path('.'), debug=False)