X-Git-Url: https://git.openstreetmap.org./nominatim.git/blobdiff_plain/6e89310a9285f1ad15d8002bf68f578eada367a0..fd9437277e14e1911b2c011fb7dabf6fcbb2d8c4:/src/nominatim_db/tools/tiger_data.py diff --git a/src/nominatim_db/tools/tiger_data.py b/src/nominatim_db/tools/tiger_data.py index 6030a38a..f4a7eba7 100644 --- a/src/nominatim_db/tools/tiger_data.py +++ b/src/nominatim_db/tools/tiger_data.py @@ -7,22 +7,22 @@ """ Functions for importing tiger data and handling tarbar and directory files """ -from typing import Any, TextIO, List, Union, cast +from typing import Any, TextIO, List, Union, cast, Iterator, Dict import csv import io import logging import os import tarfile -from psycopg2.extras import Json +from psycopg.types.json import Json -from nominatim_core.config import Configuration -from nominatim_core.db.connection import connect -from nominatim_core.db.async_connection import WorkerPool -from nominatim_core.db.sql_preprocessor import SQLPreprocessor -from nominatim_core.errors import UsageError +from ..config import Configuration +from ..db.connection import connect +from ..db.sql_preprocessor import SQLPreprocessor +from ..errors import UsageError +from ..db.query_pool import QueryPool from ..data.place_info import PlaceInfo -from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer +from ..tokenizer.base import AbstractTokenizer from . import freeze LOG = logging.getLogger() @@ -63,13 +63,13 @@ class TigerInput: self.tar_handle.close() self.tar_handle = None + def __bool__(self) -> bool: + return bool(self.files) - def next_file(self) -> TextIO: + def get_file(self, fname: Union[str, tarfile.TarInfo]) -> TextIO: """ Return a file handle to the next file to be processed. Raises an IndexError if there is no file left. """ - fname = self.files.pop(0) - if self.tar_handle is not None: extracted = self.tar_handle.extractfile(fname) assert extracted is not None @@ -78,47 +78,22 @@ class TigerInput: return open(cast(str, fname), encoding='utf-8') - def __len__(self) -> int: - return len(self.files) - - -def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO, - analyzer: AbstractAnalyzer) -> None: - """ Handles sql statement with multiplexing - """ - lines = 0 - # Using pool of database connections to execute sql statements - - sql = "SELECT tiger_line_import(%s, %s, %s, %s, %s, %s)" - - for row in csv.DictReader(fd, delimiter=';'): - try: - address = dict(street=row['street'], postcode=row['postcode']) - args = ('SRID=4326;' + row['geometry'], - int(row['from']), int(row['to']), row['interpolation'], - Json(analyzer.process_place(PlaceInfo({'address': address}))), - analyzer.normalize_postcode(row['postcode'])) - except ValueError: - continue - pool.next_free_worker().perform(sql, args=args) - - lines += 1 - if lines == 1000: - print('.', end='', flush=True) - lines = 0 + def __iter__(self) -> Iterator[Dict[str, Any]]: + """ Iterate over the lines in each file. + """ + for fname in self.files: + fd = self.get_file(fname) + yield from csv.DictReader(fd, delimiter=';') -def add_tiger_data(data_dir: str, config: Configuration, threads: int, +async def add_tiger_data(data_dir: str, config: Configuration, threads: int, tokenizer: AbstractTokenizer) -> int: """ Import tiger data from directory or tar file `data dir`. """ dsn = config.get_libpq_dsn() with connect(dsn) as conn: - is_frozen = freeze.is_frozen(conn) - conn.close() - - if is_frozen: + if freeze.is_frozen(conn): raise UsageError("Tiger cannot be imported when database frozen (Github issue #3048)") with TigerInput(data_dir) as tar: @@ -133,13 +108,30 @@ def add_tiger_data(data_dir: str, config: Configuration, threads: int, # sql_query in chunks. place_threads = max(1, threads - 1) - with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool: + async with QueryPool(dsn, place_threads, autocommit=True) as pool: with tokenizer.name_analyzer() as analyzer: - while tar: - with tar.next_file() as fd: - handle_threaded_sql_statements(pool, fd, analyzer) - - print('\n') + lines = 0 + for row in tar: + try: + address = dict(street=row['street'], postcode=row['postcode']) + args = ('SRID=4326;' + row['geometry'], + int(row['from']), int(row['to']), row['interpolation'], + Json(analyzer.process_place(PlaceInfo({'address': address}))), + analyzer.normalize_postcode(row['postcode'])) + except ValueError: + continue + + await pool.put_query( + """SELECT tiger_line_import(%s::GEOMETRY, %s::INT, + %s::INT, %s::TEXT, %s::JSONB, %s::TEXT)""", + args) + + lines += 1 + if lines == 1000: + print('.', end='', flush=True) + lines = 0 + + print('', flush=True) LOG.warning("Creating indexes on Tiger data") with connect(dsn) as conn: