From 55110111851939cef650ff400f57598f7fb484d2 Mon Sep 17 00:00:00 2001 From: Min Pae Date: Sun, 15 Feb 2015 16:55:06 -0800 Subject: [PATCH] adding check for str/unicode type in requires When the requires argument for an Atom is passed in as a string, each character of the string is iterated over to build up a requirement list. This works for simple one letter argument names but not for long argument names. Added check for str and unicode types to prevent iterating over a string. Change-Id: Ida584221b48966d26935fb2ede0075aabb7ce972 --- taskflow/atom.py | 5 ++++- taskflow/tests/unit/test_arguments_passing.py | 10 ++++++++++ taskflow/tests/unit/worker_based/test_worker.py | 2 +- taskflow/tests/utils.py | 6 ++++++ 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/taskflow/atom.py b/taskflow/atom.py index 1c5e61ef..82f7a5e3 100644 --- a/taskflow/atom.py +++ b/taskflow/atom.py @@ -100,7 +100,10 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer, required = {} # add reqs to required mappings if reqs: - required.update((a, a) for a in reqs) + if isinstance(reqs, six.string_types): + required.update({reqs: reqs}) + else: + required.update((a, a) for a in reqs) # add req_args to required mappings if do_infer is set if do_infer: diff --git a/taskflow/tests/unit/test_arguments_passing.py b/taskflow/tests/unit/test_arguments_passing.py index fb4744bd..c84d8534 100644 --- a/taskflow/tests/unit/test_arguments_passing.py +++ b/taskflow/tests/unit/test_arguments_passing.py @@ -149,6 +149,16 @@ class ArgumentsPassingTest(utils.EngineTestBase): utils.TaskOneArg, rebind=object()) + def test_long_arg_name(self): + flow = utils.LongArgNameTask(requires='long_arg_name', + provides='result') + engine = self._make_engine(flow) + engine.storage.inject({'long_arg_name': 1}) + engine.run() + self.assertEqual(engine.storage.fetch_all(), { + 'long_arg_name': 1, 'result': 1 + }) + class SingleThreadedEngineTest(ArgumentsPassingTest, test.TestCase): diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py index 7020a931..597a64a4 100644 --- a/taskflow/tests/unit/worker_based/test_worker.py +++ b/taskflow/tests/unit/worker_based/test_worker.py @@ -34,7 +34,7 @@ class TestWorker(test.MockTestCase): self.exchange = 'test-exchange' self.topic = 'test-topic' self.threads_count = 5 - self.endpoint_count = 22 + self.endpoint_count = 23 # patch classes self.executor_mock, self.executor_inst_mock = self.patchClass( diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py index fbbb83c7..9e6e2a34 100644 --- a/taskflow/tests/utils.py +++ b/taskflow/tests/utils.py @@ -95,6 +95,12 @@ class FakeTask(object): pass +class LongArgNameTask(task.Task): + + def execute(self, long_arg_name): + return long_arg_name + + if six.PY3: RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception', 'BaseException', 'object']