diff --git a/sqlalchemy_utils/functions/batch_fetch.py b/sqlalchemy_utils/functions/batch_fetch.py index ec67dc3..d8b8a58 100644 --- a/sqlalchemy_utils/functions/batch_fetch.py +++ b/sqlalchemy_utils/functions/batch_fetch.py @@ -189,15 +189,11 @@ class CompoundFetcher(AbstractFetcher): *[fetcher.condition for fetcher in self.fetchers] ) - def fetcher_for_entity(self, entity): - for fetcher in self.fetchers: - if getattr(entity, fetcher.remote_column_name) is not None: - return fetcher - def fetch(self): - self.parent_dict = defaultdict(list) for entity in self.related_entities: - self.fetcher_for_entity(entity).append_entity(entity) + for fetcher in self.fetchers: + if getattr(entity, fetcher.remote_column_name) is not None: + fetcher.append_entity(entity) def populate(self): for fetcher in self.fetchers: diff --git a/tests/batch_fetch/test_compound_fetching.py b/tests/batch_fetch/test_compound_fetching.py index fa93843..329e5b0 100644 --- a/tests/batch_fetch/test_compound_fetching.py +++ b/tests/batch_fetch/test_compound_fetching.py @@ -107,3 +107,105 @@ class TestCompoundOneToManyBatchFetching(TestCase): assert self.business_premises[2].equipment assert self.business_premises[2].equipment[0].name == 'E 4' assert self.connection.query_count == query_count + + +class TestCompoundManyToOneBatchFetching(TestCase): + def create_models(self): + class Equipment(self.Base): + __tablename__ = 'equipment' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + equipment_id = sa.Column(sa.Integer, sa.ForeignKey(Equipment.id)) + + equipment = sa.orm.relationship(Equipment) + + class BusinessPremise(self.Base): + __tablename__ = 'business_premise' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id)) + + building = sa.orm.relationship( + Building, + backref=sa.orm.backref( + 'business_premises' + ) + ) + + equipment_id = sa.Column(sa.Integer, sa.ForeignKey(Equipment.id)) + + equipment = sa.orm.relationship(Equipment) + + self.Building = Building + self.BusinessPremise = BusinessPremise + self.Equipment = Equipment + + def setup_method(self, method): + TestCase.setup_method(self, method) + self.equipment = [ + self.Equipment( + id=2, name=u'E 1', + ), + self.Equipment( + id=4, name=u'E 2', + ), + self.Equipment( + id=6, name=u'E 3', + ), + self.Equipment( + id=8, name=u'E 4', + ), + ] + self.buildings = [ + self.Building(id=12, name=u'B 1', equipment=self.equipment[0]), + self.Building(id=15, name=u'B 2', equipment=self.equipment[1]), + self.Building(id=19, name=u'B 3'), + ] + self.business_premises = [ + self.BusinessPremise( + id=22, + name=u'BP 1', + building=self.buildings[0] + ), + self.BusinessPremise( + id=33, + name=u'BP 2', + building=self.buildings[0], + equipment=self.equipment[2] + ), + self.BusinessPremise( + id=44, + name=u'BP 3', + building=self.buildings[2], + equipment=self.equipment[1] + ), + ] + + self.session.add_all(self.buildings) + self.session.add_all(self.business_premises) + self.session.add_all(self.equipment) + self.session.commit() + + def test_compound_fetching(self): + buildings = self.session.query(self.Building).all() + batch_fetch( + buildings, + 'business_premises', + compound_path( + 'equipment', + 'business_premises.equipment' + ) + ) + query_count = self.connection.query_count + + assert buildings[0].equipment.name == 'E 1' + assert buildings[1].equipment.name == 'E 2' + assert buildings[2].business_premises[0].equipment.name == 'E 2' + assert self.connection.query_count == query_count