Refactored batch fetch tests
This commit is contained in:
0
tests/batch_fetch/__init__.py
Normal file
0
tests/batch_fetch/__init__.py
Normal file
103
tests/batch_fetch/test_deep_relationships.py
Normal file
103
tests/batch_fetch/test_deep_relationships.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils import batch_fetch
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchFetch(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
|
__tablename__ = 'category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Article(self.Base):
|
||||||
|
__tablename__ = 'article'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||||
|
|
||||||
|
category = sa.orm.relationship(
|
||||||
|
Category,
|
||||||
|
primaryjoin=category_id == Category.id,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'articles'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
article_tag = sa.Table(
|
||||||
|
'article_tag',
|
||||||
|
self.Base.metadata,
|
||||||
|
sa.Column(
|
||||||
|
'article_id',
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey('article.id', ondelete='cascade')
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
'tag_id',
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey('tag.id', ondelete='cascade')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
class Tag(self.Base):
|
||||||
|
__tablename__ = 'tag'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
articles = sa.orm.relationship(
|
||||||
|
Article,
|
||||||
|
secondary=article_tag,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'tags'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.Category = Category
|
||||||
|
self.Article = Article
|
||||||
|
self.Tag = Tag
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
TestCase.setup_method(self, method)
|
||||||
|
articles = [
|
||||||
|
self.Article(name=u'Article 1'),
|
||||||
|
self.Article(name=u'Article 2'),
|
||||||
|
self.Article(name=u'Article 3'),
|
||||||
|
self.Article(name=u'Article 4'),
|
||||||
|
self.Article(name=u'Article 5')
|
||||||
|
]
|
||||||
|
self.session.add_all(articles)
|
||||||
|
self.session.flush()
|
||||||
|
|
||||||
|
tags = [
|
||||||
|
self.Tag(name=u'Tag 1'),
|
||||||
|
self.Tag(name=u'Tag 2'),
|
||||||
|
self.Tag(name=u'Tag 3')
|
||||||
|
]
|
||||||
|
articles[0].tags = tags
|
||||||
|
articles[3].tags = tags[1:]
|
||||||
|
|
||||||
|
category = self.Category(name=u'Category #1')
|
||||||
|
category.articles = articles[0:2]
|
||||||
|
category2 = self.Category(name=u'Category #2')
|
||||||
|
category2.articles = articles[2:]
|
||||||
|
self.session.add(category)
|
||||||
|
self.session.add(category2)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
def test_multiple_relationships(self):
|
||||||
|
categories = self.session.query(self.Category).all()
|
||||||
|
batch_fetch(
|
||||||
|
categories,
|
||||||
|
'articles',
|
||||||
|
'articles.tags'
|
||||||
|
)
|
||||||
|
query_count = self.connection.query_count
|
||||||
|
categories[0].articles[0].tags
|
||||||
|
assert self.connection.query_count == query_count
|
||||||
|
categories[1].articles[1].tags
|
||||||
|
assert self.connection.query_count == query_count
|
67
tests/batch_fetch/test_simple_relationships.py
Normal file
67
tests/batch_fetch/test_simple_relationships.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from pytest import raises
|
||||||
|
from sqlalchemy_utils import batch_fetch
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchFetch(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
|
__tablename__ = 'category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Article(self.Base):
|
||||||
|
__tablename__ = 'article'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||||
|
|
||||||
|
category = sa.orm.relationship(
|
||||||
|
Category,
|
||||||
|
primaryjoin=category_id == Category.id,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'articles'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.Category = Category
|
||||||
|
self.Article = Article
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
TestCase.setup_method(self, method)
|
||||||
|
articles = [
|
||||||
|
self.Article(name=u'Article 1'),
|
||||||
|
self.Article(name=u'Article 2'),
|
||||||
|
self.Article(name=u'Article 3'),
|
||||||
|
self.Article(name=u'Article 4'),
|
||||||
|
self.Article(name=u'Article 5')
|
||||||
|
]
|
||||||
|
self.session.add_all(articles)
|
||||||
|
self.session.flush()
|
||||||
|
|
||||||
|
category = self.Category(name=u'Category #1')
|
||||||
|
category.articles = articles[0:2]
|
||||||
|
category2 = self.Category(name=u'Category #2')
|
||||||
|
category2.articles = articles[2:]
|
||||||
|
self.session.add(category)
|
||||||
|
self.session.add(category2)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
def test_raises_error_if_relationship_not_found(self):
|
||||||
|
categories = self.session.query(self.Category).all()
|
||||||
|
with raises(AttributeError):
|
||||||
|
batch_fetch(categories, 'unknown_relation')
|
||||||
|
|
||||||
|
def test_supports_relationship_attributes(self):
|
||||||
|
categories = self.session.query(self.Category).all()
|
||||||
|
batch_fetch(categories, self.Category.articles)
|
||||||
|
query_count = self.connection.query_count
|
||||||
|
categories[0].articles # no lazy load should occur
|
||||||
|
assert self.connection.query_count == query_count
|
Reference in New Issue
Block a user