]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_db/tools/migration.py
restrict use of os.environ in Configuration
[nominatim.git] / src / nominatim_db / tools / migration.py
index e6803c7dea23521a805aa939584d028f66b4a7b5..5483653230bb6c2eff94109a0c87e6539727c0d9 100644 (file)
@@ -10,12 +10,13 @@ Functions for database migration to newer software versions.
 from typing import List, Tuple, Callable, Any
 import logging
 
-from psycopg2 import sql as pysql
+from psycopg import sql as pysql
 
 from ..errors import UsageError
 from ..config import Configuration
 from ..db import properties
-from ..db.connection import connect, Connection
+from ..db.connection import connect, Connection, server_version_tuple,\
+                            table_has_column, table_exists, execute_scalar, register_hstore
 from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version
 from ..tokenizer import factory as tokenizer_factory
 from . import refresh
@@ -29,7 +30,8 @@ def migrate(config: Configuration, paths: Any) -> int:
         if necesssary.
     """
     with connect(config.get_libpq_dsn()) as conn:
-        if conn.table_exists('nominatim_properties'):
+        register_hstore(conn)
+        if table_exists(conn, 'nominatim_properties'):
             db_version_str = properties.get_property(conn, 'database_version')
         else:
             db_version_str = None
@@ -72,16 +74,15 @@ def _guess_version(conn: Connection) -> NominatimVersion:
         Only migrations for 3.6 and later are supported, so bail out
         when the version seems older.
     """
-    with conn.cursor() as cur:
-        # In version 3.6, the country_name table was updated. Check for that.
-        cnt = cur.scalar("""SELECT count(*) FROM
-                            (SELECT svals(name) FROM  country_name
-                             WHERE country_code = 'gb')x;
-                         """)
-        if cnt < 100:
-            LOG.fatal('It looks like your database was imported with a version '
-                      'prior to 3.6.0. Automatic migration not possible.')
-            raise UsageError('Migration not possible.')
+    # In version 3.6, the country_name table was updated. Check for that.
+    cnt = execute_scalar(conn, """SELECT count(*) FROM
+                                  (SELECT svals(name) FROM  country_name
+                                  WHERE country_code = 'gb')x;
+                               """)
+    if cnt < 100:
+        LOG.fatal('It looks like your database was imported with a version '
+                  'prior to 3.6.0. Automatic migration not possible.')
+        raise UsageError('Migration not possible.')
 
     return NominatimVersion(3, 5, 0, 99)
 
@@ -125,7 +126,7 @@ def import_status_timestamp_change(conn: Connection, **_: Any) -> None:
 def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None:
     """ Add nominatim_property table.
     """
-    if not conn.table_exists('nominatim_properties'):
+    if not table_exists(conn, 'nominatim_properties'):
         with conn.cursor() as cur:
             cur.execute(pysql.SQL("""CREATE TABLE nominatim_properties (
                                         property TEXT,
@@ -189,13 +190,9 @@ def install_legacy_tokenizer(conn: Connection, config: Configuration, **_: Any)
         configuration for the backwards-compatible legacy tokenizer
     """
     if properties.get_property(conn, 'tokenizer') is None:
-        with conn.cursor() as cur:
-            for table in ('placex', 'location_property_osmline'):
-                has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
-                                           WHERE table_name = %s
-                                           and column_name = 'token_info'""",
-                                        (table, ))
-                if has_column == 0:
+        for table in ('placex', 'location_property_osmline'):
+            if not table_has_column(conn, table, 'token_info'):
+                with conn.cursor() as cur:
                     cur.execute(pysql.SQL('ALTER TABLE {} ADD COLUMN token_info JSONB')
                                 .format(pysql.Identifier(table)))
         tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False,
@@ -212,7 +209,7 @@ def create_tiger_housenumber_index(conn: Connection, **_: Any) -> None:
         The inclusion is needed for efficient lookup of housenumbers in
         full address searches.
     """
-    if conn.server_version_tuple() >= (11, 0, 0):
+    if server_version_tuple(conn) >= (11, 0, 0):
         with conn.cursor() as cur:
             cur.execute(""" CREATE INDEX IF NOT EXISTS
                                 idx_location_property_tiger_housenumber_migrated
@@ -239,7 +236,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
         Also converts the data into the stricter format which requires that
         startnumbers comply with the odd/even requirements.
     """
-    if conn.table_has_column('location_property_osmline', 'step'):
+    if table_has_column(conn, 'location_property_osmline', 'step'):
         return
 
     with conn.cursor() as cur:
@@ -271,7 +268,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
 def add_step_column_for_tiger(conn: Connection, **_: Any) -> None:
     """ Add a new column 'step' to the tiger data table.
     """
-    if conn.table_has_column('location_property_tiger', 'step'):
+    if table_has_column(conn, 'location_property_tiger', 'step'):
         return
 
     with conn.cursor() as cur:
@@ -287,7 +284,7 @@ def add_derived_name_column_for_country_names(conn: Connection, **_: Any) -> Non
     """ Add a new column 'derived_name' which in the future takes the
         country names as imported from OSM data.
     """
-    if not conn.table_has_column('country_name', 'derived_name'):
+    if not table_has_column(conn, 'country_name', 'derived_name'):
         with conn.cursor() as cur:
             cur.execute("ALTER TABLE country_name ADD COLUMN derived_name public.HSTORE")
 
@@ -297,12 +294,9 @@ def mark_internal_country_names(conn: Connection, config: Configuration, **_: An
     """ Names from the country table should be marked as internal to prevent
         them from being deleted. Only necessary for ICU tokenizer.
     """
-    import psycopg2.extras # pylint: disable=import-outside-toplevel
-
     tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
     with tokenizer.name_analyzer() as analyzer:
         with conn.cursor() as cur:
-            psycopg2.extras.register_hstore(cur)
             cur.execute("SELECT country_code, name FROM country_name")
 
             for country_code, names in cur:
@@ -319,7 +313,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
         The table is only necessary when updates are possible, i.e.
         the database is not in freeze mode.
     """
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         with conn.cursor() as cur:
             cur.execute("""CREATE TABLE IF NOT EXISTS place_to_be_deleted (
                              osm_type CHAR(1),
@@ -333,7 +327,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
 def split_pending_index(conn: Connection, **_: Any) -> None:
     """ Reorganise indexes for pending updates.
     """
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS idx_placex_rank_address_sector
                            ON placex USING BTREE (rank_address, geometry_sector)
@@ -349,7 +343,7 @@ def split_pending_index(conn: Connection, **_: Any) -> None:
 def enable_forward_dependencies(conn: Connection, **_: Any) -> None:
     """ Create indexes for updates with forward dependency tracking (long-running).
     """
-    if conn.table_exists('planet_osm_ways'):
+    if table_exists(conn, 'planet_osm_ways'):
         with conn.cursor() as cur:
             cur.execute("""SELECT * FROM pg_indexes
                            WHERE tablename = 'planet_osm_ways'
@@ -398,7 +392,7 @@ def create_postcode_area_lookup_index(conn: Connection, **_: Any) -> None:
 def create_postcode_parent_index(conn: Connection, **_: Any) -> None:
     """ Create index needed for updating postcodes when a parent changes.
     """
-    if conn.table_exists('planet_osm_ways'):
+    if table_exists(conn, 'planet_osm_ways'):
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS
                              idx_location_postcode_parent_place_id