]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/indexer/indexer.py
CI: run BDD tests with legacy_icu tokenizer
[nominatim.git] / nominatim / indexer / indexer.py
index fa40334b7851b617f6069c0932c3f8c6b0a310d8..5ab0eac3dca5701562085ffecdce652e110f9cfd 100644 (file)
@@ -2,47 +2,93 @@
 Main work horse for indexing (computing addresses) the database.
 """
 import logging
 Main work horse for indexing (computing addresses) the database.
 """
 import logging
-import select
+import time
 
 
-import psycopg2
+import psycopg2.extras
 
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
 
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
-from nominatim.db.async_connection import DBConnection
+from nominatim.db.async_connection import DBConnection, WorkerPool
+from nominatim.db.connection import connect
 
 LOG = logging.getLogger()
 
 
 
 LOG = logging.getLogger()
 
 
-def _analyse_db_if(conn, condition):
-    if condition:
-        with conn.cursor() as cur:
-            cur.execute('ANALYSE')
+class PlaceFetcher:
+    """ Asynchronous connection that fetches place details for processing.
+    """
+    def __init__(self, dsn, setup_conn):
+        self.wait_time = 0
+        self.current_ids = None
+        self.conn = DBConnection(dsn, cursor_factory=psycopg2.extras.DictCursor)
+
+        with setup_conn.cursor() as cur:
+            # need to fetch those manually because register_hstore cannot
+            # fetch them on an asynchronous connection below.
+            hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
+            hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
+
+        psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
+                                        array_oid=hstore_array_oid)
+
+    def close(self):
+        """ Close the underlying asynchronous connection.
+        """
+        if self.conn:
+            self.conn.close()
+            self.conn = None
 
 
 
 
-class Indexer:
-    """ Main indexing routine.
-    """
+    def fetch_next_batch(self, cur, runner):
+        """ Send a request for the next batch of places.
+            If details for the places are required, they will be fetched
+            asynchronously.
 
 
-    def __init__(self, dsn, num_threads):
-        self.dsn = dsn
-        self.num_threads = num_threads
-        self.conn = None
-        self.threads = []
+            Returns true if there is still data available.
+        """
+        ids = cur.fetchmany(100)
 
 
+        if not ids:
+            self.current_ids = None
+            return False
 
 
-    def _setup_connections(self):
-        self.conn = psycopg2.connect(self.dsn)
-        self.threads = [DBConnection(self.dsn) for _ in range(self.num_threads)]
+        if hasattr(runner, 'get_place_details'):
+            runner.get_place_details(self.conn, ids)
+            self.current_ids = []
+        else:
+            self.current_ids = ids
 
 
+        return True
 
 
-    def _close_connections(self):
-        if self.conn:
-            self.conn.close()
-            self.conn = None
+    def get_batch(self):
+        """ Get the next batch of data, previously requested with
+            `fetch_next_batch`.
+        """
+        if self.current_ids is not None and not self.current_ids:
+            tstart = time.time()
+            self.conn.wait()
+            self.wait_time += time.time() - tstart
+            self.current_ids = self.conn.cursor.fetchall()
+
+        return self.current_ids
+
+    def __enter__(self):
+        return self
+
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.conn.wait()
+        self.close()
 
 
-        for thread in self.threads:
-            thread.close()
-        self.threads = []
+
+class Indexer:
+    """ Main indexing routine.
+    """
+
+    def __init__(self, dsn, tokenizer, num_threads):
+        self.dsn = dsn
+        self.tokenizer = tokenizer
+        self.num_threads = num_threads
 
 
     def index_full(self, analyse=True):
 
 
     def index_full(self, analyse=True):
@@ -51,26 +97,31 @@ class Indexer:
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
         """
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
         """
-        conn = psycopg2.connect(self.dsn)
-        conn.autocommit = True
+        with connect(self.dsn) as conn:
+            conn.autocommit = True
+
+            if analyse:
+                def _analyze():
+                    with conn.cursor() as cur:
+                        cur.execute('ANALYZE')
+            else:
+                def _analyze():
+                    pass
 
 
-        try:
             self.index_by_rank(0, 4)
             self.index_by_rank(0, 4)
-            _analyse_db_if(conn, analyse)
+            _analyze()
 
             self.index_boundaries(0, 30)
 
             self.index_boundaries(0, 30)
-            _analyse_db_if(conn, analyse)
+            _analyze()
 
             self.index_by_rank(5, 25)
 
             self.index_by_rank(5, 25)
-            _analyse_db_if(conn, analyse)
+            _analyze()
 
             self.index_by_rank(26, 30)
 
             self.index_by_rank(26, 30)
-            _analyse_db_if(conn, analyse)
+            _analyze()
 
             self.index_postcodes()
 
             self.index_postcodes()
-            _analyse_db_if(conn, analyse)
-        finally:
-            conn.close()
+            _analyze()
 
 
     def index_boundaries(self, minrank, maxrank):
 
 
     def index_boundaries(self, minrank, maxrank):
@@ -79,13 +130,9 @@ class Indexer:
         LOG.warning("Starting indexing boundaries using %s threads",
                     self.num_threads)
 
         LOG.warning("Starting indexing boundaries using %s threads",
                     self.num_threads)
 
-        self._setup_connections()
-
-        try:
+        with self.tokenizer.name_analyzer() as analyzer:
             for rank in range(max(minrank, 4), min(maxrank, 26)):
             for rank in range(max(minrank, 4), min(maxrank, 26)):
