]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/indexer/indexer.py
fetch place info asynchronously
[nominatim.git] / nominatim / indexer / indexer.py
index 93723844d8d63582c982a4f9846058af7b512b8e..d685e83a1546097366faa41fe8bc580ad5b4e660 100644 (file)
 """
 Main work horse for indexing (computing addresses) the database.
 """
 """
 Main work horse for indexing (computing addresses) the database.
 """
-# pylint: disable=C0111
 import logging
 import select
 import logging
 import select
+import time
 
 
-import psycopg2
+import psycopg2.extras
 
 
-from .progress import ProgressLogger
-from ..db.async_connection import DBConnection
+from nominatim.indexer.progress import ProgressLogger
+from nominatim.indexer import runners
+from nominatim.db.async_connection import DBConnection
+from nominatim.db.connection import connect
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-class RankRunner:
-    """ Returns SQL commands for indexing one rank within the placex table.
-    """
-
-    def __init__(self, rank):
-        self.rank = rank
-
-    def name(self):
-        return "rank {}".format(self.rank)
-
-    def sql_count_objects(self):
-        return """SELECT count(*) FROM placex
-                  WHERE rank_address = {} and indexed_status > 0
-               """.format(self.rank)
-
-    def sql_get_objects(self):
-        return """SELECT place_id FROM placex
-                  WHERE indexed_status > 0 and rank_address = {}
-                  ORDER BY geometry_sector""".format(self.rank)
+class WorkerPool:
+    """ A pool of asynchronous database connections.
 
 
-    @staticmethod
-    def sql_index_place(ids):
-        return "UPDATE placex SET indexed_status = 0 WHERE place_id IN ({})"\
-               .format(','.join((str(i) for i in ids)))
-
-
-class InterpolationRunner:
-    """ Returns SQL commands for indexing the address interpolation table
-        location_property_osmline.
+        The pool may be used as a context manager.
     """
     """
+    REOPEN_CONNECTIONS_AFTER = 100000
 
 
-    @staticmethod
-    def name():
-        return "interpolation lines (location_property_osmline)"
-
-    @staticmethod
-    def sql_count_objects():
-        return """SELECT count(*) FROM location_property_osmline
-                  WHERE indexed_status > 0"""
-
-    @staticmethod
-    def sql_get_objects():
-        return """SELECT place_id FROM location_property_osmline
-                  WHERE indexed_status > 0
-                  ORDER BY geometry_sector"""
-
-    @staticmethod
-    def sql_index_place(ids):
-        return """UPDATE location_property_osmline
-                  SET indexed_status = 0 WHERE place_id IN ({})
-               """.format(','.join((str(i) for i in ids)))
-
-class BoundaryRunner:
-    """ Returns SQL commands for indexing the administrative boundaries
-        of a certain rank.
-    """
+    def __init__(self, dsn, pool_size):
+        self.threads = [DBConnection(dsn) for _ in range(pool_size)]
+        self.free_workers = self._yield_free_worker()
 
 
-    def __init__(self, rank):
-        self.rank = rank
 
 
-    def name(self):
-        return "boundaries rank {}".format(self.rank)
+    def finish_all(self):
+        """ Wait for all connection to finish.
+        """
+        for thread in self.threads:
+            while not thread.is_done():
+                thread.wait()
 
 
-    def sql_count_objects(self):
-        return """SELECT count(*) FROM placex
-                  WHERE indexed_status > 0
-                    AND rank_search = {}
-                    AND class = 'boundary' and type = 'administrative'
-               """.format(self.rank)
+        self.free_workers = self._yield_free_worker()
 
 
-    def sql_get_objects(self):
-        return """SELECT place_id FROM placex
-                  WHERE indexed_status > 0 and rank_search = {}
-                        and class = 'boundary' and type = 'administrative'
-                  ORDER BY partition, admin_level
-               """.format(self.rank)
+    def close(self):
+        """ Close all connections and clear the pool.
+        """
+        for thread in self.threads:
+            thread.close()
+        self.threads = []
+        self.free_workers = None
 
 
-    @staticmethod
-    def sql_index_place(ids):
-        return "UPDATE placex SET indexed_status = 0 WHERE place_id IN ({})"\
-               .format(','.join((str(i) for i in ids)))
 
 
+    def next_free_worker(self):
+        """ Get the next free connection.
+        """
+        return next(self.free_workers)
 
 
-class PostcodeRunner:
-    """ Provides the SQL commands for indexing the location_postcode table.
-    """
 
 
-    @staticmethod
-    def name():
-        return "postcodes (location_postcode)"
+    def _yield_free_worker(self):
+        ready = self.threads
+        command_stat = 0
+        while True:
+            for thread in ready:
+                if thread.is_done():
+                    command_stat += 1
+                    yield thread
 
 
-    @staticmethod
-    def sql_count_objects():
-        return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0'
+            if command_stat > self.REOPEN_CONNECTIONS_AFTER:
+                for thread in self.threads:
+                    while not thread.is_done():
+                        thread.wait()
+                    thread.connect()
+                ready = self.threads
+                command_stat = 0
+            else:
+                _, ready, _ = select.select([], self.threads, [])
 
 
-    @staticmethod
-    def sql_get_objects():
-        return """SELECT place_id FROM location_postcode
-                  WHERE indexed_status > 0
-                  ORDER BY country_code, postcode"""
 
 
-    @staticmethod
-    def sql_index_place(ids):
-        return """UPDATE location_postcode SET indexed_status = 0
-                  WHERE place_id IN ({})
-               """.format(','.join((str(i) for i in ids)))
+    def __enter__(self):
+        return self
 
 
 
 
-def _analyse_db_if(conn, condition):
-    if condition:
-        with conn.cursor() as cur:
-            cur.execute('ANALYSE')
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.close()
 
 
 class Indexer:
     """ Main indexing routine.
     """
 
 
 
 class Indexer:
     """ Main indexing routine.
     """
 
