]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/nominatim.py
Use assertEqualsWithDelta for float comparisons
[nominatim.git] / nominatim / nominatim.py
1 #! /usr/bin/env python3
2 #-----------------------------------------------------------------------------
3 # nominatim - [description]
4 #-----------------------------------------------------------------------------
5 #
6 # Indexing tool for the Nominatim database.
7 #
8 # Based on C version by Brian Quinion
9 #
10 # This program is free software; you can redistribute it and/or
11 # modify it under the terms of the GNU General Public License
12 # as published by the Free Software Foundation; either version 2
13 # of the License, or (at your option) any later version.
14 #
15 # This program is distributed in the hope that it will be useful,
16 # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 # GNU General Public License for more details.
19 #
20 # You should have received a copy of the GNU General Public License
21 # along with this program; if not, write to the Free Software
22 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
23 #-----------------------------------------------------------------------------
24
25 from argparse import ArgumentParser, RawDescriptionHelpFormatter, ArgumentTypeError
26 import logging
27 import sys
28 import re
29 import getpass
30 from datetime import datetime
31 import psycopg2
32 from psycopg2.extras import wait_select
33 import select
34
35 log = logging.getLogger()
36
37 def make_connection(options, asynchronous=False):
38     params = {'dbname' : options.dbname,
39               'user' : options.user,
40               'password' : options.password,
41               'host' : options.host,
42               'port' : options.port,
43               'async' : asynchronous}
44
45     return psycopg2.connect(**params)
46
47
48 class RankRunner(object):
49     """ Returns SQL commands for indexing one rank within the placex table.
50     """
51
52     def __init__(self, rank):
53         self.rank = rank
54
55     def name(self):
56         return "rank {}".format(self.rank)
57
58     def sql_index_sectors(self):
59         return """SELECT geometry_sector, count(*) FROM placex
60                   WHERE rank_search = {} and indexed_status > 0
61                   GROUP BY geometry_sector
62                   ORDER BY geometry_sector""".format(self.rank)
63
64     def sql_nosector_places(self):
65         return """SELECT place_id FROM placex
66                   WHERE indexed_status > 0 and rank_search = {}
67                   ORDER BY geometry_sector""".format(self.rank)
68
69     def sql_sector_places(self):
70         return """SELECT place_id FROM placex
71                   WHERE indexed_status > 0 and rank_search = {}
72                         and geometry_sector = %s""".format(self.rank)
73
74     def sql_index_place(self):
75         return "UPDATE placex SET indexed_status = 0 WHERE place_id = %s"
76
77
78 class InterpolationRunner(object):
79     """ Returns SQL commands for indexing the address interpolation table
80         location_property_osmline.
81     """
82
83     def name(self):
84         return "interpolation lines (location_property_osmline)"
85
86     def sql_index_sectors(self):
87         return """SELECT geometry_sector, count(*) FROM location_property_osmline
88                   WHERE indexed_status > 0
89                   GROUP BY geometry_sector
90                   ORDER BY geometry_sector"""
91
92     def sql_nosector_places(self):
93         return """SELECT place_id FROM location_property_osmline
94                   WHERE indexed_status > 0
95                   ORDER BY geometry_sector"""
96
97     def sql_sector_places(self):
98         return """SELECT place_id FROM location_property_osmline
99                   WHERE indexed_status > 0 and geometry_sector = %s
100                   ORDER BY geometry_sector"""
101
102     def sql_index_place(self):
103         return """UPDATE location_property_osmline
104                   SET indexed_status = 0 WHERE place_id = %s"""
105
106
107 class DBConnection(object):
108     """ A single non-blocking database connection.
109     """
110
111     def __init__(self, options):
112         self.current_query = None
113         self.current_params = None
114
115         self.conn = None
116         self.connect()
117
118     def connect(self):
119         if self.conn is not None:
120             self.cursor.close()
121             self.conn.close()
122
123         self.conn = make_connection(options, asynchronous=True)
124         self.wait()
125
126         self.cursor = self.conn.cursor()
127
128     def wait(self):
129         """ Block until any pending operation is done.
130         """
131         while True:
132             try:
133                 wait_select(self.conn)
134                 self.current_query = None
135                 return
136             except psycopg2.extensions.TransactionRollbackError as e:
137                 if e.pgcode == '40P01':
138                     log.info("Deadlock detected (params = {}), retry."
139                               .format(self.current_params))
140                     self.cursor.execute(self.current_query, self.current_params)
141                 else:
142                     raise
143             except psycopg2.errors.DeadlockDetected:
144                 self.cursor.execute(self.current_query, self.current_params)
145
146     def perform(self, sql, args=None):
147         """ Send SQL query to the server. Returns immediately without
148             blocking.
149         """
150         self.current_query = sql
151         self.current_params = args
152         self.cursor.execute(sql, args)
153
154     def fileno(self):
155         """ File descriptor to wait for. (Makes this class select()able.)
156         """
157         return self.conn.fileno()
158
159     def is_done(self):
160         """ Check if the connection is available for a new query.
161
162             Also checks if the previous query has run into a deadlock.
163             If so, then the previous query is repeated.
164         """
165         if self.current_query is None:
166             return True
167
168         try:
169             if self.conn.poll() == psycopg2.extensions.POLL_OK:
170                 self.current_query = None
171                 return True
172         except psycopg2.extensions.TransactionRollbackError as e:
173             if e.pgcode == '40P01':
174                 log.info("Deadlock detected (params = {}), retry.".format(self.current_params))
175                 self.cursor.execute(self.current_query, self.current_params)
176             else:
177                 raise
178         except psycopg2.errors.DeadlockDetected:
179             self.cursor.execute(self.current_query, self.current_params)
180
181         return False
182
183
184 class Indexer(object):
185     """ Main indexing routine.
186     """
187
188     def __init__(self, options):
189         self.minrank = max(0, options.minrank)
190         self.maxrank = min(30, options.maxrank)
191         self.conn = make_connection(options)
192         self.threads = [DBConnection(options) for i in range(options.threads)]
193
194     def run(self):
195         """ Run indexing over the entire database.
196         """
197         log.warning("Starting indexing rank ({} to {}) using {} threads".format(
198                  self.minrank, self.maxrank, len(self.threads)))
199
200         for rank in range(self.minrank, self.maxrank):
201             self.index(RankRunner(rank))
202
203         if self.maxrank == 30:
204             self.index(InterpolationRunner())
205
206         self.index(RankRunner(self.maxrank))
207
208     def index(self, obj):
209         """ Index a single rank or table. `obj` describes the SQL to use
210             for indexing.
211         """
212         log.warning("Starting {}".format(obj.name()))
213
214         cur = self.conn.cursor(name='main')
215         cur.execute(obj.sql_index_sectors())
216
217         total_tuples = 0
218         for r in cur:
219             total_tuples += r[1]
220         log.debug("Total number of rows; {}".format(total_tuples))
221
222         cur.scroll(0, mode='absolute')
223
224         next_thread = self.find_free_thread()
225         done_tuples = 0
226         rank_start_time = datetime.now()
227
228         sector_sql = obj.sql_sector_places()
229         index_sql = obj.sql_index_place()
230         min_grouped_tuples = total_tuples - len(self.threads) * 1000
231
232         next_info = 100 if log.isEnabledFor(logging.INFO) else total_tuples + 1
233
234         for r in cur:
235             sector = r[0]
236
237             # Should we do the remaining ones together?
238             do_all = done_tuples > min_grouped_tuples
239
240             pcur = self.conn.cursor(name='places')
241
242             if do_all:
243                 pcur.execute(obj.sql_nosector_places())
244             else:
245                 pcur.execute(sector_sql, (sector, ))
246
247             for place in pcur:
248                 place_id = place[0]
249                 log.debug("Processing place {}".format(place_id))
250                 thread = next(next_thread)
251
252                 thread.perform(index_sql, (place_id,))
253                 done_tuples += 1
254
255                 if done_tuples >= next_info:
256                     now = datetime.now()
257                     done_time = (now - rank_start_time).total_seconds()
258                     tuples_per_sec = done_tuples / done_time
259                     log.info("Done {} in {} @ {:.3f} per second - {} ETA (seconds): {:.2f}"
260                            .format(done_tuples, int(done_time),
261                                    tuples_per_sec, obj.name(),
262                                    (total_tuples - done_tuples)/tuples_per_sec))
263                     next_info += int(tuples_per_sec)
264
265             pcur.close()
266
267             if do_all:
268                 break
269
270         cur.close()
271
272         for t in self.threads:
273             t.wait()
274
275         rank_end_time = datetime.now()
276         diff_seconds = (rank_end_time-rank_start_time).total_seconds()
277
278         log.warning("Done {}/{} in {} @ {:.3f} per second - FINISHED {}\n".format(
279                  done_tuples, total_tuples, int(diff_seconds),
280                  done_tuples/diff_seconds, obj.name()))
281
282     def find_free_thread(self):
283         """ Generator that returns the next connection that is free for
284             sending a query.
285         """
286         ready = self.threads
287         command_stat = 0
288
289         while True:
290             for thread in ready:
291                 if thread.is_done():
292                     command_stat += 1
293                     yield thread
294
295             # refresh the connections occasionaly to avoid potential
296             # memory leaks in Postgresql.
297             if command_stat > 100000:
298                 for t in self.threads:
299                     while not t.is_done():
300                         wait_select(t.conn)
301                     t.connect()
302                 command_stat = 0
303                 ready = self.threads
304             else:
305                 ready, _, _ = select.select(self.threads, [], [])
306
307         assert False, "Unreachable code"
308
309
310 def nominatim_arg_parser():
311     """ Setup the command-line parser for the tool.
312     """
313     def h(s):
314         return re.sub("\s\s+" , " ", s)
315
316     p = ArgumentParser(description="Indexing tool for Nominatim.",
317                        formatter_class=RawDescriptionHelpFormatter)
318
319     p.add_argument('-d', '--database',
320                    dest='dbname', action='store', default='nominatim',
321                    help='Name of the PostgreSQL database to connect to.')
322     p.add_argument('-U', '--username',
323                    dest='user', action='store',
324                    help='PostgreSQL user name.')
325     p.add_argument('-W', '--password',
326                    dest='password_prompt', action='store_true',
327                    help='Force password prompt.')
328     p.add_argument('-H', '--host',
329                    dest='host', action='store',
330                    help='PostgreSQL server hostname or socket location.')
331     p.add_argument('-P', '--port',
332                    dest='port', action='store',
333                    help='PostgreSQL server port')
334     p.add_argument('-r', '--minrank',
335                    dest='minrank', type=int, metavar='RANK', default=0,
336                    help='Minimum/starting rank.')
337     p.add_argument('-R', '--maxrank',
338                    dest='maxrank', type=int, metavar='RANK', default=30,
339                    help='Maximum/finishing rank.')
340     p.add_argument('-t', '--threads',
341                    dest='threads', type=int, metavar='NUM', default=1,
342                    help='Number of threads to create for indexing.')
343     p.add_argument('-v', '--verbose',
344                    dest='loglevel', action='count', default=0,
345                    help='Increase verbosity')
346
347     return p
348
349 if __name__ == '__main__':
350     logging.basicConfig(stream=sys.stderr, format='%(levelname)s: %(message)s')
351
352     options = nominatim_arg_parser().parse_args(sys.argv[1:])
353
354     log.setLevel(max(3 - options.loglevel, 0) * 10)
355
356     options.password = None
357     if options.password_prompt:
358         password = getpass.getpass("Database password: ")
359         options.password = password
360
361     Indexer(options).run()