]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/server/starlette/server.py
Merge pull request #3515 from lonvia/custom-result-formatting
[nominatim.git] / src / nominatim_api / server / starlette / server.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Server implementation using the starlette webserver framework.
9 """
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
11 from pathlib import Path
12 import datetime as dt
13 import asyncio
14
15 from starlette.applications import Starlette
16 from starlette.routing import Route
17 from starlette.exceptions import HTTPException
18 from starlette.responses import Response, PlainTextResponse, HTMLResponse
19 from starlette.requests import Request
20 from starlette.middleware import Middleware
21 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
22 from starlette.middleware.cors import CORSMiddleware
23
24 from ...config import Configuration
25 from ...core import NominatimAPIAsync
26 from ... import v1 as api_impl
27 from ...result_formatting import FormatDispatcher, load_format_dispatcher
28 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
29 from ... import logging as loglib
30
31 class ParamWrapper(ASGIAdaptor):
32     """ Adaptor class for server glue to Starlette framework.
33     """
34
35     def __init__(self, request: Request) -> None:
36         self.request = request
37
38
39     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
40         return self.request.query_params.get(name, default=default)
41
42
43     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
44         return self.request.headers.get(name, default)
45
46
47     def error(self, msg: str, status: int = 400) -> HTTPException:
48         return HTTPException(status, detail=msg,
49                              headers={'content-type': self.content_type})
50
51
52     def create_response(self, status: int, output: str, num_results: int) -> Response:
53         self.request.state.num_results = num_results
54         return Response(output, status_code=status, media_type=self.content_type)
55
56
57     def base_uri(self) -> str:
58         scheme = self.request.url.scheme
59         host = self.request.url.hostname
60         port = self.request.url.port
61         root = self.request.scope['root_path']
62         if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
63             port = None
64         if port is not None:
65             return f"{scheme}://{host}:{port}{root}"
66
67         return f"{scheme}://{host}{root}"
68
69
70     def config(self) -> Configuration:
71         return cast(Configuration, self.request.app.state.API.config)
72
73
74     def formatting(self) -> FormatDispatcher:
75         return cast(FormatDispatcher, self.request.app.state.API.formatter)
76
77
78 def _wrap_endpoint(func: EndpointFunc)\
79         -> Callable[[Request], Coroutine[Any, Any, Response]]:
80     async def _callback(request: Request) -> Response:
81         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
82
83     return _callback
84
85
86 class FileLoggingMiddleware(BaseHTTPMiddleware):
87     """ Middleware to log selected requests into a file.
88     """
89
90     def __init__(self, app: Starlette, file_name: str = ''):
91         super().__init__(app)
92         self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
93
94     async def dispatch(self, request: Request,
95                        call_next: RequestResponseEndpoint) -> Response:
96         start = dt.datetime.now(tz=dt.timezone.utc)
97         response = await call_next(request)
98
99         if response.status_code != 200:
100             return response
101
102         finish = dt.datetime.now(tz=dt.timezone.utc)
103
104         for endpoint in ('reverse', 'search', 'lookup', 'details'):
105             if request.url.path.startswith('/' + endpoint):
106                 qtype = endpoint
107                 break
108         else:
109             return response
110
111         duration = (finish - start).total_seconds()
112         params = request.scope['query_string'].decode('utf8')
113
114         self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
115                       f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
116                       f'{qtype} "{params}"\n')
117
118         return response
119
120
121 async def timeout_error(request: Request, #pylint: disable=unused-argument
122                         _: Exception) -> Response:
123     """ Error handler for query timeouts.
124     """
125     loglib.log().comment('Aborted: Query took too long to process.')
126     logdata = loglib.get_and_disable()
127
128     if logdata:
129         return HTMLResponse(logdata)
130
131     return PlainTextResponse("Query took too long to process.", status_code=503)
132
133
134 def get_application(project_dir: Path,
135                     environ: Optional[Mapping[str, str]] = None,
136                     debug: bool = True) -> Starlette:
137     """ Create a Nominatim falcon ASGI application.
138     """
139     config = Configuration(project_dir, environ)
140
141     routes = []
142     legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
143     for name, func in api_impl.ROUTES:
144         endpoint = _wrap_endpoint(func)
145         routes.append(Route(f"/{name}", endpoint=endpoint))
146         if legacy_urls:
147             routes.append(Route(f"/{name}.php", endpoint=endpoint))
148
149     middleware = []
150     if config.get_bool('CORS_NOACCESSCONTROL'):
151         middleware.append(Middleware(CORSMiddleware,
152                                      allow_origins=['*'],
153                                      allow_methods=['GET', 'OPTIONS'],
154                                      max_age=86400))
155
156     log_file = config.LOG_FILE
157     if log_file:
158         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
159
160     exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
161         TimeoutError: timeout_error,
162         asyncio.TimeoutError: timeout_error
163     }
164
165     async def _shutdown() -> None:
166         await app.state.API.close()
167
168     app = Starlette(debug=debug, routes=routes, middleware=middleware,
169                     exception_handlers=exceptions,
170                     on_shutdown=[_shutdown])
171
172     app.state.API = NominatimAPIAsync(project_dir, environ)
173     app.state.formatter = load_format_dispatcher('v1', project_dir)
174
175     return app
176
177
178 def run_wsgi() -> Starlette:
179     """ Entry point for uvicorn.
180     """
181     return get_application(Path('.'), debug=False)