From e56957f047ffacbce41eb6af914a192274996955 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Tue, 24 Jan 2023 10:56:22 +0100 Subject: [PATCH] api: delay setup of initial database connection Defer database setup until the first call to a function. Needs an additional lock because the setup still needs to be done sequentially. --- nominatim/api.py | 127 +++++++++++++++++++++++++------------ nominatim/apicmd/status.py | 19 +++--- 2 files changed, 94 insertions(+), 52 deletions(-) diff --git a/nominatim/api.py b/nominatim/api.py index 10cca533..4ce89595 100644 --- a/nominatim/api.py +++ b/nominatim/api.py @@ -2,18 +2,18 @@ # # This file is part of Nominatim. (https://nominatim.org) # -# Copyright (C) 2022 by the Nominatim developer community. +# Copyright (C) 2023 by the Nominatim developer community. # For a full list of authors see the git log. """ Implementation of classes for API access via libraries. """ -from typing import Mapping, Optional, cast, Any +from typing import Mapping, Optional, Any, AsyncIterator import asyncio +import contextlib from pathlib import Path -from sqlalchemy import text, event -from sqlalchemy.engine.url import URL -from sqlalchemy.ext.asyncio import create_async_engine +import sqlalchemy as sa +import sqlalchemy.ext.asyncio as sa_asyncio import asyncpg from nominatim.config import Configuration @@ -25,52 +25,93 @@ class NominatimAPIAsync: def __init__(self, project_dir: Path, environ: Optional[Mapping[str, str]] = None) -> None: self.config = Configuration(project_dir, environ) + self.server_version = 0 + + self._engine_lock = asyncio.Lock() + self._engine: Optional[sa_asyncio.AsyncEngine] = None + + + async def setup_database(self) -> None: + """ Set up the engine and connection parameters. + + This function will be implicitly called when the database is + accessed for the first time. You may also call it explicitly to + avoid that the first call is delayed by the setup. + """ + async with self._engine_lock: + if self._engine: + return + + dsn = self.config.get_database_params() + + dburl = sa.engine.URL.create( + 'postgresql+asyncpg', + database=dsn.get('dbname'), + username=dsn.get('user'), password=dsn.get('password'), + host=dsn.get('host'), port=int(dsn['port']) if 'port' in dsn else None, + query={k: v for k, v in dsn.items() + if k not in ('user', 'password', 'dbname', 'host', 'port')}) + engine = sa_asyncio.create_async_engine( + dburl, future=True, + connect_args={'server_settings': { + 'DateStyle': 'sql,european', + 'max_parallel_workers_per_gather': '0' + }}) + + try: + async with engine.begin() as conn: + result = await conn.scalar(sa.text('SHOW server_version_num')) + self.server_version = int(result) + except asyncpg.PostgresError: + self.server_version = 0 + + if self.server_version >= 110000: + @sa.event.listens_for(engine.sync_engine, "connect") # type: ignore[misc] + def _on_connect(dbapi_con: Any, _: Any) -> None: + cursor = dbapi_con.cursor() + cursor.execute("SET jit_above_cost TO '-1'") + # Make sure that all connections get the new settings + await self.close() + + self._engine = engine - dsn = self.config.get_database_params() - - dburl = URL.create( - 'postgresql+asyncpg', - database=dsn.get('dbname'), - username=dsn.get('user'), password=dsn.get('password'), - host=dsn.get('host'), port=int(dsn['port']) if 'port' in dsn else None, - query={k: v for k, v in dsn.items() - if k not in ('user', 'password', 'dbname', 'host', 'port')}) - self.engine = create_async_engine( - dburl, future=True, - connect_args={'server_settings': { - 'DateStyle': 'sql,european', - 'max_parallel_workers_per_gather': '0' - }}) - asyncio.get_event_loop().run_until_complete(self._query_server_version()) - asyncio.get_event_loop().run_until_complete(self.close()) - - if self.server_version >= 110000: - @event.listens_for(self.engine.sync_engine, "connect") # type: ignore[misc] - def _on_connect(dbapi_con: Any, _: Any) -> None: - cursor = dbapi_con.cursor() - cursor.execute("SET jit_above_cost TO '-1'") - - - async def _query_server_version(self) -> None: - try: - async with self.engine.begin() as conn: - result = await conn.scalar(text('SHOW server_version_num')) - self.server_version = int(cast(str, result)) - except asyncpg.PostgresError: - self.server_version = 0 async def close(self) -> None: """ Close all active connections to the database. The NominatimAPIAsync object remains usable after closing. If a new API functions is called, new connections are created. """ - await self.engine.dispose() + if self._engine is not None: + await self._engine.dispose() + + + @contextlib.asynccontextmanager + async def begin(self) -> AsyncIterator[sa_asyncio.AsyncConnection]: + """ Create a new connection with automatic transaction handling. + + This function may be used to get low-level access to the database. + Refer to the documentation of SQLAlchemy for details how to use + the connection object. + """ + if self._engine is None: + await self.setup_database() + + assert self._engine is not None + + async with self._engine.begin() as conn: + yield conn async def status(self) -> StatusResult: """ Return the status of the database. """ - return await get_status(self.engine) + try: + async with self.begin() as conn: + status = await get_status(conn) + except asyncpg.PostgresError: + return StatusResult(700, 'Database connection failed') + + return status class NominatimAPI: @@ -79,7 +120,8 @@ class NominatimAPI: def __init__(self, project_dir: Path, environ: Optional[Mapping[str, str]] = None) -> None: - self.async_api = NominatimAPIAsync(project_dir, environ) + self._loop = asyncio.new_event_loop() + self._async_api = NominatimAPIAsync(project_dir, environ) def close(self) -> None: @@ -87,10 +129,11 @@ class NominatimAPI: object remains usable after closing. If a new API functions is called, new connections are created. """ - asyncio.get_event_loop().run_until_complete(self.async_api.close()) + self._loop.run_until_complete(self._async_api.close()) + self._loop.close() def status(self) -> StatusResult: """ Return the status of the database. """ - return asyncio.get_event_loop().run_until_complete(self.async_api.status()) + return self._loop.run_until_complete(self._async_api.status()) diff --git a/nominatim/apicmd/status.py b/nominatim/apicmd/status.py index 85071db9..560953d3 100644 --- a/nominatim/apicmd/status.py +++ b/nominatim/apicmd/status.py @@ -2,7 +2,7 @@ # # This file is part of Nominatim. (https://nominatim.org) # -# Copyright (C) 2022 by the Nominatim developer community. +# Copyright (C) 2023 by the Nominatim developer community. # For a full list of authors see the git log. """ Classes and function releated to status call. @@ -10,8 +10,8 @@ Classes and function releated to status call. from typing import Optional, cast import datetime as dt -import sqlalchemy as sqla -from sqlalchemy.ext.asyncio.engine import AsyncEngine, AsyncConnection +import sqlalchemy as sa +from sqlalchemy.ext.asyncio.engine import AsyncConnection import asyncpg from nominatim import version @@ -31,7 +31,7 @@ class StatusResult: async def _get_database_date(conn: AsyncConnection) -> Optional[dt.datetime]: """ Query the database date. """ - sql = sqla.text('SELECT lastimportdate FROM import_status LIMIT 1') + sql = sa.text('SELECT lastimportdate FROM import_status LIMIT 1') result = await conn.execute(sql) for row in result: @@ -41,8 +41,8 @@ async def _get_database_date(conn: AsyncConnection) -> Optional[dt.datetime]: async def _get_database_version(conn: AsyncConnection) -> Optional[version.NominatimVersion]: - sql = sqla.text("""SELECT value FROM nominatim_properties - WHERE property = 'database_version'""") + sql = sa.text("""SELECT value FROM nominatim_properties + WHERE property = 'database_version'""") result = await conn.execute(sql) for row in result: @@ -51,14 +51,13 @@ async def _get_database_version(conn: AsyncConnection) -> Optional[version.Nomin return None -async def get_status(engine: AsyncEngine) -> StatusResult: +async def get_status(conn: AsyncConnection) -> StatusResult: """ Execute a status API call. """ status = StatusResult(0, 'OK') try: - async with engine.begin() as conn: - status.data_updated = await _get_database_date(conn) - status.database_version = await _get_database_version(conn) + status.data_updated = await _get_database_date(conn) + status.database_version = await _get_database_version(conn) except asyncpg.PostgresError: return StatusResult(700, 'Database connection failed') -- 2.39.5