#
# This file is part of Nominatim. (https://nominatim.org)
#
-# Copyright (C) 2022 by the Nominatim developer community.
+# Copyright (C) 2023 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Implementation of classes for API access via libraries.
"""
-from typing import Mapping, Optional, cast, Any
+from typing import Mapping, Optional, Any, AsyncIterator
import asyncio
+import contextlib
from pathlib import Path
-from sqlalchemy import text, event
-from sqlalchemy.engine.url import URL
-from sqlalchemy.ext.asyncio import create_async_engine
+import sqlalchemy as sa
+import sqlalchemy.ext.asyncio as sa_asyncio
import asyncpg
from nominatim.config import Configuration
def __init__(self, project_dir: Path,
environ: Optional[Mapping[str, str]] = None) -> None:
self.config = Configuration(project_dir, environ)
+ self.server_version = 0
+
+ self._engine_lock = asyncio.Lock()
+ self._engine: Optional[sa_asyncio.AsyncEngine] = None
+
+
+ async def setup_database(self) -> None:
+ """ Set up the engine and connection parameters.
+
+ This function will be implicitly called when the database is
+ accessed for the first time. You may also call it explicitly to
+ avoid that the first call is delayed by the setup.
+ """
+ async with self._engine_lock:
+ if self._engine:
+ return
+
+ dsn = self.config.get_database_params()
+
+ dburl = sa.engine.URL.create(
+ 'postgresql+asyncpg',
+ database=dsn.get('dbname'),
+ username=dsn.get('user'), password=dsn.get('password'),
+ host=dsn.get('host'), port=int(dsn['port']) if 'port' in dsn else None,
+ query={k: v for k, v in dsn.items()
+ if k not in ('user', 'password', 'dbname', 'host', 'port')})
+ engine = sa_asyncio.create_async_engine(
+ dburl, future=True,
+ connect_args={'server_settings': {
+ 'DateStyle': 'sql,european',
+ 'max_parallel_workers_per_gather': '0'
+ }})
+
+ try:
+ async with engine.begin() as conn:
+ result = await conn.scalar(sa.text('SHOW server_version_num'))
+ self.server_version = int(result)
+ except asyncpg.PostgresError:
+ self.server_version = 0
+
+ if self.server_version >= 110000:
+ @sa.event.listens_for(engine.sync_engine, "connect") # type: ignore[misc]
+ def _on_connect(dbapi_con: Any, _: Any) -> None:
+ cursor = dbapi_con.cursor()
+ cursor.execute("SET jit_above_cost TO '-1'")
+ # Make sure that all connections get the new settings
+ await self.close()
+
+ self._engine = engine
- dsn = self.config.get_database_params()
-
- dburl = URL.create(
- 'postgresql+asyncpg',
- database=dsn.get('dbname'),
- username=dsn.get('user'), password=dsn.get('password'),
- host=dsn.get('host'), port=int(dsn['port']) if 'port' in dsn else None,
- query={k: v for k, v in dsn.items()
- if k not in ('user', 'password', 'dbname', 'host', 'port')})
- self.engine = create_async_engine(
- dburl, future=True,
- connect_args={'server_settings': {
- 'DateStyle': 'sql,european',
- 'max_parallel_workers_per_gather': '0'
- }})
- asyncio.get_event_loop().run_until_complete(self._query_server_version())
- asyncio.get_event_loop().run_until_complete(self.close())
-
- if self.server_version >= 110000:
- @event.listens_for(self.engine.sync_engine, "connect") # type: ignore[misc]
- def _on_connect(dbapi_con: Any, _: Any) -> None:
- cursor = dbapi_con.cursor()
- cursor.execute("SET jit_above_cost TO '-1'")
-
-
- async def _query_server_version(self) -> None:
- try:
- async with self.engine.begin() as conn:
- result = await conn.scalar(text('SHOW server_version_num'))
- self.server_version = int(cast(str, result))
- except asyncpg.PostgresError:
- self.server_version = 0
async def close(self) -> None:
""" Close all active connections to the database. The NominatimAPIAsync
object remains usable after closing. If a new API functions is
called, new connections are created.
"""
- await self.engine.dispose()
+ if self._engine is not None:
+ await self._engine.dispose()
+
+
+ @contextlib.asynccontextmanager
+ async def begin(self) -> AsyncIterator[sa_asyncio.AsyncConnection]:
+ """ Create a new connection with automatic transaction handling.
+
+ This function may be used to get low-level access to the database.
+ Refer to the documentation of SQLAlchemy for details how to use
+ the connection object.
+ """
+ if self._engine is None:
+ await self.setup_database()
+
+ assert self._engine is not None
+
+ async with self._engine.begin() as conn:
+ yield conn
async def status(self) -> StatusResult:
""" Return the status of the database.
"""
- return await get_status(self.engine)
+ try:
+ async with self.begin() as conn:
+ status = await get_status(conn)
+ except asyncpg.PostgresError:
+ return StatusResult(700, 'Database connection failed')
+
+ return status
class NominatimAPI:
def __init__(self, project_dir: Path,
environ: Optional[Mapping[str, str]] = None) -> None:
- self.async_api = NominatimAPIAsync(project_dir, environ)
+ self._loop = asyncio.new_event_loop()
+ self._async_api = NominatimAPIAsync(project_dir, environ)
def close(self) -> None:
object remains usable after closing. If a new API functions is
called, new connections are created.
"""
- asyncio.get_event_loop().run_until_complete(self.async_api.close())
+ self._loop.run_until_complete(self._async_api.close())
+ self._loop.close()
def status(self) -> StatusResult:
""" Return the status of the database.
"""
- return asyncio.get_event_loop().run_until_complete(self.async_api.status())
+ return self._loop.run_until_complete(self._async_api.status())
#
# This file is part of Nominatim. (https://nominatim.org)
#
-# Copyright (C) 2022 by the Nominatim developer community.
+# Copyright (C) 2023 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Classes and function releated to status call.
from typing import Optional, cast
import datetime as dt
-import sqlalchemy as sqla
-from sqlalchemy.ext.asyncio.engine import AsyncEngine, AsyncConnection
+import sqlalchemy as sa
+from sqlalchemy.ext.asyncio.engine import AsyncConnection
import asyncpg
from nominatim import version
async def _get_database_date(conn: AsyncConnection) -> Optional[dt.datetime]:
""" Query the database date.
"""
- sql = sqla.text('SELECT lastimportdate FROM import_status LIMIT 1')
+ sql = sa.text('SELECT lastimportdate FROM import_status LIMIT 1')
result = await conn.execute(sql)
for row in result:
async def _get_database_version(conn: AsyncConnection) -> Optional[version.NominatimVersion]:
- sql = sqla.text("""SELECT value FROM nominatim_properties
- WHERE property = 'database_version'""")
+ sql = sa.text("""SELECT value FROM nominatim_properties
+ WHERE property = 'database_version'""")
result = await conn.execute(sql)
for row in result:
return None
-async def get_status(engine: AsyncEngine) -> StatusResult:
+async def get_status(conn: AsyncConnection) -> StatusResult:
""" Execute a status API call.
"""
status = StatusResult(0, 'OK')
try:
- async with engine.begin() as conn:
- status.data_updated = await _get_database_date(conn)
- status.database_version = await _get_database_version(conn)
+ status.data_updated = await _get_database_date(conn)
+ status.database_version = await _get_database_version(conn)
except asyncpg.PostgresError:
return StatusResult(700, 'Database connection failed')