]> git.openstreetmap.org Git - nominatim.git/blob - test/python/conftest.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / test / python / conftest.py
1 import itertools
2 import sys
3 from pathlib import Path
4
5 import psycopg2
6 import psycopg2.extras
7 import pytest
8 import tempfile
9
10 SRC_DIR = Path(__file__) / '..' / '..' / '..'
11
12 # always test against the source
13 sys.path.insert(0, str(SRC_DIR.resolve()))
14
15 from nominatim.config import Configuration
16 from nominatim.db import connection
17 from nominatim.db.sql_preprocessor import SQLPreprocessor
18
19 class _TestingCursor(psycopg2.extras.DictCursor):
20     """ Extension to the DictCursor class that provides execution
21         short-cuts that simplify writing assertions.
22     """
23
24     def scalar(self, sql, params=None):
25         """ Execute a query with a single return value and return this value.
26             Raises an assertion when not exactly one row is returned.
27         """
28         self.execute(sql, params)
29         assert self.rowcount == 1
30         return self.fetchone()[0]
31
32     def row_set(self, sql, params=None):
33         """ Execute a query and return the result as a set of tuples.
34         """
35         self.execute(sql, params)
36         if self.rowcount == 1:
37             return set(tuple(self.fetchone()))
38
39         return set((tuple(row) for row in self))
40
41     def table_exists(self, table):
42         """ Check that a table with the given name exists in the database.
43         """
44         num = self.scalar("""SELECT count(*) FROM pg_tables
45                              WHERE tablename = %s""", (table, ))
46         return num == 1
47
48     def table_rows(self, table):
49         """ Return the number of rows in the given table.
50         """
51         return self.scalar('SELECT count(*) FROM ' + table)
52
53
54 @pytest.fixture
55 def temp_db(monkeypatch):
56     """ Create an empty database for the test. The database name is also
57         exported into NOMINATIM_DATABASE_DSN.
58     """
59     name = 'test_nominatim_python_unittest'
60     conn = psycopg2.connect(database='postgres')
61
62     conn.set_isolation_level(0)
63     with conn.cursor() as cur:
64         cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
65         cur.execute('CREATE DATABASE {}'.format(name))
66
67     conn.close()
68
69     monkeypatch.setenv('NOMINATIM_DATABASE_DSN' , 'dbname=' + name)
70
71     yield name
72
73     conn = psycopg2.connect(database='postgres')
74
75     conn.set_isolation_level(0)
76     with conn.cursor() as cur:
77         cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
78
79     conn.close()
80
81
82 @pytest.fixture
83 def dsn(temp_db):
84     return 'dbname=' + temp_db
85
86
87 @pytest.fixture
88 def temp_db_with_extensions(temp_db):
89     conn = psycopg2.connect(database=temp_db)
90     with conn.cursor() as cur:
91         cur.execute('CREATE EXTENSION hstore; CREATE EXTENSION postgis;')
92     conn.commit()
93     conn.close()
94
95     return temp_db
96
97 @pytest.fixture
98 def temp_db_conn(temp_db):
99     """ Connection to the test database.
100     """
101     with connection.connect('dbname=' + temp_db) as conn:
102         yield conn
103
104
105 @pytest.fixture
106 def temp_db_cursor(temp_db):
107     """ Connection and cursor towards the test database. The connection will
108         be in auto-commit mode.
109     """
110     conn = psycopg2.connect('dbname=' + temp_db)
111     conn.set_isolation_level(0)
112     with conn.cursor(cursor_factory=_TestingCursor) as cur:
113         yield cur
114     conn.close()
115
116
117 @pytest.fixture
118 def table_factory(temp_db_cursor):
119     def mk_table(name, definition='id INT', content=None):
120         temp_db_cursor.execute('CREATE TABLE {} ({})'.format(name, definition))
121         if content is not None:
122             if not isinstance(content, str):
123                 content = '),('.join([str(x) for x in content])
124             temp_db_cursor.execute("INSERT INTO {} VALUES ({})".format(name, content))
125
126     return mk_table
127
128
129 @pytest.fixture
130 def def_config():
131     return Configuration(None, SRC_DIR.resolve() / 'settings')
132
133 @pytest.fixture
134 def src_dir():
135     return SRC_DIR.resolve()
136
137 @pytest.fixture
138 def tmp_phplib_dir():
139     with tempfile.TemporaryDirectory() as phpdir:
140         (Path(phpdir) / 'admin').mkdir()
141
142         yield Path(phpdir)
143
144 @pytest.fixture
145 def status_table(temp_db_conn):
146     """ Create an empty version of the status table and
147         the status logging table.
148     """
149     with temp_db_conn.cursor() as cur:
150         cur.execute("""CREATE TABLE import_status (
151                            lastimportdate timestamp with time zone NOT NULL,
152                            sequence_id integer,
153                            indexed boolean
154                        )""")
155         cur.execute("""CREATE TABLE import_osmosis_log (
156                            batchend timestamp,
157                            batchseq integer,
158                            batchsize bigint,
159                            starttime timestamp,
160                            endtime timestamp,
161                            event text
162                            )""")
163     temp_db_conn.commit()
164
165
166 @pytest.fixture
167 def place_table(temp_db_with_extensions, temp_db_conn):
168     """ Create an empty version of the place table.
169     """
170     with temp_db_conn.cursor() as cur:
171         cur.execute("""CREATE TABLE place (
172                            osm_id int8 NOT NULL,
173                            osm_type char(1) NOT NULL,
174                            class text NOT NULL,
175                            type text NOT NULL,
176                            name hstore,
177                            admin_level smallint,
178                            address hstore,
179                            extratags hstore,
180                            geometry Geometry(Geometry,4326) NOT NULL)""")
181     temp_db_conn.commit()
182
183
184 @pytest.fixture
185 def place_row(place_table, temp_db_cursor):
186     """ A factory for rows in the place table. The table is created as a
187         prerequisite to the fixture.
188     """
189     idseq = itertools.count(1001)
190     def _insert(osm_type='N', osm_id=None, cls='amenity', typ='cafe', names=None,
191                 admin_level=None, address=None, extratags=None, geom=None):
192         temp_db_cursor.execute("INSERT INTO place VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)",
193                                (osm_id or next(idseq), osm_type, cls, typ, names,
194                                 admin_level, address, extratags,
195                                 geom or 'SRID=4326;POINT(0 0)'))
196
197     return _insert
198
199 @pytest.fixture
200 def placex_table(temp_db_with_extensions, temp_db_conn):
201     """ Create an empty version of the place table.
202     """
203     with temp_db_conn.cursor() as cur:
204         cur.execute("""CREATE TABLE placex (
205                            place_id BIGINT,
206                            parent_place_id BIGINT,
207                            linked_place_id BIGINT,
208                            importance FLOAT,
209                            indexed_date TIMESTAMP,
210                            geometry_sector INTEGER,
211                            rank_address SMALLINT,
212                            rank_search SMALLINT,
213                            partition SMALLINT,
214                            indexed_status SMALLINT,
215                            osm_id int8,
216                            osm_type char(1),
217                            class text,
218                            type text,
219                            name hstore,
220                            admin_level smallint,
221                            address hstore,
222                            extratags hstore,
223                            geometry Geometry(Geometry,4326),
224                            wikipedia TEXT,
225                            country_code varchar(2),
226                            housenumber TEXT,
227                            postcode TEXT,
228                            centroid GEOMETRY(Geometry, 4326))""")
229     temp_db_conn.commit()
230
231
232 @pytest.fixture
233 def osmline_table(temp_db_with_extensions, temp_db_conn):
234     with temp_db_conn.cursor() as cur:
235         cur.execute("""CREATE TABLE location_property_osmline (
236                            place_id BIGINT,
237                            osm_id BIGINT,
238                            parent_place_id BIGINT,
239                            geometry_sector INTEGER,
240                            indexed_date TIMESTAMP,
241                            startnumber INTEGER,
242                            endnumber INTEGER,
243                            partition SMALLINT,
244                            indexed_status SMALLINT,
245                            linegeo GEOMETRY,
246                            interpolationtype TEXT,
247                            address HSTORE,
248                            postcode TEXT,
249                            country_code VARCHAR(2))""")
250     temp_db_conn.commit()
251
252
253 @pytest.fixture
254 def word_table(temp_db, temp_db_conn):
255     with temp_db_conn.cursor() as cur:
256         cur.execute("""CREATE TABLE word (
257                            word_id INTEGER,
258                            word_token text,
259                            word text,
260                            class text,
261                            type text,
262                            country_code varchar(2),
263                            search_name_count INTEGER,
264                            operator TEXT)""")
265     temp_db_conn.commit()
266
267
268 @pytest.fixture
269 def osm2pgsql_options(temp_db):
270     return dict(osm2pgsql='echo',
271                 osm2pgsql_cache=10,
272                 osm2pgsql_style='style.file',
273                 threads=1,
274                 dsn='dbname=' + temp_db,
275                 flatnode_file='',
276                 tablespaces=dict(slim_data='', slim_index='',
277                                  main_data='', main_index=''))
278
279 @pytest.fixture
280 def sql_preprocessor(temp_db_conn, tmp_path, def_config, monkeypatch, table_factory):
281     monkeypatch.setenv('NOMINATIM_DATABASE_MODULE_PATH', '.')
282     table_factory('country_name', 'partition INT', (0, 1, 2))
283     return SQLPreprocessor(temp_db_conn, def_config, tmp_path)