163 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			163 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import sqlalchemy as sa
 | |
| from sqlalchemy_utils import batch_fetch
 | |
| from tests import TestCase
 | |
| 
 | |
| 
 | |
| class JoinTableInheritanceTestCase(TestCase):
 | |
|     def create_models(self):
 | |
|         class Category(self.Base):
 | |
|             __tablename__ = 'category'
 | |
|             id = sa.Column(sa.Integer, primary_key=True)
 | |
|             name = sa.Column(sa.Unicode(255))
 | |
| 
 | |
|         class User(self.Base):
 | |
|             __tablename__ = 'user'
 | |
|             id = sa.Column(sa.Integer, primary_key=True)
 | |
|             name = sa.Column(sa.Unicode(255))
 | |
| 
 | |
|         class TextItem(self.Base):
 | |
|             __tablename__ = 'text_item'
 | |
|             id = sa.Column(sa.Integer, primary_key=True)
 | |
|             name = sa.Column(sa.Unicode(255))
 | |
|             category_id = sa.Column(
 | |
|                 sa.Integer,
 | |
|                 sa.ForeignKey(Category.id)
 | |
|             )
 | |
|             author_id = sa.Column(
 | |
|                 sa.Integer,
 | |
|                 sa.ForeignKey(User.id)
 | |
|             )
 | |
|             author = sa.orm.relationship(
 | |
|                 User,
 | |
|                 backref=sa.orm.backref(
 | |
|                     'text_items'
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|             type = sa.Column(sa.Unicode(255))
 | |
| 
 | |
|             __mapper_args__ = {
 | |
|                 'polymorphic_on': type
 | |
|             }
 | |
| 
 | |
|         class Article(TextItem):
 | |
|             __tablename__ = 'article'
 | |
|             id = sa.Column(
 | |
|                 sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
 | |
|             )
 | |
|             __mapper_args__ = {
 | |
|                 'polymorphic_identity': u'article'
 | |
|             }
 | |
|             category = sa.orm.relationship(
 | |
|                 Category,
 | |
|                 backref=sa.orm.backref(
 | |
|                     'articles'
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         class BlogPost(TextItem):
 | |
|             __tablename__ = 'blog_post'
 | |
|             id = sa.Column(
 | |
|                 sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
 | |
|             )
 | |
|             __mapper_args__ = {
 | |
|                 'polymorphic_identity': u'blog_post'
 | |
|             }
 | |
|             category = sa.orm.relationship(
 | |
|                 Category,
 | |
|                 backref=sa.orm.backref(
 | |
|                     'blog_posts'
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         class Attachment(self.Base):
 | |
|             __tablename__ = 'attachment'
 | |
|             id = sa.Column(
 | |
|                 sa.Integer, primary_key=True
 | |
|             )
 | |
|             name = sa.Column(
 | |
|                 sa.Unicode(255), index=True
 | |
|             )
 | |
|             text_item_id = sa.Column(
 | |
|                 sa.Integer,
 | |
|                 sa.ForeignKey(TextItem.id),
 | |
|             )
 | |
|             text_item = sa.orm.relationship(TextItem, backref='attachments')
 | |
| 
 | |
|         self.Article = Article
 | |
|         self.Attachment = Attachment
 | |
|         self.BlogPost = BlogPost
 | |
|         self.Category = Category
 | |
|         self.TextItem = TextItem
 | |
|         self.User = User
 | |
| 
 | |
|     def setup_method(self, method):
 | |
|         TestCase.setup_method(self, method)
 | |
|         text_items = [
 | |
|             self.Article(name=u'Article 1'),
 | |
|             self.Article(name=u'Article 2'),
 | |
|             self.Article(name=u'Article 3'),
 | |
|             self.Article(name=u'Article 4'),
 | |
|             self.BlogPost(name=u'BlogPost 1'),
 | |
|             self.BlogPost(name=u'BlogPost 2'),
 | |
|             self.BlogPost(name=u'BlogPost 3'),
 | |
|             self.BlogPost(name=u'BlogPost 4')
 | |
|         ]
 | |
|         self.session.add_all(text_items)
 | |
|         self.session.flush()
 | |
|         category = self.Category(name=u'Category #1')
 | |
|         category.articles = text_items[0:2]
 | |
|         category.blog_posts = text_items[4:7]
 | |
|         category2 = self.Category(name=u'Category #2')
 | |
|         category2.articles = text_items[2:4]
 | |
|         category2.blog_posts = text_items[-1:]
 | |
|         self.session.add(category)
 | |
|         self.session.add(category2)
 | |
|         text_items[0].attachments = [
 | |
|             self.Attachment(id=22, name=u'Attachment 1'),
 | |
|             self.Attachment(id=34, name=u'Attachment 2')
 | |
|         ]
 | |
|         text_items[0].author = self.User(name=u'John Matrix')
 | |
|         text_items[1].author = self.User(name=u'John Doe')
 | |
| 
 | |
|         self.session.commit()
 | |
| 
 | |
| 
 | |
| class TestBatchFetchJoinTableInheritedModels(JoinTableInheritanceTestCase):
 | |
|     def test_multiple_relationships(self):
 | |
|         categories = self.session.query(self.Category).all()
 | |
|         batch_fetch(
 | |
|             categories,
 | |
|             'articles',
 | |
|             'blog_posts'
 | |
|         )
 | |
|         query_count = self.connection.query_count
 | |
|         categories[0].articles[1]
 | |
|         categories[0].blog_posts[0]
 | |
|         assert self.connection.query_count == query_count
 | |
|         categories[1].articles[1]
 | |
|         categories[1].blog_posts[0]
 | |
|         assert self.connection.query_count == query_count
 | |
| 
 | |
|     def test_one_to_many_relationships(self):
 | |
|         articles = (
 | |
|             self.session.query(self.Article)
 | |
|             .filter_by(name=u'Article 1')
 | |
|             .all()
 | |
|         )
 | |
|         batch_fetch(
 | |
|             articles,
 | |
|             'attachments'
 | |
|         )
 | |
| 
 | |
|     def test_many_to_one_relationships(self):
 | |
|         articles = (
 | |
|             self.session.query(self.Article)
 | |
|             .filter_by(name=u'Article 1')
 | |
|             .all()
 | |
|         )
 | |
|         batch_fetch(
 | |
|             articles,
 | |
|             'author'
 | |
|         )
 | 
