diff --git a/tests/batch_fetch/__init__.py b/tests/batch_fetch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/batch_fetch/test_deep_relationships.py b/tests/batch_fetch/test_deep_relationships.py new file mode 100644 index 0000000..9c43d66 --- /dev/null +++ b/tests/batch_fetch/test_deep_relationships.py @@ -0,0 +1,103 @@ +import sqlalchemy as sa +from sqlalchemy_utils import batch_fetch +from tests import TestCase + + +class TestBatchFetch(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) + + category = sa.orm.relationship( + Category, + primaryjoin=category_id == Category.id, + backref=sa.orm.backref( + 'articles' + ) + ) + + article_tag = sa.Table( + 'article_tag', + self.Base.metadata, + sa.Column( + 'article_id', + sa.Integer, + sa.ForeignKey('article.id', ondelete='cascade') + ), + sa.Column( + 'tag_id', + sa.Integer, + sa.ForeignKey('tag.id', ondelete='cascade') + ) + ) + + class Tag(self.Base): + __tablename__ = 'tag' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + articles = sa.orm.relationship( + Article, + secondary=article_tag, + backref=sa.orm.backref( + 'tags' + ) + ) + + self.User = User + self.Category = Category + self.Article = Article + self.Tag = Tag + + def setup_method(self, method): + TestCase.setup_method(self, method) + articles = [ + self.Article(name=u'Article 1'), + self.Article(name=u'Article 2'), + self.Article(name=u'Article 3'), + self.Article(name=u'Article 4'), + self.Article(name=u'Article 5') + ] + self.session.add_all(articles) + self.session.flush() + + tags = [ + self.Tag(name=u'Tag 1'), + self.Tag(name=u'Tag 2'), + self.Tag(name=u'Tag 3') + ] + articles[0].tags = tags + articles[3].tags = tags[1:] + + category = self.Category(name=u'Category #1') + category.articles = articles[0:2] + category2 = self.Category(name=u'Category #2') + category2.articles = articles[2:] + self.session.add(category) + self.session.add(category2) + self.session.commit() + + def test_multiple_relationships(self): + categories = self.session.query(self.Category).all() + batch_fetch( + categories, + 'articles', + 'articles.tags' + ) + query_count = self.connection.query_count + categories[0].articles[0].tags + assert self.connection.query_count == query_count + categories[1].articles[1].tags + assert self.connection.query_count == query_count diff --git a/tests/batch_fetch/test_simple_relationships.py b/tests/batch_fetch/test_simple_relationships.py new file mode 100644 index 0000000..03dfed7 --- /dev/null +++ b/tests/batch_fetch/test_simple_relationships.py @@ -0,0 +1,67 @@ +import sqlalchemy as sa +from pytest import raises +from sqlalchemy_utils import batch_fetch +from tests import TestCase + + +class TestBatchFetch(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) + + category = sa.orm.relationship( + Category, + primaryjoin=category_id == Category.id, + backref=sa.orm.backref( + 'articles' + ) + ) + + self.User = User + self.Category = Category + self.Article = Article + + def setup_method(self, method): + TestCase.setup_method(self, method) + articles = [ + self.Article(name=u'Article 1'), + self.Article(name=u'Article 2'), + self.Article(name=u'Article 3'), + self.Article(name=u'Article 4'), + self.Article(name=u'Article 5') + ] + self.session.add_all(articles) + self.session.flush() + + category = self.Category(name=u'Category #1') + category.articles = articles[0:2] + category2 = self.Category(name=u'Category #2') + category2.articles = articles[2:] + self.session.add(category) + self.session.add(category2) + self.session.commit() + + def test_raises_error_if_relationship_not_found(self): + categories = self.session.query(self.Category).all() + with raises(AttributeError): + batch_fetch(categories, 'unknown_relation') + + def test_supports_relationship_attributes(self): + categories = self.session.query(self.Category).all() + batch_fetch(categories, self.Category.articles) + query_count = self.connection.query_count + categories[0].articles # no lazy load should occur + assert self.connection.query_count == query_count