diff --git a/taskflow/decorators.py b/taskflow/decorators.py index a41f1717..b0b225ee 100644 --- a/taskflow/decorators.py +++ b/taskflow/decorators.py @@ -68,6 +68,10 @@ def task(*args, **kwargs): requires_what = kwargs.pop('requires', []) f = requires(*requires_what, **kwargs)(f) + # Attach any optional requirements this function needs for running. + optional_what = kwargs.pop('optional', []) + f = optional(*optional_what, **kwargs)(f) + # Attach any items this function provides as output provides_what = kwargs.pop('provides', []) f = provides(*provides_what, **kwargs)(f) @@ -111,6 +115,33 @@ def versionize(major, minor=None): return decorator +def optional(*args, **kwargs): + """Attaches a set of items that the decorated function would like as input + to the functions underlying dictionary.""" + + def decorator(f): + if not hasattr(f, 'optional'): + f.optional = set() + + f.optional.update([a for a in args if _take_arg(a)]) + + @wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + + return wrapper + + # This is needed to handle when the decorator has args or the decorator + # doesn't have args, python is rather weird here... + if kwargs or not args: + return decorator + else: + if isinstance(args[0], collections.Callable): + return decorator(args[0]) + else: + return decorator + + def requires(*args, **kwargs): """Attaches a set of items that the decorated function requires as input to the functions underlying dictionary.""" diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index 1db5a564..5e1ddab9 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -50,17 +50,31 @@ class Flow(ordered_flow.Flow): self._connected = False def _fetch_task_inputs(self, task): - inputs = collections.defaultdict(list) - for n in getattr(task, 'requires', []): - for (them, there_result) in self.results: - if (not self._graph.has_edge(them, task) or - not n in getattr(them, 'provides', [])): - continue - if there_result and n in there_result: - inputs[n].append(there_result[n]) - else: - inputs[n].append(None) + def extract_inputs(place_where, would_like, is_optional=False): + for n in would_like: + for (them, there_result) in self.results: + if not n in set(getattr(them, 'provides', [])): + continue + if (not is_optional and + not self._graph.has_edge(them, task)): + continue + if there_result and n in there_result: + place_where[n].append(there_result[n]) + if is_optional: + # Take the first task that provides this optional + # item. + break + elif not is_optional: + place_where[n].append(None) + + required_inputs = set(getattr(task, 'requires', [])) + optional_inputs = set(getattr(task, 'optional', [])) + optional_inputs = optional_inputs - required_inputs + + task_inputs = collections.defaultdict(list) + extract_inputs(task_inputs, required_inputs) + extract_inputs(task_inputs, optional_inputs, is_optional=True) def collapse_functor(k_v): (k, v) = k_v @@ -68,7 +82,7 @@ class Flow(ordered_flow.Flow): v = v[0] return (k, v) - return dict(map(collapse_functor, inputs.iteritems())) + return dict(map(collapse_functor, task_inputs.iteritems())) def order(self): self.connect() diff --git a/taskflow/patterns/linear_flow.py b/taskflow/patterns/linear_flow.py index 9a601b76..2f56560e 100644 --- a/taskflow/patterns/linear_flow.py +++ b/taskflow/patterns/linear_flow.py @@ -30,16 +30,19 @@ class Flow(ordered_flow.Flow): self._tasks = [] def _fetch_task_inputs(self, task): + would_like = set(getattr(task, 'requires', [])) + would_like.update(getattr(task, 'optional', [])) + inputs = {} - for r in getattr(task, 'requires', []): + for n in would_like: # Find the last task that provided this. for (last_task, last_results) in reversed(self.results): - if r not in getattr(last_task, 'provides', []): + if n not in getattr(last_task, 'provides', []): continue - if last_results and r in last_results: - inputs[r] = last_results[r] + if last_results and n in last_results: + inputs[n] = last_results[n] else: - inputs[r] = None + inputs[n] = None # Some task said they had it, get the next requirement. break return inputs diff --git a/taskflow/task.py b/taskflow/task.py index 2ad11735..ecf2d69d 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -32,6 +32,10 @@ class Task(object): # An *immutable* input 'resource' name set this task depends # on existing before this task can be applied. self.requires = set() + # An *immutable* input 'resource' name set this task would like to + # depends on existing before this task can be applied (but does not + # strongly depend on existing). + self.optional = set() # An *immutable* output 'resource' name set this task # produces that other tasks may depend on this task providing. self.provides = set()