Changed aggregate syntax according to Mike Bayer's feedback

This commit is contained in:
Konsta Vesterinen
2013-11-12 11:49:56 +02:00
parent 19dc87e8ef
commit 59657ec2da
6 changed files with 34 additions and 34 deletions

View File

@@ -1,4 +1,4 @@
from .aggregates import aggregated_attr from .aggregates import aggregated
from .decorators import generates from .decorators import generates
from .exceptions import ImproperlyConfigured from .exceptions import ImproperlyConfigured
from .functions import ( from .functions import (
@@ -55,7 +55,7 @@ __version__ = '0.21.0'
__all__ = ( __all__ = (
aggregated_attr, aggregated,
batch_fetch, batch_fetch,
coercion_listener, coercion_listener,
defer_except, defer_except,

View File

@@ -281,28 +281,23 @@ class AggregatedAttribute(declared_attr):
self, self,
fget, fget,
relationship, relationship,
expr, column,
*args, *args,
**kwargs **kwargs
): ):
super(AggregatedAttribute, self).__init__(fget, *args, **kwargs) super(AggregatedAttribute, self).__init__(fget, *args, **kwargs)
self.__doc__ = fget.__doc__ self.__doc__ = fget.__doc__
self.expr = expr self.column = column
self.relationship = relationship self.relationship = relationship
def expression(self, expr):
self.expr = expr
return self
def __get__(desc, self, cls): def __get__(desc, self, cls):
result = desc.fget(cls)
if not hasattr(cls, '__aggregates__'): if not hasattr(cls, '__aggregates__'):
cls.__aggregates__ = {} cls.__aggregates__ = {}
cls.__aggregates__[desc.fget.__name__] = { cls.__aggregates__[desc.fget.__name__] = {
'expression': desc.expr, 'expression': desc.fget,
'relationship': desc.relationship 'relationship': desc.relationship
} }
return result return desc.column
class AggregatedValue(object): class AggregatedValue(object):
@@ -430,7 +425,7 @@ class AggregationManager(object):
class_=class_, class_=class_,
attr=key, attr=key,
relationships=list(reversed(relationships)), relationships=list(reversed(relationships)),
expr=value['expression'] expr=value['expression'](class_)
) )
) )
@@ -451,14 +446,14 @@ manager = AggregationManager()
manager.register_listeners() manager.register_listeners()
def aggregated_attr( def aggregated(
relationship, relationship,
expression=sa.func.count column
): ):
def wraps(func): def wraps(func):
return AggregatedAttribute( return AggregatedAttribute(
func, func,
relationship, relationship,
expression column
) )
return wraps return wraps

View File

@@ -1,6 +1,6 @@
from decimal import Decimal from decimal import Decimal
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated_attr from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase from tests import TestCase
@@ -13,12 +13,8 @@ class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase):
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
@aggregated_attr('products') @aggregated('products', sa.Column(sa.Numeric, default=0))
def net_worth(self): def net_worth(self):
return sa.Column(sa.Numeric, default=0)
@net_worth.expression
def net_worth_expr(self):
return sa.func.sum(Product.price) return sa.func.sum(Product.price)
products = sa.orm.relationship('Product', backref='catalog') products = sa.orm.relationship('Product', backref='catalog')

View File

@@ -1,6 +1,6 @@
from decimal import Decimal from decimal import Decimal
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated_attr from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase from tests import TestCase
@@ -13,9 +13,12 @@ class TestDeepModelPathsForAggregates(TestCase):
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
@aggregated_attr('categories.products') @aggregated(
'categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self): def product_count(self):
return sa.Column(sa.Integer, default=0) return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog') categories = sa.orm.relationship('Category', backref='catalog')
@@ -69,9 +72,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
@aggregated_attr('categories.sub_categories.products') @aggregated(
'categories.sub_categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self): def product_count(self):
return sa.Column(sa.Integer, default=0) return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog') categories = sa.orm.relationship('Category', backref='catalog')

View File

@@ -1,5 +1,5 @@
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated_attr from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase from tests import TestCase
@@ -10,13 +10,16 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
@aggregated_attr('comments') @aggregated(
'comments',
sa.Column(sa.Integer, default=0)
)
def comment_count(self): def comment_count(self):
return sa.Column(sa.Integer, default=0) return sa.func.count('1')
@aggregated_attr('comments', sa.func.max) @aggregated('comments', sa.Column(sa.Integer))
def last_comment_id(self): def last_comment_id(self):
return sa.Column(sa.Integer) return sa.func.max(Comment.id)
comments = sa.orm.relationship( comments = sa.orm.relationship(
'Comment', 'Comment',

View File

@@ -1,5 +1,5 @@
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated_attr from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase from tests import TestCase
@@ -10,9 +10,9 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
@aggregated_attr('comments') @aggregated('comments', sa.Column(sa.Integer, default=0))
def comment_count(self): def comment_count(self):
return sa.Column(sa.Integer, default=0) return sa.func.count('1')
comments = sa.orm.relationship('Comment', backref='thread') comments = sa.orm.relationship('Comment', backref='thread')