]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_db/indexer/indexer.py
remove remaining pylint hints
[nominatim.git] / src / nominatim_db / indexer / indexer.py
index 5a219f6b2ecc3904b1129c9b4f9cff92c12132a5..d467efbd6e27b7dbaa207ee042d2ea3ac2cadcf9 100644 (file)
@@ -7,15 +7,14 @@
 """
 Main work horse for indexing (computing addresses) the database.
 """
 """
 Main work horse for indexing (computing addresses) the database.
 """
-from typing import Optional, Any, cast
+from typing import cast, List, Any, Optional
 import logging
 import time
 
 import logging
 import time
 
-import psycopg2.extras
+import psycopg
 
 
-from ..typing import DictCursorResults
-from ..db.async_connection import DBConnection, WorkerPool
-from ..db.connection import connect, Connection, Cursor
+from ..db.connection import connect, execute_scalar
+from ..db.query_pool import QueryPool
 from ..tokenizer.base import AbstractTokenizer
 from .progress import ProgressLogger
 from . import runners
 from ..tokenizer.base import AbstractTokenizer
 from .progress import ProgressLogger
 from . import runners
@@ -23,76 +22,6 @@ from . import runners
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
 
-class PlaceFetcher:
-    """ Asynchronous connection that fetches place details for processing.
-    """
-    def __init__(self, dsn: str, setup_conn: Connection) -> None:
-        self.wait_time = 0.0
-        self.current_ids: Optional[DictCursorResults] = None
-        self.conn: Optional[DBConnection] = DBConnection(dsn,
-                                               cursor_factory=psycopg2.extras.DictCursor)
-
-        with setup_conn.cursor() as cur:
-            # need to fetch those manually because register_hstore cannot
-            # fetch them on an asynchronous connection below.
-            hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
-            hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
-
-        psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
-                                        array_oid=hstore_array_oid)
-
-    def close(self) -> None:
-        """ Close the underlying asynchronous connection.
-        """
-        if self.conn:
-            self.conn.close()
-            self.conn = None
-
-
-    def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
-        """ Send a request for the next batch of places.
-            If details for the places are required, they will be fetched
-            asynchronously.
-
-            Returns true if there is still data available.
-        """
-        ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
-
-        if not ids:
-            self.current_ids = None
-            return False
-
-        assert self.conn is not None
-        self.current_ids = runner.get_place_details(self.conn, ids)
-
-        return True
-
-    def get_batch(self) -> DictCursorResults:
-        """ Get the next batch of data, previously requested with
-            `fetch_next_batch`.
-        """
-        assert self.conn is not None
-        assert self.conn.cursor is not None
-
-        if self.current_ids is not None and not self.current_ids:
-            tstart = time.time()
-            self.conn.wait()
-            self.wait_time += time.time() - tstart
-            self.current_ids = cast(Optional[DictCursorResults],
-                                    self.conn.cursor.fetchall())
-
-        return self.current_ids if self.current_ids is not None else []
-
-    def __enter__(self) -> 'PlaceFetcher':
-        return self
-
-
-    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
-        assert self.conn is not None
-        self.conn.wait()
-        self.close()
-
-
 class Indexer:
     """ Main indexing routine.
     """
 class Indexer:
     """ Main indexing routine.
     """
@@ -102,7 +31,6 @@ class Indexer:
         self.tokenizer = tokenizer
         self.num_threads = num_threads
 
         self.tokenizer = tokenizer
         self.num_threads = num_threads
 
-
     def has_pending(self) -> bool:
         """ Check if any data still needs indexing.
             This function must only be used after the import has finished.
     def has_pending(self) -> bool:
         """ Check if any data still needs indexing.
             This function must only be used after the import has finished.
@@ -113,8 +41,7 @@ class Indexer:
                 cur.execute("SELECT 'a' FROM placex WHERE indexed_status > 0 LIMIT 1")
                 return cur.rowcount > 0
 
                 cur.execute("SELECT 'a' FROM placex WHERE indexed_status > 0 LIMIT 1")
                 return cur.rowcount > 0
 
-
-    def index_full(self, analyse: bool = True) -> None:
+    async def index_full(self, analyse: bool = True) -> None:
         """ Index the complete database. This will first index boundaries
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
         """ Index the complete database. This will first index boundaries
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
@@ -128,36 +55,60 @@ class Indexer:
                     with conn.cursor() as cur:
                         cur.execute('ANALYZE')
 
                     with conn.cursor() as cur:
                         cur.execute('ANALYZE')
 
