]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/tiger_data.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / tools / tiger_data.py
index 19a1268253feaa7ff1e2e6de20be3c43f1d74025..6e37df5e9df7d2b6b3228c4ba79b8a5350865386 100644 (file)
@@ -1,3 +1,9 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Functions for importing tiger data and handling tarbar and directory files
 """
 """
 Functions for importing tiger data and handling tarbar and directory files
 """
@@ -15,33 +21,57 @@ from nominatim.indexer.place_info import PlaceInfo
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
-
-def handle_tarfile_or_directory(data_dir):
-    """ Handles tarfile or directory for importing tiger data
+class TigerInput:
+    """ Context manager that goes through Tiger input files which may
+        either be in a directory or gzipped together in a tar file.
     """
 
     """
 
-    tar = None
-    if data_dir.endswith('.tar.gz'):
-        try:
-            tar = tarfile.open(data_dir)
-        except tarfile.ReadError as err:
-            LOG.fatal("Cannot open '%s'. Is this a tar file?", data_dir)
-            raise UsageError("Cannot open Tiger data file.") from err
-
-        csv_files = [i for i in tar.getmembers() if i.name.endswith('.csv')]
-        LOG.warning("Found %d CSV files in tarfile with path %s", len(csv_files), data_dir)
-        if not csv_files:
-            LOG.warning("Tiger data import selected but no files in tarfile's path %s", data_dir)
-            return None, None
-    else:
-        files = os.listdir(data_dir)
-        csv_files = [os.path.join(data_dir, i) for i in files if i.endswith('.csv')]
-        LOG.warning("Found %d CSV files in path %s", len(csv_files), data_dir)
-        if not csv_files:
-            LOG.warning("Tiger data import selected but no files found in path %s", data_dir)
-            return None, None
-
-    return csv_files, tar
+    def __init__(self, data_dir):
+        self.tar_handle = None
+        self.files = []
+
+        if data_dir.endswith('.tar.gz'):
+            try:
+                self.tar_handle = tarfile.open(data_dir) # pylint: disable=consider-using-with
+            except tarfile.ReadError as err:
+                LOG.fatal("Cannot open '%s'. Is this a tar file?", data_dir)
+                raise UsageError("Cannot open Tiger data file.") from err
+
+            self.files = [i for i in self.tar_handle.getmembers() if i.name.endswith('.csv')]
+            LOG.warning("Found %d CSV files in tarfile with path %s", len(self.files), data_dir)
+        else:
+            files = os.listdir(data_dir)
+            self.files = [os.path.join(data_dir, i) for i in files if i.endswith('.csv')]
+            LOG.warning("Found %d CSV files in path %s", len(self.files), data_dir)
+
+        if not self.files:
+            LOG.warning("Tiger data import selected but no files found at %s", data_dir)
+
+
+    def __enter__(self):
+        return self
+
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if self.tar_handle:
+            self.tar_handle.close()
+            self.tar_handle = None
+
+
+    def next_file(self):
+        """ Return a file handle to the next file to be processed.
+            Raises an IndexError if there is no file left.
+        """
+        fname = self.files.pop(0)
+
+        if self.tar_handle is not None:
+            return io.TextIOWrapper(self.tar_handle.extractfile(fname))
+
+        return open(fname, encoding='utf-8')
+
+
+    def __len__(self):
+        return len(self.files)
 
 
 def handle_threaded_sql_statements(pool, fd, analyzer):
 
 
 def handle_threaded_sql_statements(pool, fd, analyzer):
@@ -73,34 +103,27 @@ def add_tiger_data(data_dir, config, threads, tokenizer):
     """ Import tiger data from directory or tar file `data dir`.
     """
     dsn = config.get_libpq_dsn()
     """ Import tiger data from directory or tar file `data dir`.
     """
     dsn = config.get_libpq_dsn()
-    files, tar = handle_tarfile_or_directory(data_dir)
-
-    if not files:
-        return
 
 
-    with connect(dsn) as conn:
-        sql = SQLPreprocessor(conn, config)
-        sql.run_sql_file(conn, 'tiger_import_start.sql')
+    with TigerInput(data_dir) as tar:
+        if not tar:
+            return
 
 
-    # Reading files and then for each file line handling
-    # sql_query in <threads - 1> chunks.
-    place_threads = max(1, threads - 1)
+        with connect(dsn) as conn:
+            sql = SQLPreprocessor(conn, config)
+            sql.run_sql_file(conn, 'tiger_import_start.sql')
 
 
-    with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
-        with tokenizer.name_analyzer() as analyzer:
-            for fname in files:
-                if not tar:
-                    fd = open(fname)
-                else:
-                    fd = io.TextIOWrapper(tar.extractfile(fname))
+        # Reading files and then for each file line handling
+        # sql_query in <threads - 1> chunks.
+        place_threads = max(1, threads - 1)
 
 
-                handle_threaded_sql_statements(pool, fd, analyzer)
+        with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
+            with tokenizer.name_analyzer() as analyzer:
+                while tar:
+                    with tar.next_file() as fd:
+                        handle_threaded_sql_statements(pool, fd, analyzer)
 
 
-                fd.close()
+        print('\n')
 
 
-    if tar:
-        tar.close()
-    print('\n')
     LOG.warning("Creating indexes on Tiger data")
     with connect(dsn) as conn:
         sql = SQLPreprocessor(conn, config)
     LOG.warning("Creating indexes on Tiger data")
     with connect(dsn) as conn:
         sql = SQLPreprocessor(conn, config)