diff --git a/docs/index.rst b/docs/index.rst index 9ee394b..ca30b17 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -156,6 +156,13 @@ UUIDType +Aggregated attributes +--------------------- + +.. automodule:: sqlalchemy_utils.aggregates + +.. autofunction:: aggregated_attr + The generates decorator diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index aa00d76..6edbf06 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -1,3 +1,117 @@ +""" +SQLAlchemy-Utils provides way of automatically calculating aggregate values of related models and saving them to parent model. + +This solution is inspired by RoR counter cache and especially counter_culture_. + + + +.. _counter_culter:: https://github.com/magnusvk/counter_culture + + +Non-atomic implementation: + +http://stackoverflow.com/questions/13693872/ + + +We should avoid deadlocks: + +http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html + + +Simple aggregates +----------------- + +:: + + from sqlalchemy_utils import aggregated_attr + + + class Thread(Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated_attr('comments') + def comment_count(self): + return sa.Column(sa.Integer) + + comments = sa.orm.relationship( + 'Comment', + backref='thread' + ) + + + class Comment(Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + content = sa.Column(sa.UnicodeText) + thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id)) + + thread = sa.orm.relationship(Thread, backref='comments') + + + + +Custom aggregate expressions +---------------------------- + + +:: + + + from sqlalchemy_utils import aggregated_attr + + + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated_attr + def net_worth(self): + return sa.Column(sa.Integer) + + @aggregated_attr.expression + def net_worth(self): + return sa.func.sum(Product.price) + + + products = sa.orm.relationship('Product') + + + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + monthly_license_price = sa.Column(sa.Numeric) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id)) + + + +:: + + + from decimal import Decimal + + + catalog = Catalog( + name=u'My first catalog' + products=[ + Product(name='Some product', price=Decimal(1000)), + Product(name='Some other product', price=Decimal(500)) + ] + ) + session.add(catalog) + session.commit() + + catalog.net_worth # 1500 + + +""" + + from collections import defaultdict import sqlalchemy as sa @@ -6,35 +120,47 @@ from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.sql.expression import _FunctionGenerator -class aggregated_attr(declared_attr): - def __init__(self, fget, *arg, **kw): - super(aggregated_attr, self).__init__(fget, *arg, **kw) +class AggregatedAttribute(declared_attr): + def __init__( + self, + fget, + relationship, + expr, + *arg, + **kw + ): + super(AggregatedAttribute, self).__init__(fget, *arg, **kw) self.__doc__ = fget.__doc__ + self.expr = expr + self.relationship = relationship - def select_expression(self, expr): - self.__aggregate__['select_expression'] = expr + def expression(self, expr): + self.expr = expr return self def __get__(desc, self, cls): result = desc.fget(cls) - cls.__aggregates__ = { - desc.fget.__name__: desc.__aggregate__ + if not hasattr(cls, '__aggregates__'): + cls.__aggregates__ = {} + cls.__aggregates__[desc.fget.__name__] = { + 'expression': desc.expr, + 'relationship': desc.relationship } return result class AggregatedValue(object): - def __init__(self, class_, attr, relationships, select_expression): + def __init__(self, class_, attr, relationships, expr): self.class_ = class_ self.attr = attr self.relationships = relationships - if isinstance(select_expression, sa.sql.visitors.Visitable): - self.select_expression = select_expression - elif isinstance(select_expression, _FunctionGenerator): - self.select_expression = select_expression(sa.sql.literal('1')) + if isinstance(expr, sa.sql.visitors.Visitable): + self.expr = expr + elif isinstance(expr, _FunctionGenerator): + self.expr = expr(sa.sql.text('1')) else: - self.select_expression = select_expression(class_) + self.expr = expr(class_) @property def aggregate_query(self): @@ -50,7 +176,7 @@ class AggregatedValue(object): ) query = sa.select( - [self.select_expression], + [self.expr], from_obj=[from_] ) @@ -65,7 +191,7 @@ class AggregatedValue(object): ) -class AggregateValueGenerator(object): +class AggregationManager(object): def __init__(self): self.reset() @@ -73,14 +199,6 @@ class AggregateValueGenerator(object): self.generator_registry = defaultdict(list) self.pending_queries = defaultdict(list) - def generator_wrapper(self, func, relationship, select_expression): - func = aggregated_attr(func) - func.__aggregate__ = { - 'select_expression': select_expression, - 'relationship': relationship - } - return func - def register_listeners(self): sa.event.listen( sa.orm.mapper, @@ -109,7 +227,7 @@ class AggregateValueGenerator(object): class_=class_, attr=key, relationships=list(reversed(relationships)), - select_expression=value['select_expression'] + expr=value['expression'] ) ) @@ -118,93 +236,22 @@ class AggregateValueGenerator(object): class_ = obj.__class__.__name__ if class_ in self.generator_registry: for aggregate_value in self.generator_registry[class_]: - session.execute(aggregate_value.update_query) + query = aggregate_value.update_query + session.execute(query) -generator = AggregateValueGenerator() -generator.register_listeners() +manager = AggregationManager() +manager.register_listeners() -def aggregate( +def aggregated_attr( relationship, - select_expression=sa.func.count, - generator=generator + expression=sa.func.count ): - """ - - Non-atomic implementation: - - http://stackoverflow.com/questions/13693872/ - - - We should avoid deadlocks: - - http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html - - - :: - - - class Thread(Base): - __tablename__ = 'thread' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - @aggregate(sa.func.count, 'comments') - def comment_count(self): - return sa.Column(sa.Integer) - - @aggregate(sa.func.max, 'comments') - def latest_comment_id(self): - return sa.Column(sa.Integer) - - latest_comment = sa.orm.relationship('Comment', viewonly=True) - - - class Comment(Base): - __tablename__ = 'comment' - id = sa.Column(sa.Integer, primary_key=True) - content = sa.Column(sa.Unicode(255)) - thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id)) - - thread = sa.orm.relationship(Thread, backref='comments') - - - - :: - - - class Catalog(Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - @aggregate( - sa.func.sum(price) + - sa.func.coalesce(monthly_license_price, 0), - 'products' - ) - def net_worth(self): - return sa.Column(sa.Integer) - - products = sa.orm.relationship('Product') - - - class Product(Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - price = sa.Column(sa.Numeric) - monthly_license_price = sa.Column(sa.Numeric) - - catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id)) - - - """ def wraps(func): - return generator.generator_wrapper( + return AggregatedAttribute( func, relationship, - select_expression=select_expression + expression ) return wraps diff --git a/tests/__init__.py b/tests/__init__.py index 719ed5b..32b7e3b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -39,7 +39,7 @@ class TestCase(object): self.session = Session() def teardown_method(self, method): - aggregates.generator.reset() + aggregates.manager.reset() self.session.close_all() self.Base.metadata.drop_all(self.connection) self.connection.close() diff --git a/tests/aggregate/test_aggregate_combinations.py b/tests/aggregate/test_aggregate_combinations.py deleted file mode 100644 index 305fa18..0000000 --- a/tests/aggregate/test_aggregate_combinations.py +++ /dev/null @@ -1,51 +0,0 @@ -from decimal import Decimal -import sqlalchemy as sa -from sqlalchemy_utils.aggregates import aggregate -from tests import TestCase - - -class TestDeepModelPathsForAggregates(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' - - def create_models(self): - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - @aggregate('products') - def net_worth(self): - return sa.Column(sa.Numeric, default=0) - - @net_worth.select_expression - def net_worth(self): - return sa.func.sum(Product.price) - - products = sa.orm.relationship('Product', backref='catalog') - - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - price = sa.Column(sa.Numeric) - - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) - - self.Catalog = Catalog - self.Product = Product - - def test_assigns_aggregates(self): - catalog = self.Catalog( - name=u'Some catalog' - ) - self.session.add(catalog) - self.session.commit() - product = self.Product( - name=u'Some product', - price=Decimal('1000'), - catalog=catalog - ) - self.session.add(product) - self.session.commit() - self.session.refresh(catalog) - assert catalog.net_worth == Decimal('1000') diff --git a/tests/aggregate/test_deep_paths.py b/tests/aggregate/test_deep_paths.py index c0366de..1cecd1d 100644 --- a/tests/aggregate/test_deep_paths.py +++ b/tests/aggregate/test_deep_paths.py @@ -1,6 +1,6 @@ from decimal import Decimal import sqlalchemy as sa -from sqlalchemy_utils.aggregates import aggregate +from sqlalchemy_utils.aggregates import aggregated_attr from tests import TestCase @@ -11,7 +11,7 @@ class TestDeepModelPathsForAggregates(TestCase): id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - @aggregate('categories.products') + @aggregated_attr('categories.products') def product_count(self): return sa.Column(sa.Integer, default=0) diff --git a/tests/aggregate/test_simple_paths.py b/tests/aggregate/test_simple_paths.py index 66b0861..6435fc8 100644 --- a/tests/aggregate/test_simple_paths.py +++ b/tests/aggregate/test_simple_paths.py @@ -1,5 +1,5 @@ import sqlalchemy as sa -from sqlalchemy_utils.aggregates import aggregate +from sqlalchemy_utils.aggregates import aggregated_attr from tests import TestCase @@ -10,7 +10,7 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase): id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - @aggregate('comments') + @aggregated_attr('comments') def comment_count(self): return sa.Column(sa.Integer, default=0)