]> git.openstreetmap.org Git - osqa.git/blob - forum_modules/exporter/importer.py
fix breach in award points that allows user to award infinite points
[osqa.git] / forum_modules / exporter / importer.py
1 from __future__ import with_statement
2
3 import os, tarfile, datetime, ConfigParser, logging
4
5 from django.utils.translation import ugettext as _
6 from django.core.cache import cache
7
8 from south.db import db
9
10 from xml.sax import make_parser
11 from xml.sax.handler import ContentHandler, ErrorHandler
12
13 from forum.templatetags.extra_tags import diff_date
14
15 from exporter import TMP_FOLDER, DATETIME_FORMAT, DATE_FORMAT, META_INF_SECTION, CACHE_KEY
16 from orm import orm
17 import commands, settings
18
19 NO_DEFAULT = object()
20
21 import string
22
23 class SafeReader():
24     def __init__(self, loc):
25         self.base = open(loc)
26
27     def read(self, *args):
28         return "".join(c for c in self.base.read(*args) if c in string.printable)
29
30     def readLine(self, *args):
31         return "".join(c for c in self.base.readLine(*args) if c in string.printable)
32
33     def close(self):
34         self.base.close()
35
36
37 class ContentElement():
38     def __init__(self, content):
39         self._content = content
40
41     def content(self):
42         return self._content.strip()
43
44     def as_bool(self):
45         return self.content() == "true"
46
47     def as_date(self, default=NO_DEFAULT):
48         try:
49             return datetime.datetime.strptime(self.content(), DATE_FORMAT)
50         except:
51             if default == NO_DEFAULT:
52                 return datetime.date.fromtimestamp(0)
53             else:
54                 return default
55             
56
57     def as_datetime(self, default=NO_DEFAULT):
58         try:
59             return datetime.datetime.strptime(self.content(), DATETIME_FORMAT)
60         except:
61             if default == NO_DEFAULT:
62                 return datetime.datetime.fromtimestamp(0)
63             else:
64                 return default
65
66     def as_int(self, default=0):
67         try:
68             return int(self.content())
69         except:
70             return default
71
72     def __str__(self):
73         return self.content()
74
75
76 class RowElement(ContentElement):
77     def __init__(self, name, attrs, parent=None):
78         self.name = name.lower()
79         self.parent = parent
80         self.attrs = dict([(k.lower(), ContentElement(v)) for k, v in attrs.items()])
81         self._content = u''
82         self.sub_elements = {}
83
84         if parent:
85             parent.add(self)
86
87     def add_to_content(self, ch):
88         self._content += unicode(ch)
89
90     def add(self, sub):
91         curr = self.sub_elements.get(sub.name, None)
92
93         if not curr:
94             curr = []
95             self.sub_elements[sub.name] = curr
96
97         curr.append(sub)
98
99     def get(self, name, default=None):
100         return self.sub_elements.get(name.lower(), [default])[-1]
101
102     def get_list(self, name):
103         return self.sub_elements.get(name.lower(), [])
104
105     def get_listc(self, name):
106         return [r.content() for r in self.get_list(name)]
107
108     def getc(self, name, default=""):
109         el = self.get(name, None)
110
111         if el:
112             return el.content()
113         else:
114             return default
115
116     def get_attr(self, name, default=""):
117         return self.attrs.get(name.lower(), default)
118
119     def as_pickled(self, default=None):
120         value_el = self.get('value')
121
122         if value_el:
123             return value_el._as_pickled(default)
124         else:
125             return default
126
127     TYPES_MAP = dict([(c.__name__, c) for c in (int, long, str, unicode, float)])
128
129     def _as_pickled(self, default=None):
130         type = self.get_attr('type').content()
131
132         try:
133             if type == 'dict':
134                 return dict([ (item.get_attr('key'), item.as_pickled()) for item in self.get_list('item') ])
135             elif type == 'list':
136                 return [item.as_pickled() for item in self.get_list('item')]
137             elif type == 'bool':
138                 return self.content().lower() == 'true'
139             elif type in RowElement.TYPES_MAP:
140                 return RowElement.TYPES_MAP[type](self.content())
141             else:
142                 return self.content()
143         except:
144             return default
145
146
147
148
149 class TableHandler(ContentHandler):
150     def __init__(self, root_name, row_name, callback, callback_args = [], ping = None):
151         self.root_name = root_name.lower()
152         self.row_name = row_name.lower()
153         self.callback = callback
154         self.callback_args = callback_args
155         self.ping = ping
156
157         self._reset()
158
159     def _reset(self):
160         self.curr_element = None
161         self.in_tag = None
162
163     def startElement(self, name, attrs):
164         name = name.lower()
165
166         if name == self.root_name.lower():
167             pass
168         elif name == self.row_name:
169             self.curr_element = RowElement(name, attrs)
170         else:
171             self.curr_element = RowElement(name, attrs, self.curr_element)
172
173     def characters(self, ch):
174         if self.curr_element:
175             self.curr_element.add_to_content(ch)
176
177     def endElement(self, name):
178         name = name.lower()
179
180         if name == self.root_name:
181             pass
182         elif name == self.row_name:
183             self.callback(self.curr_element, *self.callback_args)
184             if self.ping:
185                 self.ping()
186
187             self._reset()
188         else:
189             self.curr_element = self.curr_element.parent
190
191
192 class SaxErrorHandler(ErrorHandler):
193     def error(self, e):
194         raise e
195
196     def fatalError(self, e):
197         raise e
198
199     def warning(self, e):
200         raise e
201
202 def disable_triggers():
203     if db.backend_name == "postgres":
204         db.start_transaction()
205         db.execute_many(commands.PG_DISABLE_TRIGGERS)
206         db.commit_transaction()
207
208 def enable_triggers():
209     if db.backend_name == "postgres":
210         db.start_transaction()
211         db.execute_many(commands.PG_ENABLE_TRIGGERS)
212         db.commit_transaction()
213
214 def reset_sequences():
215     if db.backend_name == "postgres":
216         db.start_transaction()
217         db.execute_many(commands.PG_SEQUENCE_RESETS)
218         db.commit_transaction()
219
220 def reset_fts_indexes():
221     pass
222
223 FILE_HANDLERS = []
224
225 def start_import(fname, tag_merge, user):
226
227     start_time = datetime.datetime.now()
228     steps = [s for s in FILE_HANDLERS]
229
230     with open(os.path.join(TMP_FOLDER, 'backup.inf'), 'r') as inffile:
231         inf = ConfigParser.SafeConfigParser()
232         inf.readfp(inffile)
233
234         state = dict([(s['id'], {
235             'status': _('Queued'), 'count': int(inf.get(META_INF_SECTION, s['id'])), 'parsed': 0
236         }) for s in steps] + [
237             ('overall', {
238                 'status': _('Starting'), 'count': int(inf.get(META_INF_SECTION, 'overall')), 'parsed': 0
239             })
240         ])
241
242     full_state = dict(running=True, state=state, time_started="")
243
244     def set_state():
245         full_state['time_started'] = diff_date(start_time)
246         cache.set(CACHE_KEY, full_state)
247
248     set_state()
249
250     def ping_state(name):
251         state[name]['parsed'] += 1
252         state['overall']['parsed'] += 1
253         set_state()
254
255     data = {
256         'is_merge': True,
257         'tag_merge': tag_merge
258     }
259
260     def run(fn, name):
261         def ping():
262             ping_state(name)
263
264         state['overall']['status'] = _('Importing %s') % s['name']
265         state[name]['status'] = _('Importing')
266
267
268         fn(TMP_FOLDER, user, ping, data)
269
270         state[name]['status'] = _('Done')
271
272         set_state()
273
274         return fname
275
276     #dump = tarfile.open(fname, 'r')
277     #dump.extractall(TMP_FOLDER)
278
279     try:
280
281         disable_triggers()
282         db.start_transaction()
283
284         for h in FILE_HANDLERS:
285             run(h['fn'], h['id'])
286
287         db.commit_transaction()
288         enable_triggers()
289
290         settings.MERGE_MAPPINGS.set_value(dict(merged_nodes=data['nodes_map'], merged_users=data['users_map']))
291
292         reset_sequences()
293     except Exception, e:
294         full_state['running'] = False
295         full_state['errors'] = "%s: %s" % (e.__class__.__name__, unicode(e))
296         set_state()
297
298         import traceback
299         logging.error("Error executing xml import: \n %s" % (traceback.format_exc()))
300
301 def file_handler(file_name, root_tag, el_tag, name, args_handler=None, pre_callback=None, post_callback=None):
302     def decorator(fn):
303         def decorated(location, current_user, ping, data):
304             if pre_callback:
305                 pre_callback(current_user, data)
306
307             if (args_handler):
308                 args = args_handler(current_user, data)
309             else:
310                 args = []
311
312             parser = make_parser()
313             handler = TableHandler(root_tag, el_tag, fn, args, ping)
314             parser.setContentHandler(handler)
315             #parser.setErrorHandler(SaxErrorHandler())
316
317             parser.parse(SafeReader(os.path.join(location, file_name)))
318
319             if post_callback:
320                 post_callback()
321
322         FILE_HANDLERS.append(dict(id=root_tag, name=name, fn=decorated))
323         return decorated
324     return decorator
325
326 def verify_existence(row):
327     try:
328         return orm.User.objects.get(email=row.getc('email'))
329     except:
330         for key in row.get('authKeys').get_list('key'):
331             key = key=key.getc('key')
332
333             if not ("google.com" in key or "yahoo.com" in key):
334                 try:
335                     return orm.AuthKeyUserAssociation.objects.get(key=key).user
336                 except:
337                     pass
338
339     return None
340
341 def user_import_pre_callback(user, data):
342     data['users_map'] = {}
343
344 @file_handler('users.xml', 'users', 'user', _('Users'), pre_callback=user_import_pre_callback, args_handler=lambda u, d: [u, d['is_merge'], d['users_map']])
345 def user_import(row, current_user, is_merge, users_map):
346     existent = is_merge and verify_existence(row) or None
347
348     roles = row.get('roles').get_listc('role')
349     valid_email = row.get('email').get_attr('validated').as_bool()
350     badges = row.get('badges')
351
352     if existent:
353         user = existent
354
355         user.reputation += row.get('reputation').as_int()
356         user.gold += badges.get_attr('gold').as_int()
357         user.silver += badges.get_attr('gold').as_int()
358         user.bronze += badges.get_attr('gold').as_int()
359
360     else:
361         username = row.getc('username')
362
363         if is_merge:
364             username_count = 0
365
366             while orm.User.objects.filter(username=username).count():
367                 username_count += 1
368                 username = "%s %s" % (row.getc('username'), username_count)
369
370         user = orm.User(
371                 id           = (not is_merge) and row.getc('id') or None,
372                 username     = username,
373                 password     = row.getc('password'),
374                 email        = row.getc('email'),
375                 email_isvalid= valid_email,
376                 is_superuser = (not is_merge) and 'superuser' in roles,
377                 is_staff     = ('moderator' in roles) or (is_merge and 'superuser' in roles),
378                 is_active    = row.get('active').as_bool(),
379                 date_joined  = row.get('joindate').as_datetime(),
380                 about         = row.getc('bio'),
381                 date_of_birth = row.get('birthdate').as_date(None),
382                 website       = row.getc('website'),
383                 reputation    = row.get('reputation').as_int(),
384                 gold          = badges.get_attr('gold').as_int(),
385                 silver        = badges.get_attr('silver').as_int(),
386                 bronze        = badges.get_attr('bronze').as_int(),
387                 real_name     = row.getc('realname'),
388                 location      = row.getc('location'),
389         )
390
391     user.save()
392
393     users_map[row.get('id').as_int()] = user.id
394
395     authKeys = row.get('authKeys')
396
397     for key in authKeys.get_list('key'):
398         if (not is_merge) or orm.AuthKeyUserAssociation.objects.filter(key=key.getc('key')).count() == 0:
399             orm.AuthKeyUserAssociation(user=user, key=key.getc('key'), provider=key.getc('provider')).save()
400
401     if not existent:
402         notifications = row.get('notifications')
403
404         attributes = dict([(str(k), v.as_bool() and 'i' or 'n') for k, v in notifications.get('notify').attrs.items()])
405         attributes.update(dict([(str(k), v.as_bool()) for k, v in notifications.get('autoSubscribe').attrs.items()]))
406         attributes.update(dict([(str("notify_%s" % k), v.as_bool()) for k, v in notifications.get('notifyOnSubscribed').attrs.items()]))
407
408         ss = orm.SubscriptionSettings(user=user, enable_notifications=notifications.get_attr('enabled').as_bool(), **attributes)
409
410         if current_user.id == row.get('id').as_int():
411             ss.id = current_user.subscription_settings.id
412
413         ss.save()
414         
415
416 def pre_tag_import(user, data):
417     data['tag_mappings'] = dict([ (t.name, t) for t in orm.Tag.objects.all() ])
418
419
420 @file_handler('tags.xml', 'tags', 'tag', _('Tags'), pre_callback=pre_tag_import, args_handler=lambda u, d: [d['is_merge'], d['tag_merge'], d['users_map'], d['tag_mappings']])
421 def tag_import(row, is_merge, tag_merge, users_map, tag_mappings):
422     created_by = row.get('used').as_int()
423     created_by = users_map.get(created_by, created_by)
424
425     tag_name = row.getc('name')
426     tag_name = tag_merge and tag_merge.get(tag_name, tag_name) or tag_name
427
428     if is_merge and tag_name in tag_mappings:
429         tag = tag_mappings[tag_name]
430         tag.used_count += row.get('used').as_int()
431     else:
432         tag = orm.Tag(name=tag_name, used_count=row.get('used').as_int(), created_by_id=created_by)
433         tag_mappings[tag.name] = tag
434
435     tag.save()
436
437 def pre_node_import(user, data):
438     data['nodes_map'] = {}
439
440 @file_handler('nodes.xml', 'nodes', 'node', _('Nodes'), pre_callback=pre_node_import,
441               args_handler=lambda u, d: [d['is_merge'], d['tag_merge'], d['tag_mappings'], d['nodes_map'], d['users_map']])
442 def node_import(row, is_merge, tag_merge, tags, nodes_map, users_map):
443
444     ntags = []
445
446     for t in row.get('tags').get_list('tag'):
447         t = t.content()
448         ntags.append(tags[tag_merge and tag_merge.get(t, t) or t])
449
450     author = row.get('author').as_int()
451
452     last_act = row.get('lastactivity')
453     last_act_user = last_act.get('by').as_int(None)
454
455     parent = row.get('parent').as_int(None)
456     abs_parent = row.get('absparent').as_int(None)
457
458     node = orm.Node(
459             id            = (not is_merge) and row.getc('id') or None,
460             node_type     = row.getc('type'),
461             author_id     = users_map.get(author, author),
462             added_at      = row.get('date').as_datetime(),
463             parent_id     = nodes_map.get(parent, parent),
464             abs_parent_id = nodes_map.get(abs_parent, abs_parent),
465             score         = row.get('score').as_int(0),
466
467             last_activity_by_id = last_act_user and users_map.get(last_act_user, last_act_user) or last_act_user,
468             last_activity_at    = last_act.get('at').as_datetime(None),
469
470             title         = row.getc('title'),
471             body          = row.getc('body'),
472             tagnames      = " ".join([t.name for t in ntags]),
473
474             marked        = row.get('marked').as_bool(),
475             extra_ref_id  = row.get('extraRef').as_int(None),
476             extra_count   = row.get('extraCount').as_int(0),
477             extra         = row.get('extraData').as_pickled()
478     )
479
480     node.save()
481
482     nodes_map[row.get('id').as_int()] = node.id
483
484     node.tags = ntags
485
486     revisions = row.get('revisions')
487     active = revisions.get_attr('active').as_int()
488
489     if active == 0:
490         active = orm.NodeRevision(
491             author_id = node.author_id,
492             body = row.getc('body'),
493             node = node,
494             revised_at = row.get('date').as_datetime(),
495             revision = 1,
496             summary = _('Initial revision'),
497             tagnames = " ".join([t.name for t in ntags]),
498             title = row.getc('title'),
499         )
500
501         active.save()
502     else:
503         for r in revisions.get_list('revision'):
504             author = row.get('author').as_int()
505
506             rev = orm.NodeRevision(
507                 author_id = users_map.get(author, author),
508                 body = r.getc('body'),
509                 node = node,
510                 revised_at = r.get('date').as_datetime(),
511                 revision = r.get('number').as_int(),
512                 summary = r.getc('summary'),
513                 tagnames = " ".join(r.getc('tags').split(',')),
514                 title = r.getc('title'),
515             )
516
517             rev.save()
518             if rev.revision == active:
519                 active = rev
520
521     node.active_revision = active
522     node.save()
523
524 POST_ACTION = {}
525
526 def post_action(*types):
527     def decorator(fn):
528         for t in types:
529             POST_ACTION[t] = fn
530         return fn
531     return decorator
532
533 def pre_action_import_callback(user, data):
534     data['actions_map'] = {}
535
536 def post_action_import_callback():
537     with_state = orm.Node.objects.filter(id__in=orm.NodeState.objects.values_list('node_id', flat=True).distinct())
538
539     for n in with_state:
540         n.state_string = "".join(["(%s)" % s for s in n.states.values_list('state_type')])
541         n.save()
542
543 @file_handler('actions.xml', 'actions', 'action', _('Actions'), post_callback=post_action_import_callback,
544               pre_callback=pre_action_import_callback, args_handler=lambda u, d: [d['nodes_map'], d['users_map'], d['actions_map']])
545 def actions_import(row, nodes, users, actions_map):
546     node = row.get('node').as_int(None)
547     user = row.get('user').as_int()
548     real_user = row.get('realUser').as_int(None)
549
550     action = orm.Action(
551         #id           = row.get('id').as_int(),
552         action_type  = row.getc('type'),
553         action_date  = row.get('date').as_datetime(),
554         node_id      = nodes.get(node, node),
555         user_id      = users.get(user, user),
556         real_user_id = users.get(real_user, real_user),
557         ip           = row.getc('ip'),
558         extra        = row.get('extraData').as_pickled(),
559     )
560
561     canceled = row.get('canceled')
562     if canceled.get_attr('state').as_bool():
563         by = canceled.get('user').as_int()
564         action.canceled = True
565         action.canceled_by_id = users.get(by, by)
566         action.canceled_at = canceled.getc('date') #.as_datetime(),
567         action.canceled_ip = canceled.getc('ip')
568
569     action.save()
570
571     actions_map[row.get('id').as_int()] = action.id
572
573     for r in row.get('reputes').get_list('repute'):
574         by_canceled = r.get_attr('byCanceled').as_bool()
575
576         orm.ActionRepute(
577             action = action,
578             user_id = users[r.get('user').as_int()],
579             value = r.get('value').as_int(),
580
581             date = by_canceled and action.canceled_at or action.action_date,
582             by_canceled = by_canceled
583         ).save()
584
585     if (not action.canceled) and (action.action_type in POST_ACTION):
586         POST_ACTION[action.action_type](row, action, users, nodes, actions_map)
587
588
589
590
591 # Record of all persisted votes.
592 persisted_votes = []
593 @post_action('voteup', 'votedown', 'voteupcomment')
594 def vote_action(row, action, users, nodes, actions):
595     # Check to see if the vote has already been registered.
596     if not (action.user_id, action.node_id) in persisted_votes:
597         # Persist the vote action.
598         orm.Vote(user_id=action.user_id, node_id=action.node_id, action=action,
599                  voted_at=action.action_date, value=(action.action_type != 'votedown') and 1 or -1).save()
600
601         # Record the vote action.  This will help us avoid duplicates.
602         persisted_votes.append((action.user_id, action.node_id))
603
604
605 def state_action(state):
606     def fn(row, action, users, nodes, actions):
607         if orm.NodeState.objects.filter(state_type = state, node = action.node_id).count():
608             return
609
610         orm.NodeState(
611             state_type = state,
612             node_id = action.node_id,
613             action = action
614         ).save()
615     return fn
616
617 post_action('wikify')(state_action('wiki'))
618 post_action('delete')(state_action('deleted'))
619 post_action('acceptanswer')(state_action('accepted'))
620 post_action('publish')(state_action('published'))
621
622
623 @post_action('flag')
624 def flag_action(row, action, users, nodes, actions):
625     orm.Flag(user_id=action.user_id, node_id=action.node_id, action=action, reason=action.extra or "").save()
626
627
628 def award_import_args(user, data):
629     return [ dict([ (b.cls, b) for b in orm.Badge.objects.all() ]) , data['nodes_map'], data['users_map'], data['actions_map']]
630
631
632 @file_handler('awards.xml', 'awards', 'award', _('Awards'), args_handler=award_import_args)
633 def awards_import(row, badges, nodes, users, actions):
634     badge_type = badges.get(row.getc('badge'), None)
635
636     if not badge_type:
637         return
638
639     action = row.get('action').as_int(None)
640     trigger = row.get('trigger').as_int(None)
641     node = row.get('node').as_int(None)
642     user = row.get('user').as_int()
643
644     if orm.Award.objects.filter(badge=badges[row.getc('badge')], user=users.get(user, user), node=nodes.get(node, node)).count():
645         return
646
647     award = orm.Award(
648         user_id = users.get(user, user),
649         badge = badge_type,
650         node_id = nodes.get(node, node),
651         action_id = actions.get(action, action),
652         trigger_id = actions.get(trigger, trigger)
653     ).save()
654
655
656 #@file_handler('settings.xml', 'settings', 'setting', _('Settings'))
657 def settings_import(row):
658     orm.KeyValue(key=row.getc('key'), value=row.get('value').as_pickled())
659
660
661
662
663
664
665
666
667