]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/tiger_data.py
Merge pull request #2562 from lonvia/copyright-headers
[nominatim.git] / nominatim / tools / tiger_data.py
index c1de3615f7f953a65984b2bf0d644552b4d30cf0..8610880ff9f8f8104c1bf96f0dc144d4a7f957ef 100644 (file)
@@ -1,15 +1,23 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Functions for importing tiger data and handling tarbar and directory files
 """
 """
 Functions for importing tiger data and handling tarbar and directory files
 """
+import csv
+import io
 import logging
 import os
 import tarfile
 import logging
 import os
 import tarfile
-import selectors
 
 from nominatim.db.connection import connect
 
 from nominatim.db.connection import connect
-from nominatim.db.async_connection import DBConnection
+from nominatim.db.async_connection import WorkerPool
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.db.sql_preprocessor import SQLPreprocessor
-
+from nominatim.errors import UsageError
+from nominatim.indexer.place_info import PlaceInfo
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
@@ -20,101 +28,86 @@ def handle_tarfile_or_directory(data_dir):
 
     tar = None
     if data_dir.endswith('.tar.gz'):
 
     tar = None
     if data_dir.endswith('.tar.gz'):
-        tar = tarfile.open(data_dir)
-        sql_files = [i for i in tar.getmembers() if i.name.endswith('.sql')]
-        LOG.warning("Found %d SQL files in tarfile with path %s", len(sql_files), data_dir)
-        if not sql_files:
+        try:
+            tar = tarfile.open(data_dir)
+        except tarfile.ReadError as err:
+            LOG.fatal("Cannot open '%s'. Is this a tar file?", data_dir)
+            raise UsageError("Cannot open Tiger data file.") from err
+
+        csv_files = [i for i in tar.getmembers() if i.name.endswith('.csv')]
+        LOG.warning("Found %d CSV files in tarfile with path %s", len(csv_files), data_dir)
+        if not csv_files:
             LOG.warning("Tiger data import selected but no files in tarfile's path %s", data_dir)
             return None, None
     else:
         files = os.listdir(data_dir)
             LOG.warning("Tiger data import selected but no files in tarfile's path %s", data_dir)
             return None, None
     else:
         files = os.listdir(data_dir)
-        sql_files = [os.path.join(data_dir, i) for i in files if i.endswith('.sql')]
-        LOG.warning("Found %d SQL files in path %s", len(sql_files), data_dir)
-        if not sql_files:
+        csv_files = [os.path.join(data_dir, i) for i in files if i.endswith('.csv')]
+        LOG.warning("Found %d CSV files in path %s", len(csv_files), data_dir)
+        if not csv_files:
             LOG.warning("Tiger data import selected but no files found in path %s", data_dir)
             return None, None
 
             LOG.warning("Tiger data import selected but no files found in path %s", data_dir)
             return None, None
 
-    return sql_files, tar
+    return csv_files, tar
 
 
 
 
-def handle_threaded_sql_statements(sel, file):
+def handle_threaded_sql_statements(pool, fd, analyzer):
     """ Handles sql statement with multiplexing
     """
     """ Handles sql statement with multiplexing
     """
-
     lines = 0
     lines = 0
-    end_of_file = False
     # Using pool of database connections to execute sql statements
     # Using pool of database connections to execute sql statements
-    while not end_of_file:
-        for key, _ in sel.select(1):
-            conn = key.data
-            try:
-                if conn.is_done():
-                    sql_query = file.readline()
-                    lines += 1
-                    if not sql_query:
-                        end_of_file = True
-                        break
-                    conn.perform(sql_query)
-                    if lines == 1000:
-                        print('. ', end='', flush=True)
-                        lines = 0
-            except Exception as exc: # pylint: disable=broad-except
-                LOG.info('Wrong SQL statement: %s', exc)
-
-def handle_unregister_connection_pool(sel, place_threads):
-    """ Handles unregistering pool of connections
-    """
 
 
-    while place_threads > 0:
-        for key, _ in sel.select(1):
-            conn = key.data
-            sel.unregister(conn)
-            try:
-                conn.wait()
-            except Exception as exc: # pylint: disable=broad-except
-                LOG.info('Wrong SQL statement: %s', exc)
-            conn.close()
-            place_threads -= 1
-
-def add_tiger_data(dsn, data_dir, threads, config, sqllib_dir):
-    """ Import tiger data from directory or tar file
-    """
+    sql = "SELECT tiger_line_import(%s, %s, %s, %s, %s, %s)"
+
+    for row in csv.DictReader(fd, delimiter=';'):
+        try:
+            address = dict(street=row['street'], postcode=row['postcode'])
+            args = ('SRID=4326;' + row['geometry'],
+                    int(row['from']), int(row['to']), row['interpolation'],
+                    PlaceInfo({'address': address}).analyze(analyzer),
+                    analyzer.normalize_postcode(row['postcode']))
+        except ValueError:
+            continue
+        pool.next_free_worker().perform(sql, args=args)
 
 
-    sql_files, tar = handle_tarfile_or_directory(data_dir)
+        lines += 1
+        if lines == 1000:
+            print('.', end='', flush=True)
+            lines = 0
 
 
-    if not sql_files:
+
+def add_tiger_data(data_dir, config, threads, tokenizer):
+    """ Import tiger data from directory or tar file `data dir`.
+    """
+    dsn = config.get_libpq_dsn()
+    files, tar = handle_tarfile_or_directory(data_dir)
+
+    if not files:
         return
 
     with connect(dsn) as conn:
         return
 
     with connect(dsn) as conn:
-        sql = SQLPreprocessor(conn, config, sqllib_dir)
+        sql = SQLPreprocessor(conn, config)
         sql.run_sql_file(conn, 'tiger_import_start.sql')
 
         sql.run_sql_file(conn, 'tiger_import_start.sql')
 
-    # Reading sql_files and then for each file line handling
+    # Reading files and then for each file line handling
     # sql_query in <threads - 1> chunks.
     # sql_query in <threads - 1> chunks.
-    sel = selectors.DefaultSelector()
     place_threads = max(1, threads - 1)
 
     place_threads = max(1, threads - 1)
 
-    # Creates a pool of database connections
-    for _ in range(place_threads):
-        conn = DBConnection(dsn)
-        conn.connect()
-        sel.register(conn, selectors.EVENT_WRITE, conn)
-
-    for sql_file in sql_files:
-        if not tar:
-            file = open(sql_file)
-        else:
-            file = tar.extractfile(sql_file)
+    with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
+        with tokenizer.name_analyzer() as analyzer:
+            for fname in files:
+                if not tar:
+                    fd = open(fname)
+                else:
+                    fd = io.TextIOWrapper(tar.extractfile(fname))
 
 
-        handle_threaded_sql_statements(sel, file)
+                handle_threaded_sql_statements(pool, fd, analyzer)
 
 
-    # Unregistering pool of database connections
-    handle_unregister_connection_pool(sel, place_threads)
+                fd.close()
 
     if tar:
         tar.close()
     print('\n')
     LOG.warning("Creating indexes on Tiger data")
     with connect(dsn) as conn:
 
     if tar:
         tar.close()
     print('\n')
     LOG.warning("Creating indexes on Tiger data")
     with connect(dsn) as conn:
-        sql = SQLPreprocessor(conn, config, sqllib_dir)
+        sql = SQLPreprocessor(conn, config)
         sql.run_sql_file(conn, 'tiger_import_finish.sql')
         sql.run_sql_file(conn, 'tiger_import_finish.sql')