]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_db/tools/tiger_data.py
ensure consistent country assignments
[nominatim.git] / src / nominatim_db / tools / tiger_data.py
index 6030a38aef8fdf63f73a699a59bf00fa7647bb83..f4a7eba770b618df52b91f33946a4b761110b52e 100644 (file)
@@ -7,22 +7,22 @@
 """
 Functions for importing tiger data and handling tarbar and directory files
 """
-from typing import Any, TextIO, List, Union, cast
+from typing import Any, TextIO, List, Union, cast, Iterator, Dict
 import csv
 import io
 import logging
 import os
 import tarfile
 
-from psycopg2.extras import Json
+from psycopg.types.json import Json
 
-from nominatim_core.config import Configuration
-from nominatim_core.db.connection import connect
-from nominatim_core.db.async_connection import WorkerPool
-from nominatim_core.db.sql_preprocessor import SQLPreprocessor
-from nominatim_core.errors import UsageError
+from ..config import Configuration
+from ..db.connection import connect
+from ..db.sql_preprocessor import SQLPreprocessor
+from ..errors import UsageError
+from ..db.query_pool import QueryPool
 from ..data.place_info import PlaceInfo
-from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
+from ..tokenizer.base import AbstractTokenizer
 from . import freeze
 
 LOG = logging.getLogger()
@@ -63,13 +63,13 @@ class TigerInput:
             self.tar_handle.close()
             self.tar_handle = None
 
+    def __bool__(self) -> bool:
+        return bool(self.files)
 
-    def next_file(self) -> TextIO:
+    def get_file(self, fname: Union[str, tarfile.TarInfo]) -> TextIO:
         """ Return a file handle to the next file to be processed.
             Raises an IndexError if there is no file left.
         """
-        fname = self.files.pop(0)
-
         if self.tar_handle is not None:
             extracted = self.tar_handle.extractfile(fname)
             assert extracted is not None
@@ -78,47 +78,22 @@ class TigerInput:
         return open(cast(str, fname), encoding='utf-8')
 
 
-    def __len__(self) -> int:
-        return len(self.files)
-
-
-def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO,
-                                   analyzer: AbstractAnalyzer) -> None:
-    """ Handles sql statement with multiplexing
-    """
-    lines = 0
-    # Using pool of database connections to execute sql statements
-
-    sql = "SELECT tiger_line_import(%s, %s, %s, %s, %s, %s)"
-
-    for row in csv.DictReader(fd, delimiter=';'):
-        try:
-            address = dict(street=row['street'], postcode=row['postcode'])
-            args = ('SRID=4326;' + row['geometry'],
-                    int(row['from']), int(row['to']), row['interpolation'],
-                    Json(analyzer.process_place(PlaceInfo({'address': address}))),
-                    analyzer.normalize_postcode(row['postcode']))
-        except ValueError:
-            continue
-        pool.next_free_worker().perform(sql, args=args)
-
-        lines += 1
-        if lines == 1000:
-            print('.', end='', flush=True)
-            lines = 0
+    def __iter__(self) -> Iterator[Dict[str, Any]]:
+        """ Iterate over the lines in each file.
+        """
+        for fname in self.files:
+            fd = self.get_file(fname)
+            yield from csv.DictReader(fd, delimiter=';')
 
 
-def add_tiger_data(data_dir: str, config: Configuration, threads: int,
+async def add_tiger_data(data_dir: str, config: Configuration, threads: int,
                    tokenizer: AbstractTokenizer) -> int:
     """ Import tiger data from directory or tar file `data dir`.
     """
     dsn = config.get_libpq_dsn()
 
     with connect(dsn) as conn:
-        is_frozen = freeze.is_frozen(conn)
-        conn.close()
-
-        if is_frozen:
+        if freeze.is_frozen(conn):
             raise UsageError("Tiger cannot be imported when database frozen (Github issue #3048)")
 
     with TigerInput(data_dir) as tar:
@@ -133,13 +108,30 @@ def add_tiger_data(data_dir: str, config: Configuration, threads: int,
         # sql_query in <threads - 1> chunks.
         place_threads = max(1, threads - 1)
 
-        with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
+        async with QueryPool(dsn, place_threads, autocommit=True) as pool:
             with tokenizer.name_analyzer() as analyzer:
-                while tar:
-                    with tar.next_file() as fd:
-                        handle_threaded_sql_statements(pool, fd, analyzer)
-
-        print('\n')
+                lines = 0
+                for row in tar:
+                    try:
+                        address = dict(street=row['street'], postcode=row['postcode'])
+                        args = ('SRID=4326;' + row['geometry'],
+                                int(row['from']), int(row['to']), row['interpolation'],
+                                Json(analyzer.process_place(PlaceInfo({'address': address}))),
+                                analyzer.normalize_postcode(row['postcode']))
+                    except ValueError:
+                        continue
+
+                    await pool.put_query(
+                        """SELECT tiger_line_import(%s::GEOMETRY, %s::INT,
+                                                    %s::INT, %s::TEXT, %s::JSONB, %s::TEXT)""",
+                        args)
+
+                    lines += 1
+                    if lines == 1000:
+                        print('.', end='', flush=True)
+                    lines = 0
+
+        print('', flush=True)
 
     LOG.warning("Creating indexes on Tiger data")
     with connect(dsn) as conn: