]> git.openstreetmap.org Git - nominatim.git/commitdiff
add property cache for API
authorSarah Hoffmann <lonvia@denofr.de>
Sat, 28 Jan 2023 21:24:36 +0000 (22:24 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 30 Jan 2023 08:36:17 +0000 (09:36 +0100)
This caches results from querying nominatim_properties.

nominatim/api/connection.py
nominatim/api/core.py
nominatim/api/status.py
test/python/api/test_api_connection.py [new file with mode: 0644]

index b397f62413aa2686825725449f05c3def7239ade..79a5e3470d7b43340500aaa64c8739986e3467d4 100644 (file)
@@ -7,7 +7,7 @@
 """
 Extended SQLAlchemy connection class that also includes access to the schema.
 """
-from typing import Any, Mapping, Sequence, Union
+from typing import Any, Mapping, Sequence, Union, Dict, cast
 
 import sqlalchemy as sa
 from sqlalchemy.ext.asyncio import AsyncConnection
@@ -22,9 +22,11 @@ class SearchConnection:
     """
 
     def __init__(self, conn: AsyncConnection,
-                 tables: SearchTables) -> None:
+                 tables: SearchTables,
+                 properties: Dict[str, Any]) -> None:
         self.connection = conn
         self.t = tables # pylint: disable=invalid-name
+        self._property_cache = properties
 
 
     async def scalar(self, sql: sa.sql.base.Executable,
@@ -41,3 +43,44 @@ class SearchConnection:
         """ Execute a 'execute()' query on the connection.
         """
         return await self.connection.execute(sql, params)
+
+
+    async def get_property(self, name: str, cached: bool = True) -> str:
+        """ Get a property from Nominatim's property table.
+
+            Property values are normally cached so that they are only
+            retrieved from the database when they are queried for the
+            first time with this function. Set 'cached' to False to force
+            reading the property from the database.
+
+            Raises a ValueError if the property does not exist.
+        """
+        if name.startswith('DB:'):
+            raise ValueError(f"Illegal property value '{name}'.")
+
+        if cached and name in self._property_cache:
+            return cast(str, self._property_cache[name])
+
+        sql = sa.select(self.t.properties.c.value)\
+            .where(self.t.properties.c.property == name)
+        value = await self.connection.scalar(sql)
+
+        if value is None:
+            raise ValueError(f"Property '{name}' not found in database.")
+
+        self._property_cache[name] = cast(str, value)
+
+        return cast(str, value)
+
+
+    async def get_db_property(self, name: str) -> Any:
+        """ Get a setting from the database. At the moment, only
+            'server_version', the version of the database software, can
+            be retrieved with this function.
+
+            Raises a ValueError if the property does not exist.
+        """
+        if name != 'server_version':
+            raise ValueError(f"DB setting '{name}' not found in database.")
+
+        return self._property_cache['DB:server_version']
index a1f0e48df43df5c12ce1927588d65cb9dc4b2f08..54f02a938e77517b2a0e23fa5819911fd0d2d434 100644 (file)
@@ -7,7 +7,7 @@
 """
 Implementation of classes for API access via libraries.
 """
-from typing import Mapping, Optional, Any, AsyncIterator
+from typing import Mapping, Optional, Any, AsyncIterator, Dict
 import asyncio
 import contextlib
 from pathlib import Path
@@ -32,6 +32,7 @@ class NominatimAPIAsync:
         self._engine_lock = asyncio.Lock()
         self._engine: Optional[sa_asyncio.AsyncEngine] = None
         self._tables: Optional[SearchTables] = None
+        self._property_cache: Dict[str, Any] = {'DB:server_version': 0}
 
 
     async def setup_database(self) -> None:
@@ -64,11 +65,11 @@ class NominatimAPIAsync:
             try:
                 async with engine.begin() as conn:
                     result = await conn.scalar(sa.text('SHOW server_version_num'))
-                    self.server_version = int(result)
+                    server_version = int(result)
             except asyncpg.PostgresError:
-                self.server_version = 0
+                server_version = 0
 
-            if self.server_version >= 110000:
+            if server_version >= 110000:
                 @sa.event.listens_for(engine.sync_engine, "connect")
                 def _on_connect(dbapi_con: Any, _: Any) -> None:
                     cursor = dbapi_con.cursor()
@@ -76,6 +77,8 @@ class NominatimAPIAsync:
                 # Make sure that all connections get the new settings
                 await self.close()
 
+            self._property_cache['DB:server_version'] = server_version
+
             self._tables = SearchTables(sa.MetaData(), engine.name) # pylint: disable=no-member
             self._engine = engine
 
@@ -104,7 +107,7 @@ class NominatimAPIAsync:
         assert self._tables is not None
 
         async with self._engine.begin() as conn:
-            yield SearchConnection(conn, self._tables)
+            yield SearchConnection(conn, self._tables, self._property_cache)
 
 
     async def status(self) -> StatusResult:
index b6cd69a86b75de17d3c6289606f4f74b639fd350..61e36cc36488c097f58dd77d604b3975879500b4 100644 (file)
@@ -7,7 +7,7 @@
 """
 Classes and function releated to status call.
 """
-from typing import Optional, cast
+from typing import Optional
 import datetime as dt
 import dataclasses
 
@@ -37,10 +37,10 @@ async def get_status(conn: SearchConnection) -> StatusResult:
     status.data_updated = await conn.scalar(sql)
 
     # Database version
-    sql = sa.select(conn.t.properties.c.value)\
-            .where(conn.t.properties.c.property == 'database_version')
-    verstr = await conn.scalar(sql)
-    if verstr is not None:
-        status.database_version = version.parse_version(cast(str, verstr))
+    try:
+        verstr = await conn.get_property('database_version')
+        status.database_version = version.parse_version(verstr)
+    except ValueError:
+        pass
 
     return status
diff --git a/test/python/api/test_api_connection.py b/test/python/api/test_api_connection.py
new file mode 100644 (file)
index 0000000..5609cb0
--- /dev/null
@@ -0,0 +1,93 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Tests for enhanced connection class for API functions.
+"""
+from pathlib import Path
+import pytest
+import pytest_asyncio
+
+import sqlalchemy as sa
+
+from nominatim.api import NominatimAPIAsync
+
+@pytest_asyncio.fixture
+async def apiobj(temp_db):
+    """ Create an asynchronous SQLAlchemy engine for the test DB.
+    """
+    api = NominatimAPIAsync(Path('/invalid'), {})
+    yield api
+    await api.close()
+
+
+@pytest.mark.asyncio
+async def test_run_scalar(apiobj, table_factory):
+    table_factory('foo', definition='that TEXT', content=(('a', ),))
+
+    async with apiobj.begin() as conn:
+        assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a'
+
+
+@pytest.mark.asyncio
+async def test_run_execute(apiobj, table_factory):
+    table_factory('foo', definition='that TEXT', content=(('a', ),))
+
+    async with apiobj.begin() as conn:
+        result = await conn.execute(sa.text('SELECT * FROM foo'))
+        assert result.fetchone()[0] == 'a'
+
+
+@pytest.mark.asyncio
+async def test_get_property_existing_cached(apiobj, table_factory):
+    table_factory('nominatim_properties',
+                  definition='property TEXT, value TEXT',
+                  content=(('dbv', '96723'), ))
+
+    async with apiobj.begin() as conn:
+        assert await conn.get_property('dbv') == '96723'
+
+        await conn.execute(sa.text('TRUNCATE nominatim_properties'))
+
+        assert await conn.get_property('dbv') == '96723'
+
+
+@pytest.mark.asyncio
+async def test_get_property_existing_uncached(apiobj, table_factory):
+    table_factory('nominatim_properties',
+                  definition='property TEXT, value TEXT',
+                  content=(('dbv', '96723'), ))
+
+    async with apiobj.begin() as conn:
+        assert await conn.get_property('dbv') == '96723'
+
+        await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'"))
+
+        assert await conn.get_property('dbv', cached=False) == '1'
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize('param', ['foo', 'DB:server_version'])
+async def test_get_property_missing(apiobj, table_factory, param):
+    table_factory('nominatim_properties',
+                  definition='property TEXT, value TEXT')
+
+    async with apiobj.begin() as conn:
+        with pytest.raises(ValueError):
+            await conn.get_property(param)
+
+
+@pytest.mark.asyncio
+async def test_get_db_property_existing(apiobj):
+    async with apiobj.begin() as conn:
+        assert await conn.get_db_property('server_version') > 0
+
+
+@pytest.mark.asyncio
+async def test_get_db_property_existing(apiobj):
+    async with apiobj.begin() as conn:
+        with pytest.raises(ValueError):
+            await conn.get_db_property('dfkgjd.rijg')