]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/server/starlette/server.py
move server route creation into async function
[nominatim.git] / src / nominatim_api / server / starlette / server.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Server implementation using the starlette webserver framework.
9 """
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
11 from pathlib import Path
12 import datetime as dt
13 import asyncio
14 import contextlib
15
16 from starlette.applications import Starlette
17 from starlette.routing import Route
18 from starlette.exceptions import HTTPException
19 from starlette.responses import Response, PlainTextResponse, HTMLResponse
20 from starlette.requests import Request
21 from starlette.middleware import Middleware
22 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
23 from starlette.middleware.cors import CORSMiddleware
24
25 from ...config import Configuration
26 from ...core import NominatimAPIAsync
27 from ... import v1 as api_impl
28 from ...result_formatting import FormatDispatcher, load_format_dispatcher
29 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
30 from ... import logging as loglib
31
32
33 class ParamWrapper(ASGIAdaptor):
34     """ Adaptor class for server glue to Starlette framework.
35     """
36
37     def __init__(self, request: Request) -> None:
38         self.request = request
39
40     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
41         return self.request.query_params.get(name, default=default)
42
43     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
44         return self.request.headers.get(name, default)
45
46     def error(self, msg: str, status: int = 400) -> HTTPException:
47         return HTTPException(status, detail=msg,
48                              headers={'content-type': self.content_type})
49
50     def create_response(self, status: int, output: str, num_results: int) -> Response:
51         self.request.state.num_results = num_results
52         return Response(output, status_code=status, media_type=self.content_type)
53
54     def base_uri(self) -> str:
55         scheme = self.request.url.scheme
56         host = self.request.url.hostname
57         port = self.request.url.port
58         root = self.request.scope['root_path']
59         if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
60             port = None
61         if port is not None:
62             return f"{scheme}://{host}:{port}{root}"
63
64         return f"{scheme}://{host}{root}"
65
66     def config(self) -> Configuration:
67         return cast(Configuration, self.request.app.state.API.config)
68
69     def formatting(self) -> FormatDispatcher:
70         return cast(FormatDispatcher, self.request.app.state.formatter)
71
72
73 def _wrap_endpoint(func: EndpointFunc)\
74         -> Callable[[Request], Coroutine[Any, Any, Response]]:
75     async def _callback(request: Request) -> Response:
76         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
77
78     return _callback
79
80
81 class FileLoggingMiddleware(BaseHTTPMiddleware):
82     """ Middleware to log selected requests into a file.
83     """
84
85     def __init__(self, app: Starlette, file_name: str = ''):
86         super().__init__(app)
87         self.fd = open(file_name, 'a', buffering=1, encoding='utf8')
88
89     async def dispatch(self, request: Request,
90                        call_next: RequestResponseEndpoint) -> Response:
91         start = dt.datetime.now(tz=dt.timezone.utc)
92         response = await call_next(request)
93
94         if response.status_code != 200:
95             return response
96
97         finish = dt.datetime.now(tz=dt.timezone.utc)
98
99         for endpoint in ('reverse', 'search', 'lookup', 'details'):
100             if request.url.path.startswith('/' + endpoint):
101                 qtype = endpoint
102                 break
103         else:
104             return response
105
106         duration = (finish - start).total_seconds()
107         params = request.scope['query_string'].decode('utf8')
108
109         self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
110                       f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
111                       f'{qtype} "{params}"\n')
112
113         return response
114
115
116 async def timeout_error(request: Request,
117                         _: Exception) -> Response:
118     """ Error handler for query timeouts.
119     """
120     loglib.log().comment('Aborted: Query took too long to process.')
121     logdata = loglib.get_and_disable()
122
123     if logdata:
124         return HTMLResponse(logdata)
125
126     return PlainTextResponse("Query took too long to process.", status_code=503)
127
128
129 def get_application(project_dir: Path,
130                     environ: Optional[Mapping[str, str]] = None,
131                     debug: bool = True) -> Starlette:
132     """ Create a Nominatim falcon ASGI application.
133     """
134     config = Configuration(project_dir, environ)
135
136     middleware = []
137     if config.get_bool('CORS_NOACCESSCONTROL'):
138         middleware.append(Middleware(CORSMiddleware,
139                                      allow_origins=['*'],
140                                      allow_methods=['GET', 'OPTIONS'],
141                                      max_age=86400))
142
143     log_file = config.LOG_FILE
144     if log_file:
145         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
146
147     exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
148         TimeoutError: timeout_error,
149         asyncio.TimeoutError: timeout_error
150     }
151
152     @contextlib.asynccontextmanager
153     async def lifespan(app: Starlette) -> None:
154         app.state.API = NominatimAPIAsync(project_dir, environ)
155         config = app.state.API.config
156
157         legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
158         for name, func in api_impl.ROUTES:
159             endpoint = _wrap_endpoint(func)
160             app.routes.append(Route(f"/{name}", endpoint=endpoint))
161             if legacy_urls:
162                 app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
163
164         yield
165
166         await app.state.API.close()
167
168     app = Starlette(debug=debug, middleware=middleware,
169                     exception_handlers=exceptions,
170                     lifespan=lifespan)
171
172     app.state.formatter = load_format_dispatcher('v1', project_dir)
173
174     return app
175
176
177 def run_wsgi() -> Starlette:
178     """ Entry point for uvicorn.
179     """
180     return get_application(Path('.'), debug=False)