]> git.openstreetmap.org Git - nominatim.git/commitdiff
move setup function to python
authorSarah Hoffmann <lonvia@denofr.de>
Fri, 26 Feb 2021 14:02:39 +0000 (15:02 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Fri, 26 Feb 2021 14:02:39 +0000 (15:02 +0100)
There are still back-calls to PHP for some of the sub-steps.
These needs some larger refactoring to be moved to Python.

nominatim/cli.py
nominatim/clicmd/__init__.py
nominatim/clicmd/setup.py [new file with mode: 0644]
nominatim/clicmd/transition.py
nominatim/indexer/indexer.py
nominatim/tools/database_import.py
test/python/conftest.py
test/python/test_cli.py
test/python/test_tools_check_database.py
test/python/test_tools_database_import.py

index eb652d646b93fa4894dae33f1d11ad77d54e3f84..35c6c1f09da01757efa572cd623dc26827e7967a 100644 (file)
@@ -111,72 +111,6 @@ class CommandlineParser:
 # pylint: disable=E0012,C0415
 
 
 # pylint: disable=E0012,C0415
 
 
-class SetupAll:
-    """\
-    Create a new Nominatim database from an OSM file.
-    """
-
-    @staticmethod
-    def add_args(parser):
-        group_name = parser.add_argument_group('Required arguments')
-        group = group_name.add_mutually_exclusive_group(required=True)
-        group.add_argument('--osm-file',
-                           help='OSM file to be imported.')
-        group.add_argument('--continue', dest='continue_at',
-                           choices=['load-data', 'indexing', 'db-postprocess'],
-                           help='Continue an import that was interrupted')
-        group = parser.add_argument_group('Optional arguments')
-        group.add_argument('--osm2pgsql-cache', metavar='SIZE', type=int,
-                           help='Size of cache to be used by osm2pgsql (in MB)')
-        group.add_argument('--reverse-only', action='store_true',
-                           help='Do not create tables and indexes for searching')
-        group.add_argument('--enable-debug-statements', action='store_true',
-                           help='Include debug warning statements in SQL code')
-        group.add_argument('--no-partitions', action='store_true',
-                           help="""Do not partition search indices
-                                   (speeds up import of single country extracts)""")
-        group.add_argument('--no-updates', action='store_true',
-                           help="""Do not keep tables that are only needed for
-                                   updating the database later""")
-        group = parser.add_argument_group('Expert options')
-        group.add_argument('--ignore-errors', action='store_true',
-                           help='Continue import even when errors in SQL are present')
-        group.add_argument('--index-noanalyse', action='store_true',
-                           help='Do not perform analyse operations during index')
-
-
-    @staticmethod
-    def run(args):
-        params = ['setup.php']
-        if args.osm_file:
-            params.extend(('--all', '--osm-file', args.osm_file))
-        else:
-            if args.continue_at == 'load-data':
-                params.append('--load-data')
-            if args.continue_at in ('load-data', 'indexing'):
-                params.append('--index')
-            params.extend(('--create-search-indices', '--create-country-names',
-                           '--setup-website'))
-        if args.osm2pgsql_cache:
-            params.extend(('--osm2pgsql-cache', args.osm2pgsql_cache))
-        if args.reverse_only:
-            params.append('--reverse-only')
-        if args.enable_debug_statements:
-            params.append('--enable-debug-statements')
-        if args.no_partitions:
-            params.append('--no-partitions')
-        if args.no_updates:
-            params.append('--drop')
-        if args.ignore_errors:
-            params.append('--ignore-errors')
-        if args.index_noanalyse:
-            params.append('--index-noanalyse')
-        if args.threads:
-            params.extend(('--threads', args.threads))
-
-        return run_legacy_script(*params, nominatim_env=args)
-
-
 class SetupSpecialPhrases:
     """\
     Maintain special phrases.
 class SetupSpecialPhrases:
     """\
     Maintain special phrases.
@@ -334,7 +268,7 @@ def nominatim(**kwargs):
     """
     parser = CommandlineParser('nominatim', nominatim.__doc__)
 
     """
     parser = CommandlineParser('nominatim', nominatim.__doc__)
 
-    parser.add_subcommand('import', SetupAll)
+    parser.add_subcommand('import', clicmd.SetupAll)
     parser.add_subcommand('freeze', clicmd.SetupFreeze)
     parser.add_subcommand('replication', clicmd.UpdateReplication)
 
     parser.add_subcommand('freeze', clicmd.SetupFreeze)
     parser.add_subcommand('replication', clicmd.UpdateReplication)
 
index 78570b1b4f9b92cfc72a23a0c2e74e51300693e4..9101e0c08973cc7877849d1b9248e2788cd8457b 100644 (file)
@@ -2,6 +2,7 @@
 Subcommand definitions for the command-line tool.
 """
 
 Subcommand definitions for the command-line tool.
 """
 
+from .setup import SetupAll
 from .replication import UpdateReplication
 from .api import APISearch, APIReverse, APILookup, APIDetails, APIStatus
 from .index import UpdateIndex
 from .replication import UpdateReplication
 from .api import APISearch, APIReverse, APILookup, APIDetails, APIStatus
 from .index import UpdateIndex
diff --git a/nominatim/clicmd/setup.py b/nominatim/clicmd/setup.py
new file mode 100644 (file)
index 0000000..8f717cb
--- /dev/null
@@ -0,0 +1,140 @@
+"""
+Implementation of the 'import' subcommand.
+"""
+import logging
+from pathlib import Path
+
+import psutil
+
+from ..tools.exec_utils import run_legacy_script
+from ..db.connection import connect
+from ..db import status
+from ..errors import UsageError
+
+# Do not repeat documentation of subcommand classes.
+# pylint: disable=C0111
+# Using non-top-level imports to avoid eventually unused imports.
+# pylint: disable=E0012,C0415
+
+LOG = logging.getLogger()
+
+class SetupAll:
+    """\
+    Create a new Nominatim database from an OSM file.
+    """
+
+    @staticmethod
+    def add_args(parser):
+        group_name = parser.add_argument_group('Required arguments')
+        group = group_name.add_mutually_exclusive_group(required=True)
+        group.add_argument('--osm-file', metavar='FILE',
+                           help='OSM file to be imported.')
+        group.add_argument('--continue', dest='continue_at',
+                           choices=['load-data', 'indexing', 'db-postprocess'],
+                           help='Continue an import that was interrupted')
+        group = parser.add_argument_group('Optional arguments')
+        group.add_argument('--osm2pgsql-cache', metavar='SIZE', type=int,
+                           help='Size of cache to be used by osm2pgsql (in MB)')
+        group.add_argument('--reverse-only', action='store_true',
+                           help='Do not create tables and indexes for searching')
+        group.add_argument('--no-partitions', action='store_true',
+                           help="""Do not partition search indices
+                                   (speeds up import of single country extracts)""")
+        group.add_argument('--no-updates', action='store_true',
+                           help="""Do not keep tables that are only needed for
+                                   updating the database later""")
+        group = parser.add_argument_group('Expert options')
+        group.add_argument('--ignore-errors', action='store_true',
+                           help='Continue import even when errors in SQL are present')
+        group.add_argument('--index-noanalyse', action='store_true',
+                           help='Do not perform analyse operations during index')
+
+
+    @staticmethod
+    def run(args): # pylint: disable=too-many-statements
+        from ..tools import database_import
+        from ..tools import refresh
+        from ..indexer.indexer import Indexer
+
+        if args.osm_file and not Path(args.osm_file).is_file():
+            LOG.fatal("OSM file '%s' does not exist.", args.osm_file)
+            raise UsageError('Cannot access file.')
+
+        if args.continue_at is None:
+            database_import.setup_database_skeleton(args.config.get_libpq_dsn(),
+                                                    args.data_dir,
+                                                    args.no_partitions,
+                                                    rouser=args.config.DATABASE_WEBUSER)
+
+            LOG.warning('Installing database module')
+            with connect(args.config.get_libpq_dsn()) as conn:
+                database_import.install_module(args.module_dir, args.project_dir,
+                                               args.config.DATABASE_MODULE_PATH,
+                                               conn=conn)
+
+            LOG.warning('Importing OSM data file')
+            database_import.import_osm_data(Path(args.osm_file),
+                                            args.osm2pgsql_options(0, 1),
+                                            drop=args.no_updates)
+
+            LOG.warning('Create functions (1st pass)')
+            with connect(args.config.get_libpq_dsn()) as conn:
+                refresh.create_functions(conn, args.config, args.sqllib_dir,
+                                         False, False)
+
+            LOG.warning('Create tables')
+            params = ['setup.php', '--create-tables', '--create-partition-tables']
+            if args.reverse_only:
+                params.append('--reverse-only')
+            run_legacy_script(*params, nominatim_env=args)
+
+            LOG.warning('Create functions (2nd pass)')
+            with connect(args.config.get_libpq_dsn()) as conn:
+                refresh.create_functions(conn, args.config, args.sqllib_dir,
+                                         False, False)
+
+            LOG.warning('Importing wikipedia importance data')
+            data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir)
+            if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(),
+                                                 data_path) > 0:
+                LOG.error('Wikipedia importance dump file not found. '
+                          'Will be using default importances.')
+
+            LOG.warning('Initialise tables')
+            with connect(args.config.get_libpq_dsn()) as conn:
+                database_import.truncate_data_tables(conn, args.config.MAX_WORD_FREQUENCY)
+
+        if args.continue_at is None or args.continue_at == 'load-data':
+            LOG.warning('Load data into placex table')
+            database_import.load_data(args.config.get_libpq_dsn(),
+                                      args.data_dir,
+                                      args.threads or psutil.cpu_count() or 1)
+
+            LOG.warning('Calculate postcodes')
+            run_legacy_script('setup.php', '--calculate-postcodes', nominatim_env=args)
+
+        if args.continue_at is None or args.continue_at in ('load-data', 'indexing'):
+            LOG.warning('Indexing places')
+            indexer = Indexer(args.config.get_libpq_dsn(),
+                              args.threads or psutil.cpu_count() or 1)
+            indexer.index_full(analyse=not args.index_noanalyse)
+
+        LOG.warning('Post-process tables')
+        params = ['setup.php', '--create-search-indices', '--create-country-names']
+        if args.no_updates:
+            params.append('--drop')
+        run_legacy_script(*params, nominatim_env=args)
+
+        webdir = args.project_dir / 'website'
+        LOG.warning('Setup website at %s', webdir)
+        refresh.setup_website(webdir, args.phplib_dir, args.config)
+
+        with connect(args.config.get_libpq_dsn()) as conn:
+            try:
+                dbdate = status.compute_database_date(conn)
+                status.set_status(conn, dbdate)
+                LOG.info('Database is at %s.', dbdate)
+            except Exception as exc: # pylint: disable=broad-except
+                LOG.error('Cannot determine date of database: %s', exc)
+
+        return 0
index de4e16cac2e05bda4f1b2c06f4d58787b4d90dae..1d78062ed9995a22a82377f81401622d7e4516b6 100644 (file)
@@ -59,12 +59,12 @@ class AdminTransition:
 
         if args.setup_db:
             LOG.warning('Setup DB')
 
         if args.setup_db:
             LOG.warning('Setup DB')
-            mpath = database_import.install_module(args.module_dir, args.project_dir,
-                                                   args.config.DATABASE_MODULE_PATH)
 
             with connect(args.config.get_libpq_dsn()) as conn:
                 database_import.setup_extensions(conn)
 
             with connect(args.config.get_libpq_dsn()) as conn:
                 database_import.setup_extensions(conn)
-                database_import.check_module_dir_path(conn, mpath)
+                database_import.install_module(args.module_dir, args.project_dir,
+                                               args.config.DATABASE_MODULE_PATH,
+                                               conn=conn)
 
             database_import.import_base_data(args.config.get_libpq_dsn(),
                                              args.data_dir, args.no_partitions)
 
             database_import.import_base_data(args.config.get_libpq_dsn(),
                                              args.data_dir, args.no_partitions)
@@ -88,7 +88,7 @@ class AdminTransition:
             with connect(args.config.get_libpq_dsn()) as conn:
                 try:
                     status.set_status(conn, status.compute_database_date(conn))
             with connect(args.config.get_libpq_dsn()) as conn:
                 try:
                     status.set_status(conn, status.compute_database_date(conn))
-                except Exception as exc: # pylint: disable=bare-except
+                except Exception as exc: # pylint: disable=broad-except
                     LOG.error('Cannot determine date of database: %s', exc)
 
         if args.index:
                     LOG.error('Cannot determine date of database: %s', exc)
 
         if args.index:
index d997e522490735994080df726e256662a4c0c632..93723844d8d63582c982a4f9846058af7b512b8e 100644 (file)
@@ -119,6 +119,13 @@ class PostcodeRunner:
                   WHERE place_id IN ({})
                """.format(','.join((str(i) for i in ids)))
 
                   WHERE place_id IN ({})
                """.format(','.join((str(i) for i in ids)))
 
+
+def _analyse_db_if(conn, condition):
+    if condition:
+        with conn.cursor() as cur:
+            cur.execute('ANALYSE')
+
+
 class Indexer:
     """ Main indexing routine.
     """
 class Indexer:
     """ Main indexing routine.
     """
@@ -142,7 +149,7 @@ class Indexer:
 
         for thread in self.threads:
             thread.close()
 
         for thread in self.threads:
             thread.close()
-        threads = []
+        self.threads = []
 
 
     def index_full(self, analyse=True):
 
 
     def index_full(self, analyse=True):
@@ -155,26 +162,22 @@ class Indexer:
 
         try:
             self.index_by_rank(0, 4)
 
         try:
             self.index_by_rank(0, 4)
-            self._analyse_db_if(conn, analyse)
+            _analyse_db_if(conn, analyse)
 
             self.index_boundaries(0, 30)
 
             self.index_boundaries(0, 30)
-            self._analyse_db_if(conn, analyse)
+            _analyse_db_if(conn, analyse)
 
             self.index_by_rank(5, 25)
 
             self.index_by_rank(5, 25)
-            self._analyse_db_if(conn, analyse)
+            _analyse_db_if(conn, analyse)
 
             self.index_by_rank(26, 30)
 
             self.index_by_rank(26, 30)
-            self._analyse_db_if(conn, analyse)
+            _analyse_db_if(conn, analyse)
 
             self.index_postcodes()
 
             self.index_postcodes()
-            self._analyse_db_if(conn, analyse)
+            _analyse_db_if(conn, analyse)
         finally:
             conn.close()
 
         finally:
             conn.close()
 
-    def _analyse_db_if(self, conn, condition):
-        if condition:
-            with conn.cursor() as cur:
-                cur.execute('ANALYSE')
 
     def index_boundaries(self, minrank, maxrank):
         """ Index only administrative boundaries within the given rank range.
 
     def index_boundaries(self, minrank, maxrank):
         """ Index only administrative boundaries within the given rank range.
index a6df275517a4134af973fc10943f83bfbd8404c6..6e65e73ad101148c150e971710e5bdc5e47d70c2 100644 (file)
@@ -9,6 +9,7 @@ import shutil
 from pathlib import Path
 
 import psutil
 from pathlib import Path
 
 import psutil
+import psycopg2
 
 from ..db.connection import connect, get_pg_env
 from ..db import utils as db_utils
 
 from ..db.connection import connect, get_pg_env
 from ..db import utils as db_utils
@@ -19,6 +20,21 @@ from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
+def setup_database_skeleton(dsn, data_dir, no_partitions, rouser=None):
+    """ Create a new database for Nominatim and populate it with the
+        essential extensions and data.
+    """
+    LOG.warning('Creating database')
+    create_db(dsn, rouser)
+
+    LOG.warning('Setting up database')
+    with connect(dsn) as conn:
+        setup_extensions(conn)
+
+    LOG.warning('Loading basic data')
+    import_base_data(dsn, data_dir, no_partitions)
+
+
 def create_db(dsn, rouser=None):
     """ Create a new database for the given DSN. Fails when the database
         already exists or the PostgreSQL version is too old.
 def create_db(dsn, rouser=None):
     """ Create a new database for the given DSN. Fails when the database
         already exists or the PostgreSQL version is too old.
@@ -72,7 +88,7 @@ def setup_extensions(conn):
         raise UsageError('PostGIS version is too old.')
 
 
         raise UsageError('PostGIS version is too old.')
 
 
-def install_module(src_dir, project_dir, module_dir):
+def install_module(src_dir, project_dir, module_dir, conn=None):
     """ Copy the normalization module from src_dir into the project
         directory under the '/module' directory. If 'module_dir' is set, then
         use the module from there instead and check that it is accessible
     """ Copy the normalization module from src_dir into the project
         directory under the '/module' directory. If 'module_dir' is set, then
         use the module from there instead and check that it is accessible
@@ -80,6 +96,9 @@ def install_module(src_dir, project_dir, module_dir):
 
         The function detects when the installation is run from the
         build directory. It doesn't touch the module in that case.
 
         The function detects when the installation is run from the
         build directory. It doesn't touch the module in that case.
+
+        If 'conn' is given, then the function also tests if the module
+        can be access via the given database.
     """
     if not module_dir:
         module_dir = project_dir / 'module'
     """
     if not module_dir:
         module_dir = project_dir / 'module'
@@ -99,19 +118,17 @@ def install_module(src_dir, project_dir, module_dir):
     else:
         LOG.info("Using custom path for database module at '%s'", module_dir)
 
     else:
         LOG.info("Using custom path for database module at '%s'", module_dir)
 
-    return module_dir
-
-
-def check_module_dir_path(conn, path):
-    """ Check that the normalisation module can be found and executed
-        from the given path.
-    """
-    with conn.cursor() as cur:
-        cur.execute("""CREATE FUNCTION nominatim_test_import_func(text)
-                       RETURNS text AS '{}/nominatim.so', 'transliteration'
-                       LANGUAGE c IMMUTABLE STRICT;
-                       DROP FUNCTION nominatim_test_import_func(text)
-                    """.format(path))
+    if conn is not None:
+        with conn.cursor() as cur:
+            try:
+                cur.execute("""CREATE FUNCTION nominatim_test_import_func(text)
+                               RETURNS text AS '{}/nominatim.so', 'transliteration'
+                               LANGUAGE c IMMUTABLE STRICT;
+                               DROP FUNCTION nominatim_test_import_func(text)
+                            """.format(module_dir))
+            except psycopg2.DatabaseError as err:
+                LOG.fatal("Error accessing database module: %s", err)
+                raise UsageError("Database module cannot be accessed.") from err
 
 
 def import_base_data(dsn, sql_dir, ignore_partitions=False):
 
 
 def import_base_data(dsn, sql_dir, ignore_partitions=False):
@@ -174,7 +191,7 @@ def truncate_data_tables(conn, max_word_frequency=None):
         cur.execute('TRUNCATE location_property_osmline')
         cur.execute('TRUNCATE location_postcode')
         cur.execute('TRUNCATE search_name')
         cur.execute('TRUNCATE location_property_osmline')
         cur.execute('TRUNCATE location_postcode')
         cur.execute('TRUNCATE search_name')
-        cur.execute('DROP SEQUENCE seq_place')
+        cur.execute('DROP SEQUENCE IF EXISTS seq_place')
         cur.execute('CREATE SEQUENCE seq_place start 100000')
 
         cur.execute("""SELECT tablename FROM pg_tables
         cur.execute('CREATE SEQUENCE seq_place start 100000')
 
         cur.execute("""SELECT tablename FROM pg_tables
index 40b611c03a6bd4168ba62dcd0078d4c0c6e70962..d16dceffccb1008a1fe6082ab0a5e9d04cd2a439 100644 (file)
@@ -43,6 +43,11 @@ class _TestingCursor(psycopg2.extras.DictCursor):
                              WHERE tablename = %s""", (table, ))
         return num == 1
 
                              WHERE tablename = %s""", (table, ))
         return num == 1
 
+    def table_rows(self, table):
+        """ Return the number of rows in the given table.
+        """
+        return self.scalar('SELECT count(*) FROM ' + table)
+
 
 @pytest.fixture
 def temp_db(monkeypatch):
 
 @pytest.fixture
 def temp_db(monkeypatch):
@@ -109,8 +114,12 @@ def temp_db_cursor(temp_db):
 
 @pytest.fixture
 def table_factory(temp_db_cursor):
 
 @pytest.fixture
 def table_factory(temp_db_cursor):
-    def mk_table(name, definition='id INT'):
+    def mk_table(name, definition='id INT', content=None):
         temp_db_cursor.execute('CREATE TABLE {} ({})'.format(name, definition))
         temp_db_cursor.execute('CREATE TABLE {} ({})'.format(name, definition))
+        if content is not None:
+            if not isinstance(content, str):
+                content = '),('.join([str(x) for x in content])
+            temp_db_cursor.execute("INSERT INTO {} VALUES ({})".format(name, content))
 
     return mk_table
 
 
     return mk_table
 
@@ -174,7 +183,7 @@ def place_row(place_table, temp_db_cursor):
         temp_db_cursor.execute("INSERT INTO place VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)",
                                (osm_id or next(idseq), osm_type, cls, typ, names,
                                 admin_level, address, extratags,
         temp_db_cursor.execute("INSERT INTO place VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)",
                                (osm_id or next(idseq), osm_type, cls, typ, names,
                                 admin_level, address, extratags,
-                                geom or 'SRID=4326;POINT(0 0 )'))
+                                geom or 'SRID=4326;POINT(0 0)'))
 
     return _insert
 
 
     return _insert
 
@@ -184,7 +193,7 @@ def placex_table(temp_db_with_extensions, temp_db_conn):
     """
     with temp_db_conn.cursor() as cur:
         cur.execute("""CREATE TABLE placex (
     """
     with temp_db_conn.cursor() as cur:
         cur.execute("""CREATE TABLE placex (
-                           place_id BIGINT NOT NULL,
+                           place_id BIGINT,
                            parent_place_id BIGINT,
                            linked_place_id BIGINT,
                            importance FLOAT,
                            parent_place_id BIGINT,
                            linked_place_id BIGINT,
                            importance FLOAT,
@@ -207,8 +216,43 @@ def placex_table(temp_db_with_extensions, temp_db_conn):
                            country_code varchar(2),
                            housenumber TEXT,
                            postcode TEXT,
                            country_code varchar(2),
                            housenumber TEXT,
                            postcode TEXT,
-                           centroid GEOMETRY(Geometry, 4326))
-                           """)
+                           centroid GEOMETRY(Geometry, 4326))""")
+    temp_db_conn.commit()
+
+
+@pytest.fixture
+def osmline_table(temp_db_with_extensions, temp_db_conn):
+    with temp_db_conn.cursor() as cur:
+        cur.execute("""CREATE TABLE location_property_osmline (
+                           place_id BIGINT,
+                           osm_id BIGINT,
+                           parent_place_id BIGINT,
+                           geometry_sector INTEGER,
+                           indexed_date TIMESTAMP,
+                           startnumber INTEGER,
+                           endnumber INTEGER,
+                           partition SMALLINT,
+                           indexed_status SMALLINT,
+                           linegeo GEOMETRY,
+                           interpolationtype TEXT,
+                           address HSTORE,
+                           postcode TEXT,
+                           country_code VARCHAR(2))""")
+    temp_db_conn.commit()
+
+
+@pytest.fixture
+def word_table(temp_db, temp_db_conn):
+    with temp_db_conn.cursor() as cur:
+        cur.execute("""CREATE TABLE word (
+                           word_id INTEGER,
+                           word_token text,
+                           word text,
+                           class text,
+                           type text,
+                           country_code varchar(2),
+                           search_name_count INTEGER,
+                           operator TEXT)""")
     temp_db_conn.commit()
 
 
     temp_db_conn.commit()
 
 
index e2a44e37f44122c81a927bf11d54b770d6145106..70c106051f4a621838c1be167db4c9d1d8cee2f1 100644 (file)
@@ -13,9 +13,11 @@ import nominatim.cli
 import nominatim.clicmd.api
 import nominatim.clicmd.refresh
 import nominatim.clicmd.admin
 import nominatim.clicmd.api
 import nominatim.clicmd.refresh
 import nominatim.clicmd.admin
+import nominatim.clicmd.setup
 import nominatim.indexer.indexer
 import nominatim.tools.admin
 import nominatim.tools.check_database
 import nominatim.indexer.indexer
 import nominatim.tools.admin
 import nominatim.tools.check_database
+import nominatim.tools.database_import
 import nominatim.tools.freeze
 import nominatim.tools.refresh
 
 import nominatim.tools.freeze
 import nominatim.tools.refresh
 
