from datetime import datetime
import psycopg2
from psycopg2.extras import wait_select
-import select
+import threading
+from queue import Queue
log = logging.getLogger()
password=options.password, host=options.host,
port=options.port, async_=asynchronous)
-class IndexingThread(object):
+class IndexingThread(threading.Thread):
- def __init__(self, thread_num, options):
- log.debug("Creating thread {}".format(thread_num))
- self.thread_num = thread_num
- self.conn = make_connection(options, asynchronous=True)
- self.wait()
+ def __init__(self, queue, barrier, options):
+ super().__init__()
+ self.conn = make_connection(options)
+ self.conn.autocommit = True
self.cursor = self.conn.cursor()
self.perform("SET lc_messages TO 'C'")
- self.wait()
self.perform(InterpolationRunner.prepare())
- self.wait()
self.perform(RankRunner.prepare())
- self.wait()
-
- self.current_query = None
- self.current_params = None
+ self.queue = queue
+ self.barrier = barrier
- def wait(self):
- wait_select(self.conn)
- self.current_query = None
+ def run(self):
+ sql = None
+ while True:
+ item = self.queue.get()
+ if item is None:
+ break
+ elif isinstance(item, str):
+ sql = item
+ self.barrier.wait()
+ else:
+ self.perform(sql, (item,))
def perform(self, sql, args=None):
- self.current_query = sql
- self.current_params = args
- self.cursor.execute(sql, args)
-
- def fileno(self):
- return self.conn.fileno()
-
- def is_done(self):
- if self.current_query is None:
- return True
-
- try:
- if self.conn.poll() == psycopg2.extensions.POLL_OK:
- self.current_query = None
- return True
- except psycopg2.extensions.TransactionRollbackError as e:
- if e.pgcode is None:
- raise RuntimeError("Postgres exception has no error code")
- if e.pgcode == '40P01':
- log.info("Deadlock detected, retry.")
- self.cursor.execute(self.current_query, self.current_params)
- else:
- raise
+ while True:
+ try:
+ self.cursor.execute(sql, args)
+ return
+ except psycopg2.extensions.TransactionRollbackError as e:
+ if e.pgcode is None:
+ raise RuntimeError("Postgres exception has no error code")
+ if e.pgcode == '40P01':
+ log.info("Deadlock detected, retry.")
+ else:
+ raise
self.conn = make_connection(options)
self.threads = []
- self.poll = select.poll()
+ self.queue = Queue(maxsize=1000)
+ self.barrier = threading.Barrier(options.threads + 1)
for i in range(options.threads):
- t = IndexingThread(i, options)
+ t = IndexingThread(self.queue, self.barrier, options)
self.threads.append(t)
- self.poll.register(t, select.EPOLLIN)
+ t.start()
def run(self):
log.info("Starting indexing rank ({} to {}) using {} threads".format(
self.index(InterpolationRunner())
self.index(RankRunner(30))
+ self.queue_all(None)
+ for t in self.threads:
+ t.join()
+
+ def queue_all(self, item):
+ for t in self.threads:
+ self.queue.put(item)
+
def index(self, obj):
log.info("Starting {}".format(obj.name()))
+ self.queue_all(obj.sql_index_place())
+ self.barrier.wait()
+
cur = self.conn.cursor(name="main")
cur.execute(obj.sql_index_sectors())
cur.scroll(0, mode='absolute')
- next_thread = self.find_free_thread()
done_tuples = 0
rank_start_time = datetime.now()
for r in cur:
for place in pcur:
place_id = place[0]
log.debug("Processing place {}".format(place_id))
- thread = next(next_thread)
- thread.perform(obj.sql_index_place(), (place_id,))
+ self.queue.put(place_id)
done_tuples += 1
pcur.close()
cur.close()
- for t in self.threads:
- t.wait()
+ self.queue_all("")
+ self.barrier.wait()
rank_end_time = datetime.now()
diff_seconds = (rank_end_time-rank_start_time).total_seconds()
done_tuples, int(diff_seconds),
done_tuples/diff_seconds, obj.name()))
- def find_free_thread(self):
- thread_lookup = { t.fileno() : t for t in self.threads}
-
- done_fids = [ t.fileno() for t in self.threads ]
-
- while True:
- for fid in done_fids:
- thread = thread_lookup[fid]
- if thread.is_done():
- yield thread
- else:
- print("not good", fid)
-
- done_fids = [ x[0] for x in self.poll.poll()]
-
- assert(False, "Unreachable code")
class RankRunner(object):