]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/utils.py
type annotations for DB utils
[nominatim.git] / nominatim / db / utils.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Helper functions for handling DB accesses.
9 """
10 from typing import IO, Optional, Union
11 import subprocess
12 import logging
13 import gzip
14 import io
15 from pathlib import Path
16
17 from nominatim.db.connection import get_pg_env
18 from nominatim.errors import UsageError
19
20 LOG = logging.getLogger()
21
22 def _pipe_to_proc(proc: subprocess.Popen[bytes],
23                   fdesc: Union[IO[bytes], gzip.GzipFile]) -> int:
24     assert proc.stdin is not None
25     chunk = fdesc.read(2048)
26     while chunk and proc.poll() is None:
27         try:
28             proc.stdin.write(chunk)
29         except BrokenPipeError as exc:
30             raise UsageError("Failed to execute SQL file.") from exc
31         chunk = fdesc.read(2048)
32
33     return len(chunk)
34
35 def execute_file(dsn: str, fname: Path,
36                  ignore_errors: bool = False,
37                  pre_code: Optional[str] = None,
38                  post_code: Optional[str] = None) -> None:
39     """ Read an SQL file and run its contents against the given database
40         using psql. Use `pre_code` and `post_code` to run extra commands
41         before or after executing the file. The commands are run within the
42         same session, so they may be used to wrap the file execution in a
43         transaction.
44     """
45     cmd = ['psql']
46     if not ignore_errors:
47         cmd.extend(('-v', 'ON_ERROR_STOP=1'))
48     if not LOG.isEnabledFor(logging.INFO):
49         cmd.append('--quiet')
50
51     with subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE) as proc:
52         assert proc.stdin is not None
53         try:
54             if not LOG.isEnabledFor(logging.INFO):
55                 proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
56
57             if pre_code:
58                 proc.stdin.write((pre_code + ';').encode('utf-8'))
59
60             if fname.suffix == '.gz':
61                 with gzip.open(str(fname), 'rb') as fdesc:
62                     remain = _pipe_to_proc(proc, fdesc)
63             else:
64                 with fname.open('rb') as fdesc:
65                     remain = _pipe_to_proc(proc, fdesc)
66
67             if remain == 0 and post_code:
68                 proc.stdin.write((';' + post_code).encode('utf-8'))
69         finally:
70             proc.stdin.close()
71             ret = proc.wait()
72
73     if ret != 0 or remain > 0:
74         raise UsageError("Failed to execute SQL file.")
75
76
77 # List of characters that need to be quoted for the copy command.
78 _SQL_TRANSLATION = {ord('\\'): '\\\\',
79                     ord('\t'): '\\t',
80                     ord('\n'): '\\n'}
81
82
83 class CopyBuffer:
84     """ Data collector for the copy_from command.
85     """
86
87     def __init__(self):
88         self.buffer = io.StringIO()
89
90
91     def __enter__(self):
92         return self
93
94
95     def __exit__(self, exc_type, exc_value, traceback):
96         if self.buffer is not None:
97             self.buffer.close()
98
99
100     def add(self, *data):
101         """ Add another row of data to the copy buffer.
102         """
103         first = True
104         for column in data:
105             if first:
106                 first = False
107             else:
108                 self.buffer.write('\t')
109             if column is None:
110                 self.buffer.write('\\N')
111             else:
112                 self.buffer.write(str(column).translate(_SQL_TRANSLATION))
113         self.buffer.write('\n')
114
115
116     def copy_out(self, cur, table, columns=None):
117         """ Copy all collected data into the given table.
118         """
119         if self.buffer.tell() > 0:
120             self.buffer.seek(0)
121             cur.copy_from(self.buffer, table, columns=columns)