Changed aggregate syntax according to Mike Bayer's feedback
This commit is contained in:
		| @@ -1,4 +1,4 @@ | ||||
| from .aggregates import aggregated_attr | ||||
| from .aggregates import aggregated | ||||
| from .decorators import generates | ||||
| from .exceptions import ImproperlyConfigured | ||||
| from .functions import ( | ||||
| @@ -55,7 +55,7 @@ __version__ = '0.21.0' | ||||
|  | ||||
|  | ||||
| __all__ = ( | ||||
|     aggregated_attr, | ||||
|     aggregated, | ||||
|     batch_fetch, | ||||
|     coercion_listener, | ||||
|     defer_except, | ||||
|   | ||||
| @@ -281,28 +281,23 @@ class AggregatedAttribute(declared_attr): | ||||
|         self, | ||||
|         fget, | ||||
|         relationship, | ||||
|         expr, | ||||
|         column, | ||||
|         *args, | ||||
|         **kwargs | ||||
|     ): | ||||
|         super(AggregatedAttribute, self).__init__(fget, *args, **kwargs) | ||||
|         self.__doc__ = fget.__doc__ | ||||
|         self.expr = expr | ||||
|         self.column = column | ||||
|         self.relationship = relationship | ||||
|  | ||||
|     def expression(self, expr): | ||||
|         self.expr = expr | ||||
|         return self | ||||
|  | ||||
|     def __get__(desc, self, cls): | ||||
|         result = desc.fget(cls) | ||||
|         if not hasattr(cls, '__aggregates__'): | ||||
|             cls.__aggregates__ = {} | ||||
|         cls.__aggregates__[desc.fget.__name__] = { | ||||
|             'expression': desc.expr, | ||||
|             'expression': desc.fget, | ||||
|             'relationship': desc.relationship | ||||
|         } | ||||
|         return result | ||||
|         return desc.column | ||||
|  | ||||
|  | ||||
| class AggregatedValue(object): | ||||
| @@ -430,7 +425,7 @@ class AggregationManager(object): | ||||
|                         class_=class_, | ||||
|                         attr=key, | ||||
|                         relationships=list(reversed(relationships)), | ||||
|                         expr=value['expression'] | ||||
|                         expr=value['expression'](class_) | ||||
|                     ) | ||||
|                 ) | ||||
|  | ||||
| @@ -451,14 +446,14 @@ manager = AggregationManager() | ||||
| manager.register_listeners() | ||||
|  | ||||
|  | ||||
| def aggregated_attr( | ||||
| def aggregated( | ||||
|     relationship, | ||||
|     expression=sa.func.count | ||||
|     column | ||||
| ): | ||||
|     def wraps(func): | ||||
|         return AggregatedAttribute( | ||||
|             func, | ||||
|             relationship, | ||||
|             expression | ||||
|             column | ||||
|         ) | ||||
|     return wraps | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| from decimal import Decimal | ||||
| import sqlalchemy as sa | ||||
| from sqlalchemy_utils.aggregates import aggregated_attr | ||||
| from sqlalchemy_utils.aggregates import aggregated | ||||
| from tests import TestCase | ||||
|  | ||||
|  | ||||
| @@ -13,12 +13,8 @@ class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase): | ||||
|             id = sa.Column(sa.Integer, primary_key=True) | ||||
|             name = sa.Column(sa.Unicode(255)) | ||||
|  | ||||
|             @aggregated_attr('products') | ||||
|             @aggregated('products', sa.Column(sa.Numeric, default=0)) | ||||
|             def net_worth(self): | ||||
|                 return sa.Column(sa.Numeric, default=0) | ||||
|  | ||||
|             @net_worth.expression | ||||
|             def net_worth_expr(self): | ||||
|                 return sa.func.sum(Product.price) | ||||
|  | ||||
|             products = sa.orm.relationship('Product', backref='catalog') | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| from decimal import Decimal | ||||
| import sqlalchemy as sa | ||||
| from sqlalchemy_utils.aggregates import aggregated_attr | ||||
| from sqlalchemy_utils.aggregates import aggregated | ||||
| from tests import TestCase | ||||
|  | ||||
|  | ||||
| @@ -13,9 +13,12 @@ class TestDeepModelPathsForAggregates(TestCase): | ||||
|             id = sa.Column(sa.Integer, primary_key=True) | ||||
|             name = sa.Column(sa.Unicode(255)) | ||||
|  | ||||
|             @aggregated_attr('categories.products') | ||||
|             @aggregated( | ||||
|                 'categories.products', | ||||
|                 sa.Column(sa.Integer, default=0) | ||||
|             ) | ||||
|             def product_count(self): | ||||
|                 return sa.Column(sa.Integer, default=0) | ||||
|                 return sa.func.count('1') | ||||
|  | ||||
|             categories = sa.orm.relationship('Category', backref='catalog') | ||||
|  | ||||
| @@ -69,9 +72,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase): | ||||
|             id = sa.Column(sa.Integer, primary_key=True) | ||||
|             name = sa.Column(sa.Unicode(255)) | ||||
|  | ||||
|             @aggregated_attr('categories.sub_categories.products') | ||||
|             @aggregated( | ||||
|                 'categories.sub_categories.products', | ||||
|                 sa.Column(sa.Integer, default=0) | ||||
|             ) | ||||
|             def product_count(self): | ||||
|                 return sa.Column(sa.Integer, default=0) | ||||
|                 return sa.func.count('1') | ||||
|  | ||||
|             categories = sa.orm.relationship('Category', backref='catalog') | ||||
|  | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| import sqlalchemy as sa | ||||
| from sqlalchemy_utils.aggregates import aggregated_attr | ||||
| from sqlalchemy_utils.aggregates import aggregated | ||||
| from tests import TestCase | ||||
|  | ||||
|  | ||||
| @@ -10,13 +10,16 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase): | ||||
|             id = sa.Column(sa.Integer, primary_key=True) | ||||
|             name = sa.Column(sa.Unicode(255)) | ||||
|  | ||||
|             @aggregated_attr('comments') | ||||
|             @aggregated( | ||||
|                 'comments', | ||||
|                 sa.Column(sa.Integer, default=0) | ||||
|             ) | ||||
|             def comment_count(self): | ||||
|                 return sa.Column(sa.Integer, default=0) | ||||
|                 return sa.func.count('1') | ||||
|  | ||||
|             @aggregated_attr('comments', sa.func.max) | ||||
|             @aggregated('comments', sa.Column(sa.Integer)) | ||||
|             def last_comment_id(self): | ||||
|                 return sa.Column(sa.Integer) | ||||
|                 return sa.func.max(Comment.id) | ||||
|  | ||||
|             comments = sa.orm.relationship( | ||||
|                 'Comment', | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| import sqlalchemy as sa | ||||
| from sqlalchemy_utils.aggregates import aggregated_attr | ||||
| from sqlalchemy_utils.aggregates import aggregated | ||||
| from tests import TestCase | ||||
|  | ||||
|  | ||||
| @@ -10,9 +10,9 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase): | ||||
|             id = sa.Column(sa.Integer, primary_key=True) | ||||
|             name = sa.Column(sa.Unicode(255)) | ||||
|  | ||||
|             @aggregated_attr('comments') | ||||
|             @aggregated('comments', sa.Column(sa.Integer, default=0)) | ||||
|             def comment_count(self): | ||||
|                 return sa.Column(sa.Integer, default=0) | ||||
|                 return sa.func.count('1') | ||||
|  | ||||
|             comments = sa.orm.relationship('Comment', backref='thread') | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Konsta Vesterinen
					Konsta Vesterinen