Fixed fatal bug in inheritance handling of batch_fetch
This commit is contained in:
@@ -242,13 +242,7 @@ class Fetcher(object):
|
||||
|
||||
@property
|
||||
def local_column_names(self):
|
||||
names = []
|
||||
for local, remote in self.prop.local_remote_pairs:
|
||||
for fk in remote.foreign_keys:
|
||||
# TODO: make this support inherited tables
|
||||
if fk.column.table in self.prop.parent.tables:
|
||||
names.append(local.name)
|
||||
return names
|
||||
return [local.name for local, remote in self.prop.local_remote_pairs]
|
||||
|
||||
def parent_key(self, entity):
|
||||
return tuple(
|
||||
@@ -314,7 +308,7 @@ class Fetcher(object):
|
||||
return getattr(self.remote, names[0]).in_(
|
||||
value[0] for value in self.local_values_list
|
||||
)
|
||||
else:
|
||||
elif len(names) > 1:
|
||||
conditions = []
|
||||
for entity in self.path.entities:
|
||||
conditions.append(
|
||||
@@ -329,17 +323,35 @@ class Fetcher(object):
|
||||
)
|
||||
)
|
||||
return sa.or_(*conditions)
|
||||
else:
|
||||
raise Exception(
|
||||
'Could not obtain remote column names.'
|
||||
)
|
||||
|
||||
def fetch(self):
|
||||
for entity in self.related_entities:
|
||||
self.append_entity(entity)
|
||||
|
||||
@property
|
||||
def remote_column_names(self):
|
||||
return [remote.name for local, remote in self.prop.local_remote_pairs]
|
||||
|
||||
|
||||
class ManyToManyFetcher(Fetcher):
|
||||
@property
|
||||
def remote(self):
|
||||
return self.prop.secondary.c
|
||||
|
||||
@property
|
||||
def local_column_names(self):
|
||||
names = []
|
||||
for local, remote in self.prop.local_remote_pairs:
|
||||
for fk in remote.foreign_keys:
|
||||
# TODO: make this support inherited tables
|
||||
if fk.column.table in self.prop.parent.tables:
|
||||
names.append(local.name)
|
||||
return names
|
||||
|
||||
@property
|
||||
def remote_column_names(self):
|
||||
names = []
|
||||
@@ -383,14 +395,6 @@ class ManyToOneFetcher(Fetcher):
|
||||
#print 'appending entity ', entity, ' to key ', self.parent_key(entity)
|
||||
self.parent_dict[self.parent_key(entity)] = entity
|
||||
|
||||
@property
|
||||
def remote_column_names(self):
|
||||
return [remote.name for local, remote in self.prop.local_remote_pairs]
|
||||
|
||||
@property
|
||||
def local_column_names(self):
|
||||
return [local.name for local, remote in self.prop.local_remote_pairs]
|
||||
|
||||
|
||||
class OneToManyFetcher(Fetcher):
|
||||
def append_entity(self, entity):
|
||||
@@ -398,14 +402,3 @@ class OneToManyFetcher(Fetcher):
|
||||
self.parent_dict[self.parent_key(entity)].append(
|
||||
entity
|
||||
)
|
||||
|
||||
@property
|
||||
def remote_column_names(self):
|
||||
names = []
|
||||
for local, remote in self.prop.local_remote_pairs:
|
||||
for fk in remote.foreign_keys:
|
||||
# TODO: make this support inherited tables
|
||||
if fk.column.table == self.path.parent_model.__table__:
|
||||
names.append(fk.parent.name)
|
||||
|
||||
return names
|
||||
|
@@ -3,13 +3,18 @@ from sqlalchemy_utils import batch_fetch
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestBatchFetchJoinTableInheritedModels(TestCase):
|
||||
class JoinTableInheritanceTestCase(TestCase):
|
||||
def create_models(self):
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class TextItem(self.Base):
|
||||
__tablename__ = 'text_item'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
@@ -18,6 +23,16 @@ class TestBatchFetchJoinTableInheritedModels(TestCase):
|
||||
sa.Integer,
|
||||
sa.ForeignKey(Category.id)
|
||||
)
|
||||
author_id = sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey(User.id)
|
||||
)
|
||||
author = sa.orm.relationship(
|
||||
User,
|
||||
backref=sa.orm.backref(
|
||||
'text_items'
|
||||
)
|
||||
)
|
||||
|
||||
type = sa.Column(sa.Unicode(255))
|
||||
|
||||
@@ -55,10 +70,26 @@ class TestBatchFetchJoinTableInheritedModels(TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.TextItem = TextItem
|
||||
self.Category = Category
|
||||
class Attachment(self.Base):
|
||||
__tablename__ = 'attachment'
|
||||
id = sa.Column(
|
||||
sa.Integer, primary_key=True
|
||||
)
|
||||
name = sa.Column(
|
||||
sa.Unicode(255), index=True
|
||||
)
|
||||
text_item_id = sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey(TextItem.id),
|
||||
)
|
||||
text_item = sa.orm.relationship(TextItem, backref='attachments')
|
||||
|
||||
self.Article = Article
|
||||
self.Attachment = Attachment
|
||||
self.BlogPost = BlogPost
|
||||
self.Category = Category
|
||||
self.TextItem = TextItem
|
||||
self.User = User
|
||||
|
||||
def setup_method(self, method):
|
||||
TestCase.setup_method(self, method)
|
||||
@@ -82,8 +113,17 @@ class TestBatchFetchJoinTableInheritedModels(TestCase):
|
||||
category2.blog_posts = text_items[-1:]
|
||||
self.session.add(category)
|
||||
self.session.add(category2)
|
||||
text_items[0].attachments = [
|
||||
self.Attachment(id=22, name=u'Attachment 1'),
|
||||
self.Attachment(id=34, name=u'Attachment 2')
|
||||
]
|
||||
text_items[0].author = self.User(name=u'John Matrix')
|
||||
text_items[1].author = self.User(name=u'John Doe')
|
||||
|
||||
self.session.commit()
|
||||
|
||||
|
||||
class TestBatchFetchJoinTableInheritedModels(JoinTableInheritanceTestCase):
|
||||
def test_multiple_relationships(self):
|
||||
categories = self.session.query(self.Category).all()
|
||||
batch_fetch(
|
||||
@@ -98,3 +138,25 @@ class TestBatchFetchJoinTableInheritedModels(TestCase):
|
||||
categories[1].articles[1]
|
||||
categories[1].blog_posts[0]
|
||||
assert self.connection.query_count == query_count
|
||||
|
||||
def test_one_to_many_relationships(self):
|
||||
articles = (
|
||||
self.session.query(self.Article)
|
||||
.filter_by(name=u'Article 1')
|
||||
.all()
|
||||
)
|
||||
batch_fetch(
|
||||
articles,
|
||||
'attachments'
|
||||
)
|
||||
|
||||
def test_many_to_one_relationships(self):
|
||||
articles = (
|
||||
self.session.query(self.Article)
|
||||
.filter_by(name=u'Article 1')
|
||||
.all()
|
||||
)
|
||||
batch_fetch(
|
||||
articles,
|
||||
'author'
|
||||
)
|
||||
|
Reference in New Issue
Block a user