258 lines
6.3 KiB
Python
258 lines
6.3 KiB
Python
"""
|
|
SQLAlchemy-Utils provides way of automatically calculating aggregate values of related models and saving them to parent model.
|
|
|
|
This solution is inspired by RoR counter cache and especially counter_culture_.
|
|
|
|
|
|
|
|
.. _counter_culter:: https://github.com/magnusvk/counter_culture
|
|
|
|
|
|
Non-atomic implementation:
|
|
|
|
http://stackoverflow.com/questions/13693872/
|
|
|
|
|
|
We should avoid deadlocks:
|
|
|
|
http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
|
|
|
|
|
|
Simple aggregates
|
|
-----------------
|
|
|
|
::
|
|
|
|
from sqlalchemy_utils import aggregated_attr
|
|
|
|
|
|
class Thread(Base):
|
|
__tablename__ = 'thread'
|
|
id = sa.Column(sa.Integer, primary_key=True)
|
|
name = sa.Column(sa.Unicode(255))
|
|
|
|
@aggregated_attr('comments')
|
|
def comment_count(self):
|
|
return sa.Column(sa.Integer)
|
|
|
|
comments = sa.orm.relationship(
|
|
'Comment',
|
|
backref='thread'
|
|
)
|
|
|
|
|
|
class Comment(Base):
|
|
__tablename__ = 'comment'
|
|
id = sa.Column(sa.Integer, primary_key=True)
|
|
content = sa.Column(sa.UnicodeText)
|
|
thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
|
|
|
|
thread = sa.orm.relationship(Thread, backref='comments')
|
|
|
|
|
|
|
|
|
|
Custom aggregate expressions
|
|
----------------------------
|
|
|
|
|
|
::
|
|
|
|
|
|
from sqlalchemy_utils import aggregated_attr
|
|
|
|
|
|
class Catalog(Base):
|
|
__tablename__ = 'catalog'
|
|
id = sa.Column(sa.Integer, primary_key=True)
|
|
name = sa.Column(sa.Unicode(255))
|
|
|
|
@aggregated_attr
|
|
def net_worth(self):
|
|
return sa.Column(sa.Integer)
|
|
|
|
@aggregated_attr.expression
|
|
def net_worth(self):
|
|
return sa.func.sum(Product.price)
|
|
|
|
|
|
products = sa.orm.relationship('Product')
|
|
|
|
|
|
class Product(Base):
|
|
__tablename__ = 'product'
|
|
id = sa.Column(sa.Integer, primary_key=True)
|
|
name = sa.Column(sa.Unicode(255))
|
|
price = sa.Column(sa.Numeric)
|
|
monthly_license_price = sa.Column(sa.Numeric)
|
|
|
|
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
|
|
|
|
|
|
|
|
::
|
|
|
|
|
|
from decimal import Decimal
|
|
|
|
|
|
catalog = Catalog(
|
|
name=u'My first catalog'
|
|
products=[
|
|
Product(name='Some product', price=Decimal(1000)),
|
|
Product(name='Some other product', price=Decimal(500))
|
|
]
|
|
)
|
|
session.add(catalog)
|
|
session.commit()
|
|
|
|
catalog.net_worth # 1500
|
|
|
|
|
|
"""
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
import sqlalchemy as sa
|
|
import six
|
|
from sqlalchemy.ext.declarative import declared_attr
|
|
from sqlalchemy.sql.expression import _FunctionGenerator
|
|
|
|
|
|
class AggregatedAttribute(declared_attr):
|
|
def __init__(
|
|
self,
|
|
fget,
|
|
relationship,
|
|
expr,
|
|
*arg,
|
|
**kw
|
|
):
|
|
super(AggregatedAttribute, self).__init__(fget, *arg, **kw)
|
|
self.__doc__ = fget.__doc__
|
|
self.expr = expr
|
|
self.relationship = relationship
|
|
|
|
def expression(self, expr):
|
|
self.expr = expr
|
|
return self
|
|
|
|
def __get__(desc, self, cls):
|
|
result = desc.fget(cls)
|
|
if not hasattr(cls, '__aggregates__'):
|
|
cls.__aggregates__ = {}
|
|
cls.__aggregates__[desc.fget.__name__] = {
|
|
'expression': desc.expr,
|
|
'relationship': desc.relationship
|
|
}
|
|
return result
|
|
|
|
|
|
class AggregatedValue(object):
|
|
def __init__(self, class_, attr, relationships, expr):
|
|
self.class_ = class_
|
|
self.attr = attr
|
|
self.relationships = relationships
|
|
|
|
if isinstance(expr, sa.sql.visitors.Visitable):
|
|
self.expr = expr
|
|
elif isinstance(expr, _FunctionGenerator):
|
|
self.expr = expr(sa.sql.text('1'))
|
|
else:
|
|
self.expr = expr(class_)
|
|
|
|
@property
|
|
def aggregate_query(self):
|
|
from_ = self.relationships[0].mapper.class_
|
|
for relationship in self.relationships[0:-1]:
|
|
property_ = relationship.property
|
|
from_ = (
|
|
from_.__table__
|
|
.join(
|
|
property_.parent.class_,
|
|
property_.primaryjoin
|
|
)
|
|
)
|
|
|
|
query = sa.select(
|
|
[self.expr],
|
|
from_obj=[from_]
|
|
)
|
|
|
|
query = query.where(self.relationships[-1])
|
|
|
|
return query.correlate(self.class_).as_scalar()
|
|
|
|
@property
|
|
def update_query(self):
|
|
return self.class_.__table__.update().values(
|
|
{self.attr: self.aggregate_query}
|
|
)
|
|
|
|
|
|
class AggregationManager(object):
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.generator_registry = defaultdict(list)
|
|
self.pending_queries = defaultdict(list)
|
|
|
|
def register_listeners(self):
|
|
sa.event.listen(
|
|
sa.orm.mapper,
|
|
'mapper_configured',
|
|
self.update_generator_registry
|
|
)
|
|
sa.event.listen(
|
|
sa.orm.session.Session,
|
|
'after_flush',
|
|
self.construct_aggregate_queries
|
|
)
|
|
|
|
def update_generator_registry(self, mapper, class_):
|
|
if hasattr(class_, '__aggregates__'):
|
|
for key, value in six.iteritems(class_.__aggregates__):
|
|
relationships = []
|
|
rel_class = class_
|
|
|
|
for path_name in value['relationship'].split('.'):
|
|
rel = getattr(rel_class, path_name)
|
|
relationships.append(rel)
|
|
rel_class = rel.mapper.class_
|
|
|
|
self.generator_registry[rel_class.__name__].append(
|
|
AggregatedValue(
|
|
class_=class_,
|
|
attr=key,
|
|
relationships=list(reversed(relationships)),
|
|
expr=value['expression']
|
|
)
|
|
)
|
|
|
|
def construct_aggregate_queries(self, session, ctx):
|
|
for obj in session:
|
|
class_ = obj.__class__.__name__
|
|
if class_ in self.generator_registry:
|
|
for aggregate_value in self.generator_registry[class_]:
|
|
query = aggregate_value.update_query
|
|
session.execute(query)
|
|
|
|
|
|
manager = AggregationManager()
|
|
manager.register_listeners()
|
|
|
|
|
|
def aggregated_attr(
|
|
relationship,
|
|
expression=sa.func.count
|
|
):
|
|
def wraps(func):
|
|
return AggregatedAttribute(
|
|
func,
|
|
relationship,
|
|
expression
|
|
)
|
|
return wraps
|