Initial support for many-to-many relationship batch fetching
This commit is contained in:
		@@ -4,6 +4,12 @@ Changelog
 | 
			
		||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
0.16.4 (2013-08-08)
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
- Initial many-to-many relations support for batch_fetch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
0.16.3 (2013-08-05)
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@@ -55,7 +55,7 @@ for name, requirements in extras_require.items():
 | 
			
		||||
 | 
			
		||||
setup(
 | 
			
		||||
    name='SQLAlchemy-Utils',
 | 
			
		||||
    version='0.16.3',
 | 
			
		||||
    version='0.16.4',
 | 
			
		||||
    url='https://github.com/kvesteri/sqlalchemy-utils',
 | 
			
		||||
    license='BSD',
 | 
			
		||||
    author='Konsta Vesterinen',
 | 
			
		||||
 
 | 
			
		||||
@@ -34,7 +34,7 @@ from .types import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__version__ = '0.16.3'
 | 
			
		||||
__version__ = '0.16.4'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = (
 | 
			
		||||
 
 | 
			
		||||
@@ -390,7 +390,7 @@ def render_statement(statement, bind=None):
 | 
			
		||||
    return Compiler(bind.dialect, statement).process(statement)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def batch_fetch(entities, attr):
 | 
			
		||||
def batch_fetch(entities, *attr_paths):
 | 
			
		||||
    """
 | 
			
		||||
    Batch fetch given relationship attribute for collection of entities.
 | 
			
		||||
 | 
			
		||||
@@ -412,41 +412,94 @@ def batch_fetch(entities, attr):
 | 
			
		||||
 | 
			
		||||
        batch_fetch(users, User.phonenumbers)
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if entities:
 | 
			
		||||
        first = entities[0]
 | 
			
		||||
        if isinstance(attr, six.string_types):
 | 
			
		||||
            attr = getattr(
 | 
			
		||||
                first.__class__, attr
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        prop = attr.property
 | 
			
		||||
        if not isinstance(prop, RelationshipProperty):
 | 
			
		||||
            raise Exception(
 | 
			
		||||
                'Given attribute is not a relationship property.'
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        model = prop.mapper.class_
 | 
			
		||||
        session = object_session(first)
 | 
			
		||||
 | 
			
		||||
        if len(prop.remote_side) > 1:
 | 
			
		||||
            raise Exception(
 | 
			
		||||
                'Only relationships with single remote side columns are '
 | 
			
		||||
                'supported.'
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        column_name = list(prop.remote_side)[0].name
 | 
			
		||||
        parent_ids = [entity.id for entity in entities]
 | 
			
		||||
 | 
			
		||||
        related_entities = (
 | 
			
		||||
            session.query(model)
 | 
			
		||||
            .filter(
 | 
			
		||||
                getattr(model, column_name).in_(parent_ids)
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        parent_dict = dict((entity.id, []) for entity in entities)
 | 
			
		||||
        for entity in related_entities:
 | 
			
		||||
            parent_dict[getattr(entity, column_name)].append(entity)
 | 
			
		||||
 | 
			
		||||
        for entity in entities:
 | 
			
		||||
            set_committed_value(entity, prop.key, parent_dict[entity.id])
 | 
			
		||||
        for attr_path in attr_paths:
 | 
			
		||||
            if isinstance(attr_path, six.string_types):
 | 
			
		||||
                attrs = attr_path.split('.')
 | 
			
		||||
 | 
			
		||||
                if len(attrs) > 1:
 | 
			
		||||
                    related_entities = []
 | 
			
		||||
                    for entity in entities:
 | 
			
		||||
                        related_entities.extend(getattr(entity, attrs[0]))
 | 
			
		||||
 | 
			
		||||
                    batch_fetch(related_entities, '.'.join(attrs[1:]))
 | 
			
		||||
                    continue
 | 
			
		||||
                else:
 | 
			
		||||
                    attr = getattr(
 | 
			
		||||
                        first.__class__, attrs[0]
 | 
			
		||||
                    )
 | 
			
		||||
            else:
 | 
			
		||||
                attr = attr_path
 | 
			
		||||
 | 
			
		||||
            prop = attr.property
 | 
			
		||||
            if not isinstance(prop, RelationshipProperty):
 | 
			
		||||
                raise Exception(
 | 
			
		||||
                    'Given attribute is not a relationship property.'
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            model = prop.mapper.class_
 | 
			
		||||
 | 
			
		||||
            session = object_session(first)
 | 
			
		||||
 | 
			
		||||
            if prop.secondary is None:
 | 
			
		||||
                if len(prop.remote_side) > 1:
 | 
			
		||||
                    raise Exception(
 | 
			
		||||
                        'Only relationships with single remote side columns '
 | 
			
		||||
                        'are supported.'
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                column_name = list(prop.remote_side)[0].name
 | 
			
		||||
 | 
			
		||||
                related_entities = (
 | 
			
		||||
                    session.query(model)
 | 
			
		||||
                    .filter(
 | 
			
		||||
                        getattr(model, column_name).in_(parent_ids)
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                for entity in related_entities:
 | 
			
		||||
                    parent_dict[getattr(entity, column_name)].append(
 | 
			
		||||
                        entity
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                for entity in entities:
 | 
			
		||||
                    set_committed_value(
 | 
			
		||||
                        entity, prop.key, parent_dict[entity.id]
 | 
			
		||||
                    )
 | 
			
		||||
            else:
 | 
			
		||||
                column_name = None
 | 
			
		||||
                for column in prop.remote_side:
 | 
			
		||||
                    for fk in column.foreign_keys:
 | 
			
		||||
                        # TODO: make this support inherited tables
 | 
			
		||||
                        if fk.column.table == first.__class__.__table__:
 | 
			
		||||
                            column_name = fk.parent.name
 | 
			
		||||
                            break
 | 
			
		||||
                    if column_name:
 | 
			
		||||
                        break
 | 
			
		||||
 | 
			
		||||
                related_entities = (
 | 
			
		||||
                    session
 | 
			
		||||
                    .query(model, getattr(prop.secondary.c, column_name))
 | 
			
		||||
                    .join(
 | 
			
		||||
                        prop.secondary, prop.secondaryjoin
 | 
			
		||||
                    )
 | 
			
		||||
                    .filter(
 | 
			
		||||
                        getattr(prop.secondary.c, column_name).in_(
 | 
			
		||||
                            parent_ids
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                for entity, parent_id in related_entities:
 | 
			
		||||
                    parent_dict[parent_id].append(
 | 
			
		||||
                        entity
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                for entity in entities:
 | 
			
		||||
                    set_committed_value(
 | 
			
		||||
                        entity, prop.key, parent_dict[entity.id]
 | 
			
		||||
                    )
 | 
			
		||||
 
 | 
			
		||||
@@ -28,6 +28,7 @@ class TestCase(object):
 | 
			
		||||
 | 
			
		||||
    def setup_method(self, method):
 | 
			
		||||
        self.engine = create_engine(self.dns)
 | 
			
		||||
        #self.engine.echo = True
 | 
			
		||||
        self.connection = self.engine.connect()
 | 
			
		||||
        self.Base = declarative_base()
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,22 +1,91 @@
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
from pytest import raises
 | 
			
		||||
from sqlalchemy_utils import batch_fetch
 | 
			
		||||
from tests import TestCase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestBatchFetch(TestCase):
 | 
			
		||||
    def create_models(self):
 | 
			
		||||
        class User(self.Base):
 | 
			
		||||
            __tablename__ = 'user'
 | 
			
		||||
            id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
 | 
			
		||||
            name = sa.Column(sa.Unicode(255))
 | 
			
		||||
 | 
			
		||||
        class Category(self.Base):
 | 
			
		||||
            __tablename__ = 'category'
 | 
			
		||||
            id = sa.Column(sa.Integer, primary_key=True)
 | 
			
		||||
            name = sa.Column(sa.Unicode(255))
 | 
			
		||||
 | 
			
		||||
        class Article(self.Base):
 | 
			
		||||
            __tablename__ = 'article'
 | 
			
		||||
            id = sa.Column(sa.Integer, primary_key=True)
 | 
			
		||||
            name = sa.Column(sa.Unicode(255))
 | 
			
		||||
            category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
 | 
			
		||||
 | 
			
		||||
            category = sa.orm.relationship(
 | 
			
		||||
                Category,
 | 
			
		||||
                primaryjoin=category_id == Category.id,
 | 
			
		||||
                backref=sa.orm.backref(
 | 
			
		||||
                    'articles'
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        article_tag = sa.Table(
 | 
			
		||||
            'article_tag',
 | 
			
		||||
            self.Base.metadata,
 | 
			
		||||
            sa.Column(
 | 
			
		||||
                'article_id',
 | 
			
		||||
                sa.Integer,
 | 
			
		||||
                sa.ForeignKey('article.id', ondelete='cascade')
 | 
			
		||||
            ),
 | 
			
		||||
            sa.Column(
 | 
			
		||||
                'tag_id',
 | 
			
		||||
                sa.Integer,
 | 
			
		||||
                sa.ForeignKey('tag.id', ondelete='cascade')
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        class Tag(self.Base):
 | 
			
		||||
            __tablename__ = 'tag'
 | 
			
		||||
            id = sa.Column(sa.Integer, primary_key=True)
 | 
			
		||||
            name = sa.Column(sa.Unicode(255))
 | 
			
		||||
            articles = sa.orm.relationship(
 | 
			
		||||
                Article,
 | 
			
		||||
                secondary=article_tag,
 | 
			
		||||
                backref=sa.orm.backref(
 | 
			
		||||
                    'tags'
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.User = User
 | 
			
		||||
        self.Category = Category
 | 
			
		||||
        self.Article = Article
 | 
			
		||||
        self.Tag = Tag
 | 
			
		||||
 | 
			
		||||
    def setup_method(self, method):
 | 
			
		||||
        TestCase.setup_method(self, method)
 | 
			
		||||
        category = self.Category(name=u'Category #1')
 | 
			
		||||
        category.articles = [
 | 
			
		||||
        articles = [
 | 
			
		||||
            self.Article(name=u'Article 1'),
 | 
			
		||||
            self.Article(name=u'Article 2')
 | 
			
		||||
        ]
 | 
			
		||||
        category2 = self.Category(name=u'Category #2')
 | 
			
		||||
        category2.articles = [
 | 
			
		||||
            self.Article(name=u'Article 2'),
 | 
			
		||||
            self.Article(name=u'Article 3'),
 | 
			
		||||
            self.Article(name=u'Article 4'),
 | 
			
		||||
            self.Article(name=u'Article 5')
 | 
			
		||||
        ]
 | 
			
		||||
        self.session.add_all(articles)
 | 
			
		||||
        self.session.flush()
 | 
			
		||||
 | 
			
		||||
        tags = [
 | 
			
		||||
            self.Tag(name=u'Tag 1'),
 | 
			
		||||
            self.Tag(name=u'Tag 2'),
 | 
			
		||||
            self.Tag(name=u'Tag 3')
 | 
			
		||||
        ]
 | 
			
		||||
        articles[0].tags = tags
 | 
			
		||||
        articles[3].tags = tags[1:]
 | 
			
		||||
 | 
			
		||||
        category = self.Category(name=u'Category #1')
 | 
			
		||||
        category.articles = articles[0:2]
 | 
			
		||||
        category2 = self.Category(name=u'Category #2')
 | 
			
		||||
        category2.articles = articles[2:]
 | 
			
		||||
        self.session.add(category)
 | 
			
		||||
        self.session.add(category2)
 | 
			
		||||
        self.session.commit()
 | 
			
		||||
@@ -32,3 +101,14 @@ class TestBatchFetch(TestCase):
 | 
			
		||||
        query_count = self.connection.query_count
 | 
			
		||||
        categories[0].articles  # no lazy load should occur
 | 
			
		||||
        assert self.connection.query_count == query_count
 | 
			
		||||
 | 
			
		||||
    def test_multiple_relationships(self):
 | 
			
		||||
        categories = self.session.query(self.Category).all()
 | 
			
		||||
        batch_fetch(
 | 
			
		||||
            categories,
 | 
			
		||||
            'articles',
 | 
			
		||||
            'articles.tags'
 | 
			
		||||
        )
 | 
			
		||||
        query_count = self.connection.query_count
 | 
			
		||||
        categories[0].articles[0].tags
 | 
			
		||||
        assert self.connection.query_count == query_count
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user