This caches results from querying nominatim_properties.
"""
Extended SQLAlchemy connection class that also includes access to the schema.
"""
-from typing import Any, Mapping, Sequence, Union
+from typing import Any, Mapping, Sequence, Union, Dict, cast
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncConnection
"""
def __init__(self, conn: AsyncConnection,
- tables: SearchTables) -> None:
+ tables: SearchTables,
+ properties: Dict[str, Any]) -> None:
self.connection = conn
self.t = tables # pylint: disable=invalid-name
+ self._property_cache = properties
async def scalar(self, sql: sa.sql.base.Executable,
""" Execute a 'execute()' query on the connection.
"""
return await self.connection.execute(sql, params)
+
+
+ async def get_property(self, name: str, cached: bool = True) -> str:
+ """ Get a property from Nominatim's property table.
+
+ Property values are normally cached so that they are only
+ retrieved from the database when they are queried for the
+ first time with this function. Set 'cached' to False to force
+ reading the property from the database.
+
+ Raises a ValueError if the property does not exist.
+ """
+ if name.startswith('DB:'):
+ raise ValueError(f"Illegal property value '{name}'.")
+
+ if cached and name in self._property_cache:
+ return cast(str, self._property_cache[name])
+
+ sql = sa.select(self.t.properties.c.value)\
+ .where(self.t.properties.c.property == name)
+ value = await self.connection.scalar(sql)
+
+ if value is None:
+ raise ValueError(f"Property '{name}' not found in database.")
+
+ self._property_cache[name] = cast(str, value)
+
+ return cast(str, value)
+
+
+ async def get_db_property(self, name: str) -> Any:
+ """ Get a setting from the database. At the moment, only
+ 'server_version', the version of the database software, can
+ be retrieved with this function.
+
+ Raises a ValueError if the property does not exist.
+ """
+ if name != 'server_version':
+ raise ValueError(f"DB setting '{name}' not found in database.")
+
+ return self._property_cache['DB:server_version']
"""
Implementation of classes for API access via libraries.
"""
-from typing import Mapping, Optional, Any, AsyncIterator
+from typing import Mapping, Optional, Any, AsyncIterator, Dict
import asyncio
import contextlib
from pathlib import Path
self._engine_lock = asyncio.Lock()
self._engine: Optional[sa_asyncio.AsyncEngine] = None
self._tables: Optional[SearchTables] = None
+ self._property_cache: Dict[str, Any] = {'DB:server_version': 0}
async def setup_database(self) -> None:
try:
async with engine.begin() as conn:
result = await conn.scalar(sa.text('SHOW server_version_num'))
- self.server_version = int(result)
+ server_version = int(result)
except asyncpg.PostgresError:
- self.server_version = 0
+ server_version = 0
- if self.server_version >= 110000:
+ if server_version >= 110000:
@sa.event.listens_for(engine.sync_engine, "connect")
def _on_connect(dbapi_con: Any, _: Any) -> None:
cursor = dbapi_con.cursor()
# Make sure that all connections get the new settings
await self.close()
+ self._property_cache['DB:server_version'] = server_version
+
self._tables = SearchTables(sa.MetaData(), engine.name) # pylint: disable=no-member
self._engine = engine
assert self._tables is not None
async with self._engine.begin() as conn:
- yield SearchConnection(conn, self._tables)
+ yield SearchConnection(conn, self._tables, self._property_cache)
async def status(self) -> StatusResult:
"""
Classes and function releated to status call.
"""
-from typing import Optional, cast
+from typing import Optional
import datetime as dt
import dataclasses
status.data_updated = await conn.scalar(sql)
# Database version
- sql = sa.select(conn.t.properties.c.value)\
- .where(conn.t.properties.c.property == 'database_version')
- verstr = await conn.scalar(sql)
- if verstr is not None:
- status.database_version = version.parse_version(cast(str, verstr))
+ try:
+ verstr = await conn.get_property('database_version')
+ status.database_version = version.parse_version(verstr)
+ except ValueError:
+ pass
return status
--- /dev/null
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Tests for enhanced connection class for API functions.
+"""
+from pathlib import Path
+import pytest
+import pytest_asyncio
+
+import sqlalchemy as sa
+
+from nominatim.api import NominatimAPIAsync
+
+@pytest_asyncio.fixture
+async def apiobj(temp_db):
+ """ Create an asynchronous SQLAlchemy engine for the test DB.
+ """
+ api = NominatimAPIAsync(Path('/invalid'), {})
+ yield api
+ await api.close()
+
+
+@pytest.mark.asyncio
+async def test_run_scalar(apiobj, table_factory):
+ table_factory('foo', definition='that TEXT', content=(('a', ),))
+
+ async with apiobj.begin() as conn:
+ assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a'
+
+
+@pytest.mark.asyncio
+async def test_run_execute(apiobj, table_factory):
+ table_factory('foo', definition='that TEXT', content=(('a', ),))
+
+ async with apiobj.begin() as conn:
+ result = await conn.execute(sa.text('SELECT * FROM foo'))
+ assert result.fetchone()[0] == 'a'
+
+
+@pytest.mark.asyncio
+async def test_get_property_existing_cached(apiobj, table_factory):
+ table_factory('nominatim_properties',
+ definition='property TEXT, value TEXT',
+ content=(('dbv', '96723'), ))
+
+ async with apiobj.begin() as conn:
+ assert await conn.get_property('dbv') == '96723'
+
+ await conn.execute(sa.text('TRUNCATE nominatim_properties'))
+
+ assert await conn.get_property('dbv') == '96723'
+
+
+@pytest.mark.asyncio
+async def test_get_property_existing_uncached(apiobj, table_factory):
+ table_factory('nominatim_properties',
+ definition='property TEXT, value TEXT',
+ content=(('dbv', '96723'), ))
+
+ async with apiobj.begin() as conn:
+ assert await conn.get_property('dbv') == '96723'
+
+ await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'"))
+
+ assert await conn.get_property('dbv', cached=False) == '1'
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize('param', ['foo', 'DB:server_version'])
+async def test_get_property_missing(apiobj, table_factory, param):
+ table_factory('nominatim_properties',
+ definition='property TEXT, value TEXT')
+
+ async with apiobj.begin() as conn:
+ with pytest.raises(ValueError):
+ await conn.get_property(param)
+
+
+@pytest.mark.asyncio
+async def test_get_db_property_existing(apiobj):
+ async with apiobj.begin() as conn:
+ assert await conn.get_db_property('server_version') > 0
+
+
+@pytest.mark.asyncio
+async def test_get_db_property_existing(apiobj):
+ async with apiobj.begin() as conn:
+ with pytest.raises(ValueError):
+ await conn.get_db_property('dfkgjd.rijg')