Fix aggregate handling for cascade deleted objects
This commit is contained in:
		@@ -445,8 +445,9 @@ class AggregatedValue(object):
 | 
			
		||||
        )
 | 
			
		||||
        if len(self.relationships) == 1:
 | 
			
		||||
            prop = self.relationships[-1].property
 | 
			
		||||
 | 
			
		||||
            return query.where(self.local_condition(prop, objects))
 | 
			
		||||
            condition = self.local_condition(prop, objects)
 | 
			
		||||
            if condition is not None:
 | 
			
		||||
                return query.where(condition)
 | 
			
		||||
        else:
 | 
			
		||||
            # Builds query such as:
 | 
			
		||||
            #
 | 
			
		||||
@@ -462,20 +463,21 @@ class AggregatedValue(object):
 | 
			
		||||
            remote_pairs = property_.local_remote_pairs
 | 
			
		||||
            local = remote_pairs[0][0]
 | 
			
		||||
            remote = remote_pairs[0][1]
 | 
			
		||||
 | 
			
		||||
            return query.where(
 | 
			
		||||
                local.in_(
 | 
			
		||||
                    sa.select(
 | 
			
		||||
                        [remote],
 | 
			
		||||
                        from_obj=[self.multi_level_aggregate_query_base]
 | 
			
		||||
                    ).where(
 | 
			
		||||
                        self.local_condition(
 | 
			
		||||
                            self.relationships[0].property,
 | 
			
		||||
                            objects
 | 
			
		||||
            condition = self.local_condition(
 | 
			
		||||
                self.relationships[0].property,
 | 
			
		||||
                objects
 | 
			
		||||
            )
 | 
			
		||||
            if condition is not None:
 | 
			
		||||
                return query.where(
 | 
			
		||||
                    local.in_(
 | 
			
		||||
                        sa.select(
 | 
			
		||||
                            [remote],
 | 
			
		||||
                            from_obj=[self.multi_level_aggregate_query_base]
 | 
			
		||||
                        ).where(
 | 
			
		||||
                            condition
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def multi_level_aggregate_query_base(self):
 | 
			
		||||
@@ -490,22 +492,23 @@ class AggregatedValue(object):
 | 
			
		||||
        return from_
 | 
			
		||||
 | 
			
		||||
    def local_condition(self, prop, objects):
 | 
			
		||||
        pairs = prop.local_remote_pairs
 | 
			
		||||
        if prop.secondary is not None:
 | 
			
		||||
            return prop.local_remote_pairs[1][0].in_(
 | 
			
		||||
                getattr(
 | 
			
		||||
                    obj,
 | 
			
		||||
                    prop.local_remote_pairs[1][0].key
 | 
			
		||||
                )
 | 
			
		||||
                for obj in objects
 | 
			
		||||
            )
 | 
			
		||||
            column = pairs[1][0]
 | 
			
		||||
            key = pairs[1][0].key
 | 
			
		||||
        else:
 | 
			
		||||
            return prop.local_remote_pairs[0][0].in_(
 | 
			
		||||
                getattr(
 | 
			
		||||
                    obj,
 | 
			
		||||
                    prop.local_remote_pairs[0][1].key
 | 
			
		||||
                )
 | 
			
		||||
                for obj in objects
 | 
			
		||||
            )
 | 
			
		||||
            column = pairs[0][0]
 | 
			
		||||
            key = pairs[0][1].key
 | 
			
		||||
 | 
			
		||||
        values = []
 | 
			
		||||
        for obj in objects:
 | 
			
		||||
            try:
 | 
			
		||||
                values.append(getattr(obj, key))
 | 
			
		||||
            except sa.orm.exc.ObjectDeletedError:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        if values:
 | 
			
		||||
            return column.in_(values)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AggregationManager(object):
 | 
			
		||||
@@ -557,7 +560,8 @@ class AggregationManager(object):
 | 
			
		||||
        for class_, objects in six.iteritems(object_dict):
 | 
			
		||||
            for aggregate_value in self.generator_registry[class_]:
 | 
			
		||||
                query = aggregate_value.update_query(objects)
 | 
			
		||||
                session.execute(query)
 | 
			
		||||
                if query is not None:
 | 
			
		||||
                    session.execute(query)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
manager = AggregationManager()
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										46
									
								
								tests/aggregate/test_with_ondelete_cascade.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								tests/aggregate/test_with_ondelete_cascade.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,46 @@
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
from sqlalchemy_utils.aggregates import aggregated
 | 
			
		||||
from tests import TestCase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestAggregateValueGenerationWithCascadeDelete(TestCase):
 | 
			
		||||
    dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
 | 
			
		||||
 | 
			
		||||
    def create_models(self):
 | 
			
		||||
        class Thread(self.Base):
 | 
			
		||||
            __tablename__ = 'thread'
 | 
			
		||||
            id = sa.Column(sa.Integer, primary_key=True)
 | 
			
		||||
            name = sa.Column(sa.Unicode(255))
 | 
			
		||||
 | 
			
		||||
            @aggregated('comments', sa.Column(sa.Integer, default=0))
 | 
			
		||||
            def comment_count(self):
 | 
			
		||||
                return sa.func.count('1')
 | 
			
		||||
 | 
			
		||||
            comments = sa.orm.relationship(
 | 
			
		||||
                'Comment',
 | 
			
		||||
                passive_deletes=True,
 | 
			
		||||
                backref='thread'
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        class Comment(self.Base):
 | 
			
		||||
            __tablename__ = 'comment'
 | 
			
		||||
            id = sa.Column(sa.Integer, primary_key=True)
 | 
			
		||||
            content = sa.Column(sa.Unicode(255))
 | 
			
		||||
            thread_id = sa.Column(
 | 
			
		||||
                sa.Integer,
 | 
			
		||||
                sa.ForeignKey('thread.id', ondelete='CASCADE')
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.Thread = Thread
 | 
			
		||||
        self.Comment = Comment
 | 
			
		||||
 | 
			
		||||
    def test_something(self):
 | 
			
		||||
        thread = self.Thread()
 | 
			
		||||
        thread.name = u'some article name'
 | 
			
		||||
        self.session.add(thread)
 | 
			
		||||
        comment = self.Comment(content=u'Some content', thread=thread)
 | 
			
		||||
        self.session.add(comment)
 | 
			
		||||
        self.session.commit()
 | 
			
		||||
        self.session.expire_all()
 | 
			
		||||
        self.session.delete(thread)
 | 
			
		||||
        self.session.commit()
 | 
			
		||||
		Reference in New Issue
	
	Block a user