]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/server/starlette/server.py
Merge pull request #3413 from osm-search/mtmail-patch-1
[nominatim.git] / nominatim / server / starlette / server.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Server implementation using the starlette webserver framework.
9 """
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
11 from pathlib import Path
12 import datetime as dt
13 import asyncio
14
15 from starlette.applications import Starlette
16 from starlette.routing import Route
17 from starlette.exceptions import HTTPException
18 from starlette.responses import Response, PlainTextResponse, HTMLResponse
19 from starlette.requests import Request
20 from starlette.middleware import Middleware
21 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
22 from starlette.middleware.cors import CORSMiddleware
23
24 from nominatim.api import NominatimAPIAsync
25 import nominatim.api.v1 as api_impl
26 import nominatim.api.logging as loglib
27 from nominatim.config import Configuration
28
29 class ParamWrapper(api_impl.ASGIAdaptor):
30     """ Adaptor class for server glue to Starlette framework.
31     """
32
33     def __init__(self, request: Request) -> None:
34         self.request = request
35
36
37     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
38         return self.request.query_params.get(name, default=default)
39
40
41     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
42         return self.request.headers.get(name, default)
43
44
45     def error(self, msg: str, status: int = 400) -> HTTPException:
46         return HTTPException(status, detail=msg,
47                              headers={'content-type': self.content_type})
48
49
50     def create_response(self, status: int, output: str, num_results: int) -> Response:
51         self.request.state.num_results = num_results
52         return Response(output, status_code=status, media_type=self.content_type)
53
54
55     def base_uri(self) -> str:
56         scheme = self.request.url.scheme
57         host = self.request.url.hostname
58         port = self.request.url.port
59         root = self.request.scope['root_path']
60         if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
61             port = None
62         if port is not None:
63             return f"{scheme}://{host}:{port}{root}"
64
65         return f"{scheme}://{host}{root}"
66
67
68     def config(self) -> Configuration:
69         return cast(Configuration, self.request.app.state.API.config)
70
71
72 def _wrap_endpoint(func: api_impl.EndpointFunc)\
73         -> Callable[[Request], Coroutine[Any, Any, Response]]:
74     async def _callback(request: Request) -> Response:
75         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
76
77     return _callback
78
79
80 class FileLoggingMiddleware(BaseHTTPMiddleware):
81     """ Middleware to log selected requests into a file.
82     """
83
84     def __init__(self, app: Starlette, file_name: str = ''):
85         super().__init__(app)
86         self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
87
88     async def dispatch(self, request: Request,
89                        call_next: RequestResponseEndpoint) -> Response:
90         start = dt.datetime.now(tz=dt.timezone.utc)
91         response = await call_next(request)
92
93         if response.status_code != 200:
94             return response
95
96         finish = dt.datetime.now(tz=dt.timezone.utc)
97
98         for endpoint in ('reverse', 'search', 'lookup', 'details'):
99             if request.url.path.startswith('/' + endpoint):
100                 qtype = endpoint
101                 break
102         else:
103             return response
104
105         duration = (finish - start).total_seconds()
106         params = request.scope['query_string'].decode('utf8')
107
108         self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
109                       f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
110                       f'{qtype} "{params}"\n')
111
112         return response
113
114
115 async def timeout_error(request: Request, #pylint: disable=unused-argument
116                         _: Exception) -> Response:
117     """ Error handler for query timeouts.
118     """
119     loglib.log().comment('Aborted: Query took too long to process.')
120     logdata = loglib.get_and_disable()
121
122     if logdata:
123         return HTMLResponse(logdata)
124
125     return PlainTextResponse("Query took too long to process.", status_code=503)
126
127
128 def get_application(project_dir: Path,
129                     environ: Optional[Mapping[str, str]] = None,
130                     debug: bool = True) -> Starlette:
131     """ Create a Nominatim falcon ASGI application.
132     """
133     config = Configuration(project_dir, environ)
134
135     routes = []
136     legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
137     for name, func in api_impl.ROUTES:
138         endpoint = _wrap_endpoint(func)
139         routes.append(Route(f"/{name}", endpoint=endpoint))
140         if legacy_urls:
141             routes.append(Route(f"/{name}.php", endpoint=endpoint))
142
143     middleware = []
144     if config.get_bool('CORS_NOACCESSCONTROL'):
145         middleware.append(Middleware(CORSMiddleware,
146                                      allow_origins=['*'],
147                                      allow_methods=['GET', 'OPTIONS'],
148                                      max_age=86400))
149
150     log_file = config.LOG_FILE
151     if log_file:
152         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
153
154     exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
155         TimeoutError: timeout_error,
156         asyncio.TimeoutError: timeout_error
157     }
158
159     async def _shutdown() -> None:
160         await app.state.API.close()
161
162     app = Starlette(debug=debug, routes=routes, middleware=middleware,
163                     exception_handlers=exceptions,
164                     on_shutdown=[_shutdown])
165
166     app.state.API = NominatimAPIAsync(project_dir, environ)
167
168     return app
169
170
171 def run_wsgi() -> Starlette:
172     """ Entry point for uvicorn.
173     """
174     return get_application(Path('.'), debug=False)