From 59657ec2dad37b1d20c66e6648ece711334ac8dd Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 12 Nov 2013 11:49:56 +0200 Subject: [PATCH] Changed aggregate syntax according to Mike Bayer's feedback --- sqlalchemy_utils/__init__.py | 4 ++-- sqlalchemy_utils/aggregates.py | 21 +++++++------------ .../test_custom_select_expressions.py | 8 ++----- tests/aggregate/test_deep_paths.py | 16 +++++++++----- .../test_multiple_aggregates_per_class.py | 13 +++++++----- tests/aggregate/test_simple_paths.py | 6 +++--- 6 files changed, 34 insertions(+), 34 deletions(-) diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 6fcc71e..76f1ae4 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -1,4 +1,4 @@ -from .aggregates import aggregated_attr +from .aggregates import aggregated from .decorators import generates from .exceptions import ImproperlyConfigured from .functions import ( @@ -55,7 +55,7 @@ __version__ = '0.21.0' __all__ = ( - aggregated_attr, + aggregated, batch_fetch, coercion_listener, defer_except, diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index 1e033be..e3fedc9 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -281,28 +281,23 @@ class AggregatedAttribute(declared_attr): self, fget, relationship, - expr, + column, *args, **kwargs ): super(AggregatedAttribute, self).__init__(fget, *args, **kwargs) self.__doc__ = fget.__doc__ - self.expr = expr + self.column = column self.relationship = relationship - def expression(self, expr): - self.expr = expr - return self - def __get__(desc, self, cls): - result = desc.fget(cls) if not hasattr(cls, '__aggregates__'): cls.__aggregates__ = {} cls.__aggregates__[desc.fget.__name__] = { - 'expression': desc.expr, + 'expression': desc.fget, 'relationship': desc.relationship } - return result + return desc.column class AggregatedValue(object): @@ -430,7 +425,7 @@ class AggregationManager(object): class_=class_, attr=key, relationships=list(reversed(relationships)), - expr=value['expression'] + expr=value['expression'](class_) ) ) @@ -451,14 +446,14 @@ manager = AggregationManager() manager.register_listeners() -def aggregated_attr( +def aggregated( relationship, - expression=sa.func.count + column ): def wraps(func): return AggregatedAttribute( func, relationship, - expression + column ) return wraps diff --git a/tests/aggregate/test_custom_select_expressions.py b/tests/aggregate/test_custom_select_expressions.py index 650b45e..725acca 100644 --- a/tests/aggregate/test_custom_select_expressions.py +++ b/tests/aggregate/test_custom_select_expressions.py @@ -1,6 +1,6 @@ from decimal import Decimal import sqlalchemy as sa -from sqlalchemy_utils.aggregates import aggregated_attr +from sqlalchemy_utils.aggregates import aggregated from tests import TestCase @@ -13,12 +13,8 @@ class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase): id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - @aggregated_attr('products') + @aggregated('products', sa.Column(sa.Numeric, default=0)) def net_worth(self): - return sa.Column(sa.Numeric, default=0) - - @net_worth.expression - def net_worth_expr(self): return sa.func.sum(Product.price) products = sa.orm.relationship('Product', backref='catalog') diff --git a/tests/aggregate/test_deep_paths.py b/tests/aggregate/test_deep_paths.py index e74d546..ab8bbfb 100644 --- a/tests/aggregate/test_deep_paths.py +++ b/tests/aggregate/test_deep_paths.py @@ -1,6 +1,6 @@ from decimal import Decimal import sqlalchemy as sa -from sqlalchemy_utils.aggregates import aggregated_attr +from sqlalchemy_utils.aggregates import aggregated from tests import TestCase @@ -13,9 +13,12 @@ class TestDeepModelPathsForAggregates(TestCase): id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - @aggregated_attr('categories.products') + @aggregated( + 'categories.products', + sa.Column(sa.Integer, default=0) + ) def product_count(self): - return sa.Column(sa.Integer, default=0) + return sa.func.count('1') categories = sa.orm.relationship('Category', backref='catalog') @@ -69,9 +72,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase): id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - @aggregated_attr('categories.sub_categories.products') + @aggregated( + 'categories.sub_categories.products', + sa.Column(sa.Integer, default=0) + ) def product_count(self): - return sa.Column(sa.Integer, default=0) + return sa.func.count('1') categories = sa.orm.relationship('Category', backref='catalog') diff --git a/tests/aggregate/test_multiple_aggregates_per_class.py b/tests/aggregate/test_multiple_aggregates_per_class.py index b49bc79..aa949b5 100644 --- a/tests/aggregate/test_multiple_aggregates_per_class.py +++ b/tests/aggregate/test_multiple_aggregates_per_class.py @@ -1,5 +1,5 @@ import sqlalchemy as sa -from sqlalchemy_utils.aggregates import aggregated_attr +from sqlalchemy_utils.aggregates import aggregated from tests import TestCase @@ -10,13 +10,16 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase): id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - @aggregated_attr('comments') + @aggregated( + 'comments', + sa.Column(sa.Integer, default=0) + ) def comment_count(self): - return sa.Column(sa.Integer, default=0) + return sa.func.count('1') - @aggregated_attr('comments', sa.func.max) + @aggregated('comments', sa.Column(sa.Integer)) def last_comment_id(self): - return sa.Column(sa.Integer) + return sa.func.max(Comment.id) comments = sa.orm.relationship( 'Comment', diff --git a/tests/aggregate/test_simple_paths.py b/tests/aggregate/test_simple_paths.py index 6435fc8..07104a2 100644 --- a/tests/aggregate/test_simple_paths.py +++ b/tests/aggregate/test_simple_paths.py @@ -1,5 +1,5 @@ import sqlalchemy as sa -from sqlalchemy_utils.aggregates import aggregated_attr +from sqlalchemy_utils.aggregates import aggregated from tests import TestCase @@ -10,9 +10,9 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase): id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - @aggregated_attr('comments') + @aggregated('comments', sa.Column(sa.Integer, default=0)) def comment_count(self): - return sa.Column(sa.Integer, default=0) + return sa.func.count('1') comments = sa.orm.relationship('Comment', backref='thread')