select_expression -> expression

This commit is contained in:
Konsta Vesterinen
2013-11-04 15:43:51 +02:00
parent 2e7053f9b2
commit 2bc6e632ca
6 changed files with 162 additions and 159 deletions

View File

@@ -156,6 +156,13 @@ UUIDType
Aggregated attributes
---------------------
.. automodule:: sqlalchemy_utils.aggregates
.. autofunction:: aggregated_attr
The generates decorator

View File

@@ -1,3 +1,117 @@
"""
SQLAlchemy-Utils provides way of automatically calculating aggregate values of related models and saving them to parent model.
This solution is inspired by RoR counter cache and especially counter_culture_.
.. _counter_culter:: https://github.com/magnusvk/counter_culture
Non-atomic implementation:
http://stackoverflow.com/questions/13693872/
We should avoid deadlocks:
http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
Simple aggregates
-----------------
::
from sqlalchemy_utils import aggregated_attr
class Thread(Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr('comments')
def comment_count(self):
return sa.Column(sa.Integer)
comments = sa.orm.relationship(
'Comment',
backref='thread'
)
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.UnicodeText)
thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
thread = sa.orm.relationship(Thread, backref='comments')
Custom aggregate expressions
----------------------------
::
from sqlalchemy_utils import aggregated_attr
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr
def net_worth(self):
return sa.Column(sa.Integer)
@aggregated_attr.expression
def net_worth(self):
return sa.func.sum(Product.price)
products = sa.orm.relationship('Product')
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
monthly_license_price = sa.Column(sa.Numeric)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
::
from decimal import Decimal
catalog = Catalog(
name=u'My first catalog'
products=[
Product(name='Some product', price=Decimal(1000)),
Product(name='Some other product', price=Decimal(500))
]
)
session.add(catalog)
session.commit()
catalog.net_worth # 1500
"""
from collections import defaultdict
import sqlalchemy as sa
@@ -6,35 +120,47 @@ from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.sql.expression import _FunctionGenerator
class aggregated_attr(declared_attr):
def __init__(self, fget, *arg, **kw):
super(aggregated_attr, self).__init__(fget, *arg, **kw)
class AggregatedAttribute(declared_attr):
def __init__(
self,
fget,
relationship,
expr,
*arg,
**kw
):
super(AggregatedAttribute, self).__init__(fget, *arg, **kw)
self.__doc__ = fget.__doc__
self.expr = expr
self.relationship = relationship
def select_expression(self, expr):
self.__aggregate__['select_expression'] = expr
def expression(self, expr):
self.expr = expr
return self
def __get__(desc, self, cls):
result = desc.fget(cls)
cls.__aggregates__ = {
desc.fget.__name__: desc.__aggregate__
if not hasattr(cls, '__aggregates__'):
cls.__aggregates__ = {}
cls.__aggregates__[desc.fget.__name__] = {
'expression': desc.expr,
'relationship': desc.relationship
}
return result
class AggregatedValue(object):
def __init__(self, class_, attr, relationships, select_expression):
def __init__(self, class_, attr, relationships, expr):
self.class_ = class_
self.attr = attr
self.relationships = relationships
if isinstance(select_expression, sa.sql.visitors.Visitable):
self.select_expression = select_expression
elif isinstance(select_expression, _FunctionGenerator):
self.select_expression = select_expression(sa.sql.literal('1'))
if isinstance(expr, sa.sql.visitors.Visitable):
self.expr = expr
elif isinstance(expr, _FunctionGenerator):
self.expr = expr(sa.sql.text('1'))
else:
self.select_expression = select_expression(class_)
self.expr = expr(class_)
@property
def aggregate_query(self):
@@ -50,7 +176,7 @@ class AggregatedValue(object):
)
query = sa.select(
[self.select_expression],
[self.expr],
from_obj=[from_]
)
@@ -65,7 +191,7 @@ class AggregatedValue(object):
)
class AggregateValueGenerator(object):
class AggregationManager(object):
def __init__(self):
self.reset()
@@ -73,14 +199,6 @@ class AggregateValueGenerator(object):
self.generator_registry = defaultdict(list)
self.pending_queries = defaultdict(list)
def generator_wrapper(self, func, relationship, select_expression):
func = aggregated_attr(func)
func.__aggregate__ = {
'select_expression': select_expression,
'relationship': relationship
}
return func
def register_listeners(self):
sa.event.listen(
sa.orm.mapper,
@@ -109,7 +227,7 @@ class AggregateValueGenerator(object):
class_=class_,
attr=key,
relationships=list(reversed(relationships)),
select_expression=value['select_expression']
expr=value['expression']
)
)
@@ -118,93 +236,22 @@ class AggregateValueGenerator(object):
class_ = obj.__class__.__name__
if class_ in self.generator_registry:
for aggregate_value in self.generator_registry[class_]:
session.execute(aggregate_value.update_query)
query = aggregate_value.update_query
session.execute(query)
generator = AggregateValueGenerator()
generator.register_listeners()
manager = AggregationManager()
manager.register_listeners()
def aggregate(
def aggregated_attr(
relationship,
select_expression=sa.func.count,
generator=generator
expression=sa.func.count
):
"""
Non-atomic implementation:
http://stackoverflow.com/questions/13693872/
We should avoid deadlocks:
http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
::
class Thread(Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregate(sa.func.count, 'comments')
def comment_count(self):
return sa.Column(sa.Integer)
@aggregate(sa.func.max, 'comments')
def latest_comment_id(self):
return sa.Column(sa.Integer)
latest_comment = sa.orm.relationship('Comment', viewonly=True)
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
thread = sa.orm.relationship(Thread, backref='comments')
::
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregate(
sa.func.sum(price) +
sa.func.coalesce(monthly_license_price, 0),
'products'
)
def net_worth(self):
return sa.Column(sa.Integer)
products = sa.orm.relationship('Product')
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
monthly_license_price = sa.Column(sa.Numeric)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
"""
def wraps(func):
return generator.generator_wrapper(
return AggregatedAttribute(
func,
relationship,
select_expression=select_expression
expression
)
return wraps

View File

@@ -39,7 +39,7 @@ class TestCase(object):
self.session = Session()
def teardown_method(self, method):
aggregates.generator.reset()
aggregates.manager.reset()
self.session.close_all()
self.Base.metadata.drop_all(self.connection)
self.connection.close()

View File

@@ -1,51 +0,0 @@
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregate
from tests import TestCase
class TestDeepModelPathsForAggregates(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregate('products')
def net_worth(self):
return sa.Column(sa.Numeric, default=0)
@net_worth.select_expression
def net_worth(self):
return sa.func.sum(Product.price)
products = sa.orm.relationship('Product', backref='catalog')
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
self.Catalog = Catalog
self.Product = Product
def test_assigns_aggregates(self):
catalog = self.Catalog(
name=u'Some catalog'
)
self.session.add(catalog)
self.session.commit()
product = self.Product(
name=u'Some product',
price=Decimal('1000'),
catalog=catalog
)
self.session.add(product)
self.session.commit()
self.session.refresh(catalog)
assert catalog.net_worth == Decimal('1000')

View File

@@ -1,6 +1,6 @@
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregate
from sqlalchemy_utils.aggregates import aggregated_attr
from tests import TestCase
@@ -11,7 +11,7 @@ class TestDeepModelPathsForAggregates(TestCase):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregate('categories.products')
@aggregated_attr('categories.products')
def product_count(self):
return sa.Column(sa.Integer, default=0)

View File

@@ -1,5 +1,5 @@
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregate
from sqlalchemy_utils.aggregates import aggregated_attr
from tests import TestCase
@@ -10,7 +10,7 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregate('comments')
@aggregated_attr('comments')
def comment_count(self):
return sa.Column(sa.Integer, default=0)