]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/indexer/indexer.py
Vagrant and CI tests for Ubuntu 22.04
[nominatim.git] / nominatim / indexer / indexer.py
index d685e83a1546097366faa41fe8bc580ad5b4e660..555f8704a19c6796da4b97a724cd363d183f7f12 100644 (file)
@@ -1,80 +1,89 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Main work horse for indexing (computing addresses) the database.
 """
 import logging
 """
 Main work horse for indexing (computing addresses) the database.
 """
 import logging
-import select
 import time
 
 import psycopg2.extras
 
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
 import time
 
 import psycopg2.extras
 
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
-from nominatim.db.async_connection import DBConnection
+from nominatim.db.async_connection import DBConnection, WorkerPool
 from nominatim.db.connection import connect
 
 LOG = logging.getLogger()
 
 from nominatim.db.connection import connect
 
 LOG = logging.getLogger()
 
-class WorkerPool:
-    """ A pool of asynchronous database connections.
 
 
-        The pool may be used as a context manager.
+class PlaceFetcher:
+    """ Asynchronous connection that fetches place details for processing.
     """
     """
-    REOPEN_CONNECTIONS_AFTER = 100000
+    def __init__(self, dsn, setup_conn):
+        self.wait_time = 0
+        self.current_ids = None
+        self.conn = DBConnection(dsn, cursor_factory=psycopg2.extras.DictCursor)
 
 
-    def __init__(self, dsn, pool_size):
-        self.threads = [DBConnection(dsn) for _ in range(pool_size)]
-        self.free_workers = self._yield_free_worker()
+        with setup_conn.cursor() as cur:
+            # need to fetch those manually because register_hstore cannot
+            # fetch them on an asynchronous connection below.
+            hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
+            hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
 
 
-
-    def finish_all(self):
-        """ Wait for all connection to finish.
-        """
-        for thread in self.threads:
-            while not thread.is_done():
-                thread.wait()
-
-        self.free_workers = self._yield_free_worker()
+        psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
+                                        array_oid=hstore_array_oid)
 
     def close(self):
 
     def close(self):
-        """ Close all connections and clear the pool.
+        """ Close the underlying asynchronous connection.
         """
         """
-        for thread in self.threads:
-            thread.close()
-        self.threads = []
-        self.free_workers = None
+        if self.conn:
+            self.conn.close()
+            self.conn = None
+
 
 
+    def fetch_next_batch(self, cur, runner):
+        """ Send a request for the next batch of places.
+            If details for the places are required, they will be fetched
+            asynchronously.
 
 
-    def next_free_worker(self):
-        """ Get the next free connection.
+            Returns true if there is still data available.
         """
         """
-        return next(self.free_workers)
+        ids = cur.fetchmany(100)
 
 
+        if not ids:
+            self.current_ids = None
+            return False
 
 
-    def _yield_free_worker(self):
-        ready = self.threads
-        command_stat = 0
-        while True:
-            for thread in ready:
-                if thread.is_done():
-                    command_stat += 1
-                    yield thread
+        if hasattr(runner, 'get_place_details'):
+            runner.get_place_details(self.conn, ids)
+            self.current_ids = []
+        else:
+            self.current_ids = ids
 
 
-            if command_stat > self.REOPEN_CONNECTIONS_AFTER:
-                for thread in self.threads:
-                    while not thread.is_done():
-                        thread.wait()
-                    thread.connect()
-                ready = self.threads
-                command_stat = 0
-            else:
-                _, ready, _ = select.select([], self.threads, [])
+        return True
 
 
+    def get_batch(self):
+        """ Get the next batch of data, previously requested with
+            `fetch_next_batch`.
+        """
+        if self.current_ids is not None and not self.current_ids:
+            tstart = time.time()
+            self.conn.wait()
+            self.wait_time += time.time() - tstart
+            self.current_ids = self.conn.cursor.fetchall()
+
+        return self.current_ids
 
     def __enter__(self):
         return self
 
 
     def __exit__(self, exc_type, exc_value, traceback):
 
     def __enter__(self):
         return self
 
 
     def __exit__(self, exc_type, exc_value, traceback):
+        self.conn.wait()
         self.close()
 
 
         self.close()
 
 
@@ -88,8 +97,19 @@ class Indexer:
         self.num_threads = num_threads
 
 
         self.num_threads = num_threads
 
 
+    def has_pending(self):
+        """ Check if any data still needs indexing.
+            This function must only be used after the import has finished.
+            Otherwise it will be very expensive.
+        """
+        with connect(self.dsn) as conn:
+            with conn.cursor() as cur:
+                cur.execute("SELECT 'a' FROM placex WHERE indexed_status > 0 LIMIT 1")
+                return cur.rowcount > 0
+
+
     def index_full(self, analyse=True):
     def index_full(self, analyse=True):
-        """ Index the complete database. This will first index boudnaries
+        """ Index the complete database. This will first index boundaries
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
@@ -97,13 +117,10 @@ class Indexer:
         with connect(self.dsn) as conn:
             conn.autocommit = True
 
         with connect(self.dsn) as conn:
             conn.autocommit = True
 
-            if analyse:
-                def _analyze():
+            def _analyze():
+                if analyse:
                     with conn.cursor() as cur:
                         cur.execute('ANALYZE')
                     with conn.cursor() as cur:
                         cur.execute('ANALYZE')
-            else:
-                def _analyze():
-                    pass
 
             self.index_by_rank(0, 4)
             _analyze()
 
             self.index_by_rank(0, 4)
             _analyze()
@@ -143,15 +160,12 @@ class Indexer:
                     minrank, maxrank, self.num_threads)
 
         with self.tokenizer.name_analyzer() as analyzer:
                     minrank, maxrank, self.num_threads)
 
         with self.tokenizer.name_analyzer() as analyzer:
-            for rank in range(max(1, minrank), maxrank):
-                self._index(runners.RankRunner(rank, analyzer))
+            for rank in range(max(1, minrank), maxrank + 1):
+                self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1)
 
             if maxrank == 30:
                 self._index(runners.RankRunner(0, analyzer))
                 self._index(runners.InterpolationRunner(analyzer), 20)
 
             if maxrank == 30:
                 self._index(runners.RankRunner(0, analyzer))
                 self._index(runners.InterpolationRunner(analyzer), 20)
-                self._index(runners.RankRunner(30, analyzer), 20)
-            else:
-                self._index(runners.RankRunner(maxrank, analyzer))
 
 
     def index_postcodes(self):
 
 
     def index_postcodes(self):
@@ -184,70 +198,33 @@ class Indexer:
                 total_tuples = cur.scalar(runner.sql_count_objects())
                 LOG.debug("Total number of rows: %i", total_tuples)
 
                 total_tuples = cur.scalar(runner.sql_count_objects())
                 LOG.debug("Total number of rows: %i", total_tuples)
 
-                # need to fetch those manually because register_hstore cannot
-                # fetch them on an asynchronous connection below.
-                hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
-                hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
-
             conn.commit()
 
             progress = ProgressLogger(runner.name(), total_tuples)
 
             conn.commit()
 
             progress = ProgressLogger(runner.name(), total_tuples)
 
-            fetcher_wait = 0
-            pool_wait = 0
-
             if total_tuples > 0:
                 with conn.cursor(name='places') as cur:
                     cur.execute(runner.sql_get_objects())
 
             if total_tuples > 0:
                 with conn.cursor(name='places') as cur:
                     cur.execute(runner.sql_get_objects())
 
-                    fetcher = DBConnection(self.dsn, cursor_factory=psycopg2.extras.DictCursor)
-                    psycopg2.extras.register_hstore(fetcher.conn,
-                                                    oid=hstore_oid,
-                                                    array_oid=hstore_array_oid)
-
-                    with WorkerPool(self.dsn, self.num_threads) as pool:
-                        places = self._fetch_next_batch(cur, fetcher, runner)
-                        while places is not None:
-                            if not places:
-                                t0 = time.time()
-                                fetcher.wait()
-                                fetcher_wait += time.time() - t0
-                                places = fetcher.cursor.fetchall()
+                    with PlaceFetcher(self.dsn, conn) as fetcher:
+                        with WorkerPool(self.dsn, self.num_threads) as pool:
+                            has_more = fetcher.fetch_next_batch(cur, runner)
+                            while has_more:
+                                places = fetcher.get_batch()
 
 
-                            # asynchronously get the next batch
-                            next_places = self._fetch_next_batch(cur, fetcher, runner)
+                                # asynchronously get the next batch
+                                has_more = fetcher.fetch_next_batch(cur, runner)
 
 
-                            # And insert the curent batch
-                            for idx in range(0, len(places), batch):
-                                t0 = time.time()
-                                worker = pool.next_free_worker()
-                                pool_wait += time.time() - t0
-                                part = places[idx:idx+batch]
-                                LOG.debug("Processing places: %s", str(part))
-                                runner.index_places(worker, part)
-                                progress.add(len(part))
+                                # And insert the curent batch
+                                for idx in range(0, len(places), batch):
+                                    part = places[idx:idx + batch]
+                                    LOG.debug("Processing places: %s", str(part))
+                                    runner.index_places(pool.next_free_worker(), part)
+                                    progress.add(len(part))
 
 
-                            places = next_places
-
-                        pool.finish_all()
-
-                    fetcher.wait()
-                    fetcher.close()
+                            LOG.info("Wait time: fetcher: %.2fs,  pool: %.2fs",
+                                     fetcher.wait_time, pool.wait_time)
 
                 conn.commit()
 
         progress.done()
 
                 conn.commit()
 
         progress.done()
-        LOG.warning("Wait time: fetcher: {}s,  pool: {}s".format(fetcher_wait, pool_wait))
-
-
-    def _fetch_next_batch(self, cur, fetcher, runner):
-        ids = cur.fetchmany(100)
-
-        if not ids:
-            return None
-
-        if not hasattr(runner, 'get_place_details'):
-            return ids
-
-        runner.get_place_details(fetcher, ids)
-        return []