]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_core/db/utils.py
adapt bdd tests to new layout
[nominatim.git] / src / nominatim_core / db / utils.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Helper functions for handling DB accesses.
9 """
10 from typing import IO, Optional, Union, Any, Iterable
11 import subprocess
12 import logging
13 import gzip
14 import io
15 from pathlib import Path
16
17 from .connection import get_pg_env, Cursor
18 from ..errors import UsageError
19
20 LOG = logging.getLogger()
21
22 def _pipe_to_proc(proc: 'subprocess.Popen[bytes]',
23                   fdesc: Union[IO[bytes], gzip.GzipFile]) -> int:
24     assert proc.stdin is not None
25     chunk = fdesc.read(2048)
26     while chunk and proc.poll() is None:
27         try:
28             proc.stdin.write(chunk)
29         except BrokenPipeError as exc:
30             raise UsageError("Failed to execute SQL file.") from exc
31         chunk = fdesc.read(2048)
32
33     return len(chunk)
34
35 def execute_file(dsn: str, fname: Path,
36                  ignore_errors: bool = False,
37                  pre_code: Optional[str] = None,
38                  post_code: Optional[str] = None) -> None:
39     """ Read an SQL file and run its contents against the given database
40         using psql. Use `pre_code` and `post_code` to run extra commands
41         before or after executing the file. The commands are run within the
42         same session, so they may be used to wrap the file execution in a
43         transaction.
44     """
45     cmd = ['psql']
46     if not ignore_errors:
47         cmd.extend(('-v', 'ON_ERROR_STOP=1'))
48     if not LOG.isEnabledFor(logging.INFO):
49         cmd.append('--quiet')
50
51     with subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE) as proc:
52         assert proc.stdin is not None
53         try:
54             if not LOG.isEnabledFor(logging.INFO):
55                 proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
56
57             if pre_code:
58                 proc.stdin.write((pre_code + ';').encode('utf-8'))
59
60             if fname.suffix == '.gz':
61                 with gzip.open(str(fname), 'rb') as fdesc:
62                     remain = _pipe_to_proc(proc, fdesc)
63             else:
64                 with fname.open('rb') as fdesc:
65                     remain = _pipe_to_proc(proc, fdesc)
66
67             if remain == 0 and post_code:
68                 proc.stdin.write((';' + post_code).encode('utf-8'))
69         finally:
70             proc.stdin.close()
71             ret = proc.wait()
72
73     if ret != 0 or remain > 0:
74         raise UsageError("Failed to execute SQL file.")
75
76
77 # List of characters that need to be quoted for the copy command.
78 _SQL_TRANSLATION = {ord('\\'): '\\\\',
79                     ord('\t'): '\\t',
80                     ord('\n'): '\\n'}
81
82
83 class CopyBuffer:
84     """ Data collector for the copy_from command.
85     """
86
87     def __init__(self) -> None:
88         self.buffer = io.StringIO()
89
90
91     def __enter__(self) -> 'CopyBuffer':
92         return self
93
94
95     def size(self) -> int:
96         """ Return the number of bytes the buffer currently contains.
97         """
98         return self.buffer.tell()
99
100     def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
101         if self.buffer is not None:
102             self.buffer.close()
103
104
105     def add(self, *data: Any) -> None:
106         """ Add another row of data to the copy buffer.
107         """
108         first = True
109         for column in data:
110             if first:
111                 first = False
112             else:
113                 self.buffer.write('\t')
114             if column is None:
115                 self.buffer.write('\\N')
116             else:
117                 self.buffer.write(str(column).translate(_SQL_TRANSLATION))
118         self.buffer.write('\n')
119
120
121     def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
122         """ Copy all collected data into the given table.
123
124             The buffer is empty and reusable after this operation.
125         """
126         if self.buffer.tell() > 0:
127             self.buffer.seek(0)
128             cur.copy_from(self.buffer, table, columns=columns)
129             self.buffer = io.StringIO()