Moved aggregate tests, more docs for aggregates
This commit is contained in:
@@ -24,7 +24,7 @@ class AggregateValueGenerator(object):
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.generator_registry = defaultdict(list)
|
self.generator_registry = defaultdict(list)
|
||||||
self.listeners_registered = False
|
self.pending_queries = defaultdict(list)
|
||||||
|
|
||||||
def generator_wrapper(self, func, aggregate_func, relationship):
|
def generator_wrapper(self, func, aggregate_func, relationship):
|
||||||
func = aggregated_attr(func)
|
func = aggregated_attr(func)
|
||||||
@@ -35,7 +35,6 @@ class AggregateValueGenerator(object):
|
|||||||
return func
|
return func
|
||||||
|
|
||||||
def register_listeners(self):
|
def register_listeners(self):
|
||||||
if not self.listeners_registered:
|
|
||||||
sa.event.listen(
|
sa.event.listen(
|
||||||
sa.orm.mapper,
|
sa.orm.mapper,
|
||||||
'mapper_configured',
|
'mapper_configured',
|
||||||
@@ -44,31 +43,56 @@ class AggregateValueGenerator(object):
|
|||||||
sa.event.listen(
|
sa.event.listen(
|
||||||
sa.orm.session.Session,
|
sa.orm.session.Session,
|
||||||
'after_flush',
|
'after_flush',
|
||||||
self.update_generated_properties
|
self.construct_aggregate_queries
|
||||||
)
|
)
|
||||||
self.listeners_registered = True
|
|
||||||
|
|
||||||
def update_generator_registry(self, mapper, class_):
|
def update_generator_registry(self, mapper, class_):
|
||||||
|
#self.reset()
|
||||||
if hasattr(class_, '__aggregates__'):
|
if hasattr(class_, '__aggregates__'):
|
||||||
for key, value in six.iteritems(class_.__aggregates__):
|
for key, value in six.iteritems(class_.__aggregates__):
|
||||||
rel = getattr(class_, value['relationship'])
|
relationships = []
|
||||||
|
rel_class = class_
|
||||||
|
|
||||||
|
for path_name in value['relationship'].split('.'):
|
||||||
|
rel = getattr(rel_class, path_name)
|
||||||
|
relationships.append(rel)
|
||||||
rel_class = rel.mapper.class_
|
rel_class = rel.mapper.class_
|
||||||
|
|
||||||
self.generator_registry[rel_class.__name__].append({
|
self.generator_registry[rel_class.__name__].append({
|
||||||
'class': class_,
|
'class': class_,
|
||||||
'attr': key,
|
'attr': key,
|
||||||
'relationship': rel,
|
'relationship': list(reversed(relationships)),
|
||||||
'aggregate': value['func']
|
'aggregate': value['func']
|
||||||
})
|
})
|
||||||
|
|
||||||
def update_generated_properties(self, session, ctx):
|
def construct_aggregate_queries(self, session, ctx):
|
||||||
for obj in session:
|
for obj in session:
|
||||||
class_ = obj.__class__.__name__
|
class_ = obj.__class__.__name__
|
||||||
if class_ in self.generator_registry:
|
if class_ in self.generator_registry:
|
||||||
for func in self.generator_registry[class_]:
|
for func in self.generator_registry[class_]:
|
||||||
|
if isinstance(func['aggregate'], six.string_types):
|
||||||
|
agg_func = eval(func['aggregate'])
|
||||||
|
else:
|
||||||
|
agg_func = func['aggregate'](obj.__class__.id)
|
||||||
|
|
||||||
aggregate_value = (
|
aggregate_value = (
|
||||||
session.query(func['aggregate'](obj.__class__.id))
|
session.query(agg_func)
|
||||||
.filter(func['relationship'].property.primaryjoin)
|
)
|
||||||
.correlate(func['class']).as_scalar()
|
|
||||||
|
for rel in func['relationship'][0:-1]:
|
||||||
|
aggregate_value = (
|
||||||
|
aggregate_value
|
||||||
|
.join(
|
||||||
|
rel.property.parent.class_,
|
||||||
|
rel.property.primaryjoin
|
||||||
|
)
|
||||||
|
)
|
||||||
|
aggregate_value = aggregate_value.filter(
|
||||||
|
func['relationship'][-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregate_value = (
|
||||||
|
aggregate_value.correlate(func['class']).as_scalar()
|
||||||
)
|
)
|
||||||
query = func['class'].__table__.update().values(
|
query = func['class'].__table__.update().values(
|
||||||
{func['attr']: aggregate_value}
|
{func['attr']: aggregate_value}
|
||||||
@@ -77,6 +101,7 @@ class AggregateValueGenerator(object):
|
|||||||
|
|
||||||
|
|
||||||
generator = AggregateValueGenerator()
|
generator = AggregateValueGenerator()
|
||||||
|
generator.register_listeners()
|
||||||
|
|
||||||
|
|
||||||
def aggregate(aggregate_func, relationship, generator=generator):
|
def aggregate(aggregate_func, relationship, generator=generator):
|
||||||
@@ -96,17 +121,10 @@ def aggregate(aggregate_func, relationship, generator=generator):
|
|||||||
|
|
||||||
|
|
||||||
class Thread(Base):
|
class Thread(Base):
|
||||||
__tablename__ = 'article'
|
__tablename__ = 'thread'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
# _comment_count = sa.Column(sa.Integer)
|
|
||||||
|
|
||||||
# comment_count = aggregate(
|
|
||||||
# '_comment_count',
|
|
||||||
# sa.func.count,
|
|
||||||
# 'comments'
|
|
||||||
# )
|
|
||||||
@aggregate(sa.func.count, 'comments')
|
@aggregate(sa.func.count, 'comments')
|
||||||
def comment_count(self):
|
def comment_count(self):
|
||||||
return sa.Column(sa.Integer)
|
return sa.Column(sa.Integer)
|
||||||
@@ -115,7 +133,6 @@ def aggregate(aggregate_func, relationship, generator=generator):
|
|||||||
def latest_comment_id(self):
|
def latest_comment_id(self):
|
||||||
return sa.Column(sa.Integer)
|
return sa.Column(sa.Integer)
|
||||||
|
|
||||||
|
|
||||||
latest_comment = sa.orm.relationship('Comment', viewonly=True)
|
latest_comment = sa.orm.relationship('Comment', viewonly=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -128,9 +145,37 @@ def aggregate(aggregate_func, relationship, generator=generator):
|
|||||||
thread = sa.orm.relationship(Thread, backref='comments')
|
thread = sa.orm.relationship(Thread, backref='comments')
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
generator.register_listeners()
|
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
|
||||||
|
class Catalog(Base):
|
||||||
|
__tablename__ = 'catalog'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregate(
|
||||||
|
sa.func.sum(price) +
|
||||||
|
sa.func.coalesce(monthly_license_price, 0),
|
||||||
|
'products'
|
||||||
|
)
|
||||||
|
def net_worth(self):
|
||||||
|
return sa.Column(sa.Integer)
|
||||||
|
|
||||||
|
products = sa.orm.relationship('Product')
|
||||||
|
|
||||||
|
|
||||||
|
class Product(Base):
|
||||||
|
__tablename__ = 'product'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
price = sa.Column(sa.Numeric)
|
||||||
|
monthly_license_price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
|
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
def wraps(func):
|
def wraps(func):
|
||||||
return generator.generator_wrapper(
|
return generator.generator_wrapper(
|
||||||
func,
|
func,
|
||||||
|
@@ -6,8 +6,7 @@ from sqlalchemy.orm import sessionmaker
|
|||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.ext.hybrid import hybrid_property
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
|
|
||||||
from sqlalchemy_utils import InstrumentedList
|
from sqlalchemy_utils import InstrumentedList, coercion_listener, aggregates
|
||||||
from sqlalchemy_utils import coercion_listener
|
|
||||||
|
|
||||||
|
|
||||||
@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
|
@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
|
||||||
@@ -40,6 +39,7 @@ class TestCase(object):
|
|||||||
self.session = Session()
|
self.session = Session()
|
||||||
|
|
||||||
def teardown_method(self, method):
|
def teardown_method(self, method):
|
||||||
|
aggregates.generator.reset()
|
||||||
self.session.close_all()
|
self.session.close_all()
|
||||||
self.Base.metadata.drop_all(self.connection)
|
self.Base.metadata.drop_all(self.connection)
|
||||||
self.connection.close()
|
self.connection.close()
|
||||||
|
0
tests/aggregate/__init__.py
Normal file
0
tests/aggregate/__init__.py
Normal file
44
tests/aggregate/test_aggregate_combinations.py
Normal file
44
tests/aggregate/test_aggregate_combinations.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from decimal import Decimal
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils.aggregates import aggregate
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeepModelPathsForAggregates(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Catalog(self.Base):
|
||||||
|
__tablename__ = 'catalog'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregate(sa.func.count, 'categories.products')
|
||||||
|
def product_count(self):
|
||||||
|
return sa.Column(sa.Integer, default=0)
|
||||||
|
|
||||||
|
categories = sa.orm.relationship('Product', backref='catalog')
|
||||||
|
|
||||||
|
class Product(self.Base):
|
||||||
|
__tablename__ = 'product'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
|
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||||
|
|
||||||
|
self.Catalog = Catalog
|
||||||
|
self.Product = Product
|
||||||
|
|
||||||
|
def test_assigns_aggregates(self):
|
||||||
|
catalog = self.Catalog(
|
||||||
|
name=u'Some catalog'
|
||||||
|
)
|
||||||
|
self.session.add(catalog)
|
||||||
|
self.session.commit()
|
||||||
|
product = self.Product(
|
||||||
|
name=u'Some product',
|
||||||
|
price=Decimal('1000'),
|
||||||
|
)
|
||||||
|
self.session.add(product)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(catalog)
|
||||||
|
assert catalog.product_count == 1
|
57
tests/aggregate/test_deep_paths.py
Normal file
57
tests/aggregate/test_deep_paths.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from decimal import Decimal
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils.aggregates import aggregate
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeepModelPathsForAggregates(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Catalog(self.Base):
|
||||||
|
__tablename__ = 'catalog'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregate(sa.func.count, 'categories.products')
|
||||||
|
def product_count(self):
|
||||||
|
return sa.Column(sa.Integer, default=0)
|
||||||
|
|
||||||
|
categories = sa.orm.relationship('Category', backref='catalog')
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
|
__tablename__ = 'category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||||
|
|
||||||
|
products = sa.orm.relationship('Product', backref='category')
|
||||||
|
|
||||||
|
class Product(self.Base):
|
||||||
|
__tablename__ = 'product'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
|
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||||
|
|
||||||
|
self.Catalog = Catalog
|
||||||
|
self.Category = Category
|
||||||
|
self.Product = Product
|
||||||
|
|
||||||
|
def test_assigns_aggregates(self):
|
||||||
|
category = self.Category(name=u'Some category')
|
||||||
|
catalog = self.Catalog(
|
||||||
|
categories=[category]
|
||||||
|
)
|
||||||
|
catalog.name = u'Some catalog'
|
||||||
|
self.session.add(catalog)
|
||||||
|
self.session.commit()
|
||||||
|
product = self.Product(
|
||||||
|
name=u'Some product',
|
||||||
|
price=Decimal('1000'),
|
||||||
|
category=category
|
||||||
|
)
|
||||||
|
self.session.add(product)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(catalog)
|
||||||
|
assert catalog.product_count == 1
|
@@ -3,7 +3,7 @@ from sqlalchemy_utils.aggregates import aggregate
|
|||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
class TestAggregateValueGeneration(TestCase):
|
class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
class Thread(self.Base):
|
class Thread(self.Base):
|
||||||
__tablename__ = 'thread'
|
__tablename__ = 'thread'
|
||||||
@@ -25,7 +25,17 @@ class TestAggregateValueGeneration(TestCase):
|
|||||||
self.Thread = Thread
|
self.Thread = Thread
|
||||||
self.Comment = Comment
|
self.Comment = Comment
|
||||||
|
|
||||||
def test_assigns_aggregates(self):
|
def test_assigns_aggregates_on_insert(self):
|
||||||
|
thread = self.Thread()
|
||||||
|
thread.name = u'some article name'
|
||||||
|
self.session.add(thread)
|
||||||
|
comment = self.Comment(content=u'Some content', thread=thread)
|
||||||
|
self.session.add(comment)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(thread)
|
||||||
|
assert thread.comment_count == 1
|
||||||
|
|
||||||
|
def test_assigns_aggregates_on_separate_insert(self):
|
||||||
thread = self.Thread()
|
thread = self.Thread()
|
||||||
thread.name = u'some article name'
|
thread.name = u'some article name'
|
||||||
self.session.add(thread)
|
self.session.add(thread)
|
||||||
@@ -35,3 +45,16 @@ class TestAggregateValueGeneration(TestCase):
|
|||||||
self.session.commit()
|
self.session.commit()
|
||||||
self.session.refresh(thread)
|
self.session.refresh(thread)
|
||||||
assert thread.comment_count == 1
|
assert thread.comment_count == 1
|
||||||
|
|
||||||
|
def test_assigns_aggregates_on_delete(self):
|
||||||
|
thread = self.Thread()
|
||||||
|
thread.name = u'some article name'
|
||||||
|
self.session.add(thread)
|
||||||
|
self.session.commit()
|
||||||
|
comment = self.Comment(content=u'Some content', thread=thread)
|
||||||
|
self.session.add(comment)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.delete(comment)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(thread)
|
||||||
|
assert thread.comment_count == 0
|
Reference in New Issue
Block a user