diff --git a/sqlalchemy_utils/functions/batch_fetch.py b/sqlalchemy_utils/functions/batch_fetch.py index fc73282..c1afac1 100644 --- a/sqlalchemy_utils/functions/batch_fetch.py +++ b/sqlalchemy_utils/functions/batch_fetch.py @@ -201,7 +201,10 @@ class CompositeFetcher(object): def fetch(self): for entity in self.related_entities: for fetcher in self.fetchers: - if getattr(entity, fetcher.remote_column_name) is not None: + if any( + getattr(entity, name) + for name in fetcher.remote_column_names + ): fetcher.append_entity(entity) def populate(self): @@ -227,29 +230,46 @@ class Fetcher(object): return self.path.session.query(self.path.model).filter(self.condition) @property - def remote_column_name(self): - return list(self.path.property.remote_side)[0].name + def local_column_names(self): + names = [] + for local, remote in self.prop.local_remote_pairs: + for fk in remote.foreign_keys: + # TODO: make this support inherited tables + if fk.column.table in self.prop.parent.tables: + names.append(local.name) + return names + + def parent_key(self, entity): + return tuple( + getattr(entity, name) + for name in self.remote_column_names + ) def local_values(self, entity): - return getattr(entity, list(self.prop.local_columns)[0].name) + return tuple( + getattr(entity, name) + for name in self.local_column_names + ) def populate_backrefs(self, related_entities): """ Populates backrefs for given related entities. """ backref_dict = dict( - (self.local_values(entity), []) - for entity, parent_id in related_entities + (self.local_values(value[0]), []) + for value in related_entities ) - for entity, parent_id in related_entities: - backref_dict[self.local_values(entity)].append( - self.path.session.query(self.path.parent_model).get(parent_id) + for value in related_entities: + backref_dict[self.local_values(value[0])].append( + self.path.session.query(self.path.parent_model).get( + tuple(value[1:]) + ) ) - for entity, parent_id in related_entities: + for value in related_entities: set_committed_value( - entity, + value[0], self.prop.back_populates, - backref_dict[self.local_values(entity)] + backref_dict[self.local_values(value[0])] ) def populate(self): @@ -257,6 +277,12 @@ class Fetcher(object): Populate batch fetched entities to parent objects. """ for entity in self.path.entities: + # print ( + # "setting committed value for ", + # entity, + # " using local values ", + # self.local_values(entity) + # ) set_committed_value( entity, self.prop.key, @@ -268,9 +294,25 @@ class Fetcher(object): @property def condition(self): - return getattr(self.path.model, self.remote_column_name).in_( - self.local_values_list - ) + names = self.remote_column_names + if len(names) == 1: + return getattr(self.path.model, names[0]).in_( + value[0] for value in self.local_values_list + ) + else: + conditions = [] + for entity in self.path.entities: + conditions.append( + sa.and_( + *[ + getattr(self.path.model, remote.name) + == + getattr(entity, local.name) + for local, remote in self.prop.local_remote_pairs + ] + ) + ) + return sa.or_(*conditions) def fetch(self): for entity in self.related_entities: @@ -279,12 +321,40 @@ class Fetcher(object): class ManyToManyFetcher(Fetcher): @property - def remote_column_name(self): - for column in self.prop.remote_side: - for fk in column.foreign_keys: + def remote_column_names(self): + names = [] + for local, remote in self.prop.local_remote_pairs: + for fk in remote.foreign_keys: # TODO: make this support inherited tables if fk.column.table == self.path.parent_model.__table__: - return fk.parent.name + names.append(fk.parent.name) + + return names + + @property + def condition(self): + if len(self.remote_column_names) == 1: + return ( + getattr(self.prop.secondary.c, self.remote_column_names[0]) + .in_( + [value[0] for value in self.local_values_list] + ) + ) + else: + conditions = [] + for entity in self.path.entities: + conditions.append( + sa.and_( + *[ + getattr(self.prop.secondary.c, remote.name) + == + getattr(entity, local.name) + for local, remote in self.prop.local_remote_pairs + if remote.name in self.remote_column_names + ] + ) + ) + return sa.or_(*conditions) @property def related_entities(self): @@ -292,22 +362,23 @@ class ManyToManyFetcher(Fetcher): self.path.session .query( self.path.model, - getattr(self.prop.secondary.c, self.remote_column_name) + *[ + getattr(self.prop.secondary.c, name) + for name in self.remote_column_names + ] ) .join( self.prop.secondary, self.prop.secondaryjoin ) .filter( - getattr(self.prop.secondary.c, self.remote_column_name).in_( - self.local_values_list - ) + self.condition ) ) def fetch(self): - for entity, parent_id in self.related_entities: - self.parent_dict[parent_id].append( - entity + for value in self.related_entities: + self.parent_dict[tuple(value[1:])].append( + value[0] ) @@ -317,11 +388,38 @@ class ManyToOneFetcher(Fetcher): self.parent_dict = defaultdict(lambda: None) def append_entity(self, entity): - self.parent_dict[getattr(entity, self.remote_column_name)] = entity + #print 'appending entity ', entity, ' to key ', self.parent_key(entity) + self.parent_dict[self.parent_key(entity)] = entity + + @property + def remote_column_names(self): + names = [] + for local, remote in self.prop.local_remote_pairs: + names.append(remote.name) + return names + + @property + def local_column_names(self): + names = [] + for local, remote in self.prop.local_remote_pairs: + names.append(local.name) + return names class OneToManyFetcher(Fetcher): def append_entity(self, entity): - self.parent_dict[getattr(entity, self.remote_column_name)].append( + #print 'appending entity ', entity, ' to key ', self.parent_key(entity) + self.parent_dict[self.parent_key(entity)].append( entity ) + + @property + def remote_column_names(self): + names = [] + for local, remote in self.prop.local_remote_pairs: + for fk in remote.foreign_keys: + # TODO: make this support inherited tables + if fk.column.table == self.path.parent_model.__table__: + names.append(fk.parent.name) + + return names diff --git a/tests/batch_fetch/test_many_to_many_composite_keys.py b/tests/batch_fetch/test_many_to_many_composite_keys.py new file mode 100644 index 0000000..0d19578 --- /dev/null +++ b/tests/batch_fetch/test_many_to_many_composite_keys.py @@ -0,0 +1,111 @@ +import sqlalchemy as sa +from sqlalchemy_utils import batch_fetch, with_backrefs +from tests import TestCase + + +class TestBatchFetchManyToManyCompositeRelationships(TestCase): + def create_models(self): + class Article(self.Base): + __tablename__ = 'article' + id1 = sa.Column(sa.Integer, primary_key=True) + id2 = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + article_tag = sa.Table( + 'article_tag', + self.Base.metadata, + sa.Column( + 'article_id1', + sa.Integer, + ), + sa.Column( + 'article_id2', + sa.Integer, + ), + sa.Column( + 'tag_id1', + sa.Integer, + ), + sa.Column( + 'tag_id2', + sa.Integer, + ), + sa.ForeignKeyConstraint( + ['article_id1', 'article_id2'], + ['article.id1', 'article.id2'] + ), + sa.ForeignKeyConstraint( + ['tag_id1', 'tag_id2'], + ['tag.id1', 'tag.id2'] + ) + ) + + class Tag(self.Base): + __tablename__ = 'tag' + id1 = sa.Column(sa.Integer, primary_key=True) + id2 = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + articles = sa.orm.relationship( + Article, + secondary=article_tag, + backref=sa.orm.backref( + 'tags', + ), + ) + + self.Article = Article + self.Tag = Tag + + def setup_method(self, method): + TestCase.setup_method(self, method) + articles = [ + self.Article(id1=1, id2=2, name=u'Article 1'), + self.Article(id1=2, id2=2, name=u'Article 2'), + self.Article(id1=3, id2=3, name=u'Article 3'), + self.Article(id1=4, id2=3, name=u'Article 4'), + self.Article(id1=5, id2=3, name=u'Article 5') + ] + self.session.add_all(articles) + self.session.flush() + + tags = [ + self.Tag(id1=1, id2=2, name=u'Tag 1'), + self.Tag(id1=2, id2=3, name=u'Tag 2'), + self.Tag(id1=3, id2=4, name=u'Tag 3') + ] + articles[0].tags = tags + articles[3].tags = tags[1:] + self.session.commit() + + def test_deep_relationships(self): + articles = ( + self.session.query(self.Article) + .order_by(self.Article.id1).all() + ) + batch_fetch( + articles, + 'tags' + ) + query_count = self.connection.query_count + assert articles[0].tags + articles[1].tags + assert articles[3].tags + assert self.connection.query_count == query_count + + def test_many_to_many_backref_population(self): + articles = ( + self.session.query(self.Article) + .order_by(self.Article.id1).all() + ) + batch_fetch( + articles, + with_backrefs('tags'), + ) + query_count = self.connection.query_count + tags = articles[0].tags + tags2 = articles[3].tags + tags[0].articles + tags2[0].articles + names = [article.name for article in tags[0].articles] + assert u'Article 1' in names + assert self.connection.query_count == query_count diff --git a/tests/batch_fetch/test_composite_keys.py b/tests/batch_fetch/test_one_to_many_composite_keys.py similarity index 100% rename from tests/batch_fetch/test_composite_keys.py rename to tests/batch_fetch/test_one_to_many_composite_keys.py