]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/db/utils.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / db / utils.py
index b859afa8137e9254b4a4a88ce3d8236fb1383d8c..e3f0712a11b6f53551bc90e4734798e441f33142 100644 (file)
@@ -7,17 +7,21 @@
 """
 Helper functions for handling DB accesses.
 """
 """
 Helper functions for handling DB accesses.
 """
+from typing import IO, Optional, Union, Any, Iterable
 import subprocess
 import logging
 import gzip
 import io
 import subprocess
 import logging
 import gzip
 import io
+from pathlib import Path
 
 
-from nominatim.db.connection import get_pg_env
+from nominatim.db.connection import get_pg_env, Cursor
 from nominatim.errors import UsageError
 
 LOG = logging.getLogger()
 
 from nominatim.errors import UsageError
 
 LOG = logging.getLogger()
 
-def _pipe_to_proc(proc, fdesc):
+def _pipe_to_proc(proc: 'subprocess.Popen[bytes]',
+                  fdesc: Union[IO[bytes], gzip.GzipFile]) -> int:
+    assert proc.stdin is not None
     chunk = fdesc.read(2048)
     while chunk and proc.poll() is None:
         try:
     chunk = fdesc.read(2048)
     while chunk and proc.poll() is None:
         try:
@@ -28,7 +32,10 @@ def _pipe_to_proc(proc, fdesc):
 
     return len(chunk)
 
 
     return len(chunk)
 
-def execute_file(dsn, fname, ignore_errors=False, pre_code=None, post_code=None):
+def execute_file(dsn: str, fname: Path,
+                 ignore_errors: bool = False,
+                 pre_code: Optional[str] = None,
+                 post_code: Optional[str] = None) -> None:
     """ Read an SQL file and run its contents against the given database
         using psql. Use `pre_code` and `post_code` to run extra commands
         before or after executing the file. The commands are run within the
     """ Read an SQL file and run its contents against the given database
         using psql. Use `pre_code` and `post_code` to run extra commands
         before or after executing the file. The commands are run within the
@@ -42,6 +49,7 @@ def execute_file(dsn, fname, ignore_errors=False, pre_code=None, post_code=None)
         cmd.append('--quiet')
 
     with subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE) as proc:
         cmd.append('--quiet')
 
     with subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE) as proc:
+        assert proc.stdin is not None
         try:
             if not LOG.isEnabledFor(logging.INFO):
                 proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
         try:
             if not LOG.isEnabledFor(logging.INFO):
                 proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
@@ -76,20 +84,20 @@ class CopyBuffer:
     """ Data collector for the copy_from command.
     """
 
     """ Data collector for the copy_from command.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         self.buffer = io.StringIO()
 
 
         self.buffer = io.StringIO()
 
 
-    def __enter__(self):
+    def __enter__(self) -> 'CopyBuffer':
         return self
 
 
         return self
 
 
-    def __exit__(self, exc_type, exc_value, traceback):
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
         if self.buffer is not None:
             self.buffer.close()
 
 
         if self.buffer is not None:
             self.buffer.close()
 
 
-    def add(self, *data):
+    def add(self, *data: Any) -> None:
         """ Add another row of data to the copy buffer.
         """
         first = True
         """ Add another row of data to the copy buffer.
         """
         first = True
@@ -105,7 +113,7 @@ class CopyBuffer:
         self.buffer.write('\n')
 
 
         self.buffer.write('\n')
 
 
-    def copy_out(self, cur, table, columns=None):
+    def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
         """ Copy all collected data into the given table.
         """
         if self.buffer.tell() > 0:
         """ Copy all collected data into the given table.
         """
         if self.buffer.tell() > 0: