Initial support for many-to-many relationship batch fetching

This commit is contained in:
Konsta Vesterinen
2013-08-08 13:21:49 +03:00
parent 29ad0343a3
commit 6657fae37c
6 changed files with 182 additions and 42 deletions

View File

@@ -4,6 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.16.4 (2013-08-08)
^^^^^^^^^^^^^^^^^^^
- Initial many-to-many relations support for batch_fetch
0.16.3 (2013-08-05)
^^^^^^^^^^^^^^^^^^^

View File

@@ -55,7 +55,7 @@ for name, requirements in extras_require.items():
setup(
name='SQLAlchemy-Utils',
version='0.16.3',
version='0.16.4',
url='https://github.com/kvesteri/sqlalchemy-utils',
license='BSD',
author='Konsta Vesterinen',

View File

@@ -34,7 +34,7 @@ from .types import (
)
__version__ = '0.16.3'
__version__ = '0.16.4'
__all__ = (

View File

@@ -390,7 +390,7 @@ def render_statement(statement, bind=None):
return Compiler(bind.dialect, statement).process(statement)
def batch_fetch(entities, attr):
def batch_fetch(entities, *attr_paths):
"""
Batch fetch given relationship attribute for collection of entities.
@@ -412,41 +412,94 @@ def batch_fetch(entities, attr):
batch_fetch(users, User.phonenumbers)
"""
if entities:
first = entities[0]
if isinstance(attr, six.string_types):
attr = getattr(
first.__class__, attr
)
prop = attr.property
if not isinstance(prop, RelationshipProperty):
raise Exception(
'Given attribute is not a relationship property.'
)
model = prop.mapper.class_
session = object_session(first)
if len(prop.remote_side) > 1:
raise Exception(
'Only relationships with single remote side columns are '
'supported.'
)
column_name = list(prop.remote_side)[0].name
parent_ids = [entity.id for entity in entities]
related_entities = (
session.query(model)
.filter(
getattr(model, column_name).in_(parent_ids)
)
)
parent_dict = dict((entity.id, []) for entity in entities)
for entity in related_entities:
parent_dict[getattr(entity, column_name)].append(entity)
for entity in entities:
set_committed_value(entity, prop.key, parent_dict[entity.id])
for attr_path in attr_paths:
if isinstance(attr_path, six.string_types):
attrs = attr_path.split('.')
if len(attrs) > 1:
related_entities = []
for entity in entities:
related_entities.extend(getattr(entity, attrs[0]))
batch_fetch(related_entities, '.'.join(attrs[1:]))
continue
else:
attr = getattr(
first.__class__, attrs[0]
)
else:
attr = attr_path
prop = attr.property
if not isinstance(prop, RelationshipProperty):
raise Exception(
'Given attribute is not a relationship property.'
)
model = prop.mapper.class_
session = object_session(first)
if prop.secondary is None:
if len(prop.remote_side) > 1:
raise Exception(
'Only relationships with single remote side columns '
'are supported.'
)
column_name = list(prop.remote_side)[0].name
related_entities = (
session.query(model)
.filter(
getattr(model, column_name).in_(parent_ids)
)
)
for entity in related_entities:
parent_dict[getattr(entity, column_name)].append(
entity
)
for entity in entities:
set_committed_value(
entity, prop.key, parent_dict[entity.id]
)
else:
column_name = None
for column in prop.remote_side:
for fk in column.foreign_keys:
# TODO: make this support inherited tables
if fk.column.table == first.__class__.__table__:
column_name = fk.parent.name
break
if column_name:
break
related_entities = (
session
.query(model, getattr(prop.secondary.c, column_name))
.join(
prop.secondary, prop.secondaryjoin
)
.filter(
getattr(prop.secondary.c, column_name).in_(
parent_ids
)
)
)
for entity, parent_id in related_entities:
parent_dict[parent_id].append(
entity
)
for entity in entities:
set_committed_value(
entity, prop.key, parent_dict[entity.id]
)

View File

@@ -28,6 +28,7 @@ class TestCase(object):
def setup_method(self, method):
self.engine = create_engine(self.dns)
#self.engine.echo = True
self.connection = self.engine.connect()
self.Base = declarative_base()

View File

@@ -1,22 +1,91 @@
import sqlalchemy as sa
from pytest import raises
from sqlalchemy_utils import batch_fetch
from tests import TestCase
class TestBatchFetch(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(
Category,
primaryjoin=category_id == Category.id,
backref=sa.orm.backref(
'articles'
)
)
article_tag = sa.Table(
'article_tag',
self.Base.metadata,
sa.Column(
'article_id',
sa.Integer,
sa.ForeignKey('article.id', ondelete='cascade')
),
sa.Column(
'tag_id',
sa.Integer,
sa.ForeignKey('tag.id', ondelete='cascade')
)
)
class Tag(self.Base):
__tablename__ = 'tag'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
articles = sa.orm.relationship(
Article,
secondary=article_tag,
backref=sa.orm.backref(
'tags'
)
)
self.User = User
self.Category = Category
self.Article = Article
self.Tag = Tag
def setup_method(self, method):
TestCase.setup_method(self, method)
category = self.Category(name=u'Category #1')
category.articles = [
articles = [
self.Article(name=u'Article 1'),
self.Article(name=u'Article 2')
]
category2 = self.Category(name=u'Category #2')
category2.articles = [
self.Article(name=u'Article 2'),
self.Article(name=u'Article 3'),
self.Article(name=u'Article 4'),
self.Article(name=u'Article 5')
]
self.session.add_all(articles)
self.session.flush()
tags = [
self.Tag(name=u'Tag 1'),
self.Tag(name=u'Tag 2'),
self.Tag(name=u'Tag 3')
]
articles[0].tags = tags
articles[3].tags = tags[1:]
category = self.Category(name=u'Category #1')
category.articles = articles[0:2]
category2 = self.Category(name=u'Category #2')
category2.articles = articles[2:]
self.session.add(category)
self.session.add(category2)
self.session.commit()
@@ -32,3 +101,14 @@ class TestBatchFetch(TestCase):
query_count = self.connection.query_count
categories[0].articles # no lazy load should occur
assert self.connection.query_count == query_count
def test_multiple_relationships(self):
categories = self.session.query(self.Category).all()
batch_fetch(
categories,
'articles',
'articles.tags'
)
query_count = self.connection.query_count
categories[0].articles[0].tags
assert self.connection.query_count == query_count