diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index 2883c487..71ae17f6 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -107,13 +107,13 @@ class Flow(linear_flow.Flow): # connections instead of automatically doing it for them?? for n in self._graph.nodes_iter(): n_providers = {} - n_requires = set(utils.get_attr(n.task, 'requires', [])) + n_requires = n.requires if n_requires: LOG.debug("Finding providers of %s for %s", n_requires, n) for p in self._graph.nodes_iter(): if n is p: continue - p_provides = set(utils.get_attr(p.task, 'provides', [])) + p_provides = p.provides p_satisfies = n_requires & p_provides if p_satisfies: # P produces for N so thats why we link P->N diff --git a/taskflow/patterns/linear_flow.py b/taskflow/patterns/linear_flow.py index 9a9df158..3cd69ba5 100644 --- a/taskflow/patterns/linear_flow.py +++ b/taskflow/patterns/linear_flow.py @@ -78,12 +78,12 @@ class Flow(base.Flow): def _associate_providers(self, runner): # Ensure that some previous task provides this input. who_provides = {} - task_requires = set(utils.get_attr(runner.task, 'requires', [])) + task_requires = runner.requires LOG.debug("Finding providers of %s for %s", task_requires, runner) for r in task_requires: provider = None for before_me in runner.runs_before: - if r in set(utils.get_attr(before_me.task, 'provides', [])): + if r in before_me.provides: provider = before_me break if provider: diff --git a/taskflow/patterns/resumption/logbook.py b/taskflow/patterns/resumption/logbook.py index 604522d2..4fd422c0 100644 --- a/taskflow/patterns/resumption/logbook.py +++ b/taskflow/patterns/resumption/logbook.py @@ -91,10 +91,9 @@ class Resumption(object): return task_details def _get_details(self, flow_details, runner): - task_id = runner.uuid - if task_id not in flow_details: + if runner.uuid not in flow_details: return (False, None) - details = flow_details[task_id] + details = flow_details[runner.uuid] has_completed = False for state in details.metadata.get('states', []): if state in (states.SUCCESS, states.FAILURE): @@ -102,7 +101,7 @@ class Resumption(object): break if not has_completed: return (False, None) - immediate_version = utils.get_task_version(runner.task) + immediate_version = runner.version recorded_version = details.metadata.get('version') if recorded_version is not None: if not utils.is_version_compatible(recorded_version, diff --git a/taskflow/utils.py b/taskflow/utils.py index 114c01ec..19d82081 100644 --- a/taskflow/utils.py +++ b/taskflow/utils.py @@ -200,6 +200,18 @@ class Runner(object): self.runs_before = [] self.result = None + @property + def requires(self): + return set(get_attr(self.task, 'requires', [])) + + @property + def provides(self): + return set(get_attr(self.task, 'provides', [])) + + @property + def optional(self): + return set(get_attr(self.task, 'optional', [])) + @property def version(self): return get_task_version(self.task) @@ -222,12 +234,12 @@ class Runner(object): kwargs[k] = who_made.result[k] else: kwargs[k] = None - optional_keys = set(get_attr(self.task, 'optional', [])) + optional_keys = self.optional optional_missing_keys = optional_keys - set(kwargs.keys()) if optional_missing_keys: for k in optional_missing_keys: for r in self.runs_before: - r_provides = set(get_attr(r.task, 'provides', [])) + r_provides = r.provides if k in r_provides and r.result and k in r.result: kwargs[k] = r.result[k] break