]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/clicmd/replication.py
move warm script to python code
[nominatim.git] / nominatim / clicmd / replication.py
index 44eec5f185e9ad03accc0b2c55f27ba71678c9cb..b795650694d207ed73b32db5f566ae8289ffad58 100644 (file)
@@ -1,6 +1,14 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Implementation of the 'replication' sub-command.
 """
+from typing import Optional
+import argparse
 import datetime as dt
 import logging
 import socket
@@ -9,13 +17,14 @@ import time
 from nominatim.db import status
 from nominatim.db.connection import connect
 from nominatim.errors import UsageError
+from nominatim.clicmd.args import NominatimArgs
 
 LOG = logging.getLogger()
 
 # Do not repeat documentation of subcommand classes.
 # pylint: disable=C0111
 # Using non-top-level imports to make pyosmium optional for replication only.
-# pylint: disable=E0012,C0415
+# pylint: disable=C0415
 
 class UpdateReplication:
     """\
@@ -35,8 +44,7 @@ class UpdateReplication:
     downloads and imports the next batch of updates.
     """
 
-    @staticmethod
-    def add_args(parser):
+    def add_args(self, parser: argparse.ArgumentParser) -> None:
         group = parser.add_argument_group('Arguments for initialisation')
         group.add_argument('--init', action='store_true',
                            help='Initialise the update process')
@@ -62,42 +70,43 @@ class UpdateReplication:
         group.add_argument('--socket-timeout', dest='socket_timeout', type=int, default=60,
                            help='Set timeout for file downloads')
 
-    @staticmethod
-    def _init_replication(args):
+
+    def _init_replication(self, args: NominatimArgs) -> int:
         from ..tools import replication, refresh
 
         LOG.warning("Initialising replication updates")
         with connect(args.config.get_libpq_dsn()) as conn:
-            replication.init_replication(conn, base_url=args.config.REPLICATION_URL)
+            replication.init_replication(conn, base_url=args.config.REPLICATION_URL,
+                                         socket_timeout=args.socket_timeout)
             if args.update_functions:
                 LOG.warning("Create functions")
                 refresh.create_functions(conn, args.config, True, False)
         return 0
 
 
-    @staticmethod
-    def _check_for_updates(args):
+    def _check_for_updates(self, args: NominatimArgs) -> int:
         from ..tools import replication
 
         with connect(args.config.get_libpq_dsn()) as conn:
-            return replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL)
+            return replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL,
+                                                 socket_timeout=args.socket_timeout)
+
 
-    @staticmethod
-    def _report_update(batchdate, start_import, start_index):
-        def round_time(delta):
+    def _report_update(self, batchdate: dt.datetime,
+                       start_import: dt.datetime,
+                       start_index: Optional[dt.datetime]) -> None:
+        def round_time(delta: dt.timedelta) -> dt.timedelta:
             return dt.timedelta(seconds=int(delta.total_seconds()))
 
         end = dt.datetime.now(dt.timezone.utc)
         LOG.warning("Update completed. Import: %s. %sTotal: %s. Remaining backlog: %s.",
                     round_time((start_index or end) - start_import),
-                    "Indexing: {} ".format(round_time(end - start_index))
-                    if start_index else '',
+                    f"Indexing: {round_time(end - start_index)} " if start_index else '',
                     round_time(end - start_import),
                     round_time(end - batchdate))
 
 
-    @staticmethod
-    def _compute_update_interval(args):
+    def _compute_update_interval(self, args: NominatimArgs) -> int:
         if args.catch_up:
             return 0
 
@@ -114,13 +123,13 @@ class UpdateReplication:
         return update_interval
 
 
-    @staticmethod
-    def _update(args):
+    def _update(self, args: NominatimArgs) -> None:
+        # pylint: disable=too-many-locals
         from ..tools import replication
         from ..indexer.indexer import Indexer
         from ..tokenizer import factory as tokenizer_factory
 
-        update_interval = UpdateReplication._compute_update_interval(args)
+        update_interval = self._compute_update_interval(args)
 
         params = args.osm2pgsql_options(default_cache=2000, default_threads=1)
         params.update(base_url=args.config.REPLICATION_URL,
@@ -136,11 +145,15 @@ class UpdateReplication:
             recheck_interval = args.config.get_int('REPLICATION_RECHECK_INTERVAL')
 
         tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
+        indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, args.threads or 1)
+
+        dsn = args.config.get_libpq_dsn()
 
         while True:
-            with connect(args.config.get_libpq_dsn()) as conn:
-                start = dt.datetime.now(dt.timezone.utc)
-                state = replication.update(conn, params)
+            start = dt.datetime.now(dt.timezone.utc)
+            state = replication.update(dsn, params, socket_timeout=args.socket_timeout)
+
+            with connect(dsn) as conn:
                 if state is not replication.UpdateState.NO_CHANGES:
                     status.log_status(conn, start, 'import')
                 batchdate, _, _ = status.get_status(conn)
@@ -148,11 +161,9 @@ class UpdateReplication:
 
             if state is not replication.UpdateState.NO_CHANGES and args.do_index:
                 index_start = dt.datetime.now(dt.timezone.utc)
-                indexer = Indexer(args.config.get_libpq_dsn(), tokenizer,
-                                  args.threads or 1)
                 indexer.index_full(analyse=False)
 
-                with connect(args.config.get_libpq_dsn()) as conn:
+                with connect(dsn) as conn:
                     status.set_indexed(conn, True)
                     status.log_status(conn, index_start, 'index')
                     conn.commit()
@@ -165,7 +176,8 @@ class UpdateReplication:
                     indexer.index_full(analyse=False)
 
             if LOG.isEnabledFor(logging.WARNING):
-                UpdateReplication._report_update(batchdate, start, index_start)
+                assert batchdate is not None
+                self._report_update(batchdate, start, index_start)
 
             if args.once or (args.catch_up and state is replication.UpdateState.NO_CHANGES):
                 break
@@ -175,15 +187,14 @@ class UpdateReplication:
                 time.sleep(recheck_interval)
 
 
-    @staticmethod
-    def run(args):
+    def run(self, args: NominatimArgs) -> int:
         socket.setdefaulttimeout(args.socket_timeout)
 
         if args.init:
-            return UpdateReplication._init_replication(args)
+            return self._init_replication(args)
 
         if args.check_for_updates:
-            return UpdateReplication._check_for_updates(args)
+            return self._check_for_updates(args)
 
-        UpdateReplication._update(args)
+        self._update(args)
         return 0