From 202e33eef450beb2e6873ef21d424fe165584e95 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Thu, 22 Aug 2013 13:46:37 +0300 Subject: [PATCH] Added support for compound one to many batch fetching --- sqlalchemy_utils/functions/batch_fetch.py | 67 ++++++++++++------- tests/batch_fetch/test_compound_fetching.py | 37 ++++++---- .../test_one_to_many_relationships.py | 13 +++- 3 files changed, 79 insertions(+), 38 deletions(-) diff --git a/sqlalchemy_utils/functions/batch_fetch.py b/sqlalchemy_utils/functions/batch_fetch.py index 5dc4dee..ec67dc3 100644 --- a/sqlalchemy_utils/functions/batch_fetch.py +++ b/sqlalchemy_utils/functions/batch_fetch.py @@ -154,16 +154,19 @@ class FetchingCoordinator(object): fetchers.append(self.fetcher_for_attr_path(path)) fetcher = CompoundFetcher(*fetchers) - print fetcher.condition else: fetcher = self.fetcher_for_attr_path(attr_path) - if not fetcher: - return - fetcher.fetch() - fetcher.populate() + fetcher.fetch() + fetcher.populate() -class CompoundFetcher(object): +class AbstractFetcher(object): + @property + def related_entities(self): + return self.session.query(self.model).filter(self.condition) + + +class CompoundFetcher(AbstractFetcher): def __init__(self, *fetchers): if not all(fetchers[0].model == fetcher.model for fetcher in fetchers): raise Exception( @@ -172,18 +175,36 @@ class CompoundFetcher(object): ) self.fetchers = fetchers + @property + def session(self): + return self.fetchers[0].session + + @property + def model(self): + return self.fetchers[0].model + @property def condition(self): return sa.or_( *[fetcher.condition for fetcher in self.fetchers] ) - @property - def local_values(self): - pass + def fetcher_for_entity(self, entity): + for fetcher in self.fetchers: + if getattr(entity, fetcher.remote_column_name) is not None: + return fetcher + + def fetch(self): + self.parent_dict = defaultdict(list) + for entity in self.related_entities: + self.fetcher_for_entity(entity).append_entity(entity) + + def populate(self): + for fetcher in self.fetchers: + fetcher.populate() -class Fetcher(object): +class Fetcher(AbstractFetcher): def __init__(self, entities, property_, populate_backrefs=False): self.should_populate_backrefs = populate_backrefs self.entities = entities @@ -191,6 +212,7 @@ class Fetcher(object): self.model = self.prop.mapper.class_ self.first = self.entities[0] self.session = object_session(self.first) + self.parent_dict = defaultdict(list) @property def local_values_list(self): @@ -245,9 +267,9 @@ class Fetcher(object): self.local_values_list ) - @property - def related_entities(self): - return self.session.query(self.model).filter(self.condition) + def fetch(self): + for entity in self.related_entities: + self.append_entity(entity) class ManyToManyFetcher(Fetcher): @@ -278,7 +300,6 @@ class ManyToManyFetcher(Fetcher): ) def fetch(self): - self.parent_dict = defaultdict(list) for entity, parent_id in self.related_entities: self.parent_dict[parent_id].append( entity @@ -286,16 +307,16 @@ class ManyToManyFetcher(Fetcher): class ManyToOneFetcher(Fetcher): - def fetch(self): + def __init__(self, entities, property_, populate_backrefs=False): + Fetcher.__init__(self, entities, property_, populate_backrefs) self.parent_dict = defaultdict(lambda: None) - for entity in self.related_entities: - self.parent_dict[getattr(entity, self.remote_column_name)] = entity + + def append_entity(self, entity): + self.parent_dict[getattr(entity, self.remote_column_name)] = entity class OneToManyFetcher(Fetcher): - def fetch(self): - self.parent_dict = defaultdict(list) - for entity in self.related_entities: - self.parent_dict[getattr(entity, self.remote_column_name)].append( - entity - ) + def append_entity(self, entity): + self.parent_dict[getattr(entity, self.remote_column_name)].append( + entity + ) diff --git a/tests/batch_fetch/test_compound_fetching.py b/tests/batch_fetch/test_compound_fetching.py index ca6784f..fa93843 100644 --- a/tests/batch_fetch/test_compound_fetching.py +++ b/tests/batch_fetch/test_compound_fetching.py @@ -4,7 +4,7 @@ from sqlalchemy_utils.functions import compound_path from tests import TestCase -class TestCompoundBatchFetching(TestCase): +class TestCompoundOneToManyBatchFetching(TestCase): def create_models(self): class Building(self.Base): __tablename__ = 'building' @@ -54,27 +54,33 @@ class TestCompoundBatchFetching(TestCase): def setup_method(self, method): TestCase.setup_method(self, method) self.buildings = [ - self.Building(name=u'B 1'), - self.Building(name=u'B 2'), - self.Building(name=u'B 3'), + self.Building(id=12, name=u'B 1'), + self.Building(id=15, name=u'B 2'), + self.Building(id=19, name=u'B 3'), ] self.business_premises = [ - self.BusinessPremise(name=u'BP 1', building=self.buildings[0]), - self.BusinessPremise(name=u'BP 2', building=self.buildings[0]), - self.BusinessPremise(name=u'BP 3', building=self.buildings[2]), + self.BusinessPremise( + id=22, name=u'BP 1', building=self.buildings[0] + ), + self.BusinessPremise( + id=33, name=u'BP 2', building=self.buildings[0] + ), + self.BusinessPremise( + id=44, name=u'BP 3', building=self.buildings[2] + ), ] self.equipment = [ self.Equipment( - name=u'E 1', building=self.buildings[0] + id=2, name=u'E 1', building=self.buildings[0] ), self.Equipment( - name=u'E 2', building=self.buildings[2] + id=4, name=u'E 2', building=self.buildings[2] ), self.Equipment( - name=u'E 3', business_premise=self.business_premises[0] + id=6, name=u'E 3', business_premise=self.business_premises[0] ), self.Equipment( - name=u'E 4', business_premise=self.business_premises[2] + id=8, name=u'E 4', business_premise=self.business_premises[2] ), ] self.session.add_all(self.buildings) @@ -94,7 +100,10 @@ class TestCompoundBatchFetching(TestCase): ) query_count = self.connection.query_count - buildings[0].equipment - buildings[1].equipment - buildings[0].business_premises[0].equipment + assert len(buildings[0].equipment) == 1 + assert buildings[0].equipment[0].name == 'E 1' + assert not buildings[1].equipment + assert buildings[0].business_premises[0].equipment + assert self.business_premises[2].equipment + assert self.business_premises[2].equipment[0].name == 'E 4' assert self.connection.query_count == query_count diff --git a/tests/batch_fetch/test_one_to_many_relationships.py b/tests/batch_fetch/test_one_to_many_relationships.py index 6ae00ea..876ae30 100644 --- a/tests/batch_fetch/test_one_to_many_relationships.py +++ b/tests/batch_fetch/test_one_to_many_relationships.py @@ -63,5 +63,16 @@ class TestBatchFetchOneToManyRelationships(TestCase): categories = self.session.query(self.Category).all() batch_fetch(categories, self.Category.articles) query_count = self.connection.query_count - categories[0].articles # no lazy load should occur + articles = categories[0].articles # no lazy load should occur + assert len(articles) == 2 + article_names = [article.name for article in articles] + + assert 'Article 1' in article_names + assert 'Article 2' in article_names + articles = categories[1].articles # no lazy load should occur + assert len(articles) == 3 + article_names = [article.name for article in articles] + assert 'Article 3' in article_names + assert 'Article 4' in article_names + assert 'Article 5' in article_names assert self.connection.query_count == query_count