import sqlalchemy as sa from sqlalchemy_utils import dependent_objects, get_referencing_foreign_keys from tests import TestCase class TestDependentObjects(TestCase): def create_models(self): class User(self.Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) first_name = sa.Column(sa.Unicode(255)) last_name = sa.Column(sa.Unicode(255)) class Article(self.Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) owner_id = sa.Column( sa.Integer, sa.ForeignKey('user.id', ondelete='SET NULL') ) author = sa.orm.relationship(User, foreign_keys=[author_id]) owner = sa.orm.relationship(User, foreign_keys=[owner_id]) class BlogPost(self.Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) owner_id = sa.Column( sa.Integer, sa.ForeignKey('user.id', ondelete='CASCADE') ) owner = sa.orm.relationship(User) self.User = User self.Article = Article self.BlogPost = BlogPost def test_returns_all_dependent_objects(self): user = self.User(first_name=u'John') articles = [ self.Article(author=user), self.Article(), self.Article(owner=user), self.Article(author=user, owner=user) ] self.session.add_all(articles) self.session.commit() deps = list(dependent_objects(user)) assert len(deps) == 3 assert articles[0] in deps assert articles[2] in deps assert articles[3] in deps def test_with_foreign_keys_parameter(self): user = self.User(first_name=u'John') objects = [ self.Article(author=user), self.Article(), self.Article(owner=user), self.Article(author=user, owner=user), self.BlogPost(owner=user) ] self.session.add_all(objects) self.session.commit() deps = list( dependent_objects( user, ( fk for fk in get_referencing_foreign_keys(self.User) if fk.ondelete == 'RESTRICT' or fk.ondelete is None ) ).limit(5) ) assert len(deps) == 2 assert objects[0] in deps assert objects[3] in deps class TestDependentObjectsWithManyReferences(TestCase): def create_models(self): class User(self.Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) first_name = sa.Column(sa.Unicode(255)) last_name = sa.Column(sa.Unicode(255)) class BlogPost(self.Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author = sa.orm.relationship(User) class Article(self.Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author = sa.orm.relationship(User) self.User = User self.Article = Article self.BlogPost = BlogPost def test_with_many_dependencies(self): user = self.User(first_name=u'John') objects = [ self.Article(author=user), self.BlogPost(author=user) ] self.session.add_all(objects) self.session.commit() deps = list(dependent_objects(user)) assert len(deps) == 2 class TestDependentObjectsWithCompositeKeys(TestCase): def create_models(self): class User(self.Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) class Article(self.Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) author_last_name = sa.Column(sa.Unicode(255)) __table_args__ = ( sa.ForeignKeyConstraint( [author_first_name, author_last_name], [User.first_name, User.last_name] ), ) author = sa.orm.relationship(User) self.User = User self.Article = Article def test_returns_all_dependent_objects(self): user = self.User(first_name=u'John', last_name=u'Smith') articles = [ self.Article(author=user), self.Article(), self.Article(), self.Article(author=user) ] self.session.add_all(articles) self.session.commit() deps = list(dependent_objects(user)) assert len(deps) == 2 assert articles[0] in deps assert articles[3] in deps class TestDependentObjectsWithSingleTableInheritance(TestCase): def create_models(self): class Category(self.Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) class TextItem(self.Base): __tablename__ = 'text_item' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) category_id = sa.Column( sa.Integer, sa.ForeignKey(Category.id) ) category = sa.orm.relationship( Category, backref=sa.orm.backref( 'articles' ) ) type = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_on': type, } class Article(TextItem): __mapper_args__ = { 'polymorphic_identity': u'article' } class BlogPost(TextItem): __mapper_args__ = { 'polymorphic_identity': u'blog_post' } self.Category = Category self.TextItem = TextItem self.Article = Article self.BlogPost = BlogPost def test_returns_all_dependent_objects(self): category1 = self.Category(name=u'Category #1') category2 = self.Category(name=u'Category #2') articles = [ self.Article(category=category1), self.Article(category=category1), self.Article(category=category2), self.Article(category=category2), ] self.session.add_all(articles) self.session.commit() deps = list(dependent_objects(category1)) assert len(deps) == 2 assert articles[0] in deps assert articles[1] in deps