-    def __init__(self, dsn, num_threads):
+    def __init__(self, dsn, tokenizer, num_threads):
         self.dsn = dsn
         self.dsn = dsn
+        self.tokenizer = tokenizer
         self.num_threads = num_threads
         self.num_threads = num_threads
-        self.conn = None
-        self.threads = []
-
-
-    def _setup_connections(self):
-        self.conn = psycopg2.connect(self.dsn)
-        self.threads = [DBConnection(self.dsn) for _ in range(self.num_threads)]
-
-
-    def _close_connections(self):
-        if self.conn:
-            self.conn.close()
-            self.conn = None
-
-        for thread in self.threads:
-            thread.close()
-        self.threads = []
 
 
     def index_full(self, analyse=True):
 
 
     def index_full(self, analyse=True):
@@ -158,25 +94,31 @@ class Indexer:
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
         """
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
         """
-        conn = psycopg2.connect(self.dsn)
+        with connect(self.dsn) as conn:
+            conn.autocommit = True
+
+            if analyse:
+                def _analyze():
+                    with conn.cursor() as cur:
+                        cur.execute('ANALYZE')
+            else:
+                def _analyze():
+                    pass
 
 
-        try:
             self.index_by_rank(0, 4)
             self.index_by_rank(0, 4)
-            _analyse_db_if(conn, analyse)
+            _analyze()
 
             self.index_boundaries(0, 30)
 
             self.index_boundaries(0, 30)
-            _analyse_db_if(conn, analyse)
+            _analyze()
 
             self.index_by_rank(5, 25)
 
             self.index_by_rank(5, 25)
-            _analyse_db_if(conn, analyse)
+            _analyze()
 
             self.index_by_rank(26, 30)
 
             self.index_by_rank(26, 30)
-            _analyse_db_if(conn, analyse)
+            _analyze()
 
             self.index_postcodes()
 
             self.index_postcodes()
-            _analyse_db_if(conn, analyse)
-        finally:
-            conn.close()
+            _analyze()
 
 
     def index_boundaries(self, minrank, maxrank):
 
 
     def index_boundaries(self, minrank, maxrank):
@@ -185,13 +127,9 @@ class Indexer:
         LOG.warning("Starting indexing boundaries using %s threads",
                     self.num_threads)
 
         LOG.warning("Starting indexing boundaries using %s threads",
                     self.num_threads)
 
-        self._setup_connections()
-
-        try:
+        with self.tokenizer.name_analyzer() as analyzer:
             for rank in range(max(minrank, 4), min(maxrank, 26)):
             for rank in range(max(minrank, 4), min(maxrank, 26)):
-                self.index(BoundaryRunner(rank))
-        finally:
-            self._close_connections()
+                self._index(runners.BoundaryRunner(rank, analyzer))
 
     def index_by_rank(self, minrank, maxrank):
         """ Index all entries of placex in the given rank range (inclusive)
 
     def index_by_rank(self, minrank, maxrank):
         """ Index all entries of placex in the given rank range (inclusive)
@@ -204,20 +142,16 @@ class Indexer:
         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
                     minrank, maxrank, self.num_threads)
 
         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
                     minrank, maxrank, self.num_threads)
 
-        self._setup_connections()
-
-        try:
+        with self.tokenizer.name_analyzer() as analyzer:
             for rank in range(max(1, minrank), maxrank):
             for rank in range(max(1, minrank), maxrank):
-                self.index(RankRunner(rank))
+                self._index(runners.RankRunner(rank, analyzer))
 
             if maxrank == 30:
 
             if maxrank == 30:
-                self.index(RankRunner(0))
-                self.index(InterpolationRunner(), 20)
-                self.index(RankRunner(30), 20)
+                self._index(runners.RankRunner(0, analyzer))
+                self._index(runners.InterpolationRunner(analyzer), 20)
+                self._index(runners.RankRunner(30, analyzer), 20)
             else:
             else:
-                self.index(RankRunner(maxrank))
-        finally:
-            self._close_connections()
+                self._index(runners.RankRunner(maxrank, analyzer))
 
 
     def index_postcodes(self):
 
 
     def index_postcodes(self):
@@ -225,89 +159,95 @@ class Indexer:
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
 
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
 
