Add dependencies function

This commit is contained in:
Konsta Vesterinen
2014-05-06 20:22:44 +03:00
parent 340c14e0b2
commit 9e2c47d50e
2 changed files with 147 additions and 0 deletions

View File

@@ -12,7 +12,64 @@ from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.session import object_session
from sqlalchemy.orm.util import AliasedInsp
from ..query_chain import QueryChain
def dependencies(obj, foreign_keys=None):
"""
Return a QueryChain that iterates through all dependent objects for given
SQLAlchemy object.
:param obj: SQLAlchemy declarative model object
:param foreign_keys:
A sequence of foreign keys to use for searching the dependencies for
given object. By default this is None, indicating that all foreign keys
referencing the object will be used.
.. note::
This function does not support exotic mappers that use multiple tables
.. versionadded: 0.26.0
"""
if foreign_keys is None:
foreign_keys = get_referencing_foreign_keys(obj)
session = object_session(obj)
foreign_keys = sorted(
foreign_keys, key=lambda key: key.constraint.table.name
)
chain = QueryChain([])
classes = obj.__class__._decl_class_registry
for table, keys in groupby(foreign_keys, lambda key: key.constraint.table):
for class_ in classes.values():
if hasattr(class_, '__table__') and class_.__table__ == table:
criteria = []
visited_constraints = []
for key in keys:
if key.constraint not in visited_constraints:
visited_constraints.append(key.constraint)
subcriteria = [
getattr(class_, column.key) ==
getattr(
obj,
key.constraint.elements[index].column.key
)
for index, column
in enumerate(key.constraint.columns)
]
criteria.append(sa.and_(*subcriteria))
query = session.query(class_).filter(
sa.or_(
*criteria
)
)
chain.queries.append(query)
return chain
def get_referencing_foreign_keys(mixed):

View File

@@ -0,0 +1,90 @@
import sqlalchemy as sa
from sqlalchemy_utils.functions.orm import dependencies
from tests import TestCase
class TestDependencies(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
first_name = sa.Column(sa.Unicode(255))
last_name = sa.Column(sa.Unicode(255))
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
owner_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User, foreign_keys=[author_id])
owner = sa.orm.relationship(User, foreign_keys=[owner_id])
class BlogPost(self.Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User)
self.User = User
self.Article = Article
self.BlogPost = BlogPost
def test_multiple_refs(self):
user = self.User(first_name=u'John')
articles = [
self.Article(author=user),
self.Article(),
self.Article(owner=user),
self.Article(author=user, owner=user)
]
self.session.add_all(articles)
self.session.commit()
deps = list(dependencies(user))
assert len(deps) == 3
assert articles[0] in deps
assert articles[2] in deps
assert articles[3] in deps
class TestDependenciesWithCompositeKeys(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
first_name = sa.Column(sa.Unicode(255), primary_key=True)
last_name = sa.Column(sa.Unicode(255), primary_key=True)
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_first_name = sa.Column(sa.Unicode(255))
author_last_name = sa.Column(sa.Unicode(255))
__table_args__ = (
sa.ForeignKeyConstraint(
[author_first_name, author_last_name],
[User.first_name, User.last_name]
),
)
author = sa.orm.relationship(User)
self.User = User
self.Article = Article
def test_returns_all_dependent_objects(self):
user = self.User(first_name=u'John', last_name=u'Smith')
articles = [
self.Article(author=user),
self.Article(),
self.Article(),
self.Article(author=user)
]
self.session.add_all(articles)
self.session.commit()
deps = list(dependencies(user))
assert len(deps) == 2
assert articles[0] in deps
assert articles[3] in deps