diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index f36efdb..c68c8ee 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -268,11 +268,11 @@ class AggregatedValue(object): @property def aggregate_query(self): - from_ = self.relationships[0].mapper.class_ + from_ = self.relationships[0].mapper.class_.__table__ for relationship in self.relationships[0:-1]: property_ = relationship.property from_ = ( - from_.__table__ + from_ .join( property_.parent.class_, property_.primaryjoin @@ -288,11 +288,58 @@ class AggregatedValue(object): return query.correlate(self.class_).as_scalar() - @property - def update_query(self): - return self.class_.__table__.update().values( + def update_query(self, objects): + table = self.class_.__table__ + query = table.update().values( {self.attr: self.aggregate_query} ) + if len(self.relationships) == 1: + remote_pairs = self.relationships[-1].property.local_remote_pairs + + query = query.where( + remote_pairs[0][0].in_( + getattr(obj, remote_pairs[0][1].key) for obj in objects + ) + ) + else: + # Builds query such as: + # + # UPDATE catalog SET product_count = (aggregate_query) + # WHERE id IN ( + # SELECT catalog_id + # FROM category + # INNER JOIN sub_category + # ON category.id = sub_category.category_id + # WHERE sub_category.id IN (product_sub_category_ids) + # ) + property_ = self.relationships[-1].property + remote_pairs = property_.local_remote_pairs + from_ = property_.mapper.class_.__table__ + for relationship in reversed(self.relationships[1:-1]): + property_ = relationship.property + from_ = ( + from_.join(property_.mapper.class_, property_.primaryjoin) + ) + + property_ = self.relationships[0].property + + query = query.where( + remote_pairs[0][0].in_( + sa.select( + [remote_pairs[0][1]], + from_obj=[from_] + ).where( + property_.local_remote_pairs[0][0].in_( + getattr( + obj, property_.local_remote_pairs[0][1].key + ) + for obj in objects + ) + ) + ) + ) + + return query class AggregationManager(object): @@ -336,12 +383,16 @@ class AggregationManager(object): ) def construct_aggregate_queries(self, session, ctx): + object_dict = defaultdict(list) for obj in session: class_ = obj.__class__.__name__ if class_ in self.generator_registry: - for aggregate_value in self.generator_registry[class_]: - query = aggregate_value.update_query - session.execute(query) + object_dict[class_].append(obj) + + for class_, objects in six.iteritems(object_dict): + for aggregate_value in self.generator_registry[class_]: + query = aggregate_value.update_query(objects) + session.execute(query) manager = AggregationManager() diff --git a/tests/aggregate/test_deep_paths.py b/tests/aggregate/test_deep_paths.py index 1cecd1d..e74d546 100644 --- a/tests/aggregate/test_deep_paths.py +++ b/tests/aggregate/test_deep_paths.py @@ -5,6 +5,8 @@ from tests import TestCase class TestDeepModelPathsForAggregates(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + def create_models(self): class Catalog(self.Base): __tablename__ = 'catalog' @@ -55,3 +57,102 @@ class TestDeepModelPathsForAggregates(TestCase): self.session.commit() self.session.refresh(catalog) assert catalog.product_count == 1 + + +class Test3LevelDeepModelPathsForAggregates(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + n = 1 + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated_attr('categories.sub_categories.products') + def product_count(self): + return sa.Column(sa.Integer, default=0) + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_categories = sa.orm.relationship( + 'SubCategory', backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + + products = sa.orm.relationship('Product', backref='sub_category') + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def test_assigns_aggregates(self): + catalog = self.catalog_factory() + self.session.commit() + self.session.refresh(catalog) + assert catalog.product_count == 1 + + def catalog_factory(self): + product = self.Product( + name=u'Product %d' % self.n + ) + sub_category = self.SubCategory( + name=u'SubCategory %d' % self.n, + products=[product] + ) + category = self.Category( + name=u'Category %d' % self.n, + sub_categories=[sub_category] + ) + catalog = self.Catalog( + categories=[category] + ) + catalog.name = u'Catalog %d' % self.n + self.session.add(catalog) + self.n += 1 + return catalog + + def test_only_updates_affected_aggregates(self): + catalog = self.catalog_factory() + catalog2 = self.catalog_factory() + self.session.commit() + + # force set catalog2 product_count to zero in order to check if it gets + # updated when the other catalog's product count gets updated + self.session.execute( + 'UPDATE catalog SET product_count = 0 WHERE id = %d' + % catalog2.id + ) + + catalog.categories[0].sub_categories[0].products.append( + self.Product(name=u'Product 3') + ) + self.session.commit() + self.session.refresh(catalog) + self.session.refresh(catalog2) + assert catalog.product_count == 2 + assert catalog2.product_count == 0