Refactor aggregates module
This commit is contained in:
@@ -286,13 +286,15 @@ class AggregatedValue(object):
|
||||
self.class_ = class_
|
||||
self.attr = attr
|
||||
self.relationships = relationships
|
||||
self.expr = self.aggregate_expression(expr, class_)
|
||||
|
||||
def aggregate_expression(self, expr, class_):
|
||||
if isinstance(expr, sa.sql.visitors.Visitable):
|
||||
self.expr = expr
|
||||
return expr
|
||||
elif isinstance(expr, _FunctionGenerator):
|
||||
self.expr = expr(sa.sql.text('1'))
|
||||
return expr(sa.sql.text('1'))
|
||||
else:
|
||||
self.expr = expr(class_)
|
||||
return expr(class_)
|
||||
|
||||
@property
|
||||
def aggregate_query(self):
|
||||
@@ -342,6 +344,10 @@ class AggregatedValue(object):
|
||||
# )
|
||||
property_ = self.relationships[-1].property
|
||||
remote_pairs = property_.local_remote_pairs
|
||||
local = remote_pairs[0][0]
|
||||
remote = remote_paris[0][1]
|
||||
|
||||
|
||||
from_ = property_.mapper.class_.__table__
|
||||
for relationship in reversed(self.relationships[1:-1]):
|
||||
property_ = relationship.property
|
||||
@@ -352,15 +358,13 @@ class AggregatedValue(object):
|
||||
property_ = self.relationships[0].property
|
||||
|
||||
query = query.where(
|
||||
remote_pairs[0][0].in_(
|
||||
local.in_(
|
||||
sa.select(
|
||||
[remote_pairs[0][1]],
|
||||
[remote],
|
||||
from_obj=[from_]
|
||||
).where(
|
||||
property_.local_remote_pairs[0][0].in_(
|
||||
getattr(
|
||||
obj, property_.local_remote_pairs[0][1].key
|
||||
)
|
||||
local.in_(
|
||||
getattr(obj, remote.key)
|
||||
for obj in objects
|
||||
)
|
||||
)
|
||||
|
Reference in New Issue
Block a user