+
+
+class WorkerPool:
+ """ A pool of asynchronous database connections.
+
+ The pool may be used as a context manager.
+ """
+ REOPEN_CONNECTIONS_AFTER = 100000
+
+ def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None:
+ self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors)
+ for _ in range(pool_size)]
+ self.free_workers = self._yield_free_worker()
+ self.wait_time = 0.0
+
+
+ def finish_all(self) -> None:
+ """ Wait for all connection to finish.
+ """
+ for thread in self.threads:
+ while not thread.is_done():
+ thread.wait()
+
+ self.free_workers = self._yield_free_worker()
+
+ def close(self) -> None:
+ """ Close all connections and clear the pool.
+ """
+ for thread in self.threads:
+ thread.close()
+ self.threads = []
+ self.free_workers = iter([])
+
+
+ def next_free_worker(self) -> DBConnection:
+ """ Get the next free connection.
+ """
+ return next(self.free_workers)
+
+
+ def _yield_free_worker(self) -> Iterator[DBConnection]:
+ ready = self.threads
+ command_stat = 0
+ while True:
+ for thread in ready:
+ if thread.is_done():
+ command_stat += 1
+ yield thread
+
+ if command_stat > self.REOPEN_CONNECTIONS_AFTER:
+ self._reconnect_threads()
+ ready = self.threads
+ command_stat = 0
+ else:
+ tstart = time.time()
+ _, ready, _ = select.select([], self.threads, [])
+ self.wait_time += time.time() - tstart
+
+
+ def _reconnect_threads(self) -> None:
+ for thread in self.threads:
+ while not thread.is_done():
+ thread.wait()
+ thread.connect()
+
+
+ def __enter__(self) -> 'WorkerPool':
+ return self
+
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ self.finish_all()
+ self.close()