diff --git a/sqlalchemy_utils/functions/foreign_keys.py b/sqlalchemy_utils/functions/foreign_keys.py index 1b3f051..2ade46f 100644 --- a/sqlalchemy_utils/functions/foreign_keys.py +++ b/sqlalchemy_utils/functions/foreign_keys.py @@ -4,6 +4,7 @@ from itertools import groupby import six import sqlalchemy as sa from sqlalchemy.engine import reflection +from sqlalchemy.exc import NoInspectionAvailable from sqlalchemy.orm import object_session, mapperlib from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint @@ -195,9 +196,7 @@ def dependent_objects(obj, foreign_keys=None): through all dependent objects for given SQLAlchemy object. Consider a User object is referenced in various articles and also in - various orders. Getting all these dependent objects is as easy as: - - :: + various orders. Getting all these dependent objects is as easy as:: from sqlalchemy_utils import dependent_objects @@ -219,9 +218,7 @@ def dependent_objects(obj, foreign_keys=None): it will lead to nasty IntegrityErrors being raised. In the following example we delete given user if it doesn't have any - foreign key restricted dependent objects. - - :: + foreign key restricted dependent objects:: from sqlalchemy_utils import get_referencing_foreign_keys @@ -272,8 +269,17 @@ def dependent_objects(obj, foreign_keys=None): classes = obj.__class__._decl_class_registry for table, keys in group_foreign_keys(foreign_keys): + keys = list(keys) for class_ in classes.values(): - if hasattr(class_, '__table__') and class_.__table__ == table: + try: + mapper = sa.inspect(class_) + except NoInspectionAvailable: + continue + parent_mapper = mapper.inherits + if ( + table in mapper.tables and + not (parent_mapper and table in parent_mapper.tables) + ): criteria = [] visited_constraints = [] for key in keys: diff --git a/tests/functions/test_dependent_objects.py b/tests/functions/test_dependent_objects.py index 8a717aa..de64222 100644 --- a/tests/functions/test_dependent_objects.py +++ b/tests/functions/test_dependent_objects.py @@ -153,3 +153,64 @@ class TestDependentObjectsWithCompositeKeys(TestCase): assert len(deps) == 2 assert articles[0] in deps assert articles[3] in deps + + +class TestDependentObjectsWithSingleTableInheritance(TestCase): + def create_models(self): + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class TextItem(self.Base): + __tablename__ = 'text_item' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + category_id = sa.Column( + sa.Integer, + sa.ForeignKey(Category.id) + ) + category = sa.orm.relationship( + Category, + backref=sa.orm.backref( + 'articles' + ) + ) + type = sa.Column(sa.Unicode(255)) + + __mapper_args__ = { + 'polymorphic_on': type, + } + + class Article(TextItem): + __mapper_args__ = { + 'polymorphic_identity': u'article' + } + + class BlogPost(TextItem): + __mapper_args__ = { + 'polymorphic_identity': u'blog_post' + } + + + self.Category = Category + self.TextItem = TextItem + self.Article = Article + self.BlogPost = BlogPost + + def test_returns_all_dependent_objects(self): + category1 = self.Category(name=u'Category #1') + category2 = self.Category(name=u'Category #2') + articles = [ + self.Article(category=category1), + self.Article(category=category1), + self.Article(category=category2), + self.Article(category=category2), + ] + self.session.add_all(articles) + self.session.commit() + + deps = list(dependent_objects(category1)) + assert len(deps) == 2 + assert articles[0] in deps + assert articles[1] in deps