Refactor aggregates

This commit is contained in:
Konsta Vesterinen
2014-12-15 16:13:46 +02:00
parent ed158c8f9f
commit be21ffc1d5

View File

@@ -404,7 +404,7 @@ class AggregatedAttribute(declared_attr):
return desc.column
def get_aggregate_query(agg_expr, relationships):
def aggregate_select(agg_expr, relationships):
"""
Return a subquery for fetching an aggregate value of given aggregate
expression and given sequence of relationships.
@@ -474,24 +474,25 @@ def local_condition(prop, objects):
return column.in_(values)
def aggregate_expression(expr, class_):
if isinstance(expr, sa.sql.visitors.Visitable):
return expr
elif isinstance(expr, _FunctionGenerator):
return expr(sa.sql.text('1'))
else:
return expr(class_)
class AggregatedValue(object):
def __init__(self, class_, attr, relationships, expr):
self.class_ = class_
self.attr = attr
self.relationships = relationships
self.expr = self.aggregate_expression(expr, class_)
def aggregate_expression(self, expr, class_):
if isinstance(expr, sa.sql.visitors.Visitable):
return expr
elif isinstance(expr, _FunctionGenerator):
return expr(sa.sql.text('1'))
else:
return expr(class_)
self.expr = aggregate_expression(expr, class_)
@property
def aggregate_query(self):
query = get_aggregate_query(self.expr, self.relationships)
query = aggregate_select(self.expr, self.relationships)
return query.correlate(self.class_).as_scalar()