diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index ffb3743..a828b23 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -375,6 +375,8 @@ except ImportError: # SQLAlchemy 0.8 from sqlalchemy.sql.expression import _FunctionGenerator +from .relationships import chained_join + aggregated_attrs = WeakKeyDictionary(defaultdict(list)) @@ -452,6 +454,26 @@ def get_aggregate_query(agg_expr, relationships): return query.where(condition) +def local_condition(prop, objects): + pairs = prop.local_remote_pairs + if prop.secondary is not None: + column = pairs[1][0] + key = pairs[1][0].key + else: + column = pairs[0][0] + key = pairs[0][1].key + + values = [] + for obj in objects: + try: + values.append(getattr(obj, key)) + except sa.orm.exc.ObjectDeletedError: + pass + + if values: + return column.in_(values) + + class AggregatedValue(object): def __init__(self, class_, attr, relationships, expr): self.class_ = class_ @@ -480,7 +502,7 @@ class AggregatedValue(object): ) if len(self.relationships) == 1: prop = self.relationships[-1].property - condition = self.local_condition(prop, objects) + condition = local_condition(prop, objects) if condition is not None: return query.where(condition) else: @@ -498,7 +520,7 @@ class AggregatedValue(object): remote_pairs = property_.local_remote_pairs local = remote_pairs[0][0] remote = remote_pairs[0][1] - condition = self.local_condition( + condition = local_condition( self.relationships[0].property, objects ) @@ -507,55 +529,15 @@ class AggregatedValue(object): local.in_( sa.select( [remote], - from_obj=[self.multi_level_aggregate_query_base] + from_obj=[ + chained_join(*reversed(self.relationships)) + ] ).where( condition ) ) ) - @property - def multi_level_aggregate_query_base(self): - property_ = self.relationships[-1].property - - from_ = property_.mapper.class_.__table__ - for relationship in reversed(self.relationships[0:-1]): - property_ = relationship.property - if property_.secondary is not None: - from_ = from_.join( - property_.secondary, - property_.primaryjoin - ) - from_ = from_.join( - property_.mapper.class_, - property_.secondaryjoin - ) - else: - from_ = from_.join( - property_.mapper.class_, - property_.primaryjoin - ) - return from_ - - def local_condition(self, prop, objects): - pairs = prop.local_remote_pairs - if prop.secondary is not None: - column = pairs[1][0] - key = pairs[1][0].key - else: - column = pairs[0][0] - key = pairs[0][1].key - - values = [] - for obj in objects: - try: - values.append(getattr(obj, key)) - except sa.orm.exc.ObjectDeletedError: - pass - - if values: - return column.in_(values) - class AggregationManager(object): def __init__(self): diff --git a/sqlalchemy_utils/relationships/__init__.py b/sqlalchemy_utils/relationships/__init__.py new file mode 100644 index 0000000..81784bf --- /dev/null +++ b/sqlalchemy_utils/relationships/__init__.py @@ -0,0 +1 @@ +from .chained_join import chained_join diff --git a/sqlalchemy_utils/relationships/chained_join.py b/sqlalchemy_utils/relationships/chained_join.py new file mode 100644 index 0000000..68803cf --- /dev/null +++ b/sqlalchemy_utils/relationships/chained_join.py @@ -0,0 +1,28 @@ +def chained_join(*relationships): + property_ = relationships[0].property + + if property_.secondary is not None: + from_ = property_.secondary.join( + property_.mapper.class_.__table__, + property_.secondaryjoin + ) + else: + from_ = property_.mapper.class_.__table__ + for relationship in relationships[1:]: + prop = relationship.property + if prop.secondary is not None: + from_ = from_.join( + prop.secondary, + prop.primaryjoin + ) + + from_ = from_.join( + prop.mapper.class_, + prop.secondaryjoin + ) + else: + from_ = from_.join( + prop.mapper.class_, + prop.primaryjoin + ) + return from_ diff --git a/tests/relationships/__init__.py b/tests/relationships/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/relationships/test_chained_join.py b/tests/relationships/test_chained_join.py new file mode 100644 index 0000000..550e73b --- /dev/null +++ b/tests/relationships/test_chained_join.py @@ -0,0 +1,266 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.relationships import chained_join +from tests import TestCase + + +class TestChainedJoinForManyToManyToManyToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + create_tables = False + + def create_models(self): + catalog_category = sa.Table( + 'catalog_category', + self.Base.metadata, + sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), + sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')) + ) + + category_subcategory = sa.Table( + 'category_subcategory', + self.Base.metadata, + sa.Column( + 'category_id', + sa.Integer, + sa.ForeignKey('category.id') + ), + sa.Column( + 'subcategory_id', + sa.Integer, + sa.ForeignKey('sub_category.id') + ) + ) + + subcategory_product = sa.Table( + 'subcategory_product', + self.Base.metadata, + sa.Column( + 'subcategory_id', + sa.Integer, + sa.ForeignKey('sub_category.id') + ), + sa.Column( + 'product_id', + sa.Integer, + sa.ForeignKey('product.id') + ) + ) + + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + + categories = sa.orm.relationship( + 'Category', + backref='catalogs', + secondary=catalog_category + ) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + + sub_categories = sa.orm.relationship( + 'SubCategory', + backref='categories', + secondary=category_subcategory + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + products = sa.orm.relationship( + 'Product', + backref='sub_categories', + secondary=subcategory_product + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def test_simple_join(self): + assert str(chained_join(self.Catalog.categories)) == ( + 'catalog_category JOIN category ON ' + 'category.id = catalog_category.category_id' + ) + + def test_two_relations(self): + sql = chained_join( + self.Catalog.categories, + self.Category.sub_categories + ) + assert str(sql) == ( + 'catalog_category JOIN category ON category.id = ' + 'catalog_category.category_id JOIN category_subcategory ON ' + 'category.id = category_subcategory.category_id JOIN sub_category ' + 'ON sub_category.id = category_subcategory.subcategory_id' + ) + + def test_three_relations(self): + sql = chained_join( + self.Catalog.categories, + self.Category.sub_categories, + self.SubCategory.products + ) + assert str(sql) == ( + 'catalog_category JOIN category ON category.id = ' + 'catalog_category.category_id JOIN category_subcategory ON ' + 'category.id = category_subcategory.category_id JOIN sub_category ' + 'ON sub_category.id = category_subcategory.subcategory_id JOIN ' + 'subcategory_product ON sub_category.id = ' + 'subcategory_product.subcategory_id JOIN product ON product.id = ' + 'subcategory_product.product_id' + ) + + +class TestChainedJoinFor3LevelDeepOneToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + create_tables = False + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_categories = sa.orm.relationship( + 'SubCategory', backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + products = sa.orm.relationship( + 'Product', + backref='sub_category' + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + def __repr__(self): + return '' % self.id + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def test_simple_join(self): + assert str(chained_join(self.Catalog.categories)) == 'category' + + def test_two_relations(self): + sql = chained_join( + self.Catalog.categories, + self.Category.sub_categories + ) + assert str(sql) == ( + 'category JOIN sub_category ON category.id = ' + 'sub_category.category_id' + ) + + def test_three_relations(self): + sql = chained_join( + self.Catalog.categories, + self.Category.sub_categories, + self.SubCategory.products + ) + assert str(sql) == ( + 'category JOIN sub_category ON category.id = ' + 'sub_category.category_id JOIN product ON sub_category.id = ' + 'product.sub_category_id' + ) + + +class TestChainedJoinForOneToOneToOneToOne(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + category = sa.orm.relationship( + 'Category', + uselist=False, + backref='catalog' + ) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_category = sa.orm.relationship( + 'SubCategory', + uselist=False, + backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + product = sa.orm.relationship( + 'Product', + uselist=False, + backref='sub_category' + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Integer) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def test_simple_join(self): + assert str(chained_join(self.Catalog.category)) == 'category' + + def test_two_relations(self): + sql = chained_join( + self.Catalog.category, + self.Category.sub_category + ) + assert str(sql) == ( + 'category JOIN sub_category ON category.id = ' + 'sub_category.category_id' + ) + + def test_three_relations(self): + sql = chained_join( + self.Catalog.category, + self.Category.sub_category, + self.SubCategory.product + ) + assert str(sql) == ( + 'category JOIN sub_category ON category.id = ' + 'sub_category.category_id JOIN product ON sub_category.id = ' + 'product.sub_category_id' + )