]> git.openstreetmap.org Git - nominatim.git/blobdiff - src/nominatim_db/tools/migration.py
allow None and str for project_dir in NominatimAPI init
[nominatim.git] / src / nominatim_db / tools / migration.py
index 0712f187ff51dd1b659b2995c3789387d5b2b75b..5483653230bb6c2eff94109a0c87e6539727c0d9 100644 (file)
@@ -10,12 +10,13 @@ Functions for database migration to newer software versions.
 from typing import List, Tuple, Callable, Any
 import logging
 
 from typing import List, Tuple, Callable, Any
 import logging
 
-from psycopg2 import sql as pysql
+from psycopg import sql as pysql
 
 
-from nominatim_core.errors import UsageError
-from nominatim_core.config import Configuration
-from nominatim_core.db import properties
-from nominatim_core.db.connection import connect, Connection
+from ..errors import UsageError
+from ..config import Configuration
+from ..db import properties
+from ..db.connection import connect, Connection, server_version_tuple,\
+                            table_has_column, table_exists, execute_scalar, register_hstore
 from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version
 from ..tokenizer import factory as tokenizer_factory
 from . import refresh
 from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version
 from ..tokenizer import factory as tokenizer_factory
 from . import refresh
@@ -29,7 +30,8 @@ def migrate(config: Configuration, paths: Any) -> int:
         if necesssary.
     """
     with connect(config.get_libpq_dsn()) as conn:
         if necesssary.
     """
     with connect(config.get_libpq_dsn()) as conn:
-        if conn.table_exists('nominatim_properties'):
+        register_hstore(conn)
+        if table_exists(conn, 'nominatim_properties'):
             db_version_str = properties.get_property(conn, 'database_version')
         else:
             db_version_str = None
             db_version_str = properties.get_property(conn, 'database_version')
         else:
             db_version_str = None
@@ -72,16 +74,15 @@ def _guess_version(conn: Connection) -> NominatimVersion:
         Only migrations for 3.6 and later are supported, so bail out
         when the version seems older.
     """
         Only migrations for 3.6 and later are supported, so bail out
         when the version seems older.
     """
-    with conn.cursor() as cur:
-        # In version 3.6, the country_name table was updated. Check for that.
-        cnt = cur.scalar("""SELECT count(*) FROM
-                            (SELECT svals(name) FROM  country_name
-                             WHERE country_code = 'gb')x;
-                         """)
-        if cnt < 100:
-            LOG.fatal('It looks like your database was imported with a version '
-                      'prior to 3.6.0. Automatic migration not possible.')
-            raise UsageError('Migration not possible.')
+    # In version 3.6, the country_name table was updated. Check for that.
+    cnt = execute_scalar(conn, """SELECT count(*) FROM
+                                  (SELECT svals(name) FROM  country_name
+                                  WHERE country_code = 'gb')x;
+                               """)
+    if cnt < 100:
+        LOG.fatal('It looks like your database was imported with a version '
+                  'prior to 3.6.0. Automatic migration not possible.')
+        raise UsageError('Migration not possible.')
 
     return NominatimVersion(3, 5, 0, 99)
 
 
     return NominatimVersion(3, 5, 0, 99)
 
@@ -125,7 +126,7 @@ def import_status_timestamp_change(conn: Connection, **_: Any) -> None:
 def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None:
     """ Add nominatim_property table.
     """
 def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None:
     """ Add nominatim_property table.
     """
-    if not conn.table_exists('nominatim_properties'):
+    if not table_exists(conn, 'nominatim_properties'):
         with conn.cursor() as cur:
             cur.execute(pysql.SQL("""CREATE TABLE nominatim_properties (
                                         property TEXT,
         with conn.cursor() as cur:
             cur.execute(pysql.SQL("""CREATE TABLE nominatim_properties (
                                         property TEXT,
@@ -189,13 +190,9 @@ def install_legacy_tokenizer(conn: Connection, config: Configuration, **_: Any)
         configuration for the backwards-compatible legacy tokenizer
     """
     if properties.get_property(conn, 'tokenizer') is None:
         configuration for the backwards-compatible legacy tokenizer
     """
     if properties.get_property(conn, 'tokenizer') is None:
-        with conn.cursor() as cur:
-            for table in ('placex', 'location_property_osmline'):
-                has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
-                                           WHERE table_name = %s
-                                           and column_name = 'token_info'""",
-                                        (table, ))
-                if has_column == 0:
+        for table in ('placex', 'location_property_osmline'):
+            if not table_has_column(conn, table, 'token_info'):
+                with conn.cursor() as cur:
                     cur.execute(pysql.SQL('ALTER TABLE {} ADD COLUMN token_info JSONB')
                                 .format(pysql.Identifier(table)))
         tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False,
                     cur.execute(pysql.SQL('ALTER TABLE {} ADD COLUMN token_info JSONB')
                                 .format(pysql.Identifier(table)))
         tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False,
@@ -212,7 +209,7 @@ def create_tiger_housenumber_index(conn: Connection, **_: Any) -> None:
         The inclusion is needed for efficient lookup of housenumbers in
         full address searches.
     """
         The inclusion is needed for efficient lookup of housenumbers in
         full address searches.
     """
-    if conn.server_version_tuple() >= (11, 0, 0):
+    if server_version_tuple(conn) >= (11, 0, 0):
         with conn.cursor() as cur:
             cur.execute(""" CREATE INDEX IF NOT EXISTS
                                 idx_location_property_tiger_housenumber_migrated
         with conn.cursor() as cur:
             cur.execute(""" CREATE INDEX IF NOT EXISTS
                                 idx_location_property_tiger_housenumber_migrated
@@ -239,7 +236,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
         Also converts the data into the stricter format which requires that
         startnumbers comply with the odd/even requirements.
     """
         Also converts the data into the stricter format which requires that
         startnumbers comply with the odd/even requirements.
     """
-    if conn.table_has_column('location_property_osmline', 'step'):
+    if table_has_column(conn, 'location_property_osmline', 'step'):
         return
 
     with conn.cursor() as cur:
         return
 
     with conn.cursor() as cur:
@@ -271,7 +268,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
 def add_step_column_for_tiger(conn: Connection, **_: Any) -> None:
     """ Add a new column 'step' to the tiger data table.
     """
 def add_step_column_for_tiger(conn: Connection, **_: Any) -> None:
     """ Add a new column 'step' to the tiger data table.
     """
-    if conn.table_has_column('location_property_tiger', 'step'):
+    if table_has_column(conn, 'location_property_tiger', 'step'):
         return
 
     with conn.cursor() as cur:
         return
 
     with conn.cursor() as cur:
@@ -287,7 +284,7 @@ def add_derived_name_column_for_country_names(conn: Connection, **_: Any) -> Non
     """ Add a new column 'derived_name' which in the future takes the
         country names as imported from OSM data.
     """
     """ Add a new column 'derived_name' which in the future takes the
         country names as imported from OSM data.
     """
-    if not conn.table_has_column('country_name', 'derived_name'):
+    if not table_has_column(conn, 'country_name', 'derived_name'):
         with conn.cursor() as cur:
             cur.execute("ALTER TABLE country_name ADD COLUMN derived_name public.HSTORE")
 
         with conn.cursor() as cur:
             cur.execute("ALTER TABLE country_name ADD COLUMN derived_name public.HSTORE")
 
@@ -297,12 +294,9 @@ def mark_internal_country_names(conn: Connection, config: Configuration, **_: An
     """ Names from the country table should be marked as internal to prevent
         them from being deleted. Only necessary for ICU tokenizer.
     """
     """ Names from the country table should be marked as internal to prevent
         them from being deleted. Only necessary for ICU tokenizer.
     """
-    import psycopg2.extras # pylint: disable=import-outside-toplevel
-
     tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
     with tokenizer.name_analyzer() as analyzer:
         with conn.cursor() as cur:
     tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
     with tokenizer.name_analyzer() as analyzer:
         with conn.cursor() as cur:
-            psycopg2.extras.register_hstore(cur)
             cur.execute("SELECT country_code, name FROM country_name")
 
             for country_code, names in cur:
             cur.execute("SELECT country_code, name FROM country_name")
 
             for country_code, names in cur:
@@ -319,7 +313,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
         The table is only necessary when updates are possible, i.e.
         the database is not in freeze mode.
     """
         The table is only necessary when updates are possible, i.e.
         the database is not in freeze mode.
     """
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         with conn.cursor() as cur:
             cur.execute("""CREATE TABLE IF NOT EXISTS place_to_be_deleted (
                              osm_type CHAR(1),
         with conn.cursor() as cur:
             cur.execute("""CREATE TABLE IF NOT EXISTS place_to_be_deleted (
                              osm_type CHAR(1),
@@ -333,7 +327,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
 def split_pending_index(conn: Connection, **_: Any) -> None:
     """ Reorganise indexes for pending updates.
     """
 def split_pending_index(conn: Connection, **_: Any) -> None:
     """ Reorganise indexes for pending updates.
     """
-    if conn.table_exists('place'):
+    if table_exists(conn, 'place'):
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS idx_placex_rank_address_sector
                            ON placex USING BTREE (rank_address, geometry_sector)
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS idx_placex_rank_address_sector
                            ON placex USING BTREE (rank_address, geometry_sector)
@@ -349,7 +343,7 @@ def split_pending_index(conn: Connection, **_: Any) -> None:
 def enable_forward_dependencies(conn: Connection, **_: Any) -> None:
     """ Create indexes for updates with forward dependency tracking (long-running).
     """
 def enable_forward_dependencies(conn: Connection, **_: Any) -> None:
     """ Create indexes for updates with forward dependency tracking (long-running).
     """
-    if conn.table_exists('planet_osm_ways'):
+    if table_exists(conn, 'planet_osm_ways'):
         with conn.cursor() as cur:
             cur.execute("""SELECT * FROM pg_indexes
                            WHERE tablename = 'planet_osm_ways'
         with conn.cursor() as cur:
             cur.execute("""SELECT * FROM pg_indexes
                            WHERE tablename = 'planet_osm_ways'
@@ -398,7 +392,7 @@ def create_postcode_area_lookup_index(conn: Connection, **_: Any) -> None:
 def create_postcode_parent_index(conn: Connection, **_: Any) -> None:
     """ Create index needed for updating postcodes when a parent changes.
     """
 def create_postcode_parent_index(conn: Connection, **_: Any) -> None:
     """ Create index needed for updating postcodes when a parent changes.
     """
-    if conn.table_exists('planet_osm_ways'):
+    if table_exists(conn, 'planet_osm_ways'):
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS
                              idx_location_postcode_parent_place_id
         with conn.cursor() as cur:
             cur.execute("""CREATE INDEX IF NOT EXISTS
                              idx_location_postcode_parent_place_id