Make aggregates use weakrefs

This commit is contained in:
Konsta Vesterinen
2014-01-21 22:34:30 +02:00
parent 207301580a
commit 450fe55cc9

View File

@@ -305,6 +305,7 @@ TODO
from collections import defaultdict from collections import defaultdict
from weakref import WeakKeyDictionary
import sqlalchemy as sa import sqlalchemy as sa
import six import six
@@ -317,6 +318,9 @@ except ImportError:
from sqlalchemy.sql.expression import _FunctionGenerator from sqlalchemy.sql.expression import _FunctionGenerator
aggregated_attrs = WeakKeyDictionary(defaultdict(list))
class AggregatedAttribute(declared_attr): class AggregatedAttribute(declared_attr):
def __init__( def __init__(
self, self,
@@ -332,12 +336,10 @@ class AggregatedAttribute(declared_attr):
self.relationship = relationship self.relationship = relationship
def __get__(desc, self, cls): def __get__(desc, self, cls):
if '__aggregates__' not in cls.__dict__: if cls not in aggregated_attrs:
cls.__aggregates__ = {} aggregated_attrs[cls] = [(desc.fget, desc.relationship)]
cls.__aggregates__[desc.fget.__name__] = { else:
'expression': desc.fget, aggregated_attrs[cls].append((desc.fget, desc.relationship))
'relationship': desc.relationship
}
return desc.column return desc.column
@@ -449,12 +451,11 @@ class AggregationManager(object):
def reset(self): def reset(self):
self.generator_registry = defaultdict(list) self.generator_registry = defaultdict(list)
self.pending_queries = defaultdict(list)
def register_listeners(self): def register_listeners(self):
sa.event.listen( sa.event.listen(
sa.orm.mapper, sa.orm.mapper,
'mapper_configured', 'after_configured',
self.update_generator_registry self.update_generator_registry
) )
sa.event.listen( sa.event.listen(
@@ -463,30 +464,30 @@ class AggregationManager(object):
self.construct_aggregate_queries self.construct_aggregate_queries
) )
def update_generator_registry(self, mapper, class_): def update_generator_registry(self):
if '__aggregates__' in class_.__dict__: for class_, attrs in six.iteritems(aggregated_attrs):
for key, value in six.iteritems(class_.__aggregates__): for expr, relationship in attrs:
relationships = [] relationships = []
rel_class = class_ rel_class = class_
for path_name in value['relationship'].split('.'): for path_name in relationship.split('.'):
rel = getattr(rel_class, path_name) rel = getattr(rel_class, path_name)
relationships.append(rel) relationships.append(rel)
rel_class = rel.mapper.class_ rel_class = rel.mapper.class_
self.generator_registry[rel_class.__name__].append( self.generator_registry[rel_class].append(
AggregatedValue( AggregatedValue(
class_=class_, class_=class_,
attr=key, attr=expr.__name__,
relationships=list(reversed(relationships)), relationships=list(reversed(relationships)),
expr=value['expression'](class_) expr=expr(class_)
) )
) )
def construct_aggregate_queries(self, session, ctx): def construct_aggregate_queries(self, session, ctx):
object_dict = defaultdict(list) object_dict = defaultdict(list)
for obj in session: for obj in session:
class_ = obj.__class__.__name__ class_ = obj.__class__
if class_ in self.generator_registry: if class_ in self.generator_registry:
object_dict[class_].append(obj) object_dict[class_].append(obj)