]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/indexer/indexer.py
add type annotations for indexer
[nominatim.git] / nominatim / indexer / indexer.py
index 2dd8220b1da16723b14a00b3b77301e1c8060815..4f7675309cbaa91068f777a97789b5b2e809c5ac 100644 (file)
@@ -1,77 +1,95 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Main work horse for indexing (computing addresses) the database.
 """
 """
 Main work horse for indexing (computing addresses) the database.
 """
+from typing import Optional, Any, cast
 import logging
 import logging
-import select
+import time
 
 
+import psycopg2.extras
+
+from nominatim.tokenizer.base import AbstractTokenizer
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
-from nominatim.db.async_connection import DBConnection
-from nominatim.db.connection import connect
+from nominatim.db.async_connection import DBConnection, WorkerPool
+from nominatim.db.connection import connect, Connection, Cursor
+from nominatim.typing import DictCursorResults
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-class WorkerPool:
-    """ A pool of asynchronous database connections.
 
 
-        The pool may be used as a context manager.
+class PlaceFetcher:
+    """ Asynchronous connection that fetches place details for processing.
     """
     """
-    REOPEN_CONNECTIONS_AFTER = 100000
+    def __init__(self, dsn: str, setup_conn: Connection) -> None:
+        self.wait_time = 0.0
+        self.current_ids: Optional[DictCursorResults] = None
+        self.conn: Optional[DBConnection] = DBConnection(dsn,
+                                               cursor_factory=psycopg2.extras.DictCursor)
+
+        with setup_conn.cursor() as cur:
+            # need to fetch those manually because register_hstore cannot
+            # fetch them on an asynchronous connection below.
+            hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
+            hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
+
+        psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
+                                        array_oid=hstore_array_oid)
+
+    def close(self) -> None:
+        """ Close the underlying asynchronous connection.
+        """
+        if self.conn:
+            self.conn.close()
+            self.conn = None
 
 
-    def __init__(self, dsn, pool_size):
-        self.threads = [DBConnection(dsn) for _ in range(pool_size)]
-        self.free_workers = self._yield_free_worker()
 
 
+    def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
+        """ Send a request for the next batch of places.
+            If details for the places are required, they will be fetched
+            asynchronously.
 
 
-    def finish_all(self):
-        """ Wait for all connection to finish.
+            Returns true if there is still data available.
         """
         """
-        for thread in self.threads:
-            while not thread.is_done():
-                thread.wait()
+        ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
 
 
-        self.free_workers = self._yield_free_worker()
+        if not ids:
+            self.current_ids = None
+            return False
 
 
-    def close(self):
-        """ Close all connections and clear the pool.
-        """
-        for thread in self.threads:
-            thread.close()
-        self.threads = []
-        self.free_workers = None
+        assert self.conn is not None
+        self.current_ids = runner.get_place_details(self.conn, ids)
 
 
+        return True
 
 
-    def next_free_worker(self):
-        """ Get the next free connection.
+    def get_batch(self) -> DictCursorResults:
+        """ Get the next batch of data, previously requested with
+            `fetch_next_batch`.
         """
         """
-        return next(self.free_workers)
-
-
-    def _yield_free_worker(self):
-        ready = self.threads
-        command_stat = 0
-        while True:
-            for thread in ready:
-                if thread.is_done():
-                    command_stat += 1
-                    yield thread
-
-            if command_stat > self.REOPEN_CONNECTIONS_AFTER:
-                for thread in self.threads:
-                    while not thread.is_done():
-                        thread.wait()
-                    thread.connect()
-                ready = self.threads
-                command_stat = 0
-            else:
-                _, ready, _ = select.select([], self.threads, [])
-
-
-    def __enter__(self):
+        assert self.conn is not None
+        assert self.conn.cursor is not None
+
+        if self.current_ids is not None and not self.current_ids:
+            tstart = time.time()
+            self.conn.wait()
+            self.wait_time += time.time() - tstart
+            self.current_ids = cast(Optional[DictCursorResults],
+                                    self.conn.cursor.fetchall())
+
+        return self.current_ids if self.current_ids is not None else []
+
+    def __enter__(self) -> 'PlaceFetcher':
         return self
 
 
         return self
 
 
-    def __exit__(self, exc_type, exc_value, traceback):
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        assert self.conn is not None
+        self.conn.wait()
         self.close()
 
 
         self.close()
 
 
@@ -79,14 +97,25 @@ class Indexer:
     """ Main indexing routine.
     """
 
     """ Main indexing routine.
     """
 
-    def __init__(self, dsn, tokenizer, num_threads):
+    def __init__(self, dsn: str, tokenizer: AbstractTokenizer, num_threads: int):
         self.dsn = dsn
         self.tokenizer = tokenizer
         self.num_threads = num_threads
 
 
         self.dsn = dsn
         self.tokenizer = tokenizer
         self.num_threads = num_threads
 
 
-    def index_full(self, analyse=True):
-        """ Index the complete database. This will first index boudnaries
+    def has_pending(self) -> bool:
+        """ Check if any data still needs indexing.
+            This function must only be used after the import has finished.
+            Otherwise it will be very expensive.
+        """
+        with connect(self.dsn) as conn:
+            with conn.cursor() as cur:
+                cur.execute("SELECT 'a' FROM placex WHERE indexed_status > 0 LIMIT 1")
+                return cur.rowcount > 0
+
+
+    def index_full(self, analyse: bool = True) -> None:
+        """ Index the complete database. This will first index boundaries
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
@@ -94,13 +123,10 @@ class Indexer:
         with connect(self.dsn) as conn:
             conn.autocommit = True
 
         with connect(self.dsn) as conn:
             conn.autocommit = True
 
-            if analyse:
-                def _analyze():
+            def _analyze() -> None:
+                if analyse:
                     with conn.cursor() as cur:
                         cur.execute('ANALYZE')
                     with conn.cursor() as cur:
                         cur.execute('ANALYZE')
-            else:
-                def _analyze():
-                    pass
 
             self.index_by_rank(0, 4)
             _analyze()
 
             self.index_by_rank(0, 4)
             _analyze()
@@ -118,7 +144,7 @@ class Indexer:
             _analyze()
 
 
             _analyze()
 
 
-    def index_boundaries(self, minrank, maxrank):
+    def index_boundaries(self, minrank: int, maxrank: int) -> None:
         """ Index only administrative boundaries within the given rank range.
         """
         LOG.warning("Starting indexing boundaries using %s threads",
         """ Index only administrative boundaries within the given rank range.
         """
         LOG.warning("Starting indexing boundaries using %s threads",
@@ -128,7 +154,7 @@ class Indexer:
             for rank in range(max(minrank, 4), min(maxrank, 26)):
                 self._index(runners.BoundaryRunner(rank, analyzer))
 
             for rank in range(max(minrank, 4), min(maxrank, 26)):
                 self._index(runners.BoundaryRunner(rank, analyzer))
 
-    def index_by_rank(self, minrank, maxrank):
+    def index_by_rank(self, minrank: int, maxrank: int) -> None:
         """ Index all entries of placex in the given rank range (inclusive)
             in order of their address rank.
 
         """ Index all entries of placex in the given rank range (inclusive)
             in order of their address rank.
 
@@ -140,18 +166,15 @@ class Indexer:
                     minrank, maxrank, self.num_threads)
 
         with self.tokenizer.name_analyzer() as analyzer:
                     minrank, maxrank, self.num_threads)
 
         with self.tokenizer.name_analyzer() as analyzer:
-            for rank in range(max(1, minrank), maxrank):
-                self._index(runners.RankRunner(rank, analyzer))
+            for rank in range(max(1, minrank), maxrank + 1):
+                self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1)
 
             if maxrank == 30:
                 self._index(runners.RankRunner(0, analyzer))
 
             if maxrank == 30:
                 self._index(runners.RankRunner(0, analyzer))
-                self._index(runners.InterpolationRunner(), 20)
-                self._index(runners.RankRunner(30, analyzer), 20)
-            else:
-                self._index(runners.RankRunner(maxrank, analyzer))
+                self._index(runners.InterpolationRunner(analyzer), 20)
 
 
 
 
-    def index_postcodes(self):
+    def index_postcodes(self) -> None:
         """Index the entries ofthe location_postcode table.
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
         """Index the entries ofthe location_postcode table.
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
@@ -159,7 +182,7 @@ class Indexer:
         self._index(runners.PostcodeRunner(), 20)
 
 
         self._index(runners.PostcodeRunner(), 20)
 
 
-    def update_status_table(self):
+    def update_status_table(self) -> None:
         """ Update the status in the status table to 'indexed'.
         """
         with connect(self.dsn) as conn:
         """ Update the status in the status table to 'indexed'.
         """
         with connect(self.dsn) as conn:
@@ -168,7 +191,7 @@ class Indexer:
 
             conn.commit()
 
 
             conn.commit()
 
-    def _index(self, runner, batch=1):
+    def _index(self, runner: runners.Runner, batch: int = 1) -> None:
         """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
