]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/db/sql_preprocessor.py
remove remaining pylint hints
[nominatim.git] / src / nominatim_db / db / sql_preprocessor.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Preprocessing of SQL files.
9 """
10 from typing import Set, Dict, Any, cast
11
12 import jinja2
13
14 from .connection import Connection, server_version_tuple, postgis_version_tuple
15 from ..config import Configuration
16 from ..db.query_pool import QueryPool
17
18
19 def _get_partitions(conn: Connection) -> Set[int]:
20     """ Get the set of partitions currently in use.
21     """
22     with conn.cursor() as cur:
23         cur.execute('SELECT DISTINCT partition FROM country_name')
24         partitions = set([0])
25         for row in cur:
26             partitions.add(row[0])
27
28     return partitions
29
30
31 def _get_tables(conn: Connection) -> Set[str]:
32     """ Return the set of tables currently in use.
33     """
34     with conn.cursor() as cur:
35         cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
36
37         return set((row[0] for row in list(cur)))
38
39
40 def _get_middle_db_format(conn: Connection, tables: Set[str]) -> str:
41     """ Returns the version of the slim middle tables.
42     """
43     if 'osm2pgsql_properties' not in tables:
44         return '1'
45
46     with conn.cursor() as cur:
47         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
48         row = cur.fetchone()
49
50         return cast(str, row[0]) if row is not None else '1'
51
52
53 def _setup_tablespace_sql(config: Configuration) -> Dict[str, str]:
54     """ Returns a dict with tablespace expressions for the different tablespace
55         kinds depending on whether a tablespace is configured or not.
56     """
57     out = {}
58     for subset in ('ADDRESS', 'SEARCH', 'AUX'):
59         for kind in ('DATA', 'INDEX'):
60             tspace = getattr(config, f'TABLESPACE_{subset}_{kind}')
61             if tspace:
62                 tspace = f'TABLESPACE "{tspace}"'
63             out[f'{subset.lower()}_{kind.lower()}'] = tspace
64
65     return out
66
67
68 def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
69     """ Set up a dictionary with various optional Postgresql/Postgis features that
70         depend on the database version.
71     """
72     pg_version = server_version_tuple(conn)
73     postgis_version = postgis_version_tuple(conn)
74     pg11plus = pg_version >= (11, 0, 0)
75     ps3 = postgis_version >= (3, 0)
76     return {
77         'has_index_non_key_column': pg11plus,
78         'spgist_geom': 'SPGIST' if pg11plus and ps3 else 'GIST'
79     }
80
81
82 class SQLPreprocessor:
83     """ A environment for preprocessing SQL files from the
84         lib-sql directory.
85
86         The preprocessor provides a number of default filters and variables.
87         The variables may be overwritten when rendering an SQL file.
88
89         The preprocessing is currently based on the jinja2 templating library
90         and follows its syntax.
91     """
92
93     def __init__(self, conn: Connection, config: Configuration) -> None:
94         self.env = jinja2.Environment(autoescape=False,
95                                       loader=jinja2.FileSystemLoader(str(config.lib_dir.sql)))
96
97         db_info: Dict[str, Any] = {}
98         db_info['partitions'] = _get_partitions(conn)
99         db_info['tables'] = _get_tables(conn)
100         db_info['reverse_only'] = 'search_name' not in db_info['tables']
101         db_info['tablespace'] = _setup_tablespace_sql(config)
102         db_info['middle_db_format'] = _get_middle_db_format(conn, db_info['tables'])
103
104         self.env.globals['config'] = config
105         self.env.globals['db'] = db_info
106         self.env.globals['postgres'] = _setup_postgresql_features(conn)
107
108     def run_string(self, conn: Connection, template: str, **kwargs: Any) -> None:
109         """ Execute the given SQL template string on the connection.
110             The keyword arguments may supply additional parameters
111             for preprocessing.
112         """
113         sql = self.env.from_string(template).render(**kwargs)
114
115         with conn.cursor() as cur:
116             cur.execute(sql)
117         conn.commit()
118
119     def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None:
120         """ Execute the given SQL file on the connection. The keyword arguments
121             may supply additional parameters for preprocessing.
122         """
123         sql = self.env.get_template(name).render(**kwargs)
124
125         with conn.cursor() as cur:
126             cur.execute(sql)
127         conn.commit()
128
129     async def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
130                                     **kwargs: Any) -> None:
131         """ Execute the given SQL files using parallel asynchronous connections.
132             The keyword arguments may supply additional parameters for
133             preprocessing.
134
135             After preprocessing the SQL code is cut at lines containing only
136             '---'. Each chunk is sent to one of the `num_threads` workers.
137         """
138         sql = self.env.get_template(name).render(**kwargs)
139
140         parts = sql.split('\n---\n')
141
142         async with QueryPool(dsn, num_threads) as pool:
143             for part in parts:
144                 await pool.put_query(part, None)