]> git.openstreetmap.org Git - osqa.git/blobdiff - forum/models/utils.py
OSQA-349
[osqa.git] / forum / models / utils.py
index 5ac89001c07ceaaab616ca92b2e9a2308cd23771..ecbe038c15072b4c2c4619ac4d375bd54f2196de 100644 (file)
@@ -11,9 +11,13 @@ except ImportError:
 from copy import deepcopy
 from base64 import b64encode, b64decode
 from zlib import compress, decompress
+import re
 
+from base import BaseModel
 
-class PickledObject(str):
+MAX_MARKABLE_STRING_LENGTH = 100
+
+class PickledObject(unicode):
     pass
 
 def dbsafe_encode(value, compress_object=True):
@@ -33,6 +37,9 @@ def dbsafe_decode(value, compress_object=True):
 class PickledObjectField(models.Field):
     __metaclass__ = models.SubfieldBase
 
+    marker_re = re.compile(r'^T\[(?P<type>\w+)\](?P<value>.*)$', re.DOTALL)
+    markable_types = dict((t.__name__, t) for t in (str, int, unicode))
+
     def __init__(self, *args, **kwargs):
         self.compress = kwargs.pop('compress', True)
         self.protocol = kwargs.pop('protocol', 2)
@@ -40,6 +47,20 @@ class PickledObjectField(models.Field):
         kwargs.setdefault('editable', False)
         super(PickledObjectField, self).__init__(*args, **kwargs)
 
+    def generate_type_marked_value(self, value):
+        return PickledObject(u"T[%s]%s" % (type(value).__name__, value))
+
+    def read_marked_value(self, value):
+        m = self.marker_re.match(value)
+
+        if m:
+            marker = m.group('type')
+            value = m.group('value')
+            if marker in self.markable_types:
+                value = self.markable_types[marker](value)
+
+        return value
+
     def get_default(self):
         if self.has_default():
             if callable(self.default):
@@ -51,7 +72,10 @@ class PickledObjectField(models.Field):
     def to_python(self, value):
         if value is not None:
             try:
-                value = dbsafe_decode(value, self.compress)
+                if value.startswith("T["):
+                    value = self.read_marked_value(value)
+                else:
+                    value = dbsafe_decode(value, self.compress)
             except:
                 if isinstance(value, PickledObject):
                     raise
@@ -59,7 +83,11 @@ class PickledObjectField(models.Field):
 
     def get_db_prep_value(self, value):
         if value is not None and not isinstance(value, PickledObject):
-            value = force_unicode(dbsafe_encode(value, self.compress))
+            if type(value).__name__ in self.markable_types and not (isinstance(value, basestring) and len(value
+                                                                                                          ) > MAX_MARKABLE_STRING_LENGTH):
+                value = unicode(self.generate_type_marked_value(value))
+            else:
+                value = unicode(dbsafe_encode(value, self.compress))
         return value
 
     def value_to_string(self, obj):
@@ -75,43 +103,22 @@ class PickledObjectField(models.Field):
         return super(PickledObjectField, self).get_db_prep_lookup(lookup_type, value)
 
 
-class KeyValueManager(models.Manager):
-
-    def create_cache_key(self, key):
-        return "%s:keyvalue:%s" % (settings.APP_URL, key)
-
-    def save_to_cache(self, instance):
-        cache.set(self.create_cache_key(instance.key), instance, 2592000)
-
-    def remove_from_cache(self, instance):
-        cache.delete(self.create_cache_key(instance.key))
-
-    def get(self, **kwargs):
-        if 'key' in kwargs:
-            instance = cache.get(self.create_cache_key(kwargs['key']))
-
-            if instance is None:
-                instance = super(KeyValueManager, self).get(**kwargs)
-                self.save_to_cache(instance)
-
-            return instance
-
-        else:
-            return super(KeyValueManager, self).get(**kwargs)
-
-class KeyValue(models.Model):
+class KeyValue(BaseModel):
     key = models.CharField(max_length=255, unique=True)
     value = PickledObjectField()
 
-    objects = KeyValueManager()
-
     class Meta:
         app_label = 'forum'
 
-    def save(self, *args, **kwargs):
-        super(KeyValue, self).save(*args, **kwargs)
-        KeyValue.objects.save_to_cache(self)
+    def cache_key(self):
+        return self._generate_cache_key(self.key)
+
+    @classmethod
+    def infer_cache_key(cls, querydict):
+        try:
+            key = [v for (k, v) in querydict.items() if k in ('key', 'key__exact')][0]
+
+            return cls._generate_cache_key(key)
+        except:
+            return None
 
-    def delete(self):
-        KeyValue.objects.remove_from_cache(self)
-        super(KeyValue, self).delete()