import sqlalchemy as sa from sqlalchemy_utils.aggregates import select_aggregate from tests import TestCase from tests.mixins import ThreeLevelDeepManyToMany def normalize(sql): return ' '.join(sql.replace('\n', '').split()) class TestAggregateQueryForDeepToManyToMany( ThreeLevelDeepManyToMany, TestCase ): dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' create_tables = False def assert_sql(self, construct, sql): assert normalize(str(construct)) == normalize(sql) def build_update(self, *relationships): expr = sa.func.count(sa.text('1')) return ( self.Catalog.__table__.update().values( _id=select_aggregate( expr, relationships ).correlate(self.Catalog) ) ) def test_simple_join(self): self.assert_sql( self.build_update(self.Catalog.categories), ( '''UPDATE catalog SET _id=(SELECT count(1) AS count_1 FROM category JOIN catalog_category ON category._id = catalog_category.category_id WHERE catalog._id = catalog_category.catalog_id)''' ) ) def test_two_relations(self): self.assert_sql( self.build_update( self.Category.sub_categories, self.Catalog.categories, ), ( '''UPDATE catalog SET _id=(SELECT count(1) AS count_1 FROM sub_category JOIN category_subcategory ON sub_category._id = category_subcategory.subcategory_id JOIN category ON category._id = category_subcategory.category_id JOIN catalog_category ON category._id = catalog_category.category_id WHERE catalog._id = catalog_category.catalog_id)''' ) )