From 67bddd0735054761b6910093d11cace49d4e1d4d Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Mon, 1 Dec 2014 15:28:36 +0200 Subject: [PATCH] Fix column alias handling in aggregated --- sqlalchemy_utils/aggregates.py | 9 ++-- tests/aggregate/test_with_column_alias.py | 58 +++++++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) create mode 100644 tests/aggregate/test_with_column_alias.py diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index 6e7ed2a..ffb3743 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -394,10 +394,11 @@ class AggregatedAttribute(declared_attr): self.relationship = relationship def __get__(desc, self, cls): + value = (desc.fget, desc.relationship, desc.column) if cls not in aggregated_attrs: - aggregated_attrs[cls] = [(desc.fget, desc.relationship)] + aggregated_attrs[cls] = [value] else: - aggregated_attrs[cls].append((desc.fget, desc.relationship)) + aggregated_attrs[cls].append(value) return desc.column @@ -577,7 +578,7 @@ class AggregationManager(object): def update_generator_registry(self): for class_, attrs in six.iteritems(aggregated_attrs): - for expr, relationship in attrs: + for expr, relationship, column in attrs: relationships = [] rel_class = class_ @@ -589,7 +590,7 @@ class AggregationManager(object): self.generator_registry[rel_class].append( AggregatedValue( class_=class_, - attr=expr.__name__, + attr=column, relationships=list(reversed(relationships)), expr=expr(class_) ) diff --git a/tests/aggregate/test_with_column_alias.py b/tests/aggregate/test_with_column_alias.py new file mode 100644 index 0000000..744cbcd --- /dev/null +++ b/tests/aggregate/test_with_column_alias.py @@ -0,0 +1,58 @@ +import sqlalchemy as sa +from sqlalchemy_utils.aggregates import aggregated +from tests import TestCase + + +class TestAggregatedWithColumnAlias(TestCase): + def create_models(self): + class Thread(self.Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + + @aggregated( + 'comments', + sa.Column('_comment_count', sa.Integer, default=0) + ) + def comment_count(self): + return sa.func.count('1') + + comments = sa.orm.relationship('Comment', backref='thread') + + class Comment(self.Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) + + self.Thread = Thread + self.Comment = Comment + + def test_assigns_aggregates_on_insert(self): + thread = self.Thread() + self.session.add(thread) + comment = self.Comment(thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + + def test_assigns_aggregates_on_separate_insert(self): + thread = self.Thread() + self.session.add(thread) + self.session.commit() + comment = self.Comment(thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + + def test_assigns_aggregates_on_delete(self): + thread = self.Thread() + self.session.add(thread) + self.session.commit() + comment = self.Comment(thread=thread) + self.session.add(comment) + self.session.commit() + self.session.delete(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 0