Fix m2m relationship handling in aggregated
This commit is contained in:
@@ -4,6 +4,12 @@ Changelog
|
|||||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||||
|
|
||||||
|
|
||||||
|
0.27.9 (2014-12-01)
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- Fix aggregated decorator many-to-many relationship handling
|
||||||
|
|
||||||
|
|
||||||
0.27.8 (2014-11-13)
|
0.27.8 (2014-11-13)
|
||||||
^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
@@ -401,6 +401,40 @@ class AggregatedAttribute(declared_attr):
|
|||||||
return desc.column
|
return desc.column
|
||||||
|
|
||||||
|
|
||||||
|
def get_aggregate_query(agg_expr, relationships):
|
||||||
|
from_ = relationships[0].mapper.class_.__table__
|
||||||
|
for relationship in relationships[0:-1]:
|
||||||
|
property_ = relationship.property
|
||||||
|
if property_.secondary is not None:
|
||||||
|
from_ = from_.join(
|
||||||
|
property_.secondary,
|
||||||
|
property_.secondaryjoin
|
||||||
|
)
|
||||||
|
|
||||||
|
from_ = (
|
||||||
|
from_
|
||||||
|
.join(
|
||||||
|
property_.parent.class_,
|
||||||
|
property_.primaryjoin
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prop = relationships[-1].property
|
||||||
|
condition = prop.primaryjoin
|
||||||
|
if prop.secondary is not None:
|
||||||
|
from_ = from_.join(
|
||||||
|
prop.secondary,
|
||||||
|
prop.secondaryjoin
|
||||||
|
)
|
||||||
|
|
||||||
|
query = sa.select(
|
||||||
|
[agg_expr],
|
||||||
|
from_obj=[from_]
|
||||||
|
)
|
||||||
|
|
||||||
|
return query.where(condition)
|
||||||
|
|
||||||
|
|
||||||
class AggregatedValue(object):
|
class AggregatedValue(object):
|
||||||
def __init__(self, class_, attr, relationships, expr):
|
def __init__(self, class_, attr, relationships, expr):
|
||||||
self.class_ = class_
|
self.class_ = class_
|
||||||
@@ -418,23 +452,7 @@ class AggregatedValue(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def aggregate_query(self):
|
def aggregate_query(self):
|
||||||
from_ = self.relationships[0].mapper.class_.__table__
|
query = get_aggregate_query(self.expr, self.relationships)
|
||||||
for relationship in self.relationships[0:-1]:
|
|
||||||
property_ = relationship.property
|
|
||||||
from_ = (
|
|
||||||
from_
|
|
||||||
.join(
|
|
||||||
property_.parent.class_,
|
|
||||||
property_.primaryjoin
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
query = sa.select(
|
|
||||||
[self.expr],
|
|
||||||
from_obj=[from_]
|
|
||||||
)
|
|
||||||
|
|
||||||
query = query.where(self.relationships[-1])
|
|
||||||
|
|
||||||
return query.correlate(self.class_).as_scalar()
|
return query.correlate(self.class_).as_scalar()
|
||||||
|
|
||||||
@@ -484,11 +502,22 @@ class AggregatedValue(object):
|
|||||||
property_ = self.relationships[-1].property
|
property_ = self.relationships[-1].property
|
||||||
|
|
||||||
from_ = property_.mapper.class_.__table__
|
from_ = property_.mapper.class_.__table__
|
||||||
for relationship in reversed(self.relationships[1:-1]):
|
for relationship in reversed(self.relationships[0:-1]):
|
||||||
property_ = relationship.property
|
property_ = relationship.property
|
||||||
from_ = (
|
if property_.secondary is not None:
|
||||||
from_.join(property_.mapper.class_, property_.primaryjoin)
|
from_ = from_.join(
|
||||||
)
|
property_.secondary,
|
||||||
|
property_.primaryjoin
|
||||||
|
)
|
||||||
|
from_ = from_.join(
|
||||||
|
property_.mapper.class_,
|
||||||
|
property_.secondaryjoin
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from_ = from_.join(
|
||||||
|
property_.mapper.class_,
|
||||||
|
property_.primaryjoin
|
||||||
|
)
|
||||||
return from_
|
return from_
|
||||||
|
|
||||||
def local_condition(self, prop, objects):
|
def local_condition(self, prop, objects):
|
||||||
|
@@ -3,7 +3,7 @@ from sqlalchemy_utils.aggregates import aggregated
|
|||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
class TestAggregateValueGenerationForSimpleModelPaths(TestCase):
|
class TestAggregateValueGenerationWithBackrefs(TestCase):
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
class Thread(self.Base):
|
class Thread(self.Base):
|
||||||
__tablename__ = 'thread'
|
__tablename__ = 'thread'
|
||||||
|
80
tests/aggregate/test_m2m_m2m.py
Normal file
80
tests/aggregate/test_m2m_m2m.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from sqlalchemy_utils import aggregated
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestAggregateManyToManyAndManyToMany(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
catalog_products = sa.Table(
|
||||||
|
'catalog_product',
|
||||||
|
self.Base.metadata,
|
||||||
|
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
|
||||||
|
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||||
|
)
|
||||||
|
|
||||||
|
product_categories = sa.Table(
|
||||||
|
'category_product',
|
||||||
|
self.Base.metadata,
|
||||||
|
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
|
||||||
|
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||||
|
)
|
||||||
|
|
||||||
|
class Catalog(self.Base):
|
||||||
|
__tablename__ = 'catalog'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregated(
|
||||||
|
'products.categories',
|
||||||
|
sa.Column(sa.Integer, default=0)
|
||||||
|
)
|
||||||
|
def category_count(self):
|
||||||
|
return sa.func.count(sa.distinct(Category.id))
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
|
__tablename__ = 'category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Product(self.Base):
|
||||||
|
__tablename__ = 'product'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
|
catalog_id = sa.Column(
|
||||||
|
sa.Integer, sa.ForeignKey('catalog.id')
|
||||||
|
)
|
||||||
|
|
||||||
|
catalogs = sa.orm.relationship(
|
||||||
|
Catalog,
|
||||||
|
backref='products',
|
||||||
|
secondary=catalog_products
|
||||||
|
)
|
||||||
|
|
||||||
|
categories = sa.orm.relationship(
|
||||||
|
Category,
|
||||||
|
backref='products',
|
||||||
|
secondary=product_categories
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Catalog = Catalog
|
||||||
|
self.Category = Category
|
||||||
|
self.Product = Product
|
||||||
|
|
||||||
|
def test_insert(self):
|
||||||
|
category = self.Category()
|
||||||
|
products = [
|
||||||
|
self.Product(categories=[category]),
|
||||||
|
self.Product(categories=[category])
|
||||||
|
]
|
||||||
|
catalog = self.Catalog(products=products)
|
||||||
|
self.session.add(catalog)
|
||||||
|
catalog2 = self.Catalog(products=products)
|
||||||
|
self.session.add(catalog)
|
||||||
|
self.session.commit()
|
||||||
|
assert catalog.category_count == 1
|
||||||
|
assert catalog2.category_count == 1
|
76
tests/aggregate/test_o2m_m2m.py
Normal file
76
tests/aggregate/test_o2m_m2m.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from sqlalchemy_utils import aggregated
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestAggregateOneToManyAndManyToMany(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
product_categories = sa.Table(
|
||||||
|
'category_product',
|
||||||
|
self.Base.metadata,
|
||||||
|
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
|
||||||
|
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
|
||||||
|
)
|
||||||
|
|
||||||
|
class Catalog(self.Base):
|
||||||
|
__tablename__ = 'catalog'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregated(
|
||||||
|
'products.categories',
|
||||||
|
sa.Column(sa.Integer, default=0)
|
||||||
|
)
|
||||||
|
def category_count(self):
|
||||||
|
return sa.func.count(sa.distinct(Category.id))
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
|
__tablename__ = 'category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Product(self.Base):
|
||||||
|
__tablename__ = 'product'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
|
catalog_id = sa.Column(
|
||||||
|
sa.Integer, sa.ForeignKey('catalog.id')
|
||||||
|
)
|
||||||
|
|
||||||
|
catalog = sa.orm.relationship(
|
||||||
|
Catalog,
|
||||||
|
backref='products'
|
||||||
|
)
|
||||||
|
|
||||||
|
categories = sa.orm.relationship(
|
||||||
|
Category,
|
||||||
|
backref='products',
|
||||||
|
secondary=product_categories
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Catalog = Catalog
|
||||||
|
self.Category = Category
|
||||||
|
self.Product = Product
|
||||||
|
|
||||||
|
def test_insert(self):
|
||||||
|
category = self.Category()
|
||||||
|
products = [
|
||||||
|
self.Product(categories=[category]),
|
||||||
|
self.Product(categories=[category])
|
||||||
|
]
|
||||||
|
catalog = self.Catalog(products=products)
|
||||||
|
self.session.add(catalog)
|
||||||
|
products2 = [
|
||||||
|
self.Product(categories=[category]),
|
||||||
|
self.Product(categories=[category])
|
||||||
|
]
|
||||||
|
catalog2 = self.Catalog(products=products2)
|
||||||
|
self.session.add(catalog)
|
||||||
|
self.session.commit()
|
||||||
|
assert catalog.category_count == 1
|
||||||
|
assert catalog2.category_count == 1
|
62
tests/aggregate/test_o2m_o2m.py
Normal file
62
tests/aggregate/test_o2m_o2m.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from decimal import Decimal
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils.aggregates import aggregated
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestAggregateOneToManyAndOneToMany(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class Catalog(self.Base):
|
||||||
|
__tablename__ = 'catalog'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregated(
|
||||||
|
'categories.products',
|
||||||
|
sa.Column(sa.Integer, default=0)
|
||||||
|
)
|
||||||
|
def product_count(self):
|
||||||
|
return sa.func.count('1')
|
||||||
|
|
||||||
|
categories = sa.orm.relationship('Category', backref='catalog')
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
|
__tablename__ = 'category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||||
|
|
||||||
|
products = sa.orm.relationship('Product', backref='category')
|
||||||
|
|
||||||
|
class Product(self.Base):
|
||||||
|
__tablename__ = 'product'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
|
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||||
|
|
||||||
|
self.Catalog = Catalog
|
||||||
|
self.Category = Category
|
||||||
|
self.Product = Product
|
||||||
|
|
||||||
|
def test_assigns_aggregates(self):
|
||||||
|
category = self.Category(name=u'Some category')
|
||||||
|
catalog = self.Catalog(
|
||||||
|
categories=[category]
|
||||||
|
)
|
||||||
|
catalog.name = u'Some catalog'
|
||||||
|
self.session.add(catalog)
|
||||||
|
self.session.commit()
|
||||||
|
product = self.Product(
|
||||||
|
name=u'Some product',
|
||||||
|
price=Decimal('1000'),
|
||||||
|
category=category
|
||||||
|
)
|
||||||
|
self.session.add(product)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(catalog)
|
||||||
|
assert catalog.product_count == 1
|
@@ -1,76 +1,15 @@
|
|||||||
from decimal import Decimal
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy_utils.aggregates import aggregated
|
|
||||||
|
from sqlalchemy_utils import aggregated
|
||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
class Test3LevelDeepOneToMany(TestCase):
|
||||||
class TestDeepModelPathsForAggregates(TestCase):
|
|
||||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
class Catalog(self.Base):
|
class Catalog(self.Base):
|
||||||
__tablename__ = 'catalog'
|
__tablename__ = 'catalog'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
@aggregated(
|
|
||||||
'categories.products',
|
|
||||||
sa.Column(sa.Integer, default=0)
|
|
||||||
)
|
|
||||||
def product_count(self):
|
|
||||||
return sa.func.count('1')
|
|
||||||
|
|
||||||
categories = sa.orm.relationship('Category', backref='catalog')
|
|
||||||
|
|
||||||
class Category(self.Base):
|
|
||||||
__tablename__ = 'category'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
|
||||||
|
|
||||||
products = sa.orm.relationship('Product', backref='category')
|
|
||||||
|
|
||||||
class Product(self.Base):
|
|
||||||
__tablename__ = 'product'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
price = sa.Column(sa.Numeric)
|
|
||||||
|
|
||||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
|
||||||
|
|
||||||
self.Catalog = Catalog
|
|
||||||
self.Category = Category
|
|
||||||
self.Product = Product
|
|
||||||
|
|
||||||
def test_assigns_aggregates(self):
|
|
||||||
category = self.Category(name=u'Some category')
|
|
||||||
catalog = self.Catalog(
|
|
||||||
categories=[category]
|
|
||||||
)
|
|
||||||
catalog.name = u'Some catalog'
|
|
||||||
self.session.add(catalog)
|
|
||||||
self.session.commit()
|
|
||||||
product = self.Product(
|
|
||||||
name=u'Some product',
|
|
||||||
price=Decimal('1000'),
|
|
||||||
category=category
|
|
||||||
)
|
|
||||||
self.session.add(product)
|
|
||||||
self.session.commit()
|
|
||||||
self.session.refresh(catalog)
|
|
||||||
assert catalog.product_count == 1
|
|
||||||
|
|
||||||
|
|
||||||
class Test3LevelDeepModelPathsForAggregates(TestCase):
|
|
||||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
|
||||||
n = 1
|
|
||||||
|
|
||||||
def create_models(self):
|
|
||||||
class Catalog(self.Base):
|
|
||||||
__tablename__ = 'catalog'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
@aggregated(
|
@aggregated(
|
||||||
'categories.sub_categories.products',
|
'categories.sub_categories.products',
|
||||||
@@ -84,8 +23,6 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
|||||||
class Category(self.Base):
|
class Category(self.Base):
|
||||||
__tablename__ = 'category'
|
__tablename__ = 'category'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||||
|
|
||||||
sub_categories = sa.orm.relationship(
|
sub_categories = sa.orm.relationship(
|
||||||
@@ -95,16 +32,12 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
|||||||
class SubCategory(self.Base):
|
class SubCategory(self.Base):
|
||||||
__tablename__ = 'sub_category'
|
__tablename__ = 'sub_category'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||||
|
|
||||||
products = sa.orm.relationship('Product', backref='sub_category')
|
products = sa.orm.relationship('Product', backref='sub_category')
|
||||||
|
|
||||||
class Product(self.Base):
|
class Product(self.Base):
|
||||||
__tablename__ = 'product'
|
__tablename__ = 'product'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
price = sa.Column(sa.Numeric)
|
price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
sub_category_id = sa.Column(
|
sub_category_id = sa.Column(
|
||||||
@@ -123,23 +56,13 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
|||||||
assert catalog.product_count == 1
|
assert catalog.product_count == 1
|
||||||
|
|
||||||
def catalog_factory(self):
|
def catalog_factory(self):
|
||||||
product = self.Product(
|
product = self.Product()
|
||||||
name=u'Product %d' % self.n
|
|
||||||
)
|
|
||||||
sub_category = self.SubCategory(
|
sub_category = self.SubCategory(
|
||||||
name=u'SubCategory %d' % self.n,
|
|
||||||
products=[product]
|
products=[product]
|
||||||
)
|
)
|
||||||
category = self.Category(
|
category = self.Category(sub_categories=[sub_category])
|
||||||
name=u'Category %d' % self.n,
|
catalog = self.Catalog(categories=[category])
|
||||||
sub_categories=[sub_category]
|
|
||||||
)
|
|
||||||
catalog = self.Catalog(
|
|
||||||
categories=[category]
|
|
||||||
)
|
|
||||||
catalog.name = u'Catalog %d' % self.n
|
|
||||||
self.session.add(catalog)
|
self.session.add(catalog)
|
||||||
self.n += 1
|
|
||||||
return catalog
|
return catalog
|
||||||
|
|
||||||
def test_only_updates_affected_aggregates(self):
|
def test_only_updates_affected_aggregates(self):
|
||||||
@@ -155,7 +78,7 @@ class Test3LevelDeepModelPathsForAggregates(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
catalog.categories[0].sub_categories[0].products.append(
|
catalog.categories[0].sub_categories[0].products.append(
|
||||||
self.Product(name=u'Product 3')
|
self.Product()
|
||||||
)
|
)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
self.session.refresh(catalog)
|
self.session.refresh(catalog)
|
Reference in New Issue
Block a user