79 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			79 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import sqlalchemy as sa
 | 
						|
from pytest import raises
 | 
						|
from sqlalchemy_utils import batch_fetch
 | 
						|
from tests import TestCase
 | 
						|
 | 
						|
 | 
						|
class TestBatchFetchOneToManyRelationships(TestCase):
 | 
						|
    def create_models(self):
 | 
						|
        class User(self.Base):
 | 
						|
            __tablename__ = 'user'
 | 
						|
            id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
 | 
						|
            name = sa.Column(sa.Unicode(255))
 | 
						|
 | 
						|
        class Category(self.Base):
 | 
						|
            __tablename__ = 'category'
 | 
						|
            id = sa.Column(sa.Integer, primary_key=True)
 | 
						|
            name = sa.Column(sa.Unicode(255))
 | 
						|
 | 
						|
        class Article(self.Base):
 | 
						|
            __tablename__ = 'article'
 | 
						|
            id = sa.Column(sa.Integer, primary_key=True)
 | 
						|
            name = sa.Column(sa.Unicode(255))
 | 
						|
            category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
 | 
						|
 | 
						|
            category = sa.orm.relationship(
 | 
						|
                Category,
 | 
						|
                primaryjoin=category_id == Category.id,
 | 
						|
                backref=sa.orm.backref(
 | 
						|
                    'articles'
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
        self.User = User
 | 
						|
        self.Category = Category
 | 
						|
        self.Article = Article
 | 
						|
 | 
						|
    def setup_method(self, method):
 | 
						|
        TestCase.setup_method(self, method)
 | 
						|
        articles = [
 | 
						|
            self.Article(name=u'Article 1'),
 | 
						|
            self.Article(name=u'Article 2'),
 | 
						|
            self.Article(name=u'Article 3'),
 | 
						|
            self.Article(name=u'Article 4'),
 | 
						|
            self.Article(name=u'Article 5')
 | 
						|
        ]
 | 
						|
        self.session.add_all(articles)
 | 
						|
        self.session.flush()
 | 
						|
 | 
						|
        category = self.Category(name=u'Category #1')
 | 
						|
        category.articles = articles[0:2]
 | 
						|
        category2 = self.Category(name=u'Category #2')
 | 
						|
        category2.articles = articles[2:]
 | 
						|
        self.session.add(category)
 | 
						|
        self.session.add(category2)
 | 
						|
        self.session.commit()
 | 
						|
 | 
						|
    def test_raises_error_if_relationship_not_found(self):
 | 
						|
        categories = self.session.query(self.Category).all()
 | 
						|
        with raises(AttributeError):
 | 
						|
            batch_fetch(categories, 'unknown_relation')
 | 
						|
 | 
						|
    def test_supports_relationship_attributes(self):
 | 
						|
        categories = self.session.query(self.Category).all()
 | 
						|
        batch_fetch(categories, self.Category.articles)
 | 
						|
        query_count = self.connection.query_count
 | 
						|
        articles = categories[0].articles  # no lazy load should occur
 | 
						|
        assert len(articles) == 2
 | 
						|
        article_names = [article.name for article in articles]
 | 
						|
 | 
						|
        assert 'Article 1' in article_names
 | 
						|
        assert 'Article 2' in article_names
 | 
						|
        articles = categories[1].articles  # no lazy load should occur
 | 
						|
        assert len(articles) == 3
 | 
						|
        article_names = [article.name for article in articles]
 | 
						|
        assert 'Article 3' in article_names
 | 
						|
        assert 'Article 4' in article_names
 | 
						|
        assert 'Article 5' in article_names
 | 
						|
        assert self.connection.query_count == query_count
 |