]> git.openstreetmap.org Git - nominatim.git/commitdiff
indexer: make self.conn function-local
authorSarah Hoffmann <lonvia@denofr.de>
Mon, 19 Apr 2021 16:15:09 +0000 (18:15 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Tue, 20 Apr 2021 12:08:37 +0000 (14:08 +0200)
Also switches to our internal connect function which gives us
a cursor with a sclar() function.

nominatim/indexer/indexer.py

index 7b826d96182eb69100b84339a6c4df75079bb5fe..aa1fb8efb80c4c6156f0547e1fa71c1e76cc72d4 100644 (file)
@@ -4,11 +4,10 @@ Main work horse for indexing (computing addresses) the database.
 import logging
 import select
 
 import logging
 import select
 
-import psycopg2
-
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
 from nominatim.db.async_connection import DBConnection
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
 from nominatim.db.async_connection import DBConnection
+from nominatim.db.connection import connect
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
@@ -20,20 +19,14 @@ class Indexer:
     def __init__(self, dsn, num_threads):
         self.dsn = dsn
         self.num_threads = num_threads
     def __init__(self, dsn, num_threads):
         self.dsn = dsn
         self.num_threads = num_threads
-        self.conn = None
         self.threads = []
 
 
     def _setup_connections(self):
         self.threads = []
 
 
     def _setup_connections(self):
-        self.conn = psycopg2.connect(self.dsn)
         self.threads = [DBConnection(self.dsn) for _ in range(self.num_threads)]
 
 
     def _close_connections(self):
         self.threads = [DBConnection(self.dsn) for _ in range(self.num_threads)]
 
 
     def _close_connections(self):
-        if self.conn:
-            self.conn.close()
-            self.conn = None
-
         for thread in self.threads:
             thread.close()
         self.threads = []
         for thread in self.threads:
             thread.close()
         self.threads = []
@@ -45,7 +38,7 @@ class Indexer:
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
         """
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
         """
-        with psycopg2.connect(self.dsn) as conn:
+        with connect(self.dsn) as conn:
             conn.autocommit = True
 
             if analyse:
             conn.autocommit = True
 
             if analyse:
@@ -128,15 +121,11 @@ class Indexer:
     def update_status_table(self):
         """ Update the status in the status table to 'indexed'.
         """
     def update_status_table(self):
         """ Update the status in the status table to 'indexed'.
         """
-        conn = psycopg2.connect(self.dsn)
-
-        try:
+        with connect(self.dsn) as conn:
             with conn.cursor() as cur:
                 cur.execute('UPDATE import_status SET indexed = true')
 
             conn.commit()
             with conn.cursor() as cur:
                 cur.execute('UPDATE import_status SET indexed = true')
 
             conn.commit()
-        finally:
-            conn.close()
 
     def _index(self, runner, batch=1):
         """ Index a single rank or table. `runner` describes the SQL to use
 
     def _index(self, runner, batch=1):
         """ Index a single rank or table. `runner` describes the SQL to use
@@ -145,36 +134,35 @@ class Indexer:
         """
         LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
         """
         LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
-        cur = self.conn.cursor()
-        cur.execute(runner.sql_count_objects())
-
-        total_tuples = cur.fetchone()[0]
-        LOG.debug("Total number of rows: %i", total_tuples)
+        with connect(self.dsn) as conn:
+            with conn.cursor() as cur:
+                total_tuples = cur.scalar(runner.sql_count_objects())
+                LOG.debug("Total number of rows: %i", total_tuples)
 
 
-        cur.close()
+            conn.commit()
 
 
-        progress = ProgressLogger(runner.name(), total_tuples)
+            progress = ProgressLogger(runner.name(), total_tuples)
 
 
-        if total_tuples > 0:
-            cur = self.conn.cursor(name='places')
-            cur.execute(runner.sql_get_objects())
+            if total_tuples > 0:
+                with conn.cursor(name='places') as cur:
+                    cur.execute(runner.sql_get_objects())
 
 
-            next_thread = self.find_free_thread()
-            while True:
-                places = [p[0] for p in cur.fetchmany(batch)]
-                if not places:
-                    break
+                    next_thread = self.find_free_thread()
+                    while True:
+                        places = [p[0] for p in cur.fetchmany(batch)]
+                        if not places:
+                            break
 
 
-                LOG.debug("Processing places: %s", str(places))
-                thread = next(next_thread)
+                        LOG.debug("Processing places: %s", str(places))
+                        thread = next(next_thread)
 
 
-                thread.perform(runner.sql_index_place(places))
-                progress.add(len(places))
+                        thread.perform(runner.sql_index_place(places))
+                        progress.add(len(places))
 
 
-            cur.close()
+            conn.commit()
 
 
-            for thread in self.threads:
-                thread.wait()
+        for thread in self.threads:
+            thread.wait()
 
         progress.done()
 
 
         progress.done()