]> git.openstreetmap.org Git - osqa.git/blobdiff - forum_modules/exporter/importer.py
More tweaks and improvements.
[osqa.git] / forum_modules / exporter / importer.py
index 35b42b17c9ef9625570109bd8daf92b59b0e99e7..bab6d216468bddaa881a1bd93c8833e4ce628eb2 100644 (file)
@@ -1,10 +1,18 @@
-import os, tarfile, datetime
+import os, tarfile, datetime, ConfigParser, logging
+
+from django.utils.translation import ugettext as _
+from django.core.cache import cache
+
+from south.db import db
 
 from xml.sax import make_parser
 from xml.sax.handler import ContentHandler, ErrorHandler
 
-from exporter import TMP_FOLDER, DATETIME_FORMAT, DATE_FORMAT
+from forum.templatetags.extra_tags import diff_date
+
+from exporter import TMP_FOLDER, DATETIME_FORMAT, DATE_FORMAT, META_INF_SECTION, CACHE_KEY
 from orm import orm
+import commands
 
 NO_DEFAULT = object()
 
@@ -52,14 +60,14 @@ class RowElement(ContentElement):
         self.name = name.lower()
         self.parent = parent
         self.attrs = dict([(k.lower(), ContentElement(v)) for k, v in attrs.items()])
-        self._content = ''
+        self._content = u''
         self.sub_elements = {}
 
         if parent:
             parent.add(self)
 
     def add_to_content(self, ch):
-        self._content += ch
+        self._content += unicode(ch)
 
     def add(self, sub):
         curr = self.sub_elements.get(sub.name, None)
@@ -121,11 +129,12 @@ class RowElement(ContentElement):
 
 
 class TableHandler(ContentHandler):
-    def __init__(self, root_name, row_name, callback, callback_args = []):
+    def __init__(self, root_name, row_name, callback, callback_args = [], ping = None):
         self.root_name = root_name.lower()
         self.row_name = row_name.lower()
         self.callback = callback
         self.callback_args = callback_args
+        self.ping = ping
 
         self._reset()
 
@@ -154,6 +163,9 @@ class TableHandler(ContentHandler):
             pass
         elif name == self.row_name:
             self.callback(self.curr_element, *self.callback_args)
+            if self.ping:
+                self.ping()
+
             self._reset()
         else:
             self.curr_element = self.curr_element.parent
@@ -169,18 +181,99 @@ class SaxErrorHandler(ErrorHandler):
     def warning(self, e):
         raise e
 
+def disable_triggers():
+    if db.backend_name == "postgres":
+        db.start_transaction()
+        db.execute_many(commands.PG_DISABLE_TRIGGERS)
+        db.commit_transaction()
+
+def enable_triggers():
+    if db.backend_name == "postgres":
+        db.start_transaction()
+        db.execute_many(commands.PG_ENABLE_TRIGGERS)
+        db.commit_transaction()
+
+def reset_sequences():
+    if db.backend_name == "postgres":
+        db.start_transaction()
+        db.execute_many(commands.PG_SEQUENCE_RESETS)
+        db.commit_transaction()
+
 FILE_HANDLERS = []
 
 def start_import(fname, user):
+
+    start_time = datetime.datetime.now()
+    steps = [s for s in FILE_HANDLERS]
+
+    with open(os.path.join(TMP_FOLDER, 'backup.inf'), 'r') as inffile:
+        inf = ConfigParser.SafeConfigParser()
+        inf.readfp(inffile)
+
+        state = dict([(s['id'], {
+            'status': _('Queued'), 'count': int(inf.get(META_INF_SECTION, s['id'])), 'parsed': 0
+        }) for s in steps] + [
+            ('overall', {
+                'status': _('Starting'), 'count': int(inf.get(META_INF_SECTION, 'overall')), 'parsed': 0
+            })
+        ])
+
+    full_state = dict(running=True, state=state, time_started="")
+
+    def set_state():
+        full_state['time_started'] = diff_date(start_time)
+        cache.set(CACHE_KEY, full_state)
+
+    set_state()
+
+    def ping_state(name):
+        state[name]['parsed'] += 1
+        state['overall']['parsed'] += 1
+        set_state()
+
+    def run(fn, name):
+        def ping():
+            ping_state(name)
+
+        state['overall']['status'] = _('Importing %s') % s['name']
+        state[name]['status'] = _('Importing')
+
+
+        fn(TMP_FOLDER, user, ping)
+
+        state[name]['status'] = _('Done')
+
+        set_state()
+
+        return fname
+
     #dump = tarfile.open(fname, 'r')
     #dump.extractall(TMP_FOLDER)
 
-    for h in FILE_HANDLERS:
-        h(TMP_FOLDER, user)
+    try:
+
+        disable_triggers()
+        db.start_transaction()
+
+        for h in FILE_HANDLERS:
+            run(h['fn'], h['id'])
+
+        raise Exception
+        db.commit_transaction()
+        enable_triggers()
+
+        reset_sequences()
+    except Exception, e:
+        full_state['running'] = False
+        full_state['errors'] = "%s: %s" % (e.__class__.__name__, unicode(e))
+        set_state()
+
+        import traceback
+        logging.error("Error executing xml import: \n %s" % (traceback.format_exc()))
 
-def file_handler(file_name, root_tag, el_tag, args_handler=None, pre_callback=None, post_callback=None):
+def file_handler(file_name, root_tag, el_tag, name, args_handler=None, pre_callback=None, post_callback=None):
     def decorator(fn):
-        def decorated(location, current_user):
+        def decorated(location, current_user, ping):
             if pre_callback:
                 pre_callback(current_user)
 
