Added preliminary support for composite key batch fetching
This commit is contained in:
@@ -201,7 +201,10 @@ class CompositeFetcher(object):
|
|||||||
def fetch(self):
|
def fetch(self):
|
||||||
for entity in self.related_entities:
|
for entity in self.related_entities:
|
||||||
for fetcher in self.fetchers:
|
for fetcher in self.fetchers:
|
||||||
if getattr(entity, fetcher.remote_column_name) is not None:
|
if any(
|
||||||
|
getattr(entity, name)
|
||||||
|
for name in fetcher.remote_column_names
|
||||||
|
):
|
||||||
fetcher.append_entity(entity)
|
fetcher.append_entity(entity)
|
||||||
|
|
||||||
def populate(self):
|
def populate(self):
|
||||||
@@ -227,29 +230,46 @@ class Fetcher(object):
|
|||||||
return self.path.session.query(self.path.model).filter(self.condition)
|
return self.path.session.query(self.path.model).filter(self.condition)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def remote_column_name(self):
|
def local_column_names(self):
|
||||||
return list(self.path.property.remote_side)[0].name
|
names = []
|
||||||
|
for local, remote in self.prop.local_remote_pairs:
|
||||||
|
for fk in remote.foreign_keys:
|
||||||
|
# TODO: make this support inherited tables
|
||||||
|
if fk.column.table in self.prop.parent.tables:
|
||||||
|
names.append(local.name)
|
||||||
|
return names
|
||||||
|
|
||||||
|
def parent_key(self, entity):
|
||||||
|
return tuple(
|
||||||
|
getattr(entity, name)
|
||||||
|
for name in self.remote_column_names
|
||||||
|
)
|
||||||
|
|
||||||
def local_values(self, entity):
|
def local_values(self, entity):
|
||||||
return getattr(entity, list(self.prop.local_columns)[0].name)
|
return tuple(
|
||||||
|
getattr(entity, name)
|
||||||
|
for name in self.local_column_names
|
||||||
|
)
|
||||||
|
|
||||||
def populate_backrefs(self, related_entities):
|
def populate_backrefs(self, related_entities):
|
||||||
"""
|
"""
|
||||||
Populates backrefs for given related entities.
|
Populates backrefs for given related entities.
|
||||||
"""
|
"""
|
||||||
backref_dict = dict(
|
backref_dict = dict(
|
||||||
(self.local_values(entity), [])
|
(self.local_values(value[0]), [])
|
||||||
for entity, parent_id in related_entities
|
for value in related_entities
|
||||||
)
|
)
|
||||||
for entity, parent_id in related_entities:
|
for value in related_entities:
|
||||||
backref_dict[self.local_values(entity)].append(
|
backref_dict[self.local_values(value[0])].append(
|
||||||
self.path.session.query(self.path.parent_model).get(parent_id)
|
self.path.session.query(self.path.parent_model).get(
|
||||||
|
tuple(value[1:])
|
||||||
)
|
)
|
||||||
for entity, parent_id in related_entities:
|
)
|
||||||
|
for value in related_entities:
|
||||||
set_committed_value(
|
set_committed_value(
|
||||||
entity,
|
value[0],
|
||||||
self.prop.back_populates,
|
self.prop.back_populates,
|
||||||
backref_dict[self.local_values(entity)]
|
backref_dict[self.local_values(value[0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
def populate(self):
|
def populate(self):
|
||||||
@@ -257,6 +277,12 @@ class Fetcher(object):
|
|||||||
Populate batch fetched entities to parent objects.
|
Populate batch fetched entities to parent objects.
|
||||||
"""
|
"""
|
||||||
for entity in self.path.entities:
|
for entity in self.path.entities:
|
||||||
|
# print (
|
||||||
|
# "setting committed value for ",
|
||||||
|
# entity,
|
||||||
|
# " using local values ",
|
||||||
|
# self.local_values(entity)
|
||||||
|
# )
|
||||||
set_committed_value(
|
set_committed_value(
|
||||||
entity,
|
entity,
|
||||||
self.prop.key,
|
self.prop.key,
|
||||||
@@ -268,9 +294,25 @@ class Fetcher(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def condition(self):
|
def condition(self):
|
||||||
return getattr(self.path.model, self.remote_column_name).in_(
|
names = self.remote_column_names
|
||||||
self.local_values_list
|
if len(names) == 1:
|
||||||
|
return getattr(self.path.model, names[0]).in_(
|
||||||
|
value[0] for value in self.local_values_list
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
conditions = []
|
||||||
|
for entity in self.path.entities:
|
||||||
|
conditions.append(
|
||||||
|
sa.and_(
|
||||||
|
*[
|
||||||
|
getattr(self.path.model, remote.name)
|
||||||
|
==
|
||||||
|
getattr(entity, local.name)
|
||||||
|
for local, remote in self.prop.local_remote_pairs
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return sa.or_(*conditions)
|
||||||
|
|
||||||
def fetch(self):
|
def fetch(self):
|
||||||
for entity in self.related_entities:
|
for entity in self.related_entities:
|
||||||
@@ -279,12 +321,40 @@ class Fetcher(object):
|
|||||||
|
|
||||||
class ManyToManyFetcher(Fetcher):
|
class ManyToManyFetcher(Fetcher):
|
||||||
@property
|
@property
|
||||||
def remote_column_name(self):
|
def remote_column_names(self):
|
||||||
for column in self.prop.remote_side:
|
names = []
|
||||||
for fk in column.foreign_keys:
|
for local, remote in self.prop.local_remote_pairs:
|
||||||
|
for fk in remote.foreign_keys:
|
||||||
# TODO: make this support inherited tables
|
# TODO: make this support inherited tables
|
||||||
if fk.column.table == self.path.parent_model.__table__:
|
if fk.column.table == self.path.parent_model.__table__:
|
||||||
return fk.parent.name
|
names.append(fk.parent.name)
|
||||||
|
|
||||||
|
return names
|
||||||
|
|
||||||
|
@property
|
||||||
|
def condition(self):
|
||||||
|
if len(self.remote_column_names) == 1:
|
||||||
|
return (
|
||||||
|
getattr(self.prop.secondary.c, self.remote_column_names[0])
|
||||||
|
.in_(
|
||||||
|
[value[0] for value in self.local_values_list]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
conditions = []
|
||||||
|
for entity in self.path.entities:
|
||||||
|
conditions.append(
|
||||||
|
sa.and_(
|
||||||
|
*[
|
||||||
|
getattr(self.prop.secondary.c, remote.name)
|
||||||
|
==
|
||||||
|
getattr(entity, local.name)
|
||||||
|
for local, remote in self.prop.local_remote_pairs
|
||||||
|
if remote.name in self.remote_column_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return sa.or_(*conditions)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def related_entities(self):
|
def related_entities(self):
|
||||||
@@ -292,22 +362,23 @@ class ManyToManyFetcher(Fetcher):
|
|||||||
self.path.session
|
self.path.session
|
||||||
.query(
|
.query(
|
||||||
self.path.model,
|
self.path.model,
|
||||||
getattr(self.prop.secondary.c, self.remote_column_name)
|
*[
|
||||||
|
getattr(self.prop.secondary.c, name)
|
||||||
|
for name in self.remote_column_names
|
||||||
|
]
|
||||||
)
|
)
|
||||||
.join(
|
.join(
|
||||||
self.prop.secondary, self.prop.secondaryjoin
|
self.prop.secondary, self.prop.secondaryjoin
|
||||||
)
|
)
|
||||||
.filter(
|
.filter(
|
||||||
getattr(self.prop.secondary.c, self.remote_column_name).in_(
|
self.condition
|
||||||
self.local_values_list
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def fetch(self):
|
def fetch(self):
|
||||||
for entity, parent_id in self.related_entities:
|
for value in self.related_entities:
|
||||||
self.parent_dict[parent_id].append(
|
self.parent_dict[tuple(value[1:])].append(
|
||||||
entity
|
value[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -317,11 +388,38 @@ class ManyToOneFetcher(Fetcher):
|
|||||||
self.parent_dict = defaultdict(lambda: None)
|
self.parent_dict = defaultdict(lambda: None)
|
||||||
|
|
||||||
def append_entity(self, entity):
|
def append_entity(self, entity):
|
||||||
self.parent_dict[getattr(entity, self.remote_column_name)] = entity
|
#print 'appending entity ', entity, ' to key ', self.parent_key(entity)
|
||||||
|
self.parent_dict[self.parent_key(entity)] = entity
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remote_column_names(self):
|
||||||
|
names = []
|
||||||
|
for local, remote in self.prop.local_remote_pairs:
|
||||||
|
names.append(remote.name)
|
||||||
|
return names
|
||||||
|
|
||||||
|
@property
|
||||||
|
def local_column_names(self):
|
||||||
|
names = []
|
||||||
|
for local, remote in self.prop.local_remote_pairs:
|
||||||
|
names.append(local.name)
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
class OneToManyFetcher(Fetcher):
|
class OneToManyFetcher(Fetcher):
|
||||||
def append_entity(self, entity):
|
def append_entity(self, entity):
|
||||||
self.parent_dict[getattr(entity, self.remote_column_name)].append(
|
#print 'appending entity ', entity, ' to key ', self.parent_key(entity)
|
||||||
|
self.parent_dict[self.parent_key(entity)].append(
|
||||||
entity
|
entity
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remote_column_names(self):
|
||||||
|
names = []
|
||||||
|
for local, remote in self.prop.local_remote_pairs:
|
||||||
|
for fk in remote.foreign_keys:
|
||||||
|
# TODO: make this support inherited tables
|
||||||
|
if fk.column.table == self.path.parent_model.__table__:
|
||||||
|
names.append(fk.parent.name)
|
||||||
|
|
||||||
|
return names
|
||||||
|
111
tests/batch_fetch/test_many_to_many_composite_keys.py
Normal file
111
tests/batch_fetch/test_many_to_many_composite_keys.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils import batch_fetch, with_backrefs
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchFetchManyToManyCompositeRelationships(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Article(self.Base):
|
||||||
|
__tablename__ = 'article'
|
||||||
|
id1 = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
id2 = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
article_tag = sa.Table(
|
||||||
|
'article_tag',
|
||||||
|
self.Base.metadata,
|
||||||
|
sa.Column(
|
||||||
|
'article_id1',
|
||||||
|
sa.Integer,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
'article_id2',
|
||||||
|
sa.Integer,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
'tag_id1',
|
||||||
|
sa.Integer,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
'tag_id2',
|
||||||
|
sa.Integer,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
['article_id1', 'article_id2'],
|
||||||
|
['article.id1', 'article.id2']
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
['tag_id1', 'tag_id2'],
|
||||||
|
['tag.id1', 'tag.id2']
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
class Tag(self.Base):
|
||||||
|
__tablename__ = 'tag'
|
||||||
|
id1 = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
id2 = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
articles = sa.orm.relationship(
|
||||||
|
Article,
|
||||||
|
secondary=article_tag,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'tags',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Article = Article
|
||||||
|
self.Tag = Tag
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
TestCase.setup_method(self, method)
|
||||||
|
articles = [
|
||||||
|
self.Article(id1=1, id2=2, name=u'Article 1'),
|
||||||
|
self.Article(id1=2, id2=2, name=u'Article 2'),
|
||||||
|
self.Article(id1=3, id2=3, name=u'Article 3'),
|
||||||
|
self.Article(id1=4, id2=3, name=u'Article 4'),
|
||||||
|
self.Article(id1=5, id2=3, name=u'Article 5')
|
||||||
|
]
|
||||||
|
self.session.add_all(articles)
|
||||||
|
self.session.flush()
|
||||||
|
|
||||||
|
tags = [
|
||||||
|
self.Tag(id1=1, id2=2, name=u'Tag 1'),
|
||||||
|
self.Tag(id1=2, id2=3, name=u'Tag 2'),
|
||||||
|
self.Tag(id1=3, id2=4, name=u'Tag 3')
|
||||||
|
]
|
||||||
|
articles[0].tags = tags
|
||||||
|
articles[3].tags = tags[1:]
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
def test_deep_relationships(self):
|
||||||
|
articles = (
|
||||||
|
self.session.query(self.Article)
|
||||||
|
.order_by(self.Article.id1).all()
|
||||||
|
)
|
||||||
|
batch_fetch(
|
||||||
|
articles,
|
||||||
|
'tags'
|
||||||
|
)
|
||||||
|
query_count = self.connection.query_count
|
||||||
|
assert articles[0].tags
|
||||||
|
articles[1].tags
|
||||||
|
assert articles[3].tags
|
||||||
|
assert self.connection.query_count == query_count
|
||||||
|
|
||||||
|
def test_many_to_many_backref_population(self):
|
||||||
|
articles = (
|
||||||
|
self.session.query(self.Article)
|
||||||
|
.order_by(self.Article.id1).all()
|
||||||
|
)
|
||||||
|
batch_fetch(
|
||||||
|
articles,
|
||||||
|
with_backrefs('tags'),
|
||||||
|
)
|
||||||
|
query_count = self.connection.query_count
|
||||||
|
tags = articles[0].tags
|
||||||
|
tags2 = articles[3].tags
|
||||||
|
tags[0].articles
|
||||||
|
tags2[0].articles
|
||||||
|
names = [article.name for article in tags[0].articles]
|
||||||
|
assert u'Article 1' in names
|
||||||
|
assert self.connection.query_count == query_count
|
Reference in New Issue
Block a user