]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/utils.py
Documentation fix: should be "nominatim refresh"
[nominatim.git] / nominatim / db / utils.py
index b376940d804af364c049864a07649c897b515f0f..b859afa8137e9254b4a4a88ce3d8236fb1383d8c 100644 (file)
@@ -1,9 +1,16 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Helper functions for handling DB accesses.
 """
 import subprocess
 import logging
 import gzip
 """
 Helper functions for handling DB accesses.
 """
 import subprocess
 import logging
 import gzip
+import io
 
 from nominatim.db.connection import get_pg_env
 from nominatim.errors import UsageError
 
 from nominatim.db.connection import get_pg_env
 from nominatim.errors import UsageError
@@ -33,27 +40,74 @@ def execute_file(dsn, fname, ignore_errors=False, pre_code=None, post_code=None)
         cmd.extend(('-v', 'ON_ERROR_STOP=1'))
     if not LOG.isEnabledFor(logging.INFO):
         cmd.append('--quiet')
         cmd.extend(('-v', 'ON_ERROR_STOP=1'))
     if not LOG.isEnabledFor(logging.INFO):
         cmd.append('--quiet')
-    proc = subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE)
 
 
-    try:
-        if not LOG.isEnabledFor(logging.INFO):
-            proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
+    with subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE) as proc:
+        try:
+            if not LOG.isEnabledFor(logging.INFO):
+                proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
 
 
-        if pre_code:
-            proc.stdin.write((pre_code + ';').encode('utf-8'))
+            if pre_code:
+                proc.stdin.write((pre_code + ';').encode('utf-8'))
 
 
-        if fname.suffix == '.gz':
-            with gzip.open(str(fname), 'rb') as fdesc:
-                remain = _pipe_to_proc(proc, fdesc)
-        else:
-            with fname.open('rb') as fdesc:
-                remain = _pipe_to_proc(proc, fdesc)
+            if fname.suffix == '.gz':
+                with gzip.open(str(fname), 'rb') as fdesc:
+                    remain = _pipe_to_proc(proc, fdesc)
+            else:
+                with fname.open('rb') as fdesc:
+                    remain = _pipe_to_proc(proc, fdesc)
 
 
-        if remain == 0 and post_code:
-            proc.stdin.write((';' + post_code).encode('utf-8'))
-    finally:
-        proc.stdin.close()
-        ret = proc.wait()
+            if remain == 0 and post_code:
+                proc.stdin.write((';' + post_code).encode('utf-8'))
+        finally:
+            proc.stdin.close()
+            ret = proc.wait()
 
     if ret != 0 or remain > 0:
         raise UsageError("Failed to execute SQL file.")
 
     if ret != 0 or remain > 0:
         raise UsageError("Failed to execute SQL file.")
+
+
+# List of characters that need to be quoted for the copy command.
+_SQL_TRANSLATION = {ord('\\'): '\\\\',
+                    ord('\t'): '\\t',
+                    ord('\n'): '\\n'}
+
+
+class CopyBuffer:
+    """ Data collector for the copy_from command.
+    """
+
+    def __init__(self):
+        self.buffer = io.StringIO()
+
+
+    def __enter__(self):
+        return self
+
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        if self.buffer is not None:
+            self.buffer.close()
+
+
+    def add(self, *data):
+        """ Add another row of data to the copy buffer.
+        """
+        first = True
+        for column in data:
+            if first:
+                first = False
+            else:
+                self.buffer.write('\t')
+            if column is None:
+                self.buffer.write('\\N')
+            else:
+                self.buffer.write(str(column).translate(_SQL_TRANSLATION))
+        self.buffer.write('\n')
+
+
+    def copy_out(self, cur, table, columns=None):
+        """ Copy all collected data into the given table.
+        """
+        if self.buffer.tell() > 0:
+            self.buffer.seek(0)
+            cur.copy_from(self.buffer, table, columns=columns)