]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/server/starlette/server.py
fefedf0ed4242980cbf4669f995ada6c66d8943b
[nominatim.git] / src / nominatim_api / server / starlette / server.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Server implementation using the starlette webserver framework.
9 """
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
11 from pathlib import Path
12 import datetime as dt
13 import asyncio
14
15 from starlette.applications import Starlette
16 from starlette.routing import Route
17 from starlette.exceptions import HTTPException
18 from starlette.responses import Response, PlainTextResponse, HTMLResponse
19 from starlette.requests import Request
20 from starlette.middleware import Middleware
21 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
22 from starlette.middleware.cors import CORSMiddleware
23
24 from ...config import Configuration
25 from ...core import NominatimAPIAsync
26 from ... import v1 as api_impl
27 from ...result_formatting import FormatDispatcher
28 from ...v1.format import dispatch as formatting
29 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
30 from ... import logging as loglib
31
32 class ParamWrapper(ASGIAdaptor):
33     """ Adaptor class for server glue to Starlette framework.
34     """
35
36     def __init__(self, request: Request) -> None:
37         self.request = request
38
39
40     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
41         return self.request.query_params.get(name, default=default)
42
43
44     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
45         return self.request.headers.get(name, default)
46
47
48     def error(self, msg: str, status: int = 400) -> HTTPException:
49         return HTTPException(status, detail=msg,
50                              headers={'content-type': self.content_type})
51
52
53     def create_response(self, status: int, output: str, num_results: int) -> Response:
54         self.request.state.num_results = num_results
55         return Response(output, status_code=status, media_type=self.content_type)
56
57
58     def base_uri(self) -> str:
59         scheme = self.request.url.scheme
60         host = self.request.url.hostname
61         port = self.request.url.port
62         root = self.request.scope['root_path']
63         if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
64             port = None
65         if port is not None:
66             return f"{scheme}://{host}:{port}{root}"
67
68         return f"{scheme}://{host}{root}"
69
70
71     def config(self) -> Configuration:
72         return cast(Configuration, self.request.app.state.API.config)
73
74
75     def formatting(self) -> FormatDispatcher:
76         return formatting
77
78
79 def _wrap_endpoint(func: EndpointFunc)\
80         -> Callable[[Request], Coroutine[Any, Any, Response]]:
81     async def _callback(request: Request) -> Response:
82         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
83
84     return _callback
85
86
87 class FileLoggingMiddleware(BaseHTTPMiddleware):
88     """ Middleware to log selected requests into a file.
89     """
90
91     def __init__(self, app: Starlette, file_name: str = ''):
92         super().__init__(app)
93         self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
94
95     async def dispatch(self, request: Request,
96                        call_next: RequestResponseEndpoint) -> Response:
97         start = dt.datetime.now(tz=dt.timezone.utc)
98         response = await call_next(request)
99
100         if response.status_code != 200:
101             return response
102
103         finish = dt.datetime.now(tz=dt.timezone.utc)
104
105         for endpoint in ('reverse', 'search', 'lookup', 'details'):
106             if request.url.path.startswith('/' + endpoint):
107                 qtype = endpoint
108                 break
109         else:
110             return response
111
112         duration = (finish - start).total_seconds()
113         params = request.scope['query_string'].decode('utf8')
114
115         self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
116                       f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
117                       f'{qtype} "{params}"\n')
118
119         return response
120
121
122 async def timeout_error(request: Request, #pylint: disable=unused-argument
123                         _: Exception) -> Response:
124     """ Error handler for query timeouts.
125     """
126     loglib.log().comment('Aborted: Query took too long to process.')
127     logdata = loglib.get_and_disable()
128
129     if logdata:
130         return HTMLResponse(logdata)
131
132     return PlainTextResponse("Query took too long to process.", status_code=503)
133
134
135 def get_application(project_dir: Path,
136                     environ: Optional[Mapping[str, str]] = None,
137                     debug: bool = True) -> Starlette:
138     """ Create a Nominatim falcon ASGI application.
139     """
140     config = Configuration(project_dir, environ)
141
142     routes = []
143     legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
144     for name, func in api_impl.ROUTES:
145         endpoint = _wrap_endpoint(func)
146         routes.append(Route(f"/{name}", endpoint=endpoint))
147         if legacy_urls:
148             routes.append(Route(f"/{name}.php", endpoint=endpoint))
149
150     middleware = []
151     if config.get_bool('CORS_NOACCESSCONTROL'):
152         middleware.append(Middleware(CORSMiddleware,
153                                      allow_origins=['*'],
154                                      allow_methods=['GET', 'OPTIONS'],
155                                      max_age=86400))
156
157     log_file = config.LOG_FILE
158     if log_file:
159         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
160
161     exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
162         TimeoutError: timeout_error,
163         asyncio.TimeoutError: timeout_error
164     }
165
166     async def _shutdown() -> None:
167         await app.state.API.close()
168
169     app = Starlette(debug=debug, routes=routes, middleware=middleware,
170                     exception_handlers=exceptions,
171                     on_shutdown=[_shutdown])
172
173     app.state.API = NominatimAPIAsync(project_dir, environ)
174
175     return app
176
177
178 def run_wsgi() -> Starlette:
179     """ Entry point for uvicorn.
180     """
181     return get_application(Path('.'), debug=False)