diff --git a/tests/aggregate/test_lazy_select_expressions.py b/tests/aggregate/test_lazy_select_expressions.py new file mode 100644 index 0000000..73b1f47 --- /dev/null +++ b/tests/aggregate/test_lazy_select_expressions.py @@ -0,0 +1,51 @@ +from decimal import Decimal +import sqlalchemy as sa +from sqlalchemy_utils.aggregates import aggregated_attr +from tests import TestCase + + +class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated_attr('products') + def net_worth(self): + return sa.Column(sa.Numeric, default=0) + + @net_worth.expression + def net_worth_expr(self): + return sa.func.sum(Product.price) + + products = sa.orm.relationship('Product', backref='catalog') + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + self.Catalog = Catalog + self.Product = Product + + def test_assigns_aggregates(self): + catalog = self.Catalog( + name=u'Some catalog' + ) + self.session.add(catalog) + self.session.commit() + product = self.Product( + name=u'Some product', + price=Decimal('1000'), + catalog=catalog + ) + self.session.add(product) + self.session.commit() + self.session.refresh(catalog) + assert catalog.net_worth == Decimal('1000') diff --git a/tests/aggregate/test_multiple_aggregates_per_class.py b/tests/aggregate/test_multiple_aggregates_per_class.py new file mode 100644 index 0000000..b49bc79 --- /dev/null +++ b/tests/aggregate/test_multiple_aggregates_per_class.py @@ -0,0 +1,77 @@ +import sqlalchemy as sa +from sqlalchemy_utils.aggregates import aggregated_attr +from tests import TestCase + + +class TestAggregateValueGenerationForSimpleModelPaths(TestCase): + def create_models(self): + class Thread(self.Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated_attr('comments') + def comment_count(self): + return sa.Column(sa.Integer, default=0) + + @aggregated_attr('comments', sa.func.max) + def last_comment_id(self): + return sa.Column(sa.Integer) + + comments = sa.orm.relationship( + 'Comment', + backref='thread' + ) + + Thread.last_comment = sa.orm.relationship( + 'Comment', + primaryjoin='Thread.last_comment_id == Comment.id', + foreign_keys=[Thread.last_comment_id], + viewonly=True + ) + + class Comment(self.Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + content = sa.Column(sa.Unicode(255)) + thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) + + self.Thread = Thread + self.Comment = Comment + + def test_assigns_aggregates_on_insert(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + assert thread.last_comment_id == comment.id + + def test_assigns_aggregates_on_separate_insert(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + self.session.commit() + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + assert thread.last_comment_id == 1 + + def test_assigns_aggregates_on_delete(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + self.session.commit() + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.delete(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 0 + assert thread.last_comment_id is None