]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/nominatim.py
nominatim.py: also catch deadlocks on final wait
[nominatim.git] / nominatim / nominatim.py
1 #! /usr/bin/env python3
2 #-----------------------------------------------------------------------------
3 # nominatim - [description]
4 #-----------------------------------------------------------------------------
5 #
6 # Indexing tool for the Nominatim database.
7 #
8 # Based on C version by Brian Quinion
9 #
10 # This program is free software; you can redistribute it and/or
11 # modify it under the terms of the GNU General Public License
12 # as published by the Free Software Foundation; either version 2
13 # of the License, or (at your option) any later version.
14 #
15 # This program is distributed in the hope that it will be useful,
16 # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 # GNU General Public License for more details.
19 #
20 # You should have received a copy of the GNU General Public License
21 # along with this program; if not, write to the Free Software
22 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
23 #-----------------------------------------------------------------------------
24
25 from argparse import ArgumentParser, RawDescriptionHelpFormatter, ArgumentTypeError
26 import logging
27 import sys
28 import re
29 import getpass
30 from datetime import datetime
31 import psycopg2
32 from psycopg2.extras import wait_select
33 import select
34
35 log = logging.getLogger()
36
37 def make_connection(options, asynchronous=False):
38     return psycopg2.connect(dbname=options.dbname, user=options.user,
39                             password=options.password, host=options.host,
40                             port=options.port, async_=asynchronous)
41
42
43 class RankRunner(object):
44     """ Returns SQL commands for indexing one rank within the placex table.
45     """
46
47     def __init__(self, rank):
48         self.rank = rank
49
50     def name(self):
51         return "rank {}".format(self.rank)
52
53     def sql_index_sectors(self):
54         return """SELECT geometry_sector, count(*) FROM placex
55                   WHERE rank_search = {} and indexed_status > 0
56                   GROUP BY geometry_sector
57                   ORDER BY geometry_sector""".format(self.rank)
58
59     def sql_nosector_places(self):
60         return """SELECT place_id FROM placex
61                   WHERE indexed_status > 0 and rank_search = {}
62                   ORDER BY geometry_sector""".format(self.rank)
63
64     def sql_sector_places(self):
65         return """SELECT place_id FROM placex
66                   WHERE indexed_status > 0 and rank_search = {}
67                         and geometry_sector = %s""".format(self.rank)
68
69     def sql_index_place(self):
70         return "UPDATE placex SET indexed_status = 0 WHERE place_id = %s"
71
72
73 class InterpolationRunner(object):
74     """ Returns SQL commands for indexing the address interpolation table
75         location_property_osmline.
76     """
77
78     def name(self):
79         return "interpolation lines (location_property_osmline)"
80
81     def sql_index_sectors(self):
82         return """SELECT geometry_sector, count(*) FROM location_property_osmline
83                   WHERE indexed_status > 0
84                   GROUP BY geometry_sector
85                   ORDER BY geometry_sector"""
86
87     def sql_nosector_places(self):
88         return """SELECT place_id FROM location_property_osmline
89                   WHERE indexed_status > 0
90                   ORDER BY geometry_sector"""
91
92     def sql_sector_places(self):
93         return """SELECT place_id FROM location_property_osmline
94                   WHERE indexed_status > 0 and geometry_sector = %s
95                   ORDER BY geometry_sector"""
96
97     def sql_index_place(self):
98         return """UPDATE location_property_osmline
99                   SET indexed_status = 0 WHERE place_id = %s"""
100
101
102 class DBConnection(object):
103     """ A single non-blocking database connection.
104     """
105
106     def __init__(self, options):
107         self.current_query = None
108         self.current_params = None
109
110         self.conn = None
111         self.connect()
112
113     def connect(self):
114         if self.conn is not None:
115             self.cursor.close()
116             self.conn.close()
117
118         self.conn = make_connection(options, asynchronous=True)
119         self.wait()
120
121         self.cursor = self.conn.cursor()
122
123     def wait(self):
124         """ Block until any pending operation is done.
125         """
126         while True:
127             try:
128                 wait_select(self.conn)
129                 self.current_query = None
130                 return
131             except psycopg2.extensions.TransactionRollbackError as e:
132                 if e.pgcode == '40P01':
133                     log.info("Deadlock detected (params = {}), retry."
134                               .format(self.current_params))
135                     self.cursor.execute(self.current_query, self.current_params)
136                 else:
137                     raise
138             except psycopg2.errors.DeadlockDetected:
139                 self.cursor.execute(self.current_query, self.current_params)
140
141     def perform(self, sql, args=None):
142         """ Send SQL query to the server. Returns immediately without
143             blocking.
144         """
145         self.current_query = sql
146         self.current_params = args
147         self.cursor.execute(sql, args)
148
149     def fileno(self):
150         """ File descriptor to wait for. (Makes this class select()able.)
151         """
152         return self.conn.fileno()
153
154     def is_done(self):
155         """ Check if the connection is available for a new query.
156
157             Also checks if the previous query has run into a deadlock.
158             If so, then the previous query is repeated.
159         """
160         if self.current_query is None:
161             return True
162
163         try:
164             if self.conn.poll() == psycopg2.extensions.POLL_OK:
165                 self.current_query = None
166                 return True
167         except psycopg2.extensions.TransactionRollbackError as e:
168             if e.pgcode == '40P01':
169                 log.info("Deadlock detected (params = {}), retry.".format(self.current_params))
170                 self.cursor.execute(self.current_query, self.current_params)
171             else:
172                 raise
173         except psycopg2.errors.DeadlockDetected:
174             self.cursor.execute(self.current_query, self.current_params)
175
176         return False
177
178
179 class Indexer(object):
180     """ Main indexing routine.
181     """
182
183     def __init__(self, options):
184         self.minrank = max(0, options.minrank)
185         self.maxrank = min(30, options.maxrank)
186         self.conn = make_connection(options)
187         self.threads = [DBConnection(options) for i in range(options.threads)]
188
189     def run(self):
190         """ Run indexing over the entire database.
191         """
192         log.warning("Starting indexing rank ({} to {}) using {} threads".format(
193                  self.minrank, self.maxrank, len(self.threads)))
194
195         for rank in range(self.minrank, self.maxrank):
196             self.index(RankRunner(rank))
197
198         if self.maxrank == 30:
199             self.index(InterpolationRunner())
200
201         self.index(RankRunner(self.maxrank))
202
203     def index(self, obj):
204         """ Index a single rank or table. `obj` describes the SQL to use
205             for indexing.
206         """
207         log.warning("Starting {}".format(obj.name()))
208
209         cur = self.conn.cursor(name='main')
210         cur.execute(obj.sql_index_sectors())
211
212         total_tuples = 0
213         for r in cur:
214             total_tuples += r[1]
215         log.debug("Total number of rows; {}".format(total_tuples))
216
217         cur.scroll(0, mode='absolute')
218
219         next_thread = self.find_free_thread()
220         done_tuples = 0
221         rank_start_time = datetime.now()
222
223         sector_sql = obj.sql_sector_places()
224         index_sql = obj.sql_index_place()
225         min_grouped_tuples = total_tuples - len(self.threads) * 1000
226
227         next_info = 100 if log.isEnabledFor(logging.INFO) else total_tuples + 1
228
229         for r in cur:
230             sector = r[0]
231
232             # Should we do the remaining ones together?
233             do_all = done_tuples > min_grouped_tuples
234
235             pcur = self.conn.cursor(name='places')
236
237             if do_all:
238                 pcur.execute(obj.sql_nosector_places())
239             else:
240                 pcur.execute(sector_sql, (sector, ))
241
242             for place in pcur:
243                 place_id = place[0]
244                 log.debug("Processing place {}".format(place_id))
245                 thread = next(next_thread)
246
247                 thread.perform(index_sql, (place_id,))
248                 done_tuples += 1
249
250                 if done_tuples >= next_info:
251                     now = datetime.now()
252                     done_time = (now - rank_start_time).total_seconds()
253                     tuples_per_sec = done_tuples / done_time
254                     log.info("Done {} in {} @ {:.3f} per second - {} ETA (seconds): {:.2f}"
255                            .format(done_tuples, int(done_time),
256                                    tuples_per_sec, obj.name(),
257                                    (total_tuples - done_tuples)/tuples_per_sec))
258                     next_info += int(tuples_per_sec)
259
260             pcur.close()
261
262             if do_all:
263                 break
264
265         cur.close()
266
267         for t in self.threads:
268             t.wait()
269
270         rank_end_time = datetime.now()
271         diff_seconds = (rank_end_time-rank_start_time).total_seconds()
272
273         log.warning("Done {}/{} in {} @ {:.3f} per second - FINISHED {}\n".format(
274                  done_tuples, total_tuples, int(diff_seconds),
275                  done_tuples/diff_seconds, obj.name()))
276
277     def find_free_thread(self):
278         """ Generator that returns the next connection that is free for
279             sending a query.
280         """
281         ready = self.threads
282         command_stat = 0
283
284         while True:
285             for thread in ready:
286                 if thread.is_done():
287                     command_stat += 1
288                     yield thread
289
290             # refresh the connections occasionaly to avoid potential
291             # memory leaks in Postgresql.
292             if command_stat > 100000:
293                 for t in self.threads:
294                     while not t.is_done():
295                         wait_select(t.conn)
296                     t.connect()
297                 command_stat = 0
298                 ready = self.threads
299             else:
300                 ready, _, _ = select.select(self.threads, [], [])
301
302         assert(False, "Unreachable code")
303
304
305 def nominatim_arg_parser():
306     """ Setup the command-line parser for the tool.
307     """
308     def h(s):
309         return re.sub("\s\s+" , " ", s)
310
311     p = ArgumentParser(description="Indexing tool for Nominatim.",
312                        formatter_class=RawDescriptionHelpFormatter)
313
314     p.add_argument('-d', '--database',
315                    dest='dbname', action='store', default='nominatim',
316                    help='Name of the PostgreSQL database to connect to.')
317     p.add_argument('-U', '--username',
318                    dest='user', action='store',
319                    help='PostgreSQL user name.')
320     p.add_argument('-W', '--password',
321                    dest='password_prompt', action='store_true',
322                    help='Force password prompt.')
323     p.add_argument('-H', '--host',
324                    dest='host', action='store',
325                    help='PostgreSQL server hostname or socket location.')
326     p.add_argument('-P', '--port',
327                    dest='port', action='store',
328                    help='PostgreSQL server port')
329     p.add_argument('-r', '--minrank',
330                    dest='minrank', type=int, metavar='RANK', default=0,
331                    help='Minimum/starting rank.')
332     p.add_argument('-R', '--maxrank',
333                    dest='maxrank', type=int, metavar='RANK', default=30,
334                    help='Maximum/finishing rank.')
335     p.add_argument('-t', '--threads',
336                    dest='threads', type=int, metavar='NUM', default=1,
337                    help='Number of threads to create for indexing.')
338     p.add_argument('-v', '--verbose',
339                    dest='loglevel', action='count', default=0,
340                    help='Increase verbosity')
341
342     return p
343
344 if __name__ == '__main__':
345     logging.basicConfig(stream=sys.stderr, format='%(levelname)s: %(message)s')
346
347     options = nominatim_arg_parser().parse_args(sys.argv[1:])
348
349     log.setLevel(max(3 - options.loglevel, 0) * 10)
350
351     options.password = None
352     if options.password_prompt:
353         password = getpass.getpass("Database password: ")
354         options.password = password
355
356     Indexer(options).run()