diff --git a/taskflow/patterns/linear_flow.py b/taskflow/patterns/linear_flow.py index 9283c05c..3d6d330f 100644 --- a/taskflow/patterns/linear_flow.py +++ b/taskflow/patterns/linear_flow.py @@ -16,18 +16,40 @@ # License for the specific language governing permissions and limitations # under the License. +from taskflow import exceptions as exc from taskflow.patterns import ordered_flow class Flow(ordered_flow.Flow): - """A linear chain of *independent* tasks that can be applied as one unit or - rolled back as one unit.""" + """A linear chain of tasks that can be applied as one unit or + rolled back as one unit. Each task in the chain may have requirements + which are satisfied by the previous task in the chain.""" def __init__(self, name, tolerant=False, parents=None): super(Flow, self).__init__(name, tolerant, parents) self._tasks = [] + def _validate_provides(self, task): + last_provides = set() + last_provider = None + if self._tasks: + last_provider = self._tasks[-1] + last_provides = last_provider.provides() + # Ensure that the last task provides all the needed input for this + # task to run correctly. + req_diff = task.requires().difference(last_provides) + if req_diff: + if last_provider is None: + msg = ("There is no previous task providing the outputs %s" + " for %s to correctly execute.") % (req_diff, task) + else: + msg = ("%s does not provide the needed outputs %s for %s to" + " correctly execute.") + msg = msg % (last_provider, req_diff, task) + raise exc.InvalidStateException(msg) + def add(self, task): + self._validate_provides(task) self._tasks.append(task) def order(self): diff --git a/taskflow/tests/unit/test_linear_flow.py b/taskflow/tests/unit/test_linear_flow.py index e8d50e0d..1d575f7b 100644 --- a/taskflow/tests/unit/test_linear_flow.py +++ b/taskflow/tests/unit/test_linear_flow.py @@ -19,6 +19,7 @@ import functools import unittest +from taskflow import exceptions as exc from taskflow import states from taskflow import wrappers @@ -122,6 +123,33 @@ class LinearFlowTest(unittest.TestCase): self.assertEquals('reverted', run_context[1]) self.assertEquals(1, len(run_context)) + def test_not_satisfied_inputs_previous(self): + wf = lw.Flow("the-test-action") + + def task_a(context, *args, **kwargs): + pass + + def task_b(context, c, *args, **kwargs): + pass + + wf.add(wrappers.FunctorTask(None, task_a, null_functor, + extract_requires=True)) + self.assertRaises(exc.InvalidStateException, + wf.add, + wrappers.FunctorTask(None, task_b, null_functor, + extract_requires=True)) + + def test_not_satisfied_inputs_no_previous(self): + wf = lw.Flow("the-test-action") + + def task_a(context, c, *args, **kwargs): + pass + + self.assertRaises(exc.InvalidStateException, + wf.add, + wrappers.FunctorTask(None, task_a, null_functor, + extract_requires=True)) + def test_interrupt_flow(self): wf = lw.Flow("the-int-action")