]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/server/starlette/server.py
catch special async timeout error in servers
[nominatim.git] / nominatim / server / starlette / server.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Server implementation using the starlette webserver framework.
9 """
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
11 from pathlib import Path
12 import datetime as dt
13 import asyncio
14
15 from starlette.applications import Starlette
16 from starlette.routing import Route
17 from starlette.exceptions import HTTPException
18 from starlette.responses import Response, PlainTextResponse
19 from starlette.requests import Request
20 from starlette.middleware import Middleware
21 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
22 from starlette.middleware.cors import CORSMiddleware
23
24 from nominatim.api import NominatimAPIAsync
25 import nominatim.api.v1 as api_impl
26 from nominatim.config import Configuration
27
28 class ParamWrapper(api_impl.ASGIAdaptor):
29     """ Adaptor class for server glue to Starlette framework.
30     """
31
32     def __init__(self, request: Request) -> None:
33         self.request = request
34
35
36     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
37         return self.request.query_params.get(name, default=default)
38
39
40     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
41         return self.request.headers.get(name, default)
42
43
44     def error(self, msg: str, status: int = 400) -> HTTPException:
45         return HTTPException(status, detail=msg,
46                              headers={'content-type': self.content_type})
47
48
49     def create_response(self, status: int, output: str, num_results: int) -> Response:
50         self.request.state.num_results = num_results
51         return Response(output, status_code=status, media_type=self.content_type)
52
53
54     def base_uri(self) -> str:
55         scheme = self.request.url.scheme
56         host = self.request.url.hostname
57         port = self.request.url.port
58         root = self.request.scope['root_path']
59         if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
60             port = None
61         if port is not None:
62             return f"{scheme}://{host}:{port}{root}"
63
64         return f"{scheme}://{host}{root}"
65
66
67     def config(self) -> Configuration:
68         return cast(Configuration, self.request.app.state.API.config)
69
70
71 def _wrap_endpoint(func: api_impl.EndpointFunc)\
72         -> Callable[[Request], Coroutine[Any, Any, Response]]:
73     async def _callback(request: Request) -> Response:
74         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
75
76     return _callback
77
78
79 class FileLoggingMiddleware(BaseHTTPMiddleware):
80     """ Middleware to log selected requests into a file.
81     """
82
83     def __init__(self, app: Starlette, file_name: str = ''):
84         super().__init__(app)
85         self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
86
87     async def dispatch(self, request: Request,
88                        call_next: RequestResponseEndpoint) -> Response:
89         start = dt.datetime.now(tz=dt.timezone.utc)
90         response = await call_next(request)
91
92         if response.status_code != 200:
93             return response
94
95         finish = dt.datetime.now(tz=dt.timezone.utc)
96
97         for endpoint in ('reverse', 'search', 'lookup', 'details'):
98             if request.url.path.startswith('/' + endpoint):
99                 qtype = endpoint
100                 break
101         else:
102             return response
103
104         duration = (finish - start).total_seconds()
105         params = request.scope['query_string'].decode('utf8')
106
107         self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
108                       f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
109                       f'{qtype} "{params}"\n')
110
111         return response
112
113
114 async def timeout_error(request: Request, #pylint: disable=unused-argument
115                         _: Exception) -> Response:
116     """ Error handler for query timeouts.
117     """
118     return PlainTextResponse("Query took too long to process.", status_code=503)
119
120
121 def get_application(project_dir: Path,
122                     environ: Optional[Mapping[str, str]] = None,
123                     debug: bool = True) -> Starlette:
124     """ Create a Nominatim falcon ASGI application.
125     """
126     config = Configuration(project_dir, environ)
127
128     routes = []
129     legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
130     for name, func in api_impl.ROUTES:
131         endpoint = _wrap_endpoint(func)
132         routes.append(Route(f"/{name}", endpoint=endpoint))
133         if legacy_urls:
134             routes.append(Route(f"/{name}.php", endpoint=endpoint))
135
136     middleware = []
137     if config.get_bool('CORS_NOACCESSCONTROL'):
138         middleware.append(Middleware(CORSMiddleware,
139                                      allow_origins=['*'],
140                                      allow_methods=['GET', 'OPTIONS'],
141                                      max_age=86400))
142
143     log_file = config.LOG_FILE
144     if log_file:
145         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
146
147     exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
148         TimeoutError: timeout_error,
149         asyncio.TimeoutError: timeout_error
150     }
151
152     async def _shutdown() -> None:
153         await app.state.API.close()
154
155     app = Starlette(debug=debug, routes=routes, middleware=middleware,
156                     exception_handlers=exceptions,
157                     on_shutdown=[_shutdown])
158
159     app.state.API = NominatimAPIAsync(project_dir, environ)
160
161     return app
162
163
164 def run_wsgi() -> Starlette:
165     """ Entry point for uvicorn.
166     """
167     return get_application(Path('.'), debug=False)