Added more tests for deep paths

This commit is contained in:
Konsta Vesterinen
2013-11-06 10:51:12 +02:00
parent 18af0e7d3d
commit 32384f9c53
2 changed files with 160 additions and 8 deletions

View File

@@ -268,11 +268,11 @@ class AggregatedValue(object):
@property @property
def aggregate_query(self): def aggregate_query(self):
from_ = self.relationships[0].mapper.class_ from_ = self.relationships[0].mapper.class_.__table__
for relationship in self.relationships[0:-1]: for relationship in self.relationships[0:-1]:
property_ = relationship.property property_ = relationship.property
from_ = ( from_ = (
from_.__table__ from_
.join( .join(
property_.parent.class_, property_.parent.class_,
property_.primaryjoin property_.primaryjoin
@@ -288,11 +288,58 @@ class AggregatedValue(object):
return query.correlate(self.class_).as_scalar() return query.correlate(self.class_).as_scalar()
@property def update_query(self, objects):
def update_query(self): table = self.class_.__table__
return self.class_.__table__.update().values( query = table.update().values(
{self.attr: self.aggregate_query} {self.attr: self.aggregate_query}
) )
if len(self.relationships) == 1:
remote_pairs = self.relationships[-1].property.local_remote_pairs
query = query.where(
remote_pairs[0][0].in_(
getattr(obj, remote_pairs[0][1].key) for obj in objects
)
)
else:
# Builds query such as:
#
# UPDATE catalog SET product_count = (aggregate_query)
# WHERE id IN (
# SELECT catalog_id
# FROM category
# INNER JOIN sub_category
# ON category.id = sub_category.category_id
# WHERE sub_category.id IN (product_sub_category_ids)
# )
property_ = self.relationships[-1].property
remote_pairs = property_.local_remote_pairs
from_ = property_.mapper.class_.__table__
for relationship in reversed(self.relationships[1:-1]):
property_ = relationship.property
from_ = (
from_.join(property_.mapper.class_, property_.primaryjoin)
)
property_ = self.relationships[0].property
query = query.where(
remote_pairs[0][0].in_(
sa.select(
[remote_pairs[0][1]],
from_obj=[from_]
).where(
property_.local_remote_pairs[0][0].in_(
getattr(
obj, property_.local_remote_pairs[0][1].key
)
for obj in objects
)
)
)
)
return query
class AggregationManager(object): class AggregationManager(object):
@@ -336,12 +383,16 @@ class AggregationManager(object):
) )
def construct_aggregate_queries(self, session, ctx): def construct_aggregate_queries(self, session, ctx):
object_dict = defaultdict(list)
for obj in session: for obj in session:
class_ = obj.__class__.__name__ class_ = obj.__class__.__name__
if class_ in self.generator_registry: if class_ in self.generator_registry:
for aggregate_value in self.generator_registry[class_]: object_dict[class_].append(obj)
query = aggregate_value.update_query
session.execute(query) for class_, objects in six.iteritems(object_dict):
for aggregate_value in self.generator_registry[class_]:
query = aggregate_value.update_query(objects)
session.execute(query)
manager = AggregationManager() manager = AggregationManager()

View File

@@ -5,6 +5,8 @@ from tests import TestCase
class TestDeepModelPathsForAggregates(TestCase): class TestDeepModelPathsForAggregates(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self): def create_models(self):
class Catalog(self.Base): class Catalog(self.Base):
__tablename__ = 'catalog' __tablename__ = 'catalog'
@@ -55,3 +57,102 @@ class TestDeepModelPathsForAggregates(TestCase):
self.session.commit() self.session.commit()
self.session.refresh(catalog) self.session.refresh(catalog)
assert catalog.product_count == 1 assert catalog.product_count == 1
class Test3LevelDeepModelPathsForAggregates(TestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
n = 1
def create_models(self):
class Catalog(self.Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated_attr('categories.sub_categories.products')
def product_count(self):
return sa.Column(sa.Integer, default=0)
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
sub_categories = sa.orm.relationship(
'SubCategory', backref='category'
)
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship('Product', backref='sub_category')
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
self.Catalog = Catalog
self.Category = Category
self.SubCategory = SubCategory
self.Product = Product
def test_assigns_aggregates(self):
catalog = self.catalog_factory()
self.session.commit()
self.session.refresh(catalog)
assert catalog.product_count == 1
def catalog_factory(self):
product = self.Product(
name=u'Product %d' % self.n
)
sub_category = self.SubCategory(
name=u'SubCategory %d' % self.n,
products=[product]
)
category = self.Category(
name=u'Category %d' % self.n,
sub_categories=[sub_category]
)
catalog = self.Catalog(
categories=[category]
)
catalog.name = u'Catalog %d' % self.n
self.session.add(catalog)
self.n += 1
return catalog
def test_only_updates_affected_aggregates(self):
catalog = self.catalog_factory()
catalog2 = self.catalog_factory()
self.session.commit()
# force set catalog2 product_count to zero in order to check if it gets
# updated when the other catalog's product count gets updated
self.session.execute(
'UPDATE catalog SET product_count = 0 WHERE id = %d'
% catalog2.id
)
catalog.categories[0].sub_categories[0].products.append(
self.Product(name=u'Product 3')
)
self.session.commit()
self.session.refresh(catalog)
self.session.refresh(catalog2)
assert catalog.product_count == 2
assert catalog2.product_count == 0