From ee8b97f0f78b39073ac6da21eeeb2ff07d7c951e Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Sat, 17 Aug 2013 18:49:02 +0300 Subject: [PATCH] Added more tests and refactored batch fetch --- sqlalchemy_utils/functions/batch_fetch.py | 188 +++++++++++------- .../test_many_to_one_relationships.py | 12 +- 2 files changed, 130 insertions(+), 70 deletions(-) diff --git a/sqlalchemy_utils/functions/batch_fetch.py b/sqlalchemy_utils/functions/batch_fetch.py index e7fc2f7..7e67375 100644 --- a/sqlalchemy_utils/functions/batch_fetch.py +++ b/sqlalchemy_utils/functions/batch_fetch.py @@ -74,49 +74,17 @@ def batch_fetch(entities, *attr_paths): """ if entities: - fetcher = BatchFetcher(entities) + fetcher = FetchingCoordinator(entities) for attr_path in attr_paths: fetcher(attr_path) -class BatchFetcher(object): +class FetchingCoordinator(object): def __init__(self, entities): self.entities = entities self.first = entities[0] - self.parent_ids = [entity.id for entity in entities] self.session = object_session(self.first) - def populate_backrefs(self, related_entities): - """ - Populates backrefs for given related entities. - """ - - backref_dict = dict( - (entity.id, []) for entity, parent_id in related_entities - ) - for entity, parent_id in related_entities: - backref_dict[entity.id].append( - self.session.query(self.first.__class__).get(parent_id) - ) - for entity, parent_id in related_entities: - set_committed_value( - entity, self.prop.back_populates, backref_dict[entity.id] - ) - - def populate_entities(self): - """ - Populate batch fetched entities to parent objects. - """ - for entity in self.entities: - set_committed_value( - entity, - self.prop.key, - self.parent_dict[entity.id] - ) - - if self.should_populate_backrefs: - self.populate_backrefs(self.related_entities) - def parse_attr_path(self, attr_path, should_populate_backrefs): if isinstance(attr_path, six.string_types): attrs = attr_path.split('.') @@ -150,21 +118,97 @@ class BatchFetcher(object): 'are supported.' ) - column_name = list(self.prop.remote_side)[0].name - - self.related_entities = ( - self.session.query(self.model) - .filter( - getattr(self.model, column_name).in_(self.parent_ids) + def fetcher(self, property_): + if not isinstance(property_, RelationshipProperty): + raise Exception( + 'Given attribute is not a relationship property.' ) + + if property_.secondary is not None: + return ManyToManyFetcher(self, property_) + else: + if property_.direction.name == 'MANYTOONE': + return ManyToOneFetcher(self, property_) + else: + return OneToManyFetcher(self, property_) + + def __call__(self, attr_path): + if isinstance(attr_path, with_backrefs): + self.should_populate_backrefs = True + attr_path = attr_path.attr_path + else: + self.should_populate_backrefs = False + + attr = self.parse_attr_path(attr_path, self.should_populate_backrefs) + if not attr: + return + + fetcher = self.fetcher(attr.property) + fetcher.fetch() + fetcher.populate() + + +class Fetcher(object): + def __init__(self, coordinator, property_): + self.coordinator = coordinator + self.prop = property_ + self.model = self.prop.mapper.class_ + self.entities = coordinator.entities + self.first = self.entities[0] + self.session = object_session(self.first) + + for entity in self.entities: + self.parent_dict = dict( + (self.local_values(entity), []) + for entity in self.entities + ) + + @property + def local_values_list(self): + return [ + self.local_values(entity) + for entity in self.entities + ] + + def local_values(self, entity): + return getattr(entity, list(self.prop.local_columns)[0].name) + + def populate_backrefs(self, related_entities): + """ + Populates backrefs for given related entities. + """ + + backref_dict = dict( + (entity.id, []) for entity, parent_id in related_entities ) - - for entity in self.related_entities: - self.parent_dict[getattr(entity, column_name)].append( - entity + for entity, parent_id in related_entities: + backref_dict[entity.id].append( + self.session.query(self.first.__class__).get(parent_id) + ) + for entity, parent_id in related_entities: + set_committed_value( + entity, self.prop.back_populates, backref_dict[entity.id] ) - def fetch_association_entities(self): + def populate(self): + """ + Populate batch fetched entities to parent objects. + """ + for entity in self.entities: + set_committed_value( + entity, + self.prop.key, + self.parent_dict[self.local_values(entity)] + ) + + if self.coordinator.should_populate_backrefs: + self.populate_backrefs(self.related_entities) + + +class ManyToManyFetcher(Fetcher): + def fetch(self): + parent_ids = [entity.id for entity in self.entities] + column_name = None for column in self.prop.remote_side: for fk in column.foreign_keys: @@ -183,7 +227,7 @@ class BatchFetcher(object): ) .filter( getattr(self.prop.secondary.c, column_name).in_( - self.parent_ids + parent_ids ) ) ) @@ -192,30 +236,38 @@ class BatchFetcher(object): entity ) - def __call__(self, attr_path): - self.parent_dict = dict( - (entity.id, []) for entity in self.entities + +class ManyToOneFetcher(Fetcher): + def fetch(self): + column_name = list(self.prop.remote_side)[0].name + + self.related_entities = ( + self.session.query(self.model) + .filter( + getattr(self.model, column_name).in_(self.local_values_list) + ) ) - if isinstance(attr_path, with_backrefs): - self.should_populate_backrefs = True - attr_path = attr_path.attr_path - else: - self.should_populate_backrefs = False - attr = self.parse_attr_path(attr_path, self.should_populate_backrefs) - if not attr: - return - - self.prop = attr.property - if not isinstance(self.prop, RelationshipProperty): - raise Exception( - 'Given attribute is not a relationship property.' + for entity in self.related_entities: + self.parent_dict[getattr(entity, column_name)].append( + entity ) - self.model = self.prop.mapper.class_ - if self.prop.secondary is None: - self.fetch_relation_entities() - else: - self.fetch_association_entities() - self.populate_entities() +class OneToManyFetcher(Fetcher): + def fetch(self): + parent_ids = [entity.id for entity in self.entities] + + column_name = list(self.prop.remote_side)[0].name + + self.related_entities = ( + self.session.query(self.model) + .filter( + getattr(self.model, column_name).in_(parent_ids) + ) + ) + + for entity in self.related_entities: + self.parent_dict[getattr(entity, column_name)].append( + entity + ) diff --git a/tests/batch_fetch/test_many_to_one_relationships.py b/tests/batch_fetch/test_many_to_one_relationships.py index c8a0e6f..36b0af4 100644 --- a/tests/batch_fetch/test_many_to_one_relationships.py +++ b/tests/batch_fetch/test_many_to_one_relationships.py @@ -29,8 +29,16 @@ class TestBatchFetchManyToOneRelationships(TestCase): def setup_method(self, method): TestCase.setup_method(self, method) articles = [ - self.Article(name=u'Article 1', author=self.User(name=u'John')), - self.Article(name=u'Article 2', author=self.User(name=u'Matt')), + self.Article( + id=1, + name=u'Article 1', + author=self.User(id=333, name=u'John') + ), + self.Article( + id=2, + name=u'Article 2', + author=self.User(id=334, name=u'Matt') + ), ] self.session.add_all(articles) self.session.commit()