diff --git a/taskflow/engines/action_engine/runtime.py b/taskflow/engines/action_engine/runtime.py index 0fba861d3..061cca4cb 100644 --- a/taskflow/engines/action_engine/runtime.py +++ b/taskflow/engines/action_engine/runtime.py @@ -52,18 +52,25 @@ class Runtime(object): progress=0.0), 'retry': self.retry_action.change_state, } + schedulers = { + 'retry': self.retry_scheduler, + 'task': self.task_scheduler, + } for atom in self.analyzer.iterate_all_nodes(): metadata = {} walker = sc.ScopeWalker(self.compilation, atom, names_only=True) if isinstance(atom, task.BaseTask): check_transition_handler = st.check_task_transition change_state_handler = change_state_handlers['task'] + scheduler = schedulers['task'] else: check_transition_handler = st.check_retry_transition change_state_handler = change_state_handlers['retry'] + scheduler = schedulers['retry'] metadata['scope_walker'] = walker metadata['check_transition_handler'] = check_transition_handler metadata['change_state_handler'] = change_state_handler + metadata['scheduler'] = scheduler self._atom_cache[atom.name] = metadata @property @@ -90,6 +97,14 @@ class Runtime(object): def scheduler(self): return sched.Scheduler(self) + @misc.cachedproperty + def task_scheduler(self): + return sched.TaskScheduler(self) + + @misc.cachedproperty + def retry_scheduler(self): + return sched.RetryScheduler(self) + @misc.cachedproperty def retry_action(self): return ra.RetryAction(self._storage, @@ -110,6 +125,14 @@ class Runtime(object): check_transition_handler = metadata['check_transition_handler'] return check_transition_handler(current_state, target_state) + def fetch_scheduler(self, atom): + """Fetches the cached specific scheduler for the given atom.""" + # This does not check if the name exists (since this is only used + # internally to the engine, and is not exposed to atoms that will + # not exist and therefore doesn't need to handle that case). + metadata = self._atom_cache[atom.name] + return metadata['scheduler'] + def fetch_scopes_for(self, atom_name): """Fetches a walker of the visible scopes for the given atom.""" try: diff --git a/taskflow/engines/action_engine/scheduler.py b/taskflow/engines/action_engine/scheduler.py index 202218306..4ab0b0e14 100644 --- a/taskflow/engines/action_engine/scheduler.py +++ b/taskflow/engines/action_engine/scheduler.py @@ -17,22 +17,18 @@ import weakref from taskflow import exceptions as excp -from taskflow import retry as retry_atom from taskflow import states as st -from taskflow import task as task_atom from taskflow.types import failure -class _RetryScheduler(object): +class RetryScheduler(object): + """Schedules retry atoms.""" + def __init__(self, runtime): self._runtime = weakref.proxy(runtime) self._retry_action = runtime.retry_action self._storage = runtime.storage - @staticmethod - def handles(atom): - return isinstance(atom, retry_atom.Retry) - def schedule(self, retry): """Schedules the given retry atom for *future* completion. @@ -53,15 +49,13 @@ class _RetryScheduler(object): " intention: %s" % intention) -class _TaskScheduler(object): +class TaskScheduler(object): + """Schedules task atoms.""" + def __init__(self, runtime): self._storage = runtime.storage self._task_action = runtime.task_action - @staticmethod - def handles(atom): - return isinstance(atom, task_atom.BaseTask) - def schedule(self, task): """Schedules the given task atom for *future* completion. @@ -79,39 +73,28 @@ class _TaskScheduler(object): class Scheduler(object): - """Schedules atoms using actions to schedule.""" + """Safely schedules atoms using a runtime ``fetch_scheduler`` routine.""" def __init__(self, runtime): - self._schedulers = [ - _RetryScheduler(runtime), - _TaskScheduler(runtime), - ] + self._fetch_scheduler = runtime.fetch_scheduler - def _schedule_node(self, node): - """Schedule a single node for execution.""" - for sched in self._schedulers: - if sched.handles(node): - return sched.schedule(node) - else: - raise TypeError("Unknown how to schedule '%s' (%s)" - % (node, type(node))) + def schedule(self, atoms): + """Schedules the provided atoms for *future* completion. - def schedule(self, nodes): - """Schedules the provided nodes for *future* completion. - - This method should schedule a future for each node provided and return + This method should schedule a future for each atom provided and return a set of those futures to be waited on (or used for other similar purposes). It should also return any failure objects that represented scheduling failures that may have occurred during this scheduling process. """ futures = set() - for node in nodes: + for atom in atoms: + scheduler = self._fetch_scheduler(atom) try: - futures.add(self._schedule_node(node)) + futures.add(scheduler.schedule(atom)) except Exception: # Immediately stop scheduling future work so that we can - # exit execution early (rather than later) if a single task + # exit execution early (rather than later) if a single atom # fails to schedule correctly. return (futures, [failure.Failure()]) return (futures, [])