]> git.openstreetmap.org Git - nominatim.git/blobdiff - test/python/test_tools_refresh_create_functions.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / test / python / test_tools_refresh_create_functions.py
index ac2f221126f79c4da85249a8e5facb9c7bcd6f68..40d4c81af315a8014a6571ce5891e1f954f707a9 100644 (file)
@@ -1,98 +1,47 @@
 """
 Tests for creating PL/pgSQL functions for Nominatim.
 """
-from pathlib import Path
 import pytest
 
-from nominatim.db.connection import connect
-from nominatim.tools.refresh import _get_standard_function_sql, _get_partition_function_sql
-
-SQL_DIR = (Path(__file__) / '..' / '..' / '..' / 'lib-sql').resolve()
-
-@pytest.fixture
-def db(temp_db):
-    with connect('dbname=' + temp_db) as conn:
-        yield conn
+from nominatim.tools.refresh import create_functions
 
 @pytest.fixture
-def db_with_tables(db):
-    with db.cursor() as cur:
-        for table in ('place', 'placex', 'location_postcode'):
-            cur.execute('CREATE TABLE {} (place_id BIGINT)'.format(table))
-
-    return db
-
-
-def test_standard_functions_replace_module_default(db, def_config):
-    def_config.project_dir = Path('.')
-    sql = _get_standard_function_sql(db, def_config, SQL_DIR, False, False)
-
-    assert sql
-    assert sql.find('{modulepath}') < 0
-    assert sql.find("'{}'".format(Path('module/nominatim.so').resolve())) >= 0
-
-
-def test_standard_functions_replace_module_custom(monkeypatch, db, def_config):
-    monkeypatch.setenv('NOMINATIM_DATABASE_MODULE_PATH', 'custom')
-    sql = _get_standard_function_sql(db, def_config, SQL_DIR, False, False)
-
-    assert sql
-    assert sql.find('{modulepath}') < 0
-    assert sql.find("'custom/nominatim.so'") >= 0
-
-
-@pytest.mark.parametrize("enabled", (True, False))
-def test_standard_functions_enable_diff(db_with_tables, def_config, enabled):
-    def_config.project_dir = Path('.')
-    sql = _get_standard_function_sql(db_with_tables, def_config, SQL_DIR, enabled, False)
-
-    assert sql
-    assert (sql.find('%DIFFUPDATES%') < 0) == enabled
-
-
-@pytest.mark.parametrize("enabled", (True, False))
-def test_standard_functions_enable_debug(db_with_tables, def_config, enabled):
-    def_config.project_dir = Path('.')
-    sql = _get_standard_function_sql(db_with_tables, def_config, SQL_DIR, False, enabled)
-
-    assert sql
-    assert (sql.find('--DEBUG') < 0) == enabled
-
-
-@pytest.mark.parametrize("enabled", (True, False))
-def test_standard_functions_enable_limit_reindexing(monkeypatch, db_with_tables, def_config, enabled):
-    def_config.project_dir = Path('.')
-    monkeypatch.setenv('NOMINATIM_LIMIT_REINDEXING', 'yes' if enabled else 'no')
-    sql = _get_standard_function_sql(db_with_tables, def_config, SQL_DIR, False, False)
-
-    assert sql
-    assert (sql.find('--LIMIT INDEXING') < 0) == enabled
-
-
-@pytest.mark.parametrize("enabled", (True, False))
-def test_standard_functions_enable_tiger(monkeypatch, db_with_tables, def_config, enabled):
-    def_config.project_dir = Path('.')
-    monkeypatch.setenv('NOMINATIM_USE_US_TIGER_DATA', 'yes' if enabled else 'no')
-    sql = _get_standard_function_sql(db_with_tables, def_config, SQL_DIR, False, False)
-
-    assert sql
-    assert (sql.find('%NOTIGERDATA%') >= 0) == enabled
-
-
-@pytest.mark.parametrize("enabled", (True, False))
-def test_standard_functions_enable_aux(monkeypatch, db_with_tables, def_config, enabled):
-    def_config.project_dir = Path('.')
-    monkeypatch.setenv('NOMINATIM_USE_AUX_LOCATION_DATA', 'yes' if enabled else 'no')
-    sql = _get_standard_function_sql(db_with_tables, def_config, SQL_DIR, False, False)
-
-    assert sql
-    assert (sql.find('%NOAUXDATA%') >= 0) == enabled
-
-
-def test_partition_function(temp_db_cursor, db, def_config):
-    temp_db_cursor.execute("CREATE TABLE country_name (partition SMALLINT)")
-
-    sql = _get_partition_function_sql(db, SQL_DIR)
-
-    assert sql
-    assert sql.find('-partition-') < 0
+def conn(temp_db_conn, table_factory, monkeypatch):
+    monkeypatch.setenv('NOMINATIM_DATABASE_MODULE_PATH', '.')
+    table_factory('country_name', 'partition INT', (0, 1, 2))
+    return temp_db_conn
+
+
+def test_create_functions(temp_db_cursor, conn, def_config, tmp_path):
+    sqlfile = tmp_path / 'functions.sql'
+    sqlfile.write_text("""CREATE OR REPLACE FUNCTION test() RETURNS INTEGER
+                          AS $$
+                          BEGIN
+                            RETURN 43;
+                          END;
+                          $$ LANGUAGE plpgsql IMMUTABLE;
+                       """)
+
+    create_functions(conn, def_config, tmp_path)
+
+    assert temp_db_cursor.scalar('SELECT test()') == 43
+
+
+@pytest.mark.parametrize("dbg,ret", ((True, 43), (False, 22)))
+def test_create_functions_with_template(temp_db_cursor, conn, def_config, tmp_path, dbg, ret):
+    sqlfile = tmp_path / 'functions.sql'
+    sqlfile.write_text("""CREATE OR REPLACE FUNCTION test() RETURNS INTEGER
+                          AS $$
+                          BEGIN
+                            {% if debug %}
+                            RETURN 43;
+                            {% else %}
+                            RETURN 22;
+                            {% endif %}
+                          END;
+                          $$ LANGUAGE plpgsql IMMUTABLE;
+                       """)
+
+    create_functions(conn, def_config, tmp_path, enable_debug=dbg)
+
+    assert temp_db_cursor.scalar('SELECT test()') == ret