From 3e083955d328d5dfa8b32d31019e10e63894695c Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 9 Sep 2014 08:55:35 +0300 Subject: [PATCH] Fix aggregate handling for cascade deleted objects --- sqlalchemy_utils/aggregates.py | 60 ++++++++++--------- tests/aggregate/test_with_ondelete_cascade.py | 46 ++++++++++++++ 2 files changed, 78 insertions(+), 28 deletions(-) create mode 100644 tests/aggregate/test_with_ondelete_cascade.py diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index a646448..1909b8f 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -445,8 +445,9 @@ class AggregatedValue(object): ) if len(self.relationships) == 1: prop = self.relationships[-1].property - - return query.where(self.local_condition(prop, objects)) + condition = self.local_condition(prop, objects) + if condition is not None: + return query.where(condition) else: # Builds query such as: # @@ -462,20 +463,21 @@ class AggregatedValue(object): remote_pairs = property_.local_remote_pairs local = remote_pairs[0][0] remote = remote_pairs[0][1] - - return query.where( - local.in_( - sa.select( - [remote], - from_obj=[self.multi_level_aggregate_query_base] - ).where( - self.local_condition( - self.relationships[0].property, - objects + condition = self.local_condition( + self.relationships[0].property, + objects + ) + if condition is not None: + return query.where( + local.in_( + sa.select( + [remote], + from_obj=[self.multi_level_aggregate_query_base] + ).where( + condition ) ) ) - ) @property def multi_level_aggregate_query_base(self): @@ -490,22 +492,23 @@ class AggregatedValue(object): return from_ def local_condition(self, prop, objects): + pairs = prop.local_remote_pairs if prop.secondary is not None: - return prop.local_remote_pairs[1][0].in_( - getattr( - obj, - prop.local_remote_pairs[1][0].key - ) - for obj in objects - ) + column = pairs[1][0] + key = pairs[1][0].key else: - return prop.local_remote_pairs[0][0].in_( - getattr( - obj, - prop.local_remote_pairs[0][1].key - ) - for obj in objects - ) + column = pairs[0][0] + key = pairs[0][1].key + + values = [] + for obj in objects: + try: + values.append(getattr(obj, key)) + except sa.orm.exc.ObjectDeletedError: + pass + + if values: + return column.in_(values) class AggregationManager(object): @@ -557,7 +560,8 @@ class AggregationManager(object): for class_, objects in six.iteritems(object_dict): for aggregate_value in self.generator_registry[class_]: query = aggregate_value.update_query(objects) - session.execute(query) + if query is not None: + session.execute(query) manager = AggregationManager() diff --git a/tests/aggregate/test_with_ondelete_cascade.py b/tests/aggregate/test_with_ondelete_cascade.py new file mode 100644 index 0000000..6070d9b --- /dev/null +++ b/tests/aggregate/test_with_ondelete_cascade.py @@ -0,0 +1,46 @@ +import sqlalchemy as sa +from sqlalchemy_utils.aggregates import aggregated +from tests import TestCase + + +class TestAggregateValueGenerationWithCascadeDelete(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Thread(self.Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('comments', sa.Column(sa.Integer, default=0)) + def comment_count(self): + return sa.func.count('1') + + comments = sa.orm.relationship( + 'Comment', + passive_deletes=True, + backref='thread' + ) + + class Comment(self.Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + content = sa.Column(sa.Unicode(255)) + thread_id = sa.Column( + sa.Integer, + sa.ForeignKey('thread.id', ondelete='CASCADE') + ) + + self.Thread = Thread + self.Comment = Comment + + def test_something(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.expire_all() + self.session.delete(thread) + self.session.commit()