]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/connection.py
add a timeout for DB queries
[nominatim.git] / nominatim / api / connection.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Extended SQLAlchemy connection class that also includes access to the schema.
9 """
10 from typing import cast, Any, Mapping, Sequence, Union, Dict, Optional, Set, \
11                    Awaitable, Callable, TypeVar
12 import asyncio
13
14 import sqlalchemy as sa
15 from sqlalchemy.ext.asyncio import AsyncConnection
16
17 from nominatim.typing import SaFromClause
18 from nominatim.db.sqlalchemy_schema import SearchTables
19 from nominatim.db.sqlalchemy_types import Geometry
20 from nominatim.api.logging import log
21
22 T = TypeVar('T')
23
24 class SearchConnection:
25     """ An extended SQLAlchemy connection class, that also contains
26         then table definitions. The underlying asynchronous SQLAlchemy
27         connection can be accessed with the 'connection' property.
28         The 't' property is the collection of Nominatim tables.
29     """
30
31     def __init__(self, conn: AsyncConnection,
32                  tables: SearchTables,
33                  properties: Dict[str, Any]) -> None:
34         self.connection = conn
35         self.t = tables # pylint: disable=invalid-name
36         self._property_cache = properties
37         self._classtables: Optional[Set[str]] = None
38         self.query_timeout: Optional[int] = None
39
40
41     def set_query_timeout(self, timeout: Optional[int]) -> None:
42         """ Set the timeout after which a query over this connection
43             is cancelled.
44         """
45         self.query_timeout = timeout
46
47
48     async def scalar(self, sql: sa.sql.base.Executable,
49                      params: Union[Mapping[str, Any], None] = None
50                     ) -> Any:
51         """ Execute a 'scalar()' query on the connection.
52         """
53         log().sql(self.connection, sql, params)
54         async with asyncio.timeout(self.query_timeout):
55             return await self.connection.scalar(sql, params)
56
57
58     async def execute(self, sql: 'sa.Executable',
59                       params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None] = None
60                      ) -> 'sa.Result[Any]':
61         """ Execute a 'execute()' query on the connection.
62         """
63         log().sql(self.connection, sql, params)
64         async with asyncio.timeout(self.query_timeout):
65             return await self.connection.execute(sql, params)
66
67
68     async def get_property(self, name: str, cached: bool = True) -> str:
69         """ Get a property from Nominatim's property table.
70
71             Property values are normally cached so that they are only
72             retrieved from the database when they are queried for the
73             first time with this function. Set 'cached' to False to force
74             reading the property from the database.
75
76             Raises a ValueError if the property does not exist.
77         """
78         lookup_name = f'DBPROP:{name}'
79
80         if cached and lookup_name in self._property_cache:
81             return cast(str, self._property_cache[lookup_name])
82
83         sql = sa.select(self.t.properties.c.value)\
84             .where(self.t.properties.c.property == name)
85         value = await self.connection.scalar(sql)
86
87         if value is None:
88             raise ValueError(f"Property '{name}' not found in database.")
89
90         self._property_cache[lookup_name] = cast(str, value)
91
92         return cast(str, value)
93
94
95     async def get_db_property(self, name: str) -> Any:
96         """ Get a setting from the database. At the moment, only
97             'server_version', the version of the database software, can
98             be retrieved with this function.
99
100             Raises a ValueError if the property does not exist.
101         """
102         if name != 'server_version':
103             raise ValueError(f"DB setting '{name}' not found in database.")
104
105         return self._property_cache['DB:server_version']
106
107
108     async def get_cached_value(self, group: str, name: str,
109                                factory: Callable[[], Awaitable[T]]) -> T:
110         """ Access the cache for this Nominatim instance.
111             Each cache value needs to belong to a group and have a name.
112             This function is for internal API use only.
113
114             `factory` is an async callback function that produces
115             the value if it is not already cached.
116
117             Returns the cached value or the result of factory (also caching
118             the result).
119         """
120         full_name = f'{group}:{name}'
121
122         if full_name in self._property_cache:
123             return cast(T, self._property_cache[full_name])
124
125         value = await factory()
126         self._property_cache[full_name] = value
127
128         return value
129
130
131     async def get_class_table(self, cls: str, typ: str) -> Optional[SaFromClause]:
132         """ Lookup up if there is a classtype table for the given category
133             and return a SQLAlchemy table for it, if it exists.
134         """
135         if self._classtables is None:
136             res = await self.execute(sa.text("""SELECT tablename FROM pg_tables
137                                                 WHERE tablename LIKE 'place_classtype_%'
138                                              """))
139             self._classtables = {r[0] for r in res}
140
141         tablename = f"place_classtype_{cls}_{typ}"
142
143         if tablename not in self._classtables:
144             return None
145
146         if tablename in self.t.meta.tables:
147             return self.t.meta.tables[tablename]
148
149         return sa.Table(tablename, self.t.meta,
150                         sa.Column('place_id', sa.BigInteger),
151                         sa.Column('centroid', Geometry))