MultiThreaded engine and parallel action

MultiThreaded engine was implemented to execute tasks in parallel:
- added parallel action that executes and reverts tasks in parallel;
- added thread-safe storage.

Change-Id: I4a1f78c95ae5d38660bd32ce21d2b3fb1b2af8ad
This commit is contained in:
Anastasia Karpinska
2013-09-04 13:05:57 +04:00
committed by Ivan A. Melnikov
parent a91dd6b0e4
commit 6ee4d32fc2
5 changed files with 289 additions and 1 deletions

View File

@@ -17,6 +17,9 @@
# License for the specific language governing permissions and limitations
# under the License.
from multiprocessing import pool
from taskflow.engines.action_engine import parallel_action
from taskflow.engines.action_engine import seq_action
from taskflow.engines.action_engine import task_action
@@ -89,3 +92,20 @@ class SingleThreadedActionEngine(ActionEngine):
blocks.LinearFlow: seq_action.SequentialAction,
blocks.ParallelFlow: seq_action.SequentialAction
}, t_storage.Storage(flow_detail))
class MultiThreadedActionEngine(ActionEngine):
def __init__(self, flow, flow_detail=None, thread_pool=None):
ActionEngine.__init__(self, flow, {
blocks.Task: task_action.TaskAction,
blocks.LinearFlow: seq_action.SequentialAction,
blocks.ParallelFlow: parallel_action.ParallelAction
}, t_storage.ThreadSafeStorage(flow_detail))
if thread_pool:
self._thread_pool = thread_pool
else:
self._thread_pool = pool.ThreadPool()
@property
def thread_pool(self):
return self._thread_pool

View File

@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from taskflow.engines.action_engine import base_action as base
from taskflow.utils import misc
class ParallelAction(base.Action):
def __init__(self, pattern, engine):
self._actions = [engine.to_action(pat) for pat in pattern.children]
def _map(self, engine, fn):
pool = engine.thread_pool
def call_fn(action):
try:
fn(action)
except Exception:
return misc.Failure()
else:
return None
failures = []
result_iter = pool.imap_unordered(call_fn, self._actions)
for result in result_iter:
if isinstance(result, misc.Failure):
failures.append(result)
if failures:
failures[0].reraise()
def execute(self, engine):
self._map(engine, lambda action: action.execute(engine))
def revert(self, engine):
self._map(engine, lambda action: action.revert(engine))

View File

@@ -22,6 +22,7 @@ from taskflow.persistence import flowdetail
from taskflow.persistence import logbook
from taskflow.persistence import taskdetail
from taskflow import states
from taskflow.utils import threading_utils
def temporary_flow_detail():
@@ -188,3 +189,7 @@ class Storage(object):
def get_flow_state(self):
"""Set state from flowdetails"""
return self._flowdetail.state
class ThreadSafeStorage(Storage):
__metaclass__ = threading_utils.ThreadSafeMeta

View File

@@ -16,6 +16,9 @@
# License for the specific language governing permissions and limitations
# under the License.
from multiprocessing import pool
import time
from taskflow import blocks
from taskflow import exceptions
from taskflow.persistence import taskdetail
@@ -29,18 +32,23 @@ from taskflow.engines.action_engine import engine as eng
class TestTask(task.Task):
def __init__(self, values=None, name=None):
def __init__(self, values=None, name=None, sleep=None):
super(TestTask, self).__init__(name)
if values is None:
self.values = []
else:
self.values = values
self._sleep = sleep
def execute(self, **kwargs):
if self._sleep:
time.sleep(self._sleep)
self.values.append(self.name)
return 5
def revert(self, **kwargs):
if self._sleep:
time.sleep(self._sleep)
self.values.append(self.name + ' reverted(%s)'
% kwargs.get('result'))
@@ -48,6 +56,8 @@ class TestTask(task.Task):
class FailingTask(TestTask):
def execute(self, **kwargs):
if self._sleep:
time.sleep(self._sleep)
raise RuntimeError('Woot!')
@@ -326,6 +336,46 @@ class EngineLinearFlowTest(EngineTestBase):
'fail reverted(Failure: RuntimeError: Woot!)',
'task2 reverted(5)', 'task1 reverted(5)'])
class EngineParallelFlowTest(EngineTestBase):
def test_parallel_flow_one_task(self):
flow = blocks.ParallelFlow().add(
blocks.Task(TestTask(self.values, name='task1', sleep=0.01))
)
self._make_engine(flow).run()
self.assertEquals(self.values, ['task1'])
def test_parallel_flow_two_tasks(self):
flow = blocks.ParallelFlow().add(
blocks.Task(TestTask(self.values, name='task1', sleep=0.01)),
blocks.Task(TestTask(self.values, name='task2', sleep=0.01))
)
self._make_engine(flow).run()
result = set(self.values)
self.assertEquals(result, set(['task1', 'task2']))
def test_parallel_revert_common(self):
flow = blocks.ParallelFlow().add(
blocks.Task(TestTask(self.values, name='task1')),
blocks.Task(FailingTask(self.values, sleep=0.01)),
blocks.Task(TestTask(self.values, name='task2'))
)
engine = self._make_engine(flow)
with self.assertRaisesRegexp(RuntimeError, '^Woot'):
engine.run()
def test_parallel_revert_exception_is_reraised(self):
flow = blocks.ParallelFlow().add(
blocks.Task(TestTask(self.values, name='task1')),
blocks.Task(NastyTask()),
blocks.Task(FailingTask(self.values, sleep=0.1))
)
engine = self._make_engine(flow)
with self.assertRaisesRegexp(RuntimeError, '^Gotcha'):
engine.run()
def test_sequential_flow_two_tasks_with_resumption(self):
flow = blocks.LinearFlow().add(
blocks.Task(TestTask(self.values, name='task1'), save_as='x1'),
@@ -350,6 +400,149 @@ class EngineLinearFlowTest(EngineTestBase):
class SingleThreadedEngineTest(EngineTaskTest,
EngineLinearFlowTest,
EngineParallelFlowTest,
test.TestCase):
def _make_engine(self, flow, flow_detail=None):
return eng.SingleThreadedActionEngine(flow, flow_detail=flow_detail)
class MultiThreadedEngineTest(EngineTaskTest,
EngineLinearFlowTest,
EngineParallelFlowTest,
test.TestCase):
@classmethod
def setUpClass(cls):
cls.thread_pool = pool.ThreadPool()
@classmethod
def tearDownClass(cls):
cls.thread_pool.close()
cls.thread_pool.join()
def _make_engine(self, flow, flow_detail=None):
return eng.MultiThreadedActionEngine(flow, flow_detail=flow_detail,
thread_pool=self.thread_pool)
def test_using_common_pool(self):
flow = blocks.Task(TestTask(self.values, name='task1'))
thread_pool = pool.ThreadPool()
e1 = eng.MultiThreadedActionEngine(flow, thread_pool=thread_pool)
e2 = eng.MultiThreadedActionEngine(flow, thread_pool=thread_pool)
self.assertIs(e1.thread_pool, e2.thread_pool)
def test_parallel_revert_specific(self):
flow = blocks.ParallelFlow().add(
blocks.Task(TestTask(self.values, name='task1', sleep=0.01)),
blocks.Task(FailingTask(sleep=0.01)),
blocks.Task(TestTask(self.values, name='task2', sleep=0.01))
)
engine = self._make_engine(flow)
with self.assertRaisesRegexp(RuntimeError, '^Woot'):
engine.run()
result = set(self.values)
self.assertEquals(result,
set(['task1', 'task2',
'task2 reverted(5)', 'task1 reverted(5)']))
def test_parallel_revert_exception_is_reraised_(self):
flow = blocks.ParallelFlow().add(
blocks.Task(TestTask(self.values, name='task1', sleep=0.01)),
blocks.Task(NastyTask()),
blocks.Task(FailingTask(sleep=0.01)),
blocks.Task(TestTask) # this should not get reverted
)
engine = self._make_engine(flow)
with self.assertRaisesRegexp(RuntimeError, '^Gotcha'):
engine.run()
result = set(self.values)
self.assertEquals(result, set(['task1', 'task1 reverted(5)']))
def test_nested_parallel_revert_exception_is_reraised(self):
flow = blocks.ParallelFlow().add(
blocks.Task(TestTask(self.values, name='task1')),
blocks.Task(TestTask(self.values, name='task2')),
blocks.ParallelFlow().add(
blocks.Task(TestTask(self.values, name='task3', sleep=0.1)),
blocks.Task(NastyTask()),
blocks.Task(FailingTask(sleep=0.01))
)
)
engine = self._make_engine(flow)
with self.assertRaisesRegexp(RuntimeError, '^Gotcha'):
engine.run()
result = set(self.values)
self.assertEquals(result, set(['task1', 'task1 reverted(5)',
'task2', 'task2 reverted(5)',
'task3', 'task3 reverted(5)']))
def test_parallel_revert_exception_do_not_revert_linear_tasks(self):
flow = blocks.LinearFlow().add(
blocks.Task(TestTask(self.values, name='task1')),
blocks.Task(TestTask(self.values, name='task2')),
blocks.ParallelFlow().add(
blocks.Task(TestTask(self.values, name='task3', sleep=0.1)),
blocks.Task(NastyTask()),
blocks.Task(FailingTask(sleep=0.01))
)
)
engine = self._make_engine(flow)
with self.assertRaisesRegexp(RuntimeError, '^Gotcha'):
engine.run()
result = set(self.values)
self.assertEquals(result, set(['task1', 'task2',
'task3', 'task3 reverted(5)']))
def test_parallel_nested_to_linear_revert(self):
flow = blocks.LinearFlow().add(
blocks.Task(TestTask(self.values, name='task1')),
blocks.Task(TestTask(self.values, name='task2')),
blocks.ParallelFlow().add(
blocks.Task(TestTask(self.values, name='task3', sleep=0.1)),
blocks.Task(FailingTask(sleep=0.01))
)
)
engine = self._make_engine(flow)
with self.assertRaisesRegexp(RuntimeError, '^Woot'):
engine.run()
result = set(self.values)
self.assertEquals(result, set(['task1', 'task1 reverted(5)',
'task2', 'task2 reverted(5)',
'task3', 'task3 reverted(5)']))
def test_linear_nested_to_parallel_revert(self):
flow = blocks.ParallelFlow().add(
blocks.Task(TestTask(self.values, name='task1')),
blocks.Task(TestTask(self.values, name='task2')),
blocks.LinearFlow().add(
blocks.Task(TestTask(self.values, name='task3', sleep=0.1)),
blocks.Task(FailingTask(self.values, name='fail', sleep=0.01))
)
)
engine = self._make_engine(flow)
with self.assertRaisesRegexp(RuntimeError, '^Woot'):
engine.run()
result = set(self.values)
self.assertEquals(result,
set(['task1', 'task1 reverted(5)',
'task2', 'task2 reverted(5)',
'task3', 'task3 reverted(5)',
'fail reverted(Failure: RuntimeError: Woot!)']))
def test_linear_nested_to_parallel_revert_exception(self):
flow = blocks.ParallelFlow().add(
blocks.Task(TestTask(self.values, name='task1', sleep=0.01)),
blocks.Task(TestTask(self.values, name='task2', sleep=0.01)),
blocks.LinearFlow().add(
blocks.Task(TestTask(self.values, name='task3')),
blocks.Task(NastyTask()),
blocks.Task(FailingTask(sleep=0.01))
)
)
engine = self._make_engine(flow)
with self.assertRaisesRegexp(RuntimeError, '^Gotcha'):
engine.run()
result = set(self.values)
self.assertEquals(result, set(['task1', 'task1 reverted(5)',
'task2', 'task2 reverted(5)',
'task3']))

View File

@@ -20,6 +20,7 @@ import logging
import threading
import threading2
import time
import types
LOG = logging.getLogger(__name__)
@@ -144,3 +145,21 @@ class ThreadGroupExecutor(object):
if not self._threads:
return
return self._group.join(timeout)
class ThreadSafeMeta(type):
"""Metaclass that adds locking to all pubic methods of a class"""
def __new__(cls, name, bases, attrs):
from taskflow import decorators
for attr_name, attr_value in attrs.iteritems():
if isinstance(attr_value, types.FunctionType):
if attr_name[0] != '_':
attrs[attr_name] = decorators.locked(attr_value)
return super(ThreadSafeMeta, cls).__new__(cls, name, bases, attrs)
def __call__(cls, *args, **kwargs):
instance = super(ThreadSafeMeta, cls).__call__(*args, **kwargs)
if not hasattr(instance, '_lock'):
instance._lock = threading.RLock()
return instance