Only do differences on set objects.
This commit is contained in:
@@ -20,6 +20,16 @@ from taskflow import exceptions as exc
|
||||
from taskflow.patterns import ordered_flow
|
||||
|
||||
|
||||
def _convert_to_set(items):
|
||||
if not items:
|
||||
return set()
|
||||
if isinstance(items, set):
|
||||
return items
|
||||
if isinstance(items, dict):
|
||||
return items.keys()
|
||||
return set(iter(items))
|
||||
|
||||
|
||||
class Flow(ordered_flow.Flow):
|
||||
"""A linear chain of tasks that can be applied as one unit or
|
||||
rolled back as one unit. Each task in the chain may have requirements
|
||||
@@ -39,14 +49,15 @@ class Flow(ordered_flow.Flow):
|
||||
return inputs
|
||||
|
||||
def _validate_provides(self, task):
|
||||
requires = _convert_to_set(task.requires())
|
||||
last_provides = set()
|
||||
last_provider = None
|
||||
if self._tasks:
|
||||
last_provider = self._tasks[-1]
|
||||
last_provides = last_provider.provides()
|
||||
last_provides = _convert_to_set(last_provider.provides())
|
||||
# Ensure that the last task provides all the needed input for this
|
||||
# task to run correctly.
|
||||
req_diff = task.requires().difference(last_provides)
|
||||
req_diff = requires.difference(last_provides)
|
||||
if req_diff:
|
||||
if last_provider is None:
|
||||
msg = ("There is no previous task providing the outputs %s"
|
||||
|
||||
Reference in New Issue
Block a user