Added more tests for deep paths
This commit is contained in:
@@ -268,11 +268,11 @@ class AggregatedValue(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def aggregate_query(self):
|
def aggregate_query(self):
|
||||||
from_ = self.relationships[0].mapper.class_
|
from_ = self.relationships[0].mapper.class_.__table__
|
||||||
for relationship in self.relationships[0:-1]:
|
for relationship in self.relationships[0:-1]:
|
||||||
property_ = relationship.property
|
property_ = relationship.property
|
||||||
from_ = (
|
from_ = (
|
||||||
from_.__table__
|
from_
|
||||||
.join(
|
.join(
|
||||||
property_.parent.class_,
|
property_.parent.class_,
|
||||||
property_.primaryjoin
|
property_.primaryjoin
|
||||||
@@ -288,11 +288,58 @@ class AggregatedValue(object):
|
|||||||
|
|
||||||
return query.correlate(self.class_).as_scalar()
|
return query.correlate(self.class_).as_scalar()
|
||||||
|
|
||||||
@property
|
def update_query(self, objects):
|
||||||
def update_query(self):
|
table = self.class_.__table__
|
||||||
return self.class_.__table__.update().values(
|
query = table.update().values(
|
||||||
{self.attr: self.aggregate_query}
|
{self.attr: self.aggregate_query}
|
||||||
)
|
)
|
||||||
|
if len(self.relationships) == 1:
|
||||||
|
remote_pairs = self.relationships[-1].property.local_remote_pairs
|
||||||
|
|
||||||
|
query = query.where(
|
||||||
|
remote_pairs[0][0].in_(
|
||||||
|
getattr(obj, remote_pairs[0][1].key) for obj in objects
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Builds query such as:
|
||||||
|
#
|
||||||
|
# UPDATE catalog SET product_count = (aggregate_query)
|
||||||
|
# WHERE id IN (
|
||||||
|
# SELECT catalog_id
|
||||||
|
# FROM category
|
||||||
|
# INNER JOIN sub_category
|
||||||
|
# ON category.id = sub_category.category_id
|
||||||
|
# WHERE sub_category.id IN (product_sub_category_ids)
|
||||||
|
# )
|
||||||
|
property_ = self.relationships[-1].property
|
||||||
|
remote_pairs = property_.local_remote_pairs
|
||||||
|
from_ = property_.mapper.class_.__table__
|
||||||
|
for relationship in reversed(self.relationships[1:-1]):
|
||||||
|
property_ = relationship.property
|
||||||
|
from_ = (
|
||||||
|
from_.join(property_.mapper.class_, property_.primaryjoin)
|
||||||
|
)
|
||||||
|
|
||||||
|
property_ = self.relationships[0].property
|
||||||
|
|
||||||
|
query = query.where(
|
||||||
|
remote_pairs[0][0].in_(
|
||||||
|
sa.select(
|
||||||
|
[remote_pairs[0][1]],
|
||||||
|
from_obj=[from_]
|
||||||
|
).where(
|
||||||
|
property_.local_remote_pairs[0][0].in_(
|
||||||
|
getattr(
|
||||||
|
obj, property_.local_remote_pairs[0][1].key
|
||||||
|
)
|
||||||
|
for obj in objects
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
class AggregationManager(object):
|
class AggregationManager(object):
|
||||||
@@ -336,12 +383,16 @@ class AggregationManager(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def construct_aggregate_queries(self, session, ctx):
|
def construct_aggregate_queries(self, session, ctx):
|
||||||
|
object_dict = defaultdict(list)
|
||||||
for obj in session:
|
for obj in session:
|
||||||
class_ = obj.__class__.__name__
|
class_ = obj.__class__.__name__
|
||||||
if class_ in self.generator_registry:
|
if class_ in self.generator_registry:
|
||||||
for aggregate_value in self.generator_registry[class_]:
|
object_dict[class_].append(obj)
|
||||||
query = aggregate_value.update_query
|
|
||||||
session.execute(query)
|
for class_, objects in six.iteritems(object_dict):
|
||||||
|
for aggregate_value in self.generator_registry[class_]:
|
||||||
|
query = aggregate_value.update_query(objects)
|
||||||
|
session.execute(query)
|
||||||
|
|
||||||
|
|
||||||
manager = AggregationManager()
|
manager = AggregationManager()
|
||||||
|
@@ -5,6 +5,8 @@ from tests import TestCase
|
|||||||
|
|
||||||
|
|
||||||
class TestDeepModelPathsForAggregates(TestCase):
|
class TestDeepModelPathsForAggregates(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
class Catalog(self.Base):
|
class Catalog(self.Base):
|
||||||
__tablename__ = 'catalog'
|
__tablename__ = 'catalog'
|
||||||
@@ -55,3 +57,102 @@ class TestDeepModelPathsForAggregates(TestCase):
|
|||||||
self.session.commit()
|
self.session.commit()
|
||||||
self.session.refresh(catalog)
|
self.session.refresh(catalog)
|
||||||
assert catalog.product_count == 1
|
assert catalog.product_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
class Test3LevelDeepModelPathsForAggregates(TestCase):
|
||||||
|
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
n = 1
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class Catalog(self.Base):
|
||||||
|
__tablename__ = 'catalog'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
@aggregated_attr('categories.sub_categories.products')
|
||||||
|
def product_count(self):
|
||||||
|
return sa.Column(sa.Integer, default=0)
|
||||||
|
|
||||||
|
categories = sa.orm.relationship('Category', backref='catalog')
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
|
__tablename__ = 'category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||||
|
|
||||||
|
sub_categories = sa.orm.relationship(
|
||||||
|
'SubCategory', backref='category'
|
||||||
|
)
|
||||||
|
|
||||||
|
class SubCategory(self.Base):
|
||||||
|
__tablename__ = 'sub_category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||||
|
|
||||||
|
products = sa.orm.relationship('Product', backref='sub_category')
|
||||||
|
|
||||||
|
class Product(self.Base):
|
||||||
|
__tablename__ = 'product'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
price = sa.Column(sa.Numeric)
|
||||||
|
|
||||||
|
sub_category_id = sa.Column(
|
||||||
|
sa.Integer, sa.ForeignKey('sub_category.id')
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Catalog = Catalog
|
||||||
|
self.Category = Category
|
||||||
|
self.SubCategory = SubCategory
|
||||||
|
self.Product = Product
|
||||||
|
|
||||||
|
def test_assigns_aggregates(self):
|
||||||
|
catalog = self.catalog_factory()
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(catalog)
|
||||||
|
assert catalog.product_count == 1
|
||||||
|
|
||||||
|
def catalog_factory(self):
|
||||||
|
product = self.Product(
|
||||||
|
name=u'Product %d' % self.n
|
||||||
|
)
|
||||||
|
sub_category = self.SubCategory(
|
||||||
|
name=u'SubCategory %d' % self.n,
|
||||||
|
products=[product]
|
||||||
|
)
|
||||||
|
category = self.Category(
|
||||||
|
name=u'Category %d' % self.n,
|
||||||
|
sub_categories=[sub_category]
|
||||||
|
)
|
||||||
|
catalog = self.Catalog(
|
||||||
|
categories=[category]
|
||||||
|
)
|
||||||
|
catalog.name = u'Catalog %d' % self.n
|
||||||
|
self.session.add(catalog)
|
||||||
|
self.n += 1
|
||||||
|
return catalog
|
||||||
|
|
||||||
|
def test_only_updates_affected_aggregates(self):
|
||||||
|
catalog = self.catalog_factory()
|
||||||
|
catalog2 = self.catalog_factory()
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
# force set catalog2 product_count to zero in order to check if it gets
|
||||||
|
# updated when the other catalog's product count gets updated
|
||||||
|
self.session.execute(
|
||||||
|
'UPDATE catalog SET product_count = 0 WHERE id = %d'
|
||||||
|
% catalog2.id
|
||||||
|
)
|
||||||
|
|
||||||
|
catalog.categories[0].sub_categories[0].products.append(
|
||||||
|
self.Product(name=u'Product 3')
|
||||||
|
)
|
||||||
|
self.session.commit()
|
||||||
|
self.session.refresh(catalog)
|
||||||
|
self.session.refresh(catalog2)
|
||||||
|
assert catalog.product_count == 2
|
||||||
|
assert catalog2.product_count == 0
|
||||||
|
Reference in New Issue
Block a user