]> git.openstreetmap.org Git - nominatim.git/commitdiff
move server route creation into async function
authorSarah Hoffmann <lonvia@denofr.de>
Wed, 13 Nov 2024 20:27:14 +0000 (21:27 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Wed, 13 Nov 2024 20:27:14 +0000 (21:27 +0100)
src/nominatim_api/server/falcon/server.py
src/nominatim_api/server/starlette/server.py

index b58c1cfa830c5a863a8fc46445f3c62b02bdea06..e252e3c8925f3c34daaabfa3d493df12d609357f 100644 (file)
@@ -147,12 +147,36 @@ class FileLoggingMiddleware:
                       f'{resource.name} "{params}"\n')
 
 
                       f'{resource.name} "{params}"\n')
 
 
-class APIShutdown:
-    """ Middleware that closes any open database connections.
+class APIMiddleware:
+    """ Middleware managing the Nominatim database connection.
     """
 
     """
 
-    def __init__(self, api: NominatimAPIAsync) -> None:
-        self.api = api
+    def __init__(self, project_dir: Path, environ: Optional[Mapping[str, str]]) -> None:
+        self.api = NominatimAPIAsync(project_dir, environ)
+        self.app: Optional[App] = None
+
+    @property
+    def config(self) -> Configuration:
+        """ Get the configuration for Nominatim.
+        """
+        return self.api.config
+
+    def set_app(self, app: App) -> None:
+        """ Set the Falcon application this middleware is connected to.
+        """
+        self.app = app
+
+    async def process_startup(self, *_: Any) -> None:
+        """ Process the ASGI lifespan startup event.
+        """
+        assert self.app is not None
+        legacy_urls = self.api.config.get_bool('SERVE_LEGACY_URLS')
+        formatter = load_format_dispatcher('v1', self.api.config.project_dir)
+        for name, func in api_impl.ROUTES:
+            endpoint = EndpointWrapper(name, func, self.api, formatter)
+            self.app.add_route(f"/{name}", endpoint)
+            if legacy_urls:
+                self.app.add_route(f"/{name}.php", endpoint)
 
     async def process_shutdown(self, *_: Any) -> None:
         """Process the ASGI lifespan shutdown event.
 
     async def process_shutdown(self, *_: Any) -> None:
         """Process the ASGI lifespan shutdown event.
@@ -164,28 +188,22 @@ def get_application(project_dir: Path,
                     environ: Optional[Mapping[str, str]] = None) -> App:
     """ Create a Nominatim Falcon ASGI application.
     """
                     environ: Optional[Mapping[str, str]] = None) -> App:
     """ Create a Nominatim Falcon ASGI application.
     """
-    api = NominatimAPIAsync(project_dir, environ)
+    apimw = APIMiddleware(project_dir, environ)
 
 
-    middleware: List[object] = [APIShutdown(api)]
-    log_file = api.config.LOG_FILE
+    middleware: List[object] = [apimw]
+    log_file = apimw.config.LOG_FILE
     if log_file:
         middleware.append(FileLoggingMiddleware(log_file))
 
     if log_file:
         middleware.append(FileLoggingMiddleware(log_file))
 
-    app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'),
+    app = App(cors_enable=apimw.config.get_bool('CORS_NOACCESSCONTROL'),
               middleware=middleware)
               middleware=middleware)
+
+    apimw.set_app(app)
     app.add_error_handler(HTTPNominatimError, nominatim_error_handler)
     app.add_error_handler(TimeoutError, timeout_error_handler)
     # different from TimeoutError in Python <= 3.10
     app.add_error_handler(asyncio.TimeoutError, timeout_error_handler)  # type: ignore[arg-type]
 
     app.add_error_handler(HTTPNominatimError, nominatim_error_handler)
     app.add_error_handler(TimeoutError, timeout_error_handler)
     # different from TimeoutError in Python <= 3.10
     app.add_error_handler(asyncio.TimeoutError, timeout_error_handler)  # type: ignore[arg-type]
 
-    legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS')
-    formatter = load_format_dispatcher('v1', project_dir)
-    for name, func in api_impl.ROUTES:
-        endpoint = EndpointWrapper(name, func, api, formatter)
-        app.add_route(f"/{name}", endpoint)
-        if legacy_urls:
-            app.add_route(f"/{name}.php", endpoint)
-
     return app
 
 
     return app
 
 
index 48f0207ac7b2f21e07a16ca83ad627d984b74200..9d01492043d750ea2a204f96b2300f917c71923a 100644 (file)
@@ -11,6 +11,7 @@ from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awai
 from pathlib import Path
 import datetime as dt
 import asyncio
 from pathlib import Path
 import datetime as dt
 import asyncio
+import contextlib
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 
 from starlette.applications import Starlette
 from starlette.routing import Route
@@ -66,7 +67,7 @@ class ParamWrapper(ASGIAdaptor):
         return cast(Configuration, self.request.app.state.API.config)
 
     def formatting(self) -> FormatDispatcher:
         return cast(Configuration, self.request.app.state.API.config)
 
     def formatting(self) -> FormatDispatcher:
-        return cast(FormatDispatcher, self.request.app.state.API.formatter)
+        return cast(FormatDispatcher, self.request.app.state.formatter)
 
 
 def _wrap_endpoint(func: EndpointFunc)\
 
 
 def _wrap_endpoint(func: EndpointFunc)\
@@ -132,14 +133,6 @@ def get_application(project_dir: Path,
     """
     config = Configuration(project_dir, environ)
 
     """
     config = Configuration(project_dir, environ)
 
-    routes = []
-    legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
-    for name, func in api_impl.ROUTES:
-        endpoint = _wrap_endpoint(func)
-        routes.append(Route(f"/{name}", endpoint=endpoint))
-        if legacy_urls:
-            routes.append(Route(f"/{name}.php", endpoint=endpoint))
-
     middleware = []
     if config.get_bool('CORS_NOACCESSCONTROL'):
         middleware.append(Middleware(CORSMiddleware,
     middleware = []
     if config.get_bool('CORS_NOACCESSCONTROL'):
         middleware.append(Middleware(CORSMiddleware,
@@ -156,14 +149,26 @@ def get_application(project_dir: Path,
         asyncio.TimeoutError: timeout_error
     }
 
         asyncio.TimeoutError: timeout_error
     }
 
-    async def _shutdown() -> None:
+    @contextlib.asynccontextmanager
+    async def lifespan(app: Starlette) -> None:
+        app.state.API = NominatimAPIAsync(project_dir, environ)
+        config = app.state.API.config
+
+        legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
+        for name, func in api_impl.ROUTES:
+            endpoint = _wrap_endpoint(func)
+            app.routes.append(Route(f"/{name}", endpoint=endpoint))
+            if legacy_urls:
+                app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
+
+        yield
+
         await app.state.API.close()
 
         await app.state.API.close()
 
-    app = Starlette(debug=debug, routes=routes, middleware=middleware,
+    app = Starlette(debug=debug, middleware=middleware,
                     exception_handlers=exceptions,
                     exception_handlers=exceptions,
-                    on_shutdown=[_shutdown])
+                    lifespan=lifespan)
 
 
-    app.state.API = NominatimAPIAsync(project_dir, environ)
     app.state.formatter = load_format_dispatcher('v1', project_dir)
 
     return app
     app.state.formatter = load_format_dispatcher('v1', project_dir)
 
     return app