1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Server implementation using the starlette webserver framework.
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
11 from pathlib import Path
15 from starlette.applications import Starlette
16 from starlette.routing import Route
17 from starlette.exceptions import HTTPException
18 from starlette.responses import Response, PlainTextResponse, HTMLResponse
19 from starlette.requests import Request
20 from starlette.middleware import Middleware
21 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
22 from starlette.middleware.cors import CORSMiddleware
24 from ...config import Configuration
25 from ...core import NominatimAPIAsync
26 from ... import v1 as api_impl
27 from ...result_formatting import FormatDispatcher, load_format_dispatcher
28 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
29 from ... import logging as loglib
31 class ParamWrapper(ASGIAdaptor):
32 """ Adaptor class for server glue to Starlette framework.
35 def __init__(self, request: Request) -> None:
36 self.request = request
39 def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
40 return self.request.query_params.get(name, default=default)
43 def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
44 return self.request.headers.get(name, default)
47 def error(self, msg: str, status: int = 400) -> HTTPException:
48 return HTTPException(status, detail=msg,
49 headers={'content-type': self.content_type})
52 def create_response(self, status: int, output: str, num_results: int) -> Response:
53 self.request.state.num_results = num_results
54 return Response(output, status_code=status, media_type=self.content_type)
57 def base_uri(self) -> str:
58 scheme = self.request.url.scheme
59 host = self.request.url.hostname
60 port = self.request.url.port
61 root = self.request.scope['root_path']
62 if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
65 return f"{scheme}://{host}:{port}{root}"
67 return f"{scheme}://{host}{root}"
70 def config(self) -> Configuration:
71 return cast(Configuration, self.request.app.state.API.config)
74 def formatting(self) -> FormatDispatcher:
75 return cast(FormatDispatcher, self.request.app.state.API.formatter)
78 def _wrap_endpoint(func: EndpointFunc)\
79 -> Callable[[Request], Coroutine[Any, Any, Response]]:
80 async def _callback(request: Request) -> Response:
81 return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
86 class FileLoggingMiddleware(BaseHTTPMiddleware):
87 """ Middleware to log selected requests into a file.
90 def __init__(self, app: Starlette, file_name: str = ''):
92 self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
94 async def dispatch(self, request: Request,
95 call_next: RequestResponseEndpoint) -> Response:
96 start = dt.datetime.now(tz=dt.timezone.utc)
97 response = await call_next(request)
99 if response.status_code != 200:
102 finish = dt.datetime.now(tz=dt.timezone.utc)
104 for endpoint in ('reverse', 'search', 'lookup', 'details'):
105 if request.url.path.startswith('/' + endpoint):
111 duration = (finish - start).total_seconds()
112 params = request.scope['query_string'].decode('utf8')
114 self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
115 f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
116 f'{qtype} "{params}"\n')
121 async def timeout_error(request: Request, #pylint: disable=unused-argument
122 _: Exception) -> Response:
123 """ Error handler for query timeouts.
125 loglib.log().comment('Aborted: Query took too long to process.')
126 logdata = loglib.get_and_disable()
129 return HTMLResponse(logdata)
131 return PlainTextResponse("Query took too long to process.", status_code=503)
134 def get_application(project_dir: Path,
135 environ: Optional[Mapping[str, str]] = None,
136 debug: bool = True) -> Starlette:
137 """ Create a Nominatim falcon ASGI application.
139 config = Configuration(project_dir, environ)
142 legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
143 for name, func in api_impl.ROUTES:
144 endpoint = _wrap_endpoint(func)
145 routes.append(Route(f"/{name}", endpoint=endpoint))
147 routes.append(Route(f"/{name}.php", endpoint=endpoint))
150 if config.get_bool('CORS_NOACCESSCONTROL'):
151 middleware.append(Middleware(CORSMiddleware,
153 allow_methods=['GET', 'OPTIONS'],
156 log_file = config.LOG_FILE
158 middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
160 exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
161 TimeoutError: timeout_error,
162 asyncio.TimeoutError: timeout_error
165 async def _shutdown() -> None:
166 await app.state.API.close()
168 app = Starlette(debug=debug, routes=routes, middleware=middleware,
169 exception_handlers=exceptions,
170 on_shutdown=[_shutdown])
172 app.state.API = NominatimAPIAsync(project_dir, environ)
173 app.state.formatter = load_format_dispatcher('v1', project_dir)
178 def run_wsgi() -> Starlette:
179 """ Entry point for uvicorn.
181 return get_application(Path('.'), debug=False)