]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/tools/tiger_data.py
port code to psycopg3
[nominatim.git] / src / nominatim_db / tools / tiger_data.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Functions for importing tiger data and handling tarbar and directory files
9 """
10 from typing import Any, TextIO, List, Union, cast, Iterator, Dict
11 import csv
12 import io
13 import logging
14 import os
15 import tarfile
16
17 from psycopg.types.json import Json
18
19 from ..config import Configuration
20 from ..db.connection import connect
21 from ..db.sql_preprocessor import SQLPreprocessor
22 from ..errors import UsageError
23 from ..db.query_pool import QueryPool
24 from ..data.place_info import PlaceInfo
25 from ..tokenizer.base import AbstractTokenizer
26 from . import freeze
27
28 LOG = logging.getLogger()
29
30 class TigerInput:
31     """ Context manager that goes through Tiger input files which may
32         either be in a directory or gzipped together in a tar file.
33     """
34
35     def __init__(self, data_dir: str) -> None:
36         self.tar_handle = None
37         self.files: List[Union[str, tarfile.TarInfo]] = []
38
39         if data_dir.endswith('.tar.gz'):
40             try:
41                 self.tar_handle = tarfile.open(data_dir) # pylint: disable=consider-using-with
42             except tarfile.ReadError as err:
43                 LOG.fatal("Cannot open '%s'. Is this a tar file?", data_dir)
44                 raise UsageError("Cannot open Tiger data file.") from err
45
46             self.files = [i for i in self.tar_handle.getmembers() if i.name.endswith('.csv')]
47             LOG.warning("Found %d CSV files in tarfile with path %s", len(self.files), data_dir)
48         else:
49             files = os.listdir(data_dir)
50             self.files = [os.path.join(data_dir, i) for i in files if i.endswith('.csv')]
51             LOG.warning("Found %d CSV files in path %s", len(self.files), data_dir)
52
53         if not self.files:
54             LOG.warning("Tiger data import selected but no files found at %s", data_dir)
55
56
57     def __enter__(self) -> 'TigerInput':
58         return self
59
60
61     def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
62         if self.tar_handle:
63             self.tar_handle.close()
64             self.tar_handle = None
65
66     def __bool__(self) -> bool:
67         return bool(self.files)
68
69     def get_file(self, fname: Union[str, tarfile.TarInfo]) -> TextIO:
70         """ Return a file handle to the next file to be processed.
71             Raises an IndexError if there is no file left.
72         """
73         if self.tar_handle is not None:
74             extracted = self.tar_handle.extractfile(fname)
75             assert extracted is not None
76             return io.TextIOWrapper(extracted)
77
78         return open(cast(str, fname), encoding='utf-8')
79
80
81     def __iter__(self) -> Iterator[Dict[str, Any]]:
82         """ Iterate over the lines in each file.
83         """
84         for fname in self.files:
85             fd = self.get_file(fname)
86             yield from csv.DictReader(fd, delimiter=';')
87
88
89 async def add_tiger_data(data_dir: str, config: Configuration, threads: int,
90                    tokenizer: AbstractTokenizer) -> int:
91     """ Import tiger data from directory or tar file `data dir`.
92     """
93     dsn = config.get_libpq_dsn()
94
95     with connect(dsn) as conn:
96         if freeze.is_frozen(conn):
97             raise UsageError("Tiger cannot be imported when database frozen (Github issue #3048)")
98
99     with TigerInput(data_dir) as tar:
100         if not tar:
101             return 1
102
103         with connect(dsn) as conn:
104             sql = SQLPreprocessor(conn, config)
105             sql.run_sql_file(conn, 'tiger_import_start.sql')
106
107         # Reading files and then for each file line handling
108         # sql_query in <threads - 1> chunks.
109         place_threads = max(1, threads - 1)
110
111         async with QueryPool(dsn, place_threads, autocommit=True) as pool:
112             with tokenizer.name_analyzer() as analyzer:
113                 lines = 0
114                 for row in tar:
115                     try:
116                         address = dict(street=row['street'], postcode=row['postcode'])
117                         args = ('SRID=4326;' + row['geometry'],
118                                 int(row['from']), int(row['to']), row['interpolation'],
119                                 Json(analyzer.process_place(PlaceInfo({'address': address}))),
120                                 analyzer.normalize_postcode(row['postcode']))
121                     except ValueError:
122                         continue
123
124                     await pool.put_query(
125                         """SELECT tiger_line_import(%s::GEOMETRY, %s::INT,
126                                                     %s::INT, %s::TEXT, %s::JSONB, %s::TEXT)""",
127                         args)
128
129                     lines += 1
130                     if lines == 1000:
131                         print('.', end='', flush=True)
132                     lines = 0
133
134         print('', flush=True)
135
136     LOG.warning("Creating indexes on Tiger data")
137     with connect(dsn) as conn:
138         sql = SQLPreprocessor(conn, config)
139         sql.run_sql_file(conn, 'tiger_import_finish.sql')
140
141     return 0