]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/db/query_pool.py
Merge pull request #3582 from lonvia/switch-to-flake
[nominatim.git] / src / nominatim_db / db / query_pool.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 A connection pool that executes incoming queries in parallel.
9 """
10 from typing import Any, Tuple, Optional
11 import asyncio
12 import logging
13 import time
14
15 import psycopg
16
17 LOG = logging.getLogger()
18
19 QueueItem = Optional[Tuple[psycopg.abc.Query, Any]]
20
21
22 class QueryPool:
23     """ Pool to run SQL queries in parallel asynchronous execution.
24
25         All queries are run in autocommit mode. If parallel execution leads
26         to a deadlock, then the query is repeated.
27         The results of the queries is discarded.
28     """
29     def __init__(self, dsn: str, pool_size: int = 1, **conn_args: Any) -> None:
30         self.wait_time = 0.0
31         self.query_queue: 'asyncio.Queue[QueueItem]' = asyncio.Queue(maxsize=2 * pool_size)
32
33         self.pool = [asyncio.create_task(self._worker_loop(dsn, **conn_args))
34                      for _ in range(pool_size)]
35
36     async def put_query(self, query: psycopg.abc.Query, params: Any) -> None:
37         """ Schedule a query for execution.
38         """
39         tstart = time.time()
40         await self.query_queue.put((query, params))
41         self.wait_time += time.time() - tstart
42         await asyncio.sleep(0)
43
44     async def finish(self) -> None:
45         """ Wait for all queries to finish and close the pool.
46         """
47         for _ in self.pool:
48             await self.query_queue.put(None)
49
50         tstart = time.time()
51         await asyncio.wait(self.pool)
52         self.wait_time += time.time() - tstart
53
54         for task in self.pool:
55             excp = task.exception()
56             if excp is not None:
57                 raise excp
58
59     async def _worker_loop(self, dsn: str, **conn_args: Any) -> None:
60         conn_args['autocommit'] = True
61         aconn = await psycopg.AsyncConnection.connect(dsn, **conn_args)
62         async with aconn:
63             async with aconn.cursor() as cur:
64                 item = await self.query_queue.get()
65                 while item is not None:
66                     try:
67                         if item[1] is None:
68                             await cur.execute(item[0])
69                         else:
70                             await cur.execute(item[0], item[1])
71
72                         item = await self.query_queue.get()
73                     except psycopg.errors.DeadlockDetected:
74                         assert item is not None
75                         LOG.info("Deadlock detected (sql = %s, params = %s), retry.",
76                                  str(item[0]), str(item[1]))
77                         # item is still valid here, causing a retry
78
79     async def __aenter__(self) -> 'QueryPool':
80         return self
81
82     async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
83         await self.finish()