Integrate better locking and a runner helper class.

Ensure that when a linear flow or derivatives is running
that it can not be modified by another thread at the same
time it is running by putting a lock around sensitive functions.

Also instead of using the raw task objects themselves
integrate a helper 'runner' class that provides useful
functionality that occurs before its tasks runs as well
as member variables that are associated with the contained
task.

This helper class currently provides the following:

- A uuid that can be returned to callers of the add
  method to identify there task (and later its results),
  allowing for multiple of the same tasks to be added.
- Automatic extraction of the needed required and optional
  inputs for the contained task.

Change-Id: Ib01939a4726155a629e4b4703656b9067868d8f3
This commit is contained in:
Joshua Harlow
2013-06-28 17:20:57 -07:00
parent 2d5f90fa24
commit e8e60e884f
7 changed files with 215 additions and 199 deletions

View File

@@ -96,6 +96,16 @@ def wraps(fn):
return wrapper
def locked(f):
@wraps(f)
def wrapper(self, *args, **kwargs):
with self._lock:
return f(self, *args, **kwargs)
return wrapper
def task(*args, **kwargs):
"""Decorates a given function and ensures that all needed attributes of
that function are set so that the function can be used as a task."""

View File

@@ -53,3 +53,13 @@ class UnclaimableJobException(TaskFlowException):
class JobNotFound(TaskFlowException):
"""Raised when a job entry can not be found."""
pass
class MissingDependencies(InvalidStateException):
"""Raised when a task has dependencies that can not be satisified."""
message = ("%(task)s requires %(requirements)s but no other task produces"
" said requirements")
def __init__(self, task, requirements):
message = self.message % {'task': task, 'requirements': requirements}
super(MissingDependencies, self).__init__(message)

View File

@@ -17,14 +17,15 @@
# under the License.
import abc
import threading
from taskflow import decorators
from taskflow import exceptions as exc
from taskflow import states
class Flow(object):
"""The base abstract class of all flow implementations."""
__metaclass__ = abc.ABCMeta
RESETTABLE_STATES = set([
@@ -59,6 +60,9 @@ class Flow(object):
# can be implemented when you can track a flows progress.
self.task_listeners = []
self.listeners = []
# Ensure that modifications and/or multiple runs aren't happening
# at the same time in the same flow at the same time.
self._lock = threading.RLock()
@property
def state(self):
@@ -66,13 +70,21 @@ class Flow(object):
return self._state
def _change_state(self, context, new_state):
if self.state != new_state:
old_state = self.state
self._state = new_state
was_changed = False
old_state = self.state
with self._lock:
if self.state != new_state:
old_state = self.state
self._state = new_state
was_changed = True
if was_changed:
# Don't notify while holding the lock.
self._on_flow_state_change(context, old_state)
def __str__(self):
return "Flow: %s" % (self.name)
lines = ["Flow: %s" % (self.name)]
lines.append(" State: %s" % (self.state))
return "\n".join(lines)
def _on_flow_state_change(self, context, old_state):
# Notify any listeners that the internal state has changed.
@@ -96,12 +108,18 @@ class Flow(object):
@abc.abstractmethod
def add(self, task):
"""Adds a given task to this flow."""
"""Adds a given task to this flow.
Returns the uuid that is associated with the task for later operations
before and after it is ran."""
raise NotImplementedError()
@abc.abstractmethod
def add_many(self, tasks):
"""Adds many tasks to this flow."""
"""Adds many tasks to this flow.
Returns a list of uuids (one for each task added).
"""
raise NotImplementedError()
def interrupt(self):
@@ -113,8 +131,15 @@ class Flow(object):
if self.state in self.UNINTERRUPTIBLE_STATES:
raise exc.InvalidStateException(("Can not interrupt when"
" in state %s") % (self.state))
self._change_state(None, states.INTERRUPTED)
# Note(harlowja): Do *not* acquire the lock here so that the flow may
# be interrupted while running. This does mean the the above check may
# not be valid but we can worry about that if it becomes an issue.
old_state = self.state
if old_state != states.INTERRUPTED:
self._state = states.INTERRUPTED
self._on_flow_state_change(None, old_state)
@decorators.locked
def reset(self):
"""Fully resets the internal state of this flow, allowing for the flow
to be ran again. *Listeners are also reset*"""
@@ -125,6 +150,7 @@ class Flow(object):
self.listeners = []
self._change_state(None, states.PENDING)
@decorators.locked
def soft_reset(self):
"""Partially resets the internal state of this flow, allowing for the
flow to be ran again from an interrupted state *only*"""
@@ -133,14 +159,15 @@ class Flow(object):
" in state %s") % (self.state))
self._change_state(None, states.PENDING)
@decorators.locked
def run(self, context, *args, **kwargs):
"""Executes the workflow."""
if self.state not in self.RUNNABLE_STATES:
raise exc.InvalidStateException("Unable to run flow when "
"in state %s" % (self.state))
@abc.abstractmethod
@decorators.locked
def rollback(self, context, cause):
"""Performs rollback of this workflow and any attached parent workflows
if present."""
raise NotImplementedError()
pass

View File

@@ -23,6 +23,7 @@ from networkx.algorithms import dag
from networkx.classes import digraph
from networkx import exception as g_exc
from taskflow import decorators
from taskflow import exceptions as exc
from taskflow.patterns import linear_flow
from taskflow import utils
@@ -35,25 +36,25 @@ class Flow(linear_flow.Flow):
a linear topological ordering (and reverse using the same linear
topological order)"""
def __init__(self, name, parents=None, allow_same_inputs=True):
def __init__(self, name, parents=None):
super(Flow, self).__init__(name, parents)
self._graph = digraph.DiGraph()
self._connected = False
self._allow_same_inputs = allow_same_inputs
@decorators.locked
def add(self, task):
# Only insert the node to start, connect all the edges
# together later after all nodes have been added since if we try
# to infer the edges at this stage we likely will fail finding
# dependencies from nodes that don't exist.
assert isinstance(task, collections.Callable)
if not self._graph.has_node(task):
self._graph.add_node(task)
self._connected = False
r = utils.Runner(task)
self._graph.add_node(r, uuid=r.uuid)
self._runners = []
return r.uuid
def add_many(self, tasks):
for t in tasks:
self.add(t)
def _add_dependency(self, provider, requirer):
if not self._graph.has_edge(provider, requirer):
self._graph.add_edge(provider, requirer)
def __str__(self):
lines = ["GraphFlow: %s" % (self.name)]
@@ -63,46 +64,9 @@ class Flow(linear_flow.Flow):
lines.append(" State: %s" % (self.state))
return "\n".join(lines)
def _fetch_task_inputs(self, task):
def extract_inputs(place_where, would_like, is_optional=False):
for n in would_like:
for (them, there_result) in self.results:
they_provide = utils.get_attr(them, 'provides', [])
if n not in set(they_provide):
continue
if ((not is_optional and
not self._graph.has_edge(them, task))):
continue
if there_result and n in there_result:
place_where[n].append(there_result[n])
if is_optional:
# Take the first task that provides this optional
# item.
break
elif not is_optional:
place_where[n].append(None)
required_inputs = set(utils.get_attr(task, 'requires', []))
optional_inputs = set(utils.get_attr(task, 'optional', []))
optional_inputs = optional_inputs - required_inputs
task_inputs = collections.defaultdict(list)
extract_inputs(task_inputs, required_inputs)
extract_inputs(task_inputs, optional_inputs, is_optional=True)
def collapse_functor(k_v):
(k, v) = k_v
if len(v) == 1:
v = v[0]
return (k, v)
return dict(map(collapse_functor, task_inputs.iteritems()))
def _ordering(self):
self._connect()
try:
return dag.topological_sort(self._graph)
return self._connect()
except g_exc.NetworkXUnfeasible:
raise exc.InvalidStateException("Unable to correctly determine "
"the path through the provided "
@@ -110,49 +74,45 @@ class Flow(linear_flow.Flow):
"tasks needed inputs and outputs.")
def _connect(self):
"""Connects the nodes & edges of the graph together."""
if self._connected or len(self._graph) == 0:
return
"""Connects the nodes & edges of the graph together by examining who
the requirements of each node and finding another node that will
create said dependency."""
if len(self._graph) == 0:
return []
if self._runners:
return self._runners
# Figure out the provider of items and the requirers of items.
provides_what = collections.defaultdict(list)
requires_what = collections.defaultdict(list)
for t in self._graph.nodes_iter():
for r in utils.get_attr(t, 'requires', []):
requires_what[r].append(t)
for p in utils.get_attr(t, 'provides', []):
provides_what[p].append(t)
# Link providers to requirers.
#
# TODO(harlowja): allow for developers to manually establish these
# connections instead of automatically doing it for them??
for n in self._graph.nodes_iter():
n_requires = set(utils.get_attr(n.task, 'requires', []))
LOG.debug("Finding providers of %s for %s", n_requires, n)
for p in self._graph.nodes_iter():
if not n_requires:
break
if n is p:
continue
p_provides = set(utils.get_attr(p.task, 'provides', []))
p_satisfies = n_requires & p_provides
if p_satisfies:
# P produces for N so thats why we link P->N and not N->P
self._add_dependency(p, n)
for k in p_satisfies:
n.providers[k] = p
LOG.debug("Found provider of %s from %s", p_satisfies, p)
n_requires = n_requires - p_satisfies
if n_requires:
raise exc.MissingDependencies(n, sorted(n_requires))
def get_providers(node, want_what):
providers = []
for (producer, me) in self._graph.in_edges_iter(node):
providing_what = self._graph.get_edge_data(producer, me)
if want_what in providing_what:
providers.append(producer)
return providers
# Link providers to consumers of items.
for (want_what, who_wants) in requires_what.iteritems():
who_provided = 0
for p in provides_what[want_what]:
# P produces for N so thats why we link P->N and not N->P
for n in who_wants:
if p is n:
# No self-referencing allowed.
continue
if ((len(get_providers(n, want_what))
and not self._allow_same_inputs)):
msg = "Multiple providers of %s not allowed."
raise exc.InvalidStateException(msg % (want_what))
self._graph.add_edge(p, n, attr_dict={
want_what: True,
})
who_provided += 1
if not who_provided:
who_wants = ", ".join([str(a) for a in who_wants])
raise exc.InvalidStateException("%s requires input %s "
"but no other task produces "
"said output." % (who_wants,
want_what))
self._connected = True
# Now figure out the order so that we can give the runners there
# optional item providers as well as figure out the topological run
# order.
run_order = dag.topological_sort(self._graph)
run_stack = []
for r in run_order:
r.runs_before = list(reversed(run_stack))
run_stack.append(r)
self._runners = run_order
return run_order

View File

@@ -23,6 +23,7 @@ import logging
from taskflow.openstack.common import excutils
from taskflow import decorators
from taskflow import exceptions as exc
from taskflow import states
from taskflow import utils
@@ -55,70 +56,61 @@ class Flow(base.Flow):
self.result_fetcher = None
# Tasks results are stored here...
self.results = []
# The last task index in the order we left off at before being
# The last index in the order we left off at before being
# interrupted (or failing).
self._left_off_at = 0
# All tasks to run are collected here.
self._tasks = []
# All runners to run are collected here.
self._runners = []
@decorators.locked
def add_many(self, tasks):
uuids = []
for t in tasks:
self.add(t)
uuids.append(self.add(t))
return uuids
@decorators.locked
def add(self, task):
"""Adds a given task to this flow."""
assert isinstance(task, collections.Callable)
self._validate_provides(task)
self._tasks.append(task)
r = utils.Runner(task)
r.runs_before = list(reversed(self._runners))
self._associate_providers(r)
self._runners.append(r)
return r.uuid
def _validate_provides(self, task):
def _associate_providers(self, runner):
# Ensure that some previous task provides this input.
missing_requires = []
for r in utils.get_attr(task, 'requires', []):
found_provider = False
for prev_task in reversed(self._tasks):
if r in utils.get_attr(prev_task, 'provides', []):
found_provider = True
who_provides = {}
task_requires = set(utils.get_attr(runner.task, 'requires', []))
LOG.debug("Finding providers of %s for %s", task_requires, runner)
for r in task_requires:
provider = None
for before_me in runner.runs_before:
if r in set(utils.get_attr(before_me.task, 'provides', [])):
provider = before_me
break
if not found_provider:
missing_requires.append(r)
if provider:
LOG.debug("Found provider of %s from %s", r, provider)
who_provides[r] = provider
# Ensure that the last task provides all the needed input for this
# task to run correctly.
if len(missing_requires):
msg = ("There is no previous task providing the outputs %s"
" for %s to correctly execute.") % (missing_requires, task)
raise exc.InvalidStateException(msg)
missing_requires = task_requires - set(who_provides.keys())
if missing_requires:
raise exc.MissingDependencies(runner, sorted(missing_requires))
runner.providers.update(who_provides)
def __str__(self):
lines = ["LinearFlow: %s" % (self.name)]
lines.append(" Number of tasks: %s" % (len(self._tasks)))
lines.append(" Number of tasks: %s" % (len(self._runners)))
lines.append(" Last index: %s" % (self._left_off_at))
lines.append(" State: %s" % (self.state))
return "\n".join(lines)
def _ordering(self):
return list(self._tasks)
def _fetch_task_inputs(self, task):
"""Retrieves and additional kwargs inputs to provide to the task when
said task is being applied."""
would_like = set(utils.get_attr(task, 'requires', []))
would_like.update(utils.get_attr(task, 'optional', []))
inputs = {}
for n in would_like:
# Find the last task that provided this.
for (last_task, last_results) in reversed(self.results):
if n not in utils.get_attr(last_task, 'provides', []):
continue
if last_results and n in last_results:
inputs[n] = last_results[n]
else:
inputs[n] = None
# Some task said they had it, get the next requirement.
break
return inputs
return self._runners
@decorators.locked
def run(self, context, *args, **kwargs):
super(Flow, self).run(context, *args, **kwargs)
@@ -129,9 +121,9 @@ class Flow(base.Flow):
self._change_state(context, states.STARTED)
try:
task_order = self._ordering()
run_order = self._ordering()
if self._left_off_at > 0:
task_order = task_order[self._left_off_at:]
run_order = run_order[self._left_off_at:]
except Exception:
with excutils.save_and_reraise_exception():
try:
@@ -140,27 +132,23 @@ class Flow(base.Flow):
LOG.exception("Dropping exception catched when"
" notifying about ordering failure.")
def run_task(task, failed=False, result=None, simulate_run=False):
def run_it(runner, failed=False, result=None, simulate_run=False):
try:
self._on_task_start(context, task)
self._on_task_start(context, runner.task)
# Add the task to be rolled back *immediately* so that even if
# the task fails while producing results it will be given a
# chance to rollback.
rb = utils.RollbackTask(context, task, result=None)
rb = utils.RollbackTask(context, runner.task, result=None)
self._accumulator.add(rb)
if not simulate_run:
inputs = self._fetch_task_inputs(task)
if not inputs:
inputs = {}
inputs.update(kwargs)
result = task(context, *args, **inputs)
result = runner(context, *args, **kwargs)
else:
if failed:
if not result:
# If no exception or exception message was provided
# or captured from the previous run then we need to
# form one for this task.
result = "%s failed running." % (task)
result = "%s failed running." % (runner.task)
if isinstance(result, basestring):
result = exc.InvalidStateException(result)
if not isinstance(result, Exception):
@@ -182,53 +170,59 @@ class Flow(base.Flow):
# some task could alter this result intentionally or not
# intentionally).
rb.result = result
runner.result = result
# Alter the index we have ran at.
self._left_off_at += 1
result_copy = copy.deepcopy(result)
self.results.append((task, result_copy))
self._on_task_finish(context, task, result_copy)
self.results.append((runner.task, copy.deepcopy(result)))
self._on_task_finish(context, runner.task, result)
except Exception as e:
cause = utils.FlowFailure(task, self, e)
cause = utils.FlowFailure(runner.task, self, e)
with excutils.save_and_reraise_exception():
try:
self._on_task_error(context, task, e)
self._on_task_error(context, runner.task, e)
except Exception:
LOG.exception("Dropping exception catched when"
" notifying about task failure.")
self.rollback(context, cause)
last_task = 0
# Ensure in a ready to run state.
for runner in run_order:
runner.reset()
last_runner = 0
was_interrupted = False
if result_fetcher:
self._change_state(context, states.RESUMING)
for (i, task) in enumerate(task_order):
for (i, runner) in enumerate(run_order):
if self.state == states.INTERRUPTED:
was_interrupted = True
break
(has_result, was_error, result) = result_fetcher(self, task)
(has_result, was_error, result) = result_fetcher(self,
runner.task)
if not has_result:
break
# Fake running the task so that we trigger the same
# notifications and state changes (and rollback that
# would have happened in a normal flow).
last_task = i + 1
run_task(task, failed=was_error, result=result,
simulate_run=True)
last_runner = i + 1
run_it(runner, failed=was_error, result=result,
simulate_run=True)
if was_interrupted:
return
self._change_state(context, states.RUNNING)
for task in task_order[last_task:]:
for runner in run_order[last_runner:]:
if self.state == states.INTERRUPTED:
was_interrupted = True
break
run_task(task)
run_it(runner)
if not was_interrupted:
# Only gets here if everything went successfully.
self._change_state(context, states.SUCCESS)
@decorators.locked
def reset(self):
super(Flow, self).reset()
self.results = []
@@ -236,6 +230,7 @@ class Flow(base.Flow):
self._accumulator.reset()
self._left_off_at = 0
@decorators.locked
def rollback(self, context, cause):
# Performs basic task by task rollback by going through the reverse
# order that tasks have finished and asking said task to undo whatever

View File

@@ -55,39 +55,6 @@ class GraphFlowTest(unittest2.TestCase):
self.assertEquals(states.FAILURE, flo.state)
self.assertEquals(['run1'], reverted)
def test_multi_provider_disallowed(self):
flo = gw.Flow("test-flow", allow_same_inputs=False)
flo.add(utils.ProvidesRequiresTask('test6',
provides=['y'],
requires=[]))
flo.add(utils.ProvidesRequiresTask('test7',
provides=['y'],
requires=[]))
flo.add(utils.ProvidesRequiresTask('test8',
provides=[],
requires=['y']))
self.assertEquals(states.PENDING, flo.state)
self.assertRaises(excp.InvalidStateException, flo.run, {})
self.assertEquals(states.FAILURE, flo.state)
def test_multi_provider_allowed(self):
flo = gw.Flow("test-flow", allow_same_inputs=True)
flo.add(utils.ProvidesRequiresTask('test6',
provides=['y', 'z'],
requires=[]))
flo.add(utils.ProvidesRequiresTask('test7',
provides=['y'],
requires=['z']))
flo.add(utils.ProvidesRequiresTask('test8',
provides=[],
requires=['y', 'z']))
ctx = {}
flo.run(ctx)
self.assertEquals(['test6', 'test7', 'test8'], ctx[utils.ORDER_KEY])
(_task, results) = flo.results[2]
self.assertEquals([True, True], results[utils.KWARGS_KEY]['y'])
self.assertEquals(True, results[utils.KWARGS_KEY]['z'])
def test_no_requires_provider(self):
flo = gw.Flow("test-flow")
flo.add(utils.ProvidesRequiresTask('test1',

View File

@@ -23,6 +23,8 @@ import sys
import threading
import time
from taskflow.openstack.common import uuidutils
from taskflow import decorators
LOG = logging.getLogger(__name__)
@@ -97,6 +99,51 @@ class RollbackTask(object):
self.task.revert(self.context, self.result, cause)
class Runner(object):
"""A helper class that wraps a task and can find the needed inputs for
the task to run, as well as providing a uuid and other useful functionality
for users of the task.
TODO(harlowja): replace with the task details object or a subclass of
that???
"""
def __init__(self, task):
assert isinstance(task, collections.Callable)
self.task = task
self.providers = {}
self.uuid = uuidutils.generate_uuid()
self.runs_before = []
self.result = None
def reset(self):
self.result = None
def __str__(self):
return "%s@%s" % (self.task, self.uuid)
def __call__(self, *args, **kwargs):
# Find all of our inputs first.
kwargs = dict(kwargs)
for (k, who_made) in self.providers.iteritems():
if who_made.result and k in who_made.result:
kwargs[k] = who_made.result[k]
else:
kwargs[k] = None
optional_keys = set(get_attr(self.task, 'optional', []))
optional_missing_keys = optional_keys - set(kwargs.keys())
if optional_missing_keys:
for k in optional_missing_keys:
for r in self.runs_before:
r_provides = set(get_attr(r.task, 'provides', []))
if k in r_provides and r.result and k in r.result:
kwargs[k] = r.result[k]
break
# And now finally run.
self.result = self.task(*args, **kwargs)
return self.result
class RollbackAccumulator(object):
"""A utility class that can help in organizing 'undo' like code
so that said code be rolled back on failure (automatically or manually)