From be21ffc1d577c80cb17e67f5ec27eb057d084af1 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Mon, 15 Dec 2014 16:13:46 +0200 Subject: [PATCH] Refactor aggregates --- sqlalchemy_utils/aggregates.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index a828b23..958ed1b 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -404,7 +404,7 @@ class AggregatedAttribute(declared_attr): return desc.column -def get_aggregate_query(agg_expr, relationships): +def aggregate_select(agg_expr, relationships): """ Return a subquery for fetching an aggregate value of given aggregate expression and given sequence of relationships. @@ -474,24 +474,25 @@ def local_condition(prop, objects): return column.in_(values) +def aggregate_expression(expr, class_): + if isinstance(expr, sa.sql.visitors.Visitable): + return expr + elif isinstance(expr, _FunctionGenerator): + return expr(sa.sql.text('1')) + else: + return expr(class_) + + class AggregatedValue(object): def __init__(self, class_, attr, relationships, expr): self.class_ = class_ self.attr = attr self.relationships = relationships - self.expr = self.aggregate_expression(expr, class_) - - def aggregate_expression(self, expr, class_): - if isinstance(expr, sa.sql.visitors.Visitable): - return expr - elif isinstance(expr, _FunctionGenerator): - return expr(sa.sql.text('1')) - else: - return expr(class_) + self.expr = aggregate_expression(expr, class_) @property def aggregate_query(self): - query = get_aggregate_query(self.expr, self.relationships) + query = aggregate_select(self.expr, self.relationships) return query.correlate(self.class_).as_scalar()