-            if self.index_by_rank(0, 4) > 0:
-                _analyze()
+            while True:
+                if await self.index_by_rank(0, 4) > 0:
+                    _analyze()
 
 
-            if self.index_boundaries(0, 30) > 100:
-                _analyze()
+                if await self.index_boundaries(0, 30) > 100:
+                    _analyze()
 
 
-            if self.index_by_rank(5, 25) > 100:
-                _analyze()
+                if await self.index_by_rank(5, 25) > 100:
+                    _analyze()
 
 
-            if self.index_by_rank(26, 30) > 1000:
-                _analyze()
+                if await self.index_by_rank(26, 30) > 1000:
+                    _analyze()
 
 
-            if self.index_postcodes() > 100:
-                _analyze()
+                if await self.index_postcodes() > 100:
+                    _analyze()
 
 
+                if not self.has_pending():
+                    break
 
 
-    def index_boundaries(self, minrank: int, maxrank: int) -> int:
+    async def index_boundaries(self, minrank: int, maxrank: int) -> int:
         """ Index only administrative boundaries within the given rank range.
         """
         total = 0
         LOG.warning("Starting indexing boundaries using %s threads",
                     self.num_threads)
 
         """ Index only administrative boundaries within the given rank range.
         """
         total = 0
         LOG.warning("Starting indexing boundaries using %s threads",
                     self.num_threads)
 
+        minrank = max(minrank, 4)
+        maxrank = min(maxrank, 25)
+
+        # Precompute number of rows to process for all rows
+        with connect(self.dsn) as conn:
+            hstore_info = psycopg.types.TypeInfo.fetch(conn, "hstore")
+            if hstore_info is None:
+                raise RuntimeError('Hstore extension is requested but not installed.')
+            psycopg.types.hstore.register_hstore(hstore_info)
+
+            with conn.cursor() as cur:
+                cur = conn.execute(""" SELECT rank_search, count(*)
+                                       FROM placex
+                                       WHERE rank_search between %s and %s
+                                             AND class = 'boundary' and type = 'administrative'
+                                             AND indexed_status > 0
+                                       GROUP BY rank_search""",
+                                   (minrank, maxrank))
+                total_tuples = {row.rank_search: row.count for row in cur}
+
         with self.tokenizer.name_analyzer() as analyzer:
         with self.tokenizer.name_analyzer() as analyzer:
-            for rank in range(max(minrank, 4), min(maxrank, 26)):
-                total += self._index(runners.BoundaryRunner(rank, analyzer))
+            for rank in range(minrank, maxrank + 1):
+                total += await self._index(runners.BoundaryRunner(rank, analyzer),
+                                           total_tuples=total_tuples.get(rank, 0))
 
         return total
 
 
         return total
 
-    def index_by_rank(self, minrank: int, maxrank: int) -> int:
+    async def index_by_rank(self, minrank: int, maxrank: int) -> int:
         """ Index all entries of placex in the given rank range (inclusive)
             in order of their address rank.
 
         """ Index all entries of placex in the given rank range (inclusive)
             in order of their address rank.
 
@@ -169,24 +120,45 @@ class Indexer:
         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
                     minrank, maxrank, self.num_threads)
 
         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
                     minrank, maxrank, self.num_threads)
 
+        # Precompute number of rows to process for all rows
+        with connect(self.dsn) as conn:
+            hstore_info = psycopg.types.TypeInfo.fetch(conn, "hstore")
+            if hstore_info is None:
+                raise RuntimeError('Hstore extension is requested but not installed.')
+            psycopg.types.hstore.register_hstore(hstore_info)
+
+            with conn.cursor() as cur:
+                cur = conn.execute(""" SELECT rank_address, count(*)
+                                       FROM placex
+                                       WHERE rank_address between %s and %s
+                                             AND indexed_status > 0
+                                       GROUP BY rank_address""",
+                                   (minrank, maxrank))
+                total_tuples = {row.rank_address: row.count for row in cur}
+
         with self.tokenizer.name_analyzer() as analyzer:
             for rank in range(max(1, minrank), maxrank + 1):
         with self.tokenizer.name_analyzer() as analyzer:
             for rank in range(max(1, minrank), maxrank + 1):
-                total += self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1)
+                if rank >= 30:
+                    batch = 20
+                elif rank >= 26:
+                    batch = 5
+                else:
+                    batch = 1
+                total += await self._index(runners.RankRunner(rank, analyzer),
+                                           batch=batch, total_tuples=total_tuples.get(rank, 0))
 
             if maxrank == 30:
 
             if maxrank == 30:
-                total += self._index(runners.RankRunner(0, analyzer))
-                total += self._index(runners.InterpolationRunner(analyzer), 20)
+                total += await self._index(runners.RankRunner(0, analyzer))
+                total += await self._index(runners.InterpolationRunner(analyzer), batch=20)
 
         return total
 
 
         return total
 
-
-    def index_postcodes(self) -> int:
+    async def index_postcodes(self) -> int:
         """Index the entries of the location_postcode table.
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
 
         """Index the entries of the location_postcode table.
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
 
-        return self._index(runners.PostcodeRunner(), 20)
-
+        return await self._index(runners.PostcodeRunner(), batch=20)
 
     def update_status_table(self) -> None:
         """ Update the status in the status table to 'indexed'.
 
     def update_status_table(self) -> None:
         """ Update the status in the status table to 'indexed'.
@@ -197,46 +169,63 @@ class Indexer:
 
             conn.commit()
 
 
             conn.commit()
 
-    def _index(self, runner: runners.Runner, batch: int = 1) -> int:
+    async def _index(self, runner: runners.Runner, batch: int = 1,
+                     total_tuples: Optional[int] = None) -> int:
         """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
         """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
-            should be processed with a single SQL statement
+            should be processed with a single SQL statement.
+
+            `total_tuples` may contain the total number of rows to process.
+            When not supplied, the value will be computed using the
+            approriate runner function.
         """
         LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
         """
         LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
-        with connect(self.dsn) as conn:
-            psycopg2.extras.register_hstore(conn)
-            with conn.cursor() as cur:
-                total_tuples = cur.scalar(runner.sql_count_objects())
-                LOG.debug("Total number of rows: %i", total_tuples)
+        if total_tuples is None:
+            total_tuples = self._prepare_indexing(runner)
 
 
-            conn.commit()
-
-            progress = ProgressLogger(runner.name(), total_tuples)
+        progress = ProgressLogger(runner.name(), total_tuples)
 
 
-            if total_tuples > 0:
-                with conn.cursor(name='places') as cur:
-                    cur.execute(runner.sql_get_objects())
+        if total_tuples > 0:
+            async with await psycopg.AsyncConnection.connect(
+                                 self.dsn, row_factory=psycopg.rows.dict_row) as aconn, \
+                       QueryPool(self.dsn, self.num_threads, autocommit=True) as pool:
+                fetcher_time = 0.0
+                tstart = time.time()
+                async with aconn.cursor(name='places') as cur:
+                    query = runner.index_places_query(batch)
+                    params: List[Any] = []
+                    num_places = 0
+                    async for place in cur.stream(runner.sql_get_objects()):
+                        fetcher_time += time.time() - tstart
 
 
-                    with PlaceFetcher(self.dsn, conn) as fetcher:
-                        with WorkerPool(self.dsn, self.num_threads) as pool:
-                            has_more = fetcher.fetch_next_batch(cur, runner)
-                            while has_more:
-                                places = fetcher.get_batch()
+                        params.extend(runner.index_places_params(place))
+                        num_places += 1
 
 
-                                # asynchronously get the next batch
-                                has_more = fetcher.fetch_next_batch(cur, runner)
+                        if num_places >= batch:
+                            LOG.debug("Processing places: %s", str(params))
+                            await pool.put_query(query, params)
+                            progress.add(num_places)
+                            params = []
+                            num_places = 0
 
 
-                                # And insert the current batch
-                                for idx in range(0, len(places), batch):
-                                    part = places[idx:idx + batch]
-                                    LOG.debug("Processing places: %s", str(part))
-                                    runner.index_places(pool.next_free_worker(), part)
-                                    progress.add(len(part))
+                        tstart = time.time()
 
 
-                            LOG.info("Wait time: fetcher: %.2fs,  pool: %.2fs",
-                                     fetcher.wait_time, pool.wait_time)
+                if num_places > 0:
+                    await pool.put_query(runner.index_places_query(num_places), params)
 
 
-                conn.commit()
+            LOG.info("Wait time: fetcher: %.2fs,  pool: %.2fs",
+                     fetcher_time, pool.wait_time)
 
         return progress.done()
 
         return progress.done()
+
+    def _prepare_indexing(self, runner: runners.Runner) -> int:
+        with connect(self.dsn) as conn:
+            hstore_info = psycopg.types.TypeInfo.fetch(conn, "hstore")
+            if hstore_info is None:
+                raise RuntimeError('Hstore extension is requested but not installed.')
+            psycopg.types.hstore.register_hstore(hstore_info)
+
+            total_tuples = execute_scalar(conn, runner.sql_count_objects())
+            LOG.debug("Total number of rows: %i", total_tuples)
+        return cast(int, total_tuples)