From 32f84a5c5df2c7fa79c33b41747c687ca3fab518 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Sat, 17 Aug 2013 19:04:06 +0300 Subject: [PATCH] Smarter local value checking for batch fetch --- sqlalchemy_utils/functions/batch_fetch.py | 27 ++++++++++------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/sqlalchemy_utils/functions/batch_fetch.py b/sqlalchemy_utils/functions/batch_fetch.py index 7e67375..7c8c1cf 100644 --- a/sqlalchemy_utils/functions/batch_fetch.py +++ b/sqlalchemy_utils/functions/batch_fetch.py @@ -157,11 +157,10 @@ class Fetcher(object): self.first = self.entities[0] self.session = object_session(self.first) - for entity in self.entities: - self.parent_dict = dict( - (self.local_values(entity), []) - for entity in self.entities - ) + self.parent_dict = dict( + (self.local_values(entity), []) + for entity in self.entities + ) @property def local_values_list(self): @@ -177,17 +176,19 @@ class Fetcher(object): """ Populates backrefs for given related entities. """ - backref_dict = dict( - (entity.id, []) for entity, parent_id in related_entities + (self.local_values(entity), []) + for entity, parent_id in related_entities ) for entity, parent_id in related_entities: - backref_dict[entity.id].append( + backref_dict[self.local_values(entity)].append( self.session.query(self.first.__class__).get(parent_id) ) for entity, parent_id in related_entities: set_committed_value( - entity, self.prop.back_populates, backref_dict[entity.id] + entity, + self.prop.back_populates, + backref_dict[self.local_values(entity)] ) def populate(self): @@ -207,8 +208,6 @@ class Fetcher(object): class ManyToManyFetcher(Fetcher): def fetch(self): - parent_ids = [entity.id for entity in self.entities] - column_name = None for column in self.prop.remote_side: for fk in column.foreign_keys: @@ -227,7 +226,7 @@ class ManyToManyFetcher(Fetcher): ) .filter( getattr(self.prop.secondary.c, column_name).in_( - parent_ids + self.local_values_list ) ) ) @@ -256,14 +255,12 @@ class ManyToOneFetcher(Fetcher): class OneToManyFetcher(Fetcher): def fetch(self): - parent_ids = [entity.id for entity in self.entities] - column_name = list(self.prop.remote_side)[0].name self.related_entities = ( self.session.query(self.model) .filter( - getattr(self.model, column_name).in_(parent_ids) + getattr(self.model, column_name).in_(self.local_values_list) ) )