1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Server implementation using the starlette webserver framework.
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
11 from pathlib import Path
15 from starlette.applications import Starlette
16 from starlette.routing import Route
17 from starlette.exceptions import HTTPException
18 from starlette.responses import Response, PlainTextResponse, HTMLResponse
19 from starlette.requests import Request
20 from starlette.middleware import Middleware
21 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
22 from starlette.middleware.cors import CORSMiddleware
24 from ...config import Configuration
25 from ...core import NominatimAPIAsync
26 from ... import v1 as api_impl
27 from ...result_formatting import FormatDispatcher, load_format_dispatcher
28 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
29 from ... import logging as loglib
32 class ParamWrapper(ASGIAdaptor):
33 """ Adaptor class for server glue to Starlette framework.
36 def __init__(self, request: Request) -> None:
37 self.request = request
39 def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
40 return self.request.query_params.get(name, default=default)
42 def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
43 return self.request.headers.get(name, default)
45 def error(self, msg: str, status: int = 400) -> HTTPException:
46 return HTTPException(status, detail=msg,
47 headers={'content-type': self.content_type})
49 def create_response(self, status: int, output: str, num_results: int) -> Response:
50 self.request.state.num_results = num_results
51 return Response(output, status_code=status, media_type=self.content_type)
53 def base_uri(self) -> str:
54 scheme = self.request.url.scheme
55 host = self.request.url.hostname
56 port = self.request.url.port
57 root = self.request.scope['root_path']
58 if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
61 return f"{scheme}://{host}:{port}{root}"
63 return f"{scheme}://{host}{root}"
65 def config(self) -> Configuration:
66 return cast(Configuration, self.request.app.state.API.config)
68 def formatting(self) -> FormatDispatcher:
69 return cast(FormatDispatcher, self.request.app.state.API.formatter)
72 def _wrap_endpoint(func: EndpointFunc)\
73 -> Callable[[Request], Coroutine[Any, Any, Response]]:
74 async def _callback(request: Request) -> Response:
75 return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
80 class FileLoggingMiddleware(BaseHTTPMiddleware):
81 """ Middleware to log selected requests into a file.
84 def __init__(self, app: Starlette, file_name: str = ''):
86 self.fd = open(file_name, 'a', buffering=1, encoding='utf8')
88 async def dispatch(self, request: Request,
89 call_next: RequestResponseEndpoint) -> Response:
90 start = dt.datetime.now(tz=dt.timezone.utc)
91 response = await call_next(request)
93 if response.status_code != 200:
96 finish = dt.datetime.now(tz=dt.timezone.utc)
98 for endpoint in ('reverse', 'search', 'lookup', 'details'):
99 if request.url.path.startswith('/' + endpoint):
105 duration = (finish - start).total_seconds()
106 params = request.scope['query_string'].decode('utf8')
108 self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
109 f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
110 f'{qtype} "{params}"\n')
115 async def timeout_error(request: Request,
116 _: Exception) -> Response:
117 """ Error handler for query timeouts.
119 loglib.log().comment('Aborted: Query took too long to process.')
120 logdata = loglib.get_and_disable()
123 return HTMLResponse(logdata)
125 return PlainTextResponse("Query took too long to process.", status_code=503)
128 def get_application(project_dir: Path,
129 environ: Optional[Mapping[str, str]] = None,
130 debug: bool = True) -> Starlette:
131 """ Create a Nominatim falcon ASGI application.
133 config = Configuration(project_dir, environ)
136 legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
137 for name, func in api_impl.ROUTES:
138 endpoint = _wrap_endpoint(func)
139 routes.append(Route(f"/{name}", endpoint=endpoint))
141 routes.append(Route(f"/{name}.php", endpoint=endpoint))
144 if config.get_bool('CORS_NOACCESSCONTROL'):
145 middleware.append(Middleware(CORSMiddleware,
147 allow_methods=['GET', 'OPTIONS'],
150 log_file = config.LOG_FILE
152 middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
154 exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
155 TimeoutError: timeout_error,
156 asyncio.TimeoutError: timeout_error
159 async def _shutdown() -> None:
160 await app.state.API.close()
162 app = Starlette(debug=debug, routes=routes, middleware=middleware,
163 exception_handlers=exceptions,
164 on_shutdown=[_shutdown])
166 app.state.API = NominatimAPIAsync(project_dir, environ)
167 app.state.formatter = load_format_dispatcher('v1', project_dir)
172 def run_wsgi() -> Starlette:
173 """ Entry point for uvicorn.
175 return get_application(Path('.'), debug=False)