]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/tools/tiger_data.py
move complex typing annotations to extra file
[nominatim.git] / nominatim / tools / tiger_data.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Functions for importing tiger data and handling tarbar and directory files
9 """
10 import csv
11 import io
12 import logging
13 import os
14 import tarfile
15
16 from psycopg2.extras import Json
17
18 from nominatim.db.connection import connect
19 from nominatim.db.async_connection import WorkerPool
20 from nominatim.db.sql_preprocessor import SQLPreprocessor
21 from nominatim.errors import UsageError
22 from nominatim.data.place_info import PlaceInfo
23
24 LOG = logging.getLogger()
25
26 class TigerInput:
27     """ Context manager that goes through Tiger input files which may
28         either be in a directory or gzipped together in a tar file.
29     """
30
31     def __init__(self, data_dir):
32         self.tar_handle = None
33         self.files = []
34
35         if data_dir.endswith('.tar.gz'):
36             try:
37                 self.tar_handle = tarfile.open(data_dir) # pylint: disable=consider-using-with
38             except tarfile.ReadError as err:
39                 LOG.fatal("Cannot open '%s'. Is this a tar file?", data_dir)
40                 raise UsageError("Cannot open Tiger data file.") from err
41
42             self.files = [i for i in self.tar_handle.getmembers() if i.name.endswith('.csv')]
43             LOG.warning("Found %d CSV files in tarfile with path %s", len(self.files), data_dir)
44         else:
45             files = os.listdir(data_dir)
46             self.files = [os.path.join(data_dir, i) for i in files if i.endswith('.csv')]
47             LOG.warning("Found %d CSV files in path %s", len(self.files), data_dir)
48
49         if not self.files:
50             LOG.warning("Tiger data import selected but no files found at %s", data_dir)
51
52
53     def __enter__(self):
54         return self
55
56
57     def __exit__(self, exc_type, exc_val, exc_tb):
58         if self.tar_handle:
59             self.tar_handle.close()
60             self.tar_handle = None
61
62
63     def next_file(self):
64         """ Return a file handle to the next file to be processed.
65             Raises an IndexError if there is no file left.
66         """
67         fname = self.files.pop(0)
68
69         if self.tar_handle is not None:
70             return io.TextIOWrapper(self.tar_handle.extractfile(fname))
71
72         return open(fname, encoding='utf-8')
73
74
75     def __len__(self):
76         return len(self.files)
77
78
79 def handle_threaded_sql_statements(pool, fd, analyzer):
80     """ Handles sql statement with multiplexing
81     """
82     lines = 0
83     # Using pool of database connections to execute sql statements
84
85     sql = "SELECT tiger_line_import(%s, %s, %s, %s, %s, %s)"
86
87     for row in csv.DictReader(fd, delimiter=';'):
88         try:
89             address = dict(street=row['street'], postcode=row['postcode'])
90             args = ('SRID=4326;' + row['geometry'],
91                     int(row['from']), int(row['to']), row['interpolation'],
92                     Json(analyzer.process_place(PlaceInfo({'address': address}))),
93                     analyzer.normalize_postcode(row['postcode']))
94         except ValueError:
95             continue
96         pool.next_free_worker().perform(sql, args=args)
97
98         lines += 1
99         if lines == 1000:
100             print('.', end='', flush=True)
101             lines = 0
102
103
104 def add_tiger_data(data_dir, config, threads, tokenizer):
105     """ Import tiger data from directory or tar file `data dir`.
106     """
107     dsn = config.get_libpq_dsn()
108
109     with TigerInput(data_dir) as tar:
110         if not tar:
111             return
112
113         with connect(dsn) as conn:
114             sql = SQLPreprocessor(conn, config)
115             sql.run_sql_file(conn, 'tiger_import_start.sql')
116
117         # Reading files and then for each file line handling
118         # sql_query in <threads - 1> chunks.
119         place_threads = max(1, threads - 1)
120
121         with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
122             with tokenizer.name_analyzer() as analyzer:
123                 while tar:
124                     with tar.next_file() as fd:
125                         handle_threaded_sql_statements(pool, fd, analyzer)
126
127         print('\n')
128
129     LOG.warning("Creating indexes on Tiger data")
130     with connect(dsn) as conn:
131         sql = SQLPreprocessor(conn, config)
132         sql.run_sql_file(conn, 'tiger_import_finish.sql')