From 2e7053f9b22be8499c054771f147bb31c5d7a300 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Mon, 4 Nov 2013 11:53:09 +0200 Subject: [PATCH] User defined select expression support for aggregates --- sqlalchemy_utils/aggregates.py | 107 +++++++++++------- .../aggregate/test_aggregate_combinations.py | 17 ++- tests/aggregate/test_deep_paths.py | 2 +- tests/aggregate/test_simple_paths.py | 2 +- 4 files changed, 80 insertions(+), 48 deletions(-) diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index 4a85dc6..aa00d76 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -3,6 +3,7 @@ from collections import defaultdict import sqlalchemy as sa import six from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.sql.expression import _FunctionGenerator class aggregated_attr(declared_attr): @@ -10,6 +11,10 @@ class aggregated_attr(declared_attr): super(aggregated_attr, self).__init__(fget, *arg, **kw) self.__doc__ = fget.__doc__ + def select_expression(self, expr): + self.__aggregate__['select_expression'] = expr + return self + def __get__(desc, self, cls): result = desc.fget(cls) cls.__aggregates__ = { @@ -18,6 +23,48 @@ class aggregated_attr(declared_attr): return result +class AggregatedValue(object): + def __init__(self, class_, attr, relationships, select_expression): + self.class_ = class_ + self.attr = attr + self.relationships = relationships + + if isinstance(select_expression, sa.sql.visitors.Visitable): + self.select_expression = select_expression + elif isinstance(select_expression, _FunctionGenerator): + self.select_expression = select_expression(sa.sql.literal('1')) + else: + self.select_expression = select_expression(class_) + + @property + def aggregate_query(self): + from_ = self.relationships[0].mapper.class_ + for relationship in self.relationships[0:-1]: + property_ = relationship.property + from_ = ( + from_.__table__ + .join( + property_.parent.class_, + property_.primaryjoin + ) + ) + + query = sa.select( + [self.select_expression], + from_obj=[from_] + ) + + query = query.where(self.relationships[-1]) + + return query.correlate(self.class_).as_scalar() + + @property + def update_query(self): + return self.class_.__table__.update().values( + {self.attr: self.aggregate_query} + ) + + class AggregateValueGenerator(object): def __init__(self): self.reset() @@ -26,10 +73,10 @@ class AggregateValueGenerator(object): self.generator_registry = defaultdict(list) self.pending_queries = defaultdict(list) - def generator_wrapper(self, func, aggregate_func, relationship): + def generator_wrapper(self, func, relationship, select_expression): func = aggregated_attr(func) func.__aggregate__ = { - 'func': aggregate_func, + 'select_expression': select_expression, 'relationship': relationship } return func @@ -47,7 +94,6 @@ class AggregateValueGenerator(object): ) def update_generator_registry(self, mapper, class_): - #self.reset() if hasattr(class_, '__aggregates__'): for key, value in six.iteritems(class_.__aggregates__): relationships = [] @@ -58,53 +104,32 @@ class AggregateValueGenerator(object): relationships.append(rel) rel_class = rel.mapper.class_ - self.generator_registry[rel_class.__name__].append({ - 'class': class_, - 'attr': key, - 'relationship': list(reversed(relationships)), - 'aggregate': value['func'] - }) + self.generator_registry[rel_class.__name__].append( + AggregatedValue( + class_=class_, + attr=key, + relationships=list(reversed(relationships)), + select_expression=value['select_expression'] + ) + ) def construct_aggregate_queries(self, session, ctx): for obj in session: class_ = obj.__class__.__name__ if class_ in self.generator_registry: - for func in self.generator_registry[class_]: - if isinstance(func['aggregate'], six.string_types): - agg_func = eval(func['aggregate']) - else: - agg_func = func['aggregate'](obj.__class__.id) - - aggregate_value = ( - session.query(agg_func) - ) - - for rel in func['relationship'][0:-1]: - aggregate_value = ( - aggregate_value - .join( - rel.property.parent.class_, - rel.property.primaryjoin - ) - ) - aggregate_value = aggregate_value.filter( - func['relationship'][-1] - ) - - aggregate_value = ( - aggregate_value.correlate(func['class']).as_scalar() - ) - query = func['class'].__table__.update().values( - {func['attr']: aggregate_value} - ) - session.execute(query) + for aggregate_value in self.generator_registry[class_]: + session.execute(aggregate_value.update_query) generator = AggregateValueGenerator() generator.register_listeners() -def aggregate(aggregate_func, relationship, generator=generator): +def aggregate( + relationship, + select_expression=sa.func.count, + generator=generator +): """ Non-atomic implementation: @@ -179,7 +204,7 @@ def aggregate(aggregate_func, relationship, generator=generator): def wraps(func): return generator.generator_wrapper( func, - aggregate_func, - relationship + relationship, + select_expression=select_expression ) return wraps diff --git a/tests/aggregate/test_aggregate_combinations.py b/tests/aggregate/test_aggregate_combinations.py index 5861929..305fa18 100644 --- a/tests/aggregate/test_aggregate_combinations.py +++ b/tests/aggregate/test_aggregate_combinations.py @@ -5,17 +5,23 @@ from tests import TestCase class TestDeepModelPathsForAggregates(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + def create_models(self): class Catalog(self.Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - @aggregate(sa.func.count, 'categories.products') - def product_count(self): - return sa.Column(sa.Integer, default=0) + @aggregate('products') + def net_worth(self): + return sa.Column(sa.Numeric, default=0) - categories = sa.orm.relationship('Product', backref='catalog') + @net_worth.select_expression + def net_worth(self): + return sa.func.sum(Product.price) + + products = sa.orm.relationship('Product', backref='catalog') class Product(self.Base): __tablename__ = 'product' @@ -37,8 +43,9 @@ class TestDeepModelPathsForAggregates(TestCase): product = self.Product( name=u'Some product', price=Decimal('1000'), + catalog=catalog ) self.session.add(product) self.session.commit() self.session.refresh(catalog) - assert catalog.product_count == 1 + assert catalog.net_worth == Decimal('1000') diff --git a/tests/aggregate/test_deep_paths.py b/tests/aggregate/test_deep_paths.py index e16ef1d..c0366de 100644 --- a/tests/aggregate/test_deep_paths.py +++ b/tests/aggregate/test_deep_paths.py @@ -11,7 +11,7 @@ class TestDeepModelPathsForAggregates(TestCase): id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - @aggregate(sa.func.count, 'categories.products') + @aggregate('categories.products') def product_count(self): return sa.Column(sa.Integer, default=0) diff --git a/tests/aggregate/test_simple_paths.py b/tests/aggregate/test_simple_paths.py index cdd19c8..66b0861 100644 --- a/tests/aggregate/test_simple_paths.py +++ b/tests/aggregate/test_simple_paths.py @@ -10,7 +10,7 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase): id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - @aggregate(sa.func.count, 'comments') + @aggregate('comments') def comment_count(self): return sa.Column(sa.Integer, default=0)