From b1bb5b8f4362cecf39cba5b85090aeb1c1ec3b8b Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Wed, 18 Sep 2013 15:05:19 +0300 Subject: [PATCH] Fixed fatal bug in inheritance handling of batch_fetch --- sqlalchemy_utils/functions/batch_fetch.py | 47 ++++++------- .../test_join_table_inheritance.py | 68 ++++++++++++++++++- 2 files changed, 85 insertions(+), 30 deletions(-) diff --git a/sqlalchemy_utils/functions/batch_fetch.py b/sqlalchemy_utils/functions/batch_fetch.py index 6b0430e..fcf8d5a 100644 --- a/sqlalchemy_utils/functions/batch_fetch.py +++ b/sqlalchemy_utils/functions/batch_fetch.py @@ -242,13 +242,7 @@ class Fetcher(object): @property def local_column_names(self): - names = [] - for local, remote in self.prop.local_remote_pairs: - for fk in remote.foreign_keys: - # TODO: make this support inherited tables - if fk.column.table in self.prop.parent.tables: - names.append(local.name) - return names + return [local.name for local, remote in self.prop.local_remote_pairs] def parent_key(self, entity): return tuple( @@ -314,7 +308,7 @@ class Fetcher(object): return getattr(self.remote, names[0]).in_( value[0] for value in self.local_values_list ) - else: + elif len(names) > 1: conditions = [] for entity in self.path.entities: conditions.append( @@ -329,17 +323,35 @@ class Fetcher(object): ) ) return sa.or_(*conditions) + else: + raise Exception( + 'Could not obtain remote column names.' + ) def fetch(self): for entity in self.related_entities: self.append_entity(entity) + @property + def remote_column_names(self): + return [remote.name for local, remote in self.prop.local_remote_pairs] + class ManyToManyFetcher(Fetcher): @property def remote(self): return self.prop.secondary.c + @property + def local_column_names(self): + names = [] + for local, remote in self.prop.local_remote_pairs: + for fk in remote.foreign_keys: + # TODO: make this support inherited tables + if fk.column.table in self.prop.parent.tables: + names.append(local.name) + return names + @property def remote_column_names(self): names = [] @@ -383,14 +395,6 @@ class ManyToOneFetcher(Fetcher): #print 'appending entity ', entity, ' to key ', self.parent_key(entity) self.parent_dict[self.parent_key(entity)] = entity - @property - def remote_column_names(self): - return [remote.name for local, remote in self.prop.local_remote_pairs] - - @property - def local_column_names(self): - return [local.name for local, remote in self.prop.local_remote_pairs] - class OneToManyFetcher(Fetcher): def append_entity(self, entity): @@ -398,14 +402,3 @@ class OneToManyFetcher(Fetcher): self.parent_dict[self.parent_key(entity)].append( entity ) - - @property - def remote_column_names(self): - names = [] - for local, remote in self.prop.local_remote_pairs: - for fk in remote.foreign_keys: - # TODO: make this support inherited tables - if fk.column.table == self.path.parent_model.__table__: - names.append(fk.parent.name) - - return names diff --git a/tests/batch_fetch/test_join_table_inheritance.py b/tests/batch_fetch/test_join_table_inheritance.py index d793d4e..a7f66be 100644 --- a/tests/batch_fetch/test_join_table_inheritance.py +++ b/tests/batch_fetch/test_join_table_inheritance.py @@ -3,13 +3,18 @@ from sqlalchemy_utils import batch_fetch from tests import TestCase -class TestBatchFetchJoinTableInheritedModels(TestCase): +class JoinTableInheritanceTestCase(TestCase): def create_models(self): class Category(self.Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + class TextItem(self.Base): __tablename__ = 'text_item' id = sa.Column(sa.Integer, primary_key=True) @@ -18,6 +23,16 @@ class TestBatchFetchJoinTableInheritedModels(TestCase): sa.Integer, sa.ForeignKey(Category.id) ) + author_id = sa.Column( + sa.Integer, + sa.ForeignKey(User.id) + ) + author = sa.orm.relationship( + User, + backref=sa.orm.backref( + 'text_items' + ) + ) type = sa.Column(sa.Unicode(255)) @@ -55,10 +70,26 @@ class TestBatchFetchJoinTableInheritedModels(TestCase): ) ) - self.TextItem = TextItem - self.Category = Category + class Attachment(self.Base): + __tablename__ = 'attachment' + id = sa.Column( + sa.Integer, primary_key=True + ) + name = sa.Column( + sa.Unicode(255), index=True + ) + text_item_id = sa.Column( + sa.Integer, + sa.ForeignKey(TextItem.id), + ) + text_item = sa.orm.relationship(TextItem, backref='attachments') + self.Article = Article + self.Attachment = Attachment self.BlogPost = BlogPost + self.Category = Category + self.TextItem = TextItem + self.User = User def setup_method(self, method): TestCase.setup_method(self, method) @@ -82,8 +113,17 @@ class TestBatchFetchJoinTableInheritedModels(TestCase): category2.blog_posts = text_items[-1:] self.session.add(category) self.session.add(category2) + text_items[0].attachments = [ + self.Attachment(id=22, name=u'Attachment 1'), + self.Attachment(id=34, name=u'Attachment 2') + ] + text_items[0].author = self.User(name=u'John Matrix') + text_items[1].author = self.User(name=u'John Doe') + self.session.commit() + +class TestBatchFetchJoinTableInheritedModels(JoinTableInheritanceTestCase): def test_multiple_relationships(self): categories = self.session.query(self.Category).all() batch_fetch( @@ -98,3 +138,25 @@ class TestBatchFetchJoinTableInheritedModels(TestCase): categories[1].articles[1] categories[1].blog_posts[0] assert self.connection.query_count == query_count + + def test_one_to_many_relationships(self): + articles = ( + self.session.query(self.Article) + .filter_by(name=u'Article 1') + .all() + ) + batch_fetch( + articles, + 'attachments' + ) + + def test_many_to_one_relationships(self): + articles = ( + self.session.query(self.Article) + .filter_by(name=u'Article 1') + .all() + ) + batch_fetch( + articles, + 'author' + )