]> git.openstreetmap.org Git - nominatim.git/blobdiff - test/python/tools/test_refresh_create_functions.py
Merge pull request #3424 from lonvia/importance-csc-import
[nominatim.git] / test / python / tools / test_refresh_create_functions.py
index 00b863ab1e621289c3a38be1dbd4d65a97496e6c..8d26e7554dd6d1613621133485226dad0015037d 100644 (file)
@@ -1,3 +1,9 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
 Tests for creating PL/pgSQL functions for Nominatim.
 """
 """
 Tests for creating PL/pgSQL functions for Nominatim.
 """
@@ -5,47 +11,47 @@ import pytest
 
 from nominatim.tools.refresh import create_functions
 
 
 from nominatim.tools.refresh import create_functions
 
-@pytest.fixture
-def sql_tmp_path(tmp_path, def_config):
-    def_config.lib_dir.sql = tmp_path
-    return tmp_path
-
-@pytest.fixture
-def conn(sql_preprocessor, temp_db_conn):
-    return temp_db_conn
-
-
-def test_create_functions(temp_db_cursor, conn, def_config, sql_tmp_path):
-    sqlfile = sql_tmp_path / 'functions.sql'
-    sqlfile.write_text("""CREATE OR REPLACE FUNCTION test() RETURNS INTEGER
-                          AS $$
-                          BEGIN
-                            RETURN 43;
-                          END;
-                          $$ LANGUAGE plpgsql IMMUTABLE;
-                       """)
-
-    create_functions(conn, def_config)
-
-    assert temp_db_cursor.scalar('SELECT test()') == 43
-
-
-@pytest.mark.parametrize("dbg,ret", ((True, 43), (False, 22)))
-def test_create_functions_with_template(temp_db_cursor, conn, def_config, sql_tmp_path,
-                                        dbg, ret):
-    sqlfile = sql_tmp_path / 'functions.sql'
-    sqlfile.write_text("""CREATE OR REPLACE FUNCTION test() RETURNS INTEGER
-                          AS $$
-                          BEGIN
-                            {% if debug %}
-                            RETURN 43;
-                            {% else %}
-                            RETURN 22;
-                            {% endif %}
-                          END;
-                          $$ LANGUAGE plpgsql IMMUTABLE;
-                       """)
-
-    create_functions(conn, def_config, enable_debug=dbg)
-
-    assert temp_db_cursor.scalar('SELECT test()') == ret
+class TestCreateFunctions:
+    @pytest.fixture(autouse=True)
+    def init_env(self, sql_preprocessor, temp_db_conn, def_config, tmp_path):
+        self.conn = temp_db_conn
+        self.config = def_config
+        def_config.lib_dir.sql = tmp_path
+
+
+    def write_functions(self, content):
+        sqlfile = self.config.lib_dir.sql / 'functions.sql'
+        sqlfile.write_text(content)
+
+
+    def test_create_functions(self, temp_db_cursor):
+        self.write_functions("""CREATE OR REPLACE FUNCTION test() RETURNS INTEGER
+                              AS $$
+                              BEGIN
+                                RETURN 43;
+                              END;
+                              $$ LANGUAGE plpgsql IMMUTABLE;
+                           """)
+
+        create_functions(self.conn, self.config)
+
+        assert temp_db_cursor.scalar('SELECT test()') == 43
+
+
+    @pytest.mark.parametrize("dbg,ret", ((True, 43), (False, 22)))
+    def test_create_functions_with_template(self, temp_db_cursor, dbg, ret):
+        self.write_functions("""CREATE OR REPLACE FUNCTION test() RETURNS INTEGER
+                              AS $$
+                              BEGIN
+                                {% if debug %}
+                                RETURN 43;
+                                {% else %}
+                                RETURN 22;
+                                {% endif %}
+                              END;
+                              $$ LANGUAGE plpgsql IMMUTABLE;
+                           """)
+
+        create_functions(self.conn, self.config, enable_debug=dbg)
+
+        assert temp_db_cursor.scalar('SELECT test()') == ret