Added support for compound one to many batch fetching
This commit is contained in:
		| @@ -154,16 +154,19 @@ class FetchingCoordinator(object): | ||||
|                 fetchers.append(self.fetcher_for_attr_path(path)) | ||||
|  | ||||
|             fetcher = CompoundFetcher(*fetchers) | ||||
|             print fetcher.condition | ||||
|         else: | ||||
|             fetcher = self.fetcher_for_attr_path(attr_path) | ||||
|             if not fetcher: | ||||
|                 return | ||||
|             fetcher.fetch() | ||||
|             fetcher.populate() | ||||
|         fetcher.fetch() | ||||
|         fetcher.populate() | ||||
|  | ||||
|  | ||||
| class CompoundFetcher(object): | ||||
| class AbstractFetcher(object): | ||||
|     @property | ||||
|     def related_entities(self): | ||||
|         return self.session.query(self.model).filter(self.condition) | ||||
|  | ||||
|  | ||||
| class CompoundFetcher(AbstractFetcher): | ||||
|     def __init__(self, *fetchers): | ||||
|         if not all(fetchers[0].model == fetcher.model for fetcher in fetchers): | ||||
|             raise Exception( | ||||
| @@ -172,18 +175,36 @@ class CompoundFetcher(object): | ||||
|             ) | ||||
|         self.fetchers = fetchers | ||||
|  | ||||
|     @property | ||||
|     def session(self): | ||||
|         return self.fetchers[0].session | ||||
|  | ||||
|     @property | ||||
|     def model(self): | ||||
|         return self.fetchers[0].model | ||||
|  | ||||
|     @property | ||||
|     def condition(self): | ||||
|         return sa.or_( | ||||
|             *[fetcher.condition for fetcher in self.fetchers] | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def local_values(self): | ||||
|         pass | ||||
|     def fetcher_for_entity(self, entity): | ||||
|         for fetcher in self.fetchers: | ||||
|             if getattr(entity, fetcher.remote_column_name) is not None: | ||||
|                 return fetcher | ||||
|  | ||||
|     def fetch(self): | ||||
|         self.parent_dict = defaultdict(list) | ||||
|         for entity in self.related_entities: | ||||
|             self.fetcher_for_entity(entity).append_entity(entity) | ||||
|  | ||||
|     def populate(self): | ||||
|         for fetcher in self.fetchers: | ||||
|             fetcher.populate() | ||||
|  | ||||
|  | ||||
| class Fetcher(object): | ||||
| class Fetcher(AbstractFetcher): | ||||
|     def __init__(self, entities, property_, populate_backrefs=False): | ||||
|         self.should_populate_backrefs = populate_backrefs | ||||
|         self.entities = entities | ||||
| @@ -191,6 +212,7 @@ class Fetcher(object): | ||||
|         self.model = self.prop.mapper.class_ | ||||
|         self.first = self.entities[0] | ||||
|         self.session = object_session(self.first) | ||||
|         self.parent_dict = defaultdict(list) | ||||
|  | ||||
|     @property | ||||
|     def local_values_list(self): | ||||
| @@ -245,9 +267,9 @@ class Fetcher(object): | ||||
|             self.local_values_list | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def related_entities(self): | ||||
|         return self.session.query(self.model).filter(self.condition) | ||||
|     def fetch(self): | ||||
|         for entity in self.related_entities: | ||||
|             self.append_entity(entity) | ||||
|  | ||||
|  | ||||
| class ManyToManyFetcher(Fetcher): | ||||
| @@ -278,7 +300,6 @@ class ManyToManyFetcher(Fetcher): | ||||
|         ) | ||||
|  | ||||
|     def fetch(self): | ||||
|         self.parent_dict = defaultdict(list) | ||||
|         for entity, parent_id in self.related_entities: | ||||
|             self.parent_dict[parent_id].append( | ||||
|                 entity | ||||
| @@ -286,16 +307,16 @@ class ManyToManyFetcher(Fetcher): | ||||
|  | ||||
|  | ||||
| class ManyToOneFetcher(Fetcher): | ||||
|     def fetch(self): | ||||
|     def __init__(self, entities, property_, populate_backrefs=False): | ||||
|         Fetcher.__init__(self, entities, property_, populate_backrefs) | ||||
|         self.parent_dict = defaultdict(lambda: None) | ||||
|         for entity in self.related_entities: | ||||
|             self.parent_dict[getattr(entity, self.remote_column_name)] = entity | ||||
|  | ||||
|     def append_entity(self, entity): | ||||
|         self.parent_dict[getattr(entity, self.remote_column_name)] = entity | ||||
|  | ||||
|  | ||||
| class OneToManyFetcher(Fetcher): | ||||
|     def fetch(self): | ||||
|         self.parent_dict = defaultdict(list) | ||||
|         for entity in self.related_entities: | ||||
|             self.parent_dict[getattr(entity, self.remote_column_name)].append( | ||||
|                 entity | ||||
|             ) | ||||
|     def append_entity(self, entity): | ||||
|         self.parent_dict[getattr(entity, self.remote_column_name)].append( | ||||
|             entity | ||||
|         ) | ||||
|   | ||||
| @@ -4,7 +4,7 @@ from sqlalchemy_utils.functions import compound_path | ||||
| from tests import TestCase | ||||
|  | ||||
|  | ||||
| class TestCompoundBatchFetching(TestCase): | ||||
| class TestCompoundOneToManyBatchFetching(TestCase): | ||||
|     def create_models(self): | ||||
|         class Building(self.Base): | ||||
|             __tablename__ = 'building' | ||||
| @@ -54,27 +54,33 @@ class TestCompoundBatchFetching(TestCase): | ||||
|     def setup_method(self, method): | ||||
|         TestCase.setup_method(self, method) | ||||
|         self.buildings = [ | ||||
|             self.Building(name=u'B 1'), | ||||
|             self.Building(name=u'B 2'), | ||||
|             self.Building(name=u'B 3'), | ||||
|             self.Building(id=12, name=u'B 1'), | ||||
|             self.Building(id=15, name=u'B 2'), | ||||
|             self.Building(id=19, name=u'B 3'), | ||||
|         ] | ||||
|         self.business_premises = [ | ||||
|             self.BusinessPremise(name=u'BP 1', building=self.buildings[0]), | ||||
|             self.BusinessPremise(name=u'BP 2', building=self.buildings[0]), | ||||
|             self.BusinessPremise(name=u'BP 3', building=self.buildings[2]), | ||||
|             self.BusinessPremise( | ||||
|                 id=22, name=u'BP 1', building=self.buildings[0] | ||||
|             ), | ||||
|             self.BusinessPremise( | ||||
|                 id=33, name=u'BP 2', building=self.buildings[0] | ||||
|             ), | ||||
|             self.BusinessPremise( | ||||
|                 id=44, name=u'BP 3', building=self.buildings[2] | ||||
|             ), | ||||
|         ] | ||||
|         self.equipment = [ | ||||
|             self.Equipment( | ||||
|                 name=u'E 1', building=self.buildings[0] | ||||
|                 id=2, name=u'E 1', building=self.buildings[0] | ||||
|             ), | ||||
|             self.Equipment( | ||||
|                 name=u'E 2', building=self.buildings[2] | ||||
|                 id=4, name=u'E 2', building=self.buildings[2] | ||||
|             ), | ||||
|             self.Equipment( | ||||
|                 name=u'E 3', business_premise=self.business_premises[0] | ||||
|                 id=6, name=u'E 3', business_premise=self.business_premises[0] | ||||
|             ), | ||||
|             self.Equipment( | ||||
|                 name=u'E 4', business_premise=self.business_premises[2] | ||||
|                 id=8, name=u'E 4', business_premise=self.business_premises[2] | ||||
|             ), | ||||
|         ] | ||||
|         self.session.add_all(self.buildings) | ||||
| @@ -94,7 +100,10 @@ class TestCompoundBatchFetching(TestCase): | ||||
|         ) | ||||
|         query_count = self.connection.query_count | ||||
|  | ||||
|         buildings[0].equipment | ||||
|         buildings[1].equipment | ||||
|         buildings[0].business_premises[0].equipment | ||||
|         assert len(buildings[0].equipment) == 1 | ||||
|         assert buildings[0].equipment[0].name == 'E 1' | ||||
|         assert not buildings[1].equipment | ||||
|         assert buildings[0].business_premises[0].equipment | ||||
|         assert self.business_premises[2].equipment | ||||
|         assert self.business_premises[2].equipment[0].name == 'E 4' | ||||
|         assert self.connection.query_count == query_count | ||||
|   | ||||
| @@ -63,5 +63,16 @@ class TestBatchFetchOneToManyRelationships(TestCase): | ||||
|         categories = self.session.query(self.Category).all() | ||||
|         batch_fetch(categories, self.Category.articles) | ||||
|         query_count = self.connection.query_count | ||||
|         categories[0].articles  # no lazy load should occur | ||||
|         articles = categories[0].articles  # no lazy load should occur | ||||
|         assert len(articles) == 2 | ||||
|         article_names = [article.name for article in articles] | ||||
|  | ||||
|         assert 'Article 1' in article_names | ||||
|         assert 'Article 2' in article_names | ||||
|         articles = categories[1].articles  # no lazy load should occur | ||||
|         assert len(articles) == 3 | ||||
|         article_names = [article.name for article in articles] | ||||
|         assert 'Article 3' in article_names | ||||
|         assert 'Article 4' in article_names | ||||
|         assert 'Article 5' in article_names | ||||
|         assert self.connection.query_count == query_count | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Konsta Vesterinen
					Konsta Vesterinen