Files
deb-python-sqlalchemy-utils/sqlalchemy_utils/aggregates.py
2013-11-04 16:19:45 +02:00

362 lines
9.0 KiB
Python

"""
SQLAlchemy-Utils provides way of automatically calculating aggregate values of
related models and saving them to parent model.
This solution is inspired by RoR counter cache,
`counter_culture`_ and `stackoverflow reply by Michael Bayer`_.
Why?
----
Many times you may have situations where you need to calculate dynamically some
aggregate value for given model. Some simple examples include:
- Number of products in a catalog
- Average rating for movie
- Latest forum post
- Total price of orders for given customer
Now all these aggregates can be elegantly implemented with SQLAlchemy
column_property_ function. However when your data grows calculating these
values on the fly might start to hurt the performance of your application. The
more aggregates you are using the more performance penalty you get.
This module provides way of calculating these values automatically and
efficiently at the time of modification rather than on the fly.
Features
--------
* Automatically updates aggregate columns when aggregated values change
* Supports aggregate values through arbitrary number levels of relations
* Highly optimized: uses single query per transaction per aggregate column
* Aggregated columns can be of any data type and use any selectable scalar expression
.. _column_property: http://docs.sqlalchemy.org/en/latest/orm/mapper_config.html#using-column-property
.. _counter_culture: https://github.com/magnusvk/counter_culture
.. _stackoverflow reply by Michael Bayer:
http://stackoverflow.com/questions/13693872/
Simple aggregates
-----------------
::
from sqlalchemy_utils import aggregated_attr
class Thread(Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr('comments')
def comment_count(self):
return sa.Column(sa.Integer)
comments = sa.orm.relationship(
'Comment',
backref='thread'
)
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.UnicodeText)
thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
thread = sa.orm.relationship(Thread, backref='comments')
Custom aggregate expressions
----------------------------
::
from sqlalchemy_utils import aggregated_attr
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr
def net_worth(self):
return sa.Column(sa.Integer)
@aggregated_attr.expression
def net_worth(self):
return sa.func.sum(Product.price)
products = sa.orm.relationship('Product')
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
Now the net_worth column of Catalog model will be automatically whenever:
* A new product is added to the catalog
* A product is deleted from the catalog
* The price of catalog product is changed
::
from decimal import Decimal
product1 = Product(name='Some product', price=Decimal(1000))
product2 = Product(name='Some other product', price=Decimal(500))
catalog = Catalog(
name=u'My first catalog',
products=[
product1,
product2
]
)
session.add(catalog)
session.commit()
session.refresh(catalog)
catalog.net_worth # 1500
session.delete(product2)
session.commit()
session.refresh(catalog)
catalog.net_worth # 1000
product1.price = 2000
session.commit()
session.refresh(catalog)
catalog.net_worth # 2000
Multi-level aggregates
----------------------
::
from sqlalchemy_utils import aggregated_attr
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr('categories.products')
def net_worth(self):
return sa.Column(sa.Integer)
@aggregated_attr.expression
def net_worth(self):
return sa.func.sum(Product.price)
categories = sa.orm.relationship('Product')
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
products = sa.orm.relationship('Product')
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
TODO
----
* Special consideration should be given to `deadlocks`_.
.. _deadlocks:
http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
"""
from collections import defaultdict
import sqlalchemy as sa
import six
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.sql.expression import _FunctionGenerator
class AggregatedAttribute(declared_attr):
def __init__(
self,
fget,
relationship,
expr,
*arg,
**kw
):
super(AggregatedAttribute, self).__init__(fget, *arg, **kw)
self.__doc__ = fget.__doc__
self.expr = expr
self.relationship = relationship
def expression(self, expr):
self.expr = expr
return self
def __get__(desc, self, cls):
result = desc.fget(cls)
if not hasattr(cls, '__aggregates__'):
cls.__aggregates__ = {}
cls.__aggregates__[desc.fget.__name__] = {
'expression': desc.expr,
'relationship': desc.relationship
}
return result
class AggregatedValue(object):
def __init__(self, class_, attr, relationships, expr):
self.class_ = class_
self.attr = attr
self.relationships = relationships
if isinstance(expr, sa.sql.visitors.Visitable):
self.expr = expr
elif isinstance(expr, _FunctionGenerator):
self.expr = expr(sa.sql.text('1'))
else:
self.expr = expr(class_)
@property
def aggregate_query(self):
from_ = self.relationships[0].mapper.class_
for relationship in self.relationships[0:-1]:
property_ = relationship.property
from_ = (
from_.__table__
.join(
property_.parent.class_,
property_.primaryjoin
)
)
query = sa.select(
[self.expr],
from_obj=[from_]
)
query = query.where(self.relationships[-1])
return query.correlate(self.class_).as_scalar()
@property
def update_query(self):
return self.class_.__table__.update().values(
{self.attr: self.aggregate_query}
)
class AggregationManager(object):
def __init__(self):
self.reset()
def reset(self):
self.generator_registry = defaultdict(list)
self.pending_queries = defaultdict(list)
def register_listeners(self):
sa.event.listen(
sa.orm.mapper,
'mapper_configured',
self.update_generator_registry
)
sa.event.listen(
sa.orm.session.Session,
'after_flush',
self.construct_aggregate_queries
)
def update_generator_registry(self, mapper, class_):
if hasattr(class_, '__aggregates__'):
for key, value in six.iteritems(class_.__aggregates__):
relationships = []
rel_class = class_
for path_name in value['relationship'].split('.'):
rel = getattr(rel_class, path_name)
relationships.append(rel)
rel_class = rel.mapper.class_
self.generator_registry[rel_class.__name__].append(
AggregatedValue(
class_=class_,
attr=key,
relationships=list(reversed(relationships)),
expr=value['expression']
)
)
def construct_aggregate_queries(self, session, ctx):
for obj in session:
class_ = obj.__class__.__name__
if class_ in self.generator_registry:
for aggregate_value in self.generator_registry[class_]:
query = aggregate_value.update_query
session.execute(query)
manager = AggregationManager()
manager.register_listeners()
def aggregated_attr(
relationship,
expression=sa.func.count
):
def wraps(func):
return AggregatedAttribute(
func,
relationship,
expression
)
return wraps