From 9e2c47d50eecfad22e70e631bf34fec93dd5ef28 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 6 May 2014 20:22:44 +0300 Subject: [PATCH] Add dependencies function --- sqlalchemy_utils/functions/orm.py | 57 ++++++++++++++++++ tests/functions/test_dependencies.py | 90 ++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+) create mode 100644 tests/functions/test_dependencies.py diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index b5ab5c3..b72d20a 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -12,7 +12,64 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import _ColumnEntity +from sqlalchemy.orm.session import object_session from sqlalchemy.orm.util import AliasedInsp +from ..query_chain import QueryChain + + +def dependencies(obj, foreign_keys=None): + """ + Return a QueryChain that iterates through all dependent objects for given + SQLAlchemy object. + + :param obj: SQLAlchemy declarative model object + :param foreign_keys: + A sequence of foreign keys to use for searching the dependencies for + given object. By default this is None, indicating that all foreign keys + referencing the object will be used. + + .. note:: + This function does not support exotic mappers that use multiple tables + + .. versionadded: 0.26.0 + """ + if foreign_keys is None: + foreign_keys = get_referencing_foreign_keys(obj) + + session = object_session(obj) + + foreign_keys = sorted( + foreign_keys, key=lambda key: key.constraint.table.name + ) + chain = QueryChain([]) + classes = obj.__class__._decl_class_registry + + for table, keys in groupby(foreign_keys, lambda key: key.constraint.table): + for class_ in classes.values(): + if hasattr(class_, '__table__') and class_.__table__ == table: + criteria = [] + visited_constraints = [] + for key in keys: + if key.constraint not in visited_constraints: + visited_constraints.append(key.constraint) + subcriteria = [ + getattr(class_, column.key) == + getattr( + obj, + key.constraint.elements[index].column.key + ) + for index, column + in enumerate(key.constraint.columns) + ] + criteria.append(sa.and_(*subcriteria)) + + query = session.query(class_).filter( + sa.or_( + *criteria + ) + ) + chain.queries.append(query) + return chain def get_referencing_foreign_keys(mixed): diff --git a/tests/functions/test_dependencies.py b/tests/functions/test_dependencies.py new file mode 100644 index 0000000..40c6905 --- /dev/null +++ b/tests/functions/test_dependencies.py @@ -0,0 +1,90 @@ +import sqlalchemy as sa +from sqlalchemy_utils.functions.orm import dependencies +from tests import TestCase + + +class TestDependencies(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + first_name = sa.Column(sa.Unicode(255)) + last_name = sa.Column(sa.Unicode(255)) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + owner_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + + author = sa.orm.relationship(User, foreign_keys=[author_id]) + owner = sa.orm.relationship(User, foreign_keys=[owner_id]) + + class BlogPost(self.Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + + author = sa.orm.relationship(User) + + self.User = User + self.Article = Article + self.BlogPost = BlogPost + + def test_multiple_refs(self): + user = self.User(first_name=u'John') + articles = [ + self.Article(author=user), + self.Article(), + self.Article(owner=user), + self.Article(author=user, owner=user) + ] + self.session.add_all(articles) + self.session.commit() + + deps = list(dependencies(user)) + assert len(deps) == 3 + assert articles[0] in deps + assert articles[2] in deps + assert articles[3] in deps + + +class TestDependenciesWithCompositeKeys(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + first_name = sa.Column(sa.Unicode(255), primary_key=True) + last_name = sa.Column(sa.Unicode(255), primary_key=True) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + author_first_name = sa.Column(sa.Unicode(255)) + author_last_name = sa.Column(sa.Unicode(255)) + __table_args__ = ( + sa.ForeignKeyConstraint( + [author_first_name, author_last_name], + [User.first_name, User.last_name] + ), + ) + + author = sa.orm.relationship(User) + + self.User = User + self.Article = Article + + def test_returns_all_dependent_objects(self): + user = self.User(first_name=u'John', last_name=u'Smith') + articles = [ + self.Article(author=user), + self.Article(), + self.Article(), + self.Article(author=user) + ] + self.session.add_all(articles) + self.session.commit() + + deps = list(dependencies(user)) + assert len(deps) == 2 + assert articles[0] in deps + assert articles[3] in deps