diff --git a/taskflow/tests/unit/test_memory.py b/taskflow/tests/unit/test_memory.py index e46b5b43..4d5a0aa6 100644 --- a/taskflow/tests/unit/test_memory.py +++ b/taskflow/tests/unit/test_memory.py @@ -29,6 +29,7 @@ from taskflow import job from taskflow import logbook from taskflow import states from taskflow import task +from taskflow import wrappers as wrap from taskflow.backends import memory from taskflow.patterns import linear_workflow as lw @@ -49,53 +50,7 @@ def close_all(*args): a.close() -class FunctorTask(task.Task): - def __init__(self, apply_functor, revert_functor): - super(FunctorTask, self).__init__("%s+%s" % (apply_functor.__name__, - revert_functor.__name__)) - self._apply_functor = apply_functor - self._revert_functor = revert_functor - - def apply(self, context, *args, **kwargs): - return self._apply_functor(context, *args, **kwargs) - - def revert(self, context, result, cause): - return self._revert_functor(context, result, cause) - - class MemoryBackendTest(unittest.TestCase): - def _createDummyWorkflow(self, j, name='dummy'): - wf = lw.Workflow(name) - - def wf_state_change_listener(context, wf, old_state): - if wf.name in j.logbook: - return - j.logbook.add_workflow(wf.name) - - def task_state_change_listener(context, state, wf, task, result=None): - metadata = None - wf_details = j.logbook.fetch_workflow(wf.name) - if state in [states.SUCCESS]: - metadata = { - 'result': result, - } - td_name = gen_task_name(task, state) - if td_name not in wf_details: - wf_details.add_task(logbook.TaskDetail(td_name, metadata)) - - def task_result_fetcher(context, wf, task): - wf_details = j.logbook.fetch_workflow(wf.name) - td_name = gen_task_name(task, states.SUCCESS) - if td_name in wf_details: - task_details = wf_details.fetch_tasks(td_name)[0] - return (True, task_details.metadata['result']) - return (False, None) - - wf.task_listeners.append(task_state_change_listener) - wf.listeners.append(wf_state_change_listener) - wf.result_fetcher = task_result_fetcher - return wf - def _createMemoryImpl(self, cons=1): worker_group = [] poisons = [] @@ -140,9 +95,11 @@ class MemoryBackendTest(unittest.TestCase): j.state = states.PENDING for j in my_jobs: # Create some dummy workflow for the job - wf = self._createDummyWorkflow(j) + wf = lw.Workflow('dummy') for i in range(0, 5): - wf.add(FunctorTask(null_functor, null_functor)) + t = wrap.FunctorTask(null_functor, null_functor) + wf.add(t) + j.associate(wf) j.state = states.RUNNING wf.run(j.context) j.state = states.SUCCESS @@ -184,7 +141,8 @@ class MemoryBackendTest(unittest.TestCase): self.assertEquals(states.CLAIMED, j.state) self.assertEquals('me', j.owner) - wf = self._createDummyWorkflow(j, "the-int-action") + wf = lw.Workflow("the-int-action") + j.associate(wf) self.assertEquals(states.PENDING, wf.state) call_log = [] @@ -198,9 +156,9 @@ class MemoryBackendTest(unittest.TestCase): def do_interrupt(context, *args, **kwargs): wf.interrupt() - task_1 = FunctorTask(do_1, null_functor) - task_1_5 = FunctorTask(do_interrupt, null_functor) - task_2 = FunctorTask(do_2, null_functor) + task_1 = wrap.FunctorTask(do_1, null_functor) + task_1_5 = wrap.FunctorTask(do_interrupt, null_functor) + task_2 = wrap.FunctorTask(do_2, null_functor) wf.add(task_1) wf.add(task_1_5) # Interrupt it after task_1 finishes @@ -231,8 +189,9 @@ class MemoryBackendTest(unittest.TestCase): self.assertEquals(states.CLAIMED, j.state) self.assertEquals('me', j.owner) - wf = self._createDummyWorkflow(j, 'the-line-action') + wf = lw.Workflow('the-line-action') self.assertEquals(states.PENDING, wf.state) + j.associate(wf) call_log = [] @@ -242,8 +201,8 @@ class MemoryBackendTest(unittest.TestCase): def do_2(context, *args, **kwargs): call_log.append(2) - wf.add(FunctorTask(do_1, null_functor)) - wf.add(FunctorTask(do_2, null_functor)) + wf.add(wrap.FunctorTask(do_1, null_functor)) + wf.add(wrap.FunctorTask(do_2, null_functor)) wf.run(j.context) self.assertEquals(1, len(j.logbook))