Added preliminary support for composite key batch fetching

This commit is contained in:
Konsta Vesterinen
2013-08-26 17:58:57 +03:00
parent f7307f3bcb
commit 523dacd0bb
3 changed files with 237 additions and 28 deletions

View File

@@ -201,7 +201,10 @@ class CompositeFetcher(object):
def fetch(self): def fetch(self):
for entity in self.related_entities: for entity in self.related_entities:
for fetcher in self.fetchers: for fetcher in self.fetchers:
if getattr(entity, fetcher.remote_column_name) is not None: if any(
getattr(entity, name)
for name in fetcher.remote_column_names
):
fetcher.append_entity(entity) fetcher.append_entity(entity)
def populate(self): def populate(self):
@@ -227,29 +230,46 @@ class Fetcher(object):
return self.path.session.query(self.path.model).filter(self.condition) return self.path.session.query(self.path.model).filter(self.condition)
@property @property
def remote_column_name(self): def local_column_names(self):
return list(self.path.property.remote_side)[0].name names = []
for local, remote in self.prop.local_remote_pairs:
for fk in remote.foreign_keys:
# TODO: make this support inherited tables
if fk.column.table in self.prop.parent.tables:
names.append(local.name)
return names
def parent_key(self, entity):
return tuple(
getattr(entity, name)
for name in self.remote_column_names
)
def local_values(self, entity): def local_values(self, entity):
return getattr(entity, list(self.prop.local_columns)[0].name) return tuple(
getattr(entity, name)
for name in self.local_column_names
)
def populate_backrefs(self, related_entities): def populate_backrefs(self, related_entities):
""" """
Populates backrefs for given related entities. Populates backrefs for given related entities.
""" """
backref_dict = dict( backref_dict = dict(
(self.local_values(entity), []) (self.local_values(value[0]), [])
for entity, parent_id in related_entities for value in related_entities
) )
for entity, parent_id in related_entities: for value in related_entities:
backref_dict[self.local_values(entity)].append( backref_dict[self.local_values(value[0])].append(
self.path.session.query(self.path.parent_model).get(parent_id) self.path.session.query(self.path.parent_model).get(
tuple(value[1:])
) )
for entity, parent_id in related_entities: )
for value in related_entities:
set_committed_value( set_committed_value(
entity, value[0],
self.prop.back_populates, self.prop.back_populates,
backref_dict[self.local_values(entity)] backref_dict[self.local_values(value[0])]
) )
def populate(self): def populate(self):
@@ -257,6 +277,12 @@ class Fetcher(object):
Populate batch fetched entities to parent objects. Populate batch fetched entities to parent objects.
""" """
for entity in self.path.entities: for entity in self.path.entities:
# print (
# "setting committed value for ",
# entity,
# " using local values ",
# self.local_values(entity)
# )
set_committed_value( set_committed_value(
entity, entity,
self.prop.key, self.prop.key,
@@ -268,9 +294,25 @@ class Fetcher(object):
@property @property
def condition(self): def condition(self):
return getattr(self.path.model, self.remote_column_name).in_( names = self.remote_column_names
self.local_values_list if len(names) == 1:
return getattr(self.path.model, names[0]).in_(
value[0] for value in self.local_values_list
) )
else:
conditions = []
for entity in self.path.entities:
conditions.append(
sa.and_(
*[
getattr(self.path.model, remote.name)
==
getattr(entity, local.name)
for local, remote in self.prop.local_remote_pairs
]
)
)
return sa.or_(*conditions)
def fetch(self): def fetch(self):
for entity in self.related_entities: for entity in self.related_entities:
@@ -279,12 +321,40 @@ class Fetcher(object):
class ManyToManyFetcher(Fetcher): class ManyToManyFetcher(Fetcher):
@property @property
def remote_column_name(self): def remote_column_names(self):
for column in self.prop.remote_side: names = []
for fk in column.foreign_keys: for local, remote in self.prop.local_remote_pairs:
for fk in remote.foreign_keys:
# TODO: make this support inherited tables # TODO: make this support inherited tables
if fk.column.table == self.path.parent_model.__table__: if fk.column.table == self.path.parent_model.__table__:
return fk.parent.name names.append(fk.parent.name)
return names
@property
def condition(self):
if len(self.remote_column_names) == 1:
return (
getattr(self.prop.secondary.c, self.remote_column_names[0])
.in_(
[value[0] for value in self.local_values_list]
)
)
else:
conditions = []
for entity in self.path.entities:
conditions.append(
sa.and_(
*[
getattr(self.prop.secondary.c, remote.name)
==
getattr(entity, local.name)
for local, remote in self.prop.local_remote_pairs
if remote.name in self.remote_column_names
]
)
)
return sa.or_(*conditions)
@property @property
def related_entities(self): def related_entities(self):
@@ -292,22 +362,23 @@ class ManyToManyFetcher(Fetcher):
self.path.session self.path.session
.query( .query(
self.path.model, self.path.model,
getattr(self.prop.secondary.c, self.remote_column_name) *[
getattr(self.prop.secondary.c, name)
for name in self.remote_column_names
]
) )
.join( .join(
self.prop.secondary, self.prop.secondaryjoin self.prop.secondary, self.prop.secondaryjoin
) )
.filter( .filter(
getattr(self.prop.secondary.c, self.remote_column_name).in_( self.condition
self.local_values_list
)
) )
) )
def fetch(self): def fetch(self):
for entity, parent_id in self.related_entities: for value in self.related_entities:
self.parent_dict[parent_id].append( self.parent_dict[tuple(value[1:])].append(
entity value[0]
) )
@@ -317,11 +388,38 @@ class ManyToOneFetcher(Fetcher):
self.parent_dict = defaultdict(lambda: None) self.parent_dict = defaultdict(lambda: None)
def append_entity(self, entity): def append_entity(self, entity):
self.parent_dict[getattr(entity, self.remote_column_name)] = entity #print 'appending entity ', entity, ' to key ', self.parent_key(entity)
self.parent_dict[self.parent_key(entity)] = entity
@property
def remote_column_names(self):
names = []
for local, remote in self.prop.local_remote_pairs:
names.append(remote.name)
return names
@property
def local_column_names(self):
names = []
for local, remote in self.prop.local_remote_pairs:
names.append(local.name)
return names
class OneToManyFetcher(Fetcher): class OneToManyFetcher(Fetcher):
def append_entity(self, entity): def append_entity(self, entity):
self.parent_dict[getattr(entity, self.remote_column_name)].append( #print 'appending entity ', entity, ' to key ', self.parent_key(entity)
self.parent_dict[self.parent_key(entity)].append(
entity entity
) )
@property
def remote_column_names(self):
names = []
for local, remote in self.prop.local_remote_pairs:
for fk in remote.foreign_keys:
# TODO: make this support inherited tables
if fk.column.table == self.path.parent_model.__table__:
names.append(fk.parent.name)
return names

View File

@@ -0,0 +1,111 @@
import sqlalchemy as sa
from sqlalchemy_utils import batch_fetch, with_backrefs
from tests import TestCase
class TestBatchFetchManyToManyCompositeRelationships(TestCase):
def create_models(self):
class Article(self.Base):
__tablename__ = 'article'
id1 = sa.Column(sa.Integer, primary_key=True)
id2 = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
article_tag = sa.Table(
'article_tag',
self.Base.metadata,
sa.Column(
'article_id1',
sa.Integer,
),
sa.Column(
'article_id2',
sa.Integer,
),
sa.Column(
'tag_id1',
sa.Integer,
),
sa.Column(
'tag_id2',
sa.Integer,
),
sa.ForeignKeyConstraint(
['article_id1', 'article_id2'],
['article.id1', 'article.id2']
),
sa.ForeignKeyConstraint(
['tag_id1', 'tag_id2'],
['tag.id1', 'tag.id2']
)
)
class Tag(self.Base):
__tablename__ = 'tag'
id1 = sa.Column(sa.Integer, primary_key=True)
id2 = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
articles = sa.orm.relationship(
Article,
secondary=article_tag,
backref=sa.orm.backref(
'tags',
),
)
self.Article = Article
self.Tag = Tag
def setup_method(self, method):
TestCase.setup_method(self, method)
articles = [
self.Article(id1=1, id2=2, name=u'Article 1'),
self.Article(id1=2, id2=2, name=u'Article 2'),
self.Article(id1=3, id2=3, name=u'Article 3'),
self.Article(id1=4, id2=3, name=u'Article 4'),
self.Article(id1=5, id2=3, name=u'Article 5')
]
self.session.add_all(articles)
self.session.flush()
tags = [
self.Tag(id1=1, id2=2, name=u'Tag 1'),
self.Tag(id1=2, id2=3, name=u'Tag 2'),
self.Tag(id1=3, id2=4, name=u'Tag 3')
]
articles[0].tags = tags
articles[3].tags = tags[1:]
self.session.commit()
def test_deep_relationships(self):
articles = (
self.session.query(self.Article)
.order_by(self.Article.id1).all()
)
batch_fetch(
articles,
'tags'
)
query_count = self.connection.query_count
assert articles[0].tags
articles[1].tags
assert articles[3].tags
assert self.connection.query_count == query_count
def test_many_to_many_backref_population(self):
articles = (
self.session.query(self.Article)
.order_by(self.Article.id1).all()
)
batch_fetch(
articles,
with_backrefs('tags'),
)
query_count = self.connection.query_count
tags = articles[0].tags
tags2 = articles[3].tags
tags[0].articles
tags2[0].articles
names = [article.name for article in tags[0].articles]
assert u'Article 1' in names
assert self.connection.query_count == query_count