-        self._setup_connections()
+        self._index(runners.PostcodeRunner(), 20)
 
 
-        try:
-            self.index(PostcodeRunner(), 20)
-        finally:
-            self._close_connections()
 
     def update_status_table(self):
         """ Update the status in the status table to 'indexed'.
         """
 
     def update_status_table(self):
         """ Update the status in the status table to 'indexed'.
         """
-        conn = psycopg2.connect(self.dsn)
-
-        try:
+        with connect(self.dsn) as conn:
             with conn.cursor() as cur:
                 cur.execute('UPDATE import_status SET indexed = true')
 
             conn.commit()
             with conn.cursor() as cur:
                 cur.execute('UPDATE import_status SET indexed = true')
 
             conn.commit()
-        finally:
-            conn.close()
 
 
-    def index(self, obj, batch=1):
-        """ Index a single rank or table. `obj` describes the SQL to use
+    def _index(self, runner, batch=1):
+        """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """
-        LOG.warning("Starting %s (using batch size %s)", obj.name(), batch)
+        LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
+
+        with connect(self.dsn) as conn:
+            psycopg2.extras.register_hstore(conn)
+            with conn.cursor() as cur:
+                total_tuples = cur.scalar(runner.sql_count_objects())
+                LOG.debug("Total number of rows: %i", total_tuples)
 
 
-        cur = self.conn.cursor()
-        cur.execute(obj.sql_count_objects())
+                # need to fetch those manually because register_hstore cannot
+                # fetch them on an asynchronous connection below.
+                hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
+                hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
 
 
-        total_tuples = cur.fetchone()[0]
-        LOG.debug("Total number of rows: %i", total_tuples)
+            conn.commit()
 
 
-        cur.close()
+            progress = ProgressLogger(runner.name(), total_tuples)
 
 
-        progress = ProgressLogger(obj.name(), total_tuples)
+            fetcher_wait = 0
+            pool_wait = 0
 
 
-        if total_tuples > 0:
-            cur = self.conn.cursor(name='places')
-            cur.execute(obj.sql_get_objects())
+            if total_tuples > 0:
+                with conn.cursor(name='places') as cur:
+                    cur.execute(runner.sql_get_objects())
 
 
-            next_thread = self.find_free_thread()
-            while True:
-                places = [p[0] for p in cur.fetchmany(batch)]
-                if not places:
-                    break
+                    fetcher = DBConnection(self.dsn, cursor_factory=psycopg2.extras.DictCursor)
+                    psycopg2.extras.register_hstore(fetcher.conn,
+                                                    oid=hstore_oid,
+                                                    array_oid=hstore_array_oid)
 
 
-                LOG.debug("Processing places: %s", str(places))
-                thread = next(next_thread)
+                    with WorkerPool(self.dsn, self.num_threads) as pool:
+                        places = self._fetch_next_batch(cur, fetcher, runner)
+                        while places is not None:
+                            if not places:
+                                t0 = time.time()
+                                fetcher.wait()
+                                fetcher_wait += time.time() - t0
+                                places = fetcher.cursor.fetchall()
 
 
-                thread.perform(obj.sql_index_place(places))
-                progress.add(len(places))
+                            # asynchronously get the next batch
+                            next_places = self._fetch_next_batch(cur, fetcher, runner)
 
 
-            cur.close()
+                            # And insert the curent batch
+                            for idx in range(0, len(places), batch):
+                                t0 = time.time()
+                                worker = pool.next_free_worker()
+                                pool_wait += time.time() - t0
+                                part = places[idx:idx+batch]
+                                LOG.debug("Processing places: %s", str(part))
+                                runner.index_places(worker, part)
+                                progress.add(len(part))
 
 
-            for thread in self.threads:
-                thread.wait()
+                            places = next_places
+
+                        pool.finish_all()
+
+                    fetcher.wait()
+                    fetcher.close()
+
+                conn.commit()
 
         progress.done()
 
         progress.done()
+        LOG.warning("Wait time: fetcher: {}s,  pool: {}s".format(fetcher_wait, pool_wait))
 
 
-    def find_free_thread(self):
-        """ Generator that returns the next connection that is free for
-            sending a query.
-        """
-        ready = self.threads
-        command_stat = 0
 
 
-        while True:
-            for thread in ready:
-                if thread.is_done():
-                    command_stat += 1
-                    yield thread
+    def _fetch_next_batch(self, cur, fetcher, runner):
+        ids = cur.fetchmany(100)
 
 
-            # refresh the connections occasionaly to avoid potential
-            # memory leaks in Postgresql.
-            if command_stat > 100000:
-                for thread in self.threads:
-                    while not thread.is_done():
-                        thread.wait()
-                    thread.connect()
-                command_stat = 0
-                ready = self.threads
-            else:
-                ready, _, _ = select.select(self.threads, [], [])
+        if not ids:
+            return None
+
+        if not hasattr(runner, 'get_place_details'):
+            return ids
 
 
-        assert False, "Unreachable code"
+        runner.get_place_details(fetcher, ids)
+        return []