diff --git a/sqlalchemy_utils/functions/batch_fetch.py b/sqlalchemy_utils/functions/batch_fetch.py index cbe4715..19796c0 100644 --- a/sqlalchemy_utils/functions/batch_fetch.py +++ b/sqlalchemy_utils/functions/batch_fetch.py @@ -143,30 +143,28 @@ def batch_fetch(entities, *attr_paths): """ if entities: - fetcher = FetchingCoordinator() - for attr_path in attr_paths: - fetcher(entities, attr_path) + for path in attr_paths: + fetcher = fetcher_factory(entities, path) + fetcher.fetch() + fetcher.populate() -class FetchingCoordinator(object): - def __call__(self, entities, path): - populate_backrefs = False - if isinstance(path, with_backrefs): - path = path.path - populate_backrefs = True +def fetcher_factory(entities, path): + populate_backrefs = False + if isinstance(path, with_backrefs): + path = path.path + populate_backrefs = True - if isinstance(path, CompositePath): - fetchers = [] - for path in path.paths: - fetchers.append( - Path.parse(entities, path, populate_backrefs).fetcher - ) + if isinstance(path, CompositePath): + fetchers = [] + for path in path.paths: + fetchers.append( + Path.parse(entities, path, populate_backrefs).fetcher + ) - fetcher = CompositeFetcher(*fetchers) - else: - fetcher = Path.parse(entities, path, populate_backrefs).fetcher - fetcher.fetch() - fetcher.populate() + return CompositeFetcher(*fetchers) + else: + return Path.parse(entities, path, populate_backrefs).fetcher class CompositeFetcher(object):