Fix aggregate handling for cascade deleted objects
This commit is contained in:
@@ -445,8 +445,9 @@ class AggregatedValue(object):
|
||||
)
|
||||
if len(self.relationships) == 1:
|
||||
prop = self.relationships[-1].property
|
||||
|
||||
return query.where(self.local_condition(prop, objects))
|
||||
condition = self.local_condition(prop, objects)
|
||||
if condition is not None:
|
||||
return query.where(condition)
|
||||
else:
|
||||
# Builds query such as:
|
||||
#
|
||||
@@ -462,20 +463,21 @@ class AggregatedValue(object):
|
||||
remote_pairs = property_.local_remote_pairs
|
||||
local = remote_pairs[0][0]
|
||||
remote = remote_pairs[0][1]
|
||||
|
||||
return query.where(
|
||||
local.in_(
|
||||
sa.select(
|
||||
[remote],
|
||||
from_obj=[self.multi_level_aggregate_query_base]
|
||||
).where(
|
||||
self.local_condition(
|
||||
self.relationships[0].property,
|
||||
objects
|
||||
condition = self.local_condition(
|
||||
self.relationships[0].property,
|
||||
objects
|
||||
)
|
||||
if condition is not None:
|
||||
return query.where(
|
||||
local.in_(
|
||||
sa.select(
|
||||
[remote],
|
||||
from_obj=[self.multi_level_aggregate_query_base]
|
||||
).where(
|
||||
condition
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def multi_level_aggregate_query_base(self):
|
||||
@@ -490,22 +492,23 @@ class AggregatedValue(object):
|
||||
return from_
|
||||
|
||||
def local_condition(self, prop, objects):
|
||||
pairs = prop.local_remote_pairs
|
||||
if prop.secondary is not None:
|
||||
return prop.local_remote_pairs[1][0].in_(
|
||||
getattr(
|
||||
obj,
|
||||
prop.local_remote_pairs[1][0].key
|
||||
)
|
||||
for obj in objects
|
||||
)
|
||||
column = pairs[1][0]
|
||||
key = pairs[1][0].key
|
||||
else:
|
||||
return prop.local_remote_pairs[0][0].in_(
|
||||
getattr(
|
||||
obj,
|
||||
prop.local_remote_pairs[0][1].key
|
||||
)
|
||||
for obj in objects
|
||||
)
|
||||
column = pairs[0][0]
|
||||
key = pairs[0][1].key
|
||||
|
||||
values = []
|
||||
for obj in objects:
|
||||
try:
|
||||
values.append(getattr(obj, key))
|
||||
except sa.orm.exc.ObjectDeletedError:
|
||||
pass
|
||||
|
||||
if values:
|
||||
return column.in_(values)
|
||||
|
||||
|
||||
class AggregationManager(object):
|
||||
@@ -557,7 +560,8 @@ class AggregationManager(object):
|
||||
for class_, objects in six.iteritems(object_dict):
|
||||
for aggregate_value in self.generator_registry[class_]:
|
||||
query = aggregate_value.update_query(objects)
|
||||
session.execute(query)
|
||||
if query is not None:
|
||||
session.execute(query)
|
||||
|
||||
|
||||
manager = AggregationManager()
|
||||
|
46
tests/aggregate/test_with_ondelete_cascade.py
Normal file
46
tests/aggregate/test_with_ondelete_cascade.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateValueGenerationWithCascadeDelete(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class Thread(self.Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated('comments', sa.Column(sa.Integer, default=0))
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
comments = sa.orm.relationship(
|
||||
'Comment',
|
||||
passive_deletes=True,
|
||||
backref='thread'
|
||||
)
|
||||
|
||||
class Comment(self.Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
content = sa.Column(sa.Unicode(255))
|
||||
thread_id = sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey('thread.id', ondelete='CASCADE')
|
||||
)
|
||||
|
||||
self.Thread = Thread
|
||||
self.Comment = Comment
|
||||
|
||||
def test_something(self):
|
||||
thread = self.Thread()
|
||||
thread.name = u'some article name'
|
||||
self.session.add(thread)
|
||||
comment = self.Comment(content=u'Some content', thread=thread)
|
||||
self.session.add(comment)
|
||||
self.session.commit()
|
||||
self.session.expire_all()
|
||||
self.session.delete(thread)
|
||||
self.session.commit()
|
Reference in New Issue
Block a user