X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/300612c5a8ebfa9eff99b7f88cfcf4e0ed2fbbfc..2ca83efc36a96cfa070be61c7422d255044130f3:/nominatim/tools/tiger_data.py diff --git a/nominatim/tools/tiger_data.py b/nominatim/tools/tiger_data.py index e78dcd8f..4a32bb1e 100644 --- a/nominatim/tools/tiger_data.py +++ b/nominatim/tools/tiger_data.py @@ -7,6 +7,7 @@ """ Functions for importing tiger data and handling tarbar and directory files """ +from typing import Any, TextIO, List, Union, cast import csv import io import logging @@ -15,11 +16,13 @@ import tarfile from psycopg2.extras import Json +from nominatim.config import Configuration from nominatim.db.connection import connect from nominatim.db.async_connection import WorkerPool from nominatim.db.sql_preprocessor import SQLPreprocessor from nominatim.errors import UsageError from nominatim.data.place_info import PlaceInfo +from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer LOG = logging.getLogger() @@ -28,9 +31,9 @@ class TigerInput: either be in a directory or gzipped together in a tar file. """ - def __init__(self, data_dir): + def __init__(self, data_dir: str) -> None: self.tar_handle = None - self.files = [] + self.files: List[Union[str, tarfile.TarInfo]] = [] if data_dir.endswith('.tar.gz'): try: @@ -50,33 +53,36 @@ class TigerInput: LOG.warning("Tiger data import selected but no files found at %s", data_dir) - def __enter__(self): + def __enter__(self) -> 'TigerInput': return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self.tar_handle: self.tar_handle.close() self.tar_handle = None - def next_file(self): + def next_file(self) -> TextIO: """ Return a file handle to the next file to be processed. Raises an IndexError if there is no file left. """ fname = self.files.pop(0) if self.tar_handle is not None: - return io.TextIOWrapper(self.tar_handle.extractfile(fname)) + extracted = self.tar_handle.extractfile(fname) + assert extracted is not None + return io.TextIOWrapper(extracted) - return open(fname, encoding='utf-8') + return open(cast(str, fname), encoding='utf-8') - def __len__(self): + def __len__(self) -> int: return len(self.files) -def handle_threaded_sql_statements(pool, fd, analyzer): +def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO, + analyzer: AbstractAnalyzer) -> None: """ Handles sql statement with multiplexing """ lines = 0 @@ -101,14 +107,15 @@ def handle_threaded_sql_statements(pool, fd, analyzer): lines = 0 -def add_tiger_data(data_dir, config, threads, tokenizer): +def add_tiger_data(data_dir: str, config: Configuration, threads: int, + tokenizer: AbstractTokenizer) -> int: """ Import tiger data from directory or tar file `data dir`. """ dsn = config.get_libpq_dsn() with TigerInput(data_dir) as tar: if not tar: - return + return 1 with connect(dsn) as conn: sql = SQLPreprocessor(conn, config) @@ -130,3 +137,5 @@ def add_tiger_data(data_dir, config, threads, tokenizer): with connect(dsn) as conn: sql = SQLPreprocessor(conn, config) sql.run_sql_file(conn, 'tiger_import_finish.sql') + + return 0