]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_db/tools/database_import.py
make DB helper functions free functions
[nominatim.git] / src / nominatim_db / tools / database_import.py
index c4b3023a8585d57d28ffee15f06d1c3ddc4f60f5..2398d4043a734ceb6b3cf45971e998133fe6456a 100644 (file)
@@ -19,7 +19,8 @@ from psycopg2 import sql as pysql
 
 from ..errors import UsageError
 from ..config import Configuration
-from ..db.connection import connect, get_pg_env, Connection
+from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\
+                            postgis_version_tuple, drop_tables, table_exists, execute_scalar
 from ..db.async_connection import DBConnection
 from ..db.sql_preprocessor import SQLPreprocessor
 from .exec_utils import run_osm2pgsql
@@ -51,10 +52,10 @@ def check_existing_database_plugins(dsn: str) -> None:
     """ Check that the database has the required plugins installed."""
     with connect(dsn) as conn:
         _require_version('PostgreSQL server',
-                         conn.server_version_tuple(),
+                         server_version_tuple(conn),
                          POSTGRESQL_REQUIRED_VERSION)
         _require_version('PostGIS',
-                         conn.postgis_version_tuple(),
+                         postgis_version_tuple(conn),
                          POSTGIS_REQUIRED_VERSION)
         _require_loaded('hstore', conn)
 
@@ -80,31 +81,30 @@ def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
 
     with connect(dsn) as conn:
         _require_version('PostgreSQL server',
-                         conn.server_version_tuple(),
+                         server_version_tuple(conn),
                          POSTGRESQL_REQUIRED_VERSION)
 
         if rouser is not None:
-            with conn.cursor() as cur:
-                cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
+            cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s',
                                  (rouser, ))
-                if cnt == 0:
-                    LOG.fatal("Web user '%s' does not exist. Create it with:\n"
-                              "\n      createuser %s", rouser, rouser)
-                    raise UsageError('Missing read-only user.')
+            if cnt == 0:
+                LOG.fatal("Web user '%s' does not exist. Create it with:\n"
+                          "\n      createuser %s", rouser, rouser)
+                raise UsageError('Missing read-only user.')
 
         # Create extensions.
         with conn.cursor() as cur:
             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
 
-            postgis_version = conn.postgis_version_tuple()
+            postgis_version = postgis_version_tuple(conn)
             if postgis_version[0] >= 3:
                 cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
 
         conn.commit()
 
         _require_version('PostGIS',
-                         conn.postgis_version_tuple(),
+                         postgis_version_tuple(conn),
                          POSTGIS_REQUIRED_VERSION)
 
 
@@ -141,7 +141,8 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]],
                     raise UsageError('No data imported by osm2pgsql.')
 
         if drop:
-            conn.drop_table('planet_osm_nodes')
+            drop_tables(conn, 'planet_osm_nodes')
+            conn.commit()
 
     if drop and options['flatnode_file']:
         Path(options['flatnode_file']).unlink()
@@ -184,7 +185,7 @@ def truncate_data_tables(conn: Connection) -> None:
         cur.execute('TRUNCATE location_property_tiger')
         cur.execute('TRUNCATE location_property_osmline')
         cur.execute('TRUNCATE location_postcode')
-        if conn.table_exists('search_name'):
+        if table_exists(conn, 'search_name'):
             cur.execute('TRUNCATE search_name')
         cur.execute('DROP SEQUENCE IF EXISTS seq_place')
         cur.execute('CREATE SEQUENCE seq_place start 100000')