diff --git a/sqlalchemy_utils/functions/batch_fetch.py b/sqlalchemy_utils/functions/batch_fetch.py index e833b43..5dc4dee 100644 --- a/sqlalchemy_utils/functions/batch_fetch.py +++ b/sqlalchemy_utils/functions/batch_fetch.py @@ -1,4 +1,6 @@ +from collections import defaultdict import six +import sqlalchemy as sa from sqlalchemy.orm import RelationshipProperty from sqlalchemy.orm.attributes import set_committed_value from sqlalchemy.orm.session import object_session @@ -104,24 +106,14 @@ class FetchingCoordinator(object): if should_populate_backrefs: subpath = with_backrefs(subpath) - batch_fetch( - related_entities, + return self.__class__(related_entities).fetcher_for_attr_path( subpath ) - return else: - return getattr( + attr_path = getattr( self.first.__class__, attrs[0] ) - else: - return attr_path - - def fetch_relation_entities(self): - if len(self.prop.remote_side) > 1: - raise Exception( - 'Only relationships with single remote side columns ' - 'are supported.' - ) + return self.fetcher_for_property(attr_path.property) def fetcher_for_property(self, property_): if not isinstance(property_, RelationshipProperty): @@ -130,12 +122,18 @@ class FetchingCoordinator(object): ) if property_.secondary is not None: - return ManyToManyFetcher(self, property_) + fetcher_class = ManyToManyFetcher else: if property_.direction.name == 'MANYTOONE': - return ManyToOneFetcher(self, property_) + fetcher_class = ManyToOneFetcher else: - return OneToManyFetcher(self, property_) + fetcher_class = OneToManyFetcher + + return fetcher_class( + self.entities, + property_, + self.should_populate_backrefs + ) def fetcher_for_attr_path(self, attr_path): if isinstance(attr_path, with_backrefs): @@ -144,15 +142,19 @@ class FetchingCoordinator(object): else: self.should_populate_backrefs = False - attr = self.parse_attr_path(attr_path, self.should_populate_backrefs) - if not attr: - return - return self.fetcher_for_property(attr.property) + return self.parse_attr_path( + attr_path, + self.should_populate_backrefs + ) def __call__(self, attr_path): if isinstance(attr_path, compound_path): + fetchers = [] for path in attr_path.attr_paths: - self(path) + fetchers.append(self.fetcher_for_attr_path(path)) + + fetcher = CompoundFetcher(*fetchers) + print fetcher.condition else: fetcher = self.fetcher_for_attr_path(attr_path) if not fetcher: @@ -162,28 +164,33 @@ class FetchingCoordinator(object): class CompoundFetcher(object): - def __init__(self, coordinator, path): - self.coordinator = coordinator - self.entities = coordinator.entities - self.first = self.entities[0] - self.session = object_session(self.first) + def __init__(self, *fetchers): + if not all(fetchers[0].model == fetcher.model for fetcher in fetchers): + raise Exception( + 'Each relationship property must have the same class when ' + 'using CompoundFetcher.' + ) + self.fetchers = fetchers + + @property + def condition(self): + return sa.or_( + *[fetcher.condition for fetcher in self.fetchers] + ) + + @property + def local_values(self): + pass class Fetcher(object): - def __init__(self, coordinator, property_): - self.coordinator = coordinator + def __init__(self, entities, property_, populate_backrefs=False): + self.should_populate_backrefs = populate_backrefs + self.entities = entities self.prop = property_ self.model = self.prop.mapper.class_ - self.entities = coordinator.entities self.first = self.entities[0] self.session = object_session(self.first) - self.init_parent_dict() - - def init_parent_dict(self): - self.parent_dict = dict( - (self.local_values(entity), []) - for entity in self.entities - ) @property def local_values_list(self): @@ -225,7 +232,7 @@ class Fetcher(object): self.parent_dict[self.local_values(entity)] ) - if self.coordinator.should_populate_backrefs: + if self.should_populate_backrefs: self.populate_backrefs(self.related_entities) @property @@ -271,6 +278,7 @@ class ManyToManyFetcher(Fetcher): ) def fetch(self): + self.parent_dict = defaultdict(list) for entity, parent_id in self.related_entities: self.parent_dict[parent_id].append( entity @@ -278,19 +286,15 @@ class ManyToManyFetcher(Fetcher): class ManyToOneFetcher(Fetcher): - def init_parent_dict(self): - self.parent_dict = dict( - (self.local_values(entity), None) - for entity in self.entities - ) - def fetch(self): + self.parent_dict = defaultdict(lambda: None) for entity in self.related_entities: self.parent_dict[getattr(entity, self.remote_column_name)] = entity class OneToManyFetcher(Fetcher): def fetch(self): + self.parent_dict = defaultdict(list) for entity in self.related_entities: self.parent_dict[getattr(entity, self.remote_column_name)].append( entity diff --git a/tests/batch_fetch/test_compound_fetching.py b/tests/batch_fetch/test_compound_fetching.py index d320b63..ca6784f 100644 --- a/tests/batch_fetch/test_compound_fetching.py +++ b/tests/batch_fetch/test_compound_fetching.py @@ -26,7 +26,7 @@ class TestCompoundBatchFetching(TestCase): ) class Equipment(self.Base): - __tablename__ = 'article' + __tablename__ = 'equipment' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))