1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Server implementation using the starlette webserver framework.
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \
11 Awaitable, AsyncIterator
12 from pathlib import Path
17 from starlette.applications import Starlette
18 from starlette.routing import Route
19 from starlette.exceptions import HTTPException
20 from starlette.responses import Response, PlainTextResponse, HTMLResponse
21 from starlette.requests import Request
22 from starlette.middleware import Middleware
23 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
24 from starlette.middleware.cors import CORSMiddleware
26 from ...config import Configuration
27 from ...core import NominatimAPIAsync
28 from ... import v1 as api_impl
29 from ...result_formatting import FormatDispatcher, load_format_dispatcher
30 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
31 from ... import logging as loglib
34 class ParamWrapper(ASGIAdaptor):
35 """ Adaptor class for server glue to Starlette framework.
38 def __init__(self, request: Request) -> None:
39 self.request = request
41 def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
42 return self.request.query_params.get(name, default=default)
44 def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
45 return self.request.headers.get(name, default)
47 def error(self, msg: str, status: int = 400) -> HTTPException:
48 return HTTPException(status, detail=msg,
49 headers={'content-type': self.content_type})
51 def create_response(self, status: int, output: str, num_results: int) -> Response:
52 self.request.state.num_results = num_results
53 return Response(output, status_code=status, media_type=self.content_type)
55 def base_uri(self) -> str:
56 scheme = self.request.url.scheme
57 host = self.request.url.hostname
58 port = self.request.url.port
59 root = self.request.scope['root_path']
60 if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
63 return f"{scheme}://{host}:{port}{root}"
65 return f"{scheme}://{host}{root}"
67 def config(self) -> Configuration:
68 return cast(Configuration, self.request.app.state.API.config)
70 def formatting(self) -> FormatDispatcher:
71 return cast(FormatDispatcher, self.request.app.state.formatter)
74 def _wrap_endpoint(func: EndpointFunc)\
75 -> Callable[[Request], Coroutine[Any, Any, Response]]:
76 async def _callback(request: Request) -> Response:
77 return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
82 class FileLoggingMiddleware(BaseHTTPMiddleware):
83 """ Middleware to log selected requests into a file.
86 def __init__(self, app: Starlette, file_name: str = ''):
88 self.fd = open(file_name, 'a', buffering=1, encoding='utf8')
90 async def dispatch(self, request: Request,
91 call_next: RequestResponseEndpoint) -> Response:
92 start = dt.datetime.now(tz=dt.timezone.utc)
93 response = await call_next(request)
95 if response.status_code != 200:
98 finish = dt.datetime.now(tz=dt.timezone.utc)
100 for endpoint in ('reverse', 'search', 'lookup', 'details'):
101 if request.url.path.startswith('/' + endpoint):
107 duration = (finish - start).total_seconds()
108 params = request.scope['query_string'].decode('utf8')
110 self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
111 f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
112 f'{qtype} "{params}"\n')
117 async def timeout_error(request: Request,
118 _: Exception) -> Response:
119 """ Error handler for query timeouts.
121 loglib.log().comment('Aborted: Query took too long to process.')
122 logdata = loglib.get_and_disable()
125 return HTMLResponse(logdata)
127 return PlainTextResponse("Query took too long to process.", status_code=503)
130 def get_application(project_dir: Path,
131 environ: Optional[Mapping[str, str]] = None,
132 debug: bool = True) -> Starlette:
133 """ Create a Nominatim falcon ASGI application.
135 config = Configuration(project_dir, environ)
138 if config.get_bool('CORS_NOACCESSCONTROL'):
139 middleware.append(Middleware(CORSMiddleware,
141 allow_methods=['GET', 'OPTIONS'],
144 log_file = config.LOG_FILE
146 middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
148 exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
149 TimeoutError: timeout_error,
150 asyncio.TimeoutError: timeout_error
153 @contextlib.asynccontextmanager
154 async def lifespan(app: Starlette) -> AsyncIterator[Any]:
155 app.state.API = NominatimAPIAsync(project_dir, environ)
156 config = app.state.API.config
158 legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
159 for name, func in await api_impl.get_routes(app.state.API):
160 endpoint = _wrap_endpoint(func)
161 app.routes.append(Route(f"/{name}", endpoint=endpoint))
163 app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
167 await app.state.API.close()
169 app = Starlette(debug=debug, middleware=middleware,
170 exception_handlers=exceptions,
173 app.state.formatter = load_format_dispatcher('v1', project_dir)
178 def run_wsgi() -> Starlette:
179 """ Entry point for uvicorn.
181 return get_application(Path('.'), debug=False)