User defined select expression support for aggregates
This commit is contained in:
@@ -3,6 +3,7 @@ from collections import defaultdict
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
import six
|
import six
|
||||||
from sqlalchemy.ext.declarative import declared_attr
|
from sqlalchemy.ext.declarative import declared_attr
|
||||||
|
from sqlalchemy.sql.expression import _FunctionGenerator
|
||||||
|
|
||||||
|
|
||||||
class aggregated_attr(declared_attr):
|
class aggregated_attr(declared_attr):
|
||||||
@@ -10,6 +11,10 @@ class aggregated_attr(declared_attr):
|
|||||||
super(aggregated_attr, self).__init__(fget, *arg, **kw)
|
super(aggregated_attr, self).__init__(fget, *arg, **kw)
|
||||||
self.__doc__ = fget.__doc__
|
self.__doc__ = fget.__doc__
|
||||||
|
|
||||||
|
def select_expression(self, expr):
|
||||||
|
self.__aggregate__['select_expression'] = expr
|
||||||
|
return self
|
||||||
|
|
||||||
def __get__(desc, self, cls):
|
def __get__(desc, self, cls):
|
||||||
result = desc.fget(cls)
|
result = desc.fget(cls)
|
||||||
cls.__aggregates__ = {
|
cls.__aggregates__ = {
|
||||||
@@ -18,6 +23,48 @@ class aggregated_attr(declared_attr):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class AggregatedValue(object):
|
||||||
|
def __init__(self, class_, attr, relationships, select_expression):
|
||||||
|
self.class_ = class_
|
||||||
|
self.attr = attr
|
||||||
|
self.relationships = relationships
|
||||||
|
|
||||||
|
if isinstance(select_expression, sa.sql.visitors.Visitable):
|
||||||
|
self.select_expression = select_expression
|
||||||
|
elif isinstance(select_expression, _FunctionGenerator):
|
||||||
|
self.select_expression = select_expression(sa.sql.literal('1'))
|
||||||
|
else:
|
||||||
|
self.select_expression = select_expression(class_)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def aggregate_query(self):
|
||||||
|
from_ = self.relationships[0].mapper.class_
|
||||||
|
for relationship in self.relationships[0:-1]:
|
||||||
|
property_ = relationship.property
|
||||||
|
from_ = (
|
||||||
|
from_.__table__
|
||||||
|
.join(
|
||||||
|
property_.parent.class_,
|
||||||
|
property_.primaryjoin
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
query = sa.select(
|
||||||
|
[self.select_expression],
|
||||||
|
from_obj=[from_]
|
||||||
|
)
|
||||||
|
|
||||||
|
query = query.where(self.relationships[-1])
|
||||||
|
|
||||||
|
return query.correlate(self.class_).as_scalar()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def update_query(self):
|
||||||
|
return self.class_.__table__.update().values(
|
||||||
|
{self.attr: self.aggregate_query}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AggregateValueGenerator(object):
|
class AggregateValueGenerator(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -26,10 +73,10 @@ class AggregateValueGenerator(object):
|
|||||||
self.generator_registry = defaultdict(list)
|
self.generator_registry = defaultdict(list)
|
||||||
self.pending_queries = defaultdict(list)
|
self.pending_queries = defaultdict(list)
|
||||||
|
|
||||||
def generator_wrapper(self, func, aggregate_func, relationship):
|
def generator_wrapper(self, func, relationship, select_expression):
|
||||||
func = aggregated_attr(func)
|
func = aggregated_attr(func)
|
||||||
func.__aggregate__ = {
|
func.__aggregate__ = {
|
||||||
'func': aggregate_func,
|
'select_expression': select_expression,
|
||||||
'relationship': relationship
|
'relationship': relationship
|
||||||
}
|
}
|
||||||
return func
|
return func
|
||||||
@@ -47,7 +94,6 @@ class AggregateValueGenerator(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def update_generator_registry(self, mapper, class_):
|
def update_generator_registry(self, mapper, class_):
|
||||||
#self.reset()
|
|
||||||
if hasattr(class_, '__aggregates__'):
|
if hasattr(class_, '__aggregates__'):
|
||||||
for key, value in six.iteritems(class_.__aggregates__):
|
for key, value in six.iteritems(class_.__aggregates__):
|
||||||
relationships = []
|
relationships = []
|
||||||
@@ -58,53 +104,32 @@ class AggregateValueGenerator(object):
|
|||||||
relationships.append(rel)
|
relationships.append(rel)
|
||||||
rel_class = rel.mapper.class_
|
rel_class = rel.mapper.class_
|
||||||
|
|
||||||
self.generator_registry[rel_class.__name__].append({
|
self.generator_registry[rel_class.__name__].append(
|
||||||
'class': class_,
|
AggregatedValue(
|
||||||
'attr': key,
|
class_=class_,
|
||||||
'relationship': list(reversed(relationships)),
|
attr=key,
|
||||||
'aggregate': value['func']
|
relationships=list(reversed(relationships)),
|
||||||
})
|
select_expression=value['select_expression']
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def construct_aggregate_queries(self, session, ctx):
|
def construct_aggregate_queries(self, session, ctx):
|
||||||
for obj in session:
|
for obj in session:
|
||||||
class_ = obj.__class__.__name__
|
class_ = obj.__class__.__name__
|
||||||
if class_ in self.generator_registry:
|
if class_ in self.generator_registry:
|
||||||
for func in self.generator_registry[class_]:
|
for aggregate_value in self.generator_registry[class_]:
|
||||||
if isinstance(func['aggregate'], six.string_types):
|
session.execute(aggregate_value.update_query)
|
||||||
agg_func = eval(func['aggregate'])
|
|
||||||
else:
|
|
||||||
agg_func = func['aggregate'](obj.__class__.id)
|
|
||||||
|
|
||||||
aggregate_value = (
|
|
||||||
session.query(agg_func)
|
|
||||||
)
|
|
||||||
|
|
||||||
for rel in func['relationship'][0:-1]:
|
|
||||||
aggregate_value = (
|
|
||||||
aggregate_value
|
|
||||||
.join(
|
|
||||||
rel.property.parent.class_,
|
|
||||||
rel.property.primaryjoin
|
|
||||||
)
|
|
||||||
)
|
|
||||||
aggregate_value = aggregate_value.filter(
|
|
||||||
func['relationship'][-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
aggregate_value = (
|
|
||||||
aggregate_value.correlate(func['class']).as_scalar()
|
|
||||||
)
|
|
||||||
query = func['class'].__table__.update().values(
|
|
||||||
{func['attr']: aggregate_value}
|
|
||||||
)
|
|
||||||
session.execute(query)
|
|
||||||
|
|
||||||
|
|
||||||
generator = AggregateValueGenerator()
|
generator = AggregateValueGenerator()
|
||||||
generator.register_listeners()
|
generator.register_listeners()
|
||||||
|
|
||||||
|
|
||||||
def aggregate(aggregate_func, relationship, generator=generator):
|
def aggregate(
|
||||||
|
relationship,
|
||||||
|
select_expression=sa.func.count,
|
||||||
|
generator=generator
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Non-atomic implementation:
|
Non-atomic implementation:
|
||||||
@@ -179,7 +204,7 @@ def aggregate(aggregate_func, relationship, generator=generator):
|
|||||||
def wraps(func):
|
def wraps(func):
|
||||||
return generator.generator_wrapper(
|
return generator.generator_wrapper(
|
||||||
func,
|
func,
|
||||||
aggregate_func,
|
relationship,
|
||||||
relationship
|
select_expression=select_expression
|
||||||
)
|
)
|
||||||
return wraps
|
return wraps
|
||||||
|
@@ -5,17 +5,23 @@ from tests import TestCase
|
|||||||
|
|
||||||
|
|
||||||
class TestDeepModelPathsForAggregates(TestCase):
|
class TestDeepModelPathsForAggregates(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
class Catalog(self.Base):
|
class Catalog(self.Base):
|
||||||
__tablename__ = 'catalog'
|
__tablename__ = 'catalog'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
@aggregate(sa.func.count, 'categories.products')
|
@aggregate('products')
|
||||||
def product_count(self):
|
def net_worth(self):
|
||||||
return sa.Column(sa.Integer, default=0)
|
return sa.Column(sa.Numeric, default=0)
|
||||||
|
|
||||||
categories = sa.orm.relationship('Product', backref='catalog')
|
@net_worth.select_expression
|
||||||
|
def net_worth(self):
|
||||||
|
return sa.func.sum(Product.price)
|
||||||
|
|
||||||
|
products = sa.orm.relationship('Product', backref='catalog')
|
||||||
|
|
||||||
class Product(self.Base):
|
class Product(self.Base):
|
||||||
__tablename__ = 'product'
|
__tablename__ = 'product'
|
||||||
@@ -37,8 +43,9 @@ class TestDeepModelPathsForAggregates(TestCase):
|
|||||||
product = self.Product(
|
product = self.Product(
|
||||||
name=u'Some product',
|
name=u'Some product',
|
||||||
price=Decimal('1000'),
|
price=Decimal('1000'),
|
||||||
|
catalog=catalog
|
||||||
)
|
)
|
||||||
self.session.add(product)
|
self.session.add(product)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
self.session.refresh(catalog)
|
self.session.refresh(catalog)
|
||||||
assert catalog.product_count == 1
|
assert catalog.net_worth == Decimal('1000')
|
||||||
|
@@ -11,7 +11,7 @@ class TestDeepModelPathsForAggregates(TestCase):
|
|||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
@aggregate(sa.func.count, 'categories.products')
|
@aggregate('categories.products')
|
||||||
def product_count(self):
|
def product_count(self):
|
||||||
return sa.Column(sa.Integer, default=0)
|
return sa.Column(sa.Integer, default=0)
|
||||||
|
|
||||||
|
@@ -10,7 +10,7 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
|
|||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
@aggregate(sa.func.count, 'comments')
|
@aggregate('comments')
|
||||||
def comment_count(self):
|
def comment_count(self):
|
||||||
return sa.Column(sa.Integer, default=0)
|
return sa.Column(sa.Integer, default=0)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user