Fix m2m relationship handling in aggregated

This commit is contained in:
Konsta Vesterinen
2014-12-01 15:07:39 +02:00
parent 5e00019be5
commit d4f4f4ec3c
8 changed files with 282 additions and 106 deletions

View File

@@ -4,6 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release. Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.27.9 (2014-12-01)
^^^^^^^^^^^^^^^^^^^
- Fix aggregated decorator many-to-many relationship handling
0.27.8 (2014-11-13) 0.27.8 (2014-11-13)
^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^

View File

@@ -401,6 +401,40 @@ class AggregatedAttribute(declared_attr):
return desc.column return desc.column
def get_aggregate_query(agg_expr, relationships):
from_ = relationships[0].mapper.class_.__table__
for relationship in relationships[0:-1]:
property_ = relationship.property
if property_.secondary is not None:
from_ = from_.join(
property_.secondary,
property_.secondaryjoin
)
from_ = (
from_
.join(
property_.parent.class_,
property_.primaryjoin
)
)
prop = relationships[-1].property
condition = prop.primaryjoin
if prop.secondary is not None:
from_ = from_.join(
prop.secondary,
prop.secondaryjoin
)
query = sa.select(
[agg_expr],
from_obj=[from_]
)
return query.where(condition)
class AggregatedValue(object): class AggregatedValue(object):
def __init__(self, class_, attr, relationships, expr): def __init__(self, class_, attr, relationships, expr):
self.class_ = class_ self.class_ = class_
@@ -418,23 +452,7 @@ class AggregatedValue(object):
@property @property
def aggregate_query(self): def aggregate_query(self):
from_ = self.relationships[0].mapper.class_.__table__ query = get_aggregate_query(self.expr, self.relationships)
for relationship in self.relationships[0:-1]:
property_ = relationship.property
from_ = (
from_
.join(
property_.parent.class_,
property_.primaryjoin
)
)
query = sa.select(
[self.expr],
from_obj=[from_]
)
query = query.where(self.relationships[-1])
return query.correlate(self.class_).as_scalar() return query.correlate(self.class_).as_scalar()
@@ -484,11 +502,22 @@ class AggregatedValue(object):
property_ = self.relationships[-1].property property_ = self.relationships[-1].property
from_ = property_.mapper.class_.__table__ from_ = property_.mapper.class_.__table__
for relationship in reversed(self.relationships[1:-1]): for relationship in reversed(self.relationships[0:-1]):
property_ = relationship.property property_ = relationship.property
from_ = ( if property_.secondary is not None:
from_.join(property_.mapper.class_, property_.primaryjoin) from_ = from_.join(
) property_.secondary,
property_.primaryjoin
)
from_ = from_.join(
property_.mapper.class_,
property_.secondaryjoin
)
else:
from_ = from_.join(
property_.mapper.class_,
property_.primaryjoin
)
return from_ return from_
def local_condition(self, prop, objects): def local_condition(self, prop, objects):

View File

@@ -3,7 +3,7 @@ from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase from tests import TestCase
class TestAggregateValueGenerationForSimpleModelPaths(TestCase): class TestAggregateValueGenerationWithBackrefs(TestCase):
def create_models(self): def create_models(self):
class Thread(self.Base): class Thread(self.Base):
__tablename__ = 'thread' __tablename__ = 'thread'

View File

