From 795c3f2f45643b272b1ccf401d82086f8afac9ab Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Sat, 4 Jan 2014 02:41:47 +0200 Subject: [PATCH] Refactor aggregates module --- sqlalchemy_utils/aggregates.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index fa84610..18d2621 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -286,13 +286,15 @@ class AggregatedValue(object): self.class_ = class_ self.attr = attr self.relationships = relationships + self.expr = self.aggregate_expression(expr, class_) + def aggregate_expression(self, expr, class_): if isinstance(expr, sa.sql.visitors.Visitable): - self.expr = expr + return expr elif isinstance(expr, _FunctionGenerator): - self.expr = expr(sa.sql.text('1')) + return expr(sa.sql.text('1')) else: - self.expr = expr(class_) + return expr(class_) @property def aggregate_query(self): @@ -342,6 +344,10 @@ class AggregatedValue(object): # ) property_ = self.relationships[-1].property remote_pairs = property_.local_remote_pairs + local = remote_pairs[0][0] + remote = remote_paris[0][1] + + from_ = property_.mapper.class_.__table__ for relationship in reversed(self.relationships[1:-1]): property_ = relationship.property @@ -352,15 +358,13 @@ class AggregatedValue(object): property_ = self.relationships[0].property query = query.where( - remote_pairs[0][0].in_( + local.in_( sa.select( - [remote_pairs[0][1]], + [remote], from_obj=[from_] ).where( - property_.local_remote_pairs[0][0].in_( - getattr( - obj, property_.local_remote_pairs[0][1].key - ) + local.in_( + getattr(obj, remote.key) for obj in objects ) )