1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Server implementation using the starlette webserver framework.
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
11 from pathlib import Path
15 from starlette.applications import Starlette
16 from starlette.routing import Route
17 from starlette.exceptions import HTTPException
18 from starlette.responses import Response, PlainTextResponse, HTMLResponse
19 from starlette.requests import Request
20 from starlette.middleware import Middleware
21 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
22 from starlette.middleware.cors import CORSMiddleware
24 from ...config import Configuration
25 from ...core import NominatimAPIAsync
26 from ... import v1 as api_impl
27 from ...result_formatting import FormatDispatcher
28 from ...v1.format import dispatch as formatting
29 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
30 from ... import logging as loglib
32 class ParamWrapper(ASGIAdaptor):
33 """ Adaptor class for server glue to Starlette framework.
36 def __init__(self, request: Request) -> None:
37 self.request = request
40 def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
41 return self.request.query_params.get(name, default=default)
44 def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
45 return self.request.headers.get(name, default)
48 def error(self, msg: str, status: int = 400) -> HTTPException:
49 return HTTPException(status, detail=msg,
50 headers={'content-type': self.content_type})
53 def create_response(self, status: int, output: str, num_results: int) -> Response:
54 self.request.state.num_results = num_results
55 return Response(output, status_code=status, media_type=self.content_type)
58 def base_uri(self) -> str:
59 scheme = self.request.url.scheme
60 host = self.request.url.hostname
61 port = self.request.url.port
62 root = self.request.scope['root_path']
63 if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
66 return f"{scheme}://{host}:{port}{root}"
68 return f"{scheme}://{host}{root}"
71 def config(self) -> Configuration:
72 return cast(Configuration, self.request.app.state.API.config)
75 def formatting(self) -> FormatDispatcher:
79 def _wrap_endpoint(func: EndpointFunc)\
80 -> Callable[[Request], Coroutine[Any, Any, Response]]:
81 async def _callback(request: Request) -> Response:
82 return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
87 class FileLoggingMiddleware(BaseHTTPMiddleware):
88 """ Middleware to log selected requests into a file.
91 def __init__(self, app: Starlette, file_name: str = ''):
93 self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
95 async def dispatch(self, request: Request,
96 call_next: RequestResponseEndpoint) -> Response:
97 start = dt.datetime.now(tz=dt.timezone.utc)
98 response = await call_next(request)
100 if response.status_code != 200:
103 finish = dt.datetime.now(tz=dt.timezone.utc)
105 for endpoint in ('reverse', 'search', 'lookup', 'details'):
106 if request.url.path.startswith('/' + endpoint):
112 duration = (finish - start).total_seconds()
113 params = request.scope['query_string'].decode('utf8')
115 self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
116 f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
117 f'{qtype} "{params}"\n')
122 async def timeout_error(request: Request, #pylint: disable=unused-argument
123 _: Exception) -> Response:
124 """ Error handler for query timeouts.
126 loglib.log().comment('Aborted: Query took too long to process.')
127 logdata = loglib.get_and_disable()
130 return HTMLResponse(logdata)
132 return PlainTextResponse("Query took too long to process.", status_code=503)
135 def get_application(project_dir: Path,
136 environ: Optional[Mapping[str, str]] = None,
137 debug: bool = True) -> Starlette:
138 """ Create a Nominatim falcon ASGI application.
140 config = Configuration(project_dir, environ)
143 legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
144 for name, func in api_impl.ROUTES:
145 endpoint = _wrap_endpoint(func)
146 routes.append(Route(f"/{name}", endpoint=endpoint))
148 routes.append(Route(f"/{name}.php", endpoint=endpoint))
151 if config.get_bool('CORS_NOACCESSCONTROL'):
152 middleware.append(Middleware(CORSMiddleware,
154 allow_methods=['GET', 'OPTIONS'],
157 log_file = config.LOG_FILE
159 middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
161 exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
162 TimeoutError: timeout_error,
163 asyncio.TimeoutError: timeout_error
166 async def _shutdown() -> None:
167 await app.state.API.close()
169 app = Starlette(debug=debug, routes=routes, middleware=middleware,
170 exception_handlers=exceptions,
171 on_shutdown=[_shutdown])
173 app.state.API = NominatimAPIAsync(project_dir, environ)
178 def run_wsgi() -> Starlette:
179 """ Entry point for uvicorn.
181 return get_application(Path('.'), debug=False)