Fix aggregate handling for cascade deleted objects

This commit is contained in:
Konsta Vesterinen
2014-09-09 08:55:35 +03:00
parent aa478a8e20
commit 3e083955d3
2 changed files with 78 additions and 28 deletions

View File

@@ -445,8 +445,9 @@ class AggregatedValue(object):
) )
if len(self.relationships) == 1: if len(self.relationships) == 1:
prop = self.relationships[-1].property prop = self.relationships[-1].property
condition = self.local_condition(prop, objects)
return query.where(self.local_condition(prop, objects)) if condition is not None:
return query.where(condition)
else: else:
# Builds query such as: # Builds query such as:
# #
@@ -462,17 +463,18 @@ class AggregatedValue(object):
remote_pairs = property_.local_remote_pairs remote_pairs = property_.local_remote_pairs
local = remote_pairs[0][0] local = remote_pairs[0][0]
remote = remote_pairs[0][1] remote = remote_pairs[0][1]
condition = self.local_condition(
self.relationships[0].property,
objects
)
if condition is not None:
return query.where( return query.where(
local.in_( local.in_(
sa.select( sa.select(
[remote], [remote],
from_obj=[self.multi_level_aggregate_query_base] from_obj=[self.multi_level_aggregate_query_base]
).where( ).where(
self.local_condition( condition
self.relationships[0].property,
objects
)
) )
) )
) )
@@ -490,22 +492,23 @@ class AggregatedValue(object):
return from_ return from_
def local_condition(self, prop, objects): def local_condition(self, prop, objects):
pairs = prop.local_remote_pairs
if prop.secondary is not None: if prop.secondary is not None:
return prop.local_remote_pairs[1][0].in_( column = pairs[1][0]
getattr( key = pairs[1][0].key
obj,
prop.local_remote_pairs[1][0].key
)
for obj in objects
)
else: else:
return prop.local_remote_pairs[0][0].in_( column = pairs[0][0]
getattr( key = pairs[0][1].key
obj,
prop.local_remote_pairs[0][1].key values = []
) for obj in objects:
for obj in objects try:
) values.append(getattr(obj, key))
except sa.orm.exc.ObjectDeletedError:
pass
if values:
return column.in_(values)
class AggregationManager(object): class AggregationManager(object):
@@ -557,6 +560,7 @@ class AggregationManager(object):
for class_, objects in six.iteritems(object_dict): for class_, objects in six.iteritems(object_dict):
for aggregate_value in self.generator_registry[class_]: for aggregate_value in self.generator_registry[class_]:
query = aggregate_value.update_query(objects) query = aggregate_value.update_query(objects)
if query is not None:
session.execute(query) session.execute(query)

View File

@@ -0,0 +1,46 @@
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateValueGenerationWithCascadeDelete(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class Thread(self.Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('comments', sa.Column(sa.Integer, default=0))
def comment_count(self):
return sa.func.count('1')
comments = sa.orm.relationship(
'Comment',
passive_deletes=True,
backref='thread'
)
class Comment(self.Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(
sa.Integer,
sa.ForeignKey('thread.id', ondelete='CASCADE')
)
self.Thread = Thread
self.Comment = Comment
def test_something(self):
thread = self.Thread()
thread.name = u'some article name'
self.session.add(thread)
comment = self.Comment(content=u'Some content', thread=thread)
self.session.add(comment)
self.session.commit()
self.session.expire_all()
self.session.delete(thread)
self.session.commit()