Make aggregates use weakrefs
This commit is contained in:
@@ -305,6 +305,7 @@ TODO
|
|||||||
|
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from weakref import WeakKeyDictionary
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
import six
|
import six
|
||||||
@@ -317,6 +318,9 @@ except ImportError:
|
|||||||
from sqlalchemy.sql.expression import _FunctionGenerator
|
from sqlalchemy.sql.expression import _FunctionGenerator
|
||||||
|
|
||||||
|
|
||||||
|
aggregated_attrs = WeakKeyDictionary(defaultdict(list))
|
||||||
|
|
||||||
|
|
||||||
class AggregatedAttribute(declared_attr):
|
class AggregatedAttribute(declared_attr):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -332,12 +336,10 @@ class AggregatedAttribute(declared_attr):
|
|||||||
self.relationship = relationship
|
self.relationship = relationship
|
||||||
|
|
||||||
def __get__(desc, self, cls):
|
def __get__(desc, self, cls):
|
||||||
if '__aggregates__' not in cls.__dict__:
|
if cls not in aggregated_attrs:
|
||||||
cls.__aggregates__ = {}
|
aggregated_attrs[cls] = [(desc.fget, desc.relationship)]
|
||||||
cls.__aggregates__[desc.fget.__name__] = {
|
else:
|
||||||
'expression': desc.fget,
|
aggregated_attrs[cls].append((desc.fget, desc.relationship))
|
||||||
'relationship': desc.relationship
|
|
||||||
}
|
|
||||||
return desc.column
|
return desc.column
|
||||||
|
|
||||||
|
|
||||||
@@ -449,12 +451,11 @@ class AggregationManager(object):
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.generator_registry = defaultdict(list)
|
self.generator_registry = defaultdict(list)
|
||||||
self.pending_queries = defaultdict(list)
|
|
||||||
|
|
||||||
def register_listeners(self):
|
def register_listeners(self):
|
||||||
sa.event.listen(
|
sa.event.listen(
|
||||||
sa.orm.mapper,
|
sa.orm.mapper,
|
||||||
'mapper_configured',
|
'after_configured',
|
||||||
self.update_generator_registry
|
self.update_generator_registry
|
||||||
)
|
)
|
||||||
sa.event.listen(
|
sa.event.listen(
|
||||||
@@ -463,30 +464,30 @@ class AggregationManager(object):
|
|||||||
self.construct_aggregate_queries
|
self.construct_aggregate_queries
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_generator_registry(self, mapper, class_):
|
def update_generator_registry(self):
|
||||||
if '__aggregates__' in class_.__dict__:
|
for class_, attrs in six.iteritems(aggregated_attrs):
|
||||||
for key, value in six.iteritems(class_.__aggregates__):
|
for expr, relationship in attrs:
|
||||||
relationships = []
|
relationships = []
|
||||||
rel_class = class_
|
rel_class = class_
|
||||||
|
|
||||||
for path_name in value['relationship'].split('.'):
|
for path_name in relationship.split('.'):
|
||||||
rel = getattr(rel_class, path_name)
|
rel = getattr(rel_class, path_name)
|
||||||
relationships.append(rel)
|
relationships.append(rel)
|
||||||
rel_class = rel.mapper.class_
|
rel_class = rel.mapper.class_
|
||||||
|
|
||||||
self.generator_registry[rel_class.__name__].append(
|
self.generator_registry[rel_class].append(
|
||||||
AggregatedValue(
|
AggregatedValue(
|
||||||
class_=class_,
|
class_=class_,
|
||||||
attr=key,
|
attr=expr.__name__,
|
||||||
relationships=list(reversed(relationships)),
|
relationships=list(reversed(relationships)),
|
||||||
expr=value['expression'](class_)
|
expr=expr(class_)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def construct_aggregate_queries(self, session, ctx):
|
def construct_aggregate_queries(self, session, ctx):
|
||||||
object_dict = defaultdict(list)
|
object_dict = defaultdict(list)
|
||||||
for obj in session:
|
for obj in session:
|
||||||
class_ = obj.__class__.__name__
|
class_ = obj.__class__
|
||||||
if class_ in self.generator_registry:
|
if class_ in self.generator_registry:
|
||||||
object_dict[class_].append(obj)
|
object_dict[class_].append(obj)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user