@@ -176,6 +199,7 @@ class Indexer:
         LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
         with connect(self.dsn) as conn:
         LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
         with connect(self.dsn) as conn:
+            psycopg2.extras.register_hstore(conn)
             with conn.cursor() as cur:
                 total_tuples = cur.scalar(runner.sql_count_objects())
                 LOG.debug("Total number of rows: %i", total_tuples)
             with conn.cursor() as cur:
                 total_tuples = cur.scalar(runner.sql_count_objects())
                 LOG.debug("Total number of rows: %i", total_tuples)
@@ -188,19 +212,24 @@ class Indexer:
                 with conn.cursor(name='places') as cur:
                     cur.execute(runner.sql_get_objects())
 
                 with conn.cursor(name='places') as cur:
                     cur.execute(runner.sql_get_objects())
 
-                    with WorkerPool(self.dsn, self.num_threads) as pool:
-                        while True:
-                            places = [p for p in cur.fetchmany(batch)]
-                            if not places:
-                                break
+                    with PlaceFetcher(self.dsn, conn) as fetcher:
+                        with WorkerPool(self.dsn, self.num_threads) as pool:
+                            has_more = fetcher.fetch_next_batch(cur, runner)
+                            while has_more:
+                                places = fetcher.get_batch()
 
 
-                            LOG.debug("Processing places: %s", str(places))
-                            worker = pool.next_free_worker()
+                                # asynchronously get the next batch
+                                has_more = fetcher.fetch_next_batch(cur, runner)
 
 
-                            runner.index_places(worker, places)
-                            progress.add(len(places))
+                                # And insert the curent batch
+                                for idx in range(0, len(places), batch):
+                                    part = places[idx:idx + batch]
+                                    LOG.debug("Processing places: %s", str(part))
+                                    runner.index_places(pool.next_free_worker(), part)
+                                    progress.add(len(part))
 
 
-                        pool.finish_all()
+                            LOG.info("Wait time: fetcher: %.2fs,  pool: %.2fs",
+                                     fetcher.wait_time, pool.wait_time)
 
                 conn.commit()
 
 
                 conn.commit()