]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/tools/database_import.py
Merge pull request #2770 from lonvia/typed-python
[nominatim.git] / nominatim / tools / database_import.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Functions for setting up and importing a new Nominatim database.
9 """
10 from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
11 import logging
12 import os
13 import selectors
14 import subprocess
15 from pathlib import Path
16
17 import psutil
18 from psycopg2 import sql as pysql
19
20 from nominatim.config import Configuration
21 from nominatim.db.connection import connect, get_pg_env, Connection
22 from nominatim.db.async_connection import DBConnection
23 from nominatim.db.sql_preprocessor import SQLPreprocessor
24 from nominatim.tools.exec_utils import run_osm2pgsql
25 from nominatim.errors import UsageError
26 from nominatim.version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
27
28 LOG = logging.getLogger()
29
30 def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None:
31     """ Compares the version for the given module and raises an exception
32         if the actual version is too old.
33     """
34     if actual < expected:
35         LOG.fatal('Minimum supported version of %s is %d.%d. '
36                   'Found version %d.%d.',
37                   module, expected[0], expected[1], actual[0], actual[1])
38         raise UsageError(f'{module} is too old.')
39
40
41 def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
42     """ Create a new database for Nominatim and populate it with the
43         essential extensions.
44
45         The function fails when the database already exists or Postgresql or
46         PostGIS versions are too old.
47
48         Uses `createdb` to create the database.
49
50         If 'rouser' is given, then the function also checks that the user
51         with that given name exists.
52
53         Requires superuser rights by the caller.
54     """
55     proc = subprocess.run(['createdb'], env=get_pg_env(dsn), check=False)
56
57     if proc.returncode != 0:
58         raise UsageError('Creating new database failed.')
59
60     with connect(dsn) as conn:
61         _require_version('PostgreSQL server',
62                          conn.server_version_tuple(),
63                          POSTGRESQL_REQUIRED_VERSION)
64
65         if rouser is not None:
66             with conn.cursor() as cur:
67                 cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
68                                  (rouser, ))
69                 if cnt == 0:
70                     LOG.fatal("Web user '%s' does not exist. Create it with:\n"
71                               "\n      createuser %s", rouser, rouser)
72                     raise UsageError('Missing read-only user.')
73
74         # Create extensions.
75         with conn.cursor() as cur:
76             cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
77             cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
78         conn.commit()
79
80         _require_version('PostGIS',
81                          conn.postgis_version_tuple(),
82                          POSTGIS_REQUIRED_VERSION)
83
84
85 def import_osm_data(osm_files: Union[Path, Sequence[Path]],
86                     options: MutableMapping[str, Any],
87                     drop: bool = False, ignore_errors: bool = False) -> None:
88     """ Import the given OSM files. 'options' contains the list of
89         default settings for osm2pgsql.
90     """
91     options['import_file'] = osm_files
92     options['append'] = False
93     options['threads'] = 1
94
95     if not options['flatnode_file'] and options['osm2pgsql_cache'] == 0:
96         # Make some educated guesses about cache size based on the size
97         # of the import file and the available memory.
98         mem = psutil.virtual_memory() # type: ignore[no-untyped-call]
99         fsize = 0
100         if isinstance(osm_files, list):
101             for fname in osm_files:
102                 fsize += os.stat(str(fname)).st_size
103         else:
104             fsize = os.stat(str(osm_files)).st_size
105         options['osm2pgsql_cache'] = int(min((mem.available + mem.cached) * 0.75,
106                                              fsize * 2) / 1024 / 1024) + 1
107
108     run_osm2pgsql(options)
109
110     with connect(options['dsn']) as conn:
111         if not ignore_errors:
112             with conn.cursor() as cur:
113                 cur.execute('SELECT * FROM place LIMIT 1')
114                 if cur.rowcount == 0:
115                     raise UsageError('No data imported by osm2pgsql.')
116
117         if drop:
118             conn.drop_table('planet_osm_nodes')
119
120     if drop and options['flatnode_file']:
121         Path(options['flatnode_file']).unlink()
122
123
124 def create_tables(conn: Connection, config: Configuration, reverse_only: bool = False) -> None:
125     """ Create the set of basic tables.
126         When `reverse_only` is True, then the main table for searching will
127         be skipped and only reverse search is possible.
128     """
129     sql = SQLPreprocessor(conn, config)
130     sql.env.globals['db']['reverse_only'] = reverse_only
131
132     sql.run_sql_file(conn, 'tables.sql')
133
134
135 def create_table_triggers(conn: Connection, config: Configuration) -> None:
136     """ Create the triggers for the tables. The trigger functions must already
137         have been imported with refresh.create_functions().
138     """
139     sql = SQLPreprocessor(conn, config)
140     sql.run_sql_file(conn, 'table-triggers.sql')
141
142
143 def create_partition_tables(conn: Connection, config: Configuration) -> None:
144     """ Create tables that have explicit partitioning.
145     """
146     sql = SQLPreprocessor(conn, config)
147     sql.run_sql_file(conn, 'partition-tables.src.sql')
148
149
150 def truncate_data_tables(conn: Connection) -> None:
151     """ Truncate all data tables to prepare for a fresh load.
152     """
153     with conn.cursor() as cur:
154         cur.execute('TRUNCATE placex')
155         cur.execute('TRUNCATE place_addressline')
156         cur.execute('TRUNCATE location_area')
157         cur.execute('TRUNCATE location_area_country')
158         cur.execute('TRUNCATE location_property_tiger')
159         cur.execute('TRUNCATE location_property_osmline')
160         cur.execute('TRUNCATE location_postcode')
161         if conn.table_exists('search_name'):
162             cur.execute('TRUNCATE search_name')
163         cur.execute('DROP SEQUENCE IF EXISTS seq_place')
164         cur.execute('CREATE SEQUENCE seq_place start 100000')
165
166         cur.execute("""SELECT tablename FROM pg_tables
167                        WHERE tablename LIKE 'location_road_%'""")
168
169         for table in [r[0] for r in list(cur)]:
170             cur.execute('TRUNCATE ' + table)
171
172     conn.commit()
173
174
175 _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
176                                         ('osm_type', 'osm_id', 'class', 'type',
177                                          'name', 'admin_level', 'address',
178                                          'extratags', 'geometry')))
179
180
181 def load_data(dsn: str, threads: int) -> None:
182     """ Copy data into the word and placex table.
183     """
184     sel = selectors.DefaultSelector()
185     # Then copy data from place to placex in <threads - 1> chunks.
186     place_threads = max(1, threads - 1)
187     for imod in range(place_threads):
188         conn = DBConnection(dsn)
189         conn.connect()
190         conn.perform(
191             pysql.SQL("""INSERT INTO placex ({columns})
192                            SELECT {columns} FROM place
193                            WHERE osm_id % {total} = {mod}
194                              AND NOT (class='place' and (type='houses' or type='postcode'))
195                              AND ST_IsValid(geometry)
196                       """).format(columns=_COPY_COLUMNS,
197                                   total=pysql.Literal(place_threads),
198                                   mod=pysql.Literal(imod)))
199         sel.register(conn, selectors.EVENT_READ, conn)
200
201     # Address interpolations go into another table.
202     conn = DBConnection(dsn)
203     conn.connect()
204     conn.perform("""INSERT INTO location_property_osmline (osm_id, address, linegeo)
205                       SELECT osm_id, address, geometry FROM place
206                       WHERE class='place' and type='houses' and osm_type='W'
207                             and ST_GeometryType(geometry) = 'ST_LineString'
208                  """)
209     sel.register(conn, selectors.EVENT_READ, conn)
210
211     # Now wait for all of them to finish.
212     todo = place_threads + 1
213     while todo > 0:
214         for key, _ in sel.select(1):
215             conn = key.data
216             sel.unregister(conn)
217             conn.wait()
218             conn.close()
219             todo -= 1
220         print('.', end='', flush=True)
221     print('\n')
222
223     with connect(dsn) as syn_conn:
224         with syn_conn.cursor() as cur:
225             cur.execute('ANALYSE')
226
227
228 def create_search_indices(conn: Connection, config: Configuration, drop: bool = False) -> None:
229     """ Create tables that have explicit partitioning.
230     """
231
232     # If index creation failed and left an index invalid, they need to be
233     # cleaned out first, so that the script recreates them.
234     with conn.cursor() as cur:
235         cur.execute("""SELECT relname FROM pg_class, pg_index
236                        WHERE pg_index.indisvalid = false
237                              AND pg_index.indexrelid = pg_class.oid""")
238         bad_indices = [row[0] for row in list(cur)]
239         for idx in bad_indices:
240             LOG.info("Drop invalid index %s.", idx)
241             cur.execute(pysql.SQL('DROP INDEX {}').format(pysql.Identifier(idx)))
242     conn.commit()
243
244     sql = SQLPreprocessor(conn, config)
245
246     sql.run_sql_file(conn, 'indices.sql', drop=drop)