From d4f4f4ec3c4ecc48852fcfafcfd11388ab37d16e Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Mon, 1 Dec 2014 15:07:39 +0200 Subject: [PATCH] Fix m2m relationship handling in aggregated --- CHANGES.rst | 6 ++ sqlalchemy_utils/aggregates.py | 71 ++++++++++----- tests/aggregate/test_backrefs.py | 2 +- ...y_to_many_relationships.py => test_m2m.py} | 0 tests/aggregate/test_m2m_m2m.py | 80 ++++++++++++++++ tests/aggregate/test_o2m_m2m.py | 76 ++++++++++++++++ tests/aggregate/test_o2m_o2m.py | 62 +++++++++++++ ...test_deep_paths.py => test_o2m_o2m_o2m.py} | 91 ++----------------- 8 files changed, 282 insertions(+), 106 deletions(-) rename tests/aggregate/{test_many_to_many_relationships.py => test_m2m.py} (100%) create mode 100644 tests/aggregate/test_m2m_m2m.py create mode 100644 tests/aggregate/test_o2m_m2m.py create mode 100644 tests/aggregate/test_o2m_o2m.py rename tests/aggregate/{test_deep_paths.py => test_o2m_o2m_o2m.py} (50%) diff --git a/CHANGES.rst b/CHANGES.rst index 759392d..e987167 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.27.9 (2014-12-01) +^^^^^^^^^^^^^^^^^^^ + +- Fix aggregated decorator many-to-many relationship handling + + 0.27.8 (2014-11-13) ^^^^^^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index 1909b8f..340887f 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -401,6 +401,40 @@ class AggregatedAttribute(declared_attr): return desc.column +def get_aggregate_query(agg_expr, relationships): + from_ = relationships[0].mapper.class_.__table__ + for relationship in relationships[0:-1]: + property_ = relationship.property + if property_.secondary is not None: + from_ = from_.join( + property_.secondary, + property_.secondaryjoin + ) + + from_ = ( + from_ + .join( + property_.parent.class_, + property_.primaryjoin + ) + ) + + prop = relationships[-1].property + condition = prop.primaryjoin + if prop.secondary is not None: + from_ = from_.join( + prop.secondary, + prop.secondaryjoin + ) + + query = sa.select( + [agg_expr], + from_obj=[from_] + ) + + return query.where(condition) + + class AggregatedValue(object): def __init__(self, class_, attr, relationships, expr): self.class_ = class_ @@ -418,23 +452,7 @@ class AggregatedValue(object): @property def aggregate_query(self): - from_ = self.relationships[0].mapper.class_.__table__ - for relationship in self.relationships[0:-1]: - property_ = relationship.property - from_ = ( - from_ - .join( - property_.parent.class_, - property_.primaryjoin - ) - ) - - query = sa.select( - [self.expr], - from_obj=[from_] - ) - - query = query.where(self.relationships[-1]) + query = get_aggregate_query(self.expr, self.relationships) return query.correlate(self.class_).as_scalar() @@ -484,11 +502,22 @@ class AggregatedValue(object): property_ = self.relationships[-1].property from_ = property_.mapper.class_.__table__ - for relationship in reversed(self.relationships[1:-1]): + for relationship in reversed(self.relationships[0:-1]): property_ = relationship.property - from_ = ( - from_.join(property_.mapper.class_, property_.primaryjoin) - ) + if property_.secondary is not None: + from_ = from_.join( + property_.secondary, + property_.primaryjoin + ) + from_ = from_.join( + property_.mapper.class_, + property_.secondaryjoin + ) + else: + from_ = from_.join( + property_.mapper.class_, + property_.primaryjoin + ) return from_ def local_condition(self, prop, objects): diff --git a/tests/aggregate/test_backrefs.py b/tests/aggregate/test_backrefs.py index c752feb..c870444 100644 --- a/tests/aggregate/test_backrefs.py +++ b/tests/aggregate/test_backrefs.py @@ -3,7 +3,7 @@ from sqlalchemy_utils.aggregates import aggregated from tests import TestCase -class TestAggregateValueGenerationForSimpleModelPaths(TestCase): +class TestAggregateValueGenerationWithBackrefs(TestCase): def create_models(self): class Thread(self.Base): __tablename__ = 'thread' diff --git a/tests/aggregate/test_many_to_many_relationships.py b/tests/aggregate/test_m2m.py similarity index 100% rename from tests/aggregate/test_many_to_many_relationships.py rename to tests/aggregate/test_m2m.py diff --git a/tests/aggregate/test_m2m_m2m.py b/tests/aggregate/test_m2m_m2m.py new file mode 100644 index 0000000..1fee8c3 --- /dev/null +++ b/tests/aggregate/test_m2m_m2m.py @@ -0,0 +1,80 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import aggregated +from tests import TestCase + + +class TestAggregateManyToManyAndManyToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + catalog_products = sa.Table( + 'catalog_product', + self.Base.metadata, + sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), + sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) + ) + + product_categories = sa.Table( + 'category_product', + self.Base.metadata, + sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), + sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) + ) + + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated( + 'products.categories', + sa.Column(sa.Integer, default=0) + ) + def category_count(self): + return sa.func.count(sa.distinct(Category.id)) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + catalog_id = sa.Column( + sa.Integer, sa.ForeignKey('catalog.id') + ) + + catalogs = sa.orm.relationship( + Catalog, + backref='products', + secondary=catalog_products + ) + + categories = sa.orm.relationship( + Category, + backref='products', + secondary=product_categories + ) + + self.Catalog = Catalog + self.Category = Category + self.Product = Product + + def test_insert(self): + category = self.Category() + products = [ + self.Product(categories=[category]), + self.Product(categories=[category]) + ] + catalog = self.Catalog(products=products) + self.session.add(catalog) + catalog2 = self.Catalog(products=products) + self.session.add(catalog) + self.session.commit() + assert catalog.category_count == 1 + assert catalog2.category_count == 1 diff --git a/tests/aggregate/test_o2m_m2m.py b/tests/aggregate/test_o2m_m2m.py new file mode 100644 index 0000000..c2b4ff5 --- /dev/null +++ b/tests/aggregate/test_o2m_m2m.py @@ -0,0 +1,76 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import aggregated +from tests import TestCase + + +class TestAggregateOneToManyAndManyToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + product_categories = sa.Table( + 'category_product', + self.Base.metadata, + sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), + sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) + ) + + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated( + 'products.categories', + sa.Column(sa.Integer, default=0) + ) + def category_count(self): + return sa.func.count(sa.distinct(Category.id)) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + catalog_id = sa.Column( + sa.Integer, sa.ForeignKey('catalog.id') + ) + + catalog = sa.orm.relationship( + Catalog, + backref='products' + ) + + categories = sa.orm.relationship( + Category, + backref='products', + secondary=product_categories + ) + + self.Catalog = Catalog + self.Category = Category + self.Product = Product + + def test_insert(self): + category = self.Category() + products = [ + self.Product(categories=[category]), + self.Product(categories=[category]) + ] + catalog = self.Catalog(products=products) + self.session.add(catalog) + products2 = [ + self.Product(categories=[category]), + self.Product(categories=[category]) + ] + catalog2 = self.Catalog(products=products2) + self.session.add(catalog) + self.session.commit() + assert catalog.category_count == 1 + assert catalog2.category_count == 1 diff --git a/tests/aggregate/test_o2m_o2m.py b/tests/aggregate/test_o2m_o2m.py new file mode 100644 index 0000000..b245a0b --- /dev/null +++ b/tests/aggregate/test_o2m_o2m.py @@ -0,0 +1,62 @@ +from decimal import Decimal +import sqlalchemy as sa +from sqlalchemy_utils.aggregates import aggregated +from tests import TestCase + + +class TestAggregateOneToManyAndOneToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated( + 'categories.products', + sa.Column(sa.Integer, default=0) + ) + def product_count(self): + return sa.func.count('1') + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + products = sa.orm.relationship('Product', backref='category') + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + + self.Catalog = Catalog + self.Category = Category + self.Product = Product + + def test_assigns_aggregates(self): + category = self.Category(name=u'Some category') + catalog = self.Catalog( + categories=[category] + ) + catalog.name = u'Some catalog' + self.session.add(catalog) + self.session.commit() + product = self.Product( + name=u'Some product', + price=Decimal('1000'), + category=category + ) + self.session.add(product) + self.session.commit() + self.session.refresh(catalog) + assert catalog.product_count == 1 diff --git a/tests/aggregate/test_deep_paths.py b/tests/aggregate/test_o2m_o2m_o2m.py similarity index 50% rename from tests/aggregate/test_deep_paths.py rename to tests/aggregate/test_o2m_o2m_o2m.py index ab8bbfb..6dcc0e9 100644 --- a/tests/aggregate/test_deep_paths.py +++ b/tests/aggregate/test_o2m_o2m_o2m.py @@ -1,76 +1,15 @@ -from decimal import Decimal import sqlalchemy as sa -from sqlalchemy_utils.aggregates import aggregated + +from sqlalchemy_utils import aggregated from tests import TestCase - -class TestDeepModelPathsForAggregates(TestCase): +class Test3LevelDeepOneToMany(TestCase): dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def create_models(self): class Catalog(self.Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - @aggregated( - 'categories.products', - sa.Column(sa.Integer, default=0) - ) - def product_count(self): - return sa.func.count('1') - - categories = sa.orm.relationship('Category', backref='catalog') - - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) - - products = sa.orm.relationship('Product', backref='category') - - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - price = sa.Column(sa.Numeric) - - category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) - - self.Catalog = Catalog - self.Category = Category - self.Product = Product - - def test_assigns_aggregates(self): - category = self.Category(name=u'Some category') - catalog = self.Catalog( - categories=[category] - ) - catalog.name = u'Some catalog' - self.session.add(catalog) - self.session.commit() - product = self.Product( - name=u'Some product', - price=Decimal('1000'), - category=category - ) - self.session.add(product) - self.session.commit() - self.session.refresh(catalog) - assert catalog.product_count == 1 - - -class Test3LevelDeepModelPathsForAggregates(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' - n = 1 - - def create_models(self): - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) @aggregated( 'categories.sub_categories.products', @@ -84,8 +23,6 @@ class Test3LevelDeepModelPathsForAggregates(TestCase): class Category(self.Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) sub_categories = sa.orm.relationship( @@ -95,16 +32,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase): class SubCategory(self.Base): __tablename__ = 'sub_category' id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) - products = sa.orm.relationship('Product', backref='sub_category') class Product(self.Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) price = sa.Column(sa.Numeric) sub_category_id = sa.Column( @@ -123,23 +56,13 @@ class Test3LevelDeepModelPathsForAggregates(TestCase): assert catalog.product_count == 1 def catalog_factory(self): - product = self.Product( - name=u'Product %d' % self.n - ) + product = self.Product() sub_category = self.SubCategory( - name=u'SubCategory %d' % self.n, products=[product] ) - category = self.Category( - name=u'Category %d' % self.n, - sub_categories=[sub_category] - ) - catalog = self.Catalog( - categories=[category] - ) - catalog.name = u'Catalog %d' % self.n + category = self.Category(sub_categories=[sub_category]) + catalog = self.Catalog(categories=[category]) self.session.add(catalog) - self.n += 1 return catalog def test_only_updates_affected_aggregates(self): @@ -155,7 +78,7 @@ class Test3LevelDeepModelPathsForAggregates(TestCase): ) catalog.categories[0].sub_categories[0].products.append( - self.Product(name=u'Product 3') + self.Product() ) self.session.commit() self.session.refresh(catalog)