]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/tools/database_import.py
fix style issue found by flake8
[nominatim.git] / src / nominatim_db / tools / database_import.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Functions for setting up and importing a new Nominatim database.
9 """
10 from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
11 import logging
12 import os
13 import subprocess
14 import asyncio
15 from pathlib import Path
16
17 import psutil
18 import psycopg
19 from psycopg import sql as pysql
20
21 from ..errors import UsageError
22 from ..config import Configuration
23 from ..db.connection import connect, get_pg_env, Connection, server_version_tuple, \
24                             postgis_version_tuple, drop_tables, table_exists, execute_scalar
25 from ..db.sql_preprocessor import SQLPreprocessor
26 from ..db.query_pool import QueryPool
27 from .exec_utils import run_osm2pgsql
28 from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
29
30 LOG = logging.getLogger()
31
32
33 def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None:
34     """ Compares the version for the given module and raises an exception
35         if the actual version is too old.
36     """
37     if actual < expected:
38         LOG.fatal('Minimum supported version of %s is %d.%d. '
39                   'Found version %d.%d.',
40                   module, expected[0], expected[1], actual[0], actual[1])
41         raise UsageError(f'{module} is too old.')
42
43
44 def _require_loaded(extension_name: str, conn: Connection) -> None:
45     """ Check that the given extension is loaded. """
46     with conn.cursor() as cur:
47         cur.execute('SELECT * FROM pg_extension WHERE extname = %s', (extension_name, ))
48         if cur.rowcount <= 0:
49             LOG.fatal('Required module %s is not loaded.', extension_name)
50             raise UsageError(f'{extension_name} is not loaded.')
51
52
53 def check_existing_database_plugins(dsn: str) -> None:
54     """ Check that the database has the required plugins installed."""
55     with connect(dsn) as conn:
56         _require_version('PostgreSQL server',
57                          server_version_tuple(conn),
58                          POSTGRESQL_REQUIRED_VERSION)
59         _require_version('PostGIS',
60                          postgis_version_tuple(conn),
61                          POSTGIS_REQUIRED_VERSION)
62         _require_loaded('hstore', conn)
63
64
65 def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
66     """ Create a new database for Nominatim and populate it with the
67         essential extensions.
68
69         The function fails when the database already exists or Postgresql or
70         PostGIS versions are too old.
71
72         Uses `createdb` to create the database.
73
74         If 'rouser' is given, then the function also checks that the user
75         with that given name exists.
76
77         Requires superuser rights by the caller.
78     """
79     proc = subprocess.run(['createdb'], env=get_pg_env(dsn), check=False)
80
81     if proc.returncode != 0:
82         raise UsageError('Creating new database failed.')
83
84     with connect(dsn) as conn:
85         _require_version('PostgreSQL server',
86                          server_version_tuple(conn),
87                          POSTGRESQL_REQUIRED_VERSION)
88
89         if rouser is not None:
90             cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s',
91                                  (rouser, ))
92             if cnt == 0:
93                 LOG.fatal("Web user '%s' does not exist. Create it with:\n"
94                           "\n      createuser %s", rouser, rouser)
95                 raise UsageError('Missing read-only user.')
96
97         # Create extensions.
98         with conn.cursor() as cur:
99             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
100             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
101
102             postgis_version = postgis_version_tuple(conn)
103             if postgis_version[0] >= 3:
104                 cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
105
106         conn.commit()
107
108         _require_version('PostGIS',
109                          postgis_version_tuple(conn),
110                          POSTGIS_REQUIRED_VERSION)
111
112
113 def import_osm_data(osm_files: Union[Path, Sequence[Path]],
114                     options: MutableMapping[str, Any],
115                     drop: bool = False, ignore_errors: bool = False) -> None:
116     """ Import the given OSM files. 'options' contains the list of
117         default settings for osm2pgsql.
118     """
119     options['import_file'] = osm_files
120     options['append'] = False
121     options['threads'] = 1
122
123     if not options['flatnode_file'] and options['osm2pgsql_cache'] == 0:
124         # Make some educated guesses about cache size based on the size
125         # of the import file and the available memory.
126         mem = psutil.virtual_memory()
127         fsize = 0
128         if isinstance(osm_files, list):
129             for fname in osm_files:
130                 fsize += os.stat(str(fname)).st_size
131         else:
132             fsize = os.stat(str(osm_files)).st_size
133         options['osm2pgsql_cache'] = int(min((mem.available + mem.cached) * 0.75,
134                                              fsize * 2) / 1024 / 1024) + 1
135
136     run_osm2pgsql(options)
137
138     with connect(options['dsn']) as conn:
139         if not ignore_errors:
140             with conn.cursor() as cur:
141                 cur.execute('SELECT true FROM place LIMIT 1')
142                 if cur.rowcount == 0:
143                     raise UsageError('No data imported by osm2pgsql.')
144
145         if drop:
146             drop_tables(conn, 'planet_osm_nodes')
147             conn.commit()
148
149     if drop and options['flatnode_file']:
150         Path(options['flatnode_file']).unlink()
151
152
153 def create_tables(conn: Connection, config: Configuration, reverse_only: bool = False) -> None:
154     """ Create the set of basic tables.
155         When `reverse_only` is True, then the main table for searching will
156         be skipped and only reverse search is possible.
157     """
158     sql = SQLPreprocessor(conn, config)
159     sql.env.globals['db']['reverse_only'] = reverse_only
160
161     sql.run_sql_file(conn, 'tables.sql')
162
163
164 def create_table_triggers(conn: Connection, config: Configuration) -> None:
165     """ Create the triggers for the tables. The trigger functions must already
166         have been imported with refresh.create_functions().
167     """
168     sql = SQLPreprocessor(conn, config)
169     sql.run_sql_file(conn, 'table-triggers.sql')
170
171
172 def create_partition_tables(conn: Connection, config: Configuration) -> None:
173     """ Create tables that have explicit partitioning.
174     """
175     sql = SQLPreprocessor(conn, config)
176     sql.run_sql_file(conn, 'partition-tables.src.sql')
177
178
179 def truncate_data_tables(conn: Connection) -> None:
180     """ Truncate all data tables to prepare for a fresh load.
181     """
182     with conn.cursor() as cur:
183         cur.execute('TRUNCATE placex')
184         cur.execute('TRUNCATE place_addressline')
185         cur.execute('TRUNCATE location_area')
186         cur.execute('TRUNCATE location_area_country')
187         cur.execute('TRUNCATE location_property_tiger')
188         cur.execute('TRUNCATE location_property_osmline')
189         cur.execute('TRUNCATE location_postcode')
190         if table_exists(conn, 'search_name'):
191             cur.execute('TRUNCATE search_name')
192         cur.execute('DROP SEQUENCE IF EXISTS seq_place')
193         cur.execute('CREATE SEQUENCE seq_place start 100000')
194
195         cur.execute("""SELECT tablename FROM pg_tables
196                        WHERE tablename LIKE 'location_road_%'""")
197
198         for table in [r[0] for r in list(cur)]:
199             cur.execute('TRUNCATE ' + table)
200
201     conn.commit()
202
203
204 _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
205                                         ('osm_type', 'osm_id', 'class', 'type',
206                                          'name', 'admin_level', 'address',
207                                          'extratags', 'geometry')))
208
209
210 async def load_data(dsn: str, threads: int) -> None:
211     """ Copy data into the word and placex table.
212     """
213     placex_threads = max(1, threads - 1)
214
215     progress = asyncio.create_task(_progress_print())
216
217     async with QueryPool(dsn, placex_threads + 1) as pool:
218         # Copy data from place to placex in <threads - 1> chunks.
219         for imod in range(placex_threads):
220             await pool.put_query(
221                 pysql.SQL("""INSERT INTO placex ({columns})
222                                SELECT {columns} FROM place
223                                 WHERE osm_id % {total} = {mod}
224                                   AND NOT (class='place'
225                                            and (type='houses' or type='postcode'))
226                                   AND ST_IsValid(geometry)
227                           """).format(columns=_COPY_COLUMNS,
228                                       total=pysql.Literal(placex_threads),
229                                       mod=pysql.Literal(imod)), None)
230
231         # Interpolations need to be copied seperately
232         await pool.put_query("""
233                 INSERT INTO location_property_osmline (osm_id, address, linegeo)
234                   SELECT osm_id, address, geometry FROM place
235                   WHERE class='place' and type='houses' and osm_type='W'
236                         and ST_GeometryType(geometry) = 'ST_LineString' """, None)
237
238     progress.cancel()
239
240     async with await psycopg.AsyncConnection.connect(dsn) as aconn:
241         await aconn.execute('ANALYSE')
242
243
244 async def _progress_print() -> None:
245     while True:
246         try:
247             await asyncio.sleep(1)
248         except asyncio.CancelledError:
249             print('', flush=True)
250             break
251         print('.', end='', flush=True)
252
253
254 async def create_search_indices(conn: Connection, config: Configuration,
255                                 drop: bool = False, threads: int = 1) -> None:
256     """ Create tables that have explicit partitioning.
257     """
258
259     # If index creation failed and left an index invalid, they need to be
260     # cleaned out first, so that the script recreates them.
261     with conn.cursor() as cur:
262         cur.execute("""SELECT relname FROM pg_class, pg_index
263                        WHERE pg_index.indisvalid = false
264                              AND pg_index.indexrelid = pg_class.oid""")
265         bad_indices = [row[0] for row in list(cur)]
266         for idx in bad_indices:
267             LOG.info("Drop invalid index %s.", idx)
268             cur.execute(pysql.SQL('DROP INDEX {}').format(pysql.Identifier(idx)))
269     conn.commit()
270
271     sql = SQLPreprocessor(conn, config)
272
273     await sql.run_parallel_sql_file(config.get_libpq_dsn(),
274                                     'indices.sql', min(8, threads), drop=drop)