]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/connection.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / api / connection.py
index bf2173144d72fa7deee39daa010f1e75ef5293dc..405213e97659d32fb9ff9d56c2478219690af6a4 100644 (file)
@@ -9,6 +9,7 @@ Extended SQLAlchemy connection class that also includes access to the schema.
 """
 from typing import cast, Any, Mapping, Sequence, Union, Dict, Optional, Set, \
                    Awaitable, Callable, TypeVar
 """
 from typing import cast, Any, Mapping, Sequence, Union, Dict, Optional, Set, \
                    Awaitable, Callable, TypeVar
+import asyncio
 
 import sqlalchemy as sa
 from sqlalchemy.ext.asyncio import AsyncConnection
 
 import sqlalchemy as sa
 from sqlalchemy.ext.asyncio import AsyncConnection
@@ -34,6 +35,14 @@ class SearchConnection:
         self.t = tables # pylint: disable=invalid-name
         self._property_cache = properties
         self._classtables: Optional[Set[str]] = None
         self.t = tables # pylint: disable=invalid-name
         self._property_cache = properties
         self._classtables: Optional[Set[str]] = None
+        self.query_timeout: Optional[int] = None
+
+
+    def set_query_timeout(self, timeout: Optional[int]) -> None:
+        """ Set the timeout after which a query over this connection
+            is cancelled.
+        """
+        self.query_timeout = timeout
 
 
     async def scalar(self, sql: sa.sql.base.Executable,
 
 
     async def scalar(self, sql: sa.sql.base.Executable,
@@ -42,7 +51,7 @@ class SearchConnection:
         """ Execute a 'scalar()' query on the connection.
         """
         log().sql(self.connection, sql, params)
         """ Execute a 'scalar()' query on the connection.
         """
         log().sql(self.connection, sql, params)
-        return await self.connection.scalar(sql, params)
+        return await asyncio.wait_for(self.connection.scalar(sql, params), self.query_timeout)
 
 
     async def execute(self, sql: 'sa.Executable',
 
 
     async def execute(self, sql: 'sa.Executable',
@@ -51,7 +60,7 @@ class SearchConnection:
         """ Execute a 'execute()' query on the connection.
         """
         log().sql(self.connection, sql, params)
         """ Execute a 'execute()' query on the connection.
         """
         log().sql(self.connection, sql, params)
-        return await self.connection.execute(sql, params)
+        return await asyncio.wait_for(self.connection.execute(sql, params), self.query_timeout)
 
 
     async def get_property(self, name: str, cached: bool = True) -> str:
 
 
     async def get_property(self, name: str, cached: bool = True) -> str: