]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/server/starlette/server.py
60a81321ba5ca26d5f920162c1942be561df1510
[nominatim.git] / src / nominatim_api / server / starlette / server.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Server implementation using the starlette webserver framework.
9 """
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
11 from pathlib import Path
12 import datetime as dt
13 import asyncio
14
15 from starlette.applications import Starlette
16 from starlette.routing import Route
17 from starlette.exceptions import HTTPException
18 from starlette.responses import Response, PlainTextResponse, HTMLResponse
19 from starlette.requests import Request
20 from starlette.middleware import Middleware
21 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
22 from starlette.middleware.cors import CORSMiddleware
23
24 from ...config import Configuration
25 from ...core import NominatimAPIAsync
26 from ... import v1 as api_impl
27 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
28 from ... import logging as loglib
29
30 class ParamWrapper(ASGIAdaptor):
31     """ Adaptor class for server glue to Starlette framework.
32     """
33
34     def __init__(self, request: Request) -> None:
35         self.request = request
36
37
38     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
39         return self.request.query_params.get(name, default=default)
40
41
42     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
43         return self.request.headers.get(name, default)
44
45
46     def error(self, msg: str, status: int = 400) -> HTTPException:
47         return HTTPException(status, detail=msg,
48                              headers={'content-type': self.content_type})
49
50
51     def create_response(self, status: int, output: str, num_results: int) -> Response:
52         self.request.state.num_results = num_results
53         return Response(output, status_code=status, media_type=self.content_type)
54
55
56     def base_uri(self) -> str:
57         scheme = self.request.url.scheme
58         host = self.request.url.hostname
59         port = self.request.url.port
60         root = self.request.scope['root_path']
61         if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
62             port = None
63         if port is not None:
64             return f"{scheme}://{host}:{port}{root}"
65
66         return f"{scheme}://{host}{root}"
67
68
69     def config(self) -> Configuration:
70         return cast(Configuration, self.request.app.state.API.config)
71
72
73 def _wrap_endpoint(func: EndpointFunc)\
74         -> Callable[[Request], Coroutine[Any, Any, Response]]:
75     async def _callback(request: Request) -> Response:
76         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
77
78     return _callback
79
80
81 class FileLoggingMiddleware(BaseHTTPMiddleware):
82     """ Middleware to log selected requests into a file.
83     """
84
85     def __init__(self, app: Starlette, file_name: str = ''):
86         super().__init__(app)
87         self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
88
89     async def dispatch(self, request: Request,
90                        call_next: RequestResponseEndpoint) -> Response:
91         start = dt.datetime.now(tz=dt.timezone.utc)
92         response = await call_next(request)
93
94         if response.status_code != 200:
95             return response
96
97         finish = dt.datetime.now(tz=dt.timezone.utc)
98
99         for endpoint in ('reverse', 'search', 'lookup', 'details'):
100             if request.url.path.startswith('/' + endpoint):
101                 qtype = endpoint
102                 break
103         else:
104             return response
105
106         duration = (finish - start).total_seconds()
107         params = request.scope['query_string'].decode('utf8')
108
109         self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
110                       f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
111                       f'{qtype} "{params}"\n')
112
113         return response
114
115
116 async def timeout_error(request: Request, #pylint: disable=unused-argument
117                         _: Exception) -> Response:
118     """ Error handler for query timeouts.
119     """
120     loglib.log().comment('Aborted: Query took too long to process.')
121     logdata = loglib.get_and_disable()
122
123     if logdata:
124         return HTMLResponse(logdata)
125
126     return PlainTextResponse("Query took too long to process.", status_code=503)
127
128
129 def get_application(project_dir: Path,
130                     environ: Optional[Mapping[str, str]] = None,
131                     debug: bool = True) -> Starlette:
132     """ Create a Nominatim falcon ASGI application.
133     """
134     config = Configuration(project_dir, environ)
135
136     routes = []
137     legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
138     for name, func in api_impl.ROUTES:
139         endpoint = _wrap_endpoint(func)
140         routes.append(Route(f"/{name}", endpoint=endpoint))
141         if legacy_urls:
142             routes.append(Route(f"/{name}.php", endpoint=endpoint))
143
144     middleware = []
145     if config.get_bool('CORS_NOACCESSCONTROL'):
146         middleware.append(Middleware(CORSMiddleware,
147                                      allow_origins=['*'],
148                                      allow_methods=['GET', 'OPTIONS'],
149                                      max_age=86400))
150
151     log_file = config.LOG_FILE
152     if log_file:
153         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
154
155     exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
156         TimeoutError: timeout_error,
157         asyncio.TimeoutError: timeout_error
158     }
159
160     async def _shutdown() -> None:
161         await app.state.API.close()
162
163     app = Starlette(debug=debug, routes=routes, middleware=middleware,
164                     exception_handlers=exceptions,
165                     on_shutdown=[_shutdown])
166
167     app.state.API = NominatimAPIAsync(project_dir, environ)
168
169     return app
170
171
172 def run_wsgi() -> Starlette:
173     """ Entry point for uvicorn.
174     """
175     return get_application(Path('.'), debug=False)