From 450fe55cc98f4a856a70acb9414817a99e61cd71 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 21 Jan 2014 22:34:30 +0200 Subject: [PATCH] Make aggregates use weakrefs --- sqlalchemy_utils/aggregates.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index 2f71373..22b2a7e 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -305,6 +305,7 @@ TODO from collections import defaultdict +from weakref import WeakKeyDictionary import sqlalchemy as sa import six @@ -317,6 +318,9 @@ except ImportError: from sqlalchemy.sql.expression import _FunctionGenerator +aggregated_attrs = WeakKeyDictionary(defaultdict(list)) + + class AggregatedAttribute(declared_attr): def __init__( self, @@ -332,12 +336,10 @@ class AggregatedAttribute(declared_attr): self.relationship = relationship def __get__(desc, self, cls): - if '__aggregates__' not in cls.__dict__: - cls.__aggregates__ = {} - cls.__aggregates__[desc.fget.__name__] = { - 'expression': desc.fget, - 'relationship': desc.relationship - } + if cls not in aggregated_attrs: + aggregated_attrs[cls] = [(desc.fget, desc.relationship)] + else: + aggregated_attrs[cls].append((desc.fget, desc.relationship)) return desc.column @@ -449,12 +451,11 @@ class AggregationManager(object): def reset(self): self.generator_registry = defaultdict(list) - self.pending_queries = defaultdict(list) def register_listeners(self): sa.event.listen( sa.orm.mapper, - 'mapper_configured', + 'after_configured', self.update_generator_registry ) sa.event.listen( @@ -463,30 +464,30 @@ class AggregationManager(object): self.construct_aggregate_queries ) - def update_generator_registry(self, mapper, class_): - if '__aggregates__' in class_.__dict__: - for key, value in six.iteritems(class_.__aggregates__): + def update_generator_registry(self): + for class_, attrs in six.iteritems(aggregated_attrs): + for expr, relationship in attrs: relationships = [] rel_class = class_ - for path_name in value['relationship'].split('.'): + for path_name in relationship.split('.'): rel = getattr(rel_class, path_name) relationships.append(rel) rel_class = rel.mapper.class_ - self.generator_registry[rel_class.__name__].append( + self.generator_registry[rel_class].append( AggregatedValue( class_=class_, - attr=key, + attr=expr.__name__, relationships=list(reversed(relationships)), - expr=value['expression'](class_) + expr=expr(class_) ) ) def construct_aggregate_queries(self, session, ctx): object_dict = defaultdict(list) for obj in session: - class_ = obj.__class__.__name__ + class_ = obj.__class__ if class_ in self.generator_registry: object_dict[class_].append(obj)