diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index b0dbaa3a..7ececb9f 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -14,12 +14,13 @@ # License for the specific language governing permissions and limitations # under the License. -import abc +import collections import contextlib import threading from concurrent import futures from oslo.utils import excutils +import six from taskflow.engines.action_engine import compiler from taskflow.engines.action_engine import executor @@ -199,11 +200,6 @@ class ActionEngine(base.Engine): self._runtime.reset_all() self._change_state(states.PENDING) - @abc.abstractproperty - def _task_executor(self): - return self._task_executor_factory() - pass - @misc.cachedproperty def _compiler(self): return self._compiler_factory(self._flow) @@ -224,28 +220,105 @@ class SerialActionEngine(ActionEngine): """Engine that runs tasks in serial manner.""" _storage_factory = atom_storage.SingleThreadedStorage - @misc.cachedproperty - def _task_executor(self): - return executor.SerialTaskExecutor() + def __init__(self, flow, flow_detail, backend, options): + super(SerialActionEngine, self).__init__(flow, flow_detail, + backend, options) + self._task_executor = executor.SerialTaskExecutor() + + +class _ExecutorTypeMatch(collections.namedtuple('_ExecutorTypeMatch', + ['types', 'executor_cls'])): + def matches(self, executor): + return isinstance(executor, self.types) + + +class _ExecutorTextMatch(collections.namedtuple('_ExecutorTextMatch', + ['strings', 'executor_cls'])): + def matches(self, text): + return text.lower() in self.strings class ParallelActionEngine(ActionEngine): """Engine that runs tasks in parallel manner.""" _storage_factory = atom_storage.MultiThreadedStorage - @misc.cachedproperty - def _task_executor(self): - kwargs = { - 'executor': self._options.get('executor'), - 'max_workers': self._options.get('max_workers'), - } - # The reason we use the library/built-in futures is to allow for - # instances of that to be detected and handled correctly, instead of - # forcing everyone to use our derivatives... - if isinstance(kwargs['executor'], futures.ProcessPoolExecutor): - executor_cls = executor.ParallelProcessTaskExecutor - kwargs['dispatch_periodicity'] = self._options.get( - 'dispatch_periodicity') - else: - executor_cls = executor.ParallelThreadTaskExecutor + # One of these types should match when a object (non-string) is provided + # for the 'executor' option. + # + # NOTE(harlowja): the reason we use the library/built-in futures is to + # allow for instances of that to be detected and handled correctly, instead + # of forcing everyone to use our derivatives... + _executor_cls_matchers = [ + _ExecutorTypeMatch((futures.ThreadPoolExecutor,), + executor.ParallelThreadTaskExecutor), + _ExecutorTypeMatch((futures.ProcessPoolExecutor,), + executor.ParallelProcessTaskExecutor), + _ExecutorTypeMatch((futures.Executor,), + executor.ParallelThreadTaskExecutor), + ] + + # One of these should match when a string/text is provided for the + # 'executor' option (a mixed case equivalent is allowed since the match + # will be lower-cased before checking). + _executor_str_matchers = [ + _ExecutorTextMatch(frozenset(['processes', 'process']), + executor.ParallelProcessTaskExecutor), + _ExecutorTextMatch(frozenset(['thread', 'threads', 'threaded']), + executor.ParallelThreadTaskExecutor), + ] + + # Used when no executor is provided (either a string or object)... + _default_executor_cls = executor.ParallelThreadTaskExecutor + + def __init__(self, flow, flow_detail, backend, options): + super(ParallelActionEngine, self).__init__(flow, flow_detail, + backend, options) + # This ensures that any provided executor will be validated before + # we get to far in the compilation/execution pipeline... + self._task_executor = self._fetch_task_executor(self._options) + + @classmethod + def _fetch_task_executor(cls, options): + kwargs = {} + executor_cls = cls._default_executor_cls + # Match the desired executor to a class that will work with it... + desired_executor = options.get('executor') + if isinstance(desired_executor, six.string_types): + matched_executor_cls = None + for m in cls._executor_str_matchers: + if m.matches(desired_executor): + matched_executor_cls = m.executor_cls + break + if matched_executor_cls is None: + expected = set() + for m in cls._executor_str_matchers: + expected.update(m.strings) + raise ValueError("Unknown executor string '%s' expected" + " one of %s (or mixed case equivalent)" + % (desired_executor, list(expected))) + else: + executor_cls = matched_executor_cls + elif desired_executor is not None: + matched_executor_cls = None + for m in cls._executor_cls_matchers: + if m.matches(desired_executor): + matched_executor_cls = m.executor_cls + break + if matched_executor_cls is None: + expected = set() + for m in cls._executor_cls_matchers: + expected.update(m.types) + raise TypeError("Unknown executor type '%s' expected an" + " instance of %s" % (type(desired_executor), + list(expected))) + else: + executor_cls = matched_executor_cls + kwargs['executor'] = desired_executor + for k in getattr(executor_cls, 'OPTIONS', []): + if k == 'executor': + continue + try: + kwargs[k] = options[k] + except KeyError: + pass return executor_cls(**kwargs) diff --git a/taskflow/engines/action_engine/executor.py b/taskflow/engines/action_engine/executor.py index d110c313..8a31fddf 100644 --- a/taskflow/engines/action_engine/executor.py +++ b/taskflow/engines/action_engine/executor.py @@ -373,6 +373,8 @@ class ParallelTaskExecutor(TaskExecutor): to concurrent.Futures.Executor. """ + OPTIONS = frozenset(['max_workers']) + def __init__(self, executor=None, max_workers=None): self._executor = executor self._max_workers = max_workers @@ -429,6 +431,8 @@ class ParallelProcessTaskExecutor(ParallelTaskExecutor): the parent are executed on events in the child. """ + OPTIONS = frozenset(['max_workers', 'dispatch_periodicity']) + def __init__(self, executor=None, max_workers=None, dispatch_periodicity=None): super(ParallelProcessTaskExecutor, self).__init__( diff --git a/taskflow/engines/worker_based/engine.py b/taskflow/engines/worker_based/engine.py index df915fc9..8011222c 100644 --- a/taskflow/engines/worker_based/engine.py +++ b/taskflow/engines/worker_based/engine.py @@ -18,7 +18,6 @@ from taskflow.engines.action_engine import engine from taskflow.engines.worker_based import executor from taskflow.engines.worker_based import protocol as pr from taskflow import storage as t_storage -from taskflow.utils import misc class WorkerBasedActionEngine(engine.ActionEngine): @@ -45,17 +44,30 @@ class WorkerBasedActionEngine(engine.ActionEngine): _storage_factory = t_storage.SingleThreadedStorage - @misc.cachedproperty - def _task_executor(self): + def __init__(self, flow, flow_detail, backend, options): + super(WorkerBasedActionEngine, self).__init__(flow, flow_detail, + backend, options) + # This ensures that any provided executor will be validated before + # we get to far in the compilation/execution pipeline... + self._task_executor = self._fetch_task_executor(self._options, + self._flow_detail) + + @classmethod + def _fetch_task_executor(cls, options, flow_detail): try: - return self._options['executor'] + e = options['executor'] + if not isinstance(e, executor.WorkerTaskExecutor): + raise TypeError("Expected an instance of type '%s' instead of" + " type '%s' for 'executor' option" + % (executor.WorkerTaskExecutor, type(e))) + return e except KeyError: return executor.WorkerTaskExecutor( - uuid=self._flow_detail.uuid, - url=self._options.get('url'), - exchange=self._options.get('exchange', 'default'), - topics=self._options.get('topics', []), - transport=self._options.get('transport'), - transport_options=self._options.get('transport_options'), - transition_timeout=self._options.get('transition_timeout', - pr.REQUEST_TIMEOUT)) + uuid=flow_detail.uuid, + url=options.get('url'), + exchange=options.get('exchange', 'default'), + topics=options.get('topics', []), + transport=options.get('transport'), + transport_options=options.get('transport_options'), + transition_timeout=options.get('transition_timeout', + pr.REQUEST_TIMEOUT)) diff --git a/taskflow/tests/unit/action_engine/test_creation.py b/taskflow/tests/unit/action_engine/test_creation.py new file mode 100644 index 00000000..c8a0b436 --- /dev/null +++ b/taskflow/tests/unit/action_engine/test_creation.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import testtools + +from taskflow.engines.action_engine import engine +from taskflow.engines.action_engine import executor +from taskflow.patterns import linear_flow as lf +from taskflow.persistence import backends +from taskflow import test +from taskflow.tests import utils +from taskflow.types import futures as futures +from taskflow.utils import async_utils as au +from taskflow.utils import persistence_utils as pu + + +class ParallelCreationTest(test.TestCase): + @staticmethod + def _create_engine(**kwargs): + flow = lf.Flow('test-flow').add(utils.DummyTask()) + backend = backends.fetch({'connection': 'memory'}) + flow_detail = pu.create_flow_detail(flow, backend=backend) + options = kwargs.copy() + return engine.ParallelActionEngine(flow, flow_detail, + backend, options) + + def test_thread_string_creation(self): + for s in ['threads', 'threaded', 'thread']: + eng = self._create_engine(executor=s) + self.assertIsInstance(eng._task_executor, + executor.ParallelThreadTaskExecutor) + + def test_process_string_creation(self): + for s in ['process', 'processes']: + eng = self._create_engine(executor=s) + self.assertIsInstance(eng._task_executor, + executor.ParallelProcessTaskExecutor) + + def test_thread_executor_creation(self): + with futures.ThreadPoolExecutor(1) as e: + eng = self._create_engine(executor=e) + self.assertIsInstance(eng._task_executor, + executor.ParallelThreadTaskExecutor) + + def test_process_executor_creation(self): + with futures.ProcessPoolExecutor(1) as e: + eng = self._create_engine(executor=e) + self.assertIsInstance(eng._task_executor, + executor.ParallelProcessTaskExecutor) + + @testtools.skipIf(not au.EVENTLET_AVAILABLE, 'eventlet is not available') + def test_green_executor_creation(self): + with futures.GreenThreadPoolExecutor(1) as e: + eng = self._create_engine(executor=e) + self.assertIsInstance(eng._task_executor, + executor.ParallelThreadTaskExecutor) + + def test_sync_executor_creation(self): + with futures.SynchronousExecutor() as e: + eng = self._create_engine(executor=e) + self.assertIsInstance(eng._task_executor, + executor.ParallelThreadTaskExecutor) + + def test_invalid_creation(self): + self.assertRaises(ValueError, self._create_engine, executor='crap') + self.assertRaises(TypeError, self._create_engine, executor=2) + self.assertRaises(TypeError, self._create_engine, executor=object()) diff --git a/taskflow/tests/unit/worker_based/test_engine.py b/taskflow/tests/unit/worker_based/test_creation.py similarity index 50% rename from taskflow/tests/unit/worker_based/test_engine.py rename to taskflow/tests/unit/worker_based/test_creation.py index f274a829..6764926a 100644 --- a/taskflow/tests/unit/worker_based/test_engine.py +++ b/taskflow/tests/unit/worker_based/test_creation.py @@ -15,7 +15,9 @@ # under the License. from taskflow.engines.worker_based import engine +from taskflow.engines.worker_based import executor from taskflow.patterns import linear_flow as lf +from taskflow.persistence import backends from taskflow import test from taskflow.test import mock from taskflow.tests import utils @@ -23,24 +25,25 @@ from taskflow.utils import persistence_utils as pu class TestWorkerBasedActionEngine(test.MockTestCase): + @staticmethod + def _create_engine(**kwargs): + flow = lf.Flow('test-flow').add(utils.DummyTask()) + backend = backends.fetch({'connection': 'memory'}) + flow_detail = pu.create_flow_detail(flow, backend=backend) + options = kwargs.copy() + return engine.WorkerBasedActionEngine(flow, flow_detail, + backend, options) - def setUp(self): - super(TestWorkerBasedActionEngine, self).setUp() - self.broker_url = 'test-url' - self.exchange = 'test-exchange' - self.topics = ['test-topic1', 'test-topic2'] - - # patch classes - self.executor_mock, self.executor_inst_mock = self.patchClass( + def _patch_in_executor(self): + executor_mock, executor_inst_mock = self.patchClass( engine.executor, 'WorkerTaskExecutor', attach_as='executor') + return executor_mock, executor_inst_mock def test_creation_default(self): - flow = lf.Flow('test-flow').add(utils.DummyTask()) - _, flow_detail = pu.temporary_flow_detail() - engine.WorkerBasedActionEngine(flow, flow_detail, None, {}).compile() - + executor_mock, executor_inst_mock = self._patch_in_executor() + eng = self._create_engine() expected_calls = [ - mock.call.executor_class(uuid=flow_detail.uuid, + mock.call.executor_class(uuid=eng.storage.flow_uuid, url=None, exchange='default', topics=[], @@ -51,21 +54,34 @@ class TestWorkerBasedActionEngine(test.MockTestCase): self.assertEqual(self.master_mock.mock_calls, expected_calls) def test_creation_custom(self): - flow = lf.Flow('test-flow').add(utils.DummyTask()) - _, flow_detail = pu.temporary_flow_detail() - config = {'url': self.broker_url, 'exchange': self.exchange, - 'topics': self.topics, 'transport': 'memory', - 'transport_options': {}, 'transition_timeout': 200} - engine.WorkerBasedActionEngine( - flow, flow_detail, None, config).compile() - + executor_mock, executor_inst_mock = self._patch_in_executor() + topics = ['test-topic1', 'test-topic2'] + exchange = 'test-exchange' + broker_url = 'test-url' + eng = self._create_engine( + url=broker_url, + exchange=exchange, + transport='memory', + transport_options={}, + transition_timeout=200, + topics=topics) expected_calls = [ - mock.call.executor_class(uuid=flow_detail.uuid, - url=self.broker_url, - exchange=self.exchange, - topics=self.topics, + mock.call.executor_class(uuid=eng.storage.flow_uuid, + url=broker_url, + exchange=exchange, + topics=topics, transport='memory', transport_options={}, transition_timeout=200) ] self.assertEqual(self.master_mock.mock_calls, expected_calls) + + def test_creation_custom_executor(self): + ex = executor.WorkerTaskExecutor('a', 'test-exchange', ['test-topic']) + eng = self._create_engine(executor=ex) + self.assertIs(eng._task_executor, ex) + self.assertIsInstance(eng._task_executor, executor.WorkerTaskExecutor) + + def test_creation_invalid_custom_executor(self): + self.assertRaises(TypeError, self._create_engine, executor=2) + self.assertRaises(TypeError, self._create_engine, executor='blah')