]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/tools/database_import.py
e96954ddd2df605e59618fd23e204c955649ef66
[nominatim.git] / src / nominatim_db / tools / database_import.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Functions for setting up and importing a new Nominatim database.
9 """
10 from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
11 import logging
12 import os
13 import subprocess
14 import asyncio
15 from pathlib import Path
16
17 import psutil
18 import psycopg
19 from psycopg import sql as pysql
20
21 from ..errors import UsageError
22 from ..config import Configuration
23 from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\
24                             postgis_version_tuple, drop_tables, table_exists, execute_scalar
25 from ..db.sql_preprocessor import SQLPreprocessor
26 from ..db.query_pool import QueryPool
27 from .exec_utils import run_osm2pgsql
28 from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
29
30 LOG = logging.getLogger()
31
32 def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None:
33     """ Compares the version for the given module and raises an exception
34         if the actual version is too old.
35     """
36     if actual < expected:
37         LOG.fatal('Minimum supported version of %s is %d.%d. '
38                   'Found version %d.%d.',
39                   module, expected[0], expected[1], actual[0], actual[1])
40         raise UsageError(f'{module} is too old.')
41
42
43 def _require_loaded(extension_name: str, conn: Connection) -> None:
44     """ Check that the given extension is loaded. """
45     with conn.cursor() as cur:
46         cur.execute('SELECT * FROM pg_extension WHERE extname = %s', (extension_name, ))
47         if cur.rowcount <= 0:
48             LOG.fatal('Required module %s is not loaded.', extension_name)
49             raise UsageError(f'{extension_name} is not loaded.')
50
51
52 def check_existing_database_plugins(dsn: str) -> None:
53     """ Check that the database has the required plugins installed."""
54     with connect(dsn) as conn:
55         _require_version('PostgreSQL server',
56                          server_version_tuple(conn),
57                          POSTGRESQL_REQUIRED_VERSION)
58         _require_version('PostGIS',
59                          postgis_version_tuple(conn),
60                          POSTGIS_REQUIRED_VERSION)
61         _require_loaded('hstore', conn)
62
63
64 def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
65     """ Create a new database for Nominatim and populate it with the
66         essential extensions.
67
68         The function fails when the database already exists or Postgresql or
69         PostGIS versions are too old.
70
71         Uses `createdb` to create the database.
72
73         If 'rouser' is given, then the function also checks that the user
74         with that given name exists.
75
76         Requires superuser rights by the caller.
77     """
78     proc = subprocess.run(['createdb'], env=get_pg_env(dsn), check=False)
79
80     if proc.returncode != 0:
81         raise UsageError('Creating new database failed.')
82
83     with connect(dsn) as conn:
84         _require_version('PostgreSQL server',
85                          server_version_tuple(conn),
86                          POSTGRESQL_REQUIRED_VERSION)
87
88         if rouser is not None:
89             cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s',
90                                  (rouser, ))
91             if cnt == 0:
92                 LOG.fatal("Web user '%s' does not exist. Create it with:\n"
93                           "\n      createuser %s", rouser, rouser)
94                 raise UsageError('Missing read-only user.')
95
96         # Create extensions.
97         with conn.cursor() as cur:
98             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
99             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
100
101             postgis_version = postgis_version_tuple(conn)
102             if postgis_version[0] >= 3:
103                 cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
104
105         conn.commit()
106
107         _require_version('PostGIS',
108                          postgis_version_tuple(conn),
109                          POSTGIS_REQUIRED_VERSION)
110
111
112 def import_osm_data(osm_files: Union[Path, Sequence[Path]],
113                     options: MutableMapping[str, Any],
114                     drop: bool = False, ignore_errors: bool = False) -> None:
115     """ Import the given OSM files. 'options' contains the list of
116         default settings for osm2pgsql.
117     """
118     options['import_file'] = osm_files
119     options['append'] = False
120     options['threads'] = 1
121
122     if not options['flatnode_file'] and options['osm2pgsql_cache'] == 0:
123         # Make some educated guesses about cache size based on the size
124         # of the import file and the available memory.
125         mem = psutil.virtual_memory()
126         fsize = 0
127         if isinstance(osm_files, list):
128             for fname in osm_files:
129                 fsize += os.stat(str(fname)).st_size
130         else:
131             fsize = os.stat(str(osm_files)).st_size
132         options['osm2pgsql_cache'] = int(min((mem.available + mem.cached) * 0.75,
133                                              fsize * 2) / 1024 / 1024) + 1
134
135     run_osm2pgsql(options)
136
137     with connect(options['dsn']) as conn:
138         if not ignore_errors:
139             with conn.cursor() as cur:
140                 cur.execute('SELECT true FROM place LIMIT 1')
141                 if cur.rowcount == 0:
142                     raise UsageError('No data imported by osm2pgsql.')
143
144         if drop:
145             drop_tables(conn, 'planet_osm_nodes')
146             conn.commit()
147
148     if drop and options['flatnode_file']:
149         Path(options['flatnode_file']).unlink()
150
151
152 def create_tables(conn: Connection, config: Configuration, reverse_only: bool = False) -> None:
153     """ Create the set of basic tables.
154         When `reverse_only` is True, then the main table for searching will
155         be skipped and only reverse search is possible.
156     """
157     sql = SQLPreprocessor(conn, config)
158     sql.env.globals['db']['reverse_only'] = reverse_only
159
160     sql.run_sql_file(conn, 'tables.sql')
161
162
163 def create_table_triggers(conn: Connection, config: Configuration) -> None:
164     """ Create the triggers for the tables. The trigger functions must already
165         have been imported with refresh.create_functions().
166     """
167     sql = SQLPreprocessor(conn, config)
168     sql.run_sql_file(conn, 'table-triggers.sql')
169
170
171 def create_partition_tables(conn: Connection, config: Configuration) -> None:
172     """ Create tables that have explicit partitioning.
173     """
174     sql = SQLPreprocessor(conn, config)
175     sql.run_sql_file(conn, 'partition-tables.src.sql')
176
177
178 def truncate_data_tables(conn: Connection) -> None:
179     """ Truncate all data tables to prepare for a fresh load.
180     """
181     with conn.cursor() as cur:
182         cur.execute('TRUNCATE placex')
183         cur.execute('TRUNCATE place_addressline')
184         cur.execute('TRUNCATE location_area')
185         cur.execute('TRUNCATE location_area_country')
186         cur.execute('TRUNCATE location_property_tiger')
187         cur.execute('TRUNCATE location_property_osmline')
188         cur.execute('TRUNCATE location_postcode')
189         if table_exists(conn, 'search_name'):
190             cur.execute('TRUNCATE search_name')
191         cur.execute('DROP SEQUENCE IF EXISTS seq_place')
192         cur.execute('CREATE SEQUENCE seq_place start 100000')
193
194         cur.execute("""SELECT tablename FROM pg_tables
195                        WHERE tablename LIKE 'location_road_%'""")
196
197         for table in [r[0] for r in list(cur)]:
198             cur.execute('TRUNCATE ' + table)
199
200     conn.commit()
201
202
203 _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
204                                         ('osm_type', 'osm_id', 'class', 'type',
205                                          'name', 'admin_level', 'address',
206                                          'extratags', 'geometry')))
207
208
209 async def load_data(dsn: str, threads: int) -> None:
210     """ Copy data into the word and placex table.
211     """
212     placex_threads = max(1, threads - 1)
213
214     progress = asyncio.create_task(_progress_print())
215
216     async with QueryPool(dsn, placex_threads + 1) as pool:
217         # Copy data from place to placex in <threads - 1> chunks.
218         for imod in range(placex_threads):
219             await pool.put_query(
220                 pysql.SQL("""INSERT INTO placex ({columns})
221                                SELECT {columns} FROM place
222                                 WHERE osm_id % {total} = {mod}
223                                   AND NOT (class='place'
224                                            and (type='houses' or type='postcode'))
225                                   AND ST_IsValid(geometry)
226                           """).format(columns=_COPY_COLUMNS,
227                                       total=pysql.Literal(placex_threads),
228                                       mod=pysql.Literal(imod)), None)
229
230         # Interpolations need to be copied seperately
231         await pool.put_query("""
232                 INSERT INTO location_property_osmline (osm_id, address, linegeo)
233                   SELECT osm_id, address, geometry FROM place
234                   WHERE class='place' and type='houses' and osm_type='W'
235                         and ST_GeometryType(geometry) = 'ST_LineString' """, None)
236
237     progress.cancel()
238
239     async with await psycopg.AsyncConnection.connect(dsn) as aconn:
240         await aconn.execute('ANALYSE')
241
242
243 async def _progress_print() -> None:
244     while True:
245         try:
246             await asyncio.sleep(1)
247         except asyncio.CancelledError:
248             print('', flush=True)
249             break
250         print('.', end='', flush=True)
251
252
253 async def create_search_indices(conn: Connection, config: Configuration,
254                           drop: bool = False, threads: int = 1) -> None:
255     """ Create tables that have explicit partitioning.
256     """
257
258     # If index creation failed and left an index invalid, they need to be
259     # cleaned out first, so that the script recreates them.
260     with conn.cursor() as cur:
261         cur.execute("""SELECT relname FROM pg_class, pg_index
262                        WHERE pg_index.indisvalid = false
263                              AND pg_index.indexrelid = pg_class.oid""")
264         bad_indices = [row[0] for row in list(cur)]
265         for idx in bad_indices:
266             LOG.info("Drop invalid index %s.", idx)
267             cur.execute(pysql.SQL('DROP INDEX {}').format(pysql.Identifier(idx)))
268     conn.commit()
269
270     sql = SQLPreprocessor(conn, config)
271
272     await sql.run_parallel_sql_file(config.get_libpq_dsn(),
273                                     'indices.sql', min(8, threads), drop=drop)