Changed aggregate syntax according to Mike Bayer's feedback

This commit is contained in:
Konsta Vesterinen
2013-11-12 11:49:56 +02:00
parent 19dc87e8ef
commit 59657ec2da
6 changed files with 34 additions and 34 deletions

View File

@@ -1,4 +1,4 @@
from .aggregates import aggregated_attr
from .aggregates import aggregated
from .decorators import generates
from .exceptions import ImproperlyConfigured
from .functions import (
@@ -55,7 +55,7 @@ __version__ = '0.21.0'
__all__ = (
aggregated_attr,
aggregated,
batch_fetch,
coercion_listener,
defer_except,

View File

@@ -281,28 +281,23 @@ class AggregatedAttribute(declared_attr):
self,
fget,
relationship,
expr,
column,
*args,
**kwargs
):
super(AggregatedAttribute, self).__init__(fget, *args, **kwargs)
self.__doc__ = fget.__doc__
self.expr = expr
self.column = column
self.relationship = relationship
def expression(self, expr):
self.expr = expr
return self
def __get__(desc, self, cls):
result = desc.fget(cls)
if not hasattr(cls, '__aggregates__'):
cls.__aggregates__ = {}
cls.__aggregates__[desc.fget.__name__] = {
'expression': desc.expr,
'expression': desc.fget,
'relationship': desc.relationship
}
return result
return desc.column
class AggregatedValue(object):
@@ -430,7 +425,7 @@ class AggregationManager(object):
class_=class_,
attr=key,
relationships=list(reversed(relationships)),
expr=value['expression']
expr=value['expression'](class_)
)
)
@@ -451,14 +446,14 @@ manager = AggregationManager()
manager.register_listeners()
def aggregated_attr(
def aggregated(
relationship,
expression=sa.func.count
column
):
def wraps(func):
return AggregatedAttribute(
func,
relationship,
expression
column
)
return wraps

View File

@@ -1,6 +1,6 @@
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated_attr
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
@@ -13,12 +13,8 @@ class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr('products')
@aggregated('products', sa.Column(sa.Numeric, default=0))
def net_worth(self):
return sa.Column(sa.Numeric, default=0)
@net_worth.expression
def net_worth_expr(self):
return sa.func.sum(Product.price)
products = sa.orm.relationship('Product', backref='catalog')

View File

@@ -1,6 +1,6 @@
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated_attr
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
@@ -13,9 +13,12 @@ class TestDeepModelPathsForAggregates(TestCase):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr('categories.products')
@aggregated(
'categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.Column(sa.Integer, default=0)
return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog')
@@ -69,9 +72,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr('categories.sub_categories.products')
@aggregated(
'categories.sub_categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.Column(sa.Integer, default=0)
return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog')

View File

@@ -1,5 +1,5 @@
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated_attr
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
@@ -10,13 +10,16 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr('comments')
@aggregated(
'comments',
sa.Column(sa.Integer, default=0)
)
def comment_count(self):
return sa.Column(sa.Integer, default=0)
return sa.func.count('1')
@aggregated_attr('comments', sa.func.max)
@aggregated('comments', sa.Column(sa.Integer))
def last_comment_id(self):
return sa.Column(sa.Integer)
return sa.func.max(Comment.id)
comments = sa.orm.relationship(
'Comment',

View File

@@ -1,5 +1,5 @@
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated_attr
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
@@ -10,9 +10,9 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr('comments')
@aggregated('comments', sa.Column(sa.Integer, default=0))
def comment_count(self):
return sa.Column(sa.Integer, default=0)
return sa.func.count('1')
comments = sa.orm.relationship('Comment', backref='thread')