Added more tests for deep paths
This commit is contained in:
@@ -268,11 +268,11 @@ class AggregatedValue(object):
|
||||
|
||||
@property
|
||||
def aggregate_query(self):
|
||||
from_ = self.relationships[0].mapper.class_
|
||||
from_ = self.relationships[0].mapper.class_.__table__
|
||||
for relationship in self.relationships[0:-1]:
|
||||
property_ = relationship.property
|
||||
from_ = (
|
||||
from_.__table__
|
||||
from_
|
||||
.join(
|
||||
property_.parent.class_,
|
||||
property_.primaryjoin
|
||||
@@ -288,11 +288,58 @@ class AggregatedValue(object):
|
||||
|
||||
return query.correlate(self.class_).as_scalar()
|
||||
|
||||
@property
|
||||
def update_query(self):
|
||||
return self.class_.__table__.update().values(
|
||||
def update_query(self, objects):
|
||||
table = self.class_.__table__
|
||||
query = table.update().values(
|
||||
{self.attr: self.aggregate_query}
|
||||
)
|
||||
if len(self.relationships) == 1:
|
||||
remote_pairs = self.relationships[-1].property.local_remote_pairs
|
||||
|
||||
query = query.where(
|
||||
remote_pairs[0][0].in_(
|
||||
getattr(obj, remote_pairs[0][1].key) for obj in objects
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Builds query such as:
|
||||
#
|
||||
# UPDATE catalog SET product_count = (aggregate_query)
|
||||
# WHERE id IN (
|
||||
# SELECT catalog_id
|
||||
# FROM category
|
||||
# INNER JOIN sub_category
|
||||
# ON category.id = sub_category.category_id
|
||||
# WHERE sub_category.id IN (product_sub_category_ids)
|
||||
# )
|
||||
property_ = self.relationships[-1].property
|
||||
remote_pairs = property_.local_remote_pairs
|
||||
from_ = property_.mapper.class_.__table__
|
||||
for relationship in reversed(self.relationships[1:-1]):
|
||||
property_ = relationship.property
|
||||
from_ = (
|
||||
from_.join(property_.mapper.class_, property_.primaryjoin)
|
||||
)
|
||||
|
||||
property_ = self.relationships[0].property
|
||||
|
||||
query = query.where(
|
||||
remote_pairs[0][0].in_(
|
||||
sa.select(
|
||||
[remote_pairs[0][1]],
|
||||
from_obj=[from_]
|
||||
).where(
|
||||
property_.local_remote_pairs[0][0].in_(
|
||||
getattr(
|
||||
obj, property_.local_remote_pairs[0][1].key
|
||||
)
|
||||
for obj in objects
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
|
||||
class AggregationManager(object):
|
||||
@@ -336,12 +383,16 @@ class AggregationManager(object):
|
||||
)
|
||||
|
||||
def construct_aggregate_queries(self, session, ctx):
|
||||
object_dict = defaultdict(list)
|
||||
for obj in session:
|
||||
class_ = obj.__class__.__name__
|
||||
if class_ in self.generator_registry:
|
||||
for aggregate_value in self.generator_registry[class_]:
|
||||
query = aggregate_value.update_query
|
||||
session.execute(query)
|
||||
object_dict[class_].append(obj)
|
||||
|
||||
for class_, objects in six.iteritems(object_dict):
|
||||
for aggregate_value in self.generator_registry[class_]:
|
||||
query = aggregate_value.update_query(objects)
|
||||
session.execute(query)
|
||||
|
||||
|
||||
manager = AggregationManager()
|
||||
|
@@ -5,6 +5,8 @@ from tests import TestCase
|
||||
|
||||
|
||||
class TestDeepModelPathsForAggregates(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
@@ -55,3 +57,102 @@ class TestDeepModelPathsForAggregates(TestCase):
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
assert catalog.product_count == 1
|
||||
|
||||
|
||||
class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
n = 1
|
||||
|
||||
def create_models(self):
|
||||
class Catalog(self.Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated_attr('categories.sub_categories.products')
|
||||
def product_count(self):
|
||||
return sa.Column(sa.Integer, default=0)
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
sub_categories = sa.orm.relationship(
|
||||
'SubCategory', backref='category'
|
||||
)
|
||||
|
||||
class SubCategory(self.Base):
|
||||
__tablename__ = 'sub_category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
|
||||
products = sa.orm.relationship('Product', backref='sub_category')
|
||||
|
||||
class Product(self.Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
sub_category_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey('sub_category.id')
|
||||
)
|
||||
|
||||
self.Catalog = Catalog
|
||||
self.Category = Category
|
||||
self.SubCategory = SubCategory
|
||||
self.Product = Product
|
||||
|
||||
def test_assigns_aggregates(self):
|
||||
catalog = self.catalog_factory()
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
assert catalog.product_count == 1
|
||||
|
||||
def catalog_factory(self):
|
||||
product = self.Product(
|
||||
name=u'Product %d' % self.n
|
||||
)
|
||||
sub_category = self.SubCategory(
|
||||
name=u'SubCategory %d' % self.n,
|
||||
products=[product]
|
||||
)
|
||||
category = self.Category(
|
||||
name=u'Category %d' % self.n,
|
||||
sub_categories=[sub_category]
|
||||
)
|
||||
catalog = self.Catalog(
|
||||
categories=[category]
|
||||
)
|
||||
catalog.name = u'Catalog %d' % self.n
|
||||
self.session.add(catalog)
|
||||
self.n += 1
|
||||
return catalog
|
||||
|
||||
def test_only_updates_affected_aggregates(self):
|
||||
catalog = self.catalog_factory()
|
||||
catalog2 = self.catalog_factory()
|
||||
self.session.commit()
|
||||
|
||||
# force set catalog2 product_count to zero in order to check if it gets
|
||||
# updated when the other catalog's product count gets updated
|
||||
self.session.execute(
|
||||
'UPDATE catalog SET product_count = 0 WHERE id = %d'
|
||||
% catalog2.id
|
||||
)
|
||||
|
||||
catalog.categories[0].sub_categories[0].products.append(
|
||||
self.Product(name=u'Product 3')
|
||||
)
|
||||
self.session.commit()
|
||||
self.session.refresh(catalog)
|
||||
self.session.refresh(catalog2)
|
||||
assert catalog.product_count == 2
|
||||
assert catalog2.product_count == 0
|
||||
|
Reference in New Issue
Block a user