Integrate better locking and a runner helper class.
Ensure that when a linear flow or derivatives is running that it can not be modified by another thread at the same time it is running by putting a lock around sensitive functions. Also instead of using the raw task objects themselves integrate a helper 'runner' class that provides useful functionality that occurs before its tasks runs as well as member variables that are associated with the contained task. This helper class currently provides the following: - A uuid that can be returned to callers of the add method to identify there task (and later its results), allowing for multiple of the same tasks to be added. - Automatic extraction of the needed required and optional inputs for the contained task. Change-Id: Ib01939a4726155a629e4b4703656b9067868d8f3
This commit is contained in:
@@ -96,6 +96,16 @@ def wraps(fn):
|
||||
return wrapper
|
||||
|
||||
|
||||
def locked(f):
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self._lock:
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def task(*args, **kwargs):
|
||||
"""Decorates a given function and ensures that all needed attributes of
|
||||
that function are set so that the function can be used as a task."""
|
||||
|
||||
@@ -53,3 +53,13 @@ class UnclaimableJobException(TaskFlowException):
|
||||
class JobNotFound(TaskFlowException):
|
||||
"""Raised when a job entry can not be found."""
|
||||
pass
|
||||
|
||||
|
||||
class MissingDependencies(InvalidStateException):
|
||||
"""Raised when a task has dependencies that can not be satisified."""
|
||||
message = ("%(task)s requires %(requirements)s but no other task produces"
|
||||
" said requirements")
|
||||
|
||||
def __init__(self, task, requirements):
|
||||
message = self.message % {'task': task, 'requirements': requirements}
|
||||
super(MissingDependencies, self).__init__(message)
|
||||
|
||||
@@ -17,14 +17,15 @@
|
||||
# under the License.
|
||||
|
||||
import abc
|
||||
import threading
|
||||
|
||||
from taskflow import decorators
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow import states
|
||||
|
||||
|
||||
class Flow(object):
|
||||
"""The base abstract class of all flow implementations."""
|
||||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
RESETTABLE_STATES = set([
|
||||
@@ -59,6 +60,9 @@ class Flow(object):
|
||||
# can be implemented when you can track a flows progress.
|
||||
self.task_listeners = []
|
||||
self.listeners = []
|
||||
# Ensure that modifications and/or multiple runs aren't happening
|
||||
# at the same time in the same flow at the same time.
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
@@ -66,13 +70,21 @@ class Flow(object):
|
||||
return self._state
|
||||
|
||||
def _change_state(self, context, new_state):
|
||||
if self.state != new_state:
|
||||
old_state = self.state
|
||||
self._state = new_state
|
||||
was_changed = False
|
||||
old_state = self.state
|
||||
with self._lock:
|
||||
if self.state != new_state:
|
||||
old_state = self.state
|
||||
self._state = new_state
|
||||
was_changed = True
|
||||
if was_changed:
|
||||
# Don't notify while holding the lock.
|
||||
self._on_flow_state_change(context, old_state)
|
||||
|
||||
def __str__(self):
|
||||
return "Flow: %s" % (self.name)
|
||||
lines = ["Flow: %s" % (self.name)]
|
||||
lines.append(" State: %s" % (self.state))
|
||||
return "\n".join(lines)
|
||||
|
||||
def _on_flow_state_change(self, context, old_state):
|
||||
# Notify any listeners that the internal state has changed.
|
||||
@@ -96,12 +108,18 @@ class Flow(object):
|
||||
|
||||
@abc.abstractmethod
|
||||
def add(self, task):
|
||||
"""Adds a given task to this flow."""
|
||||
"""Adds a given task to this flow.
|
||||
|
||||
Returns the uuid that is associated with the task for later operations
|
||||
before and after it is ran."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_many(self, tasks):
|
||||
"""Adds many tasks to this flow."""
|
||||
"""Adds many tasks to this flow.
|
||||
|
||||
Returns a list of uuids (one for each task added).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def interrupt(self):
|
||||
@@ -113,8 +131,15 @@ class Flow(object):
|
||||
if self.state in self.UNINTERRUPTIBLE_STATES:
|
||||
raise exc.InvalidStateException(("Can not interrupt when"
|
||||
" in state %s") % (self.state))
|
||||
self._change_state(None, states.INTERRUPTED)
|
||||
# Note(harlowja): Do *not* acquire the lock here so that the flow may
|
||||
# be interrupted while running. This does mean the the above check may
|
||||
# not be valid but we can worry about that if it becomes an issue.
|
||||
old_state = self.state
|
||||
if old_state != states.INTERRUPTED:
|
||||
self._state = states.INTERRUPTED
|
||||
self._on_flow_state_change(None, old_state)
|
||||
|
||||
@decorators.locked
|
||||
def reset(self):
|
||||
"""Fully resets the internal state of this flow, allowing for the flow
|
||||
to be ran again. *Listeners are also reset*"""
|
||||
@@ -125,6 +150,7 @@ class Flow(object):
|
||||
self.listeners = []
|
||||
self._change_state(None, states.PENDING)
|
||||
|
||||
@decorators.locked
|
||||
def soft_reset(self):
|
||||
"""Partially resets the internal state of this flow, allowing for the
|
||||
flow to be ran again from an interrupted state *only*"""
|
||||
@@ -133,14 +159,15 @@ class Flow(object):
|
||||
" in state %s") % (self.state))
|
||||
self._change_state(None, states.PENDING)
|
||||
|
||||
@decorators.locked
|
||||
def run(self, context, *args, **kwargs):
|
||||
"""Executes the workflow."""
|
||||
if self.state not in self.RUNNABLE_STATES:
|
||||
raise exc.InvalidStateException("Unable to run flow when "
|
||||
"in state %s" % (self.state))
|
||||
|
||||
@abc.abstractmethod
|
||||
@decorators.locked
|
||||
def rollback(self, context, cause):
|
||||
"""Performs rollback of this workflow and any attached parent workflows
|
||||
if present."""
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
|
||||
@@ -23,6 +23,7 @@ from networkx.algorithms import dag
|
||||
from networkx.classes import digraph
|
||||
from networkx import exception as g_exc
|
||||
|
||||
from taskflow import decorators
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow.patterns import linear_flow
|
||||
from taskflow import utils
|
||||
@@ -35,25 +36,25 @@ class Flow(linear_flow.Flow):
|
||||
a linear topological ordering (and reverse using the same linear
|
||||
topological order)"""
|
||||
|
||||
def __init__(self, name, parents=None, allow_same_inputs=True):
|
||||
def __init__(self, name, parents=None):
|
||||
super(Flow, self).__init__(name, parents)
|
||||
self._graph = digraph.DiGraph()
|
||||
self._connected = False
|
||||
self._allow_same_inputs = allow_same_inputs
|
||||
|
||||
@decorators.locked
|
||||
def add(self, task):
|
||||
# Only insert the node to start, connect all the edges
|
||||
# together later after all nodes have been added since if we try
|
||||
# to infer the edges at this stage we likely will fail finding
|
||||
# dependencies from nodes that don't exist.
|
||||
assert isinstance(task, collections.Callable)
|
||||
if not self._graph.has_node(task):
|
||||
self._graph.add_node(task)
|
||||
self._connected = False
|
||||
r = utils.Runner(task)
|
||||
self._graph.add_node(r, uuid=r.uuid)
|
||||
self._runners = []
|
||||
return r.uuid
|
||||
|
||||
def add_many(self, tasks):
|
||||
for t in tasks:
|
||||
self.add(t)
|
||||
def _add_dependency(self, provider, requirer):
|
||||
if not self._graph.has_edge(provider, requirer):
|
||||
self._graph.add_edge(provider, requirer)
|
||||
|
||||
def __str__(self):
|
||||
lines = ["GraphFlow: %s" % (self.name)]
|
||||
@@ -63,46 +64,9 @@ class Flow(linear_flow.Flow):
|
||||
lines.append(" State: %s" % (self.state))
|
||||
return "\n".join(lines)
|
||||
|
||||
def _fetch_task_inputs(self, task):
|
||||
|
||||
def extract_inputs(place_where, would_like, is_optional=False):
|
||||
for n in would_like:
|
||||
for (them, there_result) in self.results:
|
||||
they_provide = utils.get_attr(them, 'provides', [])
|
||||
if n not in set(they_provide):
|
||||
continue
|
||||
if ((not is_optional and
|
||||
not self._graph.has_edge(them, task))):
|
||||
continue
|
||||
if there_result and n in there_result:
|
||||
place_where[n].append(there_result[n])
|
||||
if is_optional:
|
||||
# Take the first task that provides this optional
|
||||
# item.
|
||||
break
|
||||
elif not is_optional:
|
||||
place_where[n].append(None)
|
||||
|
||||
required_inputs = set(utils.get_attr(task, 'requires', []))
|
||||
optional_inputs = set(utils.get_attr(task, 'optional', []))
|
||||
optional_inputs = optional_inputs - required_inputs
|
||||
|
||||
task_inputs = collections.defaultdict(list)
|
||||
extract_inputs(task_inputs, required_inputs)
|
||||
extract_inputs(task_inputs, optional_inputs, is_optional=True)
|
||||
|
||||
def collapse_functor(k_v):
|
||||
(k, v) = k_v
|
||||
if len(v) == 1:
|
||||
v = v[0]
|
||||
return (k, v)
|
||||
|
||||
return dict(map(collapse_functor, task_inputs.iteritems()))
|
||||
|
||||
def _ordering(self):
|
||||
self._connect()
|
||||
try:
|
||||
return dag.topological_sort(self._graph)
|
||||
return self._connect()
|
||||
except g_exc.NetworkXUnfeasible:
|
||||
raise exc.InvalidStateException("Unable to correctly determine "
|
||||
"the path through the provided "
|
||||
@@ -110,49 +74,45 @@ class Flow(linear_flow.Flow):
|
||||
"tasks needed inputs and outputs.")
|
||||
|
||||
def _connect(self):
|
||||
"""Connects the nodes & edges of the graph together."""
|
||||
if self._connected or len(self._graph) == 0:
|
||||
return
|
||||
"""Connects the nodes & edges of the graph together by examining who
|
||||
the requirements of each node and finding another node that will
|
||||
create said dependency."""
|
||||
if len(self._graph) == 0:
|
||||
return []
|
||||
if self._runners:
|
||||
return self._runners
|
||||
|
||||
# Figure out the provider of items and the requirers of items.
|
||||
provides_what = collections.defaultdict(list)
|
||||
requires_what = collections.defaultdict(list)
|
||||
for t in self._graph.nodes_iter():
|
||||
for r in utils.get_attr(t, 'requires', []):
|
||||
requires_what[r].append(t)
|
||||
for p in utils.get_attr(t, 'provides', []):
|
||||
provides_what[p].append(t)
|
||||
# Link providers to requirers.
|
||||
#
|
||||
# TODO(harlowja): allow for developers to manually establish these
|
||||
# connections instead of automatically doing it for them??
|
||||
for n in self._graph.nodes_iter():
|
||||
n_requires = set(utils.get_attr(n.task, 'requires', []))
|
||||
LOG.debug("Finding providers of %s for %s", n_requires, n)
|
||||
for p in self._graph.nodes_iter():
|
||||
if not n_requires:
|
||||
break
|
||||
if n is p:
|
||||
continue
|
||||
p_provides = set(utils.get_attr(p.task, 'provides', []))
|
||||
p_satisfies = n_requires & p_provides
|
||||
if p_satisfies:
|
||||
# P produces for N so thats why we link P->N and not N->P
|
||||
self._add_dependency(p, n)
|
||||
for k in p_satisfies:
|
||||
n.providers[k] = p
|
||||
LOG.debug("Found provider of %s from %s", p_satisfies, p)
|
||||
n_requires = n_requires - p_satisfies
|
||||
if n_requires:
|
||||
raise exc.MissingDependencies(n, sorted(n_requires))
|
||||
|
||||
def get_providers(node, want_what):
|
||||
providers = []
|
||||
for (producer, me) in self._graph.in_edges_iter(node):
|
||||
providing_what = self._graph.get_edge_data(producer, me)
|
||||
if want_what in providing_what:
|
||||
providers.append(producer)
|
||||
return providers
|
||||
|
||||
# Link providers to consumers of items.
|
||||
for (want_what, who_wants) in requires_what.iteritems():
|
||||
who_provided = 0
|
||||
for p in provides_what[want_what]:
|
||||
# P produces for N so thats why we link P->N and not N->P
|
||||
for n in who_wants:
|
||||
if p is n:
|
||||
# No self-referencing allowed.
|
||||
continue
|
||||
if ((len(get_providers(n, want_what))
|
||||
and not self._allow_same_inputs)):
|
||||
msg = "Multiple providers of %s not allowed."
|
||||
raise exc.InvalidStateException(msg % (want_what))
|
||||
self._graph.add_edge(p, n, attr_dict={
|
||||
want_what: True,
|
||||
})
|
||||
who_provided += 1
|
||||
if not who_provided:
|
||||
who_wants = ", ".join([str(a) for a in who_wants])
|
||||
raise exc.InvalidStateException("%s requires input %s "
|
||||
"but no other task produces "
|
||||
"said output." % (who_wants,
|
||||
want_what))
|
||||
|
||||
self._connected = True
|
||||
# Now figure out the order so that we can give the runners there
|
||||
# optional item providers as well as figure out the topological run
|
||||
# order.
|
||||
run_order = dag.topological_sort(self._graph)
|
||||
run_stack = []
|
||||
for r in run_order:
|
||||
r.runs_before = list(reversed(run_stack))
|
||||
run_stack.append(r)
|
||||
self._runners = run_order
|
||||
return run_order
|
||||
|
||||
@@ -23,6 +23,7 @@ import logging
|
||||
|
||||
from taskflow.openstack.common import excutils
|
||||
|
||||
from taskflow import decorators
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow import states
|
||||
from taskflow import utils
|
||||
@@ -55,70 +56,61 @@ class Flow(base.Flow):
|
||||
self.result_fetcher = None
|
||||
# Tasks results are stored here...
|
||||
self.results = []
|
||||
# The last task index in the order we left off at before being
|
||||
# The last index in the order we left off at before being
|
||||
# interrupted (or failing).
|
||||
self._left_off_at = 0
|
||||
# All tasks to run are collected here.
|
||||
self._tasks = []
|
||||
# All runners to run are collected here.
|
||||
self._runners = []
|
||||
|
||||
@decorators.locked
|
||||
def add_many(self, tasks):
|
||||
uuids = []
|
||||
for t in tasks:
|
||||
self.add(t)
|
||||
uuids.append(self.add(t))
|
||||
return uuids
|
||||
|
||||
@decorators.locked
|
||||
def add(self, task):
|
||||
"""Adds a given task to this flow."""
|
||||
assert isinstance(task, collections.Callable)
|
||||
self._validate_provides(task)
|
||||
self._tasks.append(task)
|
||||
r = utils.Runner(task)
|
||||
r.runs_before = list(reversed(self._runners))
|
||||
self._associate_providers(r)
|
||||
self._runners.append(r)
|
||||
return r.uuid
|
||||
|
||||
def _validate_provides(self, task):
|
||||
def _associate_providers(self, runner):
|
||||
# Ensure that some previous task provides this input.
|
||||
missing_requires = []
|
||||
for r in utils.get_attr(task, 'requires', []):
|
||||
found_provider = False
|
||||
for prev_task in reversed(self._tasks):
|
||||
if r in utils.get_attr(prev_task, 'provides', []):
|
||||
found_provider = True
|
||||
who_provides = {}
|
||||
task_requires = set(utils.get_attr(runner.task, 'requires', []))
|
||||
LOG.debug("Finding providers of %s for %s", task_requires, runner)
|
||||
for r in task_requires:
|
||||
provider = None
|
||||
for before_me in runner.runs_before:
|
||||
if r in set(utils.get_attr(before_me.task, 'provides', [])):
|
||||
provider = before_me
|
||||
break
|
||||
if not found_provider:
|
||||
missing_requires.append(r)
|
||||
if provider:
|
||||
LOG.debug("Found provider of %s from %s", r, provider)
|
||||
who_provides[r] = provider
|
||||
# Ensure that the last task provides all the needed input for this
|
||||
# task to run correctly.
|
||||
if len(missing_requires):
|
||||
msg = ("There is no previous task providing the outputs %s"
|
||||
" for %s to correctly execute.") % (missing_requires, task)
|
||||
raise exc.InvalidStateException(msg)
|
||||
missing_requires = task_requires - set(who_provides.keys())
|
||||
if missing_requires:
|
||||
raise exc.MissingDependencies(runner, sorted(missing_requires))
|
||||
runner.providers.update(who_provides)
|
||||
|
||||
def __str__(self):
|
||||
lines = ["LinearFlow: %s" % (self.name)]
|
||||
lines.append(" Number of tasks: %s" % (len(self._tasks)))
|
||||
lines.append(" Number of tasks: %s" % (len(self._runners)))
|
||||
lines.append(" Last index: %s" % (self._left_off_at))
|
||||
lines.append(" State: %s" % (self.state))
|
||||
return "\n".join(lines)
|
||||
|
||||
def _ordering(self):
|
||||
return list(self._tasks)
|
||||
|
||||
def _fetch_task_inputs(self, task):
|
||||
"""Retrieves and additional kwargs inputs to provide to the task when
|
||||
said task is being applied."""
|
||||
would_like = set(utils.get_attr(task, 'requires', []))
|
||||
would_like.update(utils.get_attr(task, 'optional', []))
|
||||
|
||||
inputs = {}
|
||||
for n in would_like:
|
||||
# Find the last task that provided this.
|
||||
for (last_task, last_results) in reversed(self.results):
|
||||
if n not in utils.get_attr(last_task, 'provides', []):
|
||||
continue
|
||||
if last_results and n in last_results:
|
||||
inputs[n] = last_results[n]
|
||||
else:
|
||||
inputs[n] = None
|
||||
# Some task said they had it, get the next requirement.
|
||||
break
|
||||
return inputs
|
||||
return self._runners
|
||||
|
||||
@decorators.locked
|
||||
def run(self, context, *args, **kwargs):
|
||||
super(Flow, self).run(context, *args, **kwargs)
|
||||
|
||||
@@ -129,9 +121,9 @@ class Flow(base.Flow):
|
||||
|
||||
self._change_state(context, states.STARTED)
|
||||
try:
|
||||
task_order = self._ordering()
|
||||
run_order = self._ordering()
|
||||
if self._left_off_at > 0:
|
||||
task_order = task_order[self._left_off_at:]
|
||||
run_order = run_order[self._left_off_at:]
|
||||
except Exception:
|
||||
with excutils.save_and_reraise_exception():
|
||||
try:
|
||||
@@ -140,27 +132,23 @@ class Flow(base.Flow):
|
||||
LOG.exception("Dropping exception catched when"
|
||||
" notifying about ordering failure.")
|
||||
|
||||
def run_task(task, failed=False, result=None, simulate_run=False):
|
||||
def run_it(runner, failed=False, result=None, simulate_run=False):
|
||||
try:
|
||||
self._on_task_start(context, task)
|
||||
self._on_task_start(context, runner.task)
|
||||
# Add the task to be rolled back *immediately* so that even if
|
||||
# the task fails while producing results it will be given a
|
||||
# chance to rollback.
|
||||
rb = utils.RollbackTask(context, task, result=None)
|
||||
rb = utils.RollbackTask(context, runner.task, result=None)
|
||||
self._accumulator.add(rb)
|
||||
if not simulate_run:
|
||||
inputs = self._fetch_task_inputs(task)
|
||||
if not inputs:
|
||||
inputs = {}
|
||||
inputs.update(kwargs)
|
||||
result = task(context, *args, **inputs)
|
||||
result = runner(context, *args, **kwargs)
|
||||
else:
|
||||
if failed:
|
||||
if not result:
|
||||
# If no exception or exception message was provided
|
||||
# or captured from the previous run then we need to
|
||||
# form one for this task.
|
||||
result = "%s failed running." % (task)
|
||||
result = "%s failed running." % (runner.task)
|
||||
if isinstance(result, basestring):
|
||||
result = exc.InvalidStateException(result)
|
||||
if not isinstance(result, Exception):
|
||||
@@ -182,53 +170,59 @@ class Flow(base.Flow):
|
||||
# some task could alter this result intentionally or not
|
||||
# intentionally).
|
||||
rb.result = result
|
||||
runner.result = result
|
||||
# Alter the index we have ran at.
|
||||
self._left_off_at += 1
|
||||
result_copy = copy.deepcopy(result)
|
||||
self.results.append((task, result_copy))
|
||||
self._on_task_finish(context, task, result_copy)
|
||||
self.results.append((runner.task, copy.deepcopy(result)))
|
||||
self._on_task_finish(context, runner.task, result)
|
||||
except Exception as e:
|
||||
cause = utils.FlowFailure(task, self, e)
|
||||
cause = utils.FlowFailure(runner.task, self, e)
|
||||
with excutils.save_and_reraise_exception():
|
||||
try:
|
||||
self._on_task_error(context, task, e)
|
||||
self._on_task_error(context, runner.task, e)
|
||||
except Exception:
|
||||
LOG.exception("Dropping exception catched when"
|
||||
" notifying about task failure.")
|
||||
self.rollback(context, cause)
|
||||
|
||||
last_task = 0
|
||||
# Ensure in a ready to run state.
|
||||
for runner in run_order:
|
||||
runner.reset()
|
||||
|
||||
last_runner = 0
|
||||
was_interrupted = False
|
||||
if result_fetcher:
|
||||
self._change_state(context, states.RESUMING)
|
||||
for (i, task) in enumerate(task_order):
|
||||
for (i, runner) in enumerate(run_order):
|
||||
if self.state == states.INTERRUPTED:
|
||||
was_interrupted = True
|
||||
break
|
||||
(has_result, was_error, result) = result_fetcher(self, task)
|
||||
(has_result, was_error, result) = result_fetcher(self,
|
||||
runner.task)
|
||||
if not has_result:
|
||||
break
|
||||
# Fake running the task so that we trigger the same
|
||||
# notifications and state changes (and rollback that
|
||||
# would have happened in a normal flow).
|
||||
last_task = i + 1
|
||||
run_task(task, failed=was_error, result=result,
|
||||
simulate_run=True)
|
||||
last_runner = i + 1
|
||||
run_it(runner, failed=was_error, result=result,
|
||||
simulate_run=True)
|
||||
|
||||
if was_interrupted:
|
||||
return
|
||||
|
||||
self._change_state(context, states.RUNNING)
|
||||
for task in task_order[last_task:]:
|
||||
for runner in run_order[last_runner:]:
|
||||
if self.state == states.INTERRUPTED:
|
||||
was_interrupted = True
|
||||
break
|
||||
run_task(task)
|
||||
run_it(runner)
|
||||
|
||||
if not was_interrupted:
|
||||
# Only gets here if everything went successfully.
|
||||
self._change_state(context, states.SUCCESS)
|
||||
|
||||
@decorators.locked
|
||||
def reset(self):
|
||||
super(Flow, self).reset()
|
||||
self.results = []
|
||||
@@ -236,6 +230,7 @@ class Flow(base.Flow):
|
||||
self._accumulator.reset()
|
||||
self._left_off_at = 0
|
||||
|
||||
@decorators.locked
|
||||
def rollback(self, context, cause):
|
||||
# Performs basic task by task rollback by going through the reverse
|
||||
# order that tasks have finished and asking said task to undo whatever
|
||||
|
||||
@@ -55,39 +55,6 @@ class GraphFlowTest(unittest2.TestCase):
|
||||
self.assertEquals(states.FAILURE, flo.state)
|
||||
self.assertEquals(['run1'], reverted)
|
||||
|
||||
def test_multi_provider_disallowed(self):
|
||||
flo = gw.Flow("test-flow", allow_same_inputs=False)
|
||||
flo.add(utils.ProvidesRequiresTask('test6',
|
||||
provides=['y'],
|
||||
requires=[]))
|
||||
flo.add(utils.ProvidesRequiresTask('test7',
|
||||
provides=['y'],
|
||||
requires=[]))
|
||||
flo.add(utils.ProvidesRequiresTask('test8',
|
||||
provides=[],
|
||||
requires=['y']))
|
||||
self.assertEquals(states.PENDING, flo.state)
|
||||
self.assertRaises(excp.InvalidStateException, flo.run, {})
|
||||
self.assertEquals(states.FAILURE, flo.state)
|
||||
|
||||
def test_multi_provider_allowed(self):
|
||||
flo = gw.Flow("test-flow", allow_same_inputs=True)
|
||||
flo.add(utils.ProvidesRequiresTask('test6',
|
||||
provides=['y', 'z'],
|
||||
requires=[]))
|
||||
flo.add(utils.ProvidesRequiresTask('test7',
|
||||
provides=['y'],
|
||||
requires=['z']))
|
||||
flo.add(utils.ProvidesRequiresTask('test8',
|
||||
provides=[],
|
||||
requires=['y', 'z']))
|
||||
ctx = {}
|
||||
flo.run(ctx)
|
||||
self.assertEquals(['test6', 'test7', 'test8'], ctx[utils.ORDER_KEY])
|
||||
(_task, results) = flo.results[2]
|
||||
self.assertEquals([True, True], results[utils.KWARGS_KEY]['y'])
|
||||
self.assertEquals(True, results[utils.KWARGS_KEY]['z'])
|
||||
|
||||
def test_no_requires_provider(self):
|
||||
flo = gw.Flow("test-flow")
|
||||
flo.add(utils.ProvidesRequiresTask('test1',
|
||||
|
||||
@@ -23,6 +23,8 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
from taskflow.openstack.common import uuidutils
|
||||
|
||||
from taskflow import decorators
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -97,6 +99,51 @@ class RollbackTask(object):
|
||||
self.task.revert(self.context, self.result, cause)
|
||||
|
||||
|
||||
class Runner(object):
|
||||
"""A helper class that wraps a task and can find the needed inputs for
|
||||
the task to run, as well as providing a uuid and other useful functionality
|
||||
for users of the task.
|
||||
|
||||
TODO(harlowja): replace with the task details object or a subclass of
|
||||
that???
|
||||
"""
|
||||
|
||||
def __init__(self, task):
|
||||
assert isinstance(task, collections.Callable)
|
||||
self.task = task
|
||||
self.providers = {}
|
||||
self.uuid = uuidutils.generate_uuid()
|
||||
self.runs_before = []
|
||||
self.result = None
|
||||
|
||||
def reset(self):
|
||||
self.result = None
|
||||
|
||||
def __str__(self):
|
||||
return "%s@%s" % (self.task, self.uuid)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# Find all of our inputs first.
|
||||
kwargs = dict(kwargs)
|
||||
for (k, who_made) in self.providers.iteritems():
|
||||
if who_made.result and k in who_made.result:
|
||||
kwargs[k] = who_made.result[k]
|
||||
else:
|
||||
kwargs[k] = None
|
||||
optional_keys = set(get_attr(self.task, 'optional', []))
|
||||
optional_missing_keys = optional_keys - set(kwargs.keys())
|
||||
if optional_missing_keys:
|
||||
for k in optional_missing_keys:
|
||||
for r in self.runs_before:
|
||||
r_provides = set(get_attr(r.task, 'provides', []))
|
||||
if k in r_provides and r.result and k in r.result:
|
||||
kwargs[k] = r.result[k]
|
||||
break
|
||||
# And now finally run.
|
||||
self.result = self.task(*args, **kwargs)
|
||||
return self.result
|
||||
|
||||
|
||||
class RollbackAccumulator(object):
|
||||
"""A utility class that can help in organizing 'undo' like code
|
||||
so that said code be rolled back on failure (automatically or manually)
|
||||
|
||||
Reference in New Issue
Block a user