@@ -61,7 +63,6 @@ def test_cli_help(capsys):
 
 
 @pytest.mark.parametrize("command,script", [
 
 
 @pytest.mark.parametrize("command,script", [
-                         (('import', '--continue', 'load-data'), 'setup'),
                          (('special-phrases',), 'specialphrases'),
                          (('add-data', '--tiger-data', 'tiger'), 'setup'),
                          (('add-data', '--file', 'foo.osm'), 'update'),
                          (('special-phrases',), 'specialphrases'),
                          (('add-data', '--tiger-data', 'tiger'), 'setup'),
                          (('add-data', '--file', 'foo.osm'), 'update'),
@@ -74,6 +75,36 @@ def test_legacy_commands_simple(mock_run_legacy, command, script):
     assert mock_run_legacy.last_args[0] == script + '.php'
 
 
     assert mock_run_legacy.last_args[0] == script + '.php'
 
 
+def test_import_missing_file(temp_db):
+    assert 1 == call_nominatim('import', '--osm-file', 'sfsafegweweggdgw.reh.erh')
+
+
+def test_import_bad_file(temp_db):
+    assert 1 == call_nominatim('import', '--osm-file', '.')
+
+
+def test_import_full(temp_db, mock_func_factory):
+    mocks = [
+        mock_func_factory(nominatim.tools.database_import, 'setup_database_skeleton'),
+        mock_func_factory(nominatim.tools.database_import, 'install_module'),
+        mock_func_factory(nominatim.tools.database_import, 'import_osm_data'),
+        mock_func_factory(nominatim.tools.refresh, 'import_wikipedia_articles'),
+        mock_func_factory(nominatim.tools.database_import, 'truncate_data_tables'),
+        mock_func_factory(nominatim.tools.database_import, 'load_data'),
+        mock_func_factory(nominatim.indexer.indexer.Indexer, 'index_full'),
+        mock_func_factory(nominatim.tools.refresh, 'setup_website'),
+    ]
+
+    cf_mock = mock_func_factory(nominatim.tools.refresh, 'create_functions')
+    mock_func_factory(nominatim.clicmd.setup, 'run_legacy_script')
+
+    assert 0 == call_nominatim('import', '--osm-file', __file__)
+
+    assert cf_mock.called > 1
+
+    for mock in mocks:
+        assert mock.called == 1
+
 def test_freeze_command(mock_func_factory, temp_db):
     mock_drop = mock_func_factory(nominatim.tools.freeze, 'drop_update_tables')
     mock_flatnode = mock_func_factory(nominatim.tools.freeze, 'drop_flatnode_file')
 def test_freeze_command(mock_func_factory, temp_db):
     mock_drop = mock_func_factory(nominatim.tools.freeze, 'drop_update_tables')
     mock_flatnode = mock_func_factory(nominatim.tools.freeze, 'drop_flatnode_file')
index 3787c3be16e9351fd4ccedb5a7d98e698a3b358b..68b376a781c585b417c09a441e3f7485d7d231fa 100644 (file)
@@ -63,6 +63,10 @@ def test_check_database_indexes_bad(temp_db_conn, def_config):
     assert chkdb.check_database_indexes(temp_db_conn, def_config) == chkdb.CheckState.FAIL
 
 
     assert chkdb.check_database_indexes(temp_db_conn, def_config) == chkdb.CheckState.FAIL
 
 
+def test_check_database_indexes_valid(temp_db_conn, def_config):
+    assert chkdb.check_database_index_valid(temp_db_conn, def_config) == chkdb.CheckState.OK
+
+
 def test_check_tiger_table_disabled(temp_db_conn, def_config, monkeypatch):
     monkeypatch.setenv('NOMINATIM_USE_US_TIGER_DATA' , 'no')
     assert chkdb.check_tiger_table(temp_db_conn, def_config) == chkdb.CheckState.NOT_APPLICABLE
 def test_check_tiger_table_disabled(temp_db_conn, def_config, monkeypatch):
     monkeypatch.setenv('NOMINATIM_USE_US_TIGER_DATA' , 'no')
     assert chkdb.check_tiger_table(temp_db_conn, def_config) == chkdb.CheckState.NOT_APPLICABLE
index 597fdfc126f0838f75b8690d04b6fb5ff11b411b..f9760fc0fc0da4c754a94576af4411075d81b4cf 100644 (file)
@@ -24,6 +24,24 @@ def nonexistant_db():
     with conn.cursor() as cur:
         cur.execute('DROP DATABASE IF EXISTS {}'.format(dbname))
 
     with conn.cursor() as cur:
         cur.execute('DROP DATABASE IF EXISTS {}'.format(dbname))
 
+@pytest.mark.parametrize("no_partitions", (True, False))
+def test_setup_skeleton(src_dir, nonexistant_db, no_partitions):
+    database_import.setup_database_skeleton('dbname=' + nonexistant_db,
+                                            src_dir / 'data', no_partitions)
+
+    conn = psycopg2.connect(database=nonexistant_db)
+
+    try:
+        with conn.cursor() as cur:
+            cur.execute("SELECT distinct partition FROM country_name")
+            partitions = set([r[0] for r in list(cur)])
+            if no_partitions:
+                assert partitions == set([0])
+            else:
+                assert len(partitions) > 10
+    finally:
+        conn.close()
+
 
 def test_create_db_success(nonexistant_db):
     database_import.create_db('dbname=' + nonexistant_db, rouser='www-data')
 
 def test_create_db_success(nonexistant_db):
     database_import.create_db('dbname=' + nonexistant_db, rouser='www-data')
@@ -79,6 +97,22 @@ def test_install_module(tmp_path):
     assert outfile.stat().st_mode == 33261
 
 
     assert outfile.stat().st_mode == 33261
 
 
+def test_install_module_custom(tmp_path):
+    (tmp_path / 'nominatim.so').write_text('TEST nomiantim.so')
+
+    database_import.install_module(tmp_path, tmp_path, str(tmp_path.resolve()))
+
+    assert not (tmp_path / 'module').exists()
+
+
+def test_install_module_fail_access(temp_db_conn, tmp_path):
+    (tmp_path / 'nominatim.so').write_text('TEST nomiantim.so')
+
+    with pytest.raises(UsageError, match='.*module cannot be accessed.*'):
+        database_import.install_module(tmp_path, tmp_path, '',
+                                       conn=temp_db_conn)
+
+
 def test_import_base_data(src_dir, temp_db, temp_db_cursor):
     temp_db_cursor.execute('CREATE EXTENSION hstore')
     temp_db_cursor.execute('CREATE EXTENSION postgis')
 def test_import_base_data(src_dir, temp_db, temp_db_cursor):
     temp_db_cursor.execute('CREATE EXTENSION hstore')
     temp_db_cursor.execute('CREATE EXTENSION postgis')
@@ -134,3 +168,35 @@ def test_import_osm_data_default_cache(temp_db_cursor,osm2pgsql_options):
     osm2pgsql_options['osm2pgsql_cache'] = 0
 
     database_import.import_osm_data(Path(__file__), osm2pgsql_options)
     osm2pgsql_options['osm2pgsql_cache'] = 0
 
     database_import.import_osm_data(Path(__file__), osm2pgsql_options)
+
+
+def test_truncate_database_tables(temp_db_conn, temp_db_cursor, table_factory):
+    tables = ('word', 'placex', 'place_addressline', 'location_area',
+              'location_area_country', 'location_property',
+              'location_property_tiger', 'location_property_osmline',
+              'location_postcode', 'search_name', 'location_road_23')
+    for table in tables:
+        table_factory(table, content=(1, 2, 3))
+
+    database_import.truncate_data_tables(temp_db_conn, max_word_frequency=23)
+
+    for table in tables:
+        assert temp_db_cursor.table_rows(table) == 0
+
+
+@pytest.mark.parametrize("threads", (1, 5))
+def test_load_data(dsn, src_dir, place_row, placex_table, osmline_table, word_table,
+                   temp_db_cursor, threads):
+    for func in ('make_keywords', 'getorcreate_housenumber_id', 'make_standard_name'):
+        temp_db_cursor.execute("""CREATE FUNCTION {} (src TEXT)
+                                  RETURNS TEXT AS $$ SELECT 'a' $$ LANGUAGE SQL
+                               """.format(func))
+    for oid in range(100, 130):
+        place_row(osm_id=oid)
+    place_row(osm_type='W', osm_id=342, cls='place', typ='houses',
+              geom='SRID=4326;LINESTRING(0 0, 10 10)')
+
+    database_import.load_data(dsn, src_dir / 'data', threads)
+
+    assert temp_db_cursor.table_rows('placex') == 30
+    assert temp_db_cursor.table_rows('location_property_osmline') == 1