Smarter local value checking for batch fetch

This commit is contained in:
Konsta Vesterinen
2013-08-17 19:04:06 +03:00
parent ee8b97f0f7
commit 32f84a5c5d

View File

@@ -157,11 +157,10 @@ class Fetcher(object):
self.first = self.entities[0] self.first = self.entities[0]
self.session = object_session(self.first) self.session = object_session(self.first)
for entity in self.entities: self.parent_dict = dict(
self.parent_dict = dict( (self.local_values(entity), [])
(self.local_values(entity), []) for entity in self.entities
for entity in self.entities )
)
@property @property
def local_values_list(self): def local_values_list(self):
@@ -177,17 +176,19 @@ class Fetcher(object):
""" """
Populates backrefs for given related entities. Populates backrefs for given related entities.
""" """
backref_dict = dict( backref_dict = dict(
(entity.id, []) for entity, parent_id in related_entities (self.local_values(entity), [])
for entity, parent_id in related_entities
) )
for entity, parent_id in related_entities: for entity, parent_id in related_entities:
backref_dict[entity.id].append( backref_dict[self.local_values(entity)].append(
self.session.query(self.first.__class__).get(parent_id) self.session.query(self.first.__class__).get(parent_id)
) )
for entity, parent_id in related_entities: for entity, parent_id in related_entities:
set_committed_value( set_committed_value(
entity, self.prop.back_populates, backref_dict[entity.id] entity,
self.prop.back_populates,
backref_dict[self.local_values(entity)]
) )
def populate(self): def populate(self):
@@ -207,8 +208,6 @@ class Fetcher(object):
class ManyToManyFetcher(Fetcher): class ManyToManyFetcher(Fetcher):
def fetch(self): def fetch(self):
parent_ids = [entity.id for entity in self.entities]
column_name = None column_name = None
for column in self.prop.remote_side: for column in self.prop.remote_side:
for fk in column.foreign_keys: for fk in column.foreign_keys:
@@ -227,7 +226,7 @@ class ManyToManyFetcher(Fetcher):
) )
.filter( .filter(
getattr(self.prop.secondary.c, column_name).in_( getattr(self.prop.secondary.c, column_name).in_(
parent_ids self.local_values_list
) )
) )
) )
@@ -256,14 +255,12 @@ class ManyToOneFetcher(Fetcher):
class OneToManyFetcher(Fetcher): class OneToManyFetcher(Fetcher):
def fetch(self): def fetch(self):
parent_ids = [entity.id for entity in self.entities]
column_name = list(self.prop.remote_side)[0].name column_name = list(self.prop.remote_side)[0].name
self.related_entities = ( self.related_entities = (
self.session.query(self.model) self.session.query(self.model)
.filter( .filter(
getattr(self.model, column_name).in_(parent_ids) getattr(self.model, column_name).in_(self.local_values_list)
) )
) )