diff --git a/taskflow/engines/action_engine/compiler.py b/taskflow/engines/action_engine/compiler.py index fb81ba80d..2486b32e2 100644 --- a/taskflow/engines/action_engine/compiler.py +++ b/taskflow/engines/action_engine/compiler.py @@ -20,7 +20,6 @@ import threading from taskflow import exceptions as exc from taskflow import flow from taskflow import logging -from taskflow import retry from taskflow import task from taskflow.types import graph as gr from taskflow.types import tree as tr @@ -281,34 +280,22 @@ class PatternCompiler(object): self._freeze = freeze self._lock = threading.Lock() self._compilation = None + self._flatten_matchers = [ + ((flow.Flow,), self._flatten_flow), + ((task.BaseTask,), self._flatten_task), + ] def _flatten(self, item, parent): """Flattens a item (pattern, task) into a graph + tree node.""" - functor = self._find_flattener(item, parent) + functor = misc.match_type_handler(item, self._flatten_matchers) + if not functor: + raise TypeError("Unknown item '%s' (%s) requested to flatten" + % (item, type(item))) self._pre_item_flatten(item) graph, node = functor(item, parent) self._post_item_flatten(item, graph, node) return graph, node - def _find_flattener(self, item, parent): - """Locates the flattening function to use to flatten the given item.""" - if isinstance(item, flow.Flow): - return self._flatten_flow - elif isinstance(item, task.BaseTask): - return self._flatten_task - elif isinstance(item, retry.Retry): - if parent is None: - raise TypeError("Retry controller '%s' (%s) must only be used" - " as a flow constructor parameter and not as a" - " root component" % (item, type(item))) - else: - raise TypeError("Retry controller '%s' (%s) must only be used" - " as a flow constructor parameter and not as a" - " flow added component" % (item, type(item))) - else: - raise TypeError("Unknown item '%s' (%s) requested to flatten" - % (item, type(item))) - def _connect_retry(self, retry, graph): graph.add_node(retry) diff --git a/taskflow/storage.py b/taskflow/storage.py index e96874acb..af29218bd 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -132,6 +132,10 @@ class Storage(object): self._transients = {} self._injected_args = {} self._lock = lock_utils.ReaderWriterLock() + self._ensure_matchers = [ + ((task.BaseTask,), self._ensure_task), + ((retry.Retry,), self._ensure_retry), + ] # NOTE(imelnikov): failure serialization looses information, # so we cache failures here, in atom name -> failure mapping. @@ -168,17 +172,14 @@ class Storage(object): Returns uuid for the atomdetail that is/was created. """ - if isinstance(atom, task.BaseTask): - return self._ensure_task(atom.name, - misc.get_version_string(atom), - atom.save_as) - elif isinstance(atom, retry.Retry): - return self._ensure_retry(atom.name, - misc.get_version_string(atom), - atom.save_as) + functor = misc.match_type_handler(atom, self._ensure_matchers) + if not functor: + raise TypeError("Unknown item '%s' (%s) requested to ensure" + % (atom, type(atom))) else: - raise TypeError("Object of type 'atom' expected not" - " '%s' (%s)" % (atom, type(atom))) + return functor(atom.name, + misc.get_version_string(atom), + atom.save_as) def _ensure_task(self, task_name, task_version, result_mapping): """Ensures there is a taskdetail that corresponds to the task info. diff --git a/taskflow/tests/unit/action_engine/test_compile.py b/taskflow/tests/unit/action_engine/test_compile.py index a290c50be..445fd7c17 100644 --- a/taskflow/tests/unit/action_engine/test_compile.py +++ b/taskflow/tests/unit/action_engine/test_compile.py @@ -34,9 +34,7 @@ class PatternCompileTest(test.TestCase): def test_retry(self): r = retry.AlwaysRevert('r1') - msg_regex = "^Retry controller .* must only be used .*" - self.assertRaisesRegexp(TypeError, msg_regex, - compiler.PatternCompiler(r).compile) + self.assertRaises(TypeError, compiler.PatternCompiler(r).compile) def test_wrong_object(self): msg_regex = '^Unknown item .* requested to flatten' diff --git a/taskflow/tests/unit/test_retries.py b/taskflow/tests/unit/test_retries.py index b459184b4..edcc6d8b5 100644 --- a/taskflow/tests/unit/test_retries.py +++ b/taskflow/tests/unit/test_retries.py @@ -377,12 +377,12 @@ class RetryTest(utils.EngineTestBase): def test_run_just_retry(self): flow = utils.OneReturnRetry(provides='x') engine = self._make_engine(flow) - self.assertRaisesRegexp(TypeError, 'Retry controller', engine.run) + self.assertRaises(TypeError, engine.run) def test_use_retry_as_a_task(self): flow = lf.Flow('test').add(utils.OneReturnRetry(provides='x')) engine = self._make_engine(flow) - self.assertRaisesRegexp(TypeError, 'Retry controller', engine.run) + self.assertRaises(TypeError, engine.run) def test_resume_flow_that_had_been_interrupted_during_retrying(self): flow = lf.Flow('flow-1', retry.Times(3, 'r1')).add( diff --git a/taskflow/utils/misc.py b/taskflow/utils/misc.py index 39708b5cd..a3bbe274b 100644 --- a/taskflow/utils/misc.py +++ b/taskflow/utils/misc.py @@ -87,6 +87,18 @@ def find_monotonic(allow_time_time=False): return None +def match_type_handler(item, type_handlers): + """Matches a given items type using the given match types + handlers. + + Returns the handler if a type match occurs, otherwise none. + """ + for (match_types, handler_func) in type_handlers: + if isinstance(item, match_types): + return handler_func + else: + return None + + def countdown_iter(start_at, decr=1): """Generator that decrements after each generation until <= zero.