]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/utils.py
enable API use with psycopg 3
[nominatim.git] / nominatim / db / utils.py
index 0490bbc869fe3f19a4bc87eac2fbed73d4b54d6f..9a7b4f164787b8abb03831477fe3b876357e9b25 100644 (file)
@@ -1,16 +1,27 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Helper functions for handling DB accesses.
 """
 """
 Helper functions for handling DB accesses.
 """
+from typing import IO, Optional, Union, Any, Iterable
 import subprocess
 import logging
 import gzip
 import subprocess
 import logging
 import gzip
+import io
+from pathlib import Path
 
 
-from .connection import get_pg_env
-from ..errors import UsageError
+from nominatim.db.connection import get_pg_env, Cursor
+from nominatim.errors import UsageError
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-def _pipe_to_proc(proc, fdesc):
+def _pipe_to_proc(proc: 'subprocess.Popen[bytes]',
+                  fdesc: Union[IO[bytes], gzip.GzipFile]) -> int:
+    assert proc.stdin is not None
     chunk = fdesc.read(2048)
     while chunk and proc.poll() is None:
         try:
     chunk = fdesc.read(2048)
     while chunk and proc.poll() is None:
         try:
@@ -21,7 +32,10 @@ def _pipe_to_proc(proc, fdesc):
 
     return len(chunk)
 
 
     return len(chunk)
 
-def execute_file(dsn, fname, ignore_errors=False, pre_code=None, post_code=None):
+def execute_file(dsn: str, fname: Path,
+                 ignore_errors: bool = False,
+                 pre_code: Optional[str] = None,
+                 post_code: Optional[str] = None) -> None:
     """ Read an SQL file and run its contents against the given database
         using psql. Use `pre_code` and `post_code` to run extra commands
         before or after executing the file. The commands are run within the
     """ Read an SQL file and run its contents against the given database
         using psql. Use `pre_code` and `post_code` to run extra commands
         before or after executing the file. The commands are run within the
@@ -31,26 +45,77 @@ def execute_file(dsn, fname, ignore_errors=False, pre_code=None, post_code=None)
     cmd = ['psql']
     if not ignore_errors:
         cmd.extend(('-v', 'ON_ERROR_STOP=1'))
     cmd = ['psql']
     if not ignore_errors:
         cmd.extend(('-v', 'ON_ERROR_STOP=1'))
-    proc = subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE)
-
     if not LOG.isEnabledFor(logging.INFO):
     if not LOG.isEnabledFor(logging.INFO):
-        proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
+        cmd.append('--quiet')
 
 
-    if pre_code:
-        proc.stdin.write((pre_code + ';').encode('utf-8'))
+    with subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE) as proc:
+        assert proc.stdin is not None
+        try:
+            if not LOG.isEnabledFor(logging.INFO):
+                proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
 
 
-    if fname.suffix == '.gz':
-        with gzip.open(str(fname), 'rb') as fdesc:
-            remain = _pipe_to_proc(proc, fdesc)
-    else:
-        with fname.open('rb') as fdesc:
-            remain = _pipe_to_proc(proc, fdesc)
+            if pre_code:
+                proc.stdin.write((pre_code + ';').encode('utf-8'))
 
 
-    if remain == 0 and post_code:
-        proc.stdin.write((';' + post_code).encode('utf-8'))
+            if fname.suffix == '.gz':
+                with gzip.open(str(fname), 'rb') as fdesc:
+                    remain = _pipe_to_proc(proc, fdesc)
+            else:
+                with fname.open('rb') as fdesc:
+                    remain = _pipe_to_proc(proc, fdesc)
 
 
-    proc.stdin.close()
+            if remain == 0 and post_code:
+                proc.stdin.write((';' + post_code).encode('utf-8'))
+        finally:
+            proc.stdin.close()
+            ret = proc.wait()
 
 
-    ret = proc.wait()
     if ret != 0 or remain > 0:
         raise UsageError("Failed to execute SQL file.")
     if ret != 0 or remain > 0:
         raise UsageError("Failed to execute SQL file.")
+
+
+# List of characters that need to be quoted for the copy command.
+_SQL_TRANSLATION = {ord('\\'): '\\\\',
+                    ord('\t'): '\\t',
+                    ord('\n'): '\\n'}
+
+
+class CopyBuffer:
+    """ Data collector for the copy_from command.
+    """
+
+    def __init__(self) -> None:
+        self.buffer = io.StringIO()
+
+
+    def __enter__(self) -> 'CopyBuffer':
+        return self
+
+
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        if self.buffer is not None:
+            self.buffer.close()
+
+
+    def add(self, *data: Any) -> None:
+        """ Add another row of data to the copy buffer.
+        """
+        first = True
+        for column in data:
+            if first:
+                first = False
+            else:
+                self.buffer.write('\t')
+            if column is None:
+                self.buffer.write('\\N')
+            else:
+                self.buffer.write(str(column).translate(_SQL_TRANSLATION))
+        self.buffer.write('\n')
+
+
+    def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
+        """ Copy all collected data into the given table.
+        """
+        if self.buffer.tell() > 0:
+            self.buffer.seek(0)
+            cur.copy_from(self.buffer, table, columns=columns) # type: ignore[no-untyped-call]