Initial support for many-to-many relationship batch fetching
This commit is contained in:
@@ -4,6 +4,12 @@ Changelog
|
|||||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||||
|
|
||||||
|
|
||||||
|
0.16.4 (2013-08-08)
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
- Initial many-to-many relations support for batch_fetch
|
||||||
|
|
||||||
|
|
||||||
0.16.3 (2013-08-05)
|
0.16.3 (2013-08-05)
|
||||||
^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@@ -55,7 +55,7 @@ for name, requirements in extras_require.items():
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='SQLAlchemy-Utils',
|
name='SQLAlchemy-Utils',
|
||||||
version='0.16.3',
|
version='0.16.4',
|
||||||
url='https://github.com/kvesteri/sqlalchemy-utils',
|
url='https://github.com/kvesteri/sqlalchemy-utils',
|
||||||
license='BSD',
|
license='BSD',
|
||||||
author='Konsta Vesterinen',
|
author='Konsta Vesterinen',
|
||||||
|
@@ -34,7 +34,7 @@ from .types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
__version__ = '0.16.3'
|
__version__ = '0.16.4'
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
|
@@ -390,7 +390,7 @@ def render_statement(statement, bind=None):
|
|||||||
return Compiler(bind.dialect, statement).process(statement)
|
return Compiler(bind.dialect, statement).process(statement)
|
||||||
|
|
||||||
|
|
||||||
def batch_fetch(entities, attr):
|
def batch_fetch(entities, *attr_paths):
|
||||||
"""
|
"""
|
||||||
Batch fetch given relationship attribute for collection of entities.
|
Batch fetch given relationship attribute for collection of entities.
|
||||||
|
|
||||||
@@ -412,41 +412,94 @@ def batch_fetch(entities, attr):
|
|||||||
|
|
||||||
batch_fetch(users, User.phonenumbers)
|
batch_fetch(users, User.phonenumbers)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if entities:
|
if entities:
|
||||||
first = entities[0]
|
first = entities[0]
|
||||||
if isinstance(attr, six.string_types):
|
|
||||||
attr = getattr(
|
|
||||||
first.__class__, attr
|
|
||||||
)
|
|
||||||
|
|
||||||
prop = attr.property
|
|
||||||
if not isinstance(prop, RelationshipProperty):
|
|
||||||
raise Exception(
|
|
||||||
'Given attribute is not a relationship property.'
|
|
||||||
)
|
|
||||||
|
|
||||||
model = prop.mapper.class_
|
|
||||||
session = object_session(first)
|
|
||||||
|
|
||||||
if len(prop.remote_side) > 1:
|
|
||||||
raise Exception(
|
|
||||||
'Only relationships with single remote side columns are '
|
|
||||||
'supported.'
|
|
||||||
)
|
|
||||||
|
|
||||||
column_name = list(prop.remote_side)[0].name
|
|
||||||
parent_ids = [entity.id for entity in entities]
|
parent_ids = [entity.id for entity in entities]
|
||||||
|
|
||||||
related_entities = (
|
|
||||||
session.query(model)
|
|
||||||
.filter(
|
|
||||||
getattr(model, column_name).in_(parent_ids)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
parent_dict = dict((entity.id, []) for entity in entities)
|
parent_dict = dict((entity.id, []) for entity in entities)
|
||||||
for entity in related_entities:
|
|
||||||
parent_dict[getattr(entity, column_name)].append(entity)
|
|
||||||
|
|
||||||
for entity in entities:
|
for attr_path in attr_paths:
|
||||||
set_committed_value(entity, prop.key, parent_dict[entity.id])
|
if isinstance(attr_path, six.string_types):
|
||||||
|
attrs = attr_path.split('.')
|
||||||
|
|
||||||
|
if len(attrs) > 1:
|
||||||
|
related_entities = []
|
||||||
|
for entity in entities:
|
||||||
|
related_entities.extend(getattr(entity, attrs[0]))
|
||||||
|
|
||||||
|
batch_fetch(related_entities, '.'.join(attrs[1:]))
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
attr = getattr(
|
||||||
|
first.__class__, attrs[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attr = attr_path
|
||||||
|
|
||||||
|
prop = attr.property
|
||||||
|
if not isinstance(prop, RelationshipProperty):
|
||||||
|
raise Exception(
|
||||||
|
'Given attribute is not a relationship property.'
|
||||||
|
)
|
||||||
|
|
||||||
|
model = prop.mapper.class_
|
||||||
|
|
||||||
|
session = object_session(first)
|
||||||
|
|
||||||
|
if prop.secondary is None:
|
||||||
|
if len(prop.remote_side) > 1:
|
||||||
|
raise Exception(
|
||||||
|
'Only relationships with single remote side columns '
|
||||||
|
'are supported.'
|
||||||
|
)
|
||||||
|
|
||||||
|
column_name = list(prop.remote_side)[0].name
|
||||||
|
|
||||||
|
related_entities = (
|
||||||
|
session.query(model)
|
||||||
|
.filter(
|
||||||
|
getattr(model, column_name).in_(parent_ids)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for entity in related_entities:
|
||||||
|
parent_dict[getattr(entity, column_name)].append(
|
||||||
|
entity
|
||||||
|
)
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
set_committed_value(
|
||||||
|
entity, prop.key, parent_dict[entity.id]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
column_name = None
|
||||||
|
for column in prop.remote_side:
|
||||||
|
for fk in column.foreign_keys:
|
||||||
|
# TODO: make this support inherited tables
|
||||||
|
if fk.column.table == first.__class__.__table__:
|
||||||
|
column_name = fk.parent.name
|
||||||
|
break
|
||||||
|
if column_name:
|
||||||
|
break
|
||||||
|
|
||||||
|
related_entities = (
|
||||||
|
session
|
||||||
|
.query(model, getattr(prop.secondary.c, column_name))
|
||||||
|
.join(
|
||||||
|
prop.secondary, prop.secondaryjoin
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
getattr(prop.secondary.c, column_name).in_(
|
||||||
|
parent_ids
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for entity, parent_id in related_entities:
|
||||||
|
parent_dict[parent_id].append(
|
||||||
|
entity
|
||||||
|
)
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
set_committed_value(
|
||||||
|
entity, prop.key, parent_dict[entity.id]
|
||||||
|
)
|
||||||
|
@@ -28,6 +28,7 @@ class TestCase(object):
|
|||||||
|
|
||||||
def setup_method(self, method):
|
def setup_method(self, method):
|
||||||
self.engine = create_engine(self.dns)
|
self.engine = create_engine(self.dns)
|
||||||
|
#self.engine.echo = True
|
||||||
self.connection = self.engine.connect()
|
self.connection = self.engine.connect()
|
||||||
self.Base = declarative_base()
|
self.Base = declarative_base()
|
||||||
|
|
||||||
|
@@ -1,22 +1,91 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
from pytest import raises
|
from pytest import raises
|
||||||
from sqlalchemy_utils import batch_fetch
|
from sqlalchemy_utils import batch_fetch
|
||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
class TestBatchFetch(TestCase):
|
class TestBatchFetch(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
|
__tablename__ = 'category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Article(self.Base):
|
||||||
|
__tablename__ = 'article'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||||
|
|
||||||
|
category = sa.orm.relationship(
|
||||||
|
Category,
|
||||||
|
primaryjoin=category_id == Category.id,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'articles'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
article_tag = sa.Table(
|
||||||
|
'article_tag',
|
||||||
|
self.Base.metadata,
|
||||||
|
sa.Column(
|
||||||
|
'article_id',
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey('article.id', ondelete='cascade')
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
'tag_id',
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey('tag.id', ondelete='cascade')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
class Tag(self.Base):
|
||||||
|
__tablename__ = 'tag'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
articles = sa.orm.relationship(
|
||||||
|
Article,
|
||||||
|
secondary=article_tag,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'tags'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.Category = Category
|
||||||
|
self.Article = Article
|
||||||
|
self.Tag = Tag
|
||||||
|
|
||||||
def setup_method(self, method):
|
def setup_method(self, method):
|
||||||
TestCase.setup_method(self, method)
|
TestCase.setup_method(self, method)
|
||||||
category = self.Category(name=u'Category #1')
|
articles = [
|
||||||
category.articles = [
|
|
||||||
self.Article(name=u'Article 1'),
|
self.Article(name=u'Article 1'),
|
||||||
self.Article(name=u'Article 2')
|
self.Article(name=u'Article 2'),
|
||||||
]
|
|
||||||
category2 = self.Category(name=u'Category #2')
|
|
||||||
category2.articles = [
|
|
||||||
self.Article(name=u'Article 3'),
|
self.Article(name=u'Article 3'),
|
||||||
self.Article(name=u'Article 4'),
|
self.Article(name=u'Article 4'),
|
||||||
self.Article(name=u'Article 5')
|
self.Article(name=u'Article 5')
|
||||||
]
|
]
|
||||||
|
self.session.add_all(articles)
|
||||||
|
self.session.flush()
|
||||||
|
|
||||||
|
tags = [
|
||||||
|
self.Tag(name=u'Tag 1'),
|
||||||
|
self.Tag(name=u'Tag 2'),
|
||||||
|
self.Tag(name=u'Tag 3')
|
||||||
|
]
|
||||||
|
articles[0].tags = tags
|
||||||
|
articles[3].tags = tags[1:]
|
||||||
|
|
||||||
|
category = self.Category(name=u'Category #1')
|
||||||
|
category.articles = articles[0:2]
|
||||||
|
category2 = self.Category(name=u'Category #2')
|
||||||
|
category2.articles = articles[2:]
|
||||||
self.session.add(category)
|
self.session.add(category)
|
||||||
self.session.add(category2)
|
self.session.add(category2)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
@@ -32,3 +101,14 @@ class TestBatchFetch(TestCase):
|
|||||||
query_count = self.connection.query_count
|
query_count = self.connection.query_count
|
||||||
categories[0].articles # no lazy load should occur
|
categories[0].articles # no lazy load should occur
|
||||||
assert self.connection.query_count == query_count
|
assert self.connection.query_count == query_count
|
||||||
|
|
||||||
|
def test_multiple_relationships(self):
|
||||||
|
categories = self.session.query(self.Category).all()
|
||||||
|
batch_fetch(
|
||||||
|
categories,
|
||||||
|
'articles',
|
||||||
|
'articles.tags'
|
||||||
|
)
|
||||||
|
query_count = self.connection.query_count
|
||||||
|
categories[0].articles[0].tags
|
||||||
|
assert self.connection.query_count == query_count
|
||||||
|
Reference in New Issue
Block a user