diff --git a/CHANGES.rst b/CHANGES.rst index 0fb62e1..e749aa5 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,11 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.27.10 (2014-12-03) +^^^^^^^^^^^^^^^^^^^^ + +- Fixed column alias handling in dependent_objects + 0.27.9 (2014-12-01) ^^^^^^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index ba78c70..906547e 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -80,7 +80,7 @@ from .types import ( from .models import Timestamp -__version__ = '0.27.9' +__version__ = '0.27.10' __all__ = ( diff --git a/sqlalchemy_utils/functions/foreign_keys.py b/sqlalchemy_utils/functions/foreign_keys.py index 11a646f..6a90b1c 100644 --- a/sqlalchemy_utils/functions/foreign_keys.py +++ b/sqlalchemy_utils/functions/foreign_keys.py @@ -279,31 +279,43 @@ def dependent_objects(obj, foreign_keys=None): table in mapper.tables and not (parent_mapper and table in parent_mapper.tables) ): - criteria = [] - visited_constraints = [] - for key in keys: - if key.constraint not in visited_constraints: - visited_constraints.append(key.constraint) - subcriteria = [ - getattr(class_, column.key) == - getattr( - obj, - key.constraint.elements[index].column.key - ) - for index, column - in enumerate(key.constraint.columns) - ] - criteria.append(sa.and_(*subcriteria)) - query = session.query(class_).filter( - sa.or_( - *criteria - ) + sa.or_(*_get_criteria(keys, class_, obj)) ) chain.queries.append(query) return chain +def _get_criteria(keys, class_, obj): + criteria = [] + visited_constraints = [] + for key in keys: + if key.constraint in visited_constraints: + continue + visited_constraints.append(key.constraint) + + subcriteria = [] + for index, column in enumerate(key.constraint.columns): + prop = sa.inspect(class_).get_property_by_column( + column + ) + foreign_column = ( + key.constraint.elements[index].column + ) + subcriteria.append( + getattr(class_, prop.key) == + getattr( + obj, + sa.inspect(type(obj)) + .get_property_by_column( + foreign_column + ).key + ) + ) + criteria.append(sa.and_(*subcriteria)) + return criteria + + def non_indexed_foreign_keys(metadata, engine=None): """ Finds all non indexed foreign keys from all tables of given MetaData. @@ -347,10 +359,10 @@ def is_indexed_foreign_key(constraint): :param constraint: ForeignKeyConstraint object to check the indexes """ - for index in constraint.table.indexes: - index_column_names = set( - column.name for column in index.columns - ) - if index_column_names == set(constraint.columns): - return True - return False + return any( + set(column.name for column in index.columns) + == + set(constraint.columns) + for index + in constraint.table.indexes + ) diff --git a/tests/functions/test_dependent_objects.py b/tests/functions/test_dependent_objects.py index de64222..c9a904b 100644 --- a/tests/functions/test_dependent_objects.py +++ b/tests/functions/test_dependent_objects.py @@ -78,6 +78,85 @@ class TestDependentObjects(TestCase): assert objects[3] in deps +class TestDependentObjectsWithColumnAliases(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + first_name = sa.Column(sa.Unicode(255)) + last_name = sa.Column(sa.Unicode(255)) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + author_id = sa.Column( + '_author_id', sa.Integer, sa.ForeignKey('user.id') + ) + owner_id = sa.Column( + '_owner_id', + sa.Integer, sa.ForeignKey('user.id', ondelete='SET NULL') + ) + + author = sa.orm.relationship(User, foreign_keys=[author_id]) + owner = sa.orm.relationship(User, foreign_keys=[owner_id]) + + class BlogPost(self.Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + owner_id = sa.Column( + '_owner_id', + sa.Integer, sa.ForeignKey('user.id', ondelete='CASCADE') + ) + + owner = sa.orm.relationship(User) + + self.User = User + self.Article = Article + self.BlogPost = BlogPost + + def test_returns_all_dependent_objects(self): + user = self.User(first_name=u'John') + articles = [ + self.Article(author=user), + self.Article(), + self.Article(owner=user), + self.Article(author=user, owner=user) + ] + self.session.add_all(articles) + self.session.commit() + + deps = list(dependent_objects(user)) + assert len(deps) == 3 + assert articles[0] in deps + assert articles[2] in deps + assert articles[3] in deps + + def test_with_foreign_keys_parameter(self): + user = self.User(first_name=u'John') + objects = [ + self.Article(author=user), + self.Article(), + self.Article(owner=user), + self.Article(author=user, owner=user), + self.BlogPost(owner=user) + ] + self.session.add_all(objects) + self.session.commit() + + deps = list( + dependent_objects( + user, + ( + fk for fk in get_referencing_foreign_keys(self.User) + if fk.ondelete == 'RESTRICT' or fk.ondelete is None + ) + ).limit(5) + ) + assert len(deps) == 2 + assert objects[0] in deps + assert objects[3] in deps + + class TestDependentObjectsWithManyReferences(TestCase): def create_models(self): class User(self.Base):