]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/indexer/indexer.py
Merge pull request #3301 from lonvia/fix-class-search-regression
[nominatim.git] / nominatim / indexer / indexer.py
index 61971497f957f15aeda81480f68e0d25cba32593..233423f03c6a202ec088cfeb0fe7ac26c79db01f 100644 (file)
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Main work horse for indexing (computing addresses) the database.
 """
 """
 Main work horse for indexing (computing addresses) the database.
 """
-# pylint: disable=C0111
+from typing import Optional, Any, cast
 import logging
 import logging
-import select
+import time
 
 
-import psycopg2
+import psycopg2.extras
 
 
-from .progress import ProgressLogger
-from ..db.async_connection import DBConnection
+from nominatim.tokenizer.base import AbstractTokenizer
+from nominatim.indexer.progress import ProgressLogger
+from nominatim.indexer import runners
+from nominatim.db.async_connection import DBConnection, WorkerPool
+from nominatim.db.connection import connect, Connection, Cursor
+from nominatim.typing import DictCursorResults
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-class RankRunner:
-    """ Returns SQL commands for indexing one rank within the placex table.
-    """
-
-    def __init__(self, rank):
-        self.rank = rank
-
-    def name(self):
-        return "rank {}".format(self.rank)
-
-    def sql_count_objects(self):
-        return """SELECT count(*) FROM placex
-                  WHERE rank_address = {} and indexed_status > 0
-               """.format(self.rank)
-
-    def sql_get_objects(self):
-        return """SELECT place_id FROM placex
-                  WHERE indexed_status > 0 and rank_address = {}
-                  ORDER BY geometry_sector""".format(self.rank)
-
-    @staticmethod
-    def sql_index_place(ids):
-        return "UPDATE placex SET indexed_status = 0 WHERE place_id IN ({})"\
-               .format(','.join((str(i) for i in ids)))
-
 
 
-class InterpolationRunner:
-    """ Returns SQL commands for indexing the address interpolation table
-        location_property_osmline.
+class PlaceFetcher:
+    """ Asynchronous connection that fetches place details for processing.
     """
     """
+    def __init__(self, dsn: str, setup_conn: Connection) -> None:
+        self.wait_time = 0.0
+        self.current_ids: Optional[DictCursorResults] = None
+        self.conn: Optional[DBConnection] = DBConnection(dsn,
+                                               cursor_factory=psycopg2.extras.DictCursor)
+
+        with setup_conn.cursor() as cur:
+            # need to fetch those manually because register_hstore cannot
+            # fetch them on an asynchronous connection below.
+            hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
+            hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
+
+        psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
+                                        array_oid=hstore_array_oid)
+
+    def close(self) -> None:
+        """ Close the underlying asynchronous connection.
+        """
+        if self.conn:
+            self.conn.close()
+            self.conn = None
 
 
-    @staticmethod
-    def name():
-        return "interpolation lines (location_property_osmline)"
-
-    @staticmethod
-    def sql_count_objects():
-        return """SELECT count(*) FROM location_property_osmline
-                  WHERE indexed_status > 0"""
-
-    @staticmethod
-    def sql_get_objects():
-        return """SELECT place_id FROM location_property_osmline
-                  WHERE indexed_status > 0
-                  ORDER BY geometry_sector"""
-
-    @staticmethod
-    def sql_index_place(ids):
-        return """UPDATE location_property_osmline
-                  SET indexed_status = 0 WHERE place_id IN ({})
-               """.format(','.join((str(i) for i in ids)))
-
-class BoundaryRunner:
-    """ Returns SQL commands for indexing the administrative boundaries
-        of a certain rank.
-    """
 
 
-    def __init__(self, rank):
-        self.rank = rank
+    def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
+        """ Send a request for the next batch of places.
+            If details for the places are required, they will be fetched
+            asynchronously.
+
+            Returns true if there is still data available.
+        """
+        ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
 
 
-    def name(self):
-        return "boundaries rank {}".format(self.rank)
+        if not ids:
+            self.current_ids = None
+            return False
 
 
-    def sql_count_objects(self):
-        return """SELECT count(*) FROM placex
-                  WHERE indexed_status > 0
-                    AND rank_search = {}
-                    AND class = 'boundary' and type = 'administrative'
-               """.format(self.rank)
+        assert self.conn is not None
+        self.current_ids = runner.get_place_details(self.conn, ids)
 
 
-    def sql_get_objects(self):
-        return """SELECT place_id FROM placex
-                  WHERE indexed_status > 0 and rank_search = {}
-                        and class = 'boundary' and type = 'administrative'
-                  ORDER BY partition, admin_level
-               """.format(self.rank)
+        return True
 
 
-    @staticmethod
-    def sql_index_place(ids):
-        return "UPDATE placex SET indexed_status = 0 WHERE place_id IN ({})"\
-               .format(','.join((str(i) for i in ids)))
+    def get_batch(self) -> DictCursorResults:
+        """ Get the next batch of data, previously requested with
+            `fetch_next_batch`.
+        """
+        assert self.conn is not None
+        assert self.conn.cursor is not None
 
 
+        if self.current_ids is not None and not self.current_ids:
+            tstart = time.time()
+            self.conn.wait()
+            self.wait_time += time.time() - tstart
+            self.current_ids = cast(Optional[DictCursorResults],
+                                    self.conn.cursor.fetchall())
 
 
-class PostcodeRunner:
-    """ Provides the SQL commands for indexing the location_postcode table.
-    """
+        return self.current_ids if self.current_ids is not None else []
 
 
-    @staticmethod
-    def name():
-        return "postcodes (location_postcode)"
+    def __enter__(self) -> 'PlaceFetcher':
+        return self
 
 
-    @staticmethod
-    def sql_count_objects():
-        return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0'
 
 
-    @staticmethod
-    def sql_get_objects():
-        return """SELECT place_id FROM location_postcode
-                  WHERE indexed_status > 0
-                  ORDER BY country_code, postcode"""
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        assert self.conn is not None
+        self.conn.wait()
+        self.close()
 
 
-    @staticmethod
-    def sql_index_place(ids):
-        return """UPDATE location_postcode SET indexed_status = 0
-                  WHERE place_id IN ({})
-               """.format(','.join((str(i) for i in ids)))
 
 class Indexer:
     """ Main indexing routine.
     """
 
 
 class Indexer:
     """ Main indexing routine.
     """
 
-    def __init__(self, dsn, num_threads):
-        self.conn = psycopg2.connect(dsn)
-        self.threads = [DBConnection(dsn) for _ in range(num_threads)]
+    def __init__(self, dsn: str, tokenizer: AbstractTokenizer, num_threads: int):
+        self.dsn = dsn
+        self.tokenizer = tokenizer
+        self.num_threads = num_threads
 
 
 
 
-    def index_full(self, analyse=True):
-        """ Index the complete database. This will first index boudnaries
+    def has_pending(self) -> bool:
+        """ Check if any data still needs indexing.
+            This function must only be used after the import has finished.
+            Otherwise it will be very expensive.
+        """
+        with connect(self.dsn) as conn:
+            with conn.cursor() as cur:
+                cur.execute("SELECT 'a' FROM placex WHERE indexed_status > 0 LIMIT 1")
+                return cur.rowcount > 0
+
+
+    def index_full(self, analyse: bool = True) -> None:
+        """ Index the complete database. This will first index boundaries
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
         """
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
         """
-        self.index_by_rank(0, 4)
-        self._analyse_db_if(analyse)
+        with connect(self.dsn) as conn:
+            conn.autocommit = True
+
+            def _analyze() -> None:
+                if analyse:
+                    with conn.cursor() as cur:
+                        cur.execute('ANALYZE')
 
 
-        self.index_boundaries(0, 30)
-        self._analyse_db_if(analyse)
+            if self.index_by_rank(0, 4) > 0:
+                _analyze()
 
 
-        self.index_by_rank(5, 25)
-        self._analyse_db_if(analyse)
+            if self.index_boundaries(0, 30) > 100:
+                _analyze()
 
 
-        self.index_by_rank(26, 30)
-        self._analyse_db_if(analyse)
+            if self.index_by_rank(5, 25) > 100:
+                _analyze()
 
 
-        self.index_postcodes()
-        self._analyse_db_if(analyse)
+            if self.index_by_rank(26, 30) > 1000:
+                _analyze()
 
 
-    def _analyse_db_if(self, condition):
-        if condition:
-            with self.conn.cursor() as cur:
-                cur.execute('ANALYSE')
+            if self.index_postcodes() > 100:
+                _analyze()
 
 
-    def index_boundaries(self, minrank, maxrank):
+
+    def index_boundaries(self, minrank: int, maxrank: int) -> int:
         """ Index only administrative boundaries within the given rank range.
         """
         """ Index only administrative boundaries within the given rank range.
         """
+        total = 0
         LOG.warning("Starting indexing boundaries using %s threads",
         LOG.warning("Starting indexing boundaries using %s threads",
-                    len(self.threads))
+                    self.num_threads)
+
+        with self.tokenizer.name_analyzer() as analyzer:
+            for rank in range(max(minrank, 4), min(maxrank, 26)):
+                total += self._index(runners.BoundaryRunner(rank, analyzer))
 
 
-        for rank in range(max(minrank, 4), min(maxrank, 26)):
-            self.index(BoundaryRunner(rank))
+        return total
 
 
-    def index_by_rank(self, minrank, maxrank):
+    def index_by_rank(self, minrank: int, maxrank: int) -> int:
         """ Index all entries of placex in the given rank range (inclusive)
             in order of their address rank.
 
             When rank 30 is requested then also interpolations and
             places with address rank 0 will be indexed.
         """
         """ Index all entries of placex in the given rank range (inclusive)
             in order of their address rank.
 
             When rank 30 is requested then also interpolations and
             places with address rank 0 will be indexed.
         """
+        total = 0
         maxrank = min(maxrank, 30)
         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
         maxrank = min(maxrank, 30)
         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
-                    minrank, maxrank, len(self.threads))
+                    minrank, maxrank, self.num_threads)
+
+        with self.tokenizer.name_analyzer() as analyzer:
+            for rank in range(max(1, minrank), maxrank + 1):
+                total += self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1)
 
 
-        for rank in range(max(1, minrank), maxrank):
-            self.index(RankRunner(rank))
+            if maxrank == 30:
+                total += self._index(runners.RankRunner(0, analyzer))
+                total += self._index(runners.InterpolationRunner(analyzer), 20)
 
 
-        if maxrank == 30:
-            self.index(RankRunner(0))
-            self.index(InterpolationRunner(), 20)
-            self.index(RankRunner(30), 20)
-        else:
-            self.index(RankRunner(maxrank))
+        return total
 
 
 
 
-    def index_postcodes(self):
-        """Index the entries ofthe location_postcode table.
+    def index_postcodes(self) -> int:
+        """Index the entries of the location_postcode table.
         """
         """
-        self.index(PostcodeRunner(), 20)
+        LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
 
 
-    def update_status_table(self):
+        return self._index(runners.PostcodeRunner(), 20)
+
+
+    def update_status_table(self) -> None:
         """ Update the status in the status table to 'indexed'.
         """
         """ Update the status in the status table to 'indexed'.
         """
-        with self.conn.cursor() as cur:
-            cur.execute('UPDATE import_status SET indexed = true')
-        self.conn.commit()
+        with connect(self.dsn) as conn:
+            with conn.cursor() as cur:
+                cur.execute('UPDATE import_status SET indexed = true')
+
+            conn.commit()
 
 
-    def index(self, obj, batch=1):
-        """ Index a single rank or table. `obj` describes the SQL to use
+    def _index(self, runner: runners.Runner, batch: int = 1) -> int:
+        """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """
-        LOG.warning("Starting %s (using batch size %s)", obj.name(), batch)
+        LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
 
