]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/indexer/indexer.py
properly close connections when shutting down starlette
[nominatim.git] / nominatim / indexer / indexer.py
index 555f8704a19c6796da4b97a724cd363d183f7f12..233423f03c6a202ec088cfeb0fe7ac26c79db01f 100644 (file)
@@ -7,15 +7,18 @@
 """
 Main work horse for indexing (computing addresses) the database.
 """
 """
 Main work horse for indexing (computing addresses) the database.
 """
+from typing import Optional, Any, cast
 import logging
 import time
 
 import psycopg2.extras
 
 import logging
 import time
 
 import psycopg2.extras
 
+from nominatim.tokenizer.base import AbstractTokenizer
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
 from nominatim.db.async_connection import DBConnection, WorkerPool
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
 from nominatim.db.async_connection import DBConnection, WorkerPool
-from nominatim.db.connection import connect
+from nominatim.db.connection import connect, Connection, Cursor
+from nominatim.typing import DictCursorResults
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
@@ -23,10 +26,11 @@ LOG = logging.getLogger()
 class PlaceFetcher:
     """ Asynchronous connection that fetches place details for processing.
     """
 class PlaceFetcher:
     """ Asynchronous connection that fetches place details for processing.
     """
-    def __init__(self, dsn, setup_conn):
-        self.wait_time = 0
-        self.current_ids = None
-        self.conn = DBConnection(dsn, cursor_factory=psycopg2.extras.DictCursor)
+    def __init__(self, dsn: str, setup_conn: Connection) -> None:
+        self.wait_time = 0.0
+        self.current_ids: Optional[DictCursorResults] = None
+        self.conn: Optional[DBConnection] = DBConnection(dsn,
+                                               cursor_factory=psycopg2.extras.DictCursor)
 
         with setup_conn.cursor() as cur:
             # need to fetch those manually because register_hstore cannot
 
         with setup_conn.cursor() as cur:
             # need to fetch those manually because register_hstore cannot
@@ -37,7 +41,7 @@ class PlaceFetcher:
         psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
                                         array_oid=hstore_array_oid)
 
         psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
                                         array_oid=hstore_array_oid)
 
-    def close(self):
+    def close(self) -> None:
         """ Close the underlying asynchronous connection.
         """
         if self.conn:
         """ Close the underlying asynchronous connection.
         """
         if self.conn:
@@ -45,44 +49,46 @@ class PlaceFetcher:
             self.conn = None
 
 
             self.conn = None
 
 
-    def fetch_next_batch(self, cur, runner):
+    def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
         """ Send a request for the next batch of places.
             If details for the places are required, they will be fetched
             asynchronously.
 
             Returns true if there is still data available.
         """
         """ Send a request for the next batch of places.
             If details for the places are required, they will be fetched
             asynchronously.
 
             Returns true if there is still data available.
         """
-        ids = cur.fetchmany(100)
+        ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
 
         if not ids:
             self.current_ids = None
             return False
 
 
         if not ids:
             self.current_ids = None
             return False
 
-        if hasattr(runner, 'get_place_details'):
-            runner.get_place_details(self.conn, ids)
-            self.current_ids = []
-        else:
-            self.current_ids = ids
+        assert self.conn is not None
+        self.current_ids = runner.get_place_details(self.conn, ids)
 
         return True
 
 
         return True
 
-    def get_batch(self):
+    def get_batch(self) -> DictCursorResults:
         """ Get the next batch of data, previously requested with
             `fetch_next_batch`.
         """
         """ Get the next batch of data, previously requested with
             `fetch_next_batch`.
         """
+        assert self.conn is not None
+        assert self.conn.cursor is not None
+
         if self.current_ids is not None and not self.current_ids:
             tstart = time.time()
             self.conn.wait()
             self.wait_time += time.time() - tstart
         if self.current_ids is not None and not self.current_ids:
             tstart = time.time()
             self.conn.wait()
             self.wait_time += time.time() - tstart
-            self.current_ids = self.conn.cursor.fetchall()
+            self.current_ids = cast(Optional[DictCursorResults],
+                                    self.conn.cursor.fetchall())
 
 
-        return self.current_ids
+        return self.current_ids if self.current_ids is not None else []
 
 
-    def __enter__(self):
+    def __enter__(self) -> 'PlaceFetcher':
         return self
 
 
         return self
 
 
-    def __exit__(self, exc_type, exc_value, traceback):
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        assert self.conn is not None
         self.conn.wait()
         self.close()
 
         self.conn.wait()
         self.close()
 
@@ -91,13 +97,13 @@ class Indexer:
     """ Main indexing routine.
     """
 
     """ Main indexing routine.
     """
 
-    def __init__(self, dsn, tokenizer, num_threads):
+    def __init__(self, dsn: str, tokenizer: AbstractTokenizer, num_threads: int):
         self.dsn = dsn
         self.tokenizer = tokenizer
         self.num_threads = num_threads
 
 
         self.dsn = dsn
         self.tokenizer = tokenizer
         self.num_threads = num_threads
 
 
-    def has_pending(self):
+    def has_pending(self) -> bool:
         """ Check if any data still needs indexing.
             This function must only be used after the import has finished.
             Otherwise it will be very expensive.
         """ Check if any data still needs indexing.
             This function must only be used after the import has finished.
             Otherwise it will be very expensive.
@@ -108,7 +114,7 @@ class Indexer:
                 return cur.rowcount > 0
 
 
                 return cur.rowcount > 0
 
 
