User defined select expression support for aggregates

This commit is contained in:
Konsta Vesterinen
2013-11-04 11:53:09 +02:00
parent e4c9d338dc
commit 2e7053f9b2
4 changed files with 80 additions and 48 deletions

View File

@@ -3,6 +3,7 @@ from collections import defaultdict
import sqlalchemy as sa
import six
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.sql.expression import _FunctionGenerator
class aggregated_attr(declared_attr):
@@ -10,6 +11,10 @@ class aggregated_attr(declared_attr):
super(aggregated_attr, self).__init__(fget, *arg, **kw)
self.__doc__ = fget.__doc__
def select_expression(self, expr):
self.__aggregate__['select_expression'] = expr
return self
def __get__(desc, self, cls):
result = desc.fget(cls)
cls.__aggregates__ = {
@@ -18,6 +23,48 @@ class aggregated_attr(declared_attr):
return result
class AggregatedValue(object):
def __init__(self, class_, attr, relationships, select_expression):
self.class_ = class_
self.attr = attr
self.relationships = relationships
if isinstance(select_expression, sa.sql.visitors.Visitable):
self.select_expression = select_expression
elif isinstance(select_expression, _FunctionGenerator):
self.select_expression = select_expression(sa.sql.literal('1'))
else:
self.select_expression = select_expression(class_)
@property
def aggregate_query(self):
from_ = self.relationships[0].mapper.class_
for relationship in self.relationships[0:-1]:
property_ = relationship.property
from_ = (
from_.__table__
.join(
property_.parent.class_,
property_.primaryjoin
)
)
query = sa.select(
[self.select_expression],
from_obj=[from_]
)
query = query.where(self.relationships[-1])
return query.correlate(self.class_).as_scalar()
@property
def update_query(self):
return self.class_.__table__.update().values(
{self.attr: self.aggregate_query}
)
class AggregateValueGenerator(object):
def __init__(self):
self.reset()
@@ -26,10 +73,10 @@ class AggregateValueGenerator(object):
self.generator_registry = defaultdict(list)
self.pending_queries = defaultdict(list)
def generator_wrapper(self, func, aggregate_func, relationship):
def generator_wrapper(self, func, relationship, select_expression):
func = aggregated_attr(func)
func.__aggregate__ = {
'func': aggregate_func,
'select_expression': select_expression,
'relationship': relationship
}
return func
@@ -47,7 +94,6 @@ class AggregateValueGenerator(object):
)
def update_generator_registry(self, mapper, class_):
#self.reset()
if hasattr(class_, '__aggregates__'):
for key, value in six.iteritems(class_.__aggregates__):
relationships = []
@@ -58,53 +104,32 @@ class AggregateValueGenerator(object):
relationships.append(rel)
rel_class = rel.mapper.class_
self.generator_registry[rel_class.__name__].append({
'class': class_,
'attr': key,
'relationship': list(reversed(relationships)),
'aggregate': value['func']
})
self.generator_registry[rel_class.__name__].append(
AggregatedValue(
class_=class_,
attr=key,
relationships=list(reversed(relationships)),
select_expression=value['select_expression']
)
)
def construct_aggregate_queries(self, session, ctx):
for obj in session:
class_ = obj.__class__.__name__
if class_ in self.generator_registry:
for func in self.generator_registry[class_]:
if isinstance(func['aggregate'], six.string_types):
agg_func = eval(func['aggregate'])
else:
agg_func = func['aggregate'](obj.__class__.id)
aggregate_value = (
session.query(agg_func)
)
for rel in func['relationship'][0:-1]:
aggregate_value = (
aggregate_value
.join(
rel.property.parent.class_,
rel.property.primaryjoin
)
)
aggregate_value = aggregate_value.filter(
func['relationship'][-1]
)
aggregate_value = (
aggregate_value.correlate(func['class']).as_scalar()
)
query = func['class'].__table__.update().values(
{func['attr']: aggregate_value}
)
session.execute(query)
for aggregate_value in self.generator_registry[class_]:
session.execute(aggregate_value.update_query)
generator = AggregateValueGenerator()
generator.register_listeners()
def aggregate(aggregate_func, relationship, generator=generator):
def aggregate(
relationship,
select_expression=sa.func.count,
generator=generator
):
"""
Non-atomic implementation:
@@ -179,7 +204,7 @@ def aggregate(aggregate_func, relationship, generator=generator):
def wraps(func):
return generator.generator_wrapper(
func,
aggregate_func,
relationship
relationship,
select_expression=select_expression
)
return wraps

View File

@@ -5,17 +5,23 @@ from tests import TestCase
class TestDeepModelPathsForAggregates(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregate(sa.func.count, 'categories.products')
def product_count(self):
return sa.Column(sa.Integer, default=0)
@aggregate('products')
def net_worth(self):
return sa.Column(sa.Numeric, default=0)
categories = sa.orm.relationship('Product', backref='catalog')
@net_worth.select_expression
def net_worth(self):
return sa.func.sum(Product.price)
products = sa.orm.relationship('Product', backref='catalog')
class Product(self.Base):
__tablename__ = 'product'
@@ -37,8 +43,9 @@ class TestDeepModelPathsForAggregates(TestCase):
product = self.Product(
name=u'Some product',
price=Decimal('1000'),
catalog=catalog
)
self.session.add(product)
self.session.commit()
self.session.refresh(catalog)
assert catalog.product_count == 1
assert catalog.net_worth == Decimal('1000')

View File

@@ -11,7 +11,7 @@ class TestDeepModelPathsForAggregates(TestCase):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregate(sa.func.count, 'categories.products')
@aggregate('categories.products')
def product_count(self):
return sa.Column(sa.Integer, default=0)

View File

@@ -10,7 +10,7 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregate(sa.func.count, 'comments')
@aggregate('comments')
def comment_count(self):
return sa.Column(sa.Integer, default=0)