]> git.openstreetmap.org Git - nominatim.git/blobdiff - test/python/test_indexing.py
Merge pull request #2155 from lonvia/port-regresh-to-python
[nominatim.git] / test / python / test_indexing.py
index e1a3a4d064433b86394a113933d39b3f89be8e9c..6b52a65ea6171d318e2d6d6d81a9a1ac51cbe2d2 100644 (file)
@@ -9,19 +9,11 @@ from nominatim.indexer.indexer import Indexer
 
 class IndexerTestDB:
 
-    def __init__(self, name):
-        self.name = name
-        self.conn = None
+    def __init__(self, conn):
         self.placex_id = itertools.count(100000)
         self.osmline_id = itertools.count(500000)
 
-    def setup(self):
-        with psycopg2.connect(database='postgres') as conn:
-            conn.set_isolation_level(0)
-            with conn.cursor() as cur:
-                cur.execute('DROP DATABASE IF EXISTS {}'.format(self.name))
-                cur.execute('CREATE DATABASE {}'.format(self.name))
-        self.conn = psycopg2.connect(database=self.name)
+        self.conn = conn
         self.conn.set_isolation_level(0)
         with self.conn.cursor() as cur:
             cur.execute("""CREATE TABLE placex (place_id BIGINT,
@@ -52,16 +44,6 @@ class IndexerTestDB:
             cur.execute("""CREATE TRIGGER osmline_update BEFORE UPDATE ON location_property_osmline
                            FOR EACH ROW EXECUTE PROCEDURE date_update()""")
 
-
-    def drop(self):
-        if self.conn:
-            self.conn.close()
-            self.conn = None
-        with psycopg2.connect(database='postgres') as conn:
-            conn.set_isolation_level(0)
-            with conn.cursor() as cur:
-                cur.execute('DROP DATABASE IF EXISTS {}'.format(self.name))
-
     def scalar(self, query):
         with self.conn.cursor() as cur:
             cur.execute(query)
@@ -100,11 +82,8 @@ class IndexerTestDB:
 
 
 @pytest.fixture
-def test_db():
-    db = IndexerTestDB('test_nominatim_python_unittest')
-    db.setup()
-    yield db
-    db.drop()
+def test_db(temp_db_conn):
+    yield IndexerTestDB(temp_db_conn)
 
 
 @pytest.mark.parametrize("threads", [1, 15])