]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/tiger_data.py
add type annotations for Tiger import function
[nominatim.git] / nominatim / tools / tiger_data.py
index c655f91d73bf096f0af4fcf7cbc7ce75c4415bd0..4988e33c8552150dab9efbd56274046d939f6304 100644 (file)
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Functions for importing tiger data and handling tarbar and directory files
 """
 """
 Functions for importing tiger data and handling tarbar and directory files
 """
+from typing import Any, TextIO, List, Union, cast
+import csv
+import io
 import logging
 import os
 import tarfile
 import logging
 import os
 import tarfile
-import selectors
 
 
-from ..db.connection import connect
-from ..db.async_connection import DBConnection
-from ..db.sql_preprocessor import SQLPreprocessor
+from psycopg2.extras import Json
 
 
+from nominatim.config import Configuration
+from nominatim.db.connection import connect
+from nominatim.db.async_connection import WorkerPool
+from nominatim.db.sql_preprocessor import SQLPreprocessor
+from nominatim.errors import UsageError
+from nominatim.data.place_info import PlaceInfo
+from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-
-def handle_tarfile_or_directory(data_dir):
-    """ Handles tarfile or directory for importing tiger data
+class TigerInput:
+    """ Context manager that goes through Tiger input files which may
+        either be in a directory or gzipped together in a tar file.
     """
 
     """
 
-    tar = None
-    if data_dir.endswith('.tar.gz'):
-        tar = tarfile.open(data_dir)
-        sql_files = [i for i in tar.getmembers() if i.name.endswith('.sql')]
-        LOG.warning("Found %d SQL files in tarfile with path %s", len(sql_files), data_dir)
-        if not sql_files:
-            LOG.warning("Tiger data import selected but no files in tarfile's path %s", data_dir)
-            return None, None
-    else:
-        files = os.listdir(data_dir)
-        sql_files = [os.path.join(data_dir, i) for i in files if i.endswith('.sql')]
-        LOG.warning("Found %d SQL files in path %s", len(sql_files), data_dir)
-        if not sql_files:
-            LOG.warning("Tiger data import selected but no files found in path %s", data_dir)
-            return None, None
-
-    return sql_files, tar
-
-
-def handle_threaded_sql_statements(sel, file):
-    """ Handles sql statement with multiplexing
-    """
+    def __init__(self, data_dir: str) -> None:
+        self.tar_handle = None
+        self.files: List[Union[str, tarfile.TarInfo]] = []
 
 
-    lines = 0
-    end_of_file = False
-    # Using pool of database connections to execute sql statements
-    while not end_of_file:
-        for key, _ in sel.select(1):
-            conn = key.data
+        if data_dir.endswith('.tar.gz'):
             try:
             try:
-                if conn.is_done():
-                    sql_query = file.readline()
-                    lines += 1
-                    if not sql_query:
-                        end_of_file = True
-                        break
-                    conn.perform(sql_query)
-                    if lines == 1000:
-                        print('. ', end='', flush=True)
-                        lines = 0
-            except Exception as exc: # pylint: disable=broad-except
-                LOG.info('Wrong SQL statement: %s', exc)
-
-def handle_unregister_connection_pool(sel, place_threads):
-    """ Handles unregistering pool of connections
-    """
+                self.tar_handle = tarfile.open(data_dir) # pylint: disable=consider-using-with
+            except tarfile.ReadError as err:
+                LOG.fatal("Cannot open '%s'. Is this a tar file?", data_dir)
+                raise UsageError("Cannot open Tiger data file.") from err
 
 
-    while place_threads > 0:
-        for key, _ in sel.select(1):
-            conn = key.data
-            sel.unregister(conn)
-            try:
-                conn.wait()
-            except Exception as exc: # pylint: disable=broad-except
-                LOG.info('Wrong SQL statement: %s', exc)
-            conn.close()
-            place_threads -= 1
-
-def add_tiger_data(dsn, data_dir, threads, config, sqllib_dir):
-    """ Import tiger data from directory or tar file
+            self.files = [i for i in self.tar_handle.getmembers() if i.name.endswith('.csv')]
+            LOG.warning("Found %d CSV files in tarfile with path %s", len(self.files), data_dir)
+        else:
+            files = os.listdir(data_dir)
+            self.files = [os.path.join(data_dir, i) for i in files if i.endswith('.csv')]
+            LOG.warning("Found %d CSV files in path %s", len(self.files), data_dir)
+
+        if not self.files:
+            LOG.warning("Tiger data import selected but no files found at %s", data_dir)
+
+
+    def __enter__(self) -> 'TigerInput':
+        return self
+
+
+    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+        if self.tar_handle:
+            self.tar_handle.close()
+            self.tar_handle = None
+
+
+    def next_file(self) -> TextIO:
+        """ Return a file handle to the next file to be processed.
+            Raises an IndexError if there is no file left.
+        """
+        fname = self.files.pop(0)
+
+        if self.tar_handle is not None:
+            extracted = self.tar_handle.extractfile(fname)
+            assert extracted is not None
+            return io.TextIOWrapper(extracted)
+
+        return open(cast(str, fname), encoding='utf-8')
+
+
+    def __len__(self) -> int:
+        return len(self.files)
+
+
+def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO,
+                                   analyzer: AbstractAnalyzer) -> None:
+    """ Handles sql statement with multiplexing
     """
     """
+    lines = 0
+    # Using pool of database connections to execute sql statements
 
 
-    sql_files, tar = handle_tarfile_or_directory(data_dir)
+    sql = "SELECT tiger_line_import(%s, %s, %s, %s, %s, %s)"
 
 
-    if not sql_files:
-        return
+    for row in csv.DictReader(fd, delimiter=';'):
+        try:
+            address = dict(street=row['street'], postcode=row['postcode'])
+            args = ('SRID=4326;' + row['geometry'],
+                    int(row['from']), int(row['to']), row['interpolation'],
+                    Json(analyzer.process_place(PlaceInfo({'address': address}))),
+                    analyzer.normalize_postcode(row['postcode']))
+        except ValueError:
+            continue
+        pool.next_free_worker().perform(sql, args=args)
 
 
-    with connect(dsn) as conn:
-        sql = SQLPreprocessor(conn, config, sqllib_dir)
-        sql.run_sql_file(conn, 'tiger_import_start.sql')
+        lines += 1
+        if lines == 1000:
+            print('.', end='', flush=True)
+            lines = 0
 
 
-    # Reading sql_files and then for each file line handling
-    # sql_query in <threads - 1> chunks.
-    sel = selectors.DefaultSelector()
-    place_threads = max(1, threads - 1)
 
 
-    # Creates a pool of database connections
-    for _ in range(place_threads):
-        conn = DBConnection(dsn)
-        conn.connect()
-        sel.register(conn, selectors.EVENT_WRITE, conn)
+def add_tiger_data(data_dir: str, config: Configuration, threads: int,
+                   tokenizer: AbstractTokenizer) -> None:
+    """ Import tiger data from directory or tar file `data dir`.
+    """
+    dsn = config.get_libpq_dsn()
 
 
-    for sql_file in sql_files:
+    with TigerInput(data_dir) as tar:
         if not tar:
         if not tar:
-            file = open(sql_file)
-        else:
-            file = tar.extractfile(sql_file)
+            return
+
+        with connect(dsn) as conn:
+            sql = SQLPreprocessor(conn, config)
+            sql.run_sql_file(conn, 'tiger_import_start.sql')
+
+        # Reading files and then for each file line handling
+        # sql_query in <threads - 1> chunks.
+        place_threads = max(1, threads - 1)
 
 
-        handle_threaded_sql_statements(sel, file)
+        with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
+            with tokenizer.name_analyzer() as analyzer:
+                while tar:
+                    with tar.next_file() as fd:
+                        handle_threaded_sql_statements(pool, fd, analyzer)
 
 
-    # Unregistering pool of database connections
-    handle_unregister_connection_pool(sel, place_threads)
+        print('\n')
 
 
-    if tar:
-        tar.close()
-    print('\n')
     LOG.warning("Creating indexes on Tiger data")
     with connect(dsn) as conn:
     LOG.warning("Creating indexes on Tiger data")
     with connect(dsn) as conn:
-        sql = SQLPreprocessor(conn, config, sqllib_dir)
+        sql = SQLPreprocessor(conn, config)
         sql.run_sql_file(conn, 'tiger_import_finish.sql')
         sql.run_sql_file(conn, 'tiger_import_finish.sql')