-        cur = self.conn.cursor()
-        cur.execute(obj.sql_count_objects())
+        with connect(self.dsn) as conn:
+            psycopg2.extras.register_hstore(conn)
+            with conn.cursor() as cur:
+                total_tuples = cur.scalar(runner.sql_count_objects())
+                LOG.debug("Total number of rows: %i", total_tuples)
 
 
-        total_tuples = cur.fetchone()[0]
-        LOG.debug("Total number of rows: %i", total_tuples)
+            conn.commit()
 
 
-        cur.close()
+            progress = ProgressLogger(runner.name(), total_tuples)
 
 
-        progress = ProgressLogger(obj.name(), total_tuples)
+            if total_tuples > 0:
+                with conn.cursor(name='places') as cur:
+                    cur.execute(runner.sql_get_objects())
 
 
-        if total_tuples > 0:
-            cur = self.conn.cursor(name='places')
-            cur.execute(obj.sql_get_objects())
+                    with PlaceFetcher(self.dsn, conn) as fetcher:
+                        with WorkerPool(self.dsn, self.num_threads) as pool:
+                            has_more = fetcher.fetch_next_batch(cur, runner)
+                            while has_more:
+                                places = fetcher.get_batch()
 
 
-            next_thread = self.find_free_thread()
-            while True:
-                places = [p[0] for p in cur.fetchmany(batch)]
-                if not places:
-                    break
+                                # asynchronously get the next batch
+                                has_more = fetcher.fetch_next_batch(cur, runner)
 
 
-                LOG.debug("Processing places: %s", str(places))
-                thread = next(next_thread)
+                                # And insert the current batch
+                                for idx in range(0, len(places), batch):
+                                    part = places[idx:idx + batch]
+                                    LOG.debug("Processing places: %s", str(part))
+                                    runner.index_places(pool.next_free_worker(), part)
+                                    progress.add(len(part))
 
 
-                thread.perform(obj.sql_index_place(places))
-                progress.add(len(places))
+                            LOG.info("Wait time: fetcher: %.2fs,  pool: %.2fs",
+                                     fetcher.wait_time, pool.wait_time)
 
 
-            cur.close()
+                conn.commit()
 
 
-            for thread in self.threads:
-                thread.wait()
-
-        progress.done()
-
-    def find_free_thread(self):
-        """ Generator that returns the next connection that is free for
-            sending a query.
-        """
-        ready = self.threads
-        command_stat = 0
-
-        while True:
-            for thread in ready:
-                if thread.is_done():
-                    command_stat += 1
-                    yield thread
-
-            # refresh the connections occasionaly to avoid potential
-            # memory leaks in Postgresql.
-            if command_stat > 100000:
-                for thread in self.threads:
-                    while not thread.is_done():
-                        thread.wait()
-                    thread.connect()
-                command_stat = 0
-                ready = self.threads
-            else:
-                ready, _, _ = select.select(self.threads, [], [])
-
-        assert False, "Unreachable code"
+        return progress.done()