]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/replication.py
Removed unnecessary check for --prepare-database flag
[nominatim.git] / nominatim / tools / replication.py
index d6e8089161bc96cc71f1119b8252e3046af654b3..edd63e49a15931d289b2fd488737ad8d105dc532 100644 (file)
@@ -1,26 +1,40 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Functions for updating a database from a replication source.
 """
 """
 Functions for updating a database from a replication source.
 """
+from typing import ContextManager, MutableMapping, Any, Generator, cast, Iterator
+from contextlib import contextmanager
 import datetime as dt
 from enum import Enum
 import logging
 import time
 import datetime as dt
 from enum import Enum
 import logging
 import time
+import types
+import urllib.request as urlrequest
 
 
+import requests
 from nominatim.db import status
 from nominatim.db import status
+from nominatim.db.connection import Connection, connect
 from nominatim.tools.exec_utils import run_osm2pgsql
 from nominatim.errors import UsageError
 
 try:
     from osmium.replication.server import ReplicationServer
     from osmium import WriteHandler
 from nominatim.tools.exec_utils import run_osm2pgsql
 from nominatim.errors import UsageError
 
 try:
     from osmium.replication.server import ReplicationServer
     from osmium import WriteHandler
-except ModuleNotFoundError as exc:
-    logging.getLogger().fatal("pyosmium not installed. Replication functions not available.\n"
-                              "To install pyosmium via pip: pip3 install osmium")
+    from osmium import version as pyo_version
+except ImportError as exc:
+    logging.getLogger().critical("pyosmium not installed. Replication functions not available.\n"
+                                 "To install pyosmium via pip: pip3 install osmium")
     raise UsageError("replication tools not available") from exc
 
 LOG = logging.getLogger()
 
     raise UsageError("replication tools not available") from exc
 
 LOG = logging.getLogger()
 
-def init_replication(conn, base_url):
+def init_replication(conn: Connection, base_url: str,
+                     socket_timeout: int = 60) -> None:
     """ Set up replication for the server at the given base URL.
     """
     LOG.info("Using replication source: %s", base_url)
     """ Set up replication for the server at the given base URL.
     """
     LOG.info("Using replication source: %s", base_url)
@@ -29,9 +43,8 @@ def init_replication(conn, base_url):
     # margin of error to make sure we get all data
     date -= dt.timedelta(hours=3)
 
     # margin of error to make sure we get all data
     date -= dt.timedelta(hours=3)
 
-    repl = ReplicationServer(base_url)
-
-    seq = repl.timestamp_to_sequence(date)
+    with _make_replication_server(base_url, socket_timeout) as repl:
+        seq = repl.timestamp_to_sequence(date)
 
     if seq is None:
         LOG.fatal("Cannot reach the configured replication service '%s'.\n"
 
     if seq is None:
         LOG.fatal("Cannot reach the configured replication service '%s'.\n"
@@ -41,10 +54,11 @@ def init_replication(conn, base_url):
 
     status.set_status(conn, date=date, seq=seq)
 
 
     status.set_status(conn, date=date, seq=seq)
 
-    LOG.warning("Updates intialised at sequence %s (%s)", seq, date)
+    LOG.warning("Updates initialised at sequence %s (%s)", seq, date)
 
 
 
 
-def check_for_updates(conn, base_url):
+def check_for_updates(conn: Connection, base_url: str,
+                      socket_timeout: int = 60) -> int:
     """ Check if new data is available from the replication service at the
         given base URL.
     """
     """ Check if new data is available from the replication service at the
         given base URL.
     """
@@ -55,7 +69,8 @@ def check_for_updates(conn, base_url):
                   "Please run 'nominatim replication --init' first.")
         return 254
 
                   "Please run 'nominatim replication --init' first.")
         return 254
 
-    state = ReplicationServer(base_url).get_state_info()
+    with _make_replication_server(base_url, socket_timeout) as repl:
+        state = repl.get_state_info()
 
     if state is None:
         LOG.error("Cannot get state for URL %s.", base_url)
 
     if state is None:
         LOG.error("Cannot get state for URL %s.", base_url)
@@ -77,17 +92,22 @@ class UpdateState(Enum):
     NO_CHANGES = 3
 
 
     NO_CHANGES = 3
 
 
-def update(conn, options):
+def update(dsn: str, options: MutableMapping[str, Any],
+           socket_timeout: int = 60) -> UpdateState:
     """ Update database from the next batch of data. Returns the state of
         updates according to `UpdateState`.
     """
     """ Update database from the next batch of data. Returns the state of
         updates according to `UpdateState`.
     """
-    startdate, startseq, indexed = status.get_status(conn)
+    with connect(dsn) as conn:
+        startdate, startseq, indexed = status.get_status(conn)
+        conn.commit()
 
     if startseq is None:
         LOG.error("Replication not set up. "
                   "Please run 'nominatim replication --init' first.")
         raise UsageError("Replication not set up.")
 
 
     if startseq is None:
         LOG.error("Replication not set up. "
                   "Please run 'nominatim replication --init' first.")
         raise UsageError("Replication not set up.")
 
+    assert startdate is not None
+
     if not indexed and options['indexed_only']:
         LOG.info("Skipping update. There is data that needs indexing.")
         return UpdateState.MORE_PENDING
     if not indexed and options['indexed_only']:
         LOG.info("Skipping update. There is data that needs indexing.")
         return UpdateState.MORE_PENDING
@@ -103,24 +123,83 @@ def update(conn, options):
         options['import_file'].unlink()
 
     # Read updates into file.
         options['import_file'].unlink()
 
     # Read updates into file.
-    repl = ReplicationServer(options['base_url'])
+    with _make_replication_server(options['base_url'], socket_timeout) as repl:
+        outhandler = WriteHandler(str(options['import_file']))
+        endseq = repl.apply_diffs(outhandler, startseq + 1,
+                                  max_size=options['max_diff_size'] * 1024)
+        outhandler.close()
+
+        if endseq is None:
+            return UpdateState.NO_CHANGES
+
+        with connect(dsn) as conn:
+            run_osm2pgsql_updates(conn, options)
+
+            # Write the current status to the file
+            endstate = repl.get_state_info(endseq)
+            status.set_status(conn, endstate.timestamp if endstate else None,
+                              seq=endseq, indexed=False)
+            conn.commit()
+
+    return UpdateState.UP_TO_DATE
 
 
-    outhandler = WriteHandler(str(options['import_file']))
-    endseq = repl.apply_diffs(outhandler, startseq + 1,
-                              max_size=options['max_diff_size'] * 1024)
-    outhandler.close()
 
 
-    if endseq is None:
-        return UpdateState.NO_CHANGES
+def run_osm2pgsql_updates(conn: Connection, options: MutableMapping[str, Any]) -> None:
+    """ Run osm2pgsql in append mode.
+    """
+    # Remove any stale deletion marks.
+    with conn.cursor() as cur:
+        cur.execute('TRUNCATE place_to_be_deleted')
+    conn.commit()
 
     # Consume updates with osm2pgsql.
     options['append'] = True
     options['disable_jit'] = conn.server_version_tuple() >= (11, 0)
     run_osm2pgsql(options)
 
 
     # Consume updates with osm2pgsql.
     options['append'] = True
     options['disable_jit'] = conn.server_version_tuple() >= (11, 0)
     run_osm2pgsql(options)
 
-    # Write the current status to the file
-    endstate = repl.get_state_info(endseq)
-    status.set_status(conn, endstate.timestamp if endstate else None,
-                      seq=endseq, indexed=False)
+    # Handle deletions
+    with conn.cursor() as cur:
+        cur.execute('SELECT flush_deleted_places()')
+    conn.commit()
 
 
-    return UpdateState.UP_TO_DATE
+
+def _make_replication_server(url: str, timeout: int) -> ContextManager[ReplicationServer]:
+    """ Returns a ReplicationServer in form of a context manager.
+
+        Creates a light wrapper around older versions of pyosmium that did
+        not support the context manager interface.
+    """
+    if hasattr(ReplicationServer, '__enter__'):
+        # Patches the open_url function for pyosmium >= 3.2
+        # where the socket timeout is no longer respected.
+        def patched_open_url(self: ReplicationServer, url: urlrequest.Request) -> Any:
+            """ Download a resource from the given URL and return a byte sequence
+                of the content.
+            """
+            headers = {"User-Agent" : f"Nominatim (pyosmium/{pyo_version.pyosmium_release})"}
+
+            if self.session is not None:
+                return self.session.get(url.get_full_url(),
+                                       headers=headers, timeout=timeout or None,
+                                       stream=True)
+
+            @contextmanager
+            def _get_url_with_session() -> Iterator[requests.Response]:
+                with requests.Session() as session:
+                    request = session.get(url.get_full_url(),
+                                          headers=headers, timeout=timeout or None,
+                                          stream=True)
+                    yield request
+
+            return _get_url_with_session()
+
+        repl = ReplicationServer(url)
+        setattr(repl, 'open_url', types.MethodType(patched_open_url, repl))
+
+        return cast(ContextManager[ReplicationServer], repl)
+
+    @contextmanager
+    def get_cm() -> Generator[ReplicationServer, None, None]:
+        yield ReplicationServer(url)
+
+    return get_cm()