]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/tools/tiger_data.py
adapt annotations for SQLAlchemy 2.x
[nominatim.git] / nominatim / tools / tiger_data.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Functions for importing tiger data and handling tarbar and directory files
9 """
10 from typing import Any, TextIO, List, Union, cast
11 import csv
12 import io
13 import logging
14 import os
15 import tarfile
16
17 from psycopg2.extras import Json
18
19 from nominatim.config import Configuration
20 from nominatim.db.connection import connect
21 from nominatim.db.async_connection import WorkerPool
22 from nominatim.db.sql_preprocessor import SQLPreprocessor
23 from nominatim.errors import UsageError
24 from nominatim.data.place_info import PlaceInfo
25 from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer
26
27 LOG = logging.getLogger()
28
29 class TigerInput:
30     """ Context manager that goes through Tiger input files which may
31         either be in a directory or gzipped together in a tar file.
32     """
33
34     def __init__(self, data_dir: str) -> None:
35         self.tar_handle = None
36         self.files: List[Union[str, tarfile.TarInfo]] = []
37
38         if data_dir.endswith('.tar.gz'):
39             try:
40                 self.tar_handle = tarfile.open(data_dir) # pylint: disable=consider-using-with
41             except tarfile.ReadError as err:
42                 LOG.fatal("Cannot open '%s'. Is this a tar file?", data_dir)
43                 raise UsageError("Cannot open Tiger data file.") from err
44
45             self.files = [i for i in self.tar_handle.getmembers() if i.name.endswith('.csv')]
46             LOG.warning("Found %d CSV files in tarfile with path %s", len(self.files), data_dir)
47         else:
48             files = os.listdir(data_dir)
49             self.files = [os.path.join(data_dir, i) for i in files if i.endswith('.csv')]
50             LOG.warning("Found %d CSV files in path %s", len(self.files), data_dir)
51
52         if not self.files:
53             LOG.warning("Tiger data import selected but no files found at %s", data_dir)
54
55
56     def __enter__(self) -> 'TigerInput':
57         return self
58
59
60     def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
61         if self.tar_handle:
62             self.tar_handle.close()
63             self.tar_handle = None
64
65
66     def next_file(self) -> TextIO:
67         """ Return a file handle to the next file to be processed.
68             Raises an IndexError if there is no file left.
69         """
70         fname = self.files.pop(0)
71
72         if self.tar_handle is not None:
73             extracted = self.tar_handle.extractfile(fname)
74             assert extracted is not None
75             return io.TextIOWrapper(extracted)
76
77         return open(cast(str, fname), encoding='utf-8')
78
79
80     def __len__(self) -> int:
81         return len(self.files)
82
83
84 def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO,
85                                    analyzer: AbstractAnalyzer) -> None:
86     """ Handles sql statement with multiplexing
87     """
88     lines = 0
89     # Using pool of database connections to execute sql statements
90
91     sql = "SELECT tiger_line_import(%s, %s, %s, %s, %s, %s)"
92
93     for row in csv.DictReader(fd, delimiter=';'):
94         try:
95             address = dict(street=row['street'], postcode=row['postcode'])
96             args = ('SRID=4326;' + row['geometry'],
97                     int(row['from']), int(row['to']), row['interpolation'],
98                     Json(analyzer.process_place(PlaceInfo({'address': address}))),
99                     analyzer.normalize_postcode(row['postcode']))
100         except ValueError:
101             continue
102         pool.next_free_worker().perform(sql, args=args)
103
104         lines += 1
105         if lines == 1000:
106             print('.', end='', flush=True)
107             lines = 0
108
109
110 def add_tiger_data(data_dir: str, config: Configuration, threads: int,
111                    tokenizer: AbstractTokenizer) -> int:
112     """ Import tiger data from directory or tar file `data dir`.
113     """
114     dsn = config.get_libpq_dsn()
115
116     with TigerInput(data_dir) as tar:
117         if not tar:
118             return 1
119
120         with connect(dsn) as conn:
121             sql = SQLPreprocessor(conn, config)
122             sql.run_sql_file(conn, 'tiger_import_start.sql')
123
124         # Reading files and then for each file line handling
125         # sql_query in <threads - 1> chunks.
126         place_threads = max(1, threads - 1)
127
128         with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
129             with tokenizer.name_analyzer() as analyzer:
130                 while tar:
131                     with tar.next_file() as fd:
132                         handle_threaded_sql_statements(pool, fd, analyzer)
133
134         print('\n')
135
136     LOG.warning("Creating indexes on Tiger data")
137     with connect(dsn) as conn:
138         sql = SQLPreprocessor(conn, config)
139         sql.run_sql_file(conn, 'tiger_import_finish.sql')
140
141     return 0