Fix m2m relationship handling in aggregated

This commit is contained in:
Konsta Vesterinen
2014-12-01 15:07:39 +02:00
parent 5e00019be5
commit d4f4f4ec3c
8 changed files with 282 additions and 106 deletions

View File

@@ -4,6 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.27.9 (2014-12-01)
^^^^^^^^^^^^^^^^^^^
- Fix aggregated decorator many-to-many relationship handling
0.27.8 (2014-11-13)
^^^^^^^^^^^^^^^^^^^

View File

@@ -401,6 +401,40 @@ class AggregatedAttribute(declared_attr):
return desc.column
def get_aggregate_query(agg_expr, relationships):
from_ = relationships[0].mapper.class_.__table__
for relationship in relationships[0:-1]:
property_ = relationship.property
if property_.secondary is not None:
from_ = from_.join(
property_.secondary,
property_.secondaryjoin
)
from_ = (
from_
.join(
property_.parent.class_,
property_.primaryjoin
)
)
prop = relationships[-1].property
condition = prop.primaryjoin
if prop.secondary is not None:
from_ = from_.join(
prop.secondary,
prop.secondaryjoin
)
query = sa.select(
[agg_expr],
from_obj=[from_]
)
return query.where(condition)
class AggregatedValue(object):
def __init__(self, class_, attr, relationships, expr):
self.class_ = class_
@@ -418,23 +452,7 @@ class AggregatedValue(object):
@property
def aggregate_query(self):
from_ = self.relationships[0].mapper.class_.__table__
for relationship in self.relationships[0:-1]:
property_ = relationship.property
from_ = (
from_
.join(
property_.parent.class_,
property_.primaryjoin
)
)
query = sa.select(
[self.expr],
from_obj=[from_]
)
query = query.where(self.relationships[-1])
query = get_aggregate_query(self.expr, self.relationships)
return query.correlate(self.class_).as_scalar()
@@ -484,11 +502,22 @@ class AggregatedValue(object):
property_ = self.relationships[-1].property
from_ = property_.mapper.class_.__table__
for relationship in reversed(self.relationships[1:-1]):
for relationship in reversed(self.relationships[0:-1]):
property_ = relationship.property
from_ = (
from_.join(property_.mapper.class_, property_.primaryjoin)
)
if property_.secondary is not None:
from_ = from_.join(
property_.secondary,
property_.primaryjoin
)
from_ = from_.join(
property_.mapper.class_,
property_.secondaryjoin
)
else:
from_ = from_.join(
property_.mapper.class_,
property_.primaryjoin
)
return from_
def local_condition(self, prop, objects):

View File

@@ -3,7 +3,7 @@ from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
class TestAggregateValueGenerationWithBackrefs(TestCase):
def create_models(self):
class Thread(self.Base):
__tablename__ = 'thread'

View File

@@ -0,0 +1,80 @@
import sqlalchemy as sa
from sqlalchemy_utils import aggregated
from tests import TestCase
class TestAggregateManyToManyAndManyToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
catalog_products = sa.Table(
'catalog_product',
self.Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
product_categories = sa.Table(
'category_product',
self.Base.metadata,
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
)
catalogs = sa.orm.relationship(
Catalog,
backref='products',
secondary=catalog_products
)
categories = sa.orm.relationship(
Category,
backref='products',
secondary=product_categories
)
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_insert(self):
category = self.Category()
products = [
self.Product(categories=[category]),
self.Product(categories=[category])
]
catalog = self.Catalog(products=products)
self.session.add(catalog)
catalog2 = self.Catalog(products=products)
self.session.add(catalog)
self.session.commit()
assert catalog.category_count == 1
assert catalog2.category_count == 1

View File

@@ -0,0 +1,76 @@
import sqlalchemy as sa
from sqlalchemy_utils import aggregated
from tests import TestCase
class TestAggregateOneToManyAndManyToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
product_categories = sa.Table(
'category_product',
self.Base.metadata,
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
)
catalog = sa.orm.relationship(
Catalog,
backref='products'
)
categories = sa.orm.relationship(
Category,
backref='products',
secondary=product_categories
)
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_insert(self):
category = self.Category()
products = [
self.Product(categories=[category]),
self.Product(categories=[category])
]
catalog = self.Catalog(products=products)
self.session.add(catalog)
products2 = [
self.Product(categories=[category]),
self.Product(categories=[category])
]
catalog2 = self.Catalog(products=products2)
self.session.add(catalog)
self.session.commit()
assert catalog.category_count == 1
assert catalog2.category_count == 1

View File

@@ -0,0 +1,62 @@
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateOneToManyAndOneToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
products = sa.orm.relationship('Product', backref='category')
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_assigns_aggregates(self):
category = self.Category(name=u'Some category')
catalog = self.Catalog(
categories=[category]
)
catalog.name = u'Some catalog'
self.session.add(catalog)
self.session.commit()
product = self.Product(
name=u'Some product',
price=Decimal('1000'),
category=category
)
self.session.add(product)
self.session.commit()
self.session.refresh(catalog)
assert catalog.product_count == 1

View File

@@ -1,76 +1,15 @@
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from sqlalchemy_utils import aggregated
from tests import TestCase
class TestDeepModelPathsForAggregates(TestCase):
class Test3LevelDeepOneToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
products = sa.orm.relationship('Product', backref='category')
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_assigns_aggregates(self):
category = self.Category(name=u'Some category')
catalog = self.Catalog(
categories=[category]
)
catalog.name = u'Some catalog'
self.session.add(catalog)
self.session.commit()
product = self.Product(
name=u'Some product',
price=Decimal('1000'),
category=category
)
self.session.add(product)
self.session.commit()
self.session.refresh(catalog)
assert catalog.product_count == 1
class Test3LevelDeepModelPathsForAggregates(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
n = 1
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'categories.sub_categories.products',
@@ -84,8 +23,6 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
sub_categories = sa.orm.relationship(
@@ -95,16 +32,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship('Product', backref='sub_category')
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
sub_category_id = sa.Column(
@@ -123,23 +56,13 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
assert catalog.product_count == 1
def catalog_factory(self):
product = self.Product(
name=u'Product %d' % self.n
)
product = self.Product()
sub_category = self.SubCategory(
name=u'SubCategory %d' % self.n,
products=[product]
)
category = self.Category(
name=u'Category %d' % self.n,
sub_categories=[sub_category]
)
catalog = self.Catalog(
categories=[category]
)
catalog.name = u'Catalog %d' % self.n
category = self.Category(sub_categories=[sub_category])
catalog = self.Catalog(categories=[category])
self.session.add(catalog)
self.n += 1
return catalog
def test_only_updates_affected_aggregates(self):
@@ -155,7 +78,7 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
)
catalog.categories[0].sub_categories[0].products.append(
self.Product(name=u'Product 3')
self.Product()
)
self.session.commit()
self.session.refresh(catalog)