]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/replication.py
remove tests that differ between lua and gazetteer versions
[nominatim.git] / nominatim / tools / replication.py
index 1c46c50c4522cda6b397bae3d65d4a8b6305d880..846b9c34dd301644080b6aa8b7c240e81739531f 100644 (file)
@@ -33,7 +33,8 @@ except ImportError as exc:
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-def init_replication(conn: Connection, base_url: str) -> None:
+def init_replication(conn: Connection, base_url: str,
+                     socket_timeout: int = 60) -> None:
     """ Set up replication for the server at the given base URL.
     """
     LOG.info("Using replication source: %s", base_url)
     """ Set up replication for the server at the given base URL.
     """
     LOG.info("Using replication source: %s", base_url)
@@ -42,9 +43,8 @@ def init_replication(conn: Connection, base_url: str) -> None:
     # margin of error to make sure we get all data
     date -= dt.timedelta(hours=3)
 
     # margin of error to make sure we get all data
     date -= dt.timedelta(hours=3)
 
-    repl = ReplicationServer(base_url)
-
-    seq = repl.timestamp_to_sequence(date)
+    with _make_replication_server(base_url, socket_timeout) as repl:
+        seq = repl.timestamp_to_sequence(date)
 
     if seq is None:
         LOG.fatal("Cannot reach the configured replication service '%s'.\n"
 
     if seq is None:
         LOG.fatal("Cannot reach the configured replication service '%s'.\n"
@@ -57,7 +57,8 @@ def init_replication(conn: Connection, base_url: str) -> None:
     LOG.warning("Updates initialised at sequence %s (%s)", seq, date)
 
 
     LOG.warning("Updates initialised at sequence %s (%s)", seq, date)
 
 
-def check_for_updates(conn: Connection, base_url: str) -> int:
+def check_for_updates(conn: Connection, base_url: str,
+                      socket_timeout: int = 60) -> int:
     """ Check if new data is available from the replication service at the
         given base URL.
     """
     """ Check if new data is available from the replication service at the
         given base URL.
     """
@@ -68,7 +69,8 @@ def check_for_updates(conn: Connection, base_url: str) -> int:
                   "Please run 'nominatim replication --init' first.")
         return 254
 
                   "Please run 'nominatim replication --init' first.")
         return 254
 
-    state = ReplicationServer(base_url).get_state_info()
+    with _make_replication_server(base_url, socket_timeout) as repl:
+        state = repl.get_state_info()
 
     if state is None:
         LOG.error("Cannot get state for URL %s.", base_url)
 
     if state is None:
         LOG.error("Cannot get state for URL %s.", base_url)
@@ -128,10 +130,7 @@ def update(conn: Connection, options: MutableMapping[str, Any],
         if endseq is None:
             return UpdateState.NO_CHANGES
 
         if endseq is None:
             return UpdateState.NO_CHANGES
 
-        # Consume updates with osm2pgsql.
-        options['append'] = True
-        options['disable_jit'] = conn.server_version_tuple() >= (11, 0)
-        run_osm2pgsql(options)
+        run_osm2pgsql_updates(conn, options)
 
         # Write the current status to the file
         endstate = repl.get_state_info(endseq)
 
         # Write the current status to the file
         endstate = repl.get_state_info(endseq)
@@ -141,6 +140,25 @@ def update(conn: Connection, options: MutableMapping[str, Any],
     return UpdateState.UP_TO_DATE
 
 
     return UpdateState.UP_TO_DATE
 
 
+def run_osm2pgsql_updates(conn: Connection, options: MutableMapping[str, Any]) -> None:
+    """ Run osm2pgsql in append mode.
+    """
+    # Remove any stale deletion marks.
+    with conn.cursor() as cur:
+        cur.execute('TRUNCATE place_to_be_deleted')
+    conn.commit()
+
+    # Consume updates with osm2pgsql.
+    options['append'] = True
+    options['disable_jit'] = conn.server_version_tuple() >= (11, 0)
+    run_osm2pgsql(options)
+
+    # Handle deletions
+    with conn.cursor() as cur:
+        cur.execute('SELECT flush_deleted_places()')
+    conn.commit()
+
+
 def _make_replication_server(url: str, timeout: int) -> ContextManager[ReplicationServer]:
     """ Returns a ReplicationServer in form of a context manager.
 
 def _make_replication_server(url: str, timeout: int) -> ContextManager[ReplicationServer]:
     """ Returns a ReplicationServer in form of a context manager.
 
@@ -154,25 +172,25 @@ def _make_replication_server(url: str, timeout: int) -> ContextManager[Replicati
             """ Download a resource from the given URL and return a byte sequence
                 of the content.
             """
             """ Download a resource from the given URL and return a byte sequence
                 of the content.
             """
-            get_params = {
-                'headers': {"User-Agent" : f"Nominatim (pyosmium/{pyo_version.pyosmium_release})"},
-                'timeout': timeout or None,
-                'stream': True
-            }
+            headers = {"User-Agent" : f"Nominatim (pyosmium/{pyo_version.pyosmium_release})"}
 
             if self.session is not None:
 
             if self.session is not None:
-                return self.session.get(url.get_full_url(), **get_params)
+                return self.session.get(url.get_full_url(),
+                                       headers=headers, timeout=timeout or None,
+                                       stream=True)
 
             @contextmanager
             def _get_url_with_session() -> Iterator[requests.Response]:
                 with requests.Session() as session:
 
             @contextmanager
             def _get_url_with_session() -> Iterator[requests.Response]:
                 with requests.Session() as session:
-                    request = session.get(url.get_full_url(), **get_params) # type: ignore
+                    request = session.get(url.get_full_url(),
+                                          headers=headers, timeout=timeout or None,
+                                          stream=True)
                     yield request
 
             return _get_url_with_session()
 
         repl = ReplicationServer(url)
                     yield request
 
             return _get_url_with_session()
 
         repl = ReplicationServer(url)
-        repl.open_url = types.MethodType(patched_open_url, repl)
+        setattr(repl, 'open_url', types.MethodType(patched_open_url, repl))
 
         return cast(ContextManager[ReplicationServer], repl)
 
 
         return cast(ContextManager[ReplicationServer], repl)