]> git.openstreetmap.org Git - osqa.git/blob - forum_modules/exporter/importer.py
1519a44f9e6b2d6d134702cf82a90294535eb54e
[osqa.git] / forum_modules / exporter / importer.py
1 import os, tarfile, datetime, ConfigParser, logging
2
3 from django.utils.translation import ugettext as _
4 from django.core.cache import cache
5
6 from south.db import db
7
8 from xml.sax import make_parser
9 from xml.sax.handler import ContentHandler, ErrorHandler
10
11 from forum.templatetags.extra_tags import diff_date
12
13 from exporter import TMP_FOLDER, DATETIME_FORMAT, DATE_FORMAT, META_INF_SECTION, CACHE_KEY
14 from orm import orm
15 import commands, settings
16
17 # Try to import the with statement
18 try:
19     from __future__ import with_statement
20 except:
21     pass
22
23 NO_DEFAULT = object()
24
25 import string
26
27 class SafeReader():
28     def __init__(self, loc):
29         self.base = open(loc)
30
31     def read(self, *args):
32         return "".join(c for c in self.base.read(*args) if c in string.printable)
33
34     def readLine(self, *args):
35         return "".join(c for c in self.base.readLine(*args) if c in string.printable)
36
37     def close(self):
38         self.base.close()
39
40
41 class ContentElement():
42     def __init__(self, content):
43         self._content = content
44
45     def content(self):
46         return self._content.strip()
47
48     def as_bool(self):
49         return self.content() == "true"
50
51     def as_date(self, default=NO_DEFAULT):
52         try:
53             return datetime.datetime.strptime(self.content(), DATE_FORMAT)
54         except:
55             if default == NO_DEFAULT:
56                 return datetime.date.fromtimestamp(0)
57             else:
58                 return default
59             
60
61     def as_datetime(self, default=NO_DEFAULT):
62         try:
63             return datetime.datetime.strptime(self.content(), DATETIME_FORMAT)
64         except:
65             if default == NO_DEFAULT:
66                 return datetime.datetime.fromtimestamp(0)
67             else:
68                 return default
69
70     def as_int(self, default=0):
71         try:
72             return int(self.content())
73         except:
74             return default
75
76     def __str__(self):
77         return self.content()
78
79
80 class RowElement(ContentElement):
81     def __init__(self, name, attrs, parent=None):
82         self.name = name.lower()
83         self.parent = parent
84         self.attrs = dict([(k.lower(), ContentElement(v)) for k, v in attrs.items()])
85         self._content = u''
86         self.sub_elements = {}
87
88         if parent:
89             parent.add(self)
90
91     def add_to_content(self, ch):
92         self._content += unicode(ch)
93
94     def add(self, sub):
95         curr = self.sub_elements.get(sub.name, None)
96
97         if not curr:
98             curr = []
99             self.sub_elements[sub.name] = curr
100
101         curr.append(sub)
102
103     def get(self, name, default=None):
104         return self.sub_elements.get(name.lower(), [default])[-1]
105
106     def get_list(self, name):
107         return self.sub_elements.get(name.lower(), [])
108
109     def get_listc(self, name):
110         return [r.content() for r in self.get_list(name)]
111
112     def getc(self, name, default=""):
113         el = self.get(name, None)
114
115         if el:
116             return el.content()
117         else:
118             return default
119
120     def get_attr(self, name, default=""):
121         return self.attrs.get(name.lower(), default)
122
123     def as_pickled(self, default=None):
124         value_el = self.get('value')
125
126         if value_el:
127             return value_el._as_pickled(default)
128         else:
129             return default
130
131     TYPES_MAP = dict([(c.__name__, c) for c in (int, long, str, unicode, float)])
132
133     def _as_pickled(self, default=None):
134         type = self.get_attr('type').content()
135
136         try:
137             if type == 'dict':
138                 return dict([ (item.get_attr('key'), item.as_pickled()) for item in self.get_list('item') ])
139             elif type == 'list':
140                 return [item.as_pickled() for item in self.get_list('item')]
141             elif type == 'bool':
142                 return self.content().lower() == 'true'
143             elif type in RowElement.TYPES_MAP:
144                 return RowElement.TYPES_MAP[type](self.content())
145             else:
146                 return self.content()
147         except:
148             return default
149
150
151
152
153 class TableHandler(ContentHandler):
154     def __init__(self, root_name, row_name, callback, callback_args = [], ping = None):
155         self.root_name = root_name.lower()
156         self.row_name = row_name.lower()
157         self.callback = callback
158         self.callback_args = callback_args
159         self.ping = ping
160
161         self._reset()
162
163     def _reset(self):
164         self.curr_element = None
165         self.in_tag = None
166
167     def startElement(self, name, attrs):
168         name = name.lower()
169
170         if name == self.root_name.lower():
171             pass
172         elif name == self.row_name:
173             self.curr_element = RowElement(name, attrs)
174         else:
175             self.curr_element = RowElement(name, attrs, self.curr_element)
176
177     def characters(self, ch):
178         if self.curr_element:
179             self.curr_element.add_to_content(ch)
180
181     def endElement(self, name):
182         name = name.lower()
183
184         if name == self.root_name:
185             pass
186         elif name == self.row_name:
187             self.callback(self.curr_element, *self.callback_args)
188             if self.ping:
189                 self.ping()
190
191             self._reset()
192         else:
193             self.curr_element = self.curr_element.parent
194
195
196 class SaxErrorHandler(ErrorHandler):
197     def error(self, e):
198         raise e
199
200     def fatalError(self, e):
201         raise e
202
203     def warning(self, e):
204         raise e
205
206 def disable_triggers():
207     if db.backend_name == "postgres":
208         db.start_transaction()
209         db.execute_many(commands.PG_DISABLE_TRIGGERS)
210         db.commit_transaction()
211
212 def enable_triggers():
213     if db.backend_name == "postgres":
214         db.start_transaction()
215         db.execute_many(commands.PG_ENABLE_TRIGGERS)
216         db.commit_transaction()
217
218 def reset_sequences():
219     if db.backend_name == "postgres":
220         db.start_transaction()
221         db.execute_many(commands.PG_SEQUENCE_RESETS)
222         db.commit_transaction()
223
224 def reset_fts_indexes():
225     pass
226
227 FILE_HANDLERS = []
228
229 def start_import(fname, tag_merge, user):
230
231     start_time = datetime.datetime.now()
232     steps = [s for s in FILE_HANDLERS]
233
234     with open(os.path.join(TMP_FOLDER, 'backup.inf'), 'r') as inffile:
235         inf = ConfigParser.SafeConfigParser()
236         inf.readfp(inffile)
237
238         state = dict([(s['id'], {
239             'status': _('Queued'), 'count': int(inf.get(META_INF_SECTION, s['id'])), 'parsed': 0
240         }) for s in steps] + [
241             ('overall', {
242                 'status': _('Starting'), 'count': int(inf.get(META_INF_SECTION, 'overall')), 'parsed': 0
243             })
244         ])
245
246     full_state = dict(running=True, state=state, time_started="")
247
248     def set_state():
249         full_state['time_started'] = diff_date(start_time)
250         cache.set(CACHE_KEY, full_state)
251
252     set_state()
253
254     def ping_state(name):
255         state[name]['parsed'] += 1
256         state['overall']['parsed'] += 1
257         set_state()
258
259     data = {
260         'is_merge': True,
261         'tag_merge': tag_merge
262     }
263
264     def run(fn, name):
265         def ping():
266             ping_state(name)
267
268         state['overall']['status'] = _('Importing %s') % s['name']
269         state[name]['status'] = _('Importing')
270
271
272         fn(TMP_FOLDER, user, ping, data)
273
274         state[name]['status'] = _('Done')
275
276         set_state()
277
278         return fname
279
280     #dump = tarfile.open(fname, 'r')
281     #dump.extractall(TMP_FOLDER)
282
283     try:
284
285         disable_triggers()
286         db.start_transaction()
287
288         for h in FILE_HANDLERS:
289             run(h['fn'], h['id'])
290
291         db.commit_transaction()
292         enable_triggers()
293
294         settings.MERGE_MAPPINGS.set_value(dict(merged_nodes=data['nodes_map'], merged_users=data['users_map']))
295
296         reset_sequences()
297     except Exception, e:
298         full_state['running'] = False
299         full_state['errors'] = "%s: %s" % (e.__class__.__name__, unicode(e))
300         set_state()
301
302         import traceback
303         logging.error("Error executing xml import: \n %s" % (traceback.format_exc()))
304
305 def file_handler(file_name, root_tag, el_tag, name, args_handler=None, pre_callback=None, post_callback=None):
306     def decorator(fn):
307         def decorated(location, current_user, ping, data):
308             if pre_callback:
309                 pre_callback(current_user, data)
310
311             if (args_handler):
312                 args = args_handler(current_user, data)
313             else:
314                 args = []
315
316             parser = make_parser()
317             handler = TableHandler(root_tag, el_tag, fn, args, ping)
318             parser.setContentHandler(handler)
319             #parser.setErrorHandler(SaxErrorHandler())
320
321             parser.parse(SafeReader(os.path.join(location, file_name)))
322
323             if post_callback:
324                 post_callback()
325
326         FILE_HANDLERS.append(dict(id=root_tag, name=name, fn=decorated))
327         return decorated
328     return decorator
329
330 def verify_existence(row):
331     try:
332         return orm.User.objects.get(email=row.getc('email'))
333     except:
334         for key in row.get('authKeys').get_list('key'):
335             key = key=key.getc('key')
336
337             if not ("google.com" in key or "yahoo.com" in key):
338                 try:
339                     return orm.AuthKeyUserAssociation.objects.get(key=key).user
340                 except:
341                     pass
342
343     return None
344
345 def user_import_pre_callback(user, data):
346     data['users_map'] = {}
347
348 @file_handler('users.xml', 'users', 'user', _('Users'), pre_callback=user_import_pre_callback, args_handler=lambda u, d: [u, d['is_merge'], d['users_map']])
349 def user_import(row, current_user, is_merge, users_map):
350     existent = is_merge and verify_existence(row) or None
351
352     roles = row.get('roles').get_listc('role')
353     valid_email = row.get('email').get_attr('validated').as_bool()
354     badges = row.get('badges')
355
356     if existent:
357         user = existent
358
359         user.reputation += row.get('reputation').as_int()
360         user.gold += badges.get_attr('gold').as_int()
361         user.silver += badges.get_attr('gold').as_int()
362         user.bronze += badges.get_attr('gold').as_int()
363
364     else:
365         username = row.getc('username')
366
367         if is_merge:
368             username_count = 0
369
370             while orm.User.objects.filter(username=username).count():
371                 username_count += 1
372                 username = "%s %s" % (row.getc('username'), username_count)
373
374         user = orm.User(
375                 id           = (not is_merge) and row.getc('id') or None,
376                 username     = username,
377                 password     = row.getc('password'),
378                 email        = row.getc('email'),
379                 email_isvalid= valid_email,
380                 is_superuser = (not is_merge) and 'superuser' in roles,
381                 is_staff     = ('moderator' in roles) or (is_merge and 'superuser' in roles),
382                 is_active    = row.get('active').as_bool(),
383                 date_joined  = row.get('joindate').as_datetime(),
384                 about         = row.getc('bio'),
385                 date_of_birth = row.get('birthdate').as_date(None),
386                 website       = row.getc('website'),
387                 reputation    = row.get('reputation').as_int(),
388                 gold          = badges.get_attr('gold').as_int(),
389                 silver        = badges.get_attr('silver').as_int(),
390                 bronze        = badges.get_attr('bronze').as_int(),
391                 real_name     = row.getc('realname'),
392                 location      = row.getc('location'),
393         )
394
395     user.save()
396
397     users_map[row.get('id').as_int()] = user.id
398
399     authKeys = row.get('authKeys')
400
401     for key in authKeys.get_list('key'):
402         if (not is_merge) or orm.AuthKeyUserAssociation.objects.filter(key=key.getc('key')).count() == 0:
403             orm.AuthKeyUserAssociation(user=user, key=key.getc('key'), provider=key.getc('provider')).save()
404
405     if not existent:
406         notifications = row.get('notifications')
407
408         attributes = dict([(str(k), v.as_bool() and 'i' or 'n') for k, v in notifications.get('notify').attrs.items()])
409         attributes.update(dict([(str(k), v.as_bool()) for k, v in notifications.get('autoSubscribe').attrs.items()]))
410         attributes.update(dict([(str("notify_%s" % k), v.as_bool()) for k, v in notifications.get('notifyOnSubscribed').attrs.items()]))
411
412         ss = orm.SubscriptionSettings(user=user, enable_notifications=notifications.get_attr('enabled').as_bool(), **attributes)
413
414         if current_user.id == row.get('id').as_int():
415             ss.id = current_user.subscription_settings.id
416
417         ss.save()
418         
419
420 def pre_tag_import(user, data):
421     data['tag_mappings'] = dict([ (t.name, t) for t in orm.Tag.objects.all() ])
422
423
424 @file_handler('tags.xml', 'tags', 'tag', _('Tags'), pre_callback=pre_tag_import, args_handler=lambda u, d: [d['is_merge'], d['tag_merge'], d['users_map'], d['tag_mappings']])
425 def tag_import(row, is_merge, tag_merge, users_map, tag_mappings):
426     created_by = row.get('used').as_int()
427     created_by = users_map.get(created_by, created_by)
428
429     tag_name = row.getc('name')
430     tag_name = tag_merge and tag_merge.get(tag_name, tag_name) or tag_name
431
432     if is_merge and tag_name in tag_mappings:
433         tag = tag_mappings[tag_name]
434         tag.used_count += row.get('used').as_int()
435     else:
436         tag = orm.Tag(name=tag_name, used_count=row.get('used').as_int(), created_by_id=created_by)
437         tag_mappings[tag.name] = tag
438
439     tag.save()
440
441 def pre_node_import(user, data):
442     data['nodes_map'] = {}
443
444 @file_handler('nodes.xml', 'nodes', 'node', _('Nodes'), pre_callback=pre_node_import,
445               args_handler=lambda u, d: [d['is_merge'], d['tag_merge'], d['tag_mappings'], d['nodes_map'], d['users_map']])
446 def node_import(row, is_merge, tag_merge, tags, nodes_map, users_map):
447
448     ntags = []
449
450     for t in row.get('tags').get_list('tag'):
451         t = t.content()
452         ntags.append(tags[tag_merge and tag_merge.get(t, t) or t])
453
454     author = row.get('author').as_int()
455
456     last_act = row.get('lastactivity')
457     last_act_user = last_act.get('by').as_int(None)
458
459     parent = row.get('parent').as_int(None)
460     abs_parent = row.get('absparent').as_int(None)
461
462     node = orm.Node(
463             id            = (not is_merge) and row.getc('id') or None,
464             node_type     = row.getc('type'),
465             author_id     = users_map.get(author, author),
466             added_at      = row.get('date').as_datetime(),
467             parent_id     = nodes_map.get(parent, parent),
468             abs_parent_id = nodes_map.get(abs_parent, abs_parent),
469             score         = row.get('score').as_int(0),
470
471             last_activity_by_id = last_act_user and users_map.get(last_act_user, last_act_user) or last_act_user,
472             last_activity_at    = last_act.get('at').as_datetime(None),
473
474             title         = row.getc('title'),
475             body          = row.getc('body'),
476             tagnames      = " ".join([t.name for t in ntags]),
477
478             marked        = row.get('marked').as_bool(),
479             extra_ref_id  = row.get('extraRef').as_int(None),
480             extra_count   = row.get('extraCount').as_int(0),
481             extra         = row.get('extraData').as_pickled()
482     )
483
484     node.save()
485
486     nodes_map[row.get('id').as_int()] = node.id
487
488     node.tags = ntags
489
490     revisions = row.get('revisions')
491     active = revisions.get_attr('active').as_int()
492
493     if active == 0:
494         active = orm.NodeRevision(
495             author_id = node.author_id,
496             body = row.getc('body'),
497             node = node,
498             revised_at = row.get('date').as_datetime(),
499             revision = 1,
500             summary = _('Initial revision'),
501             tagnames = " ".join([t.name for t in ntags]),
502             title = row.getc('title'),
503         )
504
505         active.save()
506     else:
507         for r in revisions.get_list('revision'):
508             author = row.get('author').as_int()
509
510             rev = orm.NodeRevision(
511                 author_id = users_map.get(author, author),
512                 body = r.getc('body'),
513                 node = node,
514                 revised_at = r.get('date').as_datetime(),
515                 revision = r.get('number').as_int(),
516                 summary = r.getc('summary'),
517                 tagnames = " ".join(r.getc('tags').split(',')),
518                 title = r.getc('title'),
519             )
520
521             rev.save()
522             if rev.revision == active:
523                 active = rev
524
525     node.active_revision = active
526     node.save()
527
528 POST_ACTION = {}
529
530 def post_action(*types):
531     def decorator(fn):
532         for t in types:
533             POST_ACTION[t] = fn
534         return fn
535     return decorator
536
537 def pre_action_import_callback(user, data):
538     data['actions_map'] = {}
539
540 def post_action_import_callback():
541     with_state = orm.Node.objects.filter(id__in=orm.NodeState.objects.values_list('node_id', flat=True).distinct())
542
543     for n in with_state:
544         n.state_string = "".join(["(%s)" % s for s in n.states.values_list('state_type')])
545         n.save()
546
547 @file_handler('actions.xml', 'actions', 'action', _('Actions'), post_callback=post_action_import_callback,
548               pre_callback=pre_action_import_callback, args_handler=lambda u, d: [d['nodes_map'], d['users_map'], d['actions_map']])
549 def actions_import(row, nodes, users, actions_map):
550     node = row.get('node').as_int(None)
551     user = row.get('user').as_int()
552     real_user = row.get('realUser').as_int(None)
553
554     action = orm.Action(
555         #id           = row.get('id').as_int(),
556         action_type  = row.getc('type'),
557         action_date  = row.get('date').as_datetime(),
558         node_id      = nodes.get(node, node),
559         user_id      = users.get(user, user),
560         real_user_id = users.get(real_user, real_user),
561         ip           = row.getc('ip'),
562         extra        = row.get('extraData').as_pickled(),
563     )
564
565     canceled = row.get('canceled')
566     if canceled.get_attr('state').as_bool():
567         by = canceled.get('user').as_int()
568         action.canceled = True
569         action.canceled_by_id = users.get(by, by)
570         action.canceled_at = canceled.getc('date') #.as_datetime(),
571         action.canceled_ip = canceled.getc('ip')
572
573     action.save()
574
575     actions_map[row.get('id').as_int()] = action.id
576
577     for r in row.get('reputes').get_list('repute'):
578         by_canceled = r.get_attr('byCanceled').as_bool()
579
580         orm.ActionRepute(
581             action = action,
582             user_id = users[r.get('user').as_int()],
583             value = r.get('value').as_int(),
584
585             date = by_canceled and action.canceled_at or action.action_date,
586             by_canceled = by_canceled
587         ).save()
588
589     if (not action.canceled) and (action.action_type in POST_ACTION):
590         POST_ACTION[action.action_type](row, action, users, nodes, actions_map)
591
592
593
594
595 # Record of all persisted votes.
596 persisted_votes = []
597 @post_action('voteup', 'votedown', 'voteupcomment')
598 def vote_action(row, action, users, nodes, actions):
599     # Check to see if the vote has already been registered.
600     if not (action.user_id, action.node_id) in persisted_votes:
601         # Persist the vote action.
602         orm.Vote(user_id=action.user_id, node_id=action.node_id, action=action,
603                  voted_at=action.action_date, value=(action.action_type != 'votedown') and 1 or -1).save()
604
605         # Record the vote action.  This will help us avoid duplicates.
606         persisted_votes.append((action.user_id, action.node_id))
607
608
609 def state_action(state):
610     def fn(row, action, users, nodes, actions):
611         if orm.NodeState.objects.filter(state_type = state, node = action.node_id).count():
612             return
613
614         orm.NodeState(
615             state_type = state,
616             node_id = action.node_id,
617             action = action
618         ).save()
619     return fn
620
621 post_action('wikify')(state_action('wiki'))
622 post_action('delete')(state_action('deleted'))
623 post_action('acceptanswer')(state_action('accepted'))
624 post_action('publish')(state_action('published'))
625
626
627 @post_action('flag')
628 def flag_action(row, action, users, nodes, actions):
629     orm.Flag(user_id=action.user_id, node_id=action.node_id, action=action, reason=action.extra or "").save()
630
631
632 def award_import_args(user, data):
633     return [ dict([ (b.cls, b) for b in orm.Badge.objects.all() ]) , data['nodes_map'], data['users_map'], data['actions_map']]
634
635
636 @file_handler('awards.xml', 'awards', 'award', _('Awards'), args_handler=award_import_args)
637 def awards_import(row, badges, nodes, users, actions):
638     badge_type = badges.get(row.getc('badge'), None)
639
640     if not badge_type:
641         return
642
643     action = row.get('action').as_int(None)
644     trigger = row.get('trigger').as_int(None)
645     node = row.get('node').as_int(None)
646     user = row.get('user').as_int()
647
648     if orm.Award.objects.filter(badge=badges[row.getc('badge')], user=users.get(user, user), node=nodes.get(node, node)).count():
649         return
650
651     award = orm.Award(
652         user_id = users.get(user, user),
653         badge = badge_type,
654         node_id = nodes.get(node, node),
655         action_id = actions.get(action, action),
656         trigger_id = actions.get(trigger, trigger)
657     ).save()
658
659
660 #@file_handler('settings.xml', 'settings', 'setting', _('Settings'))
661 def settings_import(row):
662     orm.KeyValue(key=row.getc('key'), value=row.get('value').as_pickled())
663
664
665
666
667
668
669
670
671