]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_db/db/sql_preprocessor.py
Merge pull request #3458 from lonvia/python-package
[nominatim.git] / src / nominatim_db / db / sql_preprocessor.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Preprocessing of SQL files.
9 """
10 from typing import Set, Dict, Any, cast
11 import jinja2
12
13 from .connection import Connection
14 from .async_connection import WorkerPool
15 from ..config import Configuration
16
17 def _get_partitions(conn: Connection) -> Set[int]:
18     """ Get the set of partitions currently in use.
19     """
20     with conn.cursor() as cur:
21         cur.execute('SELECT DISTINCT partition FROM country_name')
22         partitions = set([0])
23         for row in cur:
24             partitions.add(row[0])
25
26     return partitions
27
28
29 def _get_tables(conn: Connection) -> Set[str]:
30     """ Return the set of tables currently in use.
31     """
32     with conn.cursor() as cur:
33         cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
34
35         return set((row[0] for row in list(cur)))
36
37 def _get_middle_db_format(conn: Connection, tables: Set[str]) -> str:
38     """ Returns the version of the slim middle tables.
39     """
40     if 'osm2pgsql_properties' not in tables:
41         return '1'
42
43     with conn.cursor() as cur:
44         cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
45         row = cur.fetchone()
46
47         return cast(str, row[0]) if row is not None else '1'
48
49
50 def _setup_tablespace_sql(config: Configuration) -> Dict[str, str]:
51     """ Returns a dict with tablespace expressions for the different tablespace
52         kinds depending on whether a tablespace is configured or not.
53     """
54     out = {}
55     for subset in ('ADDRESS', 'SEARCH', 'AUX'):
56         for kind in ('DATA', 'INDEX'):
57             tspace = getattr(config, f'TABLESPACE_{subset}_{kind}')
58             if tspace:
59                 tspace = f'TABLESPACE "{tspace}"'
60             out[f'{subset.lower()}_{kind.lower()}'] = tspace
61
62     return out
63
64
65 def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
66     """ Set up a dictionary with various optional Postgresql/Postgis features that
67         depend on the database version.
68     """
69     pg_version = conn.server_version_tuple()
70     postgis_version = conn.postgis_version_tuple()
71     pg11plus = pg_version >= (11, 0, 0)
72     ps3 = postgis_version >= (3, 0)
73     return {
74         'has_index_non_key_column': pg11plus,
75         'spgist_geom' : 'SPGIST' if pg11plus and ps3 else 'GIST'
76     }
77
78 class SQLPreprocessor:
79     """ A environment for preprocessing SQL files from the
80         lib-sql directory.
81
82         The preprocessor provides a number of default filters and variables.
83         The variables may be overwritten when rendering an SQL file.
84
85         The preprocessing is currently based on the jinja2 templating library
86         and follows its syntax.
87     """
88
89     def __init__(self, conn: Connection, config: Configuration) -> None:
90         self.env = jinja2.Environment(autoescape=False,
91                                       loader=jinja2.FileSystemLoader(str(config.lib_dir.sql)))
92
93         db_info: Dict[str, Any] = {}
94         db_info['partitions'] = _get_partitions(conn)
95         db_info['tables'] = _get_tables(conn)
96         db_info['reverse_only'] = 'search_name' not in db_info['tables']
97         db_info['tablespace'] = _setup_tablespace_sql(config)
98         db_info['middle_db_format'] = _get_middle_db_format(conn, db_info['tables'])
99
100         self.env.globals['config'] = config
101         self.env.globals['db'] = db_info
102         self.env.globals['postgres'] = _setup_postgresql_features(conn)
103
104
105     def run_string(self, conn: Connection, template: str, **kwargs: Any) -> None:
106         """ Execute the given SQL template string on the connection.
107             The keyword arguments may supply additional parameters
108             for preprocessing.
109         """
110         sql = self.env.from_string(template).render(**kwargs)
111
112         with conn.cursor() as cur:
113             cur.execute(sql)
114         conn.commit()
115
116
117     def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None:
118         """ Execute the given SQL file on the connection. The keyword arguments
119             may supply additional parameters for preprocessing.
120         """
121         sql = self.env.get_template(name).render(**kwargs)
122
123         with conn.cursor() as cur:
124             cur.execute(sql)
125         conn.commit()
126
127
128     def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
129                               **kwargs: Any) -> None:
130         """ Execute the given SQL files using parallel asynchronous connections.
131             The keyword arguments may supply additional parameters for
132             preprocessing.
133
134             After preprocessing the SQL code is cut at lines containing only
135             '---'. Each chunk is sent to one of the `num_threads` workers.
136         """
137         sql = self.env.get_template(name).render(**kwargs)
138
139         parts = sql.split('\n---\n')
140
141         with WorkerPool(dsn, num_threads) as pool:
142             for part in parts:
143                 pool.next_free_worker().perform(part)