Added first draft for aggregates
This commit is contained in:
140
sqlalchemy_utils/aggregates.py
Normal file
140
sqlalchemy_utils/aggregates.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import six
|
||||||
|
from sqlalchemy.ext.declarative import declared_attr
|
||||||
|
|
||||||
|
|
||||||
|
class aggregated_attr(declared_attr):
|
||||||
|
def __init__(self, fget, *arg, **kw):
|
||||||
|
super(aggregated_attr, self).__init__(fget, *arg, **kw)
|
||||||
|
self.__doc__ = fget.__doc__
|
||||||
|
|
||||||
|
def __get__(desc, self, cls):
|
||||||
|
result = desc.fget(cls)
|
||||||
|
cls.__aggregates__ = {
|
||||||
|
desc.fget.__name__: desc.__aggregate__
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class AggregateValueGenerator(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.generator_registry = defaultdict(list)
|
||||||
|
self.listeners_registered = False
|
||||||
|
|
||||||
|
def generator_wrapper(self, func, aggregate_func, relationship):
|
||||||
|
func = aggregated_attr(func)
|
||||||
|
func.__aggregate__ = {
|
||||||
|
'func': aggregate_func,
|
||||||
|
'relationship': relationship
|
||||||
|
}
|
||||||
|
return func
|
||||||
|
|
||||||
|
def register_listeners(self):
|
||||||
|
if not self.listeners_registered:
|
||||||
|
sa.event.listen(
|
||||||
|
sa.orm.mapper,
|
||||||
|
'mapper_configured',
|
||||||
|
self.update_generator_registry
|
||||||
|
)
|
||||||
|
sa.event.listen(
|
||||||
|
sa.orm.session.Session,
|
||||||
|
'after_flush',
|
||||||
|
self.update_generated_properties
|
||||||
|
)
|
||||||
|
self.listeners_registered = True
|
||||||
|
|
||||||
|
def update_generator_registry(self, mapper, class_):
|
||||||
|
if hasattr(class_, '__aggregates__'):
|
||||||
|
for key, value in six.iteritems(class_.__aggregates__):
|
||||||
|
rel = getattr(class_, value['relationship'])
|
||||||
|
rel_class = rel.mapper.class_
|
||||||
|
self.generator_registry[rel_class.__name__].append({
|
||||||
|
'class': class_,
|
||||||
|
'attr': key,
|
||||||
|
'relationship': rel,
|
||||||
|
'aggregate': value['func']
|
||||||
|
})
|
||||||
|
|
||||||
|
def update_generated_properties(self, session, ctx):
|
||||||
|
for obj in session:
|
||||||
|
class_ = obj.__class__.__name__
|
||||||
|
if class_ in self.generator_registry:
|
||||||
|
for func in self.generator_registry[class_]:
|
||||||
|
aggregate_value = (
|
||||||
|
session.query(func['aggregate'](obj.__class__.id))
|
||||||
|
.filter(func['relationship'].property.primaryjoin)
|
||||||
|
.correlate(func['class']).as_scalar()
|
||||||
|
)
|
||||||
|
query = func['class'].__table__.update().values(
|
||||||
|
{func['attr']: aggregate_value}
|
||||||
|
)
|
||||||
|
session.execute(query)
|
||||||
|
|
||||||
|
|
||||||
|
generator = AggregateValueGenerator()
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate(aggregate_func, relationship, generator=generator):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Non-atomic implementation:
|
||||||
|
|
||||||
|
http://stackoverflow.com/questions/13693872/
|
||||||
|
|
||||||
|
|
||||||
|
We should avoid deadlocks:
|
||||||
|
|
||||||
|
http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
|
||||||
|
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
|
||||||
|
class Thread(Base):
|
||||||
|
__tablename__ = 'article'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
# _comment_count = sa.Column(sa.Integer)
|
||||||
|
|
||||||
|
# comment_count = aggregate(
|
||||||
|
# '_comment_count',
|
||||||
|
# sa.func.count,
|
||||||
|
# 'comments'
|
||||||
|
# )
|
||||||
|
@aggregate(sa.func.count, 'comments')
|
||||||
|
def comment_count(self):
|
||||||
|
return sa.Column(sa.Integer)
|
||||||
|
|
||||||
|
@aggregate(sa.func.max, 'comments')
|
||||||
|
def latest_comment_id(self):
|
||||||
|
return sa.Column(sa.Integer)
|
||||||
|
|
||||||
|
|
||||||
|
latest_comment = sa.orm.relationship('Comment', viewonly=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Comment(Base):
|
||||||
|
__tablename__ = 'comment'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
content = sa.Column(sa.Unicode(255))
|
||||||
|
thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
|
||||||
|
|
||||||
|
thread = sa.orm.relationship(Thread, backref='comments')
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
generator.register_listeners()
|
||||||
|
|
||||||
|
def wraps(func):
|
||||||
|
return generator.generator_wrapper(
|
||||||
|
func,
|
||||||
|
aggregate_func,
|
||||||
|
relationship
|
||||||
|
)
|
||||||
|
return wraps
|
37
tests/test_aggregates.py
Normal file
37
tests/test_aggregates.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils.aggregates import aggregate
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestAggregateValueGeneration(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Thread(self.Base):
|
||||||
|
__tablename__ = 'thread'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregate(sa.func.count, 'comments')
|
||||||
|
def comment_count(self):
|
||||||
|
return sa.Column(sa.Integer, default=0)
|
||||||
|
|
||||||
|
comments = sa.orm.relationship('Comment', backref='thread')
|
||||||
|
|
||||||
|
class Comment(self.Base):
|
||||||
|
__tablename__ = 'comment'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
content = sa.Column(sa.Unicode(255))
|
||||||
|
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
|
||||||
|
|
||||||
|
self.Thread = Thread
|
||||||
|
self.Comment = Comment
|
||||||
|
|
||||||
|
def test_assigns_aggregates(self):
|
||||||
|
thread = self.Thread()
|
||||||
|
thread.name = u'some article name'
|
||||||
|
self.session.add(thread)
|
||||||
|
self.session.commit()
|
||||||
|
comment = self.Comment(content=u'Some content', thread=thread)
|
||||||
|
self.session.add(comment)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(thread)
|
||||||
|
assert thread.comment_count == 1
|
Reference in New Issue
Block a user