]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/utils.py
Merge pull request #2186 from lonvia/port-import-to-python
[nominatim.git] / nominatim / db / utils.py
index 575f301082e044e2d04a3a31fbcc69c8dfd2f8ac..0a2e2c067cb1b291a57cd655f9a5cc86f0f71efc 100644 (file)
@@ -21,27 +21,39 @@ def _pipe_to_proc(proc, fdesc):
 
     return len(chunk)
 
-def execute_file(dsn, fname, ignore_errors=False):
+def execute_file(dsn, fname, ignore_errors=False, pre_code=None, post_code=None):
     """ Read an SQL file and run its contents against the given database
-        using psql.
+        using psql. Use `pre_code` and `post_code` to run extra commands
+        before or after executing the file. The commands are run within the
+        same session, so they may be used to wrap the file execution in a
+        transaction.
     """
     cmd = ['psql']
     if not ignore_errors:
         cmd.extend(('-v', 'ON_ERROR_STOP=1'))
+    if not LOG.isEnabledFor(logging.INFO):
+        cmd.append('--quiet')
     proc = subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE)
 
-    if not LOG.isEnabledFor(logging.INFO):
-        proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
+    try:
+        if not LOG.isEnabledFor(logging.INFO):
+            proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
+
+        if pre_code:
+            proc.stdin.write((pre_code + ';').encode('utf-8'))
 
-    if fname.suffix == '.gz':
-        with gzip.open(str(fname), 'rb') as fdesc:
-            remain = _pipe_to_proc(proc, fdesc)
-    else:
-        with fname.open('rb') as fdesc:
-            remain = _pipe_to_proc(proc, fdesc)
+        if fname.suffix == '.gz':
+            with gzip.open(str(fname), 'rb') as fdesc:
+                remain = _pipe_to_proc(proc, fdesc)
+        else:
+            with fname.open('rb') as fdesc:
+                remain = _pipe_to_proc(proc, fdesc)
 
-    proc.stdin.close()
+        if remain == 0 and post_code:
+            proc.stdin.write((';' + post_code).encode('utf-8'))
+    finally:
+        proc.stdin.close()
+        ret = proc.wait()
 
-    ret = proc.wait()
     if ret != 0 or remain > 0:
         raise UsageError("Failed to execute SQL file.")