X-Git-Url: https://git.openstreetmap.org./osqa.git/blobdiff_plain/0ba16baba0615dd405486c7d87f943d71518375c..edcd42cff2f41ebea0f49df0390a17886d32bd38:/forum/models/base.py diff --git a/forum/models/base.py b/forum/models/base.py index dfdf989..c5a80f9 100644 --- a/forum/models/base.py +++ b/forum/models/base.py @@ -20,24 +20,33 @@ import logging class LazyQueryList(object): def __init__(self, model, items): - self.model = model self.items = items + self.model = model def __getitem__(self, k): - return self.model.objects.get(id=self.items[k]) + return self.model.objects.get(id=self.items[k][0]) def __iter__(self): for id in self.items: - yield self.model.objects.get(id=id) + yield self.model.objects.get(id=id[0]) def __len__(self): return len(self.items) class CachedQuerySet(models.query.QuerySet): + def lazy(self): - if len(self.query.extra) == 0: - return LazyQueryList(self.model, list(self.values_list('id', flat=True))) + if not len(self.query.aggregates): + values_list = ['id'] + + if len(self.query.extra): + extra_keys = self.query.extra.keys() + values_list += extra_keys + + return LazyQueryList(self.model, list(self.values_list(*values_list))) else: + if len(self.query.extra): + print self.query.extra return self def obj_from_datadict(self, datadict): @@ -46,14 +55,9 @@ class CachedQuerySet(models.query.QuerySet): return obj def get(self, *args, **kwargs): - try: - pk = [v for (k,v) in kwargs.items() if k in ('pk', 'pk__exact', 'id', 'id__exact' - ) or k.endswith('_ptr__pk') or k.endswith('_ptr__id')][0] - except: - pk = None + key = self.model.infer_cache_key(kwargs) - if pk is not None: - key = self.model.cache_key(pk) + if key is not None: obj = cache.get(key) if obj is None: @@ -65,7 +69,7 @@ class CachedQuerySet(models.query.QuerySet): return obj - return super(CachedQuerySet, self).get(*args, **kwargs) + return super(CachedQuerySet, self).get(*args, **kwargs) class CachedManager(models.Manager): use_for_related_fields = True @@ -81,9 +85,9 @@ class CachedManager(models.Manager): class DenormalizedField(object): - def __init__(self, manager, **kwargs): + def __init__(self, manager, *args, **kwargs): self.manager = manager - self.filter = kwargs + self.filter = (args, kwargs) def setup_class(self, cls, name): dict_name = '_%s_dencache_' % name @@ -92,7 +96,7 @@ class DenormalizedField(object): val = inst.__dict__.get(dict_name, None) if val is None: - val = getattr(inst, self.manager).filter(**self.filter).count() + val = getattr(inst, self.manager).filter(*self.filter[0], **self.filter[1]).count() inst.__dict__[dict_name] = val inst.cache() @@ -135,14 +139,13 @@ class BaseModel(models.Model): def __init__(self, *args, **kwargs): super(BaseModel, self).__init__(*args, **kwargs) - self.reset_original_state() - - @classmethod - def cache_key(cls, pk): - return '%s:%s:%s' % (settings.APP_URL, cls.__name__, pk) + self.reset_original_state(kwargs.keys()) - def reset_original_state(self): + def reset_original_state(self, reset_fields=None): self._original_state = self._as_dict() + + if reset_fields: + self._original_state.update(dict([(f, None) for f in reset_fields])) def get_dirty_fields(self): return [f.name for f in self._meta.fields if self._original_state[f.attname] != self.__dict__[f.attname]] @@ -157,10 +160,10 @@ class BaseModel(models.Model): (f.name, getattr(self, f.name)) for f in self._meta.fields if self._original_state[f.attname] != self.__dict__[f.attname] ]) - def save(self, *args, **kwargs): + def save(self, full_save=False, *args, **kwargs): put_back = [k for k, v in self.__dict__.items() if isinstance(v, models.expressions.ExpressionNode)] - if self.id: + if self.id and not full_save: self.__class__.objects.filter(id=self.id).update(**self._get_update_kwargs()) else: super(BaseModel, self).save() @@ -177,11 +180,31 @@ class BaseModel(models.Model): self.reset_original_state() self.cache() + @classmethod + def _generate_cache_key(cls, key, group=None): + if group is None: + group = cls.__name__ + + return '%s:%s:%s' % (settings.APP_URL, group, key) + + def cache_key(self): + return self._generate_cache_key(self.id) + + @classmethod + def infer_cache_key(cls, querydict): + try: + pk = [v for (k,v) in querydict.items() if k in ('pk', 'pk__exact', 'id', 'id__exact' + ) or k.endswith('_ptr__pk') or k.endswith('_ptr__id')][0] + + return cls._generate_cache_key(pk) + except: + return None + def cache(self): - cache.set(self.cache_key(self.id), self._as_dict(), 60 * 60) + cache.set(self.cache_key(), self._as_dict(), 60 * 60) def uncache(self): - cache.delete(self.cache_key(self.id)) + cache.delete(self.cache_key()) def delete(self): self.uncache()