]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/database_import.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / tools / database_import.py
index 6e65e73ad101148c150e971710e5bdc5e47d70c2..433cd8afaca30372ab58698821fee06092748b1a 100644 (file)
@@ -14,6 +14,7 @@ import psycopg2
 from ..db.connection import connect, get_pg_env
 from ..db import utils as db_utils
 from ..db.async_connection import DBConnection
 from ..db.connection import connect, get_pg_env
 from ..db import utils as db_utils
 from ..db.async_connection import DBConnection
+from ..db.sql_preprocessor import SQLPreprocessor
 from .exec_utils import run_osm2pgsql
 from ..errors import UsageError
 from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
 from .exec_utils import run_osm2pgsql
 from ..errors import UsageError
 from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
@@ -145,7 +146,7 @@ def import_base_data(dsn, sql_dir, ignore_partitions=False):
             conn.commit()
 
 
             conn.commit()
 
 
-def import_osm_data(osm_file, options, drop=False):
+def import_osm_data(osm_file, options, drop=False, ignore_errors=False):
     """ Import the given OSM file. 'options' contains the list of
         default settings for osm2pgsql.
     """
     """ Import the given OSM file. 'options' contains the list of
         default settings for osm2pgsql.
     """
@@ -164,10 +165,11 @@ def import_osm_data(osm_file, options, drop=False):
     run_osm2pgsql(options)
 
     with connect(options['dsn']) as conn:
     run_osm2pgsql(options)
 
     with connect(options['dsn']) as conn:
-        with conn.cursor() as cur:
-            cur.execute('SELECT * FROM place LIMIT 1')
-            if cur.rowcount == 0:
-                raise UsageError('No data imported by osm2pgsql.')
+        if not ignore_errors:
+            with conn.cursor() as cur:
+                cur.execute('SELECT * FROM place LIMIT 1')
+                if cur.rowcount == 0:
+                    raise UsageError('No data imported by osm2pgsql.')
 
         if drop:
             conn.drop_table('planet_osm_nodes')
 
         if drop:
             conn.drop_table('planet_osm_nodes')
@@ -177,6 +179,32 @@ def import_osm_data(osm_file, options, drop=False):
             Path(options['flatnode_file']).unlink()
 
 
             Path(options['flatnode_file']).unlink()
 
 
+def create_tables(conn, config, sqllib_dir, reverse_only=False):
+    """ Create the set of basic tables.
+        When `reverse_only` is True, then the main table for searching will
+        be skipped and only reverse search is possible.
+    """
+    sql = SQLPreprocessor(conn, config, sqllib_dir)
+    sql.env.globals['db']['reverse_only'] = reverse_only
+
+    sql.run_sql_file(conn, 'tables.sql')
+
+
+def create_table_triggers(conn, config, sqllib_dir):
+    """ Create the triggers for the tables. The trigger functions must already
+        have been imported with refresh.create_functions().
+    """
+    sql = SQLPreprocessor(conn, config, sqllib_dir)
+    sql.run_sql_file(conn, 'table-triggers.sql')
+
+
+def create_partition_tables(conn, config, sqllib_dir):
+    """ Create tables that have explicit partitioning.
+    """
+    sql = SQLPreprocessor(conn, config, sqllib_dir)
+    sql.run_sql_file(conn, 'partition-tables.src.sql')
+
+
 def truncate_data_tables(conn, max_word_frequency=None):
     """ Truncate all data tables to prepare for a fresh load.
     """
 def truncate_data_tables(conn, max_word_frequency=None):
     """ Truncate all data tables to prepare for a fresh load.
     """
@@ -190,7 +218,8 @@ def truncate_data_tables(conn, max_word_frequency=None):
         cur.execute('TRUNCATE location_property_tiger')
         cur.execute('TRUNCATE location_property_osmline')
         cur.execute('TRUNCATE location_postcode')
         cur.execute('TRUNCATE location_property_tiger')
         cur.execute('TRUNCATE location_property_osmline')
         cur.execute('TRUNCATE location_postcode')
-        cur.execute('TRUNCATE search_name')
+        if conn.table_exists('search_name'):
+            cur.execute('TRUNCATE search_name')
         cur.execute('DROP SEQUENCE IF EXISTS seq_place')
         cur.execute('CREATE SEQUENCE seq_place start 100000')
 
         cur.execute('DROP SEQUENCE IF EXISTS seq_place')
         cur.execute('CREATE SEQUENCE seq_place start 100000')
 
@@ -256,3 +285,56 @@ def load_data(dsn, data_dir, threads):
     with connect(dsn) as conn:
         with conn.cursor() as cur:
             cur.execute('ANALYSE')
     with connect(dsn) as conn:
         with conn.cursor() as cur:
             cur.execute('ANALYSE')
+
+
+def create_search_indices(conn, config, sqllib_dir, drop=False):
+    """ Create tables that have explicit partitioning.
+    """
+
+    # If index creation failed and left an index invalid, they need to be
+    # cleaned out first, so that the script recreates them.
+    with conn.cursor() as cur:
+        cur.execute("""SELECT relname FROM pg_class, pg_index
+                       WHERE pg_index.indisvalid = false
+                             AND pg_index.indexrelid = pg_class.oid""")
+        bad_indices = [row[0] for row in list(cur)]
+        for idx in bad_indices:
+            LOG.info("Drop invalid index %s.", idx)
+            cur.execute('DROP INDEX "{}"'.format(idx))
+    conn.commit()
+
+    sql = SQLPreprocessor(conn, config, sqllib_dir)
+
+    sql.run_sql_file(conn, 'indices.sql', drop=drop)
+
+def create_country_names(conn, config):
+    """ Create search index for default country names.
+    """
+
+    with conn.cursor() as cur:
+        cur.execute("""SELECT getorcreate_country(make_standard_name('uk'), 'gb')""")
+        cur.execute("""SELECT getorcreate_country(make_standard_name('united states'), 'us')""")
+        cur.execute("""SELECT COUNT(*) FROM
+                       (SELECT getorcreate_country(make_standard_name(country_code),
+                       country_code) FROM country_name WHERE country_code is not null) AS x""")
+        cur.execute("""SELECT COUNT(*) FROM
+                       (SELECT getorcreate_country(make_standard_name(name->'name'), country_code) 
+                       FROM country_name WHERE name ? 'name') AS x""")
+        sql_statement = """SELECT COUNT(*) FROM (SELECT getorcreate_country(make_standard_name(v),
+                           country_code) FROM (SELECT country_code, skeys(name)
+                           AS k, svals(name) AS v FROM country_name) x WHERE k"""
+
+        languages = config.LANGUAGES
+
+        if languages:
+            sql_statement = "{} IN (".format(sql_statement)
+            delim = ''
+            for language in languages.split(','):
+                sql_statement = "{}{}'name:{}'".format(sql_statement, delim, language)
+                delim = ', '
+            sql_statement = '{})'.format(sql_statement)
+        else:
+            sql_statement = "{} LIKE 'name:%'".format(sql_statement)
+        sql_statement = "{}) v".format(sql_statement)
+        cur.execute(sql_statement)
+    conn.commit()