User defined select expression support for aggregates
This commit is contained in:
@@ -3,6 +3,7 @@ from collections import defaultdict
|
||||
import sqlalchemy as sa
|
||||
import six
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
from sqlalchemy.sql.expression import _FunctionGenerator
|
||||
|
||||
|
||||
class aggregated_attr(declared_attr):
|
||||
@@ -10,6 +11,10 @@ class aggregated_attr(declared_attr):
|
||||
super(aggregated_attr, self).__init__(fget, *arg, **kw)
|
||||
self.__doc__ = fget.__doc__
|
||||
|
||||
def select_expression(self, expr):
|
||||
self.__aggregate__['select_expression'] = expr
|
||||
return self
|
||||
|
||||
def __get__(desc, self, cls):
|
||||
result = desc.fget(cls)
|
||||
cls.__aggregates__ = {
|
||||
@@ -18,6 +23,48 @@ class aggregated_attr(declared_attr):
|
||||
return result
|
||||
|
||||
|
||||
class AggregatedValue(object):
|
||||
def __init__(self, class_, attr, relationships, select_expression):
|
||||
self.class_ = class_
|
||||
self.attr = attr
|
||||
self.relationships = relationships
|
||||
|
||||
if isinstance(select_expression, sa.sql.visitors.Visitable):
|
||||
self.select_expression = select_expression
|
||||
elif isinstance(select_expression, _FunctionGenerator):
|
||||
self.select_expression = select_expression(sa.sql.literal('1'))
|
||||
else:
|
||||
self.select_expression = select_expression(class_)
|
||||
|
||||
@property
|
||||
def aggregate_query(self):
|
||||
from_ = self.relationships[0].mapper.class_
|
||||
for relationship in self.relationships[0:-1]:
|
||||
property_ = relationship.property
|
||||
from_ = (
|
||||
from_.__table__
|
||||
.join(
|
||||
property_.parent.class_,
|
||||
property_.primaryjoin
|
||||
)
|
||||
)
|
||||
|
||||
query = sa.select(
|
||||
[self.select_expression],
|
||||
from_obj=[from_]
|
||||
)
|
||||
|
||||
query = query.where(self.relationships[-1])
|
||||
|
||||
return query.correlate(self.class_).as_scalar()
|
||||
|
||||
@property
|
||||
def update_query(self):
|
||||
return self.class_.__table__.update().values(
|
||||
{self.attr: self.aggregate_query}
|
||||
)
|
||||
|
||||
|
||||
class AggregateValueGenerator(object):
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -26,10 +73,10 @@ class AggregateValueGenerator(object):
|
||||
self.generator_registry = defaultdict(list)
|
||||
self.pending_queries = defaultdict(list)
|
||||
|
||||
def generator_wrapper(self, func, aggregate_func, relationship):
|
||||
def generator_wrapper(self, func, relationship, select_expression):
|
||||
func = aggregated_attr(func)
|
||||
func.__aggregate__ = {
|
||||
'func': aggregate_func,
|
||||
'select_expression': select_expression,
|
||||
'relationship': relationship
|
||||
}
|
||||
return func
|
||||
@@ -47,7 +94,6 @@ class AggregateValueGenerator(object):
|
||||
)
|
||||
|
||||
def update_generator_registry(self, mapper, class_):
|
||||
#self.reset()
|
||||
if hasattr(class_, '__aggregates__'):
|
||||
for key, value in six.iteritems(class_.__aggregates__):
|
||||
relationships = []
|
||||
@@ -58,53 +104,32 @@ class AggregateValueGenerator(object):
|
||||
relationships.append(rel)
|
||||
rel_class = rel.mapper.class_
|
||||
|
||||
self.generator_registry[rel_class.__name__].append({
|
||||
'class': class_,
|
||||
'attr': key,
|
||||
'relationship': list(reversed(relationships)),
|
||||
'aggregate': value['func']
|
||||
})
|
||||
self.generator_registry[rel_class.__name__].append(
|
||||
AggregatedValue(
|
||||
class_=class_,
|
||||
attr=key,
|
||||
relationships=list(reversed(relationships)),
|
||||
select_expression=value['select_expression']
|
||||
)
|
||||
)
|
||||
|
||||
def construct_aggregate_queries(self, session, ctx):
|
||||
for obj in session:
|
||||
class_ = obj.__class__.__name__
|
||||
if class_ in self.generator_registry:
|
||||
for func in self.generator_registry[class_]:
|
||||
if isinstance(func['aggregate'], six.string_types):
|
||||
agg_func = eval(func['aggregate'])
|
||||
else:
|
||||
agg_func = func['aggregate'](obj.__class__.id)
|
||||
|
||||
aggregate_value = (
|
||||
session.query(agg_func)
|
||||
)
|
||||
|
||||
for rel in func['relationship'][0:-1]:
|
||||
aggregate_value = (
|
||||
aggregate_value
|
||||
.join(
|
||||
rel.property.parent.class_,
|
||||
rel.property.primaryjoin
|
||||
)
|
||||
)
|
||||
aggregate_value = aggregate_value.filter(
|
||||
func['relationship'][-1]
|
||||
)
|
||||
|
||||
aggregate_value = (
|
||||
aggregate_value.correlate(func['class']).as_scalar()
|
||||
)
|
||||
query = func['class'].__table__.update().values(
|
||||
{func['attr']: aggregate_value}
|
||||
)
|
||||
session.execute(query)
|
||||
for aggregate_value in self.generator_registry[class_]:
|
||||
session.execute(aggregate_value.update_query)
|
||||
|
||||
|
||||
generator = AggregateValueGenerator()
|
||||
generator.register_listeners()
|
||||
|
||||
|
||||
def aggregate(aggregate_func, relationship, generator=generator):
|
||||
def aggregate(
|
||||
relationship,
|
||||
select_expression=sa.func.count,
|
||||
generator=generator
|
||||
):
|
||||
"""
|
||||
|
||||
Non-atomic implementation:
|
||||
@@ -179,7 +204,7 @@ def aggregate(aggregate_func, relationship, generator=generator):
|
||||
def wraps(func):
|
||||
return generator.generator_wrapper(
|
||||
func,
|
||||
aggregate_func,
|
||||
relationship
|
||||
relationship,
|
||||
select_expression=select_expression
|
||||
)
|
||||
return wraps
|
||||
|
@@ -5,17 +5,23 @@ from tests import TestCase
|
||||
|
||||
|
||||
class TestDeepModelPathsForAggregates(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregate(sa.func.count, 'categories.products')
|
||||
def product_count(self):
|
||||
return sa.Column(sa.Integer, default=0)
|
||||
@aggregate('products')
|
||||
def net_worth(self):
|
||||
return sa.Column(sa.Numeric, default=0)
|
||||
|
||||
categories = sa.orm.relationship('Product', backref='catalog')
|
||||
@net_worth.select_expression
|
||||
def net_worth(self):
|
||||
return sa.func.sum(Product.price)
|
||||
|
||||
products = sa.orm.relationship('Product', backref='catalog')
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
@@ -37,8 +43,9 @@ class TestDeepModelPathsForAggregates(TestCase):
|
||||
product = self.Product(
|
||||
name=u'Some product',
|
||||
price=Decimal('1000'),
|
||||
catalog=catalog
|
||||
)
|
||||
self.session.add(product)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
assert catalog.product_count == 1
|
||||
assert catalog.net_worth == Decimal('1000')
|
||||
|
@@ -11,7 +11,7 @@ class TestDeepModelPathsForAggregates(TestCase):
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregate(sa.func.count, 'categories.products')
|
||||
@aggregate('categories.products')
|
||||
def product_count(self):
|
||||
return sa.Column(sa.Integer, default=0)
|
||||
|
||||
|
@@ -10,7 +10,7 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregate(sa.func.count, 'comments')
|
||||
@aggregate('comments')
|
||||
def comment_count(self):
|
||||
return sa.Column(sa.Integer, default=0)
|
||||
|
||||
|
Reference in New Issue
Block a user