Added backref population forcing for batch_fetch
This commit is contained in:
@@ -4,6 +4,12 @@ Changelog
|
||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||
|
||||
|
||||
0.16.5 (2013-08-08)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Initial backref population forcing for batch_fetch
|
||||
|
||||
|
||||
0.16.4 (2013-08-08)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
2
setup.py
2
setup.py
@@ -55,7 +55,7 @@ for name, requirements in extras_require.items():
|
||||
|
||||
setup(
|
||||
name='SQLAlchemy-Utils',
|
||||
version='0.16.4',
|
||||
version='0.16.5',
|
||||
url='https://github.com/kvesteri/sqlalchemy-utils',
|
||||
license='BSD',
|
||||
author='Konsta Vesterinen',
|
||||
|
@@ -34,7 +34,7 @@ from .types import (
|
||||
)
|
||||
|
||||
|
||||
__version__ = '0.16.4'
|
||||
__version__ = '0.16.5'
|
||||
|
||||
|
||||
__all__ = (
|
||||
|
@@ -430,8 +430,21 @@ def batch_fetch(entities, *attr_paths):
|
||||
clubs,
|
||||
'teams',
|
||||
'teams.players',
|
||||
'teams.players.friends'
|
||||
'teams.players.user_groups'
|
||||
)
|
||||
|
||||
You can also force populate backrefs: ::
|
||||
|
||||
|
||||
clubs = session.query(Club).limit(20).all()
|
||||
|
||||
batch_fetch(
|
||||
clubs,
|
||||
'teams',
|
||||
'teams.players',
|
||||
'teams.players.user_groups -pb'
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
if entities:
|
||||
@@ -440,6 +453,8 @@ def batch_fetch(entities, *attr_paths):
|
||||
|
||||
for attr_path in attr_paths:
|
||||
parent_dict = dict((entity.id, []) for entity in entities)
|
||||
populate_backrefs = False
|
||||
|
||||
if isinstance(attr_path, six.string_types):
|
||||
attrs = attr_path.split('.')
|
||||
|
||||
@@ -448,11 +463,18 @@ def batch_fetch(entities, *attr_paths):
|
||||
for entity in entities:
|
||||
related_entities.extend(getattr(entity, attrs[0]))
|
||||
|
||||
batch_fetch(related_entities, '.'.join(attrs[1:]))
|
||||
batch_fetch(
|
||||
related_entities,
|
||||
'.'.join(attrs[1:])
|
||||
)
|
||||
continue
|
||||
else:
|
||||
args = attrs[-1].split(' ')
|
||||
if '-pb' in args:
|
||||
populate_backrefs = True
|
||||
|
||||
attr = getattr(
|
||||
first.__class__, attrs[0]
|
||||
first.__class__, args[0]
|
||||
)
|
||||
else:
|
||||
attr = attr_path
|
||||
@@ -488,10 +510,6 @@ def batch_fetch(entities, *attr_paths):
|
||||
entity
|
||||
)
|
||||
|
||||
for entity in entities:
|
||||
set_committed_value(
|
||||
entity, prop.key, parent_dict[entity.id]
|
||||
)
|
||||
else:
|
||||
column_name = None
|
||||
for column in prop.remote_side:
|
||||
@@ -520,7 +538,19 @@ def batch_fetch(entities, *attr_paths):
|
||||
entity
|
||||
)
|
||||
|
||||
for entity in entities:
|
||||
set_committed_value(
|
||||
entity, prop.key, parent_dict[entity.id]
|
||||
for entity in entities:
|
||||
set_committed_value(
|
||||
entity, prop.key, parent_dict[entity.id]
|
||||
)
|
||||
if populate_backrefs:
|
||||
backref_dict = dict(
|
||||
(entity.id, []) for entity, parent_id in related_entities
|
||||
)
|
||||
for entity, parent_id in related_entities:
|
||||
backref_dict[entity.id].append(
|
||||
session.query(first.__class__).get(parent_id)
|
||||
)
|
||||
for entity, parent_id in related_entities:
|
||||
set_committed_value(
|
||||
entity, prop.back_populates, backref_dict[entity.id]
|
||||
)
|
||||
|
@@ -89,7 +89,7 @@ class TestBatchFetch(TestCase):
|
||||
self.session.add(category2)
|
||||
self.session.commit()
|
||||
|
||||
def test_multiple_relationships(self):
|
||||
def test_deep_relationships(self):
|
||||
categories = self.session.query(self.Category).all()
|
||||
batch_fetch(
|
||||
categories,
|
||||
@@ -101,3 +101,19 @@ class TestBatchFetch(TestCase):
|
||||
assert self.connection.query_count == query_count
|
||||
categories[1].articles[1].tags
|
||||
assert self.connection.query_count == query_count
|
||||
|
||||
def test_many_to_many_backref_population(self):
|
||||
categories = self.session.query(self.Category).all()
|
||||
batch_fetch(
|
||||
categories,
|
||||
'articles',
|
||||
'articles.tags -pb',
|
||||
)
|
||||
query_count = self.connection.query_count
|
||||
tags = categories[0].articles[0].tags
|
||||
tags2 = categories[1].articles[1].tags
|
||||
tags[0].articles
|
||||
tags2[0].articles
|
||||
names = [article.name for article in tags[0].articles]
|
||||
assert u'Article 1' in names
|
||||
assert self.connection.query_count == query_count
|
||||
|
@@ -1,116 +0,0 @@
|
||||
import sqlalchemy as sa
|
||||
from pytest import raises
|
||||
from sqlalchemy_utils import batch_fetch
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestBatchFetch(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||
|
||||
category = sa.orm.relationship(
|
||||
Category,
|
||||
primaryjoin=category_id == Category.id,
|
||||
backref=sa.orm.backref(
|
||||
'articles'
|
||||
)
|
||||
)
|
||||
|
||||
article_tag = sa.Table(
|
||||
'article_tag',
|
||||
self.Base.metadata,
|
||||
sa.Column(
|
||||
'article_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('article.id', ondelete='cascade')
|
||||
),
|
||||
sa.Column(
|
||||
'tag_id',
|
||||
sa.Integer,
|
||||
sa.ForeignKey('tag.id', ondelete='cascade')
|
||||
)
|
||||
)
|
||||
|
||||
class Tag(self.Base):
|
||||
__tablename__ = 'tag'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
articles = sa.orm.relationship(
|
||||
Article,
|
||||
secondary=article_tag,
|
||||
backref=sa.orm.backref(
|
||||
'tags'
|
||||
)
|
||||
)
|
||||
|
||||
self.User = User
|
||||
self.Category = Category
|
||||
self.Article = Article
|
||||
self.Tag = Tag
|
||||
|
||||
def setup_method(self, method):
|
||||
TestCase.setup_method(self, method)
|
||||
articles = [
|
||||
self.Article(name=u'Article 1'),
|
||||
self.Article(name=u'Article 2'),
|
||||
self.Article(name=u'Article 3'),
|
||||
self.Article(name=u'Article 4'),
|
||||
self.Article(name=u'Article 5')
|
||||
]
|
||||
self.session.add_all(articles)
|
||||
self.session.flush()
|
||||
|
||||
tags = [
|
||||
self.Tag(name=u'Tag 1'),
|
||||
self.Tag(name=u'Tag 2'),
|
||||
self.Tag(name=u'Tag 3')
|
||||
]
|
||||
articles[0].tags = tags
|
||||
articles[3].tags = tags[1:]
|
||||
|
||||
category = self.Category(name=u'Category #1')
|
||||
category.articles = articles[0:2]
|
||||
category2 = self.Category(name=u'Category #2')
|
||||
category2.articles = articles[2:]
|
||||
self.session.add(category)
|
||||
self.session.add(category2)
|
||||
self.session.commit()
|
||||
|
||||
def test_raises_error_if_relationship_not_found(self):
|
||||
categories = self.session.query(self.Category).all()
|
||||
with raises(AttributeError):
|
||||
batch_fetch(categories, 'unknown_relation')
|
||||
|
||||
def test_supports_relationship_attributes(self):
|
||||
categories = self.session.query(self.Category).all()
|
||||
batch_fetch(categories, self.Category.articles)
|
||||
query_count = self.connection.query_count
|
||||
categories[0].articles # no lazy load should occur
|
||||
assert self.connection.query_count == query_count
|
||||
|
||||
def test_multiple_relationships(self):
|
||||
categories = self.session.query(self.Category).all()
|
||||
batch_fetch(
|
||||
categories,
|
||||
'articles',
|
||||
'articles.tags'
|
||||
)
|
||||
query_count = self.connection.query_count
|
||||
categories[0].articles[0].tags
|
||||
assert self.connection.query_count == query_count
|
||||
categories[1].articles[1].tags
|
||||
assert self.connection.query_count == query_count
|
Reference in New Issue
Block a user