diff --git a/taskflow/job.py b/taskflow/job.py index 82ca224c..91873409 100644 --- a/taskflow/job.py +++ b/taskflow/job.py @@ -26,15 +26,13 @@ from taskflow.openstack.common import uuidutils def task_and_state(task, state): - name = None - for a in ('name', '__name__'): - if hasattr(task, a): - attr = getattr(task, a) - if attr is not None: - name = str(attr) - break - if name is None: - name = str(task) + try: + name = task.name + except AttributeError: + try: + name = task.__name__ + except AttributeError: + name = str(task) return "%s:%s" % (name, state) diff --git a/taskflow/tests/unit/test_linear_flow.py b/taskflow/tests/unit/test_linear_flow.py index 2334c88a..115ad5c5 100644 --- a/taskflow/tests/unit/test_linear_flow.py +++ b/taskflow/tests/unit/test_linear_flow.py @@ -238,6 +238,9 @@ class LinearFlowTest(unittest.TestCase): # And now reset and resume. wf.reset() + wf.result_fetcher = result_fetcher + wf.task_listeners.append(task_listener) + self.assertEquals(states.PENDING, wf.state) wf.run(context) self.assertEquals(2, len(context)) diff --git a/taskflow/tests/unit/test_memory.py b/taskflow/tests/unit/test_memory.py index f41b56df..00162c20 100644 --- a/taskflow/tests/unit/test_memory.py +++ b/taskflow/tests/unit/test_memory.py @@ -155,6 +155,7 @@ class MemoryBackendTest(unittest.TestCase): self.assertEquals(1, len(call_log)) wf.reset() + j.associate(wf) self.assertEquals(states.PENDING, wf.state) wf.run(j.context) diff --git a/taskflow/tests/unit/test_sql_db_api.py b/taskflow/tests/unit/test_sql_db_api.py index 9fc65163..cffe56b9 100644 --- a/taskflow/tests/unit/test_sql_db_api.py +++ b/taskflow/tests/unit/test_sql_db_api.py @@ -304,9 +304,9 @@ class WorkflowTest(unittest.TestCase): @classmethod def teardownClass(cls): - for id in tsk_ids: + for id in cls.tsk_ids: db_api.task_destroy('', id) - for name in wf_names: + for name in cls.wf_names: db_api.workflow_destroy('', name) cls.tsk_ids = [] cls.tsk_names = [] @@ -408,7 +408,7 @@ class TaskTest(unittest.TestCase): @classmethod def teardownClass(cls): - for id in tsk_ids: + for id in cls.tsk_ids: db_api.task_destroy('', id) cls.tsk_ids = [] cls.tsk_names = []