]> git.openstreetmap.org Git - osqa.git/blob - forum_modules/exporter/importer.py
#OSQA-603, fixing the badges count calculation error. After the update of the count...
[osqa.git] / forum_modules / exporter / importer.py
1 import os, tarfile, datetime, ConfigParser, logging
2
3 from django.utils.translation import ugettext as _
4 from django.core.cache import cache
5
6 from south.db import db
7
8 from xml.sax import make_parser
9 from xml.sax.handler import ContentHandler, ErrorHandler
10
11 from forum.templatetags.extra_tags import diff_date
12
13 from exporter import TMP_FOLDER, DATETIME_FORMAT, DATE_FORMAT, META_INF_SECTION, CACHE_KEY
14 from orm import orm
15 import commands, settings
16
17 NO_DEFAULT = object()
18
19 import string
20
21 class SafeReader():
22     def __init__(self, loc):
23         self.base = open(loc)
24
25     def read(self, *args):
26         return "".join(c for c in self.base.read(*args) if c in string.printable)
27
28     def readLine(self, *args):
29         return "".join(c for c in self.base.readLine(*args) if c in string.printable)
30
31     def close(self):
32         self.base.close()
33
34
35 class ContentElement():
36     def __init__(self, content):
37         self._content = content
38
39     def content(self):
40         return self._content.strip()
41
42     def as_bool(self):
43         return self.content() == "true"
44
45     def as_date(self, default=NO_DEFAULT):
46         try:
47             return datetime.datetime.strptime(self.content(), DATE_FORMAT)
48         except:
49             if default == NO_DEFAULT:
50                 return datetime.date.fromtimestamp(0)
51             else:
52                 return default
53             
54
55     def as_datetime(self, default=NO_DEFAULT):
56         try:
57             return datetime.datetime.strptime(self.content(), DATETIME_FORMAT)
58         except:
59             if default == NO_DEFAULT:
60                 return datetime.datetime.fromtimestamp(0)
61             else:
62                 return default
63
64     def as_int(self, default=0):
65         try:
66             return int(self.content())
67         except:
68             return default
69
70     def __str__(self):
71         return self.content()
72
73
74 class RowElement(ContentElement):
75     def __init__(self, name, attrs, parent=None):
76         self.name = name.lower()
77         self.parent = parent
78         self.attrs = dict([(k.lower(), ContentElement(v)) for k, v in attrs.items()])
79         self._content = u''
80         self.sub_elements = {}
81
82         if parent:
83             parent.add(self)
84
85     def add_to_content(self, ch):
86         self._content += unicode(ch)
87
88     def add(self, sub):
89         curr = self.sub_elements.get(sub.name, None)
90
91         if not curr:
92             curr = []
93             self.sub_elements[sub.name] = curr
94
95         curr.append(sub)
96
97     def get(self, name, default=None):
98         return self.sub_elements.get(name.lower(), [default])[-1]
99
100     def get_list(self, name):
101         return self.sub_elements.get(name.lower(), [])
102
103     def get_listc(self, name):
104         return [r.content() for r in self.get_list(name)]
105
106     def getc(self, name, default=""):
107         el = self.get(name, None)
108
109         if el:
110             return el.content()
111         else:
112             return default
113
114     def get_attr(self, name, default=""):
115         return self.attrs.get(name.lower(), default)
116
117     def as_pickled(self, default=None):
118         value_el = self.get('value')
119
120         if value_el:
121             return value_el._as_pickled(default)
122         else:
123             return default
124
125     TYPES_MAP = dict([(c.__name__, c) for c in (int, long, str, unicode, float)])
126
127     def _as_pickled(self, default=None):
128         type = self.get_attr('type').content()
129
130         try:
131             if type == 'dict':
132                 return dict([ (item.get_attr('key'), item.as_pickled()) for item in self.get_list('item') ])
133             elif type == 'list':
134                 return [item.as_pickled() for item in self.get_list('item')]
135             elif type == 'bool':
136                 return self.content().lower() == 'true'
137             elif type in RowElement.TYPES_MAP:
138                 return RowElement.TYPES_MAP[type](self.content())
139             else:
140                 return self.content()
141         except:
142             return default
143
144
145
146
147 class TableHandler(ContentHandler):
148     def __init__(self, root_name, row_name, callback, callback_args = [], ping = None):
149         self.root_name = root_name.lower()
150         self.row_name = row_name.lower()
151         self.callback = callback
152         self.callback_args = callback_args
153         self.ping = ping
154
155         self._reset()
156
157     def _reset(self):
158         self.curr_element = None
159         self.in_tag = None
160
161     def startElement(self, name, attrs):
162         name = name.lower()
163
164         if name == self.root_name.lower():
165             pass
166         elif name == self.row_name:
167             self.curr_element = RowElement(name, attrs)
168         else:
169             self.curr_element = RowElement(name, attrs, self.curr_element)
170
171     def characters(self, ch):
172         if self.curr_element:
173             self.curr_element.add_to_content(ch)
174
175     def endElement(self, name):
176         name = name.lower()
177
178         if name == self.root_name:
179             pass
180         elif name == self.row_name:
181             self.callback(self.curr_element, *self.callback_args)
182             if self.ping:
183                 self.ping()
184
185             self._reset()
186         else:
187             self.curr_element = self.curr_element.parent
188
189
190 class SaxErrorHandler(ErrorHandler):
191     def error(self, e):
192         raise e
193
194     def fatalError(self, e):
195         raise e
196
197     def warning(self, e):
198         raise e
199
200 def disable_triggers():
201     if db.backend_name == "postgres":
202         db.start_transaction()
203         db.execute_many(commands.PG_DISABLE_TRIGGERS)
204         db.commit_transaction()
205
206 def enable_triggers():
207     if db.backend_name == "postgres":
208         db.start_transaction()
209         db.execute_many(commands.PG_ENABLE_TRIGGERS)
210         db.commit_transaction()
211
212 def reset_sequences():
213     if db.backend_name == "postgres":
214         db.start_transaction()
215         db.execute_many(commands.PG_SEQUENCE_RESETS)
216         db.commit_transaction()
217
218 def reset_fts_indexes():
219     pass
220
221 FILE_HANDLERS = []
222
223 def start_import(fname, tag_merge, user):
224
225     start_time = datetime.datetime.now()
226     steps = [s for s in FILE_HANDLERS]
227
228     with open(os.path.join(TMP_FOLDER, 'backup.inf'), 'r') as inffile:
229         inf = ConfigParser.SafeConfigParser()
230         inf.readfp(inffile)
231
232         state = dict([(s['id'], {
233             'status': _('Queued'), 'count': int(inf.get(META_INF_SECTION, s['id'])), 'parsed': 0
234         }) for s in steps] + [
235             ('overall', {
236                 'status': _('Starting'), 'count': int(inf.get(META_INF_SECTION, 'overall')), 'parsed': 0
237             })
238         ])
239
240     full_state = dict(running=True, state=state, time_started="")
241
242     def set_state():
243         full_state['time_started'] = diff_date(start_time)
244         cache.set(CACHE_KEY, full_state)
245
246     set_state()
247
248     def ping_state(name):
249         state[name]['parsed'] += 1
250         state['overall']['parsed'] += 1
251         set_state()
252
253     data = {
254         'is_merge': True,
255         'tag_merge': tag_merge
256     }
257
258     def run(fn, name):
259         def ping():
260             ping_state(name)
261
262         state['overall']['status'] = _('Importing %s') % s['name']
263         state[name]['status'] = _('Importing')
264
265
266         fn(TMP_FOLDER, user, ping, data)
267
268         state[name]['status'] = _('Done')
269
270         set_state()
271
272         return fname
273
274     #dump = tarfile.open(fname, 'r')
275     #dump.extractall(TMP_FOLDER)
276
277     try:
278
279         disable_triggers()
280         db.start_transaction()
281
282         for h in FILE_HANDLERS:
283             run(h['fn'], h['id'])
284
285         db.commit_transaction()
286         enable_triggers()
287
288         settings.MERGE_MAPPINGS.set_value(dict(merged_nodes=data['nodes_map'], merged_users=data['users_map']))
289
290         reset_sequences()
291     except Exception, e:
292         full_state['running'] = False
293         full_state['errors'] = "%s: %s" % (e.__class__.__name__, unicode(e))
294         set_state()
295
296         import traceback
297         logging.error("Error executing xml import: \n %s" % (traceback.format_exc()))
298
299 def file_handler(file_name, root_tag, el_tag, name, args_handler=None, pre_callback=None, post_callback=None):
300     def decorator(fn):
301         def decorated(location, current_user, ping, data):
302             if pre_callback:
303                 pre_callback(current_user, data)
304
305             if (args_handler):
306                 args = args_handler(current_user, data)
307             else:
308                 args = []
309
310             parser = make_parser()
311             handler = TableHandler(root_tag, el_tag, fn, args, ping)
312             parser.setContentHandler(handler)
313             #parser.setErrorHandler(SaxErrorHandler())
314
315             parser.parse(SafeReader(os.path.join(location, file_name)))
316
317             if post_callback:
318                 post_callback()
319
320         FILE_HANDLERS.append(dict(id=root_tag, name=name, fn=decorated))
321         return decorated
322     return decorator
323
324 def verify_existence(row):
325     try:
326         return orm.User.objects.get(email=row.getc('email'))
327     except:
328         for key in row.get('authKeys').get_list('key'):
329             key = key=key.getc('key')
330
331             if not ("google.com" in key or "yahoo.com" in key):
332                 try:
333                     return orm.AuthKeyUserAssociation.objects.get(key=key).user
334                 except:
335                     pass
336
337     return None
338
339 def user_import_pre_callback(user, data):
340     data['users_map'] = {}
341
342 @file_handler('users.xml', 'users', 'user', _('Users'), pre_callback=user_import_pre_callback, args_handler=lambda u, d: [u, d['is_merge'], d['users_map']])
343 def user_import(row, current_user, is_merge, users_map):
344     existent = is_merge and verify_existence(row) or None
345
346     roles = row.get('roles').get_listc('role')
347     valid_email = row.get('email').get_attr('validated').as_bool()
348     badges = row.get('badges')
349
350     if existent:
351         user = existent
352
353         user.reputation += row.get('reputation').as_int()
354         user.gold += badges.get_attr('gold').as_int()
355         user.silver += badges.get_attr('gold').as_int()
356         user.bronze += badges.get_attr('gold').as_int()
357
358     else:
359         username = row.getc('username')
360
361         if is_merge:
362             username_count = 0
363
364             while orm.User.objects.filter(username=username).count():
365                 username_count += 1
366                 username = "%s %s" % (row.getc('username'), username_count)
367
368         user = orm.User(
369                 id           = (not is_merge) and row.getc('id') or None,
370                 username     = username,
371                 password     = row.getc('password'),
372                 email        = row.getc('email'),
373                 email_isvalid= valid_email,
374                 is_superuser = (not is_merge) and 'superuser' in roles,
375                 is_staff     = ('moderator' in roles) or (is_merge and 'superuser' in roles),
376                 is_active    = row.get('active').as_bool(),
377                 date_joined  = row.get('joindate').as_datetime(),
378                 about         = row.getc('bio'),
379                 date_of_birth = row.get('birthdate').as_date(None),
380                 website       = row.getc('website'),
381                 reputation    = row.get('reputation').as_int(),
382                 gold          = badges.get_attr('gold').as_int(),
383                 silver        = badges.get_attr('silver').as_int(),
384                 bronze        = badges.get_attr('bronze').as_int(),
385                 real_name     = row.getc('realname'),
386                 location      = row.getc('location'),
387         )
388
389     user.save()
390
391     users_map[row.get('id').as_int()] = user.id
392
393     authKeys = row.get('authKeys')
394
395     for key in authKeys.get_list('key'):
396         if (not is_merge) or orm.AuthKeyUserAssociation.objects.filter(key=key.getc('key')).count() == 0:
397             orm.AuthKeyUserAssociation(user=user, key=key.getc('key'), provider=key.getc('provider')).save()
398
399     if not existent:
400         notifications = row.get('notifications')
401
402         attributes = dict([(str(k), v.as_bool() and 'i' or 'n') for k, v in notifications.get('notify').attrs.items()])
403         attributes.update(dict([(str(k), v.as_bool()) for k, v in notifications.get('autoSubscribe').attrs.items()]))
404         attributes.update(dict([(str("notify_%s" % k), v.as_bool()) for k, v in notifications.get('notifyOnSubscribed').attrs.items()]))
405
406         ss = orm.SubscriptionSettings(user=user, enable_notifications=notifications.get_attr('enabled').as_bool(), **attributes)
407
408         if current_user.id == row.get('id').as_int():
409             ss.id = current_user.subscription_settings.id
410
411         ss.save()
412         
413
414 def pre_tag_import(user, data):
415     data['tag_mappings'] = dict([ (t.name, t) for t in orm.Tag.objects.all() ])
416
417
418 @file_handler('tags.xml', 'tags', 'tag', _('Tags'), pre_callback=pre_tag_import, args_handler=lambda u, d: [d['is_merge'], d['tag_merge'], d['users_map'], d['tag_mappings']])
419 def tag_import(row, is_merge, tag_merge, users_map, tag_mappings):
420     created_by = row.get('used').as_int()
421     created_by = users_map.get(created_by, created_by)
422
423     tag_name = row.getc('name')
424     tag_name = tag_merge and tag_merge.get(tag_name, tag_name) or tag_name
425
426     if is_merge and tag_name in tag_mappings:
427         tag = tag_mappings[tag_name]
428         tag.used_count += row.get('used').as_int()
429     else:
430         tag = orm.Tag(name=tag_name, used_count=row.get('used').as_int(), created_by_id=created_by)
431         tag_mappings[tag.name] = tag
432
433     tag.save()
434
435 def pre_node_import(user, data):
436     data['nodes_map'] = {}
437
438 @file_handler('nodes.xml', 'nodes', 'node', _('Nodes'), pre_callback=pre_node_import,
439               args_handler=lambda u, d: [d['is_merge'], d['tag_merge'], d['tag_mappings'], d['nodes_map'], d['users_map']])
440 def node_import(row, is_merge, tag_merge, tags, nodes_map, users_map):
441
442     ntags = []
443
444     for t in row.get('tags').get_list('tag'):
445         t = t.content()
446         ntags.append(tags[tag_merge and tag_merge.get(t, t) or t])
447
448     author = row.get('author').as_int()
449
450     last_act = row.get('lastactivity')
451     last_act_user = last_act.get('by').as_int(None)
452
453     parent = row.get('parent').as_int(None)
454     abs_parent = row.get('absparent').as_int(None)
455
456     node = orm.Node(
457             id            = (not is_merge) and row.getc('id') or None,
458             node_type     = row.getc('type'),
459             author_id     = users_map.get(author, author),
460             added_at      = row.get('date').as_datetime(),
461             parent_id     = nodes_map.get(parent, parent),
462             abs_parent_id = nodes_map.get(abs_parent, abs_parent),
463             score         = row.get('score').as_int(0),
464
465             last_activity_by_id = last_act_user and users_map.get(last_act_user, last_act_user) or last_act_user,
466             last_activity_at    = last_act.get('at').as_datetime(None),
467
468             title         = row.getc('title'),
469             body          = row.getc('body'),
470             tagnames      = " ".join([t.name for t in ntags]),
471
472             marked        = row.get('marked').as_bool(),
473             extra_ref_id  = row.get('extraRef').as_int(None),
474             extra_count   = row.get('extraCount').as_int(0),
475             extra         = row.get('extraData').as_pickled()
476     )
477
478     node.save()
479
480     nodes_map[row.get('id').as_int()] = node.id
481
482     node.tags = ntags
483
484     revisions = row.get('revisions')
485     active = revisions.get_attr('active').as_int()
486
487     if active == 0:
488         active = orm.NodeRevision(
489             author_id = node.author_id,
490             body = row.getc('body'),
491             node = node,
492             revised_at = row.get('date').as_datetime(),
493             revision = 1,
494             summary = _('Initial revision'),
495             tagnames = " ".join([t.name for t in ntags]),
496             title = row.getc('title'),
497         )
498
499         active.save()
500     else:
501         for r in revisions.get_list('revision'):
502             author = row.get('author').as_int()
503
504             rev = orm.NodeRevision(
505                 author_id = users_map.get(author, author),
506                 body = r.getc('body'),
507                 node = node,
508                 revised_at = r.get('date').as_datetime(),
509                 revision = r.get('number').as_int(),
510                 summary = r.getc('summary'),
511                 tagnames = " ".join(r.getc('tags').split(',')),
512                 title = r.getc('title'),
513             )
514
515             rev.save()
516             if rev.revision == active:
517                 active = rev
518
519     node.active_revision = active
520     node.save()
521
522 POST_ACTION = {}
523
524 def post_action(*types):
525     def decorator(fn):
526         for t in types:
527             POST_ACTION[t] = fn
528         return fn
529     return decorator
530
531 def pre_action_import_callback(user, data):
532     data['actions_map'] = {}
533
534 def post_action_import_callback():
535     with_state = orm.Node.objects.filter(id__in=orm.NodeState.objects.values_list('node_id', flat=True).distinct())
536
537     for n in with_state:
538         n.state_string = "".join(["(%s)" % s for s in n.states.values_list('state_type')])
539         n.save()
540
541 @file_handler('actions.xml', 'actions', 'action', _('Actions'), post_callback=post_action_import_callback,
542               pre_callback=pre_action_import_callback, args_handler=lambda u, d: [d['nodes_map'], d['users_map'], d['actions_map']])
543 def actions_import(row, nodes, users, actions_map):
544     node = row.get('node').as_int(None)
545     user = row.get('user').as_int()
546     real_user = row.get('realUser').as_int(None)
547
548     action = orm.Action(
549         #id           = row.get('id').as_int(),
550         action_type  = row.getc('type'),
551         action_date  = row.get('date').as_datetime(),
552         node_id      = nodes.get(node, node),
553         user_id      = users.get(user, user),
554         real_user_id = users.get(real_user, real_user),
555         ip           = row.getc('ip'),
556         extra        = row.get('extraData').as_pickled(),
557     )
558
559     canceled = row.get('canceled')
560     if canceled.get_attr('state').as_bool():
561         by = canceled.get('user').as_int()
562         action.canceled = True
563         action.canceled_by_id = users.get(by, by)
564         action.canceled_at = canceled.getc('date') #.as_datetime(),
565         action.canceled_ip = canceled.getc('ip')
566
567     action.save()
568
569     actions_map[row.get('id').as_int()] = action.id
570
571     for r in row.get('reputes').get_list('repute'):
572         by_canceled = r.get_attr('byCanceled').as_bool()
573
574         orm.ActionRepute(
575             action = action,
576             user_id = users[r.get('user').as_int()],
577             value = r.get('value').as_int(),
578
579             date = by_canceled and action.canceled_at or action.action_date,
580             by_canceled = by_canceled
581         ).save()
582
583     if (not action.canceled) and (action.action_type in POST_ACTION):
584         POST_ACTION[action.action_type](row, action, users, nodes, actions_map)
585
586
587
588
589 # Record of all persisted votes.
590 persisted_votes = []
591 @post_action('voteup', 'votedown', 'voteupcomment')
592 def vote_action(row, action, users, nodes, actions):
593     # Check to see if the vote has already been registered.
594     if not (action.user_id, action.node_id) in persisted_votes:
595         # Persist the vote action.
596         orm.Vote(user_id=action.user_id, node_id=action.node_id, action=action,
597                  voted_at=action.action_date, value=(action.action_type != 'votedown') and 1 or -1).save()
598
599         # Record the vote action.  This will help us avoid duplicates.
600         persisted_votes.append((action.user_id, action.node_id))
601
602
603 def state_action(state):
604     def fn(row, action, users, nodes, actions):
605         if orm.NodeState.objects.filter(state_type = state, node = action.node_id).count():
606             return
607
608         orm.NodeState(
609             state_type = state,
610             node_id = action.node_id,
611             action = action
612         ).save()
613     return fn
614
615 post_action('wikify')(state_action('wiki'))
616 post_action('delete')(state_action('deleted'))
617 post_action('acceptanswer')(state_action('accepted'))
618 post_action('publish')(state_action('published'))
619
620
621 @post_action('flag')
622 def flag_action(row, action, users, nodes, actions):
623     orm.Flag(user_id=action.user_id, node_id=action.node_id, action=action, reason=action.extra or "").save()
624
625
626 def award_import_args(user, data):
627     return [ dict([ (b.cls, b) for b in orm.Badge.objects.all() ]) , data['nodes_map'], data['users_map'], data['actions_map']]
628
629
630 @file_handler('awards.xml', 'awards', 'award', _('Awards'), args_handler=award_import_args)
631 def awards_import(row, badges, nodes, users, actions):
632     badge_type = badges.get(row.getc('badge'), None)
633
634     if not badge_type:
635         return
636
637     action = row.get('action').as_int(None)
638     trigger = row.get('trigger').as_int(None)
639     node = row.get('node').as_int(None)
640     user = row.get('user').as_int()
641
642     if orm.Award.objects.filter(badge=badges[row.getc('badge')], user=users.get(user, user), node=nodes.get(node, node)).count():
643         return
644
645     award = orm.Award(
646         user_id = users.get(user, user),
647         badge = badge_type,
648         node_id = nodes.get(node, node),
649         action_id = actions.get(action, action),
650         trigger_id = actions.get(trigger, trigger)
651     ).save()
652
653
654 #@file_handler('settings.xml', 'settings', 'setting', _('Settings'))
655 def settings_import(row):
656     orm.KeyValue(key=row.getc('key'), value=row.get('value').as_pickled())
657
658
659
660
661
662
663
664
665