@@ -190,7 +283,7 @@ def file_handler(file_name, root_tag, el_tag, args_handler=None, pre_callback=No
                 args = []
 
             parser = make_parser()
-            handler = TableHandler(root_tag, el_tag, fn, args)
+            handler = TableHandler(root_tag, el_tag, fn, args, ping)
             parser.setContentHandler(handler)
             #parser.setErrorHandler(SaxErrorHandler())
 
@@ -199,15 +292,17 @@ def file_handler(file_name, root_tag, el_tag, args_handler=None, pre_callback=No
             if post_callback:
                 post_callback()
 
-        FILE_HANDLERS.append(decorated)
+        FILE_HANDLERS.append(dict(id=root_tag, name=name, fn=decorated))
         return decorated
     return decorator
 
 
-@file_handler('users.xml', 'users', 'user', args_handler=lambda u: [u])
+@file_handler('users.xml', 'users', 'user', _('Users'), args_handler=lambda u: [u])
 def user_import(row, current_user):
+    existent = False
+
     if str(current_user.id) == row.getc('id'):
-        return
+        existent = True
 
     roles = row.get('roles').get_listc('role')
     valid_email = row.get('email').get_attr('validated').as_bool()
@@ -247,13 +342,19 @@ def user_import(row, current_user):
     attributes.update(dict([(str(k), v.as_bool()) for k, v in notifications.get('autoSubscribe').attrs.items()]))
     attributes.update(dict([(str("notify_%s" % k), v.as_bool()) for k, v in notifications.get('notifyOnSubscribed').attrs.items()]))
 
-    orm.SubscriptionSettings(user=user, enable_notifications=notifications.get_attr('enabled').as_bool(), **attributes).save()
+    ss = orm.SubscriptionSettings(user=user, enable_notifications=notifications.get_attr('enabled').as_bool(), **attributes)
+
+    if existent:
+        ss.id = current_user.subscription_settings.id
+
+    ss.save()
+        
 
 def pre_tag_import(user):
     tag_import.tag_mappings={}
 
 
-@file_handler('tags.xml', 'tags', 'tag', pre_callback=pre_tag_import)
+@file_handler('tags.xml', 'tags', 'tag', _('Tags'), pre_callback=pre_tag_import)
 def tag_import(row):
     tag = orm.Tag(name=row.getc('name'), used_count=row.get('used').as_int(), created_by_id=row.get('author').as_int())
     tag.save()
@@ -263,7 +364,7 @@ def tag_import(row):
 def post_node_import():
     tag_import.tag_mappings = None
 
-@file_handler('nodes.xml', 'nodes', 'node', args_handler=lambda u: [tag_import.tag_mappings], post_callback=post_node_import)
+@file_handler('nodes.xml', 'nodes', 'node', _('Nodes'), args_handler=lambda u: [tag_import.tag_mappings], post_callback=post_node_import)
 def node_import(row, tags):
 
     ntags = []
@@ -280,6 +381,7 @@ def node_import(row, tags):
             added_at      = row.get('date').as_datetime(),
             parent_id     = row.get('parent').as_int(None),
             abs_parent_id = row.get('absparent').as_int(None),
+            score         = row.get('score').as_int(0),
 
             last_activity_by_id = last_act.get('by').as_int(None),
             last_activity_at    = last_act.get('at').as_datetime(None),
@@ -335,7 +437,7 @@ def post_action_import_callback():
         n.state_string = "".join(["(%s)" % s for s in n.states.values_list('state_type')])
         n.save()
 
-@file_handler('actions.xml', 'actions', 'action', post_callback=post_action_import_callback)
+@file_handler('actions.xml', 'actions', 'action', _('Actions'), post_callback=post_action_import_callback)
 def actions_import(row):
     action = orm.Action(
         id           = row.get('id').as_int(),
@@ -350,8 +452,9 @@ def actions_import(row):
 
     canceled = row.get('canceled')
     if canceled.get_attr('state').as_bool():
+        action.canceled = True
         action.canceled_by_id = canceled.get('user').as_int()
-        action.canceled_at = canceled.get('date').as_datetime(),
+        #action.canceled_at = canceled.get('date').as_datetime(),
         action.canceled_ip = canceled.getc('ip')
 
     action.save()
@@ -368,7 +471,7 @@ def actions_import(row):
             by_canceled = by_canceled
         ).save()
 
-    if (not action.canceled) and action.action_type in POST_ACTION:
+    if (not action.canceled) and (action.action_type in POST_ACTION):
         POST_ACTION[action.action_type](row, action)
 
 
@@ -403,15 +506,28 @@ def award_import_args(user):
     return [ dict([ (b.cls, b) for b in orm.Badge.objects.all() ]) ]
 
 
-@file_handler('awards.xml', 'awards', 'award', args_handler=award_import_args)
+@file_handler('awards.xml', 'awards', 'award', _('Awards'), args_handler=award_import_args)
 def awards_import(row, badges):
-    award = orm.Award(
-        user_id = row.get('user').as_int(),
-        badge = badges[row.getc('badge')],
-        node_id = row.get('node').as_int(None),
-        action_id = row.get('action').as_int(None),
-        trigger_id = row.get('trigger').as_int(None)
-    ).save()
+    try:
+        badge_type = badges.get(row.getc('badge'), None)
+
+        if not badge_type:
+            return
+
+        award = orm.Award(
+            user_id = row.get('user').as_int(),
+            badge = badges[row.getc('badge')],
+            node_id = row.get('node').as_int(None),
+            action_id = row.get('action').as_int(None),
+            trigger_id = row.get('trigger').as_int(None)
+        ).save()
+    except Exception, e:
+        return
+
+
+@file_handler('settings.xml', 'settings', 'setting', _('Settings'))
+def settings_import(row):
+    orm.KeyValue(key=row.getc('key'), value=row.get('value').as_pickled())