Added support for compound one to many batch fetching

This commit is contained in:
Konsta Vesterinen
2013-08-22 13:46:37 +03:00
parent e197a56f53
commit 202e33eef4
3 changed files with 79 additions and 38 deletions

View File

@@ -154,16 +154,19 @@ class FetchingCoordinator(object):
fetchers.append(self.fetcher_for_attr_path(path)) fetchers.append(self.fetcher_for_attr_path(path))
fetcher = CompoundFetcher(*fetchers) fetcher = CompoundFetcher(*fetchers)
print fetcher.condition
else: else:
fetcher = self.fetcher_for_attr_path(attr_path) fetcher = self.fetcher_for_attr_path(attr_path)
if not fetcher: fetcher.fetch()
return fetcher.populate()
fetcher.fetch()
fetcher.populate()
class CompoundFetcher(object): class AbstractFetcher(object):
@property
def related_entities(self):
return self.session.query(self.model).filter(self.condition)
class CompoundFetcher(AbstractFetcher):
def __init__(self, *fetchers): def __init__(self, *fetchers):
if not all(fetchers[0].model == fetcher.model for fetcher in fetchers): if not all(fetchers[0].model == fetcher.model for fetcher in fetchers):
raise Exception( raise Exception(
@@ -172,18 +175,36 @@ class CompoundFetcher(object):
) )
self.fetchers = fetchers self.fetchers = fetchers
@property
def session(self):
return self.fetchers[0].session
@property
def model(self):
return self.fetchers[0].model
@property @property
def condition(self): def condition(self):
return sa.or_( return sa.or_(
*[fetcher.condition for fetcher in self.fetchers] *[fetcher.condition for fetcher in self.fetchers]
) )
@property def fetcher_for_entity(self, entity):
def local_values(self): for fetcher in self.fetchers:
pass if getattr(entity, fetcher.remote_column_name) is not None:
return fetcher
def fetch(self):
self.parent_dict = defaultdict(list)
for entity in self.related_entities:
self.fetcher_for_entity(entity).append_entity(entity)
def populate(self):
for fetcher in self.fetchers:
fetcher.populate()
class Fetcher(object): class Fetcher(AbstractFetcher):
def __init__(self, entities, property_, populate_backrefs=False): def __init__(self, entities, property_, populate_backrefs=False):
self.should_populate_backrefs = populate_backrefs self.should_populate_backrefs = populate_backrefs
self.entities = entities self.entities = entities
@@ -191,6 +212,7 @@ class Fetcher(object):
self.model = self.prop.mapper.class_ self.model = self.prop.mapper.class_
self.first = self.entities[0] self.first = self.entities[0]
self.session = object_session(self.first) self.session = object_session(self.first)
self.parent_dict = defaultdict(list)
@property @property
def local_values_list(self): def local_values_list(self):
@@ -245,9 +267,9 @@ class Fetcher(object):
self.local_values_list self.local_values_list
) )
@property def fetch(self):
def related_entities(self): for entity in self.related_entities:
return self.session.query(self.model).filter(self.condition) self.append_entity(entity)
class ManyToManyFetcher(Fetcher): class ManyToManyFetcher(Fetcher):
@@ -278,7 +300,6 @@ class ManyToManyFetcher(Fetcher):
) )
def fetch(self): def fetch(self):
self.parent_dict = defaultdict(list)
for entity, parent_id in self.related_entities: for entity, parent_id in self.related_entities:
self.parent_dict[parent_id].append( self.parent_dict[parent_id].append(
entity entity
@@ -286,16 +307,16 @@ class ManyToManyFetcher(Fetcher):
class ManyToOneFetcher(Fetcher): class ManyToOneFetcher(Fetcher):
def fetch(self): def __init__(self, entities, property_, populate_backrefs=False):
Fetcher.__init__(self, entities, property_, populate_backrefs)
self.parent_dict = defaultdict(lambda: None) self.parent_dict = defaultdict(lambda: None)
for entity in self.related_entities:
self.parent_dict[getattr(entity, self.remote_column_name)] = entity def append_entity(self, entity):
self.parent_dict[getattr(entity, self.remote_column_name)] = entity
class OneToManyFetcher(Fetcher): class OneToManyFetcher(Fetcher):
def fetch(self): def append_entity(self, entity):
self.parent_dict = defaultdict(list) self.parent_dict[getattr(entity, self.remote_column_name)].append(
for entity in self.related_entities: entity
self.parent_dict[getattr(entity, self.remote_column_name)].append( )
entity
)

View File

@@ -4,7 +4,7 @@ from sqlalchemy_utils.functions import compound_path
from tests import TestCase from tests import TestCase
class TestCompoundBatchFetching(TestCase): class TestCompoundOneToManyBatchFetching(TestCase):
def create_models(self): def create_models(self):
class Building(self.Base): class Building(self.Base):
__tablename__ = 'building' __tablename__ = 'building'
@@ -54,27 +54,33 @@ class TestCompoundBatchFetching(TestCase):
def setup_method(self, method): def setup_method(self, method):
TestCase.setup_method(self, method) TestCase.setup_method(self, method)
self.buildings = [ self.buildings = [
self.Building(name=u'B 1'), self.Building(id=12, name=u'B 1'),
self.Building(name=u'B 2'), self.Building(id=15, name=u'B 2'),
self.Building(name=u'B 3'), self.Building(id=19, name=u'B 3'),
] ]
self.business_premises = [ self.business_premises = [
self.BusinessPremise(name=u'BP 1', building=self.buildings[0]), self.BusinessPremise(
self.BusinessPremise(name=u'BP 2', building=self.buildings[0]), id=22, name=u'BP 1', building=self.buildings[0]
self.BusinessPremise(name=u'BP 3', building=self.buildings[2]), ),
self.BusinessPremise(
id=33, name=u'BP 2', building=self.buildings[0]
),
self.BusinessPremise(
id=44, name=u'BP 3', building=self.buildings[2]
),
] ]
self.equipment = [ self.equipment = [
self.Equipment( self.Equipment(
name=u'E 1', building=self.buildings[0] id=2, name=u'E 1', building=self.buildings[0]
), ),
self.Equipment( self.Equipment(
name=u'E 2', building=self.buildings[2] id=4, name=u'E 2', building=self.buildings[2]
), ),
self.Equipment( self.Equipment(
name=u'E 3', business_premise=self.business_premises[0] id=6, name=u'E 3', business_premise=self.business_premises[0]
), ),
self.Equipment( self.Equipment(
name=u'E 4', business_premise=self.business_premises[2] id=8, name=u'E 4', business_premise=self.business_premises[2]
), ),
] ]
self.session.add_all(self.buildings) self.session.add_all(self.buildings)
@@ -94,7 +100,10 @@ class TestCompoundBatchFetching(TestCase):
) )
query_count = self.connection.query_count query_count = self.connection.query_count
buildings[0].equipment assert len(buildings[0].equipment) == 1
buildings[1].equipment assert buildings[0].equipment[0].name == 'E 1'
buildings[0].business_premises[0].equipment assert not buildings[1].equipment
assert buildings[0].business_premises[0].equipment
assert self.business_premises[2].equipment
assert self.business_premises[2].equipment[0].name == 'E 4'
assert self.connection.query_count == query_count assert self.connection.query_count == query_count

View File

@@ -63,5 +63,16 @@ class TestBatchFetchOneToManyRelationships(TestCase):
categories = self.session.query(self.Category).all() categories = self.session.query(self.Category).all()
batch_fetch(categories, self.Category.articles) batch_fetch(categories, self.Category.articles)
query_count = self.connection.query_count query_count = self.connection.query_count
categories[0].articles # no lazy load should occur articles = categories[0].articles # no lazy load should occur
assert len(articles) == 2
article_names = [article.name for article in articles]
assert 'Article 1' in article_names
assert 'Article 2' in article_names
articles = categories[1].articles # no lazy load should occur
assert len(articles) == 3
article_names = [article.name for article in articles]
assert 'Article 3' in article_names
assert 'Article 4' in article_names
assert 'Article 5' in article_names
assert self.connection.query_count == query_count assert self.connection.query_count == query_count