Refactor aggregates module

This commit is contained in:
Konsta Vesterinen
2014-01-04 02:41:47 +02:00
parent ac4a2a7829
commit 795c3f2f45

View File

@@ -286,13 +286,15 @@ class AggregatedValue(object):
self.class_ = class_ self.class_ = class_
self.attr = attr self.attr = attr
self.relationships = relationships self.relationships = relationships
self.expr = self.aggregate_expression(expr, class_)
def aggregate_expression(self, expr, class_):
if isinstance(expr, sa.sql.visitors.Visitable): if isinstance(expr, sa.sql.visitors.Visitable):
self.expr = expr return expr
elif isinstance(expr, _FunctionGenerator): elif isinstance(expr, _FunctionGenerator):
self.expr = expr(sa.sql.text('1')) return expr(sa.sql.text('1'))
else: else:
self.expr = expr(class_) return expr(class_)
@property @property
def aggregate_query(self): def aggregate_query(self):
@@ -342,6 +344,10 @@ class AggregatedValue(object):
# ) # )
property_ = self.relationships[-1].property property_ = self.relationships[-1].property
remote_pairs = property_.local_remote_pairs remote_pairs = property_.local_remote_pairs
local = remote_pairs[0][0]
remote = remote_paris[0][1]
from_ = property_.mapper.class_.__table__ from_ = property_.mapper.class_.__table__
for relationship in reversed(self.relationships[1:-1]): for relationship in reversed(self.relationships[1:-1]):
property_ = relationship.property property_ = relationship.property
@@ -352,15 +358,13 @@ class AggregatedValue(object):
property_ = self.relationships[0].property property_ = self.relationships[0].property
query = query.where( query = query.where(
remote_pairs[0][0].in_( local.in_(
sa.select( sa.select(
[remote_pairs[0][1]], [remote],
from_obj=[from_] from_obj=[from_]
).where( ).where(
property_.local_remote_pairs[0][0].in_( local.in_(
getattr( getattr(obj, remote.key)
obj, property_.local_remote_pairs[0][1].key
)
for obj in objects for obj in objects
) )
) )