Fix aggregate handling for cascade deleted objects
This commit is contained in:
@@ -445,8 +445,9 @@ class AggregatedValue(object):
|
|||||||
)
|
)
|
||||||
if len(self.relationships) == 1:
|
if len(self.relationships) == 1:
|
||||||
prop = self.relationships[-1].property
|
prop = self.relationships[-1].property
|
||||||
|
condition = self.local_condition(prop, objects)
|
||||||
return query.where(self.local_condition(prop, objects))
|
if condition is not None:
|
||||||
|
return query.where(condition)
|
||||||
else:
|
else:
|
||||||
# Builds query such as:
|
# Builds query such as:
|
||||||
#
|
#
|
||||||
@@ -462,17 +463,18 @@ class AggregatedValue(object):
|
|||||||
remote_pairs = property_.local_remote_pairs
|
remote_pairs = property_.local_remote_pairs
|
||||||
local = remote_pairs[0][0]
|
local = remote_pairs[0][0]
|
||||||
remote = remote_pairs[0][1]
|
remote = remote_pairs[0][1]
|
||||||
|
condition = self.local_condition(
|
||||||
|
self.relationships[0].property,
|
||||||
|
objects
|
||||||
|
)
|
||||||
|
if condition is not None:
|
||||||
return query.where(
|
return query.where(
|
||||||
local.in_(
|
local.in_(
|
||||||
sa.select(
|
sa.select(
|
||||||
[remote],
|
[remote],
|
||||||
from_obj=[self.multi_level_aggregate_query_base]
|
from_obj=[self.multi_level_aggregate_query_base]
|
||||||
).where(
|
).where(
|
||||||
self.local_condition(
|
condition
|
||||||
self.relationships[0].property,
|
|
||||||
objects
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -490,22 +492,23 @@ class AggregatedValue(object):
|
|||||||
return from_
|
return from_
|
||||||
|
|
||||||
def local_condition(self, prop, objects):
|
def local_condition(self, prop, objects):
|
||||||
|
pairs = prop.local_remote_pairs
|
||||||
if prop.secondary is not None:
|
if prop.secondary is not None:
|
||||||
return prop.local_remote_pairs[1][0].in_(
|
column = pairs[1][0]
|
||||||
getattr(
|
key = pairs[1][0].key
|
||||||
obj,
|
|
||||||
prop.local_remote_pairs[1][0].key
|
|
||||||
)
|
|
||||||
for obj in objects
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return prop.local_remote_pairs[0][0].in_(
|
column = pairs[0][0]
|
||||||
getattr(
|
key = pairs[0][1].key
|
||||||
obj,
|
|
||||||
prop.local_remote_pairs[0][1].key
|
values = []
|
||||||
)
|
for obj in objects:
|
||||||
for obj in objects
|
try:
|
||||||
)
|
values.append(getattr(obj, key))
|
||||||
|
except sa.orm.exc.ObjectDeletedError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if values:
|
||||||
|
return column.in_(values)
|
||||||
|
|
||||||
|
|
||||||
class AggregationManager(object):
|
class AggregationManager(object):
|
||||||
@@ -557,6 +560,7 @@ class AggregationManager(object):
|
|||||||
for class_, objects in six.iteritems(object_dict):
|
for class_, objects in six.iteritems(object_dict):
|
||||||
for aggregate_value in self.generator_registry[class_]:
|
for aggregate_value in self.generator_registry[class_]:
|
||||||
query = aggregate_value.update_query(objects)
|
query = aggregate_value.update_query(objects)
|
||||||
|
if query is not None:
|
||||||
session.execute(query)
|
session.execute(query)
|
||||||
|
|
||||||
|
|
||||||
|
46
tests/aggregate/test_with_ondelete_cascade.py
Normal file
46
tests/aggregate/test_with_ondelete_cascade.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils.aggregates import aggregated
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestAggregateValueGenerationWithCascadeDelete(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class Thread(self.Base):
|
||||||
|
__tablename__ = 'thread'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregated('comments', sa.Column(sa.Integer, default=0))
|
||||||
|
def comment_count(self):
|
||||||
|
return sa.func.count('1')
|
||||||
|
|
||||||
|
comments = sa.orm.relationship(
|
||||||
|
'Comment',
|
||||||
|
passive_deletes=True,
|
||||||
|
backref='thread'
|
||||||
|
)
|
||||||
|
|
||||||
|
class Comment(self.Base):
|
||||||
|
__tablename__ = 'comment'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
content = sa.Column(sa.Unicode(255))
|
||||||
|
thread_id = sa.Column(
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey('thread.id', ondelete='CASCADE')
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Thread = Thread
|
||||||
|
self.Comment = Comment
|
||||||
|
|
||||||
|
def test_something(self):
|
||||||
|
thread = self.Thread()
|
||||||
|
thread.name = u'some article name'
|
||||||
|
self.session.add(thread)
|
||||||
|
comment = self.Comment(content=u'Some content', thread=thread)
|
||||||
|
self.session.add(comment)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.expire_all()
|
||||||
|
self.session.delete(thread)
|
||||||
|
self.session.commit()
|
Reference in New Issue
Block a user