Refactor aggregates module

This commit is contained in:
Konsta Vesterinen
2014-01-04 02:41:47 +02:00
parent ac4a2a7829
commit 795c3f2f45

View File

@@ -286,13 +286,15 @@ class AggregatedValue(object):
self.class_ = class_
self.attr = attr
self.relationships = relationships
self.expr = self.aggregate_expression(expr, class_)
def aggregate_expression(self, expr, class_):
if isinstance(expr, sa.sql.visitors.Visitable):
self.expr = expr
return expr
elif isinstance(expr, _FunctionGenerator):
self.expr = expr(sa.sql.text('1'))
return expr(sa.sql.text('1'))
else:
self.expr = expr(class_)
return expr(class_)
@property
def aggregate_query(self):
@@ -342,6 +344,10 @@ class AggregatedValue(object):
# )
property_ = self.relationships[-1].property
remote_pairs = property_.local_remote_pairs
local = remote_pairs[0][0]
remote = remote_paris[0][1]
from_ = property_.mapper.class_.__table__
for relationship in reversed(self.relationships[1:-1]):
property_ = relationship.property
@@ -352,15 +358,13 @@ class AggregatedValue(object):
property_ = self.relationships[0].property
query = query.where(
remote_pairs[0][0].in_(
local.in_(
sa.select(
[remote_pairs[0][1]],
[remote],
from_obj=[from_]
).where(
property_.local_remote_pairs[0][0].in_(
getattr(
obj, property_.local_remote_pairs[0][1].key
)
local.in_(
getattr(obj, remote.key)
for obj in objects
)
)