From 0a26ca7104940447ea98a1edf0e7c443da09d37f Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Sun, 19 Jan 2020 21:56:20 +0100 Subject: [PATCH] switch to threading --- nominatim/nominatim.py | 112 ++++++++++++++++++----------------------- 1 file changed, 49 insertions(+), 63 deletions(-) diff --git a/nominatim/nominatim.py b/nominatim/nominatim.py index 6b25cf5c..61907060 100644 --- a/nominatim/nominatim.py +++ b/nominatim/nominatim.py @@ -30,7 +30,8 @@ import getpass from datetime import datetime import psycopg2 from psycopg2.extras import wait_select -import select +import threading +from queue import Queue log = logging.getLogger() @@ -39,53 +40,44 @@ def make_connection(options, asynchronous=False): password=options.password, host=options.host, port=options.port, async_=asynchronous) -class IndexingThread(object): +class IndexingThread(threading.Thread): - def __init__(self, thread_num, options): - log.debug("Creating thread {}".format(thread_num)) - self.thread_num = thread_num - self.conn = make_connection(options, asynchronous=True) - self.wait() + def __init__(self, queue, barrier, options): + super().__init__() + self.conn = make_connection(options) + self.conn.autocommit = True self.cursor = self.conn.cursor() self.perform("SET lc_messages TO 'C'") - self.wait() self.perform(InterpolationRunner.prepare()) - self.wait() self.perform(RankRunner.prepare()) - self.wait() - - self.current_query = None - self.current_params = None + self.queue = queue + self.barrier = barrier - def wait(self): - wait_select(self.conn) - self.current_query = None + def run(self): + sql = None + while True: + item = self.queue.get() + if item is None: + break + elif isinstance(item, str): + sql = item + self.barrier.wait() + else: + self.perform(sql, (item,)) def perform(self, sql, args=None): - self.current_query = sql - self.current_params = args - self.cursor.execute(sql, args) - - def fileno(self): - return self.conn.fileno() - - def is_done(self): - if self.current_query is None: - return True - - try: - if self.conn.poll() == psycopg2.extensions.POLL_OK: - self.current_query = None - return True - except psycopg2.extensions.TransactionRollbackError as e: - if e.pgcode is None: - raise RuntimeError("Postgres exception has no error code") - if e.pgcode == '40P01': - log.info("Deadlock detected, retry.") - self.cursor.execute(self.current_query, self.current_params) - else: - raise + while True: + try: + self.cursor.execute(sql, args) + return + except psycopg2.extensions.TransactionRollbackError as e: + if e.pgcode is None: + raise RuntimeError("Postgres exception has no error code") + if e.pgcode == '40P01': + log.info("Deadlock detected, retry.") + else: + raise @@ -96,11 +88,12 @@ class Indexer(object): self.conn = make_connection(options) self.threads = [] - self.poll = select.poll() + self.queue = Queue(maxsize=1000) + self.barrier = threading.Barrier(options.threads + 1) for i in range(options.threads): - t = IndexingThread(i, options) + t = IndexingThread(self.queue, self.barrier, options) self.threads.append(t) - self.poll.register(t, select.EPOLLIN) + t.start() def run(self): log.info("Starting indexing rank ({} to {}) using {} threads".format( @@ -114,9 +107,20 @@ class Indexer(object): self.index(InterpolationRunner()) self.index(RankRunner(30)) + self.queue_all(None) + for t in self.threads: + t.join() + + def queue_all(self, item): + for t in self.threads: + self.queue.put(item) + def index(self, obj): log.info("Starting {}".format(obj.name())) + self.queue_all(obj.sql_index_place()) + self.barrier.wait() + cur = self.conn.cursor(name="main") cur.execute(obj.sql_index_sectors()) @@ -127,7 +131,6 @@ class Indexer(object): cur.scroll(0, mode='absolute') - next_thread = self.find_free_thread() done_tuples = 0 rank_start_time = datetime.now() for r in cur: @@ -146,9 +149,8 @@ class Indexer(object): for place in pcur: place_id = place[0] log.debug("Processing place {}".format(place_id)) - thread = next(next_thread) - thread.perform(obj.sql_index_place(), (place_id,)) + self.queue.put(place_id) done_tuples += 1 pcur.close() @@ -158,8 +160,8 @@ class Indexer(object): cur.close() - for t in self.threads: - t.wait() + self.queue_all("") + self.barrier.wait() rank_end_time = datetime.now() diff_seconds = (rank_end_time-rank_start_time).total_seconds() @@ -168,22 +170,6 @@ class Indexer(object): done_tuples, int(diff_seconds), done_tuples/diff_seconds, obj.name())) - def find_free_thread(self): - thread_lookup = { t.fileno() : t for t in self.threads} - - done_fids = [ t.fileno() for t in self.threads ] - - while True: - for fid in done_fids: - thread = thread_lookup[fid] - if thread.is_done(): - yield thread - else: - print("not good", fid) - - done_fids = [ x[0] for x in self.poll.poll()] - - assert(False, "Unreachable code") class RankRunner(object): -- 2.39.5