]> git.openstreetmap.org Git - osqa.git/blobdiff - forum_modules/sximporter/importer.py
the mark_favorite view should accept not only Questions but any Nodes (required for...
[osqa.git] / forum_modules / sximporter / importer.py
index 5277b8dd3720fa83bf68650e5d2a8288a11d05b6..6669b8c0fe7dc8bf44efd9b951b8ba8fe3758c22 100644 (file)
@@ -1,15 +1,11 @@
 # -*- coding: utf-8 -*-
 
 # -*- coding: utf-8 -*-
 
-from xml.dom import minidom
-from datetime import datetime, timedelta
+from datetime import datetime
 import time
 import re
 import os
 import gc
 from django.utils.translation import ugettext as _
 import time
 import re
 import os
 import gc
 from django.utils.translation import ugettext as _
-from django.template.defaultfilters import slugify
-from forum.models.utils import dbsafe_encode
-from orm import orm
 
 from django.utils.encoding import force_unicode
 
 
 from django.utils.encoding import force_unicode
 
@@ -25,6 +21,33 @@ from zlib import compress, decompress
 from xml.sax import make_parser
 from xml.sax.handler import ContentHandler
 
 from xml.sax import make_parser
 from xml.sax.handler import ContentHandler
 
+def create_orm():
+    from django.conf import settings
+    from south.orm import FakeORM
+
+    get_migration_number_re = re.compile(r'^((\d+)_.*)\.py$')
+
+    migrations_folder = os.path.join(settings.SITE_SRC_ROOT, 'forum/migrations')
+
+    highest_number = 0
+    highest_file = None
+
+    for f in os.listdir(migrations_folder):
+        if os.path.isfile(os.path.join(migrations_folder, f)):
+            m = get_migration_number_re.match(f)
+
+            if m:
+                found = int(m.group(2))
+
+                if found > highest_number:
+                    highest_number = found
+                    highest_file = m.group(1)
+
+    mod = __import__('forum.migrations.%s' % highest_file, globals(), locals(), ['forum.migrations'])
+    return FakeORM(getattr(mod, 'Migration'), "forum")
+
+orm = create_orm()
+
 class SXTableHandler(ContentHandler):
     def __init__(self, fname, callback):
         self.in_row = False
 class SXTableHandler(ContentHandler):
     def __init__(self, fname, callback):
         self.in_row = False
@@ -151,9 +174,13 @@ class UnknownYahooUser(UnknownUser):
 
 
 class IdMapper(dict):
 
 
 class IdMapper(dict):
+
+    def __init__(self):
+        self.default = 1
+
     def __getitem__(self, key):
         key = int(key)
     def __getitem__(self, key):
         key = int(key)
-        return super(IdMapper, self).get(key, 1)
+        return super(IdMapper, self).get(key, self.default)
 
     def __setitem__(self, key, value):
         super(IdMapper, self).__setitem__(int(key), int(value))
 
     def __setitem__(self, key, value):
         super(IdMapper, self).__setitem__(int(key), int(value))
@@ -167,43 +194,40 @@ class IdIncrementer():
 
 openidre = re.compile('^https?\:\/\/')
 def userimport(path, options):
 
 openidre = re.compile('^https?\:\/\/')
 def userimport(path, options):
-#users = readTable(dump, "Users")
 
 
-    user_by_name = {}
+    usernames = []
+    openids = set()
     uidmapper = IdMapper()
     uidmapper = IdMapper()
-    #merged_users = []
 
 
+    authenticated_user = options.get('authenticated_user', None)
     owneruid = options.get('owneruid', None)
     #check for empty values
     if not owneruid:
         owneruid = None
     owneruid = options.get('owneruid', None)
     #check for empty values
     if not owneruid:
         owneruid = None
+    else:
+        owneruid = int(owneruid)
 
     def callback(sxu):
         create = True
 
     def callback(sxu):
         create = True
+        set_mapper_defaults = False
 
         if sxu.get('id') == '-1':
             return
         #print "\n".join(["%s : %s" % i for i in sxu.items()])
 
         if sxu.get('id') == '-1':
             return
         #print "\n".join(["%s : %s" % i for i in sxu.items()])
-        if int(sxu.get('id')) == int(owneruid):
-            osqau = orm.User.objects.get(id=1)
-            uidmapper[owneruid] = 1
-            uidmapper[-1] = 1
-            create = False
-        else:
-            username = sxu.get('displayname',
-                               sxu.get('displaynamecleaned', sxu.get('realname', final_username_attempt(sxu))))
 
 
-            if not isinstance(username, UnknownUser) and username in user_by_name:
-            #if options.get('mergesimilar', False) and sxu.get('email', 'INVALID') == user_by_name[username].email:
-            #    osqau = user_by_name[username]
-            #    create = False
-            #    uidmapper[sxu.get('id')] = osqau.id
-            #else:
-                inc = 1
-                while ("%s %d" % (username, inc)) in user_by_name:
-                    inc += 1
+        if (owneruid and (int(sxu.get('id')) == owneruid)) or (
+            (not owneruid) and len(uidmapper)):
+
+            set_mapper_defaults = True
+
+            if authenticated_user:
+                osqau = orm.User.objects.get(id=authenticated_user.id)
+
+                for assoc in orm.AuthKeyUserAssociation.objects.filter(user=osqau):
+                    openids.add(assoc.key)
 
 
-                username = "%s %d" % (username, inc)
+                uidmapper[owneruid] = osqau.id
+                create = False
 
         sxbadges = sxu.get('badgesummary', None)
         badges = {'1':'0', '2':'0', '3':'0'}
 
         sxbadges = sxu.get('badgesummary', None)
         badges = {'1':'0', '2':'0', '3':'0'}
@@ -212,9 +236,28 @@ def userimport(path, options):
             badges.update(dict([b.split('=') for b in sxbadges.split()]))
 
         if create:
             badges.update(dict([b.split('=') for b in sxbadges.split()]))
 
         if create:
+            username = unicode(sxu.get('displayname',
+                               sxu.get('displaynamecleaned', sxu.get('realname', final_username_attempt(sxu)))))[:30]
+
+            if username in usernames:
+            #if options.get('mergesimilar', False) and sxu.get('email', 'INVALID') == user_by_name[username].email:
+            #    osqau = user_by_name[username]
+            #    create = False
+            #    uidmapper[sxu.get('id')] = osqau.id
+            #else:
+                inc = 0
+
+                while True:
+                    inc += 1
+                    totest = "%s %d" % (username[:29 - len(str(inc))], inc)
+
+                    if not totest in usernames:
+                        username = totest
+                        break
+
             osqau = orm.User(
                     id           = sxu.get('id'),
             osqau = orm.User(
                     id           = sxu.get('id'),
-                    username     = unicode(username),
+                    username     = username,
                     password     = '!',
                     email        = sxu.get('email', ''),
                     is_superuser = sxu.get('usertypeid') == '5',
                     password     = '!',
                     email        = sxu.get('email', ''),
                     is_superuser = sxu.get('usertypeid') == '5',
@@ -230,7 +273,7 @@ def userimport(path, options):
                     gold          = int(badges['1']),
                     silver        = int(badges['2']),
                     bronze        = int(badges['3']),
                     gold          = int(badges['1']),
                     silver        = int(badges['2']),
                     bronze        = int(badges['3']),
-                    real_name     = sxu.get('realname', ''),
+                    real_name     = sxu.get('realname', '')[:30],
                     location      = sxu.get('location', ''),
                     )
 
                     location      = sxu.get('location', ''),
                     )
 
@@ -283,22 +326,32 @@ def userimport(path, options):
             #merged_users.append(osqau.id)
             osqau.save()
 
             #merged_users.append(osqau.id)
             osqau.save()
 
-        user_by_name[osqau.username] = osqau
+        if set_mapper_defaults:
+            uidmapper[-1] = osqau.id
+            uidmapper.default = osqau.id
+
+        usernames.append(osqau.username)
 
         openid = sxu.get('openid', None)
 
         openid = sxu.get('openid', None)
-        if openid and openidre.match(openid):
+        if openid and openidre.match(openid) and (not openid in openids):
             assoc = orm.AuthKeyUserAssociation(user=osqau, key=openid, provider="openidurl")
             assoc.save()
             assoc = orm.AuthKeyUserAssociation(user=osqau, key=openid, provider="openidurl")
             assoc.save()
+            openids.add(openid)
+
+        openidalt = sxu.get('openidalt', None)
+        if openidalt and openidre.match(openidalt) and (not openidalt in openids):
+            assoc = orm.AuthKeyUserAssociation(user=osqau, key=openidalt, provider="openidurl")
+            assoc.save()
+            openids.add(openidalt)
 
     readTable(path, "Users", callback)
 
 
     readTable(path, "Users", callback)
 
-    if uidmapper[-1] == -1:
-        uidmapper[-1] = 1
+    #if uidmapper[-1] == -1:
+    #    uidmapper[-1] = 1
 
     return uidmapper
 
 def tagsimport(dump, uidmap):
 
     return uidmapper
 
 def tagsimport(dump, uidmap):
-#tags = readTable(dump, "Tags")
 
     tagmap = {}
 
 
     tagmap = {}
 
@@ -340,17 +393,7 @@ def remove_post_state(name, post):
     post.state_string = "".join("(%s)" % s for s in re.findall('\w+', post.state_string) if s != name)
 
 def postimport(dump, uidmap, tagmap):
     post.state_string = "".join("(%s)" % s for s in re.findall('\w+', post.state_string) if s != name)
 
 def postimport(dump, uidmap, tagmap):
-#history = {}
-#accepted = {}
-    all = []
-
-    #for h in readTable(dump, "PostHistory"):
-    #    if not history.get(h.get('postid'), None):
-    #        history[h.get('postid')] = []
-    #
-    #    history[h.get('postid')].append(h)
-
-    #posts = readTable(dump, "Posts")
+    all = {}
 
     def callback(sxpost):
         nodetype = (sxpost.get('posttypeid') == '1') and "nodetype" or "answer"
 
     def callback(sxpost):
         nodetype = (sxpost.get('posttypeid') == '1') and "nodetype" or "answer"
@@ -411,13 +454,16 @@ def postimport(dump, uidmap, tagmap):
             post.extra_count = sxpost.get('viewcount', 0)
 
             add_tags_to_post(post, tagmap)
             post.extra_count = sxpost.get('viewcount', 0)
 
             add_tags_to_post(post, tagmap)
+            all[int(post.id)] = int(post.id)
 
         else:
             post.parent_id = sxpost['parentid']
 
         else:
             post.parent_id = sxpost['parentid']
+            post.abs_parent_id = sxpost['parentid']
+            all[int(post.id)] = int(sxpost['parentid'])
 
         post.save()
 
 
         post.save()
 
-        all.append(int(post.id))
+        create_and_activate_revision(post)
 
         del post
 
 
         del post
 
@@ -425,8 +471,9 @@ def postimport(dump, uidmap, tagmap):
 
     return all
 
 
     return all
 
-def comment_import(dump, uidmap, posts):
-#comments = readTable(dump, "PostComments")
+def comment_import(dump, uidmap, absparent_map):
+    posts = absparent_map.keys()
+
     currid = IdIncrementer(max(posts))
     mapping = {}
 
     currid = IdIncrementer(max(posts))
     mapping = {}
 
@@ -439,6 +486,7 @@ def comment_import(dump, uidmap, posts):
                 author_id = uidmap[sxc.get('userid', 1)],
                 body = sxc['text'],
                 parent_id = sxc.get('postid'),
                 author_id = uidmap[sxc.get('userid', 1)],
                 body = sxc['text'],
                 parent_id = sxc.get('postid'),
+                abs_parent_id = absparent_map.get(int(sxc.get('postid')), sxc.get('postid'))
                 )
 
         if sxc.get('deletiondate', None):
                 )
 
         if sxc.get('deletiondate', None):
@@ -466,6 +514,8 @@ def comment_import(dump, uidmap, posts):
                 action_date = oc.added_at
                 )
 
                 action_date = oc.added_at
                 )
 
+        create_and_activate_revision(oc)
+
         create_action.save()
         oc.save()
 
         create_action.save()
         oc.save()
 
@@ -480,7 +530,6 @@ def add_tags_to_post(post, tagmap):
     tags = [tag for tag in [tagmap.get(name.strip()) for name in post.tagnames.split(u' ') if name] if tag]
     post.tagnames = " ".join([t.name for t in tags]).strip()
     post.tags = tags
     tags = [tag for tag in [tagmap.get(name.strip()) for name in post.tagnames.split(u' ') if name] if tag]
     post.tagnames = " ".join([t.name for t in tags]).strip()
     post.tags = tags
-    create_and_activate_revision(post)
 
 
 def create_and_activate_revision(post):
 
 
 def create_and_activate_revision(post):
@@ -500,7 +549,6 @@ def create_and_activate_revision(post):
     post.save()
 
 def post_vote_import(dump, uidmap, posts):
     post.save()
 
 def post_vote_import(dump, uidmap, posts):
-#votes = readTable(dump, "Posts2Votes")
     close_reasons = {}
 
     def close_callback(r):
     close_reasons = {}
 
     def close_callback(r):
@@ -630,7 +678,6 @@ def post_vote_import(dump, uidmap, posts):
 
 
 def comment_vote_import(dump, uidmap, comments):
 
 
 def comment_vote_import(dump, uidmap, comments):
-#votes = readTable(dump, "Comments2Votes")
     user2vote = []
     comments2score = {}
 
     user2vote = []
     comments2score = {}
 
@@ -672,7 +719,6 @@ def comment_vote_import(dump, uidmap, comments):
 
 
 def badges_import(dump, uidmap, post_list):
 
 
 def badges_import(dump, uidmap, post_list):
-#node_ctype = orm['contenttypes.contenttype'].objects.get(name='node')
 
     sxbadges = {}
 
 
     sxbadges = {}
 
@@ -727,6 +773,7 @@ def badges_import(dump, uidmap, post_list):
 
         osqaa.save()
         badge.awarded_count += 1
 
         osqaa.save()
         badge.awarded_count += 1
+
         user_badge_count[user_id] += 1
 
     readTable(dump, "Users2Badges", callback)
         user_badge_count[user_id] += 1
 
     readTable(dump, "Users2Badges", callback)
@@ -734,10 +781,19 @@ def badges_import(dump, uidmap, post_list):
     for badge in obadges.values():
         badge.save()
 
     for badge in obadges.values():
         badge.save()
 
-def pages_import(dump, currid):
+def save_setting(k, v):
+    try:
+        kv = orm.KeyValue.objects.get(key=k)
+        kv.value = v
+    except:
+        kv = orm.KeyValue(key = k, value = v)
+
+    kv.save()
+
+
+def pages_import(dump, currid, owner):
     currid = IdIncrementer(currid)
     registry = {}
     currid = IdIncrementer(currid)
     registry = {}
-    #sx_pages = readTable(dump, "FlatPages")
 
     def callback(sxp):
         currid.inc()
 
     def callback(sxp):
         currid.inc()
@@ -756,9 +812,11 @@ def pages_import(dump, currid):
                 'sidebar_render': "html",
                 'comments': False
                 }),
                 'sidebar_render': "html",
                 'comments': False
                 }),
-                author_id = 1
+                author_id = owner
                 )
 
                 )
 
+        create_and_activate_revision(page)
+
         page.save()
         registry[sxp['url'][1:]] = page.id
 
         page.save()
         registry[sxp['url'][1:]] = page.id
 
@@ -782,8 +840,7 @@ def pages_import(dump, currid):
 
     readTable(dump, "FlatPages", callback)
 
 
     readTable(dump, "FlatPages", callback)
 
-    kv = orm.KeyValue(key='STATIC_PAGE_REGISTRY', value=dbsafe_encode(registry))
-    kv.save()
+    save_setting('STATIC_PAGE_REGISTRY', dbsafe_encode(registry))
 
 sx2osqa_set_map = {
 u'theme.html.name': 'APP_TITLE',
 
 sx2osqa_set_map = {
 u'theme.html.name': 'APP_TITLE',
@@ -814,24 +871,17 @@ def html_decode(html):
 
 
 def static_import(dump):
 
 
 def static_import(dump):
-#sx_sets = readTable(dump, "ThemeTextResources")
     sx_unknown = {}
 
     def callback(set):
         if unicode(set['name']) in sx2osqa_set_map:
     sx_unknown = {}
 
     def callback(set):
         if unicode(set['name']) in sx2osqa_set_map:
-            kv = orm.KeyValue(
-                    key = sx2osqa_set_map[set['name']],
-                    value = dbsafe_encode(html_decode(set['value']))
-                    )
-
-            kv.save()
+            save_setting(sx2osqa_set_map[set['name']], dbsafe_encode(html_decode(set['value'])))
         else:
             sx_unknown[set['name']] = html_decode(set['value'])
 
     readTable(dump, "ThemeTextResources", callback)
 
         else:
             sx_unknown[set['name']] = html_decode(set['value'])
 
     readTable(dump, "ThemeTextResources", callback)
 
-    unknown = orm.KeyValue(key='SXIMPORT_UNKNOWN_SETS', value=dbsafe_encode(sx_unknown))
-    unknown.save()
+    save_setting('SXIMPORT_UNKNOWN_SETS', dbsafe_encode(sx_unknown))
 
 def disable_triggers():
     from south.db import db
 
 def disable_triggers():
     from south.db import db
@@ -854,6 +904,13 @@ def reset_sequences():
         db.execute_many(PG_SEQUENCE_RESETS)
         db.commit_transaction()
 
         db.execute_many(PG_SEQUENCE_RESETS)
         db.commit_transaction()
 
+def reindex_fts():
+    from south.db import db
+    if db.backend_name == "postgres":
+        db.start_transaction()
+        db.execute_many("UPDATE forum_noderevision set id = id WHERE TRUE;")
+        db.commit_transaction()
+
 
 def sximport(dump, options):
     try:
 
 def sximport(dump, options):
     try:
@@ -861,6 +918,7 @@ def sximport(dump, options):
         triggers_disabled = True
     except:
         triggers_disabled = False
         triggers_disabled = True
     except:
         triggers_disabled = False
+
     uidmap = userimport(dump, options)
     tagmap = tagsimport(dump, uidmap)
     gc.collect()
     uidmap = userimport(dump, options)
     tagmap = tagsimport(dump, uidmap)
     gc.collect()
@@ -879,7 +937,7 @@ def sximport(dump, options):
 
     badges_import(dump, uidmap, posts)
 
 
     badges_import(dump, uidmap, posts)
 
-    pages_import(dump, max(posts))
+    pages_import(dump, max(posts), uidmap.default)
     static_import(dump)
     gc.collect()
 
     static_import(dump)
     gc.collect()
 
@@ -890,6 +948,7 @@ def sximport(dump, options):
 
     if triggers_disabled:
         enable_triggers()
 
     if triggers_disabled:
         enable_triggers()
+        reindex_fts()
 
 
 PG_DISABLE_TRIGGERS = """
 
 
 PG_DISABLE_TRIGGERS = """
@@ -972,4 +1031,4 @@ SELECT setval('"forum_openidassociation_id_seq"', coalesce(max("id"), 1) + 2, ma
 
 
     
 
 
     
-    
\ No newline at end of file
+