diff --git a/CHANGES.rst b/CHANGES.rst index fa54812..67c22ca 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.16.5 (2013-08-08) +^^^^^^^^^^^^^^^^^^^ + +- Initial backref population forcing for batch_fetch + + 0.16.4 (2013-08-08) ^^^^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 6936436..c4acf2b 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ for name, requirements in extras_require.items(): setup( name='SQLAlchemy-Utils', - version='0.16.4', + version='0.16.5', url='https://github.com/kvesteri/sqlalchemy-utils', license='BSD', author='Konsta Vesterinen', diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index dec4c2e..ea0259f 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -34,7 +34,7 @@ from .types import ( ) -__version__ = '0.16.4' +__version__ = '0.16.5' __all__ = ( diff --git a/sqlalchemy_utils/functions.py b/sqlalchemy_utils/functions.py index bb19f03..0ecaebb 100644 --- a/sqlalchemy_utils/functions.py +++ b/sqlalchemy_utils/functions.py @@ -430,8 +430,21 @@ def batch_fetch(entities, *attr_paths): clubs, 'teams', 'teams.players', - 'teams.players.friends' + 'teams.players.user_groups' ) + + You can also force populate backrefs: :: + + + clubs = session.query(Club).limit(20).all() + + batch_fetch( + clubs, + 'teams', + 'teams.players', + 'teams.players.user_groups -pb' + ) + """ if entities: @@ -440,6 +453,8 @@ def batch_fetch(entities, *attr_paths): for attr_path in attr_paths: parent_dict = dict((entity.id, []) for entity in entities) + populate_backrefs = False + if isinstance(attr_path, six.string_types): attrs = attr_path.split('.') @@ -448,11 +463,18 @@ def batch_fetch(entities, *attr_paths): for entity in entities: related_entities.extend(getattr(entity, attrs[0])) - batch_fetch(related_entities, '.'.join(attrs[1:])) + batch_fetch( + related_entities, + '.'.join(attrs[1:]) + ) continue else: + args = attrs[-1].split(' ') + if '-pb' in args: + populate_backrefs = True + attr = getattr( - first.__class__, attrs[0] + first.__class__, args[0] ) else: attr = attr_path @@ -488,10 +510,6 @@ def batch_fetch(entities, *attr_paths): entity ) - for entity in entities: - set_committed_value( - entity, prop.key, parent_dict[entity.id] - ) else: column_name = None for column in prop.remote_side: @@ -520,7 +538,19 @@ def batch_fetch(entities, *attr_paths): entity ) - for entity in entities: - set_committed_value( - entity, prop.key, parent_dict[entity.id] + for entity in entities: + set_committed_value( + entity, prop.key, parent_dict[entity.id] + ) + if populate_backrefs: + backref_dict = dict( + (entity.id, []) for entity, parent_id in related_entities + ) + for entity, parent_id in related_entities: + backref_dict[entity.id].append( + session.query(first.__class__).get(parent_id) + ) + for entity, parent_id in related_entities: + set_committed_value( + entity, prop.back_populates, backref_dict[entity.id] ) diff --git a/tests/batch_fetch/test_deep_relationships.py b/tests/batch_fetch/test_deep_relationships.py index 9c43d66..e4fa619 100644 --- a/tests/batch_fetch/test_deep_relationships.py +++ b/tests/batch_fetch/test_deep_relationships.py @@ -89,7 +89,7 @@ class TestBatchFetch(TestCase): self.session.add(category2) self.session.commit() - def test_multiple_relationships(self): + def test_deep_relationships(self): categories = self.session.query(self.Category).all() batch_fetch( categories, @@ -101,3 +101,19 @@ class TestBatchFetch(TestCase): assert self.connection.query_count == query_count categories[1].articles[1].tags assert self.connection.query_count == query_count + + def test_many_to_many_backref_population(self): + categories = self.session.query(self.Category).all() + batch_fetch( + categories, + 'articles', + 'articles.tags -pb', + ) + query_count = self.connection.query_count + tags = categories[0].articles[0].tags + tags2 = categories[1].articles[1].tags + tags[0].articles + tags2[0].articles + names = [article.name for article in tags[0].articles] + assert u'Article 1' in names + assert self.connection.query_count == query_count diff --git a/tests/test_batch_fetch.py b/tests/test_batch_fetch.py deleted file mode 100644 index 55c2411..0000000 --- a/tests/test_batch_fetch.py +++ /dev/null @@ -1,116 +0,0 @@ -import sqlalchemy as sa -from pytest import raises -from sqlalchemy_utils import batch_fetch -from tests import TestCase - - -class TestBatchFetch(TestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - class Article(self.Base): - __tablename__ = 'article' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) - - category = sa.orm.relationship( - Category, - primaryjoin=category_id == Category.id, - backref=sa.orm.backref( - 'articles' - ) - ) - - article_tag = sa.Table( - 'article_tag', - self.Base.metadata, - sa.Column( - 'article_id', - sa.Integer, - sa.ForeignKey('article.id', ondelete='cascade') - ), - sa.Column( - 'tag_id', - sa.Integer, - sa.ForeignKey('tag.id', ondelete='cascade') - ) - ) - - class Tag(self.Base): - __tablename__ = 'tag' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - articles = sa.orm.relationship( - Article, - secondary=article_tag, - backref=sa.orm.backref( - 'tags' - ) - ) - - self.User = User - self.Category = Category - self.Article = Article - self.Tag = Tag - - def setup_method(self, method): - TestCase.setup_method(self, method) - articles = [ - self.Article(name=u'Article 1'), - self.Article(name=u'Article 2'), - self.Article(name=u'Article 3'), - self.Article(name=u'Article 4'), - self.Article(name=u'Article 5') - ] - self.session.add_all(articles) - self.session.flush() - - tags = [ - self.Tag(name=u'Tag 1'), - self.Tag(name=u'Tag 2'), - self.Tag(name=u'Tag 3') - ] - articles[0].tags = tags - articles[3].tags = tags[1:] - - category = self.Category(name=u'Category #1') - category.articles = articles[0:2] - category2 = self.Category(name=u'Category #2') - category2.articles = articles[2:] - self.session.add(category) - self.session.add(category2) - self.session.commit() - - def test_raises_error_if_relationship_not_found(self): - categories = self.session.query(self.Category).all() - with raises(AttributeError): - batch_fetch(categories, 'unknown_relation') - - def test_supports_relationship_attributes(self): - categories = self.session.query(self.Category).all() - batch_fetch(categories, self.Category.articles) - query_count = self.connection.query_count - categories[0].articles # no lazy load should occur - assert self.connection.query_count == query_count - - def test_multiple_relationships(self): - categories = self.session.query(self.Category).all() - batch_fetch( - categories, - 'articles', - 'articles.tags' - ) - query_count = self.connection.query_count - categories[0].articles[0].tags - assert self.connection.query_count == query_count - categories[1].articles[1].tags - assert self.connection.query_count == query_count