-    def index_full(self, analyse=True):
+    def index_full(self, analyse: bool = True) -> None:
         """ Index the complete database. This will first index boundaries
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
         """ Index the complete database. This will first index boundaries
             followed by all other objects. When `analyse` is True, then the
             database will be analysed at the appropriate places to
@@ -117,66 +123,72 @@ class Indexer:
         with connect(self.dsn) as conn:
             conn.autocommit = True
 
         with connect(self.dsn) as conn:
             conn.autocommit = True
 
-            def _analyze():
+            def _analyze() -> None:
                 if analyse:
                     with conn.cursor() as cur:
                         cur.execute('ANALYZE')
 
                 if analyse:
                     with conn.cursor() as cur:
                         cur.execute('ANALYZE')
 
-            self.index_by_rank(0, 4)
-            _analyze()
+            if self.index_by_rank(0, 4) > 0:
+                _analyze()
 
 
-            self.index_boundaries(0, 30)
-            _analyze()
+            if self.index_boundaries(0, 30) > 100:
+                _analyze()
 
 
-            self.index_by_rank(5, 25)
-            _analyze()
+            if self.index_by_rank(5, 25) > 100:
+                _analyze()
 
 
-            self.index_by_rank(26, 30)
-            _analyze()
+            if self.index_by_rank(26, 30) > 1000:
+                _analyze()
 
 
-            self.index_postcodes()
-            _analyze()
+            if self.index_postcodes() > 100:
+                _analyze()
 
 
 
 
-    def index_boundaries(self, minrank, maxrank):
+    def index_boundaries(self, minrank: int, maxrank: int) -> int:
         """ Index only administrative boundaries within the given rank range.
         """
         """ Index only administrative boundaries within the given rank range.
         """
+        total = 0
         LOG.warning("Starting indexing boundaries using %s threads",
                     self.num_threads)
 
         with self.tokenizer.name_analyzer() as analyzer:
             for rank in range(max(minrank, 4), min(maxrank, 26)):
         LOG.warning("Starting indexing boundaries using %s threads",
                     self.num_threads)
 
         with self.tokenizer.name_analyzer() as analyzer:
             for rank in range(max(minrank, 4), min(maxrank, 26)):
-                self._index(runners.BoundaryRunner(rank, analyzer))
+                total += self._index(runners.BoundaryRunner(rank, analyzer))
+
+        return total
 
 
-    def index_by_rank(self, minrank, maxrank):
+    def index_by_rank(self, minrank: int, maxrank: int) -> int:
         """ Index all entries of placex in the given rank range (inclusive)
             in order of their address rank.
 
             When rank 30 is requested then also interpolations and
             places with address rank 0 will be indexed.
         """
         """ Index all entries of placex in the given rank range (inclusive)
             in order of their address rank.
 
             When rank 30 is requested then also interpolations and
             places with address rank 0 will be indexed.
         """
+        total = 0
         maxrank = min(maxrank, 30)
         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
                     minrank, maxrank, self.num_threads)
 
         with self.tokenizer.name_analyzer() as analyzer:
             for rank in range(max(1, minrank), maxrank + 1):
         maxrank = min(maxrank, 30)
         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
                     minrank, maxrank, self.num_threads)
 
         with self.tokenizer.name_analyzer() as analyzer:
             for rank in range(max(1, minrank), maxrank + 1):
-                self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1)
+                total += self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1)
 
             if maxrank == 30:
 
             if maxrank == 30:
-                self._index(runners.RankRunner(0, analyzer))
-                self._index(runners.InterpolationRunner(analyzer), 20)
+                total += self._index(runners.RankRunner(0, analyzer))
+                total += self._index(runners.InterpolationRunner(analyzer), 20)
+
+        return total
 
 
 
 
-    def index_postcodes(self):
-        """Index the entries ofthe location_postcode table.
+    def index_postcodes(self) -> int:
+        """Index the entries of the location_postcode table.
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
 
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
 
-        self._index(runners.PostcodeRunner(), 20)
+        return self._index(runners.PostcodeRunner(), 20)
 
 
 
 
-    def update_status_table(self):
+    def update_status_table(self) -> None:
         """ Update the status in the status table to 'indexed'.
         """
         with connect(self.dsn) as conn:
         """ Update the status in the status table to 'indexed'.
         """
         with connect(self.dsn) as conn:
@@ -185,7 +197,7 @@ class Indexer:
 
             conn.commit()
 
 
             conn.commit()
 
-    def _index(self, runner, batch=1):
+    def _index(self, runner: runners.Runner, batch: int = 1) -> int:
         """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
@@ -215,7 +227,7 @@ class Indexer:
                                 # asynchronously get the next batch
                                 has_more = fetcher.fetch_next_batch(cur, runner)
 
                                 # asynchronously get the next batch
                                 has_more = fetcher.fetch_next_batch(cur, runner)
 
-                                # And insert the curent batch
+                                # And insert the current batch
                                 for idx in range(0, len(places), batch):
                                     part = places[idx:idx + batch]
                                     LOG.debug("Processing places: %s", str(part))
                                 for idx in range(0, len(places), batch):
                                     part = places[idx:idx + batch]
                                     LOG.debug("Processing places: %s", str(part))
@@ -227,4 +239,4 @@ class Indexer:
 
                 conn.commit()
 
 
                 conn.commit()
 
-        progress.done()
+        return progress.done()