diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index 1acfed738..1adc49245 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -17,6 +17,9 @@ # License for the specific language governing permissions and limitations # under the License. +from multiprocessing import pool + +from taskflow.engines.action_engine import parallel_action from taskflow.engines.action_engine import seq_action from taskflow.engines.action_engine import task_action @@ -89,3 +92,20 @@ class SingleThreadedActionEngine(ActionEngine): blocks.LinearFlow: seq_action.SequentialAction, blocks.ParallelFlow: seq_action.SequentialAction }, t_storage.Storage(flow_detail)) + + +class MultiThreadedActionEngine(ActionEngine): + def __init__(self, flow, flow_detail=None, thread_pool=None): + ActionEngine.__init__(self, flow, { + blocks.Task: task_action.TaskAction, + blocks.LinearFlow: seq_action.SequentialAction, + blocks.ParallelFlow: parallel_action.ParallelAction + }, t_storage.ThreadSafeStorage(flow_detail)) + if thread_pool: + self._thread_pool = thread_pool + else: + self._thread_pool = pool.ThreadPool() + + @property + def thread_pool(self): + return self._thread_pool diff --git a/taskflow/engines/action_engine/parallel_action.py b/taskflow/engines/action_engine/parallel_action.py new file mode 100644 index 000000000..4c883d1c4 --- /dev/null +++ b/taskflow/engines/action_engine/parallel_action.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from taskflow.engines.action_engine import base_action as base +from taskflow.utils import misc + + +class ParallelAction(base.Action): + + def __init__(self, pattern, engine): + self._actions = [engine.to_action(pat) for pat in pattern.children] + + def _map(self, engine, fn): + pool = engine.thread_pool + + def call_fn(action): + try: + fn(action) + except Exception: + return misc.Failure() + else: + return None + + failures = [] + result_iter = pool.imap_unordered(call_fn, self._actions) + for result in result_iter: + if isinstance(result, misc.Failure): + failures.append(result) + if failures: + failures[0].reraise() + + def execute(self, engine): + self._map(engine, lambda action: action.execute(engine)) + + def revert(self, engine): + self._map(engine, lambda action: action.revert(engine)) diff --git a/taskflow/storage.py b/taskflow/storage.py index 5e2776d32..679f68597 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -22,6 +22,7 @@ from taskflow.persistence import flowdetail from taskflow.persistence import logbook from taskflow.persistence import taskdetail from taskflow import states +from taskflow.utils import threading_utils def temporary_flow_detail(): @@ -188,3 +189,7 @@ class Storage(object): def get_flow_state(self): """Set state from flowdetails""" return self._flowdetail.state + + +class ThreadSafeStorage(Storage): + __metaclass__ = threading_utils.ThreadSafeMeta diff --git a/taskflow/tests/unit/test_action_engine.py b/taskflow/tests/unit/test_action_engine.py index 3f1c31ff9..5ea041d70 100644 --- a/taskflow/tests/unit/test_action_engine.py +++ b/taskflow/tests/unit/test_action_engine.py @@ -16,6 +16,9 @@ # License for the specific language governing permissions and limitations # under the License. +from multiprocessing import pool +import time + from taskflow import blocks from taskflow import exceptions from taskflow.persistence import taskdetail @@ -29,18 +32,23 @@ from taskflow.engines.action_engine import engine as eng class TestTask(task.Task): - def __init__(self, values=None, name=None): + def __init__(self, values=None, name=None, sleep=None): super(TestTask, self).__init__(name) if values is None: self.values = [] else: self.values = values + self._sleep = sleep def execute(self, **kwargs): + if self._sleep: + time.sleep(self._sleep) self.values.append(self.name) return 5 def revert(self, **kwargs): + if self._sleep: + time.sleep(self._sleep) self.values.append(self.name + ' reverted(%s)' % kwargs.get('result')) @@ -48,6 +56,8 @@ class TestTask(task.Task): class FailingTask(TestTask): def execute(self, **kwargs): + if self._sleep: + time.sleep(self._sleep) raise RuntimeError('Woot!') @@ -326,6 +336,46 @@ class EngineLinearFlowTest(EngineTestBase): 'fail reverted(Failure: RuntimeError: Woot!)', 'task2 reverted(5)', 'task1 reverted(5)']) + +class EngineParallelFlowTest(EngineTestBase): + + def test_parallel_flow_one_task(self): + flow = blocks.ParallelFlow().add( + blocks.Task(TestTask(self.values, name='task1', sleep=0.01)) + ) + self._make_engine(flow).run() + self.assertEquals(self.values, ['task1']) + + def test_parallel_flow_two_tasks(self): + flow = blocks.ParallelFlow().add( + blocks.Task(TestTask(self.values, name='task1', sleep=0.01)), + blocks.Task(TestTask(self.values, name='task2', sleep=0.01)) + ) + self._make_engine(flow).run() + + result = set(self.values) + self.assertEquals(result, set(['task1', 'task2'])) + + def test_parallel_revert_common(self): + flow = blocks.ParallelFlow().add( + blocks.Task(TestTask(self.values, name='task1')), + blocks.Task(FailingTask(self.values, sleep=0.01)), + blocks.Task(TestTask(self.values, name='task2')) + ) + engine = self._make_engine(flow) + with self.assertRaisesRegexp(RuntimeError, '^Woot'): + engine.run() + + def test_parallel_revert_exception_is_reraised(self): + flow = blocks.ParallelFlow().add( + blocks.Task(TestTask(self.values, name='task1')), + blocks.Task(NastyTask()), + blocks.Task(FailingTask(self.values, sleep=0.1)) + ) + engine = self._make_engine(flow) + with self.assertRaisesRegexp(RuntimeError, '^Gotcha'): + engine.run() + def test_sequential_flow_two_tasks_with_resumption(self): flow = blocks.LinearFlow().add( blocks.Task(TestTask(self.values, name='task1'), save_as='x1'), @@ -350,6 +400,149 @@ class EngineLinearFlowTest(EngineTestBase): class SingleThreadedEngineTest(EngineTaskTest, EngineLinearFlowTest, + EngineParallelFlowTest, test.TestCase): def _make_engine(self, flow, flow_detail=None): return eng.SingleThreadedActionEngine(flow, flow_detail=flow_detail) + + +class MultiThreadedEngineTest(EngineTaskTest, + EngineLinearFlowTest, + EngineParallelFlowTest, + test.TestCase): + + @classmethod + def setUpClass(cls): + cls.thread_pool = pool.ThreadPool() + + @classmethod + def tearDownClass(cls): + cls.thread_pool.close() + cls.thread_pool.join() + + def _make_engine(self, flow, flow_detail=None): + return eng.MultiThreadedActionEngine(flow, flow_detail=flow_detail, + thread_pool=self.thread_pool) + + def test_using_common_pool(self): + flow = blocks.Task(TestTask(self.values, name='task1')) + thread_pool = pool.ThreadPool() + e1 = eng.MultiThreadedActionEngine(flow, thread_pool=thread_pool) + e2 = eng.MultiThreadedActionEngine(flow, thread_pool=thread_pool) + self.assertIs(e1.thread_pool, e2.thread_pool) + + def test_parallel_revert_specific(self): + flow = blocks.ParallelFlow().add( + blocks.Task(TestTask(self.values, name='task1', sleep=0.01)), + blocks.Task(FailingTask(sleep=0.01)), + blocks.Task(TestTask(self.values, name='task2', sleep=0.01)) + ) + engine = self._make_engine(flow) + with self.assertRaisesRegexp(RuntimeError, '^Woot'): + engine.run() + result = set(self.values) + self.assertEquals(result, + set(['task1', 'task2', + 'task2 reverted(5)', 'task1 reverted(5)'])) + + def test_parallel_revert_exception_is_reraised_(self): + flow = blocks.ParallelFlow().add( + blocks.Task(TestTask(self.values, name='task1', sleep=0.01)), + blocks.Task(NastyTask()), + blocks.Task(FailingTask(sleep=0.01)), + blocks.Task(TestTask) # this should not get reverted + ) + engine = self._make_engine(flow) + with self.assertRaisesRegexp(RuntimeError, '^Gotcha'): + engine.run() + result = set(self.values) + self.assertEquals(result, set(['task1', 'task1 reverted(5)'])) + + def test_nested_parallel_revert_exception_is_reraised(self): + flow = blocks.ParallelFlow().add( + blocks.Task(TestTask(self.values, name='task1')), + blocks.Task(TestTask(self.values, name='task2')), + blocks.ParallelFlow().add( + blocks.Task(TestTask(self.values, name='task3', sleep=0.1)), + blocks.Task(NastyTask()), + blocks.Task(FailingTask(sleep=0.01)) + ) + ) + engine = self._make_engine(flow) + with self.assertRaisesRegexp(RuntimeError, '^Gotcha'): + engine.run() + result = set(self.values) + self.assertEquals(result, set(['task1', 'task1 reverted(5)', + 'task2', 'task2 reverted(5)', + 'task3', 'task3 reverted(5)'])) + + def test_parallel_revert_exception_do_not_revert_linear_tasks(self): + flow = blocks.LinearFlow().add( + blocks.Task(TestTask(self.values, name='task1')), + blocks.Task(TestTask(self.values, name='task2')), + blocks.ParallelFlow().add( + blocks.Task(TestTask(self.values, name='task3', sleep=0.1)), + blocks.Task(NastyTask()), + blocks.Task(FailingTask(sleep=0.01)) + ) + ) + engine = self._make_engine(flow) + with self.assertRaisesRegexp(RuntimeError, '^Gotcha'): + engine.run() + result = set(self.values) + self.assertEquals(result, set(['task1', 'task2', + 'task3', 'task3 reverted(5)'])) + + def test_parallel_nested_to_linear_revert(self): + flow = blocks.LinearFlow().add( + blocks.Task(TestTask(self.values, name='task1')), + blocks.Task(TestTask(self.values, name='task2')), + blocks.ParallelFlow().add( + blocks.Task(TestTask(self.values, name='task3', sleep=0.1)), + blocks.Task(FailingTask(sleep=0.01)) + ) + ) + engine = self._make_engine(flow) + with self.assertRaisesRegexp(RuntimeError, '^Woot'): + engine.run() + result = set(self.values) + self.assertEquals(result, set(['task1', 'task1 reverted(5)', + 'task2', 'task2 reverted(5)', + 'task3', 'task3 reverted(5)'])) + + def test_linear_nested_to_parallel_revert(self): + flow = blocks.ParallelFlow().add( + blocks.Task(TestTask(self.values, name='task1')), + blocks.Task(TestTask(self.values, name='task2')), + blocks.LinearFlow().add( + blocks.Task(TestTask(self.values, name='task3', sleep=0.1)), + blocks.Task(FailingTask(self.values, name='fail', sleep=0.01)) + ) + ) + engine = self._make_engine(flow) + with self.assertRaisesRegexp(RuntimeError, '^Woot'): + engine.run() + result = set(self.values) + self.assertEquals(result, + set(['task1', 'task1 reverted(5)', + 'task2', 'task2 reverted(5)', + 'task3', 'task3 reverted(5)', + 'fail reverted(Failure: RuntimeError: Woot!)'])) + + def test_linear_nested_to_parallel_revert_exception(self): + flow = blocks.ParallelFlow().add( + blocks.Task(TestTask(self.values, name='task1', sleep=0.01)), + blocks.Task(TestTask(self.values, name='task2', sleep=0.01)), + blocks.LinearFlow().add( + blocks.Task(TestTask(self.values, name='task3')), + blocks.Task(NastyTask()), + blocks.Task(FailingTask(sleep=0.01)) + ) + ) + engine = self._make_engine(flow) + with self.assertRaisesRegexp(RuntimeError, '^Gotcha'): + engine.run() + result = set(self.values) + self.assertEquals(result, set(['task1', 'task1 reverted(5)', + 'task2', 'task2 reverted(5)', + 'task3'])) diff --git a/taskflow/utils/threading_utils.py b/taskflow/utils/threading_utils.py index 05c5e3565..127ed7185 100644 --- a/taskflow/utils/threading_utils.py +++ b/taskflow/utils/threading_utils.py @@ -20,6 +20,7 @@ import logging import threading import threading2 import time +import types LOG = logging.getLogger(__name__) @@ -144,3 +145,21 @@ class ThreadGroupExecutor(object): if not self._threads: return return self._group.join(timeout) + + +class ThreadSafeMeta(type): + """Metaclass that adds locking to all pubic methods of a class""" + + def __new__(cls, name, bases, attrs): + from taskflow import decorators + for attr_name, attr_value in attrs.iteritems(): + if isinstance(attr_value, types.FunctionType): + if attr_name[0] != '_': + attrs[attr_name] = decorators.locked(attr_value) + return super(ThreadSafeMeta, cls).__new__(cls, name, bases, attrs) + + def __call__(cls, *args, **kwargs): + instance = super(ThreadSafeMeta, cls).__call__(*args, **kwargs) + if not hasattr(instance, '_lock'): + instance._lock = threading.RLock() + return instance