-                self.index(runners.BoundaryRunner(rank))
-        finally:
-            self._close_connections()
+                self._index(runners.BoundaryRunner(rank, analyzer))
 
     def index_by_rank(self, minrank, maxrank):
         """ Index all entries of placex in the given rank range (inclusive)
 
     def index_by_rank(self, minrank, maxrank):
         """ Index all entries of placex in the given rank range (inclusive)
@@ -98,20 +145,16 @@ class Indexer:
         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
                     minrank, maxrank, self.num_threads)
 
         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
                     minrank, maxrank, self.num_threads)
 
-        self._setup_connections()
-
-        try:
+        with self.tokenizer.name_analyzer() as analyzer:
             for rank in range(max(1, minrank), maxrank):
             for rank in range(max(1, minrank), maxrank):
-                self.index(runners.RankRunner(rank))
+                self._index(runners.RankRunner(rank, analyzer))
 
             if maxrank == 30:
 
             if maxrank == 30:
-                self.index(runners.RankRunner(0))
-                self.index(runners.InterpolationRunner(), 20)
-                self.index(runners.RankRunner(30), 20)
+                self._index(runners.RankRunner(0, analyzer))
+                self._index(runners.InterpolationRunner(analyzer), 20)
+                self._index(runners.RankRunner(30, analyzer), 20)
             else:
             else:
-                self.index(runners.RankRunner(maxrank))
-        finally:
-            self._close_connections()
+                self._index(runners.RankRunner(maxrank, analyzer))
 
 
     def index_postcodes(self):
 
 
     def index_postcodes(self):
@@ -119,89 +162,58 @@ class Indexer:
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
 
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
 
-        self._setup_connections()
+        self._index(runners.PostcodeRunner(), 20)
 
 
-        try:
-            self.index(runners.PostcodeRunner(), 20)
-        finally:
-            self._close_connections()
 
     def update_status_table(self):
         """ Update the status in the status table to 'indexed'.
         """
 
     def update_status_table(self):
         """ Update the status in the status table to 'indexed'.
         """
-        conn = psycopg2.connect(self.dsn)
-
-        try:
+        with connect(self.dsn) as conn:
             with conn.cursor() as cur:
                 cur.execute('UPDATE import_status SET indexed = true')
 
             conn.commit()
             with conn.cursor() as cur:
                 cur.execute('UPDATE import_status SET indexed = true')
 
             conn.commit()
-        finally:
-            conn.close()
 
 
-    def index(self, obj, batch=1):
-        """ Index a single rank or table. `obj` describes the SQL to use
+    def _index(self, runner, batch=1):
+        """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """
-        LOG.warning("Starting %s (using batch size %s)", obj.name(), batch)
-
-        cur = self.conn.cursor()
-        cur.execute(obj.sql_count_objects())
+        LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
 
-        total_tuples = cur.fetchone()[0]
-        LOG.debug("Total number of rows: %i", total_tuples)
+        with connect(self.dsn) as conn:
+            psycopg2.extras.register_hstore(conn)
+            with conn.cursor() as cur:
+                total_tuples = cur.scalar(runner.sql_count_objects())
+                LOG.debug("Total number of rows: %i", total_tuples)
 
 
-        cur.close()
+            conn.commit()
 
 
-        progress = ProgressLogger(obj.name(), total_tuples)
+            progress = ProgressLogger(runner.name(), total_tuples)
 
 
-        if total_tuples > 0:
-            cur = self.conn.cursor(name='places')
-            cur.execute(obj.sql_get_objects())
+            if total_tuples > 0:
+                with conn.cursor(name='places') as cur:
+                    cur.execute(runner.sql_get_objects())
 
 
-            next_thread = self.find_free_thread()
-            while True:
-                places = [p[0] for p in cur.fetchmany(batch)]
-                if not places:
-                    break
+                    with PlaceFetcher(self.dsn, conn) as fetcher:
+                        with WorkerPool(self.dsn, self.num_threads) as pool:
+                            has_more = fetcher.fetch_next_batch(cur, runner)
+                            while has_more:
+                                places = fetcher.get_batch()
 
 
-                LOG.debug("Processing places: %s", str(places))
-                thread = next(next_thread)
+                                # asynchronously get the next batch
+                                has_more = fetcher.fetch_next_batch(cur, runner)
 
 
-                thread.perform(obj.sql_index_place(places))
-                progress.add(len(places))
+                                # And insert the curent batch
+                                for idx in range(0, len(places), batch):
+                                    part = places[idx:idx+batch]
+                                    LOG.debug("Processing places: %s", str(part))
+                                    runner.index_places(pool.next_free_worker(), part)
+                                    progress.add(len(part))
 
 
-            cur.close()
+                            LOG.info("Wait time: fetcher: %.2fs,  pool: %.2fs",
+                                     fetcher.wait_time, pool.wait_time)
 
 
-            for thread in self.threads:
-                thread.wait()
+                conn.commit()
 
         progress.done()
 
         progress.done()
-
-    def find_free_thread(self):
-        """ Generator that returns the next connection that is free for
-            sending a query.
-        """
-        ready = self.threads
-        command_stat = 0
-
-        while True:
-            for thread in ready:
-                if thread.is_done():
-                    command_stat += 1
-                    yield thread
-
-            # refresh the connections occasionaly to avoid potential
-            # memory leaks in Postgresql.
-            if command_stat > 100000:
-                for thread in self.threads:
-                    while not thread.is_done():
-                        thread.wait()
-                    thread.connect()
-                command_stat = 0
-                ready = self.threads
-            else:
-                ready, _, _ = select.select(self.threads, [], [])
-
-        assert False, "Unreachable code"