]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/database_import.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / tools / database_import.py
index a6df275517a4134af973fc10943f83bfbd8404c6..433cd8afaca30372ab58698821fee06092748b1a 100644 (file)
@@ -9,16 +9,33 @@ import shutil
 from pathlib import Path
 
 import psutil
 from pathlib import Path
 
 import psutil
+import psycopg2
 
 from ..db.connection import connect, get_pg_env
 from ..db import utils as db_utils
 from ..db.async_connection import DBConnection
 
 from ..db.connection import connect, get_pg_env
 from ..db import utils as db_utils
 from ..db.async_connection import DBConnection
+from ..db.sql_preprocessor import SQLPreprocessor
 from .exec_utils import run_osm2pgsql
 from ..errors import UsageError
 from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
 
 LOG = logging.getLogger()
 
 from .exec_utils import run_osm2pgsql
 from ..errors import UsageError
 from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
 
 LOG = logging.getLogger()
 
+def setup_database_skeleton(dsn, data_dir, no_partitions, rouser=None):
+    """ Create a new database for Nominatim and populate it with the
+        essential extensions and data.
+    """
+    LOG.warning('Creating database')
+    create_db(dsn, rouser)
+
+    LOG.warning('Setting up database')
+    with connect(dsn) as conn:
+        setup_extensions(conn)
+
+    LOG.warning('Loading basic data')
+    import_base_data(dsn, data_dir, no_partitions)
+
+
 def create_db(dsn, rouser=None):
     """ Create a new database for the given DSN. Fails when the database
         already exists or the PostgreSQL version is too old.
 def create_db(dsn, rouser=None):
     """ Create a new database for the given DSN. Fails when the database
         already exists or the PostgreSQL version is too old.
@@ -72,7 +89,7 @@ def setup_extensions(conn):
         raise UsageError('PostGIS version is too old.')
 
 
         raise UsageError('PostGIS version is too old.')
 
 
-def install_module(src_dir, project_dir, module_dir):
+def install_module(src_dir, project_dir, module_dir, conn=None):
     """ Copy the normalization module from src_dir into the project
         directory under the '/module' directory. If 'module_dir' is set, then
         use the module from there instead and check that it is accessible
     """ Copy the normalization module from src_dir into the project
         directory under the '/module' directory. If 'module_dir' is set, then
         use the module from there instead and check that it is accessible
@@ -80,6 +97,9 @@ def install_module(src_dir, project_dir, module_dir):
 
         The function detects when the installation is run from the
         build directory. It doesn't touch the module in that case.
 
         The function detects when the installation is run from the
         build directory. It doesn't touch the module in that case.
+
+        If 'conn' is given, then the function also tests if the module
+        can be access via the given database.
     """
     if not module_dir:
         module_dir = project_dir / 'module'
     """
     if not module_dir:
         module_dir = project_dir / 'module'
@@ -99,19 +119,17 @@ def install_module(src_dir, project_dir, module_dir):
     else:
         LOG.info("Using custom path for database module at '%s'", module_dir)
 
     else:
         LOG.info("Using custom path for database module at '%s'", module_dir)
 
-    return module_dir
-
-
-def check_module_dir_path(conn, path):
-    """ Check that the normalisation module can be found and executed
-        from the given path.
-    """
-    with conn.cursor() as cur:
-        cur.execute("""CREATE FUNCTION nominatim_test_import_func(text)
-                       RETURNS text AS '{}/nominatim.so', 'transliteration'
-                       LANGUAGE c IMMUTABLE STRICT;
-                       DROP FUNCTION nominatim_test_import_func(text)
-                    """.format(path))
+    if conn is not None:
+        with conn.cursor() as cur:
+            try:
+                cur.execute("""CREATE FUNCTION nominatim_test_import_func(text)
+                               RETURNS text AS '{}/nominatim.so', 'transliteration'
+                               LANGUAGE c IMMUTABLE STRICT;
+                               DROP FUNCTION nominatim_test_import_func(text)
+                            """.format(module_dir))
+            except psycopg2.DatabaseError as err:
+                LOG.fatal("Error accessing database module: %s", err)
+                raise UsageError("Database module cannot be accessed.") from err
 
 
 def import_base_data(dsn, sql_dir, ignore_partitions=False):
 
 
 def import_base_data(dsn, sql_dir, ignore_partitions=False):
@@ -128,7 +146,7 @@ def import_base_data(dsn, sql_dir, ignore_partitions=False):
             conn.commit()
 
 
             conn.commit()
 
 
-def import_osm_data(osm_file, options, drop=False):
+def import_osm_data(osm_file, options, drop=False, ignore_errors=False):
     """ Import the given OSM file. 'options' contains the list of
         default settings for osm2pgsql.
     """
     """ Import the given OSM file. 'options' contains the list of
         default settings for osm2pgsql.
     """
@@ -147,10 +165,11 @@ def import_osm_data(osm_file, options, drop=False):
     run_osm2pgsql(options)
 
     with connect(options['dsn']) as conn:
     run_osm2pgsql(options)
 
     with connect(options['dsn']) as conn:
-        with conn.cursor() as cur:
-            cur.execute('SELECT * FROM place LIMIT 1')
-            if cur.rowcount == 0:
-                raise UsageError('No data imported by osm2pgsql.')
+        if not ignore_errors:
+            with conn.cursor() as cur:
+                cur.execute('SELECT * FROM place LIMIT 1')
+                if cur.rowcount == 0:
+                    raise UsageError('No data imported by osm2pgsql.')
 
         if drop:
             conn.drop_table('planet_osm_nodes')
 
         if drop:
             conn.drop_table('planet_osm_nodes')
@@ -160,6 +179,32 @@ def import_osm_data(osm_file, options, drop=False):
             Path(options['flatnode_file']).unlink()
 
 
             Path(options['flatnode_file']).unlink()
 
 
+def create_tables(conn, config, sqllib_dir, reverse_only=False):
+    """ Create the set of basic tables.
+        When `reverse_only` is True, then the main table for searching will
+        be skipped and only reverse search is possible.
+    """
+    sql = SQLPreprocessor(conn, config, sqllib_dir)
+    sql.env.globals['db']['reverse_only'] = reverse_only
+
+    sql.run_sql_file(conn, 'tables.sql')
+
+
+def create_table_triggers(conn, config, sqllib_dir):
+    """ Create the triggers for the tables. The trigger functions must already
+        have been imported with refresh.create_functions().
+    """
+    sql = SQLPreprocessor(conn, config, sqllib_dir)
+    sql.run_sql_file(conn, 'table-triggers.sql')
+
+
+def create_partition_tables(conn, config, sqllib_dir):
+    """ Create tables that have explicit partitioning.
+    """
+    sql = SQLPreprocessor(conn, config, sqllib_dir)
+    sql.run_sql_file(conn, 'partition-tables.src.sql')
+
+
 def truncate_data_tables(conn, max_word_frequency=None):
     """ Truncate all data tables to prepare for a fresh load.
     """
 def truncate_data_tables(conn, max_word_frequency=None):
     """ Truncate all data tables to prepare for a fresh load.
     """
@@ -173,8 +218,9 @@ def truncate_data_tables(conn, max_word_frequency=None):
         cur.execute('TRUNCATE location_property_tiger')
         cur.execute('TRUNCATE location_property_osmline')
         cur.execute('TRUNCATE location_postcode')
         cur.execute('TRUNCATE location_property_tiger')
         cur.execute('TRUNCATE location_property_osmline')
         cur.execute('TRUNCATE location_postcode')
-        cur.execute('TRUNCATE search_name')
-        cur.execute('DROP SEQUENCE seq_place')
+        if conn.table_exists('search_name'):
+            cur.execute('TRUNCATE search_name')
+        cur.execute('DROP SEQUENCE IF EXISTS seq_place')
         cur.execute('CREATE SEQUENCE seq_place start 100000')
 
         cur.execute("""SELECT tablename FROM pg_tables
         cur.execute('CREATE SEQUENCE seq_place start 100000')
 
         cur.execute("""SELECT tablename FROM pg_tables
@@ -239,3 +285,56 @@ def load_data(dsn, data_dir, threads):
     with connect(dsn) as conn:
         with conn.cursor() as cur:
             cur.execute('ANALYSE')
     with connect(dsn) as conn:
         with conn.cursor() as cur:
             cur.execute('ANALYSE')
+
+
+def create_search_indices(conn, config, sqllib_dir, drop=False):
+    """ Create tables that have explicit partitioning.
+    """
+
+    # If index creation failed and left an index invalid, they need to be
+    # cleaned out first, so that the script recreates them.
+    with conn.cursor() as cur:
+        cur.execute("""SELECT relname FROM pg_class, pg_index
+                       WHERE pg_index.indisvalid = false
+                             AND pg_index.indexrelid = pg_class.oid""")
+        bad_indices = [row[0] for row in list(cur)]
+        for idx in bad_indices:
+            LOG.info("Drop invalid index %s.", idx)
+            cur.execute('DROP INDEX "{}"'.format(idx))
+    conn.commit()
+
+    sql = SQLPreprocessor(conn, config, sqllib_dir)
+
+    sql.run_sql_file(conn, 'indices.sql', drop=drop)
+
+def create_country_names(conn, config):
+    """ Create search index for default country names.
+    """
+
+    with conn.cursor() as cur:
+        cur.execute("""SELECT getorcreate_country(make_standard_name('uk'), 'gb')""")
+        cur.execute("""SELECT getorcreate_country(make_standard_name('united states'), 'us')""")
+        cur.execute("""SELECT COUNT(*) FROM
+                       (SELECT getorcreate_country(make_standard_name(country_code),
+                       country_code) FROM country_name WHERE country_code is not null) AS x""")
+        cur.execute("""SELECT COUNT(*) FROM
+                       (SELECT getorcreate_country(make_standard_name(name->'name'), country_code) 
+                       FROM country_name WHERE name ? 'name') AS x""")
+        sql_statement = """SELECT COUNT(*) FROM (SELECT getorcreate_country(make_standard_name(v),
+                           country_code) FROM (SELECT country_code, skeys(name)
+                           AS k, svals(name) AS v FROM country_name) x WHERE k"""
+
+        languages = config.LANGUAGES
+
+        if languages:
+            sql_statement = "{} IN (".format(sql_statement)
+            delim = ''
+            for language in languages.split(','):
+                sql_statement = "{}{}'name:{}'".format(sql_statement, delim, language)
+                delim = ', '
+            sql_statement = '{})'.format(sql_statement)
+        else:
+            sql_statement = "{} LIKE 'name:%'".format(sql_statement)
+        sql_statement = "{}) v".format(sql_statement)
+        cur.execute(sql_statement)
+    conn.commit()