Files
deb-python-sqlalchemy-utils/sqlalchemy_utils/aggregates.py
2013-11-11 10:21:16 +02:00

465 lines
12 KiB
Python

"""
SQLAlchemy-Utils provides way of automatically calculating aggregate values of
related models and saving them to parent model.
This solution is inspired by RoR counter cache,
`counter_culture`_ and `stackoverflow reply by Michael Bayer`_.
Why?
----
Many times you may have situations where you need to calculate dynamically some
aggregate value for given model. Some simple examples include:
- Number of products in a catalog
- Average rating for movie
- Latest forum post
- Total price of orders for given customer
Now all these aggregates can be elegantly implemented with SQLAlchemy
column_property_ function. However when your data grows calculating these
values on the fly might start to hurt the performance of your application. The
more aggregates you are using the more performance penalty you get.
This module provides way of calculating these values automatically and
efficiently at the time of modification rather than on the fly.
Features
--------
* Automatically updates aggregate columns when aggregated values change
* Supports aggregate values through arbitrary number levels of relations
* Highly optimized: uses single query per transaction per aggregate column
* Aggregated columns can be of any data type and use any selectable scalar expression
.. _column_property: http://docs.sqlalchemy.org/en/latest/orm/mapper_config.html#using-column-property
.. _counter_culture: https://github.com/magnusvk/counter_culture
.. _stackoverflow reply by Michael Bayer:
http://stackoverflow.com/questions/13693872/
Simple aggregates
-----------------
::
from sqlalchemy_utils import aggregated_attr
class Thread(Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr('comments')
def comment_count(self):
return sa.Column(sa.Integer)
comments = sa.orm.relationship(
'Comment',
backref='thread'
)
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.UnicodeText)
thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
Custom aggregate expressions
----------------------------
::
from sqlalchemy_utils import aggregated_attr
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr
def net_worth(self):
return sa.Column(sa.Integer)
@net_worth.expression
def net_worth(self):
return sa.func.sum(Product.price)
products = sa.orm.relationship('Product')
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
Now the net_worth column of Catalog model will be automatically whenever:
* A new product is added to the catalog
* A product is deleted from the catalog
* The price of catalog product is changed
::
from decimal import Decimal
product1 = Product(name='Some product', price=Decimal(1000))
product2 = Product(name='Some other product', price=Decimal(500))
catalog = Catalog(
name=u'My first catalog',
products=[
product1,
product2
]
)
session.add(catalog)
session.commit()
session.refresh(catalog)
catalog.net_worth # 1500
session.delete(product2)
session.commit()
session.refresh(catalog)
catalog.net_worth # 1000
product1.price = 2000
session.commit()
session.refresh(catalog)
catalog.net_worth # 2000
Multiple aggregates per class
-----------------------------
Sometimes you may need to define multiple aggregate values for same class. If you need
to define lots of relationships pointing to same class, remember to define the relationships as viewonly when possible.
::
from sqlalchemy_utils import aggregated_attr
class Customer(Base):
__tablename__ = 'customer'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr('orders')
def orders_sum(self):
return sa.Column(sa.Integer)
@orders_sum.expression
def orders_sum(self):
return sa.func.sum(Order.price)
@aggregated_attr('invoiced_orders')
def invoiced_orders_sum(self):
return sa.Column(sa.Integer)
@invoiced_orders_sum.expression
def invoiced_orders_sum(self):
return sa.func.sum(Order.price)
orders = sa.orm.relationship('Order')
invoiced_orders = sa.orm.relationship(
'Order',
primaryjoin=
'sa.and_(Order.customer_id == Customer.id, Order.invoiced)',
viewonly=True
)
class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
invoiced = sa.Column(sa.Boolean, default=False)
customer_id = sa.Column(sa.Integer, sa.ForeignKey(Customer.id))
Multi-level aggregates
----------------------
::
from sqlalchemy_utils import aggregated_attr
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr('categories.products')
def net_worth(self):
return sa.Column(sa.Integer)
@net_worth.expression
def net_worth(self):
return sa.func.sum(Product.price)
categories = sa.orm.relationship('Product')
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
products = sa.orm.relationship('Product')
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
TODO
----
* Support calculation of many-to-many aggregates
* Special consideration should be given to `deadlocks`_.
.. _deadlocks:
http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
"""
from collections import defaultdict
import sqlalchemy as sa
import six
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.sql.expression import _FunctionGenerator
class AggregatedAttribute(declared_attr):
def __init__(
self,
fget,
relationship,
expr,
*args,
**kwargs
):
super(AggregatedAttribute, self).__init__(fget, *args, **kwargs)
self.__doc__ = fget.__doc__
self.expr = expr
self.relationship = relationship
def expression(self, expr):
self.expr = expr
return self
def __get__(desc, self, cls):
result = desc.fget(cls)
if not hasattr(cls, '__aggregates__'):
cls.__aggregates__ = {}
cls.__aggregates__[desc.fget.__name__] = {
'expression': desc.expr,
'relationship': desc.relationship
}
return result
class AggregatedValue(object):
def __init__(self, class_, attr, relationships, expr):
self.class_ = class_
self.attr = attr
self.relationships = relationships
if isinstance(expr, sa.sql.visitors.Visitable):
self.expr = expr
elif isinstance(expr, _FunctionGenerator):
self.expr = expr(sa.sql.text('1'))
else:
self.expr = expr(class_)
@property
def aggregate_query(self):
from_ = self.relationships[0].mapper.class_.__table__
for relationship in self.relationships[0:-1]:
property_ = relationship.property
from_ = (
from_
.join(
property_.parent.class_,
property_.primaryjoin
)
)
query = sa.select(
[self.expr],
from_obj=[from_]
)
query = query.where(self.relationships[-1])
return query.correlate(self.class_).as_scalar()
def update_query(self, objects):
table = self.class_.__table__
query = table.update().values(
{self.attr: self.aggregate_query}
)
if len(self.relationships) == 1:
remote_pairs = self.relationships[-1].property.local_remote_pairs
query = query.where(
remote_pairs[0][0].in_(
getattr(obj, remote_pairs[0][1].key) for obj in objects
)
)
else:
# Builds query such as:
#
# UPDATE catalog SET product_count = (aggregate_query)
# WHERE id IN (
# SELECT catalog_id
# FROM category
# INNER JOIN sub_category
# ON category.id = sub_category.category_id
# WHERE sub_category.id IN (product_sub_category_ids)
# )
property_ = self.relationships[-1].property
remote_pairs = property_.local_remote_pairs
from_ = property_.mapper.class_.__table__
for relationship in reversed(self.relationships[1:-1]):
property_ = relationship.property
from_ = (
from_.join(property_.mapper.class_, property_.primaryjoin)
)
property_ = self.relationships[0].property
query = query.where(
remote_pairs[0][0].in_(
sa.select(
[remote_pairs[0][1]],
from_obj=[from_]
).where(
property_.local_remote_pairs[0][0].in_(
getattr(
obj, property_.local_remote_pairs[0][1].key
)
for obj in objects
)
)
)
)
return query
class AggregationManager(object):
def __init__(self):
self.reset()
def reset(self):
self.generator_registry = defaultdict(list)
self.pending_queries = defaultdict(list)
def register_listeners(self):
sa.event.listen(
sa.orm.mapper,
'mapper_configured',
self.update_generator_registry
)
sa.event.listen(
sa.orm.session.Session,
'after_flush',
self.construct_aggregate_queries
)
def update_generator_registry(self, mapper, class_):
if hasattr(class_, '__aggregates__'):
for key, value in six.iteritems(class_.__aggregates__):
relationships = []
rel_class = class_
for path_name in value['relationship'].split('.'):
rel = getattr(rel_class, path_name)
relationships.append(rel)
rel_class = rel.mapper.class_
self.generator_registry[rel_class.__name__].append(
AggregatedValue(
class_=class_,
attr=key,
relationships=list(reversed(relationships)),
expr=value['expression']
)
)
def construct_aggregate_queries(self, session, ctx):
object_dict = defaultdict(list)
for obj in session:
class_ = obj.__class__.__name__
if class_ in self.generator_registry:
object_dict[class_].append(obj)
for class_, objects in six.iteritems(object_dict):
for aggregate_value in self.generator_registry[class_]:
query = aggregate_value.update_query(objects)
session.execute(query)
manager = AggregationManager()
manager.register_listeners()
def aggregated_attr(
relationship,
expression=sa.func.count
):
def wraps(func):
return AggregatedAttribute(
func,
relationship,
expression
)
return wraps