]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/sql_preprocessor.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / db / sql_preprocessor.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Preprocessing of SQL files.
9 """
10 from typing import Set, Dict, Any
11 import jinja2
12
13 from nominatim.db.connection import Connection
14 from nominatim.db.async_connection import WorkerPool
15 from nominatim.config import Configuration
16
17 def _get_partitions(conn: Connection) -> Set[int]:
18     """ Get the set of partitions currently in use.
19     """
20     with conn.cursor() as cur:
21         cur.execute('SELECT DISTINCT partition FROM country_name')
22         partitions = set([0])
23         for row in cur:
24             partitions.add(row[0])
25
26     return partitions
27
28
29 def _get_tables(conn: Connection) -> Set[str]:
30     """ Return the set of tables currently in use.
31         Only includes non-partitioned
32     """
33     with conn.cursor() as cur:
34         cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
35
36         return set((row[0] for row in list(cur)))
37
38
39 def _setup_tablespace_sql(config: Configuration) -> Dict[str, str]:
40     """ Returns a dict with tablespace expressions for the different tablespace
41         kinds depending on whether a tablespace is configured or not.
42     """
43     out = {}
44     for subset in ('ADDRESS', 'SEARCH', 'AUX'):
45         for kind in ('DATA', 'INDEX'):
46             tspace = getattr(config, f'TABLESPACE_{subset}_{kind}')
47             if tspace:
48                 tspace = f'TABLESPACE "{tspace}"'
49             out[f'{subset.lower()}_{kind.lower()}'] = tspace
50
51     return out
52
53
54 def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
55     """ Set up a dictionary with various optional Postgresql/Postgis features that
56         depend on the database version.
57     """
58     pg_version = conn.server_version_tuple()
59     postgis_version = conn.postgis_version_tuple()
60     pg11plus = pg_version >= (11, 0, 0)
61     ps3 = postgis_version >= (3, 0)
62     return {
63         'has_index_non_key_column': pg11plus,
64         'spgist_geom' : 'SPGIST' if pg11plus and ps3 else 'GIST'
65     }
66
67 class SQLPreprocessor:
68     """ A environment for preprocessing SQL files from the
69         lib-sql directory.
70
71         The preprocessor provides a number of default filters and variables.
72         The variables may be overwritten when rendering an SQL file.
73
74         The preprocessing is currently based on the jinja2 templating library
75         and follows its syntax.
76     """
77
78     def __init__(self, conn: Connection, config: Configuration) -> None:
79         self.env = jinja2.Environment(autoescape=False,
80                                       loader=jinja2.FileSystemLoader(str(config.lib_dir.sql)))
81
82         db_info: Dict[str, Any] = {}
83         db_info['partitions'] = _get_partitions(conn)
84         db_info['tables'] = _get_tables(conn)
85         db_info['reverse_only'] = 'search_name' not in db_info['tables']
86         db_info['tablespace'] = _setup_tablespace_sql(config)
87
88         self.env.globals['config'] = config
89         self.env.globals['db'] = db_info
90         self.env.globals['postgres'] = _setup_postgresql_features(conn)
91
92
93     def run_string(self, conn: Connection, template: str, **kwargs: Any) -> None:
94         """ Execute the given SQL template string on the connection.
95             The keyword arguments may supply additional parameters
96             for preprocessing.
97         """
98         sql = self.env.from_string(template).render(**kwargs)
99
100         with conn.cursor() as cur:
101             cur.execute(sql)
102         conn.commit()
103
104
105     def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None:
106         """ Execute the given SQL file on the connection. The keyword arguments
107             may supply additional parameters for preprocessing.
108         """
109         sql = self.env.get_template(name).render(**kwargs)
110
111         with conn.cursor() as cur:
112             cur.execute(sql)
113         conn.commit()
114
115
116     def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
117                               **kwargs: Any) -> None:
118         """ Execure the given SQL files using parallel asynchronous connections.
119             The keyword arguments may supply additional parameters for
120             preprocessing.
121
122             After preprocessing the SQL code is cut at lines containing only
123             '---'. Each chunk is sent to one of the `num_threads` workers.
124         """
125         sql = self.env.get_template(name).render(**kwargs)
126
127         parts = sql.split('\n---\n')
128
129         with WorkerPool(dsn, num_threads) as pool:
130             for part in parts:
131                 pool.next_free_worker().perform(part)