Fix m2m relationship handling in aggregated
This commit is contained in:
@@ -4,6 +4,12 @@ Changelog
|
||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||
|
||||
|
||||
0.27.9 (2014-12-01)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Fix aggregated decorator many-to-many relationship handling
|
||||
|
||||
|
||||
0.27.8 (2014-11-13)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
@@ -401,6 +401,40 @@ class AggregatedAttribute(declared_attr):
|
||||
return desc.column
|
||||
|
||||
|
||||
def get_aggregate_query(agg_expr, relationships):
|
||||
from_ = relationships[0].mapper.class_.__table__
|
||||
for relationship in relationships[0:-1]:
|
||||
property_ = relationship.property
|
||||
if property_.secondary is not None:
|
||||
from_ = from_.join(
|
||||
property_.secondary,
|
||||
property_.secondaryjoin
|
||||
)
|
||||
|
||||
from_ = (
|
||||
from_
|
||||
.join(
|
||||
property_.parent.class_,
|
||||
property_.primaryjoin
|
||||
)
|
||||
)
|
||||
|
||||
prop = relationships[-1].property
|
||||
condition = prop.primaryjoin
|
||||
if prop.secondary is not None:
|
||||
from_ = from_.join(
|
||||
prop.secondary,
|
||||
prop.secondaryjoin
|
||||
)
|
||||
|
||||
query = sa.select(
|
||||
[agg_expr],
|
||||
from_obj=[from_]
|
||||
)
|
||||
|
||||
return query.where(condition)
|
||||
|
||||
|
||||
class AggregatedValue(object):
|
||||
def __init__(self, class_, attr, relationships, expr):
|
||||
self.class_ = class_
|
||||
@@ -418,23 +452,7 @@ class AggregatedValue(object):
|
||||
|
||||
@property
|
||||
def aggregate_query(self):
|
||||
from_ = self.relationships[0].mapper.class_.__table__
|
||||
for relationship in self.relationships[0:-1]:
|
||||
property_ = relationship.property
|
||||
from_ = (
|
||||
from_
|
||||
.join(
|
||||
property_.parent.class_,
|
||||
property_.primaryjoin
|
||||
)
|
||||
)
|
||||
|
||||
query = sa.select(
|
||||
[self.expr],
|
||||
from_obj=[from_]
|
||||
)
|
||||
|
||||
query = query.where(self.relationships[-1])
|
||||
query = get_aggregate_query(self.expr, self.relationships)
|
||||
|
||||
return query.correlate(self.class_).as_scalar()
|
||||
|
||||
@@ -484,10 +502,21 @@ class AggregatedValue(object):
|
||||
property_ = self.relationships[-1].property
|
||||
|
||||
from_ = property_.mapper.class_.__table__
|
||||
for relationship in reversed(self.relationships[1:-1]):
|
||||
for relationship in reversed(self.relationships[0:-1]):
|
||||
property_ = relationship.property
|
||||
from_ = (
|
||||
from_.join(property_.mapper.class_, property_.primaryjoin)
|
||||
if property_.secondary is not None:
|
||||
from_ = from_.join(
|
||||
property_.secondary,
|
||||
property_.primaryjoin
|
||||
)
|
||||
from_ = from_.join(
|
||||
property_.mapper.class_,
|
||||
property_.secondaryjoin
|
||||
)
|
||||
else:
|
||||
from_ = from_.join(
|
||||
property_.mapper.class_,
|
||||
property_.primaryjoin
|
||||
)
|
||||
return from_
|
||||
|
||||
|
@@ -3,7 +3,7 @@ from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
|
||||
class TestAggregateValueGenerationWithBackrefs(TestCase):
|
||||
def create_models(self):
|
||||
class Thread(self.Base):
|
||||
__tablename__ = 'thread'
|
||||
|
80
tests/aggregate/test_m2m_m2m.py
Normal file
80
tests/aggregate/test_m2m_m2m.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateManyToManyAndManyToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
catalog_products = sa.Table(
|
||||
'catalog_product',
|
||||
self.Base.metadata,
|
||||
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
|
||||
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||
)
|
||||
|
||||
product_categories = sa.Table(
|
||||
'category_product',
|
||||
self.Base.metadata,
|
||||
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
|
||||
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||
)
|
||||
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated(
|
||||
'products.categories',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def category_count(self):
|
||||
return sa.func.count(sa.distinct(Category.id))
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
catalog_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('catalog.id')
|
||||
)
|
||||
|
||||
catalogs = sa.orm.relationship(
|
||||
Catalog,
|
||||
backref='products',
|
||||
secondary=catalog_products
|
||||
)
|
||||
|
||||
categories = sa.orm.relationship(
|
||||
Category,
|
||||
backref='products',
|
||||
secondary=product_categories
|
||||
)
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.Product = Product
|
||||
|
||||
def test_insert(self):
|
||||
category = self.Category()
|
||||
products = [
|
||||
self.Product(categories=[category]),
|
||||
self.Product(categories=[category])
|
||||
]
|
||||
catalog = self.Catalog(products=products)
|
||||
self.session.add(catalog)
|
||||
catalog2 = self.Catalog(products=products)
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
assert catalog.category_count == 1
|
||||
assert catalog2.category_count == 1
|
76
tests/aggregate/test_o2m_m2m.py
Normal file
76
tests/aggregate/test_o2m_m2m.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateOneToManyAndManyToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
product_categories = sa.Table(
|
||||
'category_product',
|
||||
self.Base.metadata,
|
||||
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
|
||||
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||
)
|
||||
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated(
|
||||
'products.categories',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def category_count(self):
|
||||
return sa.func.count(sa.distinct(Category.id))
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
catalog_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('catalog.id')
|
||||
)
|
||||
|
||||
catalog = sa.orm.relationship(
|
||||
Catalog,
|
||||
backref='products'
|
||||
)
|
||||
|
||||
categories = sa.orm.relationship(
|
||||
Category,
|
||||
backref='products',
|
||||
secondary=product_categories
|
||||
)
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.Product = Product
|
||||
|
||||
def test_insert(self):
|
||||
category = self.Category()
|
||||
products = [
|
||||
self.Product(categories=[category]),
|
||||
self.Product(categories=[category])
|
||||
]
|
||||
catalog = self.Catalog(products=products)
|
||||
self.session.add(catalog)
|
||||
products2 = [
|
||||
self.Product(categories=[category]),
|
||||
self.Product(categories=[category])
|
||||
]
|
||||
catalog2 = self.Catalog(products=products2)
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
assert catalog.category_count == 1
|
||||
assert catalog2.category_count == 1
|
62
tests/aggregate/test_o2m_o2m.py
Normal file
62
tests/aggregate/test_o2m_o2m.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from decimal import Decimal
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestAggregateOneToManyAndOneToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated(
|
||||
'categories.products',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def product_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
products = sa.orm.relationship('Product', backref='category')
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.Product = Product
|
||||
|
||||
def test_assigns_aggregates(self):
|
||||
category = self.Category(name=u'Some category')
|
||||
catalog = self.Catalog(
|
||||
categories=[category]
|
||||
)
|
||||
catalog.name = u'Some catalog'
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
product = self.Product(
|
||||
name=u'Some product',
|
||||
price=Decimal('1000'),
|
||||
category=category
|
||||
)
|
||||
self.session.add(product)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
assert catalog.product_count == 1
|
@@ -1,76 +1,15 @@
|
||||
from decimal import Decimal
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.aggregates import aggregated
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestDeepModelPathsForAggregates(TestCase):
|
||||
class Test3LevelDeepOneToMany(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated(
|
||||
'categories.products',
|
||||
sa.Column(sa.Integer, default=0)
|
||||
)
|
||||
def product_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
products = sa.orm.relationship('Product', backref='category')
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.Product = Product
|
||||
|
||||
def test_assigns_aggregates(self):
|
||||
category = self.Category(name=u'Some category')
|
||||
catalog = self.Catalog(
|
||||
categories=[category]
|
||||
)
|
||||
catalog.name = u'Some catalog'
|
||||
self.session.add(catalog)
|
||||
self.session.commit()
|
||||
product = self.Product(
|
||||
name=u'Some product',
|
||||
price=Decimal('1000'),
|
||||
category=category
|
||||
)
|
||||
self.session.add(product)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
assert catalog.product_count == 1
|
||||
|
||||
|
||||
class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
n = 1
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated(
|
||||
'categories.sub_categories.products',
|
||||
@@ -84,8 +23,6 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
sub_categories = sa.orm.relationship(
|
||||
@@ -95,16 +32,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||
class SubCategory(self.Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
|
||||
products = sa.orm.relationship('Product', backref='sub_category')
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
sub_category_id = sa.Column(
|
||||
@@ -123,23 +56,13 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||
assert catalog.product_count == 1
|
||||
|
||||
def catalog_factory(self):
|
||||
product = self.Product(
|
||||
name=u'Product %d' % self.n
|
||||
)
|
||||
product = self.Product()
|
||||
sub_category = self.SubCategory(
|
||||
name=u'SubCategory %d' % self.n,
|
||||
products=[product]
|
||||
)
|
||||
category = self.Category(
|
||||
name=u'Category %d' % self.n,
|
||||
sub_categories=[sub_category]
|
||||
)
|
||||
catalog = self.Catalog(
|
||||
categories=[category]
|
||||
)
|
||||
catalog.name = u'Catalog %d' % self.n
|
||||
category = self.Category(sub_categories=[sub_category])
|
||||
catalog = self.Catalog(categories=[category])
|
||||
self.session.add(catalog)
|
||||
self.n += 1
|
||||
return catalog
|
||||
|
||||
def test_only_updates_affected_aggregates(self):
|
||||
@@ -155,7 +78,7 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||
)
|
||||
|
||||
catalog.categories[0].sub_categories[0].products.append(
|
||||
self.Product(name=u'Product 3')
|
||||
self.Product()
|
||||
)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
Reference in New Issue
Block a user