Add dependencies function
This commit is contained in:
@@ -12,7 +12,64 @@ from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.query import _ColumnEntity
|
||||
from sqlalchemy.orm.session import object_session
|
||||
from sqlalchemy.orm.util import AliasedInsp
|
||||
from ..query_chain import QueryChain
|
||||
|
||||
|
||||
def dependencies(obj, foreign_keys=None):
|
||||
"""
|
||||
Return a QueryChain that iterates through all dependent objects for given
|
||||
SQLAlchemy object.
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param foreign_keys:
|
||||
A sequence of foreign keys to use for searching the dependencies for
|
||||
given object. By default this is None, indicating that all foreign keys
|
||||
referencing the object will be used.
|
||||
|
||||
.. note::
|
||||
This function does not support exotic mappers that use multiple tables
|
||||
|
||||
.. versionadded: 0.26.0
|
||||
"""
|
||||
if foreign_keys is None:
|
||||
foreign_keys = get_referencing_foreign_keys(obj)
|
||||
|
||||
session = object_session(obj)
|
||||
|
||||
foreign_keys = sorted(
|
||||
foreign_keys, key=lambda key: key.constraint.table.name
|
||||
)
|
||||
chain = QueryChain([])
|
||||
classes = obj.__class__._decl_class_registry
|
||||
|
||||
for table, keys in groupby(foreign_keys, lambda key: key.constraint.table):
|
||||
for class_ in classes.values():
|
||||
if hasattr(class_, '__table__') and class_.__table__ == table:
|
||||
criteria = []
|
||||
visited_constraints = []
|
||||
for key in keys:
|
||||
if key.constraint not in visited_constraints:
|
||||
visited_constraints.append(key.constraint)
|
||||
subcriteria = [
|
||||
getattr(class_, column.key) ==
|
||||
getattr(
|
||||
obj,
|
||||
key.constraint.elements[index].column.key
|
||||
)
|
||||
for index, column
|
||||
in enumerate(key.constraint.columns)
|
||||
]
|
||||
criteria.append(sa.and_(*subcriteria))
|
||||
|
||||
query = session.query(class_).filter(
|
||||
sa.or_(
|
||||
*criteria
|
||||
)
|
||||
)
|
||||
chain.queries.append(query)
|
||||
return chain
|
||||
|
||||
|
||||
def get_referencing_foreign_keys(mixed):
|
||||
|
90
tests/functions/test_dependencies.py
Normal file
90
tests/functions/test_dependencies.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.functions.orm import dependencies
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestDependencies(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
first_name = sa.Column(sa.Unicode(255))
|
||||
last_name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
owner_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
|
||||
author = sa.orm.relationship(User, foreign_keys=[author_id])
|
||||
owner = sa.orm.relationship(User, foreign_keys=[owner_id])
|
||||
|
||||
class BlogPost(self.Base):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
|
||||
author = sa.orm.relationship(User)
|
||||
|
||||
self.User = User
|
||||
self.Article = Article
|
||||
self.BlogPost = BlogPost
|
||||
|
||||
def test_multiple_refs(self):
|
||||
user = self.User(first_name=u'John')
|
||||
articles = [
|
||||
self.Article(author=user),
|
||||
self.Article(),
|
||||
self.Article(owner=user),
|
||||
self.Article(author=user, owner=user)
|
||||
]
|
||||
self.session.add_all(articles)
|
||||
self.session.commit()
|
||||
|
||||
deps = list(dependencies(user))
|
||||
assert len(deps) == 3
|
||||
assert articles[0] in deps
|
||||
assert articles[2] in deps
|
||||
assert articles[3] in deps
|
||||
|
||||
|
||||
class TestDependenciesWithCompositeKeys(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
first_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
last_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_first_name = sa.Column(sa.Unicode(255))
|
||||
author_last_name = sa.Column(sa.Unicode(255))
|
||||
__table_args__ = (
|
||||
sa.ForeignKeyConstraint(
|
||||
[author_first_name, author_last_name],
|
||||
[User.first_name, User.last_name]
|
||||
),
|
||||
)
|
||||
|
||||
author = sa.orm.relationship(User)
|
||||
|
||||
self.User = User
|
||||
self.Article = Article
|
||||
|
||||
def test_returns_all_dependent_objects(self):
|
||||
user = self.User(first_name=u'John', last_name=u'Smith')
|
||||
articles = [
|
||||
self.Article(author=user),
|
||||
self.Article(),
|
||||
self.Article(),
|
||||
self.Article(author=user)
|
||||
]
|
||||
self.session.add_all(articles)
|
||||
self.session.commit()
|
||||
|
||||
deps = list(dependencies(user))
|
||||
assert len(deps) == 2
|
||||
assert articles[0] in deps
|
||||
assert articles[3] in deps
|
Reference in New Issue
Block a user