select_expression -> expression
This commit is contained in:
@@ -156,6 +156,13 @@ UUIDType
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Aggregated attributes
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
.. automodule:: sqlalchemy_utils.aggregates
|
||||||
|
|
||||||
|
.. autofunction:: aggregated_attr
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
The generates decorator
|
The generates decorator
|
||||||
|
@@ -1,3 +1,117 @@
|
|||||||
|
"""
|
||||||
|
SQLAlchemy-Utils provides way of automatically calculating aggregate values of related models and saving them to parent model.
|
||||||
|
|
||||||
|
This solution is inspired by RoR counter cache and especially counter_culture_.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
.. _counter_culter:: https://github.com/magnusvk/counter_culture
|
||||||
|
|
||||||
|
|
||||||
|
Non-atomic implementation:
|
||||||
|
|
||||||
|
http://stackoverflow.com/questions/13693872/
|
||||||
|
|
||||||
|
|
||||||
|
We should avoid deadlocks:
|
||||||
|
|
||||||
|
http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
|
||||||
|
|
||||||
|
|
||||||
|
Simple aggregates
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
from sqlalchemy_utils import aggregated_attr
|
||||||
|
|
||||||
|
|
||||||
|
class Thread(Base):
|
||||||
|
__tablename__ = 'thread'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregated_attr('comments')
|
||||||
|
def comment_count(self):
|
||||||
|
return sa.Column(sa.Integer)
|
||||||
|
|
||||||
|
comments = sa.orm.relationship(
|
||||||
|
'Comment',
|
||||||
|
backref='thread'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Comment(Base):
|
||||||
|
__tablename__ = 'comment'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
content = sa.Column(sa.UnicodeText)
|
||||||
|
thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
|
||||||
|
|
||||||
|
thread = sa.orm.relationship(Thread, backref='comments')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Custom aggregate expressions
|
||||||
|
----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy_utils import aggregated_attr
|
||||||
|
|
||||||
|
|
||||||
|
class Catalog(Base):
|
||||||
|
__tablename__ = 'catalog'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregated_attr
|
||||||
|
def net_worth(self):
|
||||||
|
return sa.Column(sa.Integer)
|
||||||
|
|
||||||
|
@aggregated_attr.expression
|
||||||
|
def net_worth(self):
|
||||||
|
return sa.func.sum(Product.price)
|
||||||
|
|
||||||
|
|
||||||
|
products = sa.orm.relationship('Product')
|
||||||
|
|
||||||
|
|
||||||
|
class Product(Base):
|
||||||
|
__tablename__ = 'product'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
price = sa.Column(sa.Numeric)
|
||||||
|
monthly_license_price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
|
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
|
||||||
|
catalog = Catalog(
|
||||||
|
name=u'My first catalog'
|
||||||
|
products=[
|
||||||
|
Product(name='Some product', price=Decimal(1000)),
|
||||||
|
Product(name='Some other product', price=Decimal(500))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
session.add(catalog)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
catalog.net_worth # 1500
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
@@ -6,35 +120,47 @@ from sqlalchemy.ext.declarative import declared_attr
|
|||||||
from sqlalchemy.sql.expression import _FunctionGenerator
|
from sqlalchemy.sql.expression import _FunctionGenerator
|
||||||
|
|
||||||
|
|
||||||
class aggregated_attr(declared_attr):
|
class AggregatedAttribute(declared_attr):
|
||||||
def __init__(self, fget, *arg, **kw):
|
def __init__(
|
||||||
super(aggregated_attr, self).__init__(fget, *arg, **kw)
|
self,
|
||||||
|
fget,
|
||||||
|
relationship,
|
||||||
|
expr,
|
||||||
|
*arg,
|
||||||
|
**kw
|
||||||
|
):
|
||||||
|
super(AggregatedAttribute, self).__init__(fget, *arg, **kw)
|
||||||
self.__doc__ = fget.__doc__
|
self.__doc__ = fget.__doc__
|
||||||
|
self.expr = expr
|
||||||
|
self.relationship = relationship
|
||||||
|
|
||||||
def select_expression(self, expr):
|
def expression(self, expr):
|
||||||
self.__aggregate__['select_expression'] = expr
|
self.expr = expr
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __get__(desc, self, cls):
|
def __get__(desc, self, cls):
|
||||||
result = desc.fget(cls)
|
result = desc.fget(cls)
|
||||||
cls.__aggregates__ = {
|
if not hasattr(cls, '__aggregates__'):
|
||||||
desc.fget.__name__: desc.__aggregate__
|
cls.__aggregates__ = {}
|
||||||
|
cls.__aggregates__[desc.fget.__name__] = {
|
||||||
|
'expression': desc.expr,
|
||||||
|
'relationship': desc.relationship
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class AggregatedValue(object):
|
class AggregatedValue(object):
|
||||||
def __init__(self, class_, attr, relationships, select_expression):
|
def __init__(self, class_, attr, relationships, expr):
|
||||||
self.class_ = class_
|
self.class_ = class_
|
||||||
self.attr = attr
|
self.attr = attr
|
||||||
self.relationships = relationships
|
self.relationships = relationships
|
||||||
|
|
||||||
if isinstance(select_expression, sa.sql.visitors.Visitable):
|
if isinstance(expr, sa.sql.visitors.Visitable):
|
||||||
self.select_expression = select_expression
|
self.expr = expr
|
||||||
elif isinstance(select_expression, _FunctionGenerator):
|
elif isinstance(expr, _FunctionGenerator):
|
||||||
self.select_expression = select_expression(sa.sql.literal('1'))
|
self.expr = expr(sa.sql.text('1'))
|
||||||
else:
|
else:
|
||||||
self.select_expression = select_expression(class_)
|
self.expr = expr(class_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def aggregate_query(self):
|
def aggregate_query(self):
|
||||||
@@ -50,7 +176,7 @@ class AggregatedValue(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
query = sa.select(
|
query = sa.select(
|
||||||
[self.select_expression],
|
[self.expr],
|
||||||
from_obj=[from_]
|
from_obj=[from_]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -65,7 +191,7 @@ class AggregatedValue(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AggregateValueGenerator(object):
|
class AggregationManager(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@@ -73,14 +199,6 @@ class AggregateValueGenerator(object):
|
|||||||
self.generator_registry = defaultdict(list)
|
self.generator_registry = defaultdict(list)
|
||||||
self.pending_queries = defaultdict(list)
|
self.pending_queries = defaultdict(list)
|
||||||
|
|
||||||
def generator_wrapper(self, func, relationship, select_expression):
|
|
||||||
func = aggregated_attr(func)
|
|
||||||
func.__aggregate__ = {
|
|
||||||
'select_expression': select_expression,
|
|
||||||
'relationship': relationship
|
|
||||||
}
|
|
||||||
return func
|
|
||||||
|
|
||||||
def register_listeners(self):
|
def register_listeners(self):
|
||||||
sa.event.listen(
|
sa.event.listen(
|
||||||
sa.orm.mapper,
|
sa.orm.mapper,
|
||||||
@@ -109,7 +227,7 @@ class AggregateValueGenerator(object):
|
|||||||
class_=class_,
|
class_=class_,
|
||||||
attr=key,
|
attr=key,
|
||||||
relationships=list(reversed(relationships)),
|
relationships=list(reversed(relationships)),
|
||||||
select_expression=value['select_expression']
|
expr=value['expression']
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -118,93 +236,22 @@ class AggregateValueGenerator(object):
|
|||||||
class_ = obj.__class__.__name__
|
class_ = obj.__class__.__name__
|
||||||
if class_ in self.generator_registry:
|
if class_ in self.generator_registry:
|
||||||
for aggregate_value in self.generator_registry[class_]:
|
for aggregate_value in self.generator_registry[class_]:
|
||||||
session.execute(aggregate_value.update_query)
|
query = aggregate_value.update_query
|
||||||
|
session.execute(query)
|
||||||
|
|
||||||
|
|
||||||
generator = AggregateValueGenerator()
|
manager = AggregationManager()
|
||||||
generator.register_listeners()
|
manager.register_listeners()
|
||||||
|
|
||||||
|
|
||||||
def aggregate(
|
def aggregated_attr(
|
||||||
relationship,
|
relationship,
|
||||||
select_expression=sa.func.count,
|
expression=sa.func.count
|
||||||
generator=generator
|
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
|
|
||||||
Non-atomic implementation:
|
|
||||||
|
|
||||||
http://stackoverflow.com/questions/13693872/
|
|
||||||
|
|
||||||
|
|
||||||
We should avoid deadlocks:
|
|
||||||
|
|
||||||
http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
|
|
||||||
|
|
||||||
|
|
||||||
::
|
|
||||||
|
|
||||||
|
|
||||||
class Thread(Base):
|
|
||||||
__tablename__ = 'thread'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
@aggregate(sa.func.count, 'comments')
|
|
||||||
def comment_count(self):
|
|
||||||
return sa.Column(sa.Integer)
|
|
||||||
|
|
||||||
@aggregate(sa.func.max, 'comments')
|
|
||||||
def latest_comment_id(self):
|
|
||||||
return sa.Column(sa.Integer)
|
|
||||||
|
|
||||||
latest_comment = sa.orm.relationship('Comment', viewonly=True)
|
|
||||||
|
|
||||||
|
|
||||||
class Comment(Base):
|
|
||||||
__tablename__ = 'comment'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
content = sa.Column(sa.Unicode(255))
|
|
||||||
thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
|
|
||||||
|
|
||||||
thread = sa.orm.relationship(Thread, backref='comments')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
::
|
|
||||||
|
|
||||||
|
|
||||||
class Catalog(Base):
|
|
||||||
__tablename__ = 'catalog'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
@aggregate(
|
|
||||||
sa.func.sum(price) +
|
|
||||||
sa.func.coalesce(monthly_license_price, 0),
|
|
||||||
'products'
|
|
||||||
)
|
|
||||||
def net_worth(self):
|
|
||||||
return sa.Column(sa.Integer)
|
|
||||||
|
|
||||||
products = sa.orm.relationship('Product')
|
|
||||||
|
|
||||||
|
|
||||||
class Product(Base):
|
|
||||||
__tablename__ = 'product'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
price = sa.Column(sa.Numeric)
|
|
||||||
monthly_license_price = sa.Column(sa.Numeric)
|
|
||||||
|
|
||||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
def wraps(func):
|
def wraps(func):
|
||||||
return generator.generator_wrapper(
|
return AggregatedAttribute(
|
||||||
func,
|
func,
|
||||||
relationship,
|
relationship,
|
||||||
select_expression=select_expression
|
expression
|
||||||
)
|
)
|
||||||
return wraps
|
return wraps
|
||||||
|
@@ -39,7 +39,7 @@ class TestCase(object):
|
|||||||
self.session = Session()
|
self.session = Session()
|
||||||
|
|
||||||
def teardown_method(self, method):
|
def teardown_method(self, method):
|
||||||
aggregates.generator.reset()
|
aggregates.manager.reset()
|
||||||
self.session.close_all()
|
self.session.close_all()
|
||||||
self.Base.metadata.drop_all(self.connection)
|
self.Base.metadata.drop_all(self.connection)
|
||||||
self.connection.close()
|
self.connection.close()
|
||||||
|
@@ -1,51 +0,0 @@
|
|||||||
from decimal import Decimal
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from sqlalchemy_utils.aggregates import aggregate
|
|
||||||
from tests import TestCase
|
|
||||||
|
|
||||||
|
|
||||||
class TestDeepModelPathsForAggregates(TestCase):
|
|
||||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
|
||||||
|
|
||||||
def create_models(self):
|
|
||||||
class Catalog(self.Base):
|
|
||||||
__tablename__ = 'catalog'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
@aggregate('products')
|
|
||||||
def net_worth(self):
|
|
||||||
return sa.Column(sa.Numeric, default=0)
|
|
||||||
|
|
||||||
@net_worth.select_expression
|
|
||||||
def net_worth(self):
|
|
||||||
return sa.func.sum(Product.price)
|
|
||||||
|
|
||||||
products = sa.orm.relationship('Product', backref='catalog')
|
|
||||||
|
|
||||||
class Product(self.Base):
|
|
||||||
__tablename__ = 'product'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
price = sa.Column(sa.Numeric)
|
|
||||||
|
|
||||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
|
||||||
|
|
||||||
self.Catalog = Catalog
|
|
||||||
self.Product = Product
|
|
||||||
|
|
||||||
def test_assigns_aggregates(self):
|
|
||||||
catalog = self.Catalog(
|
|
||||||
name=u'Some catalog'
|
|
||||||
)
|
|
||||||
self.session.add(catalog)
|
|
||||||
self.session.commit()
|
|
||||||
product = self.Product(
|
|
||||||
name=u'Some product',
|
|
||||||
price=Decimal('1000'),
|
|
||||||
catalog=catalog
|
|
||||||
)
|
|
||||||
self.session.add(product)
|
|
||||||
self.session.commit()
|
|
||||||
self.session.refresh(catalog)
|
|
||||||
assert catalog.net_worth == Decimal('1000')
|
|
@@ -1,6 +1,6 @@
|
|||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy_utils.aggregates import aggregate
|
from sqlalchemy_utils.aggregates import aggregated_attr
|
||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -11,7 +11,7 @@ class TestDeepModelPathsForAggregates(TestCase):
|
|||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
@aggregate('categories.products')
|
@aggregated_attr('categories.products')
|
||||||
def product_count(self):
|
def product_count(self):
|
||||||
return sa.Column(sa.Integer, default=0)
|
return sa.Column(sa.Integer, default=0)
|
||||||
|
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy_utils.aggregates import aggregate
|
from sqlalchemy_utils.aggregates import aggregated_attr
|
||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
|
|||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
@aggregate('comments')
|
@aggregated_attr('comments')
|
||||||
def comment_count(self):
|
def comment_count(self):
|
||||||
return sa.Column(sa.Integer, default=0)
|
return sa.Column(sa.Integer, default=0)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user