From 4a916e5fabd1e92437dacd2d72da24e40009c85d Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Sat, 4 Jan 2014 03:07:38 +0200 Subject: [PATCH] Refactor aggregates --- sqlalchemy_utils/aggregates.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index 45fecd5..0543d2b 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -331,7 +331,7 @@ class AggregatedValue(object): if len(self.relationships) == 1: remote_pairs = self.relationships[-1].property.local_remote_pairs - query = query.where( + return query.where( remote_pairs[0][0].in_( getattr(obj, remote_pairs[0][1].key) for obj in objects ) @@ -352,29 +352,31 @@ class AggregatedValue(object): local = remote_pairs[0][0] remote = remote_pairs[0][1] - - from_ = property_.mapper.class_.__table__ - for relationship in reversed(self.relationships[1:-1]): - property_ = relationship.property - from_ = ( - from_.join(property_.mapper.class_, property_.primaryjoin) - ) - - property_ = self.relationships[0].property - - query = query.where( + return query.where( local.in_( sa.select( [remote], - from_obj=[from_] + from_obj=[self.multi_level_aggregate_query_base] ).where( - self.local_condition(property_, objects) + self.local_condition( + self.relationships[0].property, + objects + ) ) ) ) - return query + @property + def multi_level_aggregate_query_base(self): + property_ = self.relationships[-1].property + from_ = property_.mapper.class_.__table__ + for relationship in reversed(self.relationships[1:-1]): + property_ = relationship.property + from_ = ( + from_.join(property_.mapper.class_, property_.primaryjoin) + ) + return from_ def local_condition(self, prop, objects): return prop.local_remote_pairs[0][0].in_(