]> git.openstreetmap.org Git - osqa.git/blobdiff - forum_modules/exporter/importer.py
allow only AJAX requests for post votes, otherwise it makes CSRF possible
[osqa.git] / forum_modules / exporter / importer.py
index bab6d216468bddaa881a1bd93c8833e4ce628eb2..c6b60ab0e78a7e1e9042eaa6ce7f38c746baf8c1 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import os, tarfile, datetime, ConfigParser, logging
 
 from django.utils.translation import ugettext as _
 import os, tarfile, datetime, ConfigParser, logging
 
 from django.utils.translation import ugettext as _
@@ -12,10 +14,26 @@ from forum.templatetags.extra_tags import diff_date
 
 from exporter import TMP_FOLDER, DATETIME_FORMAT, DATE_FORMAT, META_INF_SECTION, CACHE_KEY
 from orm import orm
 
 from exporter import TMP_FOLDER, DATETIME_FORMAT, DATE_FORMAT, META_INF_SECTION, CACHE_KEY
 from orm import orm
-import commands
+import commands, settings
 
 NO_DEFAULT = object()
 
 
 NO_DEFAULT = object()
 
+import string
+
+class SafeReader():
+    def __init__(self, loc):
+        self.base = open(loc)
+
+    def read(self, *args):
+        return "".join(c for c in self.base.read(*args) if c in string.printable)
+
+    def readLine(self, *args):
+        return "".join(c for c in self.base.readLine(*args) if c in string.printable)
+
+    def close(self):
+        self.base.close()
+
+
 class ContentElement():
     def __init__(self, content):
         self._content = content
 class ContentElement():
     def __init__(self, content):
         self._content = content
@@ -199,9 +217,12 @@ def reset_sequences():
         db.execute_many(commands.PG_SEQUENCE_RESETS)
         db.commit_transaction()
 
         db.execute_many(commands.PG_SEQUENCE_RESETS)
         db.commit_transaction()
 
+def reset_fts_indexes():
+    pass
+
 FILE_HANDLERS = []
 
 FILE_HANDLERS = []
 
-def start_import(fname, user):
+def start_import(fname, tag_merge, user):
 
     start_time = datetime.datetime.now()
     steps = [s for s in FILE_HANDLERS]
 
     start_time = datetime.datetime.now()
     steps = [s for s in FILE_HANDLERS]
@@ -231,6 +252,11 @@ def start_import(fname, user):
         state['overall']['parsed'] += 1
         set_state()
 
         state['overall']['parsed'] += 1
         set_state()
 
+    data = {
+        'is_merge': True,
+        'tag_merge': tag_merge
+    }
+
     def run(fn, name):
         def ping():
             ping_state(name)
     def run(fn, name):
         def ping():
             ping_state(name)
@@ -239,7 +265,7 @@ def start_import(fname, user):
         state[name]['status'] = _('Importing')
 
 
         state[name]['status'] = _('Importing')
 
 
-        fn(TMP_FOLDER, user, ping)
+        fn(TMP_FOLDER, user, ping, data)
 
         state[name]['status'] = _('Done')
 
 
         state[name]['status'] = _('Done')
 
@@ -258,10 +284,11 @@ def start_import(fname, user):
         for h in FILE_HANDLERS:
             run(h['fn'], h['id'])
 
         for h in FILE_HANDLERS:
             run(h['fn'], h['id'])
 
-        raise Exception
         db.commit_transaction()
         enable_triggers()
 
         db.commit_transaction()
         enable_triggers()
 
+        settings.MERGE_MAPPINGS.set_value(dict(merged_nodes=data['nodes_map'], merged_users=data['users_map']))
+
         reset_sequences()
     except Exception, e:
         full_state['running'] = False
         reset_sequences()
     except Exception, e:
         full_state['running'] = False
@@ -273,12 +300,12 @@ def start_import(fname, user):
 
 def file_handler(file_name, root_tag, el_tag, name, args_handler=None, pre_callback=None, post_callback=None):
     def decorator(fn):
 
 def file_handler(file_name, root_tag, el_tag, name, args_handler=None, pre_callback=None, post_callback=None):
     def decorator(fn):
-        def decorated(location, current_user, ping):
+        def decorated(location, current_user, ping, data):
             if pre_callback:
             if pre_callback:
-                pre_callback(current_user)
+                pre_callback(current_user, data)
 
             if (args_handler):
 
             if (args_handler):
-                args = args_handler(current_user)
+                args = args_handler(current_user, data)
             else:
                 args = []
 
             else:
                 args = []
 
@@ -287,7 +314,7 @@ def file_handler(file_name, root_tag, el_tag, name, args_handler=None, pre_callb
             parser.setContentHandler(handler)
             #parser.setErrorHandler(SaxErrorHandler())
 
             parser.setContentHandler(handler)
             #parser.setErrorHandler(SaxErrorHandler())
 
-            parser.parse(os.path.join(location, file_name))
+            parser.parse(SafeReader(os.path.join(location, file_name)))
 
             if post_callback:
                 post_callback()
 
             if post_callback:
                 post_callback()
@@ -296,94 +323,148 @@ def file_handler(file_name, root_tag, el_tag, name, args_handler=None, pre_callb
         return decorated
     return decorator
 
         return decorated
     return decorator
 
+def verify_existence(row):
+    try:
+        return orm.User.objects.get(email=row.getc('email'))
+    except:
+        for key in row.get('authKeys').get_list('key'):
+            key = key=key.getc('key')
+
+            if not ("google.com" in key or "yahoo.com" in key):
+                try:
+                    return orm.AuthKeyUserAssociation.objects.get(key=key).user
+                except:
+                    pass
+
+    return None
 
 
-@file_handler('users.xml', 'users', 'user', _('Users'), args_handler=lambda u: [u])
-def user_import(row, current_user):
-    existent = False
+def user_import_pre_callback(user, data):
+    data['users_map'] = {}
 
 
-    if str(current_user.id) == row.getc('id'):
-        existent = True
+@file_handler('users.xml', 'users', 'user', _('Users'), pre_callback=user_import_pre_callback, args_handler=lambda u, d: [u, d['is_merge'], d['users_map']])
+def user_import(row, current_user, is_merge, users_map):
+    existent = is_merge and verify_existence(row) or None
 
     roles = row.get('roles').get_listc('role')
     valid_email = row.get('email').get_attr('validated').as_bool()
     badges = row.get('badges')
 
 
     roles = row.get('roles').get_listc('role')
     valid_email = row.get('email').get_attr('validated').as_bool()
     badges = row.get('badges')
 
-    user = orm.User(
-            id           = row.getc('id'),
-            username     = row.getc('username'),
-            password     = row.getc('password'),
-            email        = row.getc('email'),
-            email_isvalid= valid_email,
-            is_superuser = 'superuser' in roles,
-            is_staff     = 'moderator' in roles,
-            is_active    = True,
-            date_joined  = row.get('joindate').as_datetime(),
-            about         = row.getc('bio'),
-            date_of_birth = row.get('birthdate').as_date(None),
-            website       = row.getc('website'),
-            reputation    = row.get('reputation').as_int(),
-            gold          = badges.get_attr('gold').as_int(),
-            silver        = badges.get_attr('silver').as_int(),
-            bronze        = badges.get_attr('bronze').as_int(),
-            real_name     = row.getc('realname'),
-            location      = row.getc('location'),
-    )
+    if existent:
+        user = existent
+
+        user.reputation += row.get('reputation').as_int()
+        user.gold += badges.get_attr('gold').as_int()
+        user.silver += badges.get_attr('gold').as_int()
+        user.bronze += badges.get_attr('gold').as_int()
+
+    else:
+        username = row.getc('username')
+
+        if is_merge:
+            username_count = 0
+
+            while orm.User.objects.filter(username=username).count():
+                username_count += 1
+                username = "%s %s" % (row.getc('username'), username_count)
+
+        user = orm.User(
+                id           = (not is_merge) and row.getc('id') or None,
+                username     = username,
+                password     = row.getc('password'),
+                email        = row.getc('email'),
+                email_isvalid= valid_email,
+                is_superuser = (not is_merge) and 'superuser' in roles,
+                is_staff     = ('moderator' in roles) or (is_merge and 'superuser' in roles),
+                is_active    = row.get('active').as_bool(),
+                date_joined  = row.get('joindate').as_datetime(),
+                about         = row.getc('bio'),
+                date_of_birth = row.get('birthdate').as_date(None),
+                website       = row.getc('website'),
+                reputation    = row.get('reputation').as_int(),
+                gold          = badges.get_attr('gold').as_int(),
+                silver        = badges.get_attr('silver').as_int(),
+                bronze        = badges.get_attr('bronze').as_int(),
+                real_name     = row.getc('realname'),
+                location      = row.getc('location'),
+        )
 
     user.save()
 
 
     user.save()
 
+    users_map[row.get('id').as_int()] = user.id
+
     authKeys = row.get('authKeys')
 
     for key in authKeys.get_list('key'):
     authKeys = row.get('authKeys')
 
     for key in authKeys.get_list('key'):
-        orm.AuthKeyUserAssociation(user=user, key=key.getc('key'), provider=key.getc('provider')).save()
+        if (not is_merge) or orm.AuthKeyUserAssociation.objects.filter(key=key.getc('key')).count() == 0:
+            orm.AuthKeyUserAssociation(user=user, key=key.getc('key'), provider=key.getc('provider')).save()
 
 
-    notifications = row.get('notifications')
+    if not existent:
+        notifications = row.get('notifications')
 
 
-    attributes = dict([(str(k), v.as_bool() and 'i' or 'n') for k, v in notifications.get('notify').attrs.items()])
-    attributes.update(dict([(str(k), v.as_bool()) for k, v in notifications.get('autoSubscribe').attrs.items()]))
-    attributes.update(dict([(str("notify_%s" % k), v.as_bool()) for k, v in notifications.get('notifyOnSubscribed').attrs.items()]))
+        attributes = dict([(str(k), v.as_bool() and 'i' or 'n') for k, v in notifications.get('notify').attrs.items()])
+        attributes.update(dict([(str(k), v.as_bool()) for k, v in notifications.get('autoSubscribe').attrs.items()]))
+        attributes.update(dict([(str("notify_%s" % k), v.as_bool()) for k, v in notifications.get('notifyOnSubscribed').attrs.items()]))
 
 
-    ss = orm.SubscriptionSettings(user=user, enable_notifications=notifications.get_attr('enabled').as_bool(), **attributes)
+        ss = orm.SubscriptionSettings(user=user, enable_notifications=notifications.get_attr('enabled').as_bool(), **attributes)
 
 
-    if existent:
-        ss.id = current_user.subscription_settings.id
+        if current_user.id == row.get('id').as_int():
+            ss.id = current_user.subscription_settings.id
 
 
-    ss.save()
+        ss.save()
         
 
         
 
-def pre_tag_import(user):
-    tag_import.tag_mappings={}
+def pre_tag_import(user, data):
+    data['tag_mappings'] = dict([ (t.name, t) for t in orm.Tag.objects.all() ])
 
 
 
 
-@file_handler('tags.xml', 'tags', 'tag', _('Tags'), pre_callback=pre_tag_import)
-def tag_import(row):
-    tag = orm.Tag(name=row.getc('name'), used_count=row.get('used').as_int(), created_by_id=row.get('author').as_int())
-    tag.save()
-    tag_import.tag_mappings[tag.name] = tag
+@file_handler('tags.xml', 'tags', 'tag', _('Tags'), pre_callback=pre_tag_import, args_handler=lambda u, d: [d['is_merge'], d['tag_merge'], d['users_map'], d['tag_mappings']])
+def tag_import(row, is_merge, tag_merge, users_map, tag_mappings):
+    created_by = row.get('used').as_int()
+    created_by = users_map.get(created_by, created_by)
+
+    tag_name = row.getc('name')
+    tag_name = tag_merge and tag_merge.get(tag_name, tag_name) or tag_name
 
 
+    if is_merge and tag_name in tag_mappings:
+        tag = tag_mappings[tag_name]
+        tag.used_count += row.get('used').as_int()
+    else:
+        tag = orm.Tag(name=tag_name, used_count=row.get('used').as_int(), created_by_id=created_by)
+        tag_mappings[tag.name] = tag
 
 
-def post_node_import():
-    tag_import.tag_mappings = None
+    tag.save()
+
+def pre_node_import(user, data):
+    data['nodes_map'] = {}
 
 
-@file_handler('nodes.xml', 'nodes', 'node', _('Nodes'), args_handler=lambda u: [tag_import.tag_mappings], post_callback=post_node_import)
-def node_import(row, tags):
+@file_handler('nodes.xml', 'nodes', 'node', _('Nodes'), pre_callback=pre_node_import,
+              args_handler=lambda u, d: [d['is_merge'], d['tag_merge'], d['tag_mappings'], d['nodes_map'], d['users_map']])
+def node_import(row, is_merge, tag_merge, tags, nodes_map, users_map):
 
     ntags = []
 
     for t in row.get('tags').get_list('tag'):
 
     ntags = []
 
     for t in row.get('tags').get_list('tag'):
-        ntags.append(tags[t.content()])
+        t = t.content()
+        ntags.append(tags[tag_merge and tag_merge.get(t, t) or t])
+
+    author = row.get('author').as_int()
 
     last_act = row.get('lastactivity')
 
     last_act = row.get('lastactivity')
+    last_act_user = last_act.get('by').as_int(None)
+
+    parent = row.get('parent').as_int(None)
+    abs_parent = row.get('absparent').as_int(None)
 
     node = orm.Node(
 
     node = orm.Node(
-            id            = row.getc('id'),
+            id            = (not is_merge) and row.getc('id') or None,
             node_type     = row.getc('type'),
             node_type     = row.getc('type'),
-            author_id     = row.get('author').as_int(),
+            author_id     = users_map.get(author, author),
             added_at      = row.get('date').as_datetime(),
             added_at      = row.get('date').as_datetime(),
-            parent_id     = row.get('parent').as_int(None),
-            abs_parent_id = row.get('absparent').as_int(None),
+            parent_id     = nodes_map.get(parent, parent),
+            abs_parent_id = nodes_map.get(abs_parent, abs_parent),
             score         = row.get('score').as_int(0),
 
             score         = row.get('score').as_int(0),
 
-            last_activity_by_id = last_act.get('by').as_int(None),
+            last_activity_by_id = last_act_user and users_map.get(last_act_user, last_act_user) or last_act_user,
             last_activity_at    = last_act.get('at').as_datetime(None),
 
             title         = row.getc('title'),
             last_activity_at    = last_act.get('at').as_datetime(None),
 
             title         = row.getc('title'),
@@ -397,26 +478,45 @@ def node_import(row, tags):
     )
 
     node.save()
     )
 
     node.save()
+
+    nodes_map[row.get('id').as_int()] = node.id
+
     node.tags = ntags
 
     revisions = row.get('revisions')
     active = revisions.get_attr('active').as_int()
 
     node.tags = ntags
 
     revisions = row.get('revisions')
     active = revisions.get_attr('active').as_int()
 
-    for r in revisions.get_list('revision'):
-        rev = orm.NodeRevision(
-            author_id = r.getc('author'),
-            body = r.getc('body'),
+    if active == 0:
+        active = orm.NodeRevision(
+            author_id = node.author_id,
+            body = row.getc('body'),
             node = node,
             node = node,
-            revised_at = r.get('date').as_datetime(),
-            revision = r.get('number').as_int(),
-            summary = r.getc('summary'),
-            tagnames = " ".join(r.getc('tags').split(',')),
-            title = r.getc('title'),
+            revised_at = row.get('date').as_datetime(),
+            revision = 1,
+            summary = _('Initial revision'),
+            tagnames = " ".join([t.name for t in ntags]),
+            title = row.getc('title'),
         )
 
         )
 
-        rev.save()
-        if rev.revision == active:
-            active = rev
+        active.save()
+    else:
+        for r in revisions.get_list('revision'):
+            author = row.get('author').as_int()
+
+            rev = orm.NodeRevision(
+                author_id = users_map.get(author, author),
+                body = r.getc('body'),
+                node = node,
+                revised_at = r.get('date').as_datetime(),
+                revision = r.get('number').as_int(),
+                summary = r.getc('summary'),
+                tagnames = " ".join(r.getc('tags').split(',')),
+                title = r.getc('title'),
+            )
+
+            rev.save()
+            if rev.revision == active:
+                active = rev
 
     node.active_revision = active
     node.save()
 
     node.active_revision = active
     node.save()
@@ -430,6 +530,9 @@ def post_action(*types):
         return fn
     return decorator
 
         return fn
     return decorator
 
+def pre_action_import_callback(user, data):
+    data['actions_map'] = {}
+
 def post_action_import_callback():
     with_state = orm.Node.objects.filter(id__in=orm.NodeState.objects.values_list('node_id', flat=True).distinct())
 
 def post_action_import_callback():
     with_state = orm.Node.objects.filter(id__in=orm.NodeState.objects.values_list('node_id', flat=True).distinct())
 
@@ -437,34 +540,42 @@ def post_action_import_callback():
         n.state_string = "".join(["(%s)" % s for s in n.states.values_list('state_type')])
         n.save()
 
         n.state_string = "".join(["(%s)" % s for s in n.states.values_list('state_type')])
         n.save()
 
-@file_handler('actions.xml', 'actions', 'action', _('Actions'), post_callback=post_action_import_callback)
-def actions_import(row):
+@file_handler('actions.xml', 'actions', 'action', _('Actions'), post_callback=post_action_import_callback,
+              pre_callback=pre_action_import_callback, args_handler=lambda u, d: [d['nodes_map'], d['users_map'], d['actions_map']])
+def actions_import(row, nodes, users, actions_map):
+    node = row.get('node').as_int(None)
+    user = row.get('user').as_int()
+    real_user = row.get('realUser').as_int(None)
+
     action = orm.Action(
     action = orm.Action(
-        id           = row.get('id').as_int(),
+        #id           = row.get('id').as_int(),
         action_type  = row.getc('type'),
         action_date  = row.get('date').as_datetime(),
         action_type  = row.getc('type'),
         action_date  = row.get('date').as_datetime(),
-        node_id      = row.get('node').as_int(None),
-        user_id      = row.get('user').as_int(),
-        real_user_id = row.get('realUser').as_int(None),
+        node_id      = nodes.get(node, node),
+        user_id      = users.get(user, user),
+        real_user_id = users.get(real_user, real_user),
         ip           = row.getc('ip'),
         extra        = row.get('extraData').as_pickled(),
     )
 
     canceled = row.get('canceled')
     if canceled.get_attr('state').as_bool():
         ip           = row.getc('ip'),
         extra        = row.get('extraData').as_pickled(),
     )
 
     canceled = row.get('canceled')
     if canceled.get_attr('state').as_bool():
+        by = canceled.get('user').as_int()
         action.canceled = True
         action.canceled = True
-        action.canceled_by_id = canceled.get('user').as_int()
-        #action.canceled_at = canceled.get('date').as_datetime(),
+        action.canceled_by_id = users.get(by, by)
+        action.canceled_at = canceled.getc('date') #.as_datetime(),
         action.canceled_ip = canceled.getc('ip')
 
     action.save()
 
         action.canceled_ip = canceled.getc('ip')
 
     action.save()
 
+    actions_map[row.get('id').as_int()] = action.id
+
     for r in row.get('reputes').get_list('repute'):
         by_canceled = r.get_attr('byCanceled').as_bool()
 
         orm.ActionRepute(
             action = action,
     for r in row.get('reputes').get_list('repute'):
         by_canceled = r.get_attr('byCanceled').as_bool()
 
         orm.ActionRepute(
             action = action,
-            user_id = r.get('user').as_int(),
+            user_id = users[r.get('user').as_int()],
             value = r.get('value').as_int(),
 
             date = by_canceled and action.canceled_at or action.action_date,
             value = r.get('value').as_int(),
 
             date = by_canceled and action.canceled_at or action.action_date,
@@ -472,18 +583,30 @@ def actions_import(row):
         ).save()
 
     if (not action.canceled) and (action.action_type in POST_ACTION):
         ).save()
 
     if (not action.canceled) and (action.action_type in POST_ACTION):
-        POST_ACTION[action.action_type](row, action)
+        POST_ACTION[action.action_type](row, action, users, nodes, actions_map)
 
 
 
 
 
 
 
 
+# Record of all persisted votes.
+persisted_votes = []
 @post_action('voteup', 'votedown', 'voteupcomment')
 @post_action('voteup', 'votedown', 'voteupcomment')
-def vote_action(row, action):
-    orm.Vote(user_id=action.user_id, node_id=action.node_id, action=action,
-             voted_at=action.action_date, value=(action.action_type != 'votedown') and 1 or -1).save()
+def vote_action(row, action, users, nodes, actions):
+    # Check to see if the vote has already been registered.
+    if not (action.user_id, action.node_id) in persisted_votes:
+        # Persist the vote action.
+        orm.Vote(user_id=action.user_id, node_id=action.node_id, action=action,
+                 voted_at=action.action_date, value=(action.action_type != 'votedown') and 1 or -1).save()
+
+        # Record the vote action.  This will help us avoid duplicates.
+        persisted_votes.append((action.user_id, action.node_id))
+
 
 def state_action(state):
 
 def state_action(state):
-    def fn(row, action):
+    def fn(row, action, users, nodes, actions):
+        if orm.NodeState.objects.filter(state_type = state, node = action.node_id).count():
+            return
+
         orm.NodeState(
             state_type = state,
             node_id = action.node_id,
         orm.NodeState(
             state_type = state,
             node_id = action.node_id,
@@ -498,34 +621,39 @@ post_action('publish')(state_action('published'))
 
 
 @post_action('flag')
 
 
 @post_action('flag')
-def flag_action(row, action):
-    orm.Flag(user_id=action.user_id, node_id=action.node_id, action=action, reason=action.extra).save()
+def flag_action(row, action, users, nodes, actions):
+    orm.Flag(user_id=action.user_id, node_id=action.node_id, action=action, reason=action.extra or "").save()
 
 
 
 
-def award_import_args(user):
-    return [ dict([ (b.cls, b) for b in orm.Badge.objects.all() ]) ]
+def award_import_args(user, data):
+    return [ dict([ (b.cls, b) for b in orm.Badge.objects.all() ]) , data['nodes_map'], data['users_map'], data['actions_map']]
 
 
 @file_handler('awards.xml', 'awards', 'award', _('Awards'), args_handler=award_import_args)
 
 
 @file_handler('awards.xml', 'awards', 'award', _('Awards'), args_handler=award_import_args)
-def awards_import(row, badges):
-    try:
-        badge_type = badges.get(row.getc('badge'), None)
+def awards_import(row, badges, nodes, users, actions):
+    badge_type = badges.get(row.getc('badge'), None)
 
 
-        if not badge_type:
-            return
+    if not badge_type:
+        return
 
 
-        award = orm.Award(
-            user_id = row.get('user').as_int(),
-            badge = badges[row.getc('badge')],
-            node_id = row.get('node').as_int(None),
-            action_id = row.get('action').as_int(None),
-            trigger_id = row.get('trigger').as_int(None)
-        ).save()
-    except Exception, e:
+    action = row.get('action').as_int(None)
+    trigger = row.get('trigger').as_int(None)
+    node = row.get('node').as_int(None)
+    user = row.get('user').as_int()
+
+    if orm.Award.objects.filter(badge=badges[row.getc('badge')], user=users.get(user, user), node=nodes.get(node, node)).count():
         return
 
         return
 
+    award = orm.Award(
+        user_id = users.get(user, user),
+        badge = badge_type,
+        node_id = nodes.get(node, node),
+        action_id = actions.get(action, action),
+        trigger_id = actions.get(trigger, trigger)
+    ).save()
+
 
 
-@file_handler('settings.xml', 'settings', 'setting', _('Settings'))
+#@file_handler('settings.xml', 'settings', 'setting', _('Settings'))
 def settings_import(row):
     orm.KeyValue(key=row.getc('key'), value=row.get('value').as_pickled())
 
 def settings_import(row):
     orm.KeyValue(key=row.getc('key'), value=row.get('value').as_pickled())