from datetime import datetime
import psycopg2
from psycopg2.extras import wait_select
-import threading
-from queue import Queue
+import select
log = logging.getLogger()
password=options.password, host=options.host,
port=options.port, async_=asynchronous)
-class IndexingThread(threading.Thread):
+class IndexingThread(object):
- def __init__(self, queue, barrier, options):
- super().__init__()
- self.conn = make_connection(options)
- self.conn.autocommit = True
+ def __init__(self, thread_num, options):
+ log.debug("Creating thread {}".format(thread_num))
+ self.thread_num = thread_num
+ self.conn = make_connection(options, asynchronous=True)
+ self.wait()
self.cursor = self.conn.cursor()
self.perform("SET lc_messages TO 'C'")
+ self.wait()
self.perform(InterpolationRunner.prepare())
+ self.wait()
self.perform(RankRunner.prepare())
- self.queue = queue
- self.barrier = barrier
+ self.wait()
- def run(self):
- sql = None
- while True:
- item = self.queue.get()
- if item is None:
- break
- elif isinstance(item, str):
- sql = item
- self.barrier.wait()
- else:
- self.perform(sql, (item,))
+ self.current_query = None
+ self.current_params = None
+
+ def wait(self):
+ wait_select(self.conn)
+ self.current_query = None
def perform(self, sql, args=None):
- while True:
- try:
- self.cursor.execute(sql, args)
- return
- except psycopg2.extensions.TransactionRollbackError as e:
- if e.pgcode is None:
- raise RuntimeError("Postgres exception has no error code")
- if e.pgcode == '40P01':
- log.info("Deadlock detected, retry.")
- else:
- raise
+ self.current_query = sql
+ self.current_params = args
+ self.cursor.execute(sql, args)
+
+ def fileno(self):
+ return self.conn.fileno()
+
+ def is_done(self):
+ if self.current_query is None:
+ return True
+
+ try:
+ if self.conn.poll() == psycopg2.extensions.POLL_OK:
+ self.current_query = None
+ return True
+ except psycopg2.extensions.TransactionRollbackError as e:
+ if e.pgcode is None:
+ raise RuntimeError("Postgres exception has no error code")
+ if e.pgcode == '40P01':
+ log.info("Deadlock detected, retry.")
+ self.cursor.execute(self.current_query, self.current_params)
+ else:
+ raise
self.conn = make_connection(options)
self.threads = []
- self.queue = Queue(maxsize=1000)
- self.barrier = threading.Barrier(options.threads + 1)
+ self.poll = select.poll()
for i in range(options.threads):
- t = IndexingThread(self.queue, self.barrier, options)
+ t = IndexingThread(i, options)
self.threads.append(t)
- t.start()
+ self.poll.register(t, select.EPOLLIN)
def run(self):
log.info("Starting indexing rank ({} to {}) using {} threads".format(
self.index(InterpolationRunner())
self.index(RankRunner(30))
- self.queue_all(None)
- for t in self.threads:
- t.join()
-
- def queue_all(self, item):
- for t in self.threads:
- self.queue.put(item)
-
def index(self, obj):
log.info("Starting {}".format(obj.name()))
- self.queue_all(obj.sql_index_place())
- self.barrier.wait()
-
cur = self.conn.cursor(name="main")
cur.execute(obj.sql_index_sectors())
cur.scroll(0, mode='absolute')
+ next_thread = self.find_free_thread()
done_tuples = 0
rank_start_time = datetime.now()
for r in cur:
for place in pcur:
place_id = place[0]
log.debug("Processing place {}".format(place_id))
+ thread = next(next_thread)
- self.queue.put(place_id)
+ thread.perform(obj.sql_index_place(), (place_id,))
done_tuples += 1
pcur.close()
cur.close()
- self.queue_all("")
- self.barrier.wait()
+ for t in self.threads:
+ t.wait()
rank_end_time = datetime.now()
diff_seconds = (rank_end_time-rank_start_time).total_seconds()
done_tuples, int(diff_seconds),
done_tuples/diff_seconds, obj.name()))
+ def find_free_thread(self):
+ thread_lookup = { t.fileno() : t for t in self.threads}
+
+ done_fids = [ t.fileno() for t in self.threads ]
+
+ while True:
+ for fid in done_fids:
+ thread = thread_lookup[fid]
+ if thread.is_done():
+ yield thread
+ else:
+ print("not good", fid)
+
+ done_fids = [ x[0] for x in self.poll.poll()]
+
+ assert(False, "Unreachable code")
class RankRunner(object):