Changed aggregate syntax according to Mike Bayer's feedback
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .aggregates import aggregated_attr
|
||||
from .aggregates import aggregated
|
||||
from .decorators import generates
|
||||
from .exceptions import ImproperlyConfigured
|
||||
from .functions import (
|
||||
@@ -55,7 +55,7 @@ __version__ = '0.21.0'
|
||||
|
||||
|
||||
__all__ = (
|
||||
aggregated_attr,
|
||||
aggregated,
|
||||
batch_fetch,
|
||||
coercion_listener,
|
||||
defer_except,
|
||||
|
@@ -281,28 +281,23 @@ class AggregatedAttribute(declared_attr):
|
||||
self,
|
||||
fget,
|
||||
relationship,
|
||||
expr,
|
||||
column,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super(AggregatedAttribute, self).__init__(fget, *args, **kwargs)
|
||||
self.__doc__ = fget.__doc__
|
||||
self.expr = expr
|
||||
self.column = column
|
||||
self.relationship = relationship
|
||||
|
||||
def expression(self, expr):
|
||||
self.expr = expr
|
||||
return self
|
||||
|
||||
def __get__(desc, self, cls):
|
||||
result = desc.fget(cls)
|
||||
if not hasattr(cls, '__aggregates__'):
|
||||
cls.__aggregates__ = {}
|
||||
cls.__aggregates__[desc.fget.__name__] = {
|
||||
'expression': desc.expr,
|
||||
'expression': desc.fget,
|
||||
'relationship': desc.relationship
|
||||
}
|
||||
return result
|
||||
return desc.column
|
||||
|
||||
|
||||
class AggregatedValue(object):
|
||||
@@ -430,7 +425,7 @@ class AggregationManager(object):
|
||||
class_=class_,
|
||||
attr=key,
|
||||
relationships=list(reversed(relationships)),
|
||||
expr=value['expression']
|
||||
expr=value['expression'](class_)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -451,14 +446,14 @@ manager = AggregationManager()
|
||||
manager.register_listeners()
|
||||
|
||||
|
||||
def aggregated_attr(
|
||||
def aggregated(
|
||||
relationship,
|
||||
expression=sa.func.count
|
||||
column
|
||||
):
|
||||
def wraps(func):
|
||||
return AggregatedAttribute(
|
||||
func,
|
||||
relationship,
|
||||
expression
|
||||
column
|
||||
)
|
||||
return wraps
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from decimal import Decimal
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.aggregates import aggregated_attr
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
@@ -13,12 +13,8 @@ class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase):
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated_attr('products')
|
||||
@aggregated('products', sa.Column(sa.Numeric, default=0))
|
||||
def net_worth(self):
|
||||
return sa.Column(sa.Numeric, default=0)
|
||||
|
||||
@net_worth.expression
|
||||
def net_worth_expr(self):
|
||||
return sa.func.sum(Product.price)
|
||||
|
||||
products = sa.orm.relationship('Product', backref='catalog')
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from decimal import Decimal
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.aggregates import aggregated_attr
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
@@ -13,9 +13,12 @@ class TestDeepModelPathsForAggregates(TestCase):
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated_attr('categories.products')
|
||||
@aggregated(
|
||||
'categories.products',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def product_count(self):
|
||||
return sa.Column(sa.Integer, default=0)
|
||||
return sa.func.count('1')
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
|
||||
@@ -69,9 +72,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated_attr('categories.sub_categories.products')
|
||||
@aggregated(
|
||||
'categories.sub_categories.products',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def product_count(self):
|
||||
return sa.Column(sa.Integer, default=0)
|
||||
return sa.func.count('1')
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.aggregates import aggregated_attr
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
@@ -10,13 +10,16 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated_attr('comments')
|
||||
@aggregated(
|
||||
'comments',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def comment_count(self):
|
||||
return sa.Column(sa.Integer, default=0)
|
||||
return sa.func.count('1')
|
||||
|
||||
@aggregated_attr('comments', sa.func.max)
|
||||
@aggregated('comments', sa.Column(sa.Integer))
|
||||
def last_comment_id(self):
|
||||
return sa.Column(sa.Integer)
|
||||
return sa.func.max(Comment.id)
|
||||
|
||||
comments = sa.orm.relationship(
|
||||
'Comment',
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.aggregates import aggregated_attr
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated_attr('comments')
|
||||
@aggregated('comments', sa.Column(sa.Integer, default=0))
|
||||
def comment_count(self):
|
||||
return sa.Column(sa.Integer, default=0)
|
||||
return sa.func.count('1')
|
||||
|
||||
comments = sa.orm.relationship('Comment', backref='thread')
|
||||
|
||||
|
Reference in New Issue
Block a user