@@ -0,0 +1,80 @@
import sqlalchemy as sa
from sqlalchemy_utils import aggregated
from tests import TestCase
class TestAggregateManyToManyAndManyToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
catalog_products = sa.Table(
'catalog_product',
self.Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
product_categories = sa.Table(
'category_product',
self.Base.metadata,
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
)
catalogs = sa.orm.relationship(
Catalog,
backref='products',
secondary=catalog_products
)
categories = sa.orm.relationship(
Category,
backref='products',
secondary=product_categories
)
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_insert(self):
category = self.Category()
products = [
self.Product(categories=[category]),
self.Product(categories=[category])
]
catalog = self.Catalog(products=products)
self.session.add(catalog)
catalog2 = self.Catalog(products=products)
self.session.add(catalog)
self.session.commit()
assert catalog.category_count == 1
assert catalog2.category_count == 1

View File

@@ -0,0 +1,76 @@
import sqlalchemy as sa
from sqlalchemy_utils import aggregated
from tests import TestCase
class TestAggregateOneToManyAndManyToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
product_categories = sa.Table(
'category_product',
self.Base.metadata,
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
)
catalog = sa.orm.relationship(
Catalog,
backref='products'
)
categories = sa.orm.relationship(
Category,
backref='products',
secondary=product_categories
)
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_insert(self):
category = self.Category()
products = [
self.Product(categories=[category]),
self.Product(categories=[category])
]
catalog = self.Catalog(products=products)
self.session.add(catalog)
products2 = [
self.Product(categories=[category]),
self.Product(categories=[category])
]
catalog2 = self.Catalog(products=products2)
self.session.add(catalog)
self.session.commit()
assert catalog.category_count == 1
assert catalog2.category_count == 1

View File

@@ -0,0 +1,62 @@
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateOneToManyAndOneToMany(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
products = sa.orm.relationship('Product', backref='category')
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_assigns_aggregates(self):
category = self.Category(name=u'Some category')
catalog = self.Catalog(
categories=[category]
)
catalog.name = u'Some catalog'
self.session.add(catalog)
self.session.commit()
product = self.Product(
name=u'Some product',
price=Decimal('1000'),
category=category
)
self.session.add(product)
self.session.commit()
self.session.refresh(catalog)
assert catalog.product_count == 1

View File

@@ -1,76 +1,15 @@
from decimal import Decimal
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated
from sqlalchemy_utils import aggregated
from tests import TestCase from tests import TestCase
class Test3LevelDeepOneToMany(TestCase):
class TestDeepModelPathsForAggregates(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self): def create_models(self):
class Catalog(self.Base): class Catalog(self.Base):
__tablename__ = 'catalog' __tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'categories.products',
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
products = sa.orm.relationship('Product', backref='category')
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_assigns_aggregates(self):
category = self.Category(name=u'Some category')
catalog = self.Catalog(
categories=[category]
)
catalog.name = u'Some catalog'
self.session.add(catalog)
self.session.commit()
product = self.Product(
name=u'Some product',
price=Decimal('1000'),
category=category
)
self.session.add(product)
self.session.commit()
self.session.refresh(catalog)
assert catalog.product_count == 1
class Test3LevelDeepModelPathsForAggregates(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
n = 1
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated( @aggregated(
'categories.sub_categories.products', 'categories.sub_categories.products',
@@ -84,8 +23,6 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
class Category(self.Base): class Category(self.Base):
__tablename__ = 'category' __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
sub_categories = sa.orm.relationship( sub_categories = sa.orm.relationship(
@@ -95,16 +32,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
class SubCategory(self.Base): class SubCategory(self.Base):
__tablename__ = 'sub_category' __tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship('Product', backref='sub_category') products = sa.orm.relationship('Product', backref='sub_category')
class Product(self.Base): class Product(self.Base):
__tablename__ = 'product' __tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric) price = sa.Column(sa.Numeric)
sub_category_id = sa.Column( sub_category_id = sa.Column(
@@ -123,23 +56,13 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
assert catalog.product_count == 1 assert catalog.product_count == 1
def catalog_factory(self): def catalog_factory(self):
product = self.Product( product = self.Product()
name=u'Product %d' % self.n
)
sub_category = self.SubCategory( sub_category = self.SubCategory(
name=u'SubCategory %d' % self.n,
products=[product] products=[product]
) )
category = self.Category( category = self.Category(sub_categories=[sub_category])
name=u'Category %d' % self.n, catalog = self.Catalog(categories=[category])
sub_categories=[sub_category]
)
catalog = self.Catalog(
categories=[category]
)
catalog.name = u'Catalog %d' % self.n
self.session.add(catalog) self.session.add(catalog)
self.n += 1
return catalog return catalog
def test_only_updates_affected_aggregates(self): def test_only_updates_affected_aggregates(self):
@@ -155,7 +78,7 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
) )
catalog.categories[0].sub_categories[0].products.append( catalog.categories[0].sub_categories[0].products.append(
self.Product(name=u'Product 3') self.Product()
) )
self.session.commit() self.session.commit()
self.session.refresh(catalog) self.session.refresh(catalog)