]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/server/starlette/server.py
enable search endpoint only when search table is available
[nominatim.git] / src / nominatim_api / server / starlette / server.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Server implementation using the starlette webserver framework.
9 """
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \
11                    Awaitable, AsyncIterator
12 from pathlib import Path
13 import datetime as dt
14 import asyncio
15 import contextlib
16
17 from starlette.applications import Starlette
18 from starlette.routing import Route
19 from starlette.exceptions import HTTPException
20 from starlette.responses import Response, PlainTextResponse, HTMLResponse
21 from starlette.requests import Request
22 from starlette.middleware import Middleware
23 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
24 from starlette.middleware.cors import CORSMiddleware
25
26 from ...config import Configuration
27 from ...core import NominatimAPIAsync
28 from ... import v1 as api_impl
29 from ...result_formatting import FormatDispatcher, load_format_dispatcher
30 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
31 from ... import logging as loglib
32
33
34 class ParamWrapper(ASGIAdaptor):
35     """ Adaptor class for server glue to Starlette framework.
36     """
37
38     def __init__(self, request: Request) -> None:
39         self.request = request
40
41     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
42         return self.request.query_params.get(name, default=default)
43
44     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
45         return self.request.headers.get(name, default)
46
47     def error(self, msg: str, status: int = 400) -> HTTPException:
48         return HTTPException(status, detail=msg,
49                              headers={'content-type': self.content_type})
50
51     def create_response(self, status: int, output: str, num_results: int) -> Response:
52         self.request.state.num_results = num_results
53         return Response(output, status_code=status, media_type=self.content_type)
54
55     def base_uri(self) -> str:
56         scheme = self.request.url.scheme
57         host = self.request.url.hostname
58         port = self.request.url.port
59         root = self.request.scope['root_path']
60         if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
61             port = None
62         if port is not None:
63             return f"{scheme}://{host}:{port}{root}"
64
65         return f"{scheme}://{host}{root}"
66
67     def config(self) -> Configuration:
68         return cast(Configuration, self.request.app.state.API.config)
69
70     def formatting(self) -> FormatDispatcher:
71         return cast(FormatDispatcher, self.request.app.state.formatter)
72
73
74 def _wrap_endpoint(func: EndpointFunc)\
75         -> Callable[[Request], Coroutine[Any, Any, Response]]:
76     async def _callback(request: Request) -> Response:
77         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
78
79     return _callback
80
81
82 class FileLoggingMiddleware(BaseHTTPMiddleware):
83     """ Middleware to log selected requests into a file.
84     """
85
86     def __init__(self, app: Starlette, file_name: str = ''):
87         super().__init__(app)
88         self.fd = open(file_name, 'a', buffering=1, encoding='utf8')
89
90     async def dispatch(self, request: Request,
91                        call_next: RequestResponseEndpoint) -> Response:
92         start = dt.datetime.now(tz=dt.timezone.utc)
93         response = await call_next(request)
94
95         if response.status_code != 200:
96             return response
97
98         finish = dt.datetime.now(tz=dt.timezone.utc)
99
100         for endpoint in ('reverse', 'search', 'lookup', 'details'):
101             if request.url.path.startswith('/' + endpoint):
102                 qtype = endpoint
103                 break
104         else:
105             return response
106
107         duration = (finish - start).total_seconds()
108         params = request.scope['query_string'].decode('utf8')
109
110         self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
111                       f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
112                       f'{qtype} "{params}"\n')
113
114         return response
115
116
117 async def timeout_error(request: Request,
118                         _: Exception) -> Response:
119     """ Error handler for query timeouts.
120     """
121     loglib.log().comment('Aborted: Query took too long to process.')
122     logdata = loglib.get_and_disable()
123
124     if logdata:
125         return HTMLResponse(logdata)
126
127     return PlainTextResponse("Query took too long to process.", status_code=503)
128
129
130 def get_application(project_dir: Path,
131                     environ: Optional[Mapping[str, str]] = None,
132                     debug: bool = True) -> Starlette:
133     """ Create a Nominatim falcon ASGI application.
134     """
135     config = Configuration(project_dir, environ)
136
137     middleware = []
138     if config.get_bool('CORS_NOACCESSCONTROL'):
139         middleware.append(Middleware(CORSMiddleware,
140                                      allow_origins=['*'],
141                                      allow_methods=['GET', 'OPTIONS'],
142                                      max_age=86400))
143
144     log_file = config.LOG_FILE
145     if log_file:
146         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
147
148     exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
149         TimeoutError: timeout_error,
150         asyncio.TimeoutError: timeout_error
151     }
152
153     @contextlib.asynccontextmanager
154     async def lifespan(app: Starlette) -> AsyncIterator[Any]:
155         app.state.API = NominatimAPIAsync(project_dir, environ)
156         config = app.state.API.config
157
158         legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
159         for name, func in await api_impl.get_routes(app.state.API):
160             endpoint = _wrap_endpoint(func)
161             app.routes.append(Route(f"/{name}", endpoint=endpoint))
162             if legacy_urls:
163                 app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
164
165         yield
166
167         await app.state.API.close()
168
169     app = Starlette(debug=debug, middleware=middleware,
170                     exception_handlers=exceptions,
171                     lifespan=lifespan)
172
173     app.state.formatter = load_format_dispatcher('v1', project_dir)
174
175     return app
176
177
178 def run_wsgi() -> Starlette:
179     """ Entry point for uvicorn.
180     """
181     return get_application(Path('.'), debug=False)