diff --git a/taskflow/patterns/linear_flow.py b/taskflow/patterns/linear_flow.py index bcf7f960..643ccced 100644 --- a/taskflow/patterns/linear_flow.py +++ b/taskflow/patterns/linear_flow.py @@ -20,6 +20,16 @@ from taskflow import exceptions as exc from taskflow.patterns import ordered_flow +def _convert_to_set(items): + if not items: + return set() + if isinstance(items, set): + return items + if isinstance(items, dict): + return items.keys() + return set(iter(items)) + + class Flow(ordered_flow.Flow): """A linear chain of tasks that can be applied as one unit or rolled back as one unit. Each task in the chain may have requirements @@ -39,14 +49,15 @@ class Flow(ordered_flow.Flow): return inputs def _validate_provides(self, task): + requires = _convert_to_set(task.requires()) last_provides = set() last_provider = None if self._tasks: last_provider = self._tasks[-1] - last_provides = last_provider.provides() + last_provides = _convert_to_set(last_provider.provides()) # Ensure that the last task provides all the needed input for this # task to run correctly. - req_diff = task.requires().difference(last_provides) + req_diff = requires.difference(last_provides) if req_diff: if last_provider is None: msg = ("There is no previous task providing the outputs %s"