Refactor aggregates

This commit is contained in:
Konsta Vesterinen
2014-01-04 03:07:38 +02:00
parent 9f6f1f98d1
commit 4a916e5fab

View File

@@ -331,7 +331,7 @@ class AggregatedValue(object):
if len(self.relationships) == 1: if len(self.relationships) == 1:
remote_pairs = self.relationships[-1].property.local_remote_pairs remote_pairs = self.relationships[-1].property.local_remote_pairs
query = query.where( return query.where(
remote_pairs[0][0].in_( remote_pairs[0][0].in_(
getattr(obj, remote_pairs[0][1].key) for obj in objects getattr(obj, remote_pairs[0][1].key) for obj in objects
) )
@@ -352,6 +352,23 @@ class AggregatedValue(object):
local = remote_pairs[0][0] local = remote_pairs[0][0]
remote = remote_pairs[0][1] remote = remote_pairs[0][1]
return query.where(
local.in_(
sa.select(
[remote],
from_obj=[self.multi_level_aggregate_query_base]
).where(
self.local_condition(
self.relationships[0].property,
objects
)
)
)
)
@property
def multi_level_aggregate_query_base(self):
property_ = self.relationships[-1].property
from_ = property_.mapper.class_.__table__ from_ = property_.mapper.class_.__table__
for relationship in reversed(self.relationships[1:-1]): for relationship in reversed(self.relationships[1:-1]):
@@ -359,22 +376,7 @@ class AggregatedValue(object):
from_ = ( from_ = (
from_.join(property_.mapper.class_, property_.primaryjoin) from_.join(property_.mapper.class_, property_.primaryjoin)
) )
return from_
property_ = self.relationships[0].property
query = query.where(
local.in_(
sa.select(
[remote],
from_obj=[from_]
).where(
self.local_condition(property_, objects)
)
)
)
return query
def local_condition(self, prop, objects): def local_condition(self, prop, objects):
return prop.local_remote_pairs[0][0].in_( return prop.local_remote_pairs[0][0].in_(