diff --git a/taskflow/engines/action_engine/compiler.py b/taskflow/engines/action_engine/compiler.py index e27b1f8f..ba8a1ba2 100644 --- a/taskflow/engines/action_engine/compiler.py +++ b/taskflow/engines/action_engine/compiler.py @@ -26,6 +26,7 @@ from taskflow import task from taskflow.types import graph as gr from taskflow.types import tree as tr from taskflow.utils import iter_utils +from taskflow.utils import misc from taskflow.flow import (LINK_INVARIANT, LINK_RETRY) # noqa @@ -104,10 +105,6 @@ def _add_update_edges(graph, nodes_from, nodes_to, attr_dict=None): class TaskCompiler(object): """Non-recursive compiler of tasks.""" - @staticmethod - def handles(obj): - return isinstance(obj, task.BaseTask) - def compile(self, task, parent=None): graph = gr.DiGraph(name=task.name) graph.add_node(task, kind=TASK) @@ -120,10 +117,6 @@ class TaskCompiler(object): class FlowCompiler(object): """Recursive compiler of flows.""" - @staticmethod - def handles(obj): - return isinstance(obj, flow.Flow) - def __init__(self, deep_compiler_func): self._deep_compiler_func = deep_compiler_func @@ -274,17 +267,20 @@ class PatternCompiler(object): self._freeze = freeze self._lock = threading.Lock() self._compilation = None - self._matchers = (FlowCompiler(self._compile), TaskCompiler()) + self._matchers = [ + (flow.Flow, FlowCompiler(self._compile)), + (task.BaseTask, TaskCompiler()), + ] self._level = 0 def _compile(self, item, parent=None): """Compiles a item (pattern, task) into a graph + tree node.""" - for m in self._matchers: - if m.handles(item): - self._pre_item_compile(item) - graph, node = m.compile(item, parent=parent) - self._post_item_compile(item, graph, node) - return graph, node + item_compiler = misc.match_type(item, self._matchers) + if item_compiler is not None: + self._pre_item_compile(item) + graph, node = item_compiler.compile(item, parent=parent) + self._post_item_compile(item, graph, node) + return graph, node else: raise TypeError("Unknown object '%s' (%s) requested to compile" % (item, type(item)))