X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/1801db523b7d125a408cf041d24cb41d9ef72632..8c7d285e03a212479fdfab91a397ff4b1eff26c6:/nominatim/nominatim.py diff --git a/nominatim/nominatim.py b/nominatim/nominatim.py index 54d9b208..e8600ca8 100755 --- a/nominatim/nominatim.py +++ b/nominatim/nominatim.py @@ -32,12 +32,19 @@ import psycopg2 from psycopg2.extras import wait_select import select +from indexer.progress import ProgressLogger + log = logging.getLogger() def make_connection(options, asynchronous=False): - return psycopg2.connect(dbname=options.dbname, user=options.user, - password=options.password, host=options.host, - port=options.port, async_=asynchronous) + params = {'dbname' : options.dbname, + 'user' : options.user, + 'password' : options.password, + 'host' : options.host, + 'port' : options.port, + 'async' : asynchronous} + + return psycopg2.connect(**params) class RankRunner(object): @@ -50,24 +57,19 @@ class RankRunner(object): def name(self): return "rank {}".format(self.rank) - def sql_index_sectors(self): - return """SELECT geometry_sector, count(*) FROM placex + def sql_count_objects(self): + return """SELECT count(*) FROM placex WHERE rank_search = {} and indexed_status > 0 - GROUP BY geometry_sector - ORDER BY geometry_sector""".format(self.rank) + """.format(self.rank) - def sql_nosector_places(self): + def sql_get_objects(self): return """SELECT place_id FROM placex WHERE indexed_status > 0 and rank_search = {} ORDER BY geometry_sector""".format(self.rank) - def sql_sector_places(self): - return """SELECT place_id FROM placex - WHERE indexed_status > 0 and rank_search = {} - and geometry_sector = %s""".format(self.rank) - - def sql_index_place(self): - return "UPDATE placex SET indexed_status = 0 WHERE place_id = %s" + def sql_index_place(self, ids): + return "UPDATE placex SET indexed_status = 0 WHERE place_id IN ({})"\ + .format(','.join((str(i) for i in ids))) class InterpolationRunner(object): @@ -78,25 +80,19 @@ class InterpolationRunner(object): def name(self): return "interpolation lines (location_property_osmline)" - def sql_index_sectors(self): - return """SELECT geometry_sector, count(*) FROM location_property_osmline - WHERE indexed_status > 0 - GROUP BY geometry_sector - ORDER BY geometry_sector""" + def sql_count_objects(self): + return """SELECT count(*) FROM location_property_osmline + WHERE indexed_status > 0""" - def sql_nosector_places(self): + def sql_get_objects(self): return """SELECT place_id FROM location_property_osmline WHERE indexed_status > 0 ORDER BY geometry_sector""" - def sql_sector_places(self): - return """SELECT place_id FROM location_property_osmline - WHERE indexed_status > 0 and geometry_sector = %s - ORDER BY geometry_sector""" - - def sql_index_place(self): + def sql_index_place(self, ids): return """UPDATE location_property_osmline - SET indexed_status = 0 WHERE place_id = %s""" + SET indexed_status = 0 WHERE place_id IN ({})"""\ + .format(','.join((str(i) for i in ids))) class DBConnection(object): @@ -104,19 +100,48 @@ class DBConnection(object): """ def __init__(self, options): + self.current_query = None + self.current_params = None + + self.conn = None + self.connect() + + def connect(self): + if self.conn is not None: + self.cursor.close() + self.conn.close() + self.conn = make_connection(options, asynchronous=True) self.wait() self.cursor = self.conn.cursor() - - self.current_query = None - self.current_params = None + # Disable JIT and parallel workers as they are known to cause problems. + # Update pg_settings instead of using SET because it does not yield + # errors on older versions of Postgres where the settings are not + # implemented. + self.perform( + """ UPDATE pg_settings SET setting = -1 WHERE name = 'jit_above_cost'; + UPDATE pg_settings SET setting = 0 + WHERE name = 'max_parallel_workers_per_gather';""") + self.wait() def wait(self): """ Block until any pending operation is done. """ - wait_select(self.conn) - self.current_query = None + while True: + try: + wait_select(self.conn) + self.current_query = None + return + except psycopg2.extensions.TransactionRollbackError as e: + if e.pgcode == '40P01': + log.info("Deadlock detected (params = {}), retry." + .format(self.current_params)) + self.cursor.execute(self.current_query, self.current_params) + else: + raise + except psycopg2.errors.DeadlockDetected: + self.cursor.execute(self.current_query, self.current_params) def perform(self, sql, args=None): """ Send SQL query to the server. Returns immediately without @@ -150,6 +175,8 @@ class DBConnection(object): self.cursor.execute(self.current_query, self.current_params) else: raise + except psycopg2.errors.DeadlockDetected: + self.cursor.execute(self.current_query, self.current_params) return False @@ -174,98 +201,75 @@ class Indexer(object): self.index(RankRunner(rank)) if self.maxrank == 30: - self.index(InterpolationRunner()) + self.index(InterpolationRunner(), 20) - self.index(RankRunner(self.maxrank)) + self.index(RankRunner(self.maxrank), 20) - def index(self, obj): + def index(self, obj, batch=1): """ Index a single rank or table. `obj` describes the SQL to use - for indexing. + for indexing. `batch` describes the number of objects that + should be processed with a single SQL statement """ log.warning("Starting {}".format(obj.name())) - cur = self.conn.cursor(name='main') - cur.execute(obj.sql_index_sectors()) + cur = self.conn.cursor() + cur.execute(obj.sql_count_objects()) - total_tuples = 0 - for r in cur: - total_tuples += r[1] - log.debug("Total number of rows; {}".format(total_tuples)) + total_tuples = cur.fetchone()[0] + log.debug("Total number of rows: {}".format(total_tuples)) - cur.scroll(0, mode='absolute') + cur.close() next_thread = self.find_free_thread() - done_tuples = 0 - rank_start_time = datetime.now() - - sector_sql = obj.sql_sector_places() - index_sql = obj.sql_index_place() - min_grouped_tuples = total_tuples - len(self.threads) * 1000 - - next_info = 100 if log.isEnabledFor(logging.INFO) else total_tuples + 1 - - for r in cur: - sector = r[0] - - # Should we do the remaining ones together? - do_all = done_tuples > min_grouped_tuples + progress = ProgressLogger(obj.name(), total_tuples) - pcur = self.conn.cursor(name='places') + cur = self.conn.cursor(name='places') + cur.execute(obj.sql_get_objects()) - if do_all: - pcur.execute(obj.sql_nosector_places()) - else: - pcur.execute(sector_sql, (sector, )) - - for place in pcur: - place_id = place[0] - log.debug("Processing place {}".format(place_id)) - thread = next(next_thread) - - thread.perform(index_sql, (place_id,)) - done_tuples += 1 - - if done_tuples >= next_info: - now = datetime.now() - done_time = (now - rank_start_time).total_seconds() - tuples_per_sec = done_tuples / done_time - log.info("Done {} in {} @ {:.3f} per second - {} ETA (seconds): {:.2f}" - .format(done_tuples, int(done_time), - tuples_per_sec, obj.name(), - (total_tuples - done_tuples)/tuples_per_sec)) - next_info += int(tuples_per_sec) + while True: + places = [p[0] for p in cur.fetchmany(batch)] + if len(places) == 0: + break - pcur.close() + log.debug("Processing places: {}".format(places)) + thread = next(next_thread) - if do_all: - break + thread.perform(obj.sql_index_place(places)) + progress.add(len(places)) cur.close() for t in self.threads: t.wait() - rank_end_time = datetime.now() - diff_seconds = (rank_end_time-rank_start_time).total_seconds() - - log.warning("Done {}/{} in {} @ {:.3f} per second - FINISHED {}\n".format( - done_tuples, total_tuples, int(diff_seconds), - done_tuples/diff_seconds, obj.name())) + progress.done() def find_free_thread(self): """ Generator that returns the next connection that is free for sending a query. """ ready = self.threads + command_stat = 0 while True: for thread in ready: if thread.is_done(): + command_stat += 1 yield thread - ready, _, _ = select.select(self.threads, [], []) + # refresh the connections occasionaly to avoid potential + # memory leaks in Postgresql. + if command_stat > 100000: + for t in self.threads: + while not t.is_done(): + t.wait() + t.connect() + command_stat = 0 + ready = self.threads + else: + ready, _, _ = select.select(self.threads, [], []) - assert(False, "Unreachable code") + assert False, "Unreachable code" def nominatim_arg_parser():