]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/tools/tiger_data.py
make DB helper functions free functions
[nominatim.git] / src / nominatim_db / tools / tiger_data.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Functions for importing tiger data and handling tarbar and directory files
9 """
10 from typing import Any, TextIO, List, Union, cast
11 import csv
12 import io
13 import logging
14 import os
15 import tarfile
16
17 from psycopg2.extras import Json
18
19 from ..config import Configuration
20 from ..db.connection import connect
21 from ..db.async_connection import WorkerPool
22 from ..db.sql_preprocessor import SQLPreprocessor
23 from ..errors import UsageError
24 from ..data.place_info import PlaceInfo
25 from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
26 from . import freeze
27
28 LOG = logging.getLogger()
29
30 class TigerInput:
31     """ Context manager that goes through Tiger input files which may
32         either be in a directory or gzipped together in a tar file.
33     """
34
35     def __init__(self, data_dir: str) -> None:
36         self.tar_handle = None
37         self.files: List[Union[str, tarfile.TarInfo]] = []
38
39         if data_dir.endswith('.tar.gz'):
40             try:
41                 self.tar_handle = tarfile.open(data_dir) # pylint: disable=consider-using-with
42             except tarfile.ReadError as err:
43                 LOG.fatal("Cannot open '%s'. Is this a tar file?", data_dir)
44                 raise UsageError("Cannot open Tiger data file.") from err
45
46             self.files = [i for i in self.tar_handle.getmembers() if i.name.endswith('.csv')]
47             LOG.warning("Found %d CSV files in tarfile with path %s", len(self.files), data_dir)
48         else:
49             files = os.listdir(data_dir)
50             self.files = [os.path.join(data_dir, i) for i in files if i.endswith('.csv')]
51             LOG.warning("Found %d CSV files in path %s", len(self.files), data_dir)
52
53         if not self.files:
54             LOG.warning("Tiger data import selected but no files found at %s", data_dir)
55
56
57     def __enter__(self) -> 'TigerInput':
58         return self
59
60
61     def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
62         if self.tar_handle:
63             self.tar_handle.close()
64             self.tar_handle = None
65
66
67     def next_file(self) -> TextIO:
68         """ Return a file handle to the next file to be processed.
69             Raises an IndexError if there is no file left.
70         """
71         fname = self.files.pop(0)
72
73         if self.tar_handle is not None:
74             extracted = self.tar_handle.extractfile(fname)
75             assert extracted is not None
76             return io.TextIOWrapper(extracted)
77
78         return open(cast(str, fname), encoding='utf-8')
79
80
81     def __len__(self) -> int:
82         return len(self.files)
83
84
85 def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO,
86                                    analyzer: AbstractAnalyzer) -> None:
87     """ Handles sql statement with multiplexing
88     """
89     lines = 0
90     # Using pool of database connections to execute sql statements
91
92     sql = "SELECT tiger_line_import(%s, %s, %s, %s, %s, %s)"
93
94     for row in csv.DictReader(fd, delimiter=';'):
95         try:
96             address = dict(street=row['street'], postcode=row['postcode'])
97             args = ('SRID=4326;' + row['geometry'],
98                     int(row['from']), int(row['to']), row['interpolation'],
99                     Json(analyzer.process_place(PlaceInfo({'address': address}))),
100                     analyzer.normalize_postcode(row['postcode']))
101         except ValueError:
102             continue
103         pool.next_free_worker().perform(sql, args=args)
104
105         lines += 1
106         if lines == 1000:
107             print('.', end='', flush=True)
108             lines = 0
109
110
111 def add_tiger_data(data_dir: str, config: Configuration, threads: int,
112                    tokenizer: AbstractTokenizer) -> int:
113     """ Import tiger data from directory or tar file `data dir`.
114     """
115     dsn = config.get_libpq_dsn()
116
117     with connect(dsn) as conn:
118         is_frozen = freeze.is_frozen(conn)
119         conn.close()
120
121         if is_frozen:
122             raise UsageError("Tiger cannot be imported when database frozen (Github issue #3048)")
123
124     with TigerInput(data_dir) as tar:
125         if not tar:
126             return 1
127
128         with connect(dsn) as conn:
129             sql = SQLPreprocessor(conn, config)
130             sql.run_sql_file(conn, 'tiger_import_start.sql')
131
132         # Reading files and then for each file line handling
133         # sql_query in <threads - 1> chunks.
134         place_threads = max(1, threads - 1)
135
136         with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
137             with tokenizer.name_analyzer() as analyzer:
138                 while tar:
139                     with tar.next_file() as fd:
140                         handle_threaded_sql_statements(pool, fd, analyzer)
141
142         print('\n')
143
144     LOG.warning("Creating indexes on Tiger data")
145     with connect(dsn) as conn:
146         sql = SQLPreprocessor(conn, config)
147         sql.run_sql_file(conn, 'tiger_import_finish.sql')
148
149     return 0