diff --git a/CHANGES.rst b/CHANGES.rst index 6b58280..fa54812 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.16.4 (2013-08-08) +^^^^^^^^^^^^^^^^^^^ + +- Initial many-to-many relations support for batch_fetch + + 0.16.3 (2013-08-05) ^^^^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 4cb150a..6936436 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ for name, requirements in extras_require.items(): setup( name='SQLAlchemy-Utils', - version='0.16.3', + version='0.16.4', url='https://github.com/kvesteri/sqlalchemy-utils', license='BSD', author='Konsta Vesterinen', diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index ec20313..dec4c2e 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -34,7 +34,7 @@ from .types import ( ) -__version__ = '0.16.3' +__version__ = '0.16.4' __all__ = ( diff --git a/sqlalchemy_utils/functions.py b/sqlalchemy_utils/functions.py index b152469..2b961e6 100644 --- a/sqlalchemy_utils/functions.py +++ b/sqlalchemy_utils/functions.py @@ -390,7 +390,7 @@ def render_statement(statement, bind=None): return Compiler(bind.dialect, statement).process(statement) -def batch_fetch(entities, attr): +def batch_fetch(entities, *attr_paths): """ Batch fetch given relationship attribute for collection of entities. @@ -412,41 +412,94 @@ def batch_fetch(entities, attr): batch_fetch(users, User.phonenumbers) """ + if entities: first = entities[0] - if isinstance(attr, six.string_types): - attr = getattr( - first.__class__, attr - ) - - prop = attr.property - if not isinstance(prop, RelationshipProperty): - raise Exception( - 'Given attribute is not a relationship property.' - ) - - model = prop.mapper.class_ - session = object_session(first) - - if len(prop.remote_side) > 1: - raise Exception( - 'Only relationships with single remote side columns are ' - 'supported.' - ) - - column_name = list(prop.remote_side)[0].name parent_ids = [entity.id for entity in entities] - - related_entities = ( - session.query(model) - .filter( - getattr(model, column_name).in_(parent_ids) - ) - ) - parent_dict = dict((entity.id, []) for entity in entities) - for entity in related_entities: - parent_dict[getattr(entity, column_name)].append(entity) - for entity in entities: - set_committed_value(entity, prop.key, parent_dict[entity.id]) + for attr_path in attr_paths: + if isinstance(attr_path, six.string_types): + attrs = attr_path.split('.') + + if len(attrs) > 1: + related_entities = [] + for entity in entities: + related_entities.extend(getattr(entity, attrs[0])) + + batch_fetch(related_entities, '.'.join(attrs[1:])) + continue + else: + attr = getattr( + first.__class__, attrs[0] + ) + else: + attr = attr_path + + prop = attr.property + if not isinstance(prop, RelationshipProperty): + raise Exception( + 'Given attribute is not a relationship property.' + ) + + model = prop.mapper.class_ + + session = object_session(first) + + if prop.secondary is None: + if len(prop.remote_side) > 1: + raise Exception( + 'Only relationships with single remote side columns ' + 'are supported.' + ) + + column_name = list(prop.remote_side)[0].name + + related_entities = ( + session.query(model) + .filter( + getattr(model, column_name).in_(parent_ids) + ) + ) + + for entity in related_entities: + parent_dict[getattr(entity, column_name)].append( + entity + ) + + for entity in entities: + set_committed_value( + entity, prop.key, parent_dict[entity.id] + ) + else: + column_name = None + for column in prop.remote_side: + for fk in column.foreign_keys: + # TODO: make this support inherited tables + if fk.column.table == first.__class__.__table__: + column_name = fk.parent.name + break + if column_name: + break + + related_entities = ( + session + .query(model, getattr(prop.secondary.c, column_name)) + .join( + prop.secondary, prop.secondaryjoin + ) + .filter( + getattr(prop.secondary.c, column_name).in_( + parent_ids + ) + ) + ) + for entity, parent_id in related_entities: + parent_dict[parent_id].append( + entity + ) + + for entity in entities: + set_committed_value( + entity, prop.key, parent_dict[entity.id] + ) diff --git a/tests/__init__.py b/tests/__init__.py index e7c1894..ce022c8 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -28,6 +28,7 @@ class TestCase(object): def setup_method(self, method): self.engine = create_engine(self.dns) + #self.engine.echo = True self.connection = self.engine.connect() self.Base = declarative_base() diff --git a/tests/test_batch_fetch.py b/tests/test_batch_fetch.py index 0067943..236c056 100644 --- a/tests/test_batch_fetch.py +++ b/tests/test_batch_fetch.py @@ -1,22 +1,91 @@ +import sqlalchemy as sa from pytest import raises from sqlalchemy_utils import batch_fetch from tests import TestCase class TestBatchFetch(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) + + category = sa.orm.relationship( + Category, + primaryjoin=category_id == Category.id, + backref=sa.orm.backref( + 'articles' + ) + ) + + article_tag = sa.Table( + 'article_tag', + self.Base.metadata, + sa.Column( + 'article_id', + sa.Integer, + sa.ForeignKey('article.id', ondelete='cascade') + ), + sa.Column( + 'tag_id', + sa.Integer, + sa.ForeignKey('tag.id', ondelete='cascade') + ) + ) + + class Tag(self.Base): + __tablename__ = 'tag' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + articles = sa.orm.relationship( + Article, + secondary=article_tag, + backref=sa.orm.backref( + 'tags' + ) + ) + + self.User = User + self.Category = Category + self.Article = Article + self.Tag = Tag + def setup_method(self, method): TestCase.setup_method(self, method) - category = self.Category(name=u'Category #1') - category.articles = [ + articles = [ self.Article(name=u'Article 1'), - self.Article(name=u'Article 2') - ] - category2 = self.Category(name=u'Category #2') - category2.articles = [ + self.Article(name=u'Article 2'), self.Article(name=u'Article 3'), self.Article(name=u'Article 4'), self.Article(name=u'Article 5') ] + self.session.add_all(articles) + self.session.flush() + + tags = [ + self.Tag(name=u'Tag 1'), + self.Tag(name=u'Tag 2'), + self.Tag(name=u'Tag 3') + ] + articles[0].tags = tags + articles[3].tags = tags[1:] + + category = self.Category(name=u'Category #1') + category.articles = articles[0:2] + category2 = self.Category(name=u'Category #2') + category2.articles = articles[2:] self.session.add(category) self.session.add(category2) self.session.commit() @@ -32,3 +101,14 @@ class TestBatchFetch(TestCase): query_count = self.connection.query_count categories[0].articles # no lazy load should occur assert self.connection.query_count == query_count + + def test_multiple_relationships(self): + categories = self.session.query(self.Category).all() + batch_fetch( + categories, + 'articles', + 'articles.tags' + ) + query_count = self.connection.query_count + categories[0].articles[0].tags + assert self.connection.query_count == query_count