From 16b6484c650eef161e64153f748643ca61553753 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Sat, 28 Jan 2023 22:24:36 +0100 Subject: [PATCH] add property cache for API This caches results from querying nominatim_properties. --- nominatim/api/connection.py | 47 ++++++++++++- nominatim/api/core.py | 13 ++-- nominatim/api/status.py | 12 ++-- test/python/api/test_api_connection.py | 93 ++++++++++++++++++++++++++ 4 files changed, 152 insertions(+), 13 deletions(-) create mode 100644 test/python/api/test_api_connection.py diff --git a/nominatim/api/connection.py b/nominatim/api/connection.py index b397f624..79a5e347 100644 --- a/nominatim/api/connection.py +++ b/nominatim/api/connection.py @@ -7,7 +7,7 @@ """ Extended SQLAlchemy connection class that also includes access to the schema. """ -from typing import Any, Mapping, Sequence, Union +from typing import Any, Mapping, Sequence, Union, Dict, cast import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncConnection @@ -22,9 +22,11 @@ class SearchConnection: """ def __init__(self, conn: AsyncConnection, - tables: SearchTables) -> None: + tables: SearchTables, + properties: Dict[str, Any]) -> None: self.connection = conn self.t = tables # pylint: disable=invalid-name + self._property_cache = properties async def scalar(self, sql: sa.sql.base.Executable, @@ -41,3 +43,44 @@ class SearchConnection: """ Execute a 'execute()' query on the connection. """ return await self.connection.execute(sql, params) + + + async def get_property(self, name: str, cached: bool = True) -> str: + """ Get a property from Nominatim's property table. + + Property values are normally cached so that they are only + retrieved from the database when they are queried for the + first time with this function. Set 'cached' to False to force + reading the property from the database. + + Raises a ValueError if the property does not exist. + """ + if name.startswith('DB:'): + raise ValueError(f"Illegal property value '{name}'.") + + if cached and name in self._property_cache: + return cast(str, self._property_cache[name]) + + sql = sa.select(self.t.properties.c.value)\ + .where(self.t.properties.c.property == name) + value = await self.connection.scalar(sql) + + if value is None: + raise ValueError(f"Property '{name}' not found in database.") + + self._property_cache[name] = cast(str, value) + + return cast(str, value) + + + async def get_db_property(self, name: str) -> Any: + """ Get a setting from the database. At the moment, only + 'server_version', the version of the database software, can + be retrieved with this function. + + Raises a ValueError if the property does not exist. + """ + if name != 'server_version': + raise ValueError(f"DB setting '{name}' not found in database.") + + return self._property_cache['DB:server_version'] diff --git a/nominatim/api/core.py b/nominatim/api/core.py index a1f0e48d..54f02a93 100644 --- a/nominatim/api/core.py +++ b/nominatim/api/core.py @@ -7,7 +7,7 @@ """ Implementation of classes for API access via libraries. """ -from typing import Mapping, Optional, Any, AsyncIterator +from typing import Mapping, Optional, Any, AsyncIterator, Dict import asyncio import contextlib from pathlib import Path @@ -32,6 +32,7 @@ class NominatimAPIAsync: self._engine_lock = asyncio.Lock() self._engine: Optional[sa_asyncio.AsyncEngine] = None self._tables: Optional[SearchTables] = None + self._property_cache: Dict[str, Any] = {'DB:server_version': 0} async def setup_database(self) -> None: @@ -64,11 +65,11 @@ class NominatimAPIAsync: try: async with engine.begin() as conn: result = await conn.scalar(sa.text('SHOW server_version_num')) - self.server_version = int(result) + server_version = int(result) except asyncpg.PostgresError: - self.server_version = 0 + server_version = 0 - if self.server_version >= 110000: + if server_version >= 110000: @sa.event.listens_for(engine.sync_engine, "connect") def _on_connect(dbapi_con: Any, _: Any) -> None: cursor = dbapi_con.cursor() @@ -76,6 +77,8 @@ class NominatimAPIAsync: # Make sure that all connections get the new settings await self.close() + self._property_cache['DB:server_version'] = server_version + self._tables = SearchTables(sa.MetaData(), engine.name) # pylint: disable=no-member self._engine = engine @@ -104,7 +107,7 @@ class NominatimAPIAsync: assert self._tables is not None async with self._engine.begin() as conn: - yield SearchConnection(conn, self._tables) + yield SearchConnection(conn, self._tables, self._property_cache) async def status(self) -> StatusResult: diff --git a/nominatim/api/status.py b/nominatim/api/status.py index b6cd69a8..61e36cc3 100644 --- a/nominatim/api/status.py +++ b/nominatim/api/status.py @@ -7,7 +7,7 @@ """ Classes and function releated to status call. """ -from typing import Optional, cast +from typing import Optional import datetime as dt import dataclasses @@ -37,10 +37,10 @@ async def get_status(conn: SearchConnection) -> StatusResult: status.data_updated = await conn.scalar(sql) # Database version - sql = sa.select(conn.t.properties.c.value)\ - .where(conn.t.properties.c.property == 'database_version') - verstr = await conn.scalar(sql) - if verstr is not None: - status.database_version = version.parse_version(cast(str, verstr)) + try: + verstr = await conn.get_property('database_version') + status.database_version = version.parse_version(verstr) + except ValueError: + pass return status diff --git a/test/python/api/test_api_connection.py b/test/python/api/test_api_connection.py new file mode 100644 index 00000000..5609cb03 --- /dev/null +++ b/test/python/api/test_api_connection.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for enhanced connection class for API functions. +""" +from pathlib import Path +import pytest +import pytest_asyncio + +import sqlalchemy as sa + +from nominatim.api import NominatimAPIAsync + +@pytest_asyncio.fixture +async def apiobj(temp_db): + """ Create an asynchronous SQLAlchemy engine for the test DB. + """ + api = NominatimAPIAsync(Path('/invalid'), {}) + yield api + await api.close() + + +@pytest.mark.asyncio +async def test_run_scalar(apiobj, table_factory): + table_factory('foo', definition='that TEXT', content=(('a', ),)) + + async with apiobj.begin() as conn: + assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a' + + +@pytest.mark.asyncio +async def test_run_execute(apiobj, table_factory): + table_factory('foo', definition='that TEXT', content=(('a', ),)) + + async with apiobj.begin() as conn: + result = await conn.execute(sa.text('SELECT * FROM foo')) + assert result.fetchone()[0] == 'a' + + +@pytest.mark.asyncio +async def test_get_property_existing_cached(apiobj, table_factory): + table_factory('nominatim_properties', + definition='property TEXT, value TEXT', + content=(('dbv', '96723'), )) + + async with apiobj.begin() as conn: + assert await conn.get_property('dbv') == '96723' + + await conn.execute(sa.text('TRUNCATE nominatim_properties')) + + assert await conn.get_property('dbv') == '96723' + + +@pytest.mark.asyncio +async def test_get_property_existing_uncached(apiobj, table_factory): + table_factory('nominatim_properties', + definition='property TEXT, value TEXT', + content=(('dbv', '96723'), )) + + async with apiobj.begin() as conn: + assert await conn.get_property('dbv') == '96723' + + await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'")) + + assert await conn.get_property('dbv', cached=False) == '1' + + +@pytest.mark.asyncio +@pytest.mark.parametrize('param', ['foo', 'DB:server_version']) +async def test_get_property_missing(apiobj, table_factory, param): + table_factory('nominatim_properties', + definition='property TEXT, value TEXT') + + async with apiobj.begin() as conn: + with pytest.raises(ValueError): + await conn.get_property(param) + + +@pytest.mark.asyncio +async def test_get_db_property_existing(apiobj): + async with apiobj.begin() as conn: + assert await conn.get_db_property('server_version') > 0 + + +@pytest.mark.asyncio +async def test_get_db_property_existing(apiobj): + async with apiobj.begin() as conn: + with pytest.raises(ValueError): + await conn.get_db_property('dfkgjd.rijg') -- 2.39.5