Fix aggregate handling for cascade deleted objects

This commit is contained in:
Konsta Vesterinen
2014-09-09 08:55:35 +03:00
parent aa478a8e20
commit 3e083955d3
2 changed files with 78 additions and 28 deletions

View File

@@ -445,8 +445,9 @@ class AggregatedValue(object):
)
if len(self.relationships) == 1:
prop = self.relationships[-1].property
return query.where(self.local_condition(prop, objects))
condition = self.local_condition(prop, objects)
if condition is not None:
return query.where(condition)
else:
# Builds query such as:
#
@@ -462,20 +463,21 @@ class AggregatedValue(object):
remote_pairs = property_.local_remote_pairs
local = remote_pairs[0][0]
remote = remote_pairs[0][1]
return query.where(
local.in_(
sa.select(
[remote],
from_obj=[self.multi_level_aggregate_query_base]
).where(
self.local_condition(
self.relationships[0].property,
objects
condition = self.local_condition(
self.relationships[0].property,
objects
)
if condition is not None:
return query.where(
local.in_(
sa.select(
[remote],
from_obj=[self.multi_level_aggregate_query_base]
).where(
condition
)
)
)
)
@property
def multi_level_aggregate_query_base(self):
@@ -490,22 +492,23 @@ class AggregatedValue(object):
return from_
def local_condition(self, prop, objects):
pairs = prop.local_remote_pairs
if prop.secondary is not None:
return prop.local_remote_pairs[1][0].in_(
getattr(
obj,
prop.local_remote_pairs[1][0].key
)
for obj in objects
)
column = pairs[1][0]
key = pairs[1][0].key
else:
return prop.local_remote_pairs[0][0].in_(
getattr(
obj,
prop.local_remote_pairs[0][1].key
)
for obj in objects
)
column = pairs[0][0]
key = pairs[0][1].key
values = []
for obj in objects:
try:
values.append(getattr(obj, key))
except sa.orm.exc.ObjectDeletedError:
pass
if values:
return column.in_(values)
class AggregationManager(object):
@@ -557,7 +560,8 @@ class AggregationManager(object):
for class_, objects in six.iteritems(object_dict):
for aggregate_value in self.generator_registry[class_]:
query = aggregate_value.update_query(objects)
session.execute(query)
if query is not None:
session.execute(query)
manager = AggregationManager()

View File

@@ -0,0 +1,46 @@
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateValueGenerationWithCascadeDelete(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class Thread(self.Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('comments', sa.Column(sa.Integer, default=0))
def comment_count(self):
return sa.func.count('1')
comments = sa.orm.relationship(
'Comment',
passive_deletes=True,
backref='thread'
)
class Comment(self.Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(
sa.Integer,
sa.ForeignKey('thread.id', ondelete='CASCADE')
)
self.Thread = Thread
self.Comment = Comment
def test_something(self):
thread = self.Thread()
thread.name = u'some article name'
self.session.add(thread)
comment = self.Comment(content=u'Some content', thread=thread)
self.session.add(comment)
self.session.commit()
self.session.expire_all()
self.session.delete(thread)
self.session.commit()