Fixed fatal bug in inheritance handling of batch_fetch

This commit is contained in:
Konsta Vesterinen
2013-09-18 15:05:19 +03:00
parent 97d42d3d0a
commit b1bb5b8f43
2 changed files with 85 additions and 30 deletions

View File

@@ -242,13 +242,7 @@ class Fetcher(object):
@property @property
def local_column_names(self): def local_column_names(self):
names = [] return [local.name for local, remote in self.prop.local_remote_pairs]
for local, remote in self.prop.local_remote_pairs:
for fk in remote.foreign_keys:
# TODO: make this support inherited tables
if fk.column.table in self.prop.parent.tables:
names.append(local.name)
return names
def parent_key(self, entity): def parent_key(self, entity):
return tuple( return tuple(
@@ -314,7 +308,7 @@ class Fetcher(object):
return getattr(self.remote, names[0]).in_( return getattr(self.remote, names[0]).in_(
value[0] for value in self.local_values_list value[0] for value in self.local_values_list
) )
else: elif len(names) > 1:
conditions = [] conditions = []
for entity in self.path.entities: for entity in self.path.entities:
conditions.append( conditions.append(
@@ -329,17 +323,35 @@ class Fetcher(object):
) )
) )
return sa.or_(*conditions) return sa.or_(*conditions)
else:
raise Exception(
'Could not obtain remote column names.'
)
def fetch(self): def fetch(self):
for entity in self.related_entities: for entity in self.related_entities:
self.append_entity(entity) self.append_entity(entity)
@property
def remote_column_names(self):
return [remote.name for local, remote in self.prop.local_remote_pairs]
class ManyToManyFetcher(Fetcher): class ManyToManyFetcher(Fetcher):
@property @property
def remote(self): def remote(self):
return self.prop.secondary.c return self.prop.secondary.c
@property
def local_column_names(self):
names = []
for local, remote in self.prop.local_remote_pairs:
for fk in remote.foreign_keys:
# TODO: make this support inherited tables
if fk.column.table in self.prop.parent.tables:
names.append(local.name)
return names
@property @property
def remote_column_names(self): def remote_column_names(self):
names = [] names = []
@@ -383,14 +395,6 @@ class ManyToOneFetcher(Fetcher):
#print 'appending entity ', entity, ' to key ', self.parent_key(entity) #print 'appending entity ', entity, ' to key ', self.parent_key(entity)
self.parent_dict[self.parent_key(entity)] = entity self.parent_dict[self.parent_key(entity)] = entity
@property
def remote_column_names(self):
return [remote.name for local, remote in self.prop.local_remote_pairs]
@property
def local_column_names(self):
return [local.name for local, remote in self.prop.local_remote_pairs]
class OneToManyFetcher(Fetcher): class OneToManyFetcher(Fetcher):
def append_entity(self, entity): def append_entity(self, entity):
@@ -398,14 +402,3 @@ class OneToManyFetcher(Fetcher):
self.parent_dict[self.parent_key(entity)].append( self.parent_dict[self.parent_key(entity)].append(
entity entity
) )
@property
def remote_column_names(self):
names = []
for local, remote in self.prop.local_remote_pairs:
for fk in remote.foreign_keys:
# TODO: make this support inherited tables
if fk.column.table == self.path.parent_model.__table__:
names.append(fk.parent.name)
return names

View File

@@ -3,13 +3,18 @@ from sqlalchemy_utils import batch_fetch
from tests import TestCase from tests import TestCase
class TestBatchFetchJoinTableInheritedModels(TestCase): class JoinTableInheritanceTestCase(TestCase):
def create_models(self): def create_models(self):
class Category(self.Base): class Category(self.Base):
__tablename__ = 'category' __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class TextItem(self.Base): class TextItem(self.Base):
__tablename__ = 'text_item' __tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
@@ -18,6 +23,16 @@ class TestBatchFetchJoinTableInheritedModels(TestCase):
sa.Integer, sa.Integer,
sa.ForeignKey(Category.id) sa.ForeignKey(Category.id)
) )
author_id = sa.Column(
sa.Integer,
sa.ForeignKey(User.id)
)
author = sa.orm.relationship(
User,
backref=sa.orm.backref(
'text_items'
)
)
type = sa.Column(sa.Unicode(255)) type = sa.Column(sa.Unicode(255))
@@ -55,10 +70,26 @@ class TestBatchFetchJoinTableInheritedModels(TestCase):
) )
) )
self.TextItem = TextItem class Attachment(self.Base):
self.Category = Category __tablename__ = 'attachment'
id = sa.Column(
sa.Integer, primary_key=True
)
name = sa.Column(
sa.Unicode(255), index=True
)
text_item_id = sa.Column(
sa.Integer,
sa.ForeignKey(TextItem.id),
)
text_item = sa.orm.relationship(TextItem, backref='attachments')
self.Article = Article self.Article = Article
self.Attachment = Attachment
self.BlogPost = BlogPost self.BlogPost = BlogPost
self.Category = Category
self.TextItem = TextItem
self.User = User
def setup_method(self, method): def setup_method(self, method):
TestCase.setup_method(self, method) TestCase.setup_method(self, method)
@@ -82,8 +113,17 @@ class TestBatchFetchJoinTableInheritedModels(TestCase):
category2.blog_posts = text_items[-1:] category2.blog_posts = text_items[-1:]
self.session.add(category) self.session.add(category)
self.session.add(category2) self.session.add(category2)
text_items[0].attachments = [
self.Attachment(id=22, name=u'Attachment 1'),
self.Attachment(id=34, name=u'Attachment 2')
]
text_items[0].author = self.User(name=u'John Matrix')
text_items[1].author = self.User(name=u'John Doe')
self.session.commit() self.session.commit()
class TestBatchFetchJoinTableInheritedModels(JoinTableInheritanceTestCase):
def test_multiple_relationships(self): def test_multiple_relationships(self):
categories = self.session.query(self.Category).all() categories = self.session.query(self.Category).all()
batch_fetch( batch_fetch(
@@ -98,3 +138,25 @@ class TestBatchFetchJoinTableInheritedModels(TestCase):
categories[1].articles[1] categories[1].articles[1]
categories[1].blog_posts[0] categories[1].blog_posts[0]
assert self.connection.query_count == query_count assert self.connection.query_count == query_count
def test_one_to_many_relationships(self):
articles = (
self.session.query(self.Article)
.filter_by(name=u'Article 1')
.all()
)
batch_fetch(
articles,
'attachments'
)
def test_many_to_one_relationships(self):
articles = (
self.session.query(self.Article)
.filter_by(name=u'Article 1')
.all()
)
batch_fetch(
articles,
'author'
)