From e4c9d338dc0fab8c2cee4b5f84a4715559ca1d0a Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Mon, 4 Nov 2013 10:16:47 +0200 Subject: [PATCH] Moved aggregate tests, more docs for aggregates --- sqlalchemy_utils/aggregates.py | 107 +++++++++++++----- tests/__init__.py | 4 +- tests/aggregate/__init__.py | 0 .../aggregate/test_aggregate_combinations.py | 44 +++++++ tests/aggregate/test_deep_paths.py | 57 ++++++++++ .../test_simple_paths.py} | 27 ++++- 6 files changed, 204 insertions(+), 35 deletions(-) create mode 100644 tests/aggregate/__init__.py create mode 100644 tests/aggregate/test_aggregate_combinations.py create mode 100644 tests/aggregate/test_deep_paths.py rename tests/{test_aggregates.py => aggregate/test_simple_paths.py} (55%) diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index 9e91c41..4a85dc6 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -24,7 +24,7 @@ class AggregateValueGenerator(object): def reset(self): self.generator_registry = defaultdict(list) - self.listeners_registered = False + self.pending_queries = defaultdict(list) def generator_wrapper(self, func, aggregate_func, relationship): func = aggregated_attr(func) @@ -35,40 +35,64 @@ class AggregateValueGenerator(object): return func def register_listeners(self): - if not self.listeners_registered: - sa.event.listen( - sa.orm.mapper, - 'mapper_configured', - self.update_generator_registry - ) - sa.event.listen( - sa.orm.session.Session, - 'after_flush', - self.update_generated_properties - ) - self.listeners_registered = True + sa.event.listen( + sa.orm.mapper, + 'mapper_configured', + self.update_generator_registry + ) + sa.event.listen( + sa.orm.session.Session, + 'after_flush', + self.construct_aggregate_queries + ) def update_generator_registry(self, mapper, class_): + #self.reset() if hasattr(class_, '__aggregates__'): for key, value in six.iteritems(class_.__aggregates__): - rel = getattr(class_, value['relationship']) - rel_class = rel.mapper.class_ + relationships = [] + rel_class = class_ + + for path_name in value['relationship'].split('.'): + rel = getattr(rel_class, path_name) + relationships.append(rel) + rel_class = rel.mapper.class_ + self.generator_registry[rel_class.__name__].append({ 'class': class_, 'attr': key, - 'relationship': rel, + 'relationship': list(reversed(relationships)), 'aggregate': value['func'] }) - def update_generated_properties(self, session, ctx): + def construct_aggregate_queries(self, session, ctx): for obj in session: class_ = obj.__class__.__name__ if class_ in self.generator_registry: for func in self.generator_registry[class_]: + if isinstance(func['aggregate'], six.string_types): + agg_func = eval(func['aggregate']) + else: + agg_func = func['aggregate'](obj.__class__.id) + aggregate_value = ( - session.query(func['aggregate'](obj.__class__.id)) - .filter(func['relationship'].property.primaryjoin) - .correlate(func['class']).as_scalar() + session.query(agg_func) + ) + + for rel in func['relationship'][0:-1]: + aggregate_value = ( + aggregate_value + .join( + rel.property.parent.class_, + rel.property.primaryjoin + ) + ) + aggregate_value = aggregate_value.filter( + func['relationship'][-1] + ) + + aggregate_value = ( + aggregate_value.correlate(func['class']).as_scalar() ) query = func['class'].__table__.update().values( {func['attr']: aggregate_value} @@ -77,6 +101,7 @@ class AggregateValueGenerator(object): generator = AggregateValueGenerator() +generator.register_listeners() def aggregate(aggregate_func, relationship, generator=generator): @@ -96,17 +121,10 @@ def aggregate(aggregate_func, relationship, generator=generator): class Thread(Base): - __tablename__ = 'article' + __tablename__ = 'thread' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - # _comment_count = sa.Column(sa.Integer) - - # comment_count = aggregate( - # '_comment_count', - # sa.func.count, - # 'comments' - # ) @aggregate(sa.func.count, 'comments') def comment_count(self): return sa.Column(sa.Integer) @@ -115,7 +133,6 @@ def aggregate(aggregate_func, relationship, generator=generator): def latest_comment_id(self): return sa.Column(sa.Integer) - latest_comment = sa.orm.relationship('Comment', viewonly=True) @@ -128,9 +145,37 @@ def aggregate(aggregate_func, relationship, generator=generator): thread = sa.orm.relationship(Thread, backref='comments') - """ - generator.register_listeners() + :: + + + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregate( + sa.func.sum(price) + + sa.func.coalesce(monthly_license_price, 0), + 'products' + ) + def net_worth(self): + return sa.Column(sa.Integer) + + products = sa.orm.relationship('Product') + + + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + monthly_license_price = sa.Column(sa.Numeric) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id)) + + + """ def wraps(func): return generator.generator_wrapper( func, diff --git a/tests/__init__.py b/tests/__init__.py index 53ba079..719ed5b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -6,8 +6,7 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy_utils import InstrumentedList -from sqlalchemy_utils import coercion_listener +from sqlalchemy_utils import InstrumentedList, coercion_listener, aggregates @sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute') @@ -40,6 +39,7 @@ class TestCase(object): self.session = Session() def teardown_method(self, method): + aggregates.generator.reset() self.session.close_all() self.Base.metadata.drop_all(self.connection) self.connection.close() diff --git a/tests/aggregate/__init__.py b/tests/aggregate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/aggregate/test_aggregate_combinations.py b/tests/aggregate/test_aggregate_combinations.py new file mode 100644 index 0000000..5861929 --- /dev/null +++ b/tests/aggregate/test_aggregate_combinations.py @@ -0,0 +1,44 @@ +from decimal import Decimal +import sqlalchemy as sa +from sqlalchemy_utils.aggregates import aggregate +from tests import TestCase + + +class TestDeepModelPathsForAggregates(TestCase): + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregate(sa.func.count, 'categories.products') + def product_count(self): + return sa.Column(sa.Integer, default=0) + + categories = sa.orm.relationship('Product', backref='catalog') + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + self.Catalog = Catalog + self.Product = Product + + def test_assigns_aggregates(self): + catalog = self.Catalog( + name=u'Some catalog' + ) + self.session.add(catalog) + self.session.commit() + product = self.Product( + name=u'Some product', + price=Decimal('1000'), + ) + self.session.add(product) + self.session.commit() + self.session.refresh(catalog) + assert catalog.product_count == 1 diff --git a/tests/aggregate/test_deep_paths.py b/tests/aggregate/test_deep_paths.py new file mode 100644 index 0000000..e16ef1d --- /dev/null +++ b/tests/aggregate/test_deep_paths.py @@ -0,0 +1,57 @@ +from decimal import Decimal +import sqlalchemy as sa +from sqlalchemy_utils.aggregates import aggregate +from tests import TestCase + + +class TestDeepModelPathsForAggregates(TestCase): + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregate(sa.func.count, 'categories.products') + def product_count(self): + return sa.Column(sa.Integer, default=0) + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + products = sa.orm.relationship('Product', backref='category') + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + + self.Catalog = Catalog + self.Category = Category + self.Product = Product + + def test_assigns_aggregates(self): + category = self.Category(name=u'Some category') + catalog = self.Catalog( + categories=[category] + ) + catalog.name = u'Some catalog' + self.session.add(catalog) + self.session.commit() + product = self.Product( + name=u'Some product', + price=Decimal('1000'), + category=category + ) + self.session.add(product) + self.session.commit() + self.session.refresh(catalog) + assert catalog.product_count == 1 diff --git a/tests/test_aggregates.py b/tests/aggregate/test_simple_paths.py similarity index 55% rename from tests/test_aggregates.py rename to tests/aggregate/test_simple_paths.py index 682dcb6..cdd19c8 100644 --- a/tests/test_aggregates.py +++ b/tests/aggregate/test_simple_paths.py @@ -3,7 +3,7 @@ from sqlalchemy_utils.aggregates import aggregate from tests import TestCase -class TestAggregateValueGeneration(TestCase): +class TestAggregateValueGenerationForSimpleModelPaths(TestCase): def create_models(self): class Thread(self.Base): __tablename__ = 'thread' @@ -25,7 +25,17 @@ class TestAggregateValueGeneration(TestCase): self.Thread = Thread self.Comment = Comment - def test_assigns_aggregates(self): + def test_assigns_aggregates_on_insert(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + + def test_assigns_aggregates_on_separate_insert(self): thread = self.Thread() thread.name = u'some article name' self.session.add(thread) @@ -35,3 +45,16 @@ class TestAggregateValueGeneration(TestCase): self.session.commit() self.session.refresh(thread) assert thread.comment_count == 1 + + def test_assigns_aggregates_on_delete(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + self.session.commit() + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.delete(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 0