Added more tests and refactored batch fetch
This commit is contained in:
@@ -74,49 +74,17 @@ def batch_fetch(entities, *attr_paths):
|
||||
"""
|
||||
|
||||
if entities:
|
||||
fetcher = BatchFetcher(entities)
|
||||
fetcher = FetchingCoordinator(entities)
|
||||
for attr_path in attr_paths:
|
||||
fetcher(attr_path)
|
||||
|
||||
|
||||
class BatchFetcher(object):
|
||||
class FetchingCoordinator(object):
|
||||
def __init__(self, entities):
|
||||
self.entities = entities
|
||||
self.first = entities[0]
|
||||
self.parent_ids = [entity.id for entity in entities]
|
||||
self.session = object_session(self.first)
|
||||
|
||||
def populate_backrefs(self, related_entities):
|
||||
"""
|
||||
Populates backrefs for given related entities.
|
||||
"""
|
||||
|
||||
backref_dict = dict(
|
||||
(entity.id, []) for entity, parent_id in related_entities
|
||||
)
|
||||
for entity, parent_id in related_entities:
|
||||
backref_dict[entity.id].append(
|
||||
self.session.query(self.first.__class__).get(parent_id)
|
||||
)
|
||||
for entity, parent_id in related_entities:
|
||||
set_committed_value(
|
||||
entity, self.prop.back_populates, backref_dict[entity.id]
|
||||
)
|
||||
|
||||
def populate_entities(self):
|
||||
"""
|
||||
Populate batch fetched entities to parent objects.
|
||||
"""
|
||||
for entity in self.entities:
|
||||
set_committed_value(
|
||||
entity,
|
||||
self.prop.key,
|
||||
self.parent_dict[entity.id]
|
||||
)
|
||||
|
||||
if self.should_populate_backrefs:
|
||||
self.populate_backrefs(self.related_entities)
|
||||
|
||||
def parse_attr_path(self, attr_path, should_populate_backrefs):
|
||||
if isinstance(attr_path, six.string_types):
|
||||
attrs = attr_path.split('.')
|
||||
@@ -150,21 +118,97 @@ class BatchFetcher(object):
|
||||
'are supported.'
|
||||
)
|
||||
|
||||
column_name = list(self.prop.remote_side)[0].name
|
||||
|
||||
self.related_entities = (
|
||||
self.session.query(self.model)
|
||||
.filter(
|
||||
getattr(self.model, column_name).in_(self.parent_ids)
|
||||
def fetcher(self, property_):
|
||||
if not isinstance(property_, RelationshipProperty):
|
||||
raise Exception(
|
||||
'Given attribute is not a relationship property.'
|
||||
)
|
||||
|
||||
if property_.secondary is not None:
|
||||
return ManyToManyFetcher(self, property_)
|
||||
else:
|
||||
if property_.direction.name == 'MANYTOONE':
|
||||
return ManyToOneFetcher(self, property_)
|
||||
else:
|
||||
return OneToManyFetcher(self, property_)
|
||||
|
||||
def __call__(self, attr_path):
|
||||
if isinstance(attr_path, with_backrefs):
|
||||
self.should_populate_backrefs = True
|
||||
attr_path = attr_path.attr_path
|
||||
else:
|
||||
self.should_populate_backrefs = False
|
||||
|
||||
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
|
||||
if not attr:
|
||||
return
|
||||
|
||||
fetcher = self.fetcher(attr.property)
|
||||
fetcher.fetch()
|
||||
fetcher.populate()
|
||||
|
||||
|
||||
class Fetcher(object):
|
||||
def __init__(self, coordinator, property_):
|
||||
self.coordinator = coordinator
|
||||
self.prop = property_
|
||||
self.model = self.prop.mapper.class_
|
||||
self.entities = coordinator.entities
|
||||
self.first = self.entities[0]
|
||||
self.session = object_session(self.first)
|
||||
|
||||
for entity in self.entities:
|
||||
self.parent_dict = dict(
|
||||
(self.local_values(entity), [])
|
||||
for entity in self.entities
|
||||
)
|
||||
|
||||
@property
|
||||
def local_values_list(self):
|
||||
return [
|
||||
self.local_values(entity)
|
||||
for entity in self.entities
|
||||
]
|
||||
|
||||
def local_values(self, entity):
|
||||
return getattr(entity, list(self.prop.local_columns)[0].name)
|
||||
|
||||
def populate_backrefs(self, related_entities):
|
||||
"""
|
||||
Populates backrefs for given related entities.
|
||||
"""
|
||||
|
||||
backref_dict = dict(
|
||||
(entity.id, []) for entity, parent_id in related_entities
|
||||
)
|
||||
|
||||
for entity in self.related_entities:
|
||||
self.parent_dict[getattr(entity, column_name)].append(
|
||||
entity
|
||||
for entity, parent_id in related_entities:
|
||||
backref_dict[entity.id].append(
|
||||
self.session.query(self.first.__class__).get(parent_id)
|
||||
)
|
||||
for entity, parent_id in related_entities:
|
||||
set_committed_value(
|
||||
entity, self.prop.back_populates, backref_dict[entity.id]
|
||||
)
|
||||
|
||||
def fetch_association_entities(self):
|
||||
def populate(self):
|
||||
"""
|
||||
Populate batch fetched entities to parent objects.
|
||||
"""
|
||||
for entity in self.entities:
|
||||
set_committed_value(
|
||||
entity,
|
||||
self.prop.key,
|
||||
self.parent_dict[self.local_values(entity)]
|
||||
)
|
||||
|
||||
if self.coordinator.should_populate_backrefs:
|
||||
self.populate_backrefs(self.related_entities)
|
||||
|
||||
|
||||
class ManyToManyFetcher(Fetcher):
|
||||
def fetch(self):
|
||||
parent_ids = [entity.id for entity in self.entities]
|
||||
|
||||
column_name = None
|
||||
for column in self.prop.remote_side:
|
||||
for fk in column.foreign_keys:
|
||||
@@ -183,7 +227,7 @@ class BatchFetcher(object):
|
||||
)
|
||||
.filter(
|
||||
getattr(self.prop.secondary.c, column_name).in_(
|
||||
self.parent_ids
|
||||
parent_ids
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -192,30 +236,38 @@ class BatchFetcher(object):
|
||||
entity
|
||||
)
|
||||
|
||||
def __call__(self, attr_path):
|
||||
self.parent_dict = dict(
|
||||
(entity.id, []) for entity in self.entities
|
||||
|
||||
class ManyToOneFetcher(Fetcher):
|
||||
def fetch(self):
|
||||
column_name = list(self.prop.remote_side)[0].name
|
||||
|
||||
self.related_entities = (
|
||||
self.session.query(self.model)
|
||||
.filter(
|
||||
getattr(self.model, column_name).in_(self.local_values_list)
|
||||
)
|
||||
)
|
||||
if isinstance(attr_path, with_backrefs):
|
||||
self.should_populate_backrefs = True
|
||||
attr_path = attr_path.attr_path
|
||||
else:
|
||||
self.should_populate_backrefs = False
|
||||
|
||||
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
|
||||
if not attr:
|
||||
return
|
||||
|
||||
self.prop = attr.property
|
||||
if not isinstance(self.prop, RelationshipProperty):
|
||||
raise Exception(
|
||||
'Given attribute is not a relationship property.'
|
||||
for entity in self.related_entities:
|
||||
self.parent_dict[getattr(entity, column_name)].append(
|
||||
entity
|
||||
)
|
||||
|
||||
self.model = self.prop.mapper.class_
|
||||
|
||||
if self.prop.secondary is None:
|
||||
self.fetch_relation_entities()
|
||||
else:
|
||||
self.fetch_association_entities()
|
||||
self.populate_entities()
|
||||
class OneToManyFetcher(Fetcher):
|
||||
def fetch(self):
|
||||
parent_ids = [entity.id for entity in self.entities]
|
||||
|
||||
column_name = list(self.prop.remote_side)[0].name
|
||||
|
||||
self.related_entities = (
|
||||
self.session.query(self.model)
|
||||
.filter(
|
||||
getattr(self.model, column_name).in_(parent_ids)
|
||||
)
|
||||
)
|
||||
|
||||
for entity in self.related_entities:
|
||||
self.parent_dict[getattr(entity, column_name)].append(
|
||||
entity
|
||||
)
|
||||
|
@@ -29,8 +29,16 @@ class TestBatchFetchManyToOneRelationships(TestCase):
|
||||
def setup_method(self, method):
|
||||
TestCase.setup_method(self, method)
|
||||
articles = [
|
||||
self.Article(name=u'Article 1', author=self.User(name=u'John')),
|
||||
self.Article(name=u'Article 2', author=self.User(name=u'Matt')),
|
||||
self.Article(
|
||||
id=1,
|
||||
name=u'Article 1',
|
||||
author=self.User(id=333, name=u'John')
|
||||
),
|
||||
self.Article(
|
||||
id=2,
|
||||
name=u'Article 2',
|
||||
author=self.User(id=334, name=u'Matt')
|
||||
),
|
||||
]
|
||||
self.session.add_all(articles)
|
||||
self.session.commit()
|
||||
|
Reference in New Issue
Block a user