diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2f5f6beb8..4f57b6315 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,14 +12,15 @@ repos: - id: debug-statements - id: check-yaml files: .*\.(yaml|yml)$ + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.5 + hooks: + - id: ruff-check + args: ['--fix', '--unsafe-fixes'] + - id: ruff-format - repo: https://opendev.org/openstack/hacking rev: 8.0.0 hooks: - id: hacking additional_dependencies: [] exclude: '^(doc|releasenotes|tools)/.*$' - - repo: https://github.com/asottile/pyupgrade - rev: v3.21.2 - hooks: - - id: pyupgrade - args: [--py310-plus] diff --git a/doc/source/conf.py b/doc/source/conf.py index 93de060a7..bf4eda2e3 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -25,7 +25,7 @@ extensions = [ 'sphinx.ext.extlinks', 'sphinx.ext.inheritance_diagram', 'sphinx.ext.viewcode', - 'openstackdocstheme' + 'openstackdocstheme', ] # openstackdocstheme options diff --git a/pyproject.toml b/pyproject.toml index bb78e801d..694fde340 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,21 @@ parallel = "taskflow.engines.action_engine.engine:ParallelActionEngine" worker-based = "taskflow.engines.worker_based.engine:WorkerBasedActionEngine" workers = "taskflow.engines.worker_based.engine:WorkerBasedActionEngine" -[tool.setuptools] -packages = [ - "taskflow" -] +[tool.setuptools.packages.find] +include = ["taskflow"] + +[tool.ruff] +line-length = 79 + +[tool.ruff.format] +quote-style = "preserve" +docstring-code-format = true + +[tool.ruff.lint] +select = ["E4", "E5", "E7", "E9", "F", "G", "LOG", "S"] +external = ["H"] +ignore = ["E402", "E721", "E731", "E741"] + +[tool.ruff.lint.per-file-ignores] +"taskflow/examples/*" = ["S"] +"taskflow/tests/*" = ["S"] diff --git a/releasenotes/source/conf.py b/releasenotes/source/conf.py index ea724b598..5f9ab232f 100644 --- a/releasenotes/source/conf.py +++ b/releasenotes/source/conf.py @@ -88,9 +88,13 @@ html_static_path = ['_static'] # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ('index', 'taskflowReleaseNotes.tex', - 'taskflow Release Notes Documentation', - 'taskflow Developers', 'manual'), + ( + 'index', + 'taskflowReleaseNotes.tex', + 'taskflow Release Notes Documentation', + 'taskflow Developers', + 'manual', + ), ] diff --git a/setup.py b/setup.py index cd35c3c35..481505b03 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,4 @@ import setuptools -setuptools.setup( - setup_requires=['pbr>=2.0.0'], - pbr=True) +setuptools.setup(setup_requires=['pbr>=2.0.0'], pbr=True) diff --git a/taskflow/atom.py b/taskflow/atom.py index a6eba50e5..122ae687d 100644 --- a/taskflow/atom.py +++ b/taskflow/atom.py @@ -52,8 +52,9 @@ def _save_as_to_mapping(save_as): # NOTE(harlowja): this means that your atom will return a indexable # object, like a list or tuple and the results can be mapped by index # to that tuple/list that is returned for others to use. - return collections.OrderedDict((key, num) - for num, key in enumerate(save_as)) + return collections.OrderedDict( + (key, num) for num, key in enumerate(save_as) + ) elif isinstance(save_as, _set_types): # NOTE(harlowja): in the case where a set is given we will not be # able to determine the numeric ordering in a reliable way (since it @@ -61,8 +62,10 @@ def _save_as_to_mapping(save_as): # result of the atom will be via the key itself. return collections.OrderedDict((key, key) for key in save_as) else: - raise TypeError('Atom provides parameter ' - 'should be str, set or tuple/list, not %r' % save_as) + raise TypeError( + 'Atom provides parameter ' + 'should be str, set or tuple/list, not %r' % save_as + ) def _build_rebind_dict(req_args, rebind_args): @@ -84,17 +87,19 @@ def _build_rebind_dict(req_args, rebind_args): # Extra things were rebound, that may be because of *args # or **kwargs (or some other reason); so just keep all of them # using 1:1 rebinding... - rebind.update((a, a) for a in rebind_args[len(req_args):]) + rebind.update((a, a) for a in rebind_args[len(req_args) :]) return rebind elif isinstance(rebind_args, dict): return rebind_args else: - raise TypeError("Invalid rebind value '%s' (%s)" - % (rebind_args, type(rebind_args))) + raise TypeError( + "Invalid rebind value '%s' (%s)" % (rebind_args, type(rebind_args)) + ) -def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer, - ignore_list=None): +def _build_arg_mapping( + atom_name, reqs, rebind_args, function, do_infer, ignore_list=None +): """Builds an input argument mapping for a given function. Given a function, its requirements and a rebind mapping this helper @@ -135,8 +140,9 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer, # Determine if there are optional arguments that we may or may not take. if do_infer: opt_args = sets.OrderedSet(all_args) - opt_args = opt_args - set(itertools.chain(required.keys(), - iter(ignore_list))) + opt_args = opt_args - set( + itertools.chain(required.keys(), iter(ignore_list)) + ) optional = collections.OrderedDict((a, a) for a in opt_args) else: optional = collections.OrderedDict() @@ -146,14 +152,17 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer, extra_args = sets.OrderedSet(required.keys()) extra_args -= all_args if extra_args: - raise ValueError('Extra arguments given to atom %s: %s' - % (atom_name, list(extra_args))) + raise ValueError( + 'Extra arguments given to atom %s: %s' + % (atom_name, list(extra_args)) + ) # NOTE(imelnikov): don't use set to preserve order in error message missing_args = [arg for arg in req_args if arg not in required] if missing_args: - raise ValueError('Missing arguments for atom %s: %s' - % (atom_name, missing_args)) + raise ValueError( + 'Missing arguments for atom %s: %s' % (atom_name, missing_args) + ) return required, optional @@ -244,9 +253,18 @@ class Atom(metaclass=abc.ABCMeta): default_provides = None - def __init__(self, name=None, provides=None, requires=None, - auto_extract=True, rebind=None, inject=None, - ignore_list=None, revert_rebind=None, revert_requires=None): + def __init__( + self, + name=None, + provides=None, + requires=None, + auto_extract=True, + rebind=None, + inject=None, + ignore_list=None, + revert_rebind=None, + revert_requires=None, + ): if provides is None: provides = self.default_provides @@ -263,8 +281,9 @@ class Atom(metaclass=abc.ABCMeta): self.rebind, exec_requires, self.optional = self._build_arg_mapping( self.execute, requires=requires, - rebind=rebind, auto_extract=auto_extract, - ignore_list=ignore_list + rebind=rebind, + auto_extract=auto_extract, + ignore_list=ignore_list, ) revert_ignore = ignore_list + list(_default_revert_args) @@ -273,10 +292,11 @@ class Atom(metaclass=abc.ABCMeta): requires=revert_requires or requires, rebind=revert_rebind or rebind, auto_extract=auto_extract, - ignore_list=revert_ignore + ignore_list=revert_ignore, + ) + (self.revert_rebind, addl_requires, self.revert_optional) = ( + revert_mapping ) - (self.revert_rebind, addl_requires, - self.revert_optional) = revert_mapping # TODO(bnemec): This should be documented as an ivar, but can't be due # to https://github.com/sphinx-doc/sphinx/issues/2549 @@ -284,18 +304,30 @@ class Atom(metaclass=abc.ABCMeta): #: requires to function. self.requires = exec_requires.union(addl_requires) - def _build_arg_mapping(self, executor, requires=None, rebind=None, - auto_extract=True, ignore_list=None): + def _build_arg_mapping( + self, + executor, + requires=None, + rebind=None, + auto_extract=True, + ignore_list=None, + ): - required, optional = _build_arg_mapping(self.name, requires, rebind, - executor, auto_extract, - ignore_list=ignore_list) + required, optional = _build_arg_mapping( + self.name, + requires, + rebind, + executor, + auto_extract, + ignore_list=ignore_list, + ) # Form the real rebind mapping, if a key name is the same as the # key value, then well there is no rebinding happening, otherwise # there will be. rebind = collections.OrderedDict() - for (arg_name, bound_name) in itertools.chain(required.items(), - optional.items()): + for arg_name, bound_name in itertools.chain( + required.items(), optional.items() + ): rebind.setdefault(arg_name, bound_name) requires = sets.OrderedSet(required.values()) optional = sets.OrderedSet(optional.values()) diff --git a/taskflow/conductors/backends/__init__.py b/taskflow/conductors/backends/__init__.py index 4fd3c028e..b389b0a67 100644 --- a/taskflow/conductors/backends/__init__.py +++ b/taskflow/conductors/backends/__init__.py @@ -34,10 +34,12 @@ def fetch(kind, name, jobboard, namespace=CONDUCTOR_NAMESPACE, **kwargs): LOG.debug('Looking for %r conductor driver in %r', kind, namespace) try: mgr = stevedore.driver.DriverManager( - namespace, kind, + namespace, + kind, invoke_on_load=True, invoke_args=(name, jobboard), - invoke_kwds=kwargs) + invoke_kwds=kwargs, + ) return mgr.driver except RuntimeError as e: raise exc.NotFound("Could not find conductor %s" % (kind), e) diff --git a/taskflow/conductors/backends/impl_blocking.py b/taskflow/conductors/backends/impl_blocking.py index c3d9b3f74..774070cc6 100644 --- a/taskflow/conductors/backends/impl_blocking.py +++ b/taskflow/conductors/backends/impl_blocking.py @@ -27,13 +27,24 @@ class BlockingConductor(impl_executor.ExecutorConductor): def _executor_factory(): return futurist.SynchronousExecutor() - def __init__(self, name, jobboard, - persistence=None, engine=None, - engine_options=None, wait_timeout=None, - log=None, max_simultaneous_jobs=MAX_SIMULTANEOUS_JOBS): + def __init__( + self, + name, + jobboard, + persistence=None, + engine=None, + engine_options=None, + wait_timeout=None, + log=None, + max_simultaneous_jobs=MAX_SIMULTANEOUS_JOBS, + ): super().__init__( - name, jobboard, - persistence=persistence, engine=engine, + name, + jobboard, + persistence=persistence, + engine=engine, engine_options=engine_options, - wait_timeout=wait_timeout, log=log, - max_simultaneous_jobs=max_simultaneous_jobs) + wait_timeout=wait_timeout, + log=log, + max_simultaneous_jobs=max_simultaneous_jobs, + ) diff --git a/taskflow/conductors/backends/impl_executor.py b/taskflow/conductors/backends/impl_executor.py index af69b9573..65679de19 100644 --- a/taskflow/conductors/backends/impl_executor.py +++ b/taskflow/conductors/backends/impl_executor.py @@ -78,47 +78,72 @@ class ExecutorConductor(base.Conductor, metaclass=abc.ABCMeta): """ #: Exceptions that will **not** cause consumption to occur. - NO_CONSUME_EXCEPTIONS = tuple([ - excp.ExecutionFailure, - excp.StorageFailure, - ]) + NO_CONSUME_EXCEPTIONS = tuple( + [ + excp.ExecutionFailure, + excp.StorageFailure, + ] + ) _event_factory = threading.Event """This attribute *can* be overridden by subclasses (for example if an eventlet *green* event works better for the conductor user).""" - EVENTS_EMITTED = tuple([ - 'compilation_start', 'compilation_end', - 'preparation_start', 'preparation_end', - 'validation_start', 'validation_end', - 'running_start', 'running_end', - 'job_consumed', 'job_abandoned', - ]) + EVENTS_EMITTED = tuple( + [ + 'compilation_start', + 'compilation_end', + 'preparation_start', + 'preparation_end', + 'validation_start', + 'validation_end', + 'running_start', + 'running_end', + 'job_consumed', + 'job_abandoned', + ] + ) """Events will be emitted for each of the events above. The event is emitted to listeners registered with the conductor. """ - def __init__(self, name, jobboard, - persistence=None, engine=None, - engine_options=None, wait_timeout=None, - log=None, max_simultaneous_jobs=MAX_SIMULTANEOUS_JOBS): + def __init__( + self, + name, + jobboard, + persistence=None, + engine=None, + engine_options=None, + wait_timeout=None, + log=None, + max_simultaneous_jobs=MAX_SIMULTANEOUS_JOBS, + ): super().__init__( - name, jobboard, persistence=persistence, - engine=engine, engine_options=engine_options) + name, + jobboard, + persistence=persistence, + engine=engine, + engine_options=engine_options, + ) self._wait_timeout = tt.convert_to_timeout( - value=wait_timeout, default_value=self.WAIT_TIMEOUT, - event_factory=self._event_factory) + value=wait_timeout, + default_value=self.WAIT_TIMEOUT, + event_factory=self._event_factory, + ) self._dead = self._event_factory() self._log = misc.pick_first_not_none(log, self.LOG, LOG) self._max_simultaneous_jobs = int( - misc.pick_first_not_none(max_simultaneous_jobs, - self.MAX_SIMULTANEOUS_JOBS)) + misc.pick_first_not_none( + max_simultaneous_jobs, self.MAX_SIMULTANEOUS_JOBS + ) + ) self._dispatched = set() def _executor_factory(self): """Creates an executor to be used during dispatching.""" - raise excp.NotImplementedError("This method must be implemented but" - " it has not been") + raise excp.NotImplementedError( + "This method must be implemented but it has not been" + ) def stop(self): self._wait_timeout.interrupt() @@ -134,8 +159,9 @@ class ExecutorConductor(base.Conductor, metaclass=abc.ABCMeta): def _listeners_from_job(self, job, engine): listeners = super()._listeners_from_job(job, engine) - listeners.append(logging_listener.LoggingListener(engine, - log=self._log)) + listeners.append( + logging_listener.LoggingListener(engine, log=self._log) + ) return listeners def _dispatch_job(self, job): @@ -156,17 +182,22 @@ class ExecutorConductor(base.Conductor, metaclass=abc.ABCMeta): has_suspended = False for _state in engine.run_iter(): if not has_suspended and self._wait_timeout.is_stopped(): - self._log.info("Conductor stopped, requesting " - "suspension of engine running " - "job %s", job) + self._log.info( + "Conductor stopped, requesting " + "suspension of engine running " + "job %s", + job, + ) engine.suspend() has_suspended = True try: - for stage_func, event_name in [(engine.compile, 'compilation'), - (engine.prepare, 'preparation'), - (engine.validate, 'validation'), - (_run_engine, 'running')]: + for stage_func, event_name in [ + (engine.compile, 'compilation'), + (engine.prepare, 'preparation'), + (engine.validate, 'validation'), + (_run_engine, 'running'), + ]: self._notifier.notify("%s_start" % event_name, details) stage_func() self._notifier.notify("%s_end" % event_name, details) @@ -177,23 +208,35 @@ class ExecutorConductor(base.Conductor, metaclass=abc.ABCMeta): if consume: self._log.warning( "Job execution failed (consumption being" - " skipped): %s [%s failures]", job, len(e)) + " skipped): %s [%s failures]", + job, + len(e), + ) else: self._log.warning( "Job execution failed (consumption" - " proceeding): %s [%s failures]", job, len(e)) + " proceeding): %s [%s failures]", + job, + len(e), + ) # Show the failure/s + traceback (if possible)... for i, f in enumerate(e): - self._log.warning("%s. %s", i + 1, - f.pformat(traceback=True)) + self._log.warning( + "%s. %s", i + 1, f.pformat(traceback=True) + ) except self.NO_CONSUME_EXCEPTIONS: - self._log.warning("Job execution failed (consumption being" - " skipped): %s", job, exc_info=True) + self._log.warning( + "Job execution failed (consumption being skipped): %s", + job, + exc_info=True, + ) consume = False except Exception: self._log.warning( "Job execution failed (consumption proceeding): %s", - job, exc_info=True) + job, + exc_info=True, + ) else: if engine.storage.get_flow_state() == states.SUSPENDED: self._log.info("Job execution was suspended: %s", job) @@ -206,32 +249,43 @@ class ExecutorConductor(base.Conductor, metaclass=abc.ABCMeta): try: if consume: self._jobboard.consume(job, self._name) - self._notifier.notify("job_consumed", { - 'job': job, - 'conductor': self, - 'persistence': self._persistence, - }) + self._notifier.notify( + "job_consumed", + { + 'job': job, + 'conductor': self, + 'persistence': self._persistence, + }, + ) elif trash: self._jobboard.trash(job, self._name) - self._notifier.notify("job_trashed", { - 'job': job, - 'conductor': self, - 'persistence': self._persistence, - }) + self._notifier.notify( + "job_trashed", + { + 'job': job, + 'conductor': self, + 'persistence': self._persistence, + }, + ) else: self._jobboard.abandon(job, self._name) - self._notifier.notify("job_abandoned", { - 'job': job, - 'conductor': self, - 'persistence': self._persistence, - }) + self._notifier.notify( + "job_abandoned", + { + 'job': job, + 'conductor': self, + 'persistence': self._persistence, + }, + ) except (excp.JobFailure, excp.NotFound): if consume: - self._log.warn("Failed job consumption: %s", job, - exc_info=True) + self._log.warn( + "Failed job consumption: %s", job, exc_info=True + ) else: - self._log.warn("Failed job abandonment: %s", job, - exc_info=True) + self._log.warn( + "Failed job abandonment: %s", job, exc_info=True + ) def _on_job_done(self, job, fut): consume = False @@ -273,7 +327,8 @@ class ExecutorConductor(base.Conductor, metaclass=abc.ABCMeta): if max_dispatches == 0: raise StopIteration fresh_period = timeutils.StopWatch( - duration=self.REFRESH_PERIODICITY) + duration=self.REFRESH_PERIODICITY + ) fresh_period.start() while not is_stopped(): any_dispatched = False @@ -284,28 +339,32 @@ class ExecutorConductor(base.Conductor, metaclass=abc.ABCMeta): ensure_fresh = False job_it = itertools.takewhile( self._can_claim_more_jobs, - self._jobboard.iterjobs(ensure_fresh=ensure_fresh)) + self._jobboard.iterjobs(ensure_fresh=ensure_fresh), + ) for job in job_it: self._log.debug("Trying to claim job: %s", job) try: self._jobboard.claim(job, self._name) except (excp.UnclaimableJob, excp.NotFound): - self._log.debug("Job already claimed or" - " consumed: %s", job) + self._log.debug( + "Job already claimed or consumed: %s", job + ) else: try: fut = executor.submit(self._dispatch_job, job) except RuntimeError: with excutils.save_and_reraise_exception(): - self._log.warn("Job dispatch submitting" - " failed: %s", job) + self._log.warn( + "Job dispatch submitting failed: %s", job + ) self._try_finish_job(job, False) else: fut.job = job self._dispatched.add(fut) any_dispatched = True fut.add_done_callback( - functools.partial(self._on_job_done, job)) + functools.partial(self._on_job_done, job) + ) total_dispatched = next(dispatch_gen) if not any_dispatched and not is_stopped(): self._wait_timeout.wait() @@ -314,8 +373,9 @@ class ExecutorConductor(base.Conductor, metaclass=abc.ABCMeta): # max dispatch number (which implies we should do no more work). with excutils.save_and_reraise_exception(): if max_dispatches >= 0 and total_dispatched >= max_dispatches: - self._log.info("Maximum dispatch limit of %s reached", - max_dispatches) + self._log.info( + "Maximum dispatch limit of %s reached", max_dispatches + ) def run(self, max_dispatches=None): self._dead.clear() @@ -323,8 +383,7 @@ class ExecutorConductor(base.Conductor, metaclass=abc.ABCMeta): try: self._jobboard.register_entity(self.conductor) with self._executor_factory() as executor: - self._run_until_dead(executor, - max_dispatches=max_dispatches) + self._run_until_dead(executor, max_dispatches=max_dispatches) except StopIteration: pass except KeyboardInterrupt: diff --git a/taskflow/conductors/backends/impl_nonblocking.py b/taskflow/conductors/backends/impl_nonblocking.py index 97f54fbcf..873da51fb 100644 --- a/taskflow/conductors/backends/impl_nonblocking.py +++ b/taskflow/conductors/backends/impl_nonblocking.py @@ -47,20 +47,34 @@ class NonBlockingConductor(impl_executor.ExecutorConductor): max_workers = max_simultaneous_jobs return futurist.ThreadPoolExecutor(max_workers=max_workers) - def __init__(self, name, jobboard, - persistence=None, engine=None, - engine_options=None, wait_timeout=None, - log=None, max_simultaneous_jobs=MAX_SIMULTANEOUS_JOBS, - executor_factory=None): + def __init__( + self, + name, + jobboard, + persistence=None, + engine=None, + engine_options=None, + wait_timeout=None, + log=None, + max_simultaneous_jobs=MAX_SIMULTANEOUS_JOBS, + executor_factory=None, + ): super().__init__( - name, jobboard, - persistence=persistence, engine=engine, - engine_options=engine_options, wait_timeout=wait_timeout, - log=log, max_simultaneous_jobs=max_simultaneous_jobs) + name, + jobboard, + persistence=persistence, + engine=engine, + engine_options=engine_options, + wait_timeout=wait_timeout, + log=log, + max_simultaneous_jobs=max_simultaneous_jobs, + ) if executor_factory is None: self._executor_factory = self._default_executor_factory else: if not callable(executor_factory): - raise ValueError("Provided keyword argument 'executor_factory'" - " must be callable") + raise ValueError( + "Provided keyword argument 'executor_factory'" + " must be callable" + ) self._executor_factory = executor_factory diff --git a/taskflow/conductors/base.py b/taskflow/conductors/base.py index 5fcd45039..8a774a42a 100644 --- a/taskflow/conductors/base.py +++ b/taskflow/conductors/base.py @@ -37,8 +37,14 @@ class Conductor(metaclass=abc.ABCMeta): #: Entity kind used when creating new entity objects ENTITY_KIND = 'conductor' - def __init__(self, name, jobboard, - persistence=None, engine=None, engine_options=None): + def __init__( + self, + name, + jobboard, + persistence=None, + engine=None, + engine_options=None, + ): self._name = name self._jobboard = jobboard self._engine = engine @@ -91,9 +97,11 @@ class Conductor(metaclass=abc.ABCMeta): flow_uuid = job.details["flow_uuid"] flow_detail = book.find(flow_uuid) if flow_detail is None: - raise excp.NotFound("No matching flow detail found in" - " jobs book for flow detail" - " with uuid %s" % flow_uuid) + raise excp.NotFound( + "No matching flow detail found in" + " jobs book for flow detail" + " with uuid %s" % flow_uuid + ) else: choices = len(book) if choices == 1: @@ -101,8 +109,10 @@ class Conductor(metaclass=abc.ABCMeta): elif choices == 0: raise excp.NotFound("No flow detail(s) found in jobs book") else: - raise excp.MultipleChoices("No matching flow detail found (%s" - " choices) in jobs book" % choices) + raise excp.MultipleChoices( + "No matching flow detail found (%s" + " choices) in jobs book" % choices + ) return flow_detail def _engine_from_job(self, job): @@ -116,10 +126,13 @@ class Conductor(metaclass=abc.ABCMeta): if job.details and 'store' in job.details: store.update(job.details["store"]) - engine = engines.load_from_detail(flow_detail, store=store, - engine=self._engine, - backend=self._persistence, - **self._engine_options) + engine = engines.load_from_detail( + flow_detail, + store=store, + engine=self._engine, + backend=self._persistence, + **self._engine_options, + ) return engine def _listeners_from_job(self, job, engine): diff --git a/taskflow/deciders.py b/taskflow/deciders.py index 5690df824..d16d7fa72 100644 --- a/taskflow/deciders.py +++ b/taskflow/deciders.py @@ -71,23 +71,32 @@ class Depth(misc.StrEnum): # Nothing to do in the first place... return desired_depth if not isinstance(desired_depth, str): - raise TypeError("Unexpected desired depth type, string type" - " expected, not %s" % type(desired_depth)) + raise TypeError( + "Unexpected desired depth type, string type" + " expected, not %s" % type(desired_depth) + ) try: return cls(desired_depth.upper()) except ValueError: pretty_depths = sorted([a_depth.name for a_depth in cls]) - raise ValueError("Unexpected decider depth value, one of" - " %s (case-insensitive) is expected and" - " not '%s'" % (pretty_depths, desired_depth)) + raise ValueError( + "Unexpected decider depth value, one of" + " %s (case-insensitive) is expected and" + " not '%s'" % (pretty_depths, desired_depth) + ) # Depth area of influence order (from greater influence to least). # # Order very much matters here... -_ORDERING = tuple([ - Depth.ALL, Depth.FLOW, Depth.NEIGHBORS, Depth.ATOM, -]) +_ORDERING = tuple( + [ + Depth.ALL, + Depth.FLOW, + Depth.NEIGHBORS, + Depth.ATOM, + ] +) def pick_widest(depths): diff --git a/taskflow/engines/__init__.py b/taskflow/engines/__init__.py index 30ad33324..7098d6fc6 100644 --- a/taskflow/engines/__init__.py +++ b/taskflow/engines/__init__.py @@ -18,13 +18,14 @@ from oslo_utils import eventletutils as _eventletutils # are highly recommended to be patched (or otherwise bad things could # happen). _eventletutils.warn_eventlet_not_patched( - expected_patched_modules=['time', 'thread']) + expected_patched_modules=['time', 'thread'] +) # Promote helpers to this module namespace (for easy access). -from taskflow.engines.helpers import flow_from_detail # noqa -from taskflow.engines.helpers import load # noqa -from taskflow.engines.helpers import load_from_detail # noqa +from taskflow.engines.helpers import flow_from_detail # noqa +from taskflow.engines.helpers import load # noqa +from taskflow.engines.helpers import load_from_detail # noqa from taskflow.engines.helpers import load_from_factory # noqa -from taskflow.engines.helpers import run # noqa +from taskflow.engines.helpers import run # noqa from taskflow.engines.helpers import save_factory_details # noqa diff --git a/taskflow/engines/action_engine/actions/base.py b/taskflow/engines/action_engine/actions/base.py index ef201d0ff..a704b0314 100644 --- a/taskflow/engines/action_engine/actions/base.py +++ b/taskflow/engines/action_engine/actions/base.py @@ -26,8 +26,12 @@ class Action(metaclass=abc.ABCMeta): """ #: States that are expected to have a result to save... - SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE, - states.REVERTED, states.REVERT_FAILURE) + SAVE_RESULT_STATES = ( + states.SUCCESS, + states.FAILURE, + states.REVERTED, + states.REVERT_FAILURE, + ) def __init__(self, storage, notifier): self._storage = storage diff --git a/taskflow/engines/action_engine/actions/retry.py b/taskflow/engines/action_engine/actions/retry.py index bdc6b65bd..5b6eb4b8e 100644 --- a/taskflow/engines/action_engine/actions/retry.py +++ b/taskflow/engines/action_engine/actions/retry.py @@ -30,13 +30,13 @@ class RetryAction(base.Action): arguments = self._storage.fetch_mapped_args( retry.revert_rebind, atom_name=retry.name, - optional_args=retry.revert_optional + optional_args=retry.revert_optional, ) else: arguments = self._storage.fetch_mapped_args( retry.rebind, atom_name=retry.name, - optional_args=retry.optional + optional_args=retry.optional, ) history = self._storage.get_retry_history(retry.name) arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history @@ -74,7 +74,8 @@ class RetryAction(base.Action): def schedule_execution(self, retry): self.change_state(retry, states.RUNNING) return self._retry_executor.execute_retry( - retry, self._get_retry_args(retry)) + retry, self._get_retry_args(retry) + ) def complete_reversion(self, retry, result): if isinstance(result, failure.Failure): @@ -94,7 +95,8 @@ class RetryAction(base.Action): retry_atom.REVERT_FLOW_FAILURES: self._storage.get_failures(), } return self._retry_executor.revert_retry( - retry, self._get_retry_args(retry, addons=arg_addons, revert=True)) + retry, self._get_retry_args(retry, addons=arg_addons, revert=True) + ) def on_failure(self, retry, atom, last_failure): self._storage.save_retry_failure(retry.name, atom.name, last_failure) diff --git a/taskflow/engines/action_engine/actions/task.py b/taskflow/engines/action_engine/actions/task.py index 05af46a6b..d0271cd07 100644 --- a/taskflow/engines/action_engine/actions/task.py +++ b/taskflow/engines/action_engine/actions/task.py @@ -47,11 +47,13 @@ class TaskAction(base.Action): return False return True - def change_state(self, task, state, - progress=None, result=base.Action.NO_RESULT): + def change_state( + self, task, state, progress=None, result=base.Action.NO_RESULT + ): old_state = self._storage.get_atom_state(task.name) - if self._is_identity_transition(old_state, state, task, - progress=progress): + if self._is_identity_transition( + old_state, state, task, progress=progress + ): # NOTE(imelnikov): ignore identity transitions in order # to avoid extra write to storage backend and, what's # more important, extra notifications. @@ -85,60 +87,71 @@ class TaskAction(base.Action): pass else: try: - self._storage.set_task_progress(task.name, progress, - details=details) + self._storage.set_task_progress( + task.name, progress, details=details + ) except Exception: # Update progress callbacks should never fail, so capture and # log the emitted exception instead of raising it. - LOG.exception("Failed setting task progress for %s to %0.3f", - task, progress) + LOG.exception( + "Failed setting task progress for %s to %0.3f", + task, + progress, + ) def schedule_execution(self, task): self.change_state(task, states.RUNNING, progress=0.0) arguments = self._storage.fetch_mapped_args( - task.rebind, - atom_name=task.name, - optional_args=task.optional + task.rebind, atom_name=task.name, optional_args=task.optional ) if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS): - progress_callback = functools.partial(self._on_update_progress, - task) + progress_callback = functools.partial( + self._on_update_progress, task + ) else: progress_callback = None task_uuid = self._storage.get_atom_uuid(task.name) return self._task_executor.execute_task( - task, task_uuid, arguments, - progress_callback=progress_callback) + task, task_uuid, arguments, progress_callback=progress_callback + ) def complete_execution(self, task, result): if isinstance(result, failure.Failure): self.change_state(task, states.FAILURE, result=result) else: - self.change_state(task, states.SUCCESS, - result=result, progress=1.0) + self.change_state( + task, states.SUCCESS, result=result, progress=1.0 + ) def schedule_reversion(self, task): self.change_state(task, states.REVERTING, progress=0.0) arguments = self._storage.fetch_mapped_args( task.revert_rebind, atom_name=task.name, - optional_args=task.revert_optional + optional_args=task.revert_optional, ) task_uuid = self._storage.get_atom_uuid(task.name) task_result = self._storage.get(task.name) failures = self._storage.get_failures() if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS): - progress_callback = functools.partial(self._on_update_progress, - task) + progress_callback = functools.partial( + self._on_update_progress, task + ) else: progress_callback = None return self._task_executor.revert_task( - task, task_uuid, arguments, task_result, failures, - progress_callback=progress_callback) + task, + task_uuid, + arguments, + task_result, + failures, + progress_callback=progress_callback, + ) def complete_reversion(self, task, result): if isinstance(result, failure.Failure): self.change_state(task, states.REVERT_FAILURE, result=result) else: - self.change_state(task, states.REVERTED, progress=1.0, - result=result) + self.change_state( + task, states.REVERTED, progress=1.0, result=result + ) diff --git a/taskflow/engines/action_engine/builder.py b/taskflow/engines/action_engine/builder.py index 30737e787..198d41f4b 100644 --- a/taskflow/engines/action_engine/builder.py +++ b/taskflow/engines/action_engine/builder.py @@ -143,9 +143,12 @@ class MachineBuilder: def do_schedule(next_nodes): with self._storage.lock.write_lock(): return self._scheduler.schedule( - sorted(next_nodes, - key=lambda node: getattr(node, 'priority', 0), - reverse=True)) + sorted( + next_nodes, + key=lambda node: getattr(node, 'priority', 0), + reverse=True, + ) + ) def iter_next_atoms(atom=None, apply_deciders=True): # Yields and filters and tweaks the next atoms to run... @@ -165,8 +168,10 @@ class MachineBuilder: # that are now ready to be ran. with self._storage.lock.write_lock(): memory.next_up.update( - iter_utils.unique_seen((self._completer.resume(), - iter_next_atoms()))) + iter_utils.unique_seen( + (self._completer.resume(), iter_next_atoms()) + ) + ) return SCHEDULE def game_over(old_state, new_state, event): @@ -181,13 +186,17 @@ class MachineBuilder: # Avoid activating the deciders, since at this point # the engine is finishing and there will be no more further # work done anyway... - iter_next_atoms(apply_deciders=False)) + iter_next_atoms(apply_deciders=False) + ) if leftover_atoms: # Ok we didn't finish (either reverting or executing...) so # that means we must of been stopped at some point... - LOG.trace("Suspension determined to have been reacted to" - " since (at least) %s atoms have been left in an" - " unfinished state", leftover_atoms) + LOG.trace( + "Suspension determined to have been reacted to" + " since (at least) %s atoms have been left in an" + " unfinished state", + leftover_atoms, + ) return SUSPENDED elif self._runtime.is_success(): return SUCCESS @@ -239,11 +248,16 @@ class MachineBuilder: # would suck...) if LOG.isEnabledFor(logging.DEBUG): intention = get_atom_intention(atom.name) - LOG.debug("Discarding failure '%s' (in response" - " to outcome '%s') under completion" - " units request during completion of" - " atom '%s' (intention is to %s)", - result, outcome, atom, intention) + LOG.debug( + "Discarding failure '%s' (in response" + " to outcome '%s') under completion" + " units request during completion of" + " atom '%s' (intention is to %s)", + result, + outcome, + atom, + intention, + ) if gather_statistics: statistics['discarded_failures'] += 1 if gather_statistics: @@ -256,8 +270,7 @@ class MachineBuilder: return WAS_CANCELLED except Exception: memory.failures.append(failure.Failure()) - LOG.exception("Engine '%s' atom post-completion" - " failed", atom) + LOG.exception("Engine '%s' atom post-completion failed", atom) return FAILED_COMPLETING else: return SUCCESSFULLY_COMPLETED @@ -286,8 +299,10 @@ class MachineBuilder: # before we iterate over any successors or predecessors # that we know it has been completed and saved and so on... completion_status = complete_an_atom(fut) - if (not memory.failures - and completion_status != WAS_CANCELLED): + if ( + not memory.failures + and completion_status != WAS_CANCELLED + ): atom = fut.atom try: more_work = set(iter_next_atoms(atom=atom)) @@ -295,12 +310,17 @@ class MachineBuilder: memory.failures.append(failure.Failure()) LOG.exception( "Engine '%s' atom post-completion" - " next atom searching failed", atom) + " next atom searching failed", + atom, + ) else: next_up.update(more_work) current_flow_state = self._storage.get_flow_state() - if (current_flow_state == st.RUNNING - and next_up and not memory.failures): + if ( + current_flow_state == st.RUNNING + and next_up + and not memory.failures + ): memory.next_up.update(next_up) return SCHEDULE elif memory.not_done: @@ -311,8 +331,11 @@ class MachineBuilder: return FINISH def on_exit(old_state, event): - LOG.trace("Exiting old state '%s' in response to event '%s'", - old_state, event) + LOG.trace( + "Exiting old state '%s' in response to event '%s'", + old_state, + event, + ) if gather_statistics: if old_state in watches: w = watches[old_state] @@ -324,8 +347,11 @@ class MachineBuilder: statistics['awaiting'] = len(memory.next_up) def on_enter(new_state, event): - LOG.trace("Entering new state '%s' in response to event '%s'", - new_state, event) + LOG.trace( + "Entering new state '%s' in response to event '%s'", + new_state, + event, + ) if gather_statistics and new_state in watches: watches[new_state].restart() diff --git a/taskflow/engines/action_engine/compiler.py b/taskflow/engines/action_engine/compiler.py index 92517d460..92d9717ed 100644 --- a/taskflow/engines/action_engine/compiler.py +++ b/taskflow/engines/action_engine/compiler.py @@ -25,7 +25,7 @@ from taskflow.types import tree as tr from taskflow.utils import iter_utils from taskflow.utils import misc -from taskflow.flow import (LINK_INVARIANT, LINK_RETRY) # noqa +from taskflow.flow import LINK_INVARIANT, LINK_RETRY # noqa LOG = logging.getLogger(__name__) @@ -105,8 +105,9 @@ class Compilation: def _overlap_occurrence_detector(to_graph, from_graph): """Returns how many nodes in 'from' graph are in 'to' graph (if any).""" - return iter_utils.count(node for node in from_graph.nodes - if node in to_graph) + return iter_utils.count( + node for node in from_graph.nodes if node in to_graph + ) def _add_update_edges(graph, nodes_from, nodes_to, attr_dict=None): @@ -162,22 +163,30 @@ class FlowCompiler: tree_node.add(tr.Node(flow.retry, kind=RETRY)) decomposed = { child: self._deep_compiler_func(child, parent=tree_node)[0] - for child in flow} + for child in flow + } decomposed_graphs = list(decomposed.values()) - graph = gr.merge_graphs(graph, *decomposed_graphs, - overlap_detector=_overlap_occurrence_detector) + graph = gr.merge_graphs( + graph, + *decomposed_graphs, + overlap_detector=_overlap_occurrence_detector, + ) for u, v, attr_dict in flow.iter_links(): u_graph = decomposed[u] v_graph = decomposed[v] - _add_update_edges(graph, u_graph.no_successors_iter(), - list(v_graph.no_predecessors_iter()), - attr_dict=attr_dict) + _add_update_edges( + graph, + u_graph.no_successors_iter(), + list(v_graph.no_predecessors_iter()), + attr_dict=attr_dict, + ) # Insert the flow(s) retry if needed, and always make sure it # is the **immediate** successor of the flow node itself. if flow.retry is not None: graph.add_node(flow.retry, kind=RETRY) - _add_update_edges(graph, [flow], [flow.retry], - attr_dict={LINK_INVARIANT: True}) + _add_update_edges( + graph, [flow], [flow.retry], attr_dict={LINK_INVARIANT: True} + ) for node in graph.nodes: if node is not flow.retry and node is not flow: graph.nodes[node].setdefault(RETRY, flow.retry) @@ -192,10 +201,16 @@ class FlowCompiler: # us to easily know when we have entered a flow (when running) and # do special and/or smart things such as only traverse up to the # start of a flow when looking for node deciders. - _add_update_edges(graph, from_nodes, [ - node for node in graph.no_predecessors_iter() - if node is not flow - ], attr_dict=attr_dict) + _add_update_edges( + graph, + from_nodes, + [ + node + for node in graph.no_predecessors_iter() + if node is not flow + ], + attr_dict=attr_dict, + ) # Connect all nodes with no successors into a special terminator # that is used to identify the end of the flow and ensure that all # execution traversals will traverse over this node before executing @@ -214,10 +229,16 @@ class FlowCompiler: # that networkx provides?? flow_term = Terminator(flow) graph.add_node(flow_term, kind=FLOW_END, noop=True) - _add_update_edges(graph, [ - node for node in graph.no_successors_iter() - if node is not flow_term - ], [flow_term], attr_dict={LINK_INVARIANT: True}) + _add_update_edges( + graph, + [ + node + for node in graph.no_successors_iter() + if node is not flow_term + ], + [flow_term], + attr_dict={LINK_INVARIANT: True}, + ) return graph, tree_node @@ -337,15 +358,19 @@ class PatternCompiler: self._post_item_compile(item, graph, node) return graph, node else: - raise TypeError("Unknown object '%s' (%s) requested to compile" - % (item, type(item))) + raise TypeError( + "Unknown object '%s' (%s) requested to compile" + % (item, type(item)) + ) def _pre_item_compile(self, item): """Called before a item is compiled; any pre-compilation actions.""" if item in self._history: - raise ValueError("Already compiled item '%s' (%s), duplicate" - " and/or recursive compiling is not" - " supported" % (item, type(item))) + raise ValueError( + "Already compiled item '%s' (%s), duplicate" + " and/or recursive compiling is not" + " supported" % (item, type(item)) + ) self._history.add(item) if LOG.isEnabledFor(logging.TRACE): LOG.trace("%sCompiling '%s'", " " * self._level, item) diff --git a/taskflow/engines/action_engine/completer.py b/taskflow/engines/action_engine/completer.py index 922c292c4..117bb22ea 100644 --- a/taskflow/engines/action_engine/completer.py +++ b/taskflow/engines/action_engine/completer.py @@ -58,10 +58,14 @@ class RevertAndRetry(Strategy): self._retry = retry def apply(self): - tweaked = self._runtime.reset_atoms([self._retry], state=None, - intention=st.RETRY) - tweaked.extend(self._runtime.reset_subgraph(self._retry, state=None, - intention=st.REVERT)) + tweaked = self._runtime.reset_atoms( + [self._retry], state=None, intention=st.RETRY + ) + tweaked.extend( + self._runtime.reset_subgraph( + self._retry, state=None, intention=st.REVERT + ) + ) return tweaked @@ -76,7 +80,9 @@ class RevertAll(Strategy): def apply(self): return self._runtime.reset_atoms( self._runtime.iterate_nodes(co.ATOMS), - state=None, intention=st.REVERT) + state=None, + intention=st.REVERT, + ) class Revert(Strategy): @@ -89,10 +95,14 @@ class Revert(Strategy): self._atom = atom def apply(self): - tweaked = self._runtime.reset_atoms([self._atom], state=None, - intention=st.REVERT) - tweaked.extend(self._runtime.reset_subgraph(self._atom, state=None, - intention=st.REVERT)) + tweaked = self._runtime.reset_atoms( + [self._atom], state=None, intention=st.REVERT + ) + tweaked.extend( + self._runtime.reset_subgraph( + self._atom, state=None, intention=st.REVERT + ) + ) return tweaked @@ -104,9 +114,11 @@ class Completer: self._storage = runtime.storage self._undefined_resolver = RevertAll(self._runtime) self._defer_reverts = strutils.bool_from_string( - self._runtime.options.get('defer_reverts', False)) + self._runtime.options.get('defer_reverts', False) + ) self._resolve = not strutils.bool_from_string( - self._runtime.options.get('never_resolve', False)) + self._runtime.options.get('never_resolve', False) + ) def resume(self): """Resumes atoms in the contained graph. @@ -120,14 +132,16 @@ class Completer: attempt not previously finishing). """ atoms = list(self._runtime.iterate_nodes(co.ATOMS)) - atom_states = self._storage.get_atoms_states(atom.name - for atom in atoms) + atom_states = self._storage.get_atoms_states( + atom.name for atom in atoms + ) if self._resolve: for atom in atoms: atom_state, _atom_intention = atom_states[atom.name] if atom_state == st.FAILURE: self._process_atom_failure( - atom, self._storage.get(atom.name)) + atom, self._storage.get(atom.name) + ) for retry in self._runtime.iterate_retries(st.RETRYING): retry_affected_atoms_it = self._runtime.retry_subflow(retry) for atom, state, intention in retry_affected_atoms_it: @@ -138,8 +152,11 @@ class Completer: atom_state, _atom_intention = atom_states[atom.name] if atom_state in (st.RUNNING, st.REVERTING): unfinished_atoms.add(atom) - LOG.trace("Resuming atom '%s' since it was left in" - " state %s", atom, atom_state) + LOG.trace( + "Resuming atom '%s' since it was left in state %s", + atom, + atom_state, + ) return unfinished_atoms def complete_failure(self, node, outcome, failure): @@ -192,8 +209,10 @@ class Completer: elif strategy == retry_atom.REVERT_ALL: return RevertAll(self._runtime) else: - raise ValueError("Unknown atom failure resolution" - " action/strategy '%s'" % strategy) + raise ValueError( + "Unknown atom failure resolution" + " action/strategy '%s'" % strategy + ) else: return self._undefined_resolver @@ -207,14 +226,24 @@ class Completer: the failure can be worked around. """ resolver = self._determine_resolution(atom, failure) - LOG.debug("Applying resolver '%s' to resolve failure '%s'" - " of atom '%s'", resolver, failure, atom) + LOG.debug( + "Applying resolver '%s' to resolve failure '%s' of atom '%s'", + resolver, + failure, + atom, + ) tweaked = resolver.apply() # Only show the tweaked node list when trace is on, otherwise # just show the amount/count of nodes tweaks... if LOG.isEnabledFor(logging.TRACE): - LOG.trace("Modified/tweaked %s nodes while applying" - " resolver '%s'", tweaked, resolver) + LOG.trace( + "Modified/tweaked %s nodes while applying resolver '%s'", + tweaked, + resolver, + ) else: - LOG.debug("Modified/tweaked %s nodes while applying" - " resolver '%s'", len(tweaked), resolver) + LOG.debug( + "Modified/tweaked %s nodes while applying resolver '%s'", + len(tweaked), + resolver, + ) diff --git a/taskflow/engines/action_engine/deciders.py b/taskflow/engines/action_engine/deciders.py index d712acffd..0cfac80e3 100644 --- a/taskflow/engines/action_engine/deciders.py +++ b/taskflow/engines/action_engine/deciders.py @@ -67,23 +67,34 @@ class Decider(metaclass=abc.ABCMeta): def _affect_all_successors(atom, runtime): execution_graph = runtime.compilation.execution_graph successors_iter = traversal.depth_first_iterate( - execution_graph, atom, traversal.Direction.FORWARD) - runtime.reset_atoms(itertools.chain([atom], successors_iter), - state=states.IGNORE, intention=states.IGNORE) + execution_graph, atom, traversal.Direction.FORWARD + ) + runtime.reset_atoms( + itertools.chain([atom], successors_iter), + state=states.IGNORE, + intention=states.IGNORE, + ) def _affect_successor_tasks_in_same_flow(atom, runtime): execution_graph = runtime.compilation.execution_graph successors_iter = traversal.depth_first_iterate( - execution_graph, atom, traversal.Direction.FORWARD, + execution_graph, + atom, + traversal.Direction.FORWARD, # Do not go through nested flows but do follow *all* tasks that # are directly connected in this same flow (thus the reason this is # called the same flow decider); retries are direct successors # of flows, so they should also be not traversed through, but # setting this explicitly ensures that. - through_flows=False, through_retries=False) - runtime.reset_atoms(itertools.chain([atom], successors_iter), - state=states.IGNORE, intention=states.IGNORE) + through_flows=False, + through_retries=False, + ) + runtime.reset_atoms( + itertools.chain([atom], successors_iter), + state=states.IGNORE, + intention=states.IGNORE, + ) def _affect_atom(atom, runtime): @@ -97,9 +108,13 @@ def _affect_direct_task_neighbors(atom, runtime): node_data = execution_graph.nodes[node] if node_data['kind'] == compiler.TASK: yield node + successors_iter = _walk_neighbors() - runtime.reset_atoms(itertools.chain([atom], successors_iter), - state=states.IGNORE, intention=states.IGNORE) + runtime.reset_atoms( + itertools.chain([atom], successors_iter), + state=states.IGNORE, + intention=states.IGNORE, + ) class IgnoreDecider(Decider): @@ -128,18 +143,23 @@ class IgnoreDecider(Decider): # that those results can be used by the decider(s) that are # making a decision as to pass or not pass... states_intentions = runtime.storage.get_atoms_states( - ed.from_node.name for ed in self._edge_deciders - if ed.kind in compiler.ATOMS) + ed.from_node.name + for ed in self._edge_deciders + if ed.kind in compiler.ATOMS + ) for atom_name in states_intentions.keys(): atom_state, _atom_intention = states_intentions[atom_name] if atom_state != states.IGNORE: history[atom_name] = runtime.storage.get(atom_name) for ed in self._edge_deciders: - if (ed.kind in compiler.ATOMS and - # It was an ignored atom (not included in history and - # the only way that is possible is via above loop - # skipping it...) - ed.from_node.name not in history): + if ( + ed.kind in compiler.ATOMS + and + # It was an ignored atom (not included in history and + # the only way that is possible is via above loop + # skipping it...) + ed.from_node.name not in history + ): voters['ignored'].append(ed) continue if not ed.decider(history=history): @@ -147,15 +167,17 @@ class IgnoreDecider(Decider): else: voters['run_it'].append(ed) if LOG.isEnabledFor(logging.TRACE): - LOG.trace("Out of %s deciders there were %s 'do no run it'" - " voters, %s 'do run it' voters and %s 'ignored'" - " voters for transition to atom '%s' given history %s", - sum(len(eds) for eds in voters.values()), - list(ed.from_node.name - for ed in voters['do_not_run_it']), - list(ed.from_node.name for ed in voters['run_it']), - list(ed.from_node.name for ed in voters['ignored']), - self._atom.name, history) + LOG.trace( + "Out of %s deciders there were %s 'do no run it'" + " voters, %s 'do run it' voters and %s 'ignored'" + " voters for transition to atom '%s' given history %s", + sum(len(eds) for eds in voters.values()), + list(ed.from_node.name for ed in voters['do_not_run_it']), + list(ed.from_node.name for ed in voters['run_it']), + list(ed.from_node.name for ed in voters['ignored']), + self._atom.name, + history, + ) return voters['do_not_run_it'] def affect(self, runtime, nay_voters): diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index 6b29ae465..a2e35f1db 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -55,8 +55,9 @@ def _start_stop(task_executor, retry_executor): task_executor.stop() -def _pre_check(check_compiled=True, check_storage_ensured=True, - check_validated=True): +def _pre_check( + check_compiled=True, check_storage_ensured=True, check_validated=True +): """Engine state precondition checking decorator.""" def decorator(meth): @@ -65,15 +66,21 @@ def _pre_check(check_compiled=True, check_storage_ensured=True, @functools.wraps(meth) def wrapper(self, *args, **kwargs): if check_compiled and not self._compiled: - raise exc.InvalidState("Can not %s an engine which" - " has not been compiled" % do_what) + raise exc.InvalidState( + "Can not %s an engine which" + " has not been compiled" % do_what + ) if check_storage_ensured and not self._storage_ensured: - raise exc.InvalidState("Can not %s an engine" - " which has not had its storage" - " populated" % do_what) + raise exc.InvalidState( + "Can not %s an engine" + " which has not had its storage" + " populated" % do_what + ) if check_validated and not self._validated: - raise exc.InvalidState("Can not %s an engine which" - " has not been validated" % do_what) + raise exc.InvalidState( + "Can not %s an engine which" + " has not been validated" % do_what + ) return meth(self, *args, **kwargs) return wrapper @@ -148,8 +155,16 @@ class ActionEngine(base.Engine): """ IGNORABLE_STATES = frozenset( - itertools.chain([states.SCHEDULING, states.WAITING, states.RESUMING, - states.ANALYZING], builder.META_STATES)) + itertools.chain( + [ + states.SCHEDULING, + states.WAITING, + states.RESUMING, + states.ANALYZING, + ], + builder.META_STATES, + ) + ) """ Informational states this engines internal machine yields back while running, not useful to have the engine record but useful to provide to @@ -175,18 +190,22 @@ class ActionEngine(base.Engine): # or thread (this could change in the future if we desire it to). self._retry_executor = executor.SerialRetryExecutor() self._inject_transient = strutils.bool_from_string( - self._options.get('inject_transient', True)) + self._options.get('inject_transient', True) + ) self._gather_statistics = strutils.bool_from_string( - self._options.get('gather_statistics', True)) + self._options.get('gather_statistics', True) + ) self._statistics = {} - @_pre_check(check_compiled=True, - # NOTE(harlowja): We can alter the state of the - # flow without ensuring its storage is setup for - # its atoms (since this state change does not affect - # those units). - check_storage_ensured=False, - check_validated=False) + @_pre_check( + check_compiled=True, + # NOTE(harlowja): We can alter the state of the + # flow without ensuring its storage is setup for + # its atoms (since this state change does not affect + # those units). + check_storage_ensured=False, + check_validated=False, + ) def suspend(self): self._change_state(states.SUSPENDING) @@ -221,14 +240,18 @@ class ActionEngine(base.Engine): the actual runtime lookup strategy, which typically will be, but is not always different). """ + def _scope_fetcher(atom_name): if self._compiled: return self._runtime.fetch_scopes_for(atom_name) else: return None - return storage.Storage(self._flow_detail, - backend=self._backend, - scope_fetcher=_scope_fetcher) + + return storage.Storage( + self._flow_detail, + backend=self._backend, + scope_fetcher=_scope_fetcher, + ) def run(self, timeout=None): """Runs the engine (or die trying). @@ -239,8 +262,9 @@ class ActionEngine(base.Engine): """ with fasteners.try_lock(self._lock) as was_locked: if not was_locked: - raise exc.ExecutionFailure("Engine currently locked, please" - " try again later") + raise exc.ExecutionFailure( + "Engine currently locked, please try again later" + ) for _state in self.run_iter(timeout=timeout): pass @@ -272,7 +296,8 @@ class ActionEngine(base.Engine): # are quite useful to log (and the performance of tracking this # should be negligible). last_transitions = collections.deque( - maxlen=max(1, self.MAX_MACHINE_STATES_RETAINED)) + maxlen=max(1, self.MAX_MACHINE_STATES_RETAINED) + ) with _start_stop(self._task_executor, self._retry_executor): self._change_state(states.RUNNING) if self._gather_statistics: @@ -284,8 +309,10 @@ class ActionEngine(base.Engine): try: closed = False machine, memory = self._runtime.builder.build( - self._statistics, timeout=timeout, - gather_statistics=self._gather_statistics) + self._statistics, + timeout=timeout, + gather_statistics=self._gather_statistics, + ) r = runners.FiniteRunner(machine) for transition in r.run_iter(builder.START): last_transitions.append(transition) @@ -317,11 +344,13 @@ class ActionEngine(base.Engine): self.suspend() except Exception: with excutils.save_and_reraise_exception(): - LOG.exception("Engine execution has failed, something" - " bad must have happened (last" - " %s machine transitions were %s)", - last_transitions.maxlen, - list(last_transitions)) + LOG.exception( + "Engine execution has failed, something" + " bad must have happened (last" + " %s machine transitions were %s)", + last_transitions.maxlen, + list(last_transitions), + ) self._change_state(states.FAILURE) else: if last_transitions: @@ -332,8 +361,8 @@ class ActionEngine(base.Engine): e_failures = self.storage.get_execute_failures() r_failures = self.storage.get_revert_failures() er_failures = itertools.chain( - e_failures.values(), - r_failures.values()) + e_failures.values(), r_failures.values() + ) failure.Failure.reraise_if_any(er_failures) finally: if w is not None: @@ -355,7 +384,8 @@ class ActionEngine(base.Engine): seen.add(atom_name) if dups: raise exc.Duplicate( - "Atoms with duplicate names found: %s" % (sorted(dups))) + "Atoms with duplicate names found: %s" % (sorted(dups)) + ) return compilation def _change_state(self, state): @@ -371,12 +401,12 @@ class ActionEngine(base.Engine): def _ensure_storage(self): """Ensure all contained atoms exist in the storage unit.""" - self.storage.ensure_atoms( - self._runtime.iterate_nodes(compiler.ATOMS)) + self.storage.ensure_atoms(self._runtime.iterate_nodes(compiler.ATOMS)) for atom in self._runtime.iterate_nodes(compiler.ATOMS): if atom.inject: - self.storage.inject_atom_args(atom.name, atom.inject, - transient=self._inject_transient) + self.storage.inject_atom_args( + atom.name, atom.inject, transient=self._inject_transient + ) @fasteners.locked @_pre_check(check_validated=False) @@ -387,11 +417,14 @@ class ActionEngine(base.Engine): # by failing at validation time). if LOG.isEnabledFor(logging.TRACE): execution_graph = self._compilation.execution_graph - LOG.trace("Validating scoping and argument visibility for" - " execution graph with %s nodes and %s edges with" - " density %0.3f", execution_graph.number_of_nodes(), - execution_graph.number_of_edges(), - nx.density(execution_graph)) + LOG.trace( + "Validating scoping and argument visibility for" + " execution graph with %s nodes and %s edges with" + " density %0.3f", + execution_graph.number_of_nodes(), + execution_graph.number_of_edges(), + nx.density(execution_graph), + ) missing = set() # Attempt to retain a chain of what was missing (so that the final # raised exception for the flow has the nodes that had missing @@ -401,18 +434,25 @@ class ActionEngine(base.Engine): missing_nodes = 0 for atom in self._runtime.iterate_nodes(compiler.ATOMS): exec_missing = self.storage.fetch_unsatisfied_args( - atom.name, atom.rebind, optional_args=atom.optional) + atom.name, atom.rebind, optional_args=atom.optional + ) revert_missing = self.storage.fetch_unsatisfied_args( - atom.name, atom.revert_rebind, - optional_args=atom.revert_optional) - atom_missing = (('execute', exec_missing), - ('revert', revert_missing)) + atom.name, + atom.revert_rebind, + optional_args=atom.revert_optional, + ) + atom_missing = ( + ('execute', exec_missing), + ('revert', revert_missing), + ) for method, method_missing in atom_missing: if method_missing: - cause = exc.MissingDependencies(atom, - sorted(method_missing), - cause=last_cause, - method=method) + cause = exc.MissingDependencies( + atom, + sorted(method_missing), + cause=last_cause, + method=method, + ) last_cause = cause last_node = atom missing_nodes += 1 @@ -424,9 +464,9 @@ class ActionEngine(base.Engine): if missing_nodes == 1 and last_node is self._flow: raise last_cause else: - raise exc.MissingDependencies(self._flow, - sorted(missing), - cause=last_cause) + raise exc.MissingDependencies( + self._flow, sorted(missing), cause=last_cause + ) self._validated = True @fasteners.locked @@ -458,12 +498,14 @@ class ActionEngine(base.Engine): if self._compiled: return self._compilation = self._check_compilation(self._compiler.compile()) - self._runtime = runtime.Runtime(self._compilation, - self.storage, - self.atom_notifier, - self._task_executor, - self._retry_executor, - options=self._options) + self._runtime = runtime.Runtime( + self._compilation, + self.storage, + self.atom_notifier, + self._task_executor, + self._retry_executor, + options=self._options, + ) self._runtime.compile() self._compiled = True @@ -476,14 +518,16 @@ class SerialActionEngine(ActionEngine): self._task_executor = executor.SerialTaskExecutor() -class _ExecutorTypeMatch(collections.namedtuple('_ExecutorTypeMatch', - ['types', 'executor_cls'])): +class _ExecutorTypeMatch( + collections.namedtuple('_ExecutorTypeMatch', ['types', 'executor_cls']) +): def matches(self, executor): return isinstance(executor, self.types) -class _ExecutorTextMatch(collections.namedtuple('_ExecutorTextMatch', - ['strings', 'executor_cls'])): +class _ExecutorTextMatch( + collections.namedtuple('_ExecutorTextMatch', ['strings', 'executor_cls']) +): def matches(self, text): return text.lower() in self.strings @@ -494,51 +538,52 @@ class ParallelActionEngine(ActionEngine): **Additional engine options:** * ``executor``: a object that implements a :pep:`3148` compatible executor - interface; it will be used for scheduling tasks. The following - type are applicable (other unknown types passed will cause a type - error to be raised). + interface; it will be used for scheduling tasks. The following type are + applicable (other unknown types passed will cause a type error to be + raised). -========================= =============================================== -Type provided Executor used -========================= =============================================== -|cft|.ThreadPoolExecutor :class:`~.executor.ParallelThreadTaskExecutor` -|cfp|.ProcessPoolExecutor :class:`~.|pe|.ParallelProcessTaskExecutor` -|cf|._base.Executor :class:`~.executor.ParallelThreadTaskExecutor` -========================= =============================================== + ========================= ============================================== + Type provided Executor used + ========================= ============================================== + |cft|.ThreadPoolExecutor :class:`~.executor.ParallelThreadTaskExecutor` + |cfp|.ProcessPoolExecutor :class:`~.|pe|.ParallelProcessTaskExecutor` + |cf|._base.Executor :class:`~.executor.ParallelThreadTaskExecutor` + ========================= ============================================== * ``executor``: a string that will be used to select a :pep:`3148` compatible executor; it will be used for scheduling tasks. The following string are applicable (other unknown strings passed will cause a value error to be raised). -=========================== =============================================== -String (case insensitive) Executor used -=========================== =============================================== -``process`` :class:`~.|pe|.ParallelProcessTaskExecutor` -``processes`` :class:`~.|pe|.ParallelProcessTaskExecutor` -``thread`` :class:`~.executor.ParallelThreadTaskExecutor` -``threaded`` :class:`~.executor.ParallelThreadTaskExecutor` -``threads`` :class:`~.executor.ParallelThreadTaskExecutor` -``greenthread`` :class:`~.executor.ParallelThreadTaskExecutor` - (greened version) -``greedthreaded`` :class:`~.executor.ParallelThreadTaskExecutor` - (greened version) -``greenthreads`` :class:`~.executor.ParallelThreadTaskExecutor` - (greened version) -=========================== =============================================== + =========================== ============================================== + String (case insensitive) Executor used + =========================== ============================================== + ``process`` :class:`~.|pe|.ParallelProcessTaskExecutor` + ``processes`` :class:`~.|pe|.ParallelProcessTaskExecutor` + ``thread`` :class:`~.executor.ParallelThreadTaskExecutor` + ``threaded`` :class:`~.executor.ParallelThreadTaskExecutor` + ``threads`` :class:`~.executor.ParallelThreadTaskExecutor` + ``greenthread`` :class:`~.executor.ParallelThreadTaskExecutor` + (greened version) + ``greedthreaded`` :class:`~.executor.ParallelThreadTaskExecutor` + (greened version) + ``greenthreads`` :class:`~.executor.ParallelThreadTaskExecutor` + (greened version) + =========================== ============================================== * ``max_workers``: a integer that will affect the number of parallel workers that are used to dispatch tasks into (this number is bounded by the maximum parallelization your workflow can support). * ``wait_timeout``: a float (in seconds) that will affect the - parallel process task executor (and therefore is **only** applicable when - the executor provided above is of the process variant). This number - affects how much time the process task executor waits for messages from - child processes (typically indicating they have finished or failed). A - lower number will have high granularity but *currently* involves more - polling while a higher number will involve less polling but a slower time - for an engine to notice a task has completed. + parallel process task executor (and therefore is **only** applicable + when the executor provided above is of the process variant). This + number affects how much time the process task executor waits for + messages from child processes (typically indicating they have + finished or failed). A lower number will have high granularity but + *currently* involves more polling while a higher number will involve + less polling but a slower time for an engine to notice a task has + completed. .. |cfp| replace:: concurrent.futures.process .. |cft| replace:: concurrent.futures.thread @@ -552,21 +597,26 @@ String (case insensitive) Executor used # allow for instances of that to be detected and handled correctly, instead # of forcing everyone to use our derivatives (futurist or other)... _executor_cls_matchers = [ - _ExecutorTypeMatch((futures.ThreadPoolExecutor,), - executor.ParallelThreadTaskExecutor), - _ExecutorTypeMatch((futures.Executor,), - executor.ParallelThreadTaskExecutor), + _ExecutorTypeMatch( + (futures.ThreadPoolExecutor,), executor.ParallelThreadTaskExecutor + ), + _ExecutorTypeMatch( + (futures.Executor,), executor.ParallelThreadTaskExecutor + ), ] # One of these should match when a string/text is provided for the # 'executor' option (a mixed case equivalent is allowed since the match # will be lower-cased before checking). _executor_str_matchers = [ - _ExecutorTextMatch(frozenset(['thread', 'threads', 'threaded']), - executor.ParallelThreadTaskExecutor), - _ExecutorTextMatch(frozenset(['greenthread', 'greenthreads', - 'greenthreaded']), - executor.ParallelGreenThreadTaskExecutor), + _ExecutorTextMatch( + frozenset(['thread', 'threads', 'threaded']), + executor.ParallelThreadTaskExecutor, + ), + _ExecutorTextMatch( + frozenset(['greenthread', 'greenthreads', 'greenthreaded']), + executor.ParallelGreenThreadTaskExecutor, + ), ] # Used when no executor is provided (either a string or object)... @@ -594,9 +644,11 @@ String (case insensitive) Executor used expected = set() for m in cls._executor_str_matchers: expected.update(m.strings) - raise ValueError("Unknown executor string '%s' expected" - " one of %s (or mixed case equivalent)" - % (desired_executor, list(expected))) + raise ValueError( + "Unknown executor string '%s' expected" + " one of %s (or mixed case equivalent)" + % (desired_executor, list(expected)) + ) else: executor_cls = matched_executor_cls elif desired_executor is not None: @@ -609,15 +661,20 @@ String (case insensitive) Executor used expected = set() for m in cls._executor_cls_matchers: expected.update(m.types) - raise TypeError("Unknown executor '%s' (%s) expected an" - " instance of %s" % (desired_executor, - type(desired_executor), - list(expected))) + raise TypeError( + "Unknown executor '%s' (%s) expected an" + " instance of %s" + % ( + desired_executor, + type(desired_executor), + list(expected), + ) + ) else: executor_cls = matched_executor_cls kwargs['executor'] = desired_executor try: - for (k, value_converter) in executor_cls.constructor_options: + for k, value_converter in executor_cls.constructor_options: try: kwargs[k] = value_converter(options[k]) except KeyError: diff --git a/taskflow/engines/action_engine/executor.py b/taskflow/engines/action_engine/executor.py index 066a40a94..97f65e06d 100644 --- a/taskflow/engines/action_engine/executor.py +++ b/taskflow/engines/action_engine/executor.py @@ -42,9 +42,9 @@ def _revert_retry(retry, arguments): def _execute_task(task, arguments, progress_callback=None): - with notifier.register_deregister(task.notifier, - ta.EVENT_UPDATE_PROGRESS, - callback=progress_callback): + with notifier.register_deregister( + task.notifier, ta.EVENT_UPDATE_PROGRESS, callback=progress_callback + ): try: task.pre_execute() result = task.execute(**arguments) @@ -61,9 +61,9 @@ def _revert_task(task, arguments, result, failures, progress_callback=None): arguments = arguments.copy() arguments[ta.REVERT_RESULT] = result arguments[ta.REVERT_FLOW_FAILURES] = failures - with notifier.register_deregister(task.notifier, - ta.EVENT_UPDATE_PROGRESS, - callback=progress_callback): + with notifier.register_deregister( + task.notifier, ta.EVENT_UPDATE_PROGRESS, callback=progress_callback + ): try: task.pre_revert() result = task.revert(**arguments) @@ -112,13 +112,19 @@ class TaskExecutor(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def execute_task(self, task, task_uuid, arguments, - progress_callback=None): + def execute_task(self, task, task_uuid, arguments, progress_callback=None): """Schedules task execution.""" @abc.abstractmethod - def revert_task(self, task, task_uuid, arguments, result, failures, - progress_callback=None): + def revert_task( + self, + task, + task_uuid, + arguments, + result, + failures, + progress_callback=None, + ): """Schedules task reversion.""" def start(self): @@ -141,17 +147,29 @@ class SerialTaskExecutor(TaskExecutor): self._executor.shutdown() def execute_task(self, task, task_uuid, arguments, progress_callback=None): - fut = self._executor.submit(_execute_task, - task, arguments, - progress_callback=progress_callback) + fut = self._executor.submit( + _execute_task, task, arguments, progress_callback=progress_callback + ) fut.atom = task return fut - def revert_task(self, task, task_uuid, arguments, result, failures, - progress_callback=None): - fut = self._executor.submit(_revert_task, - task, arguments, result, failures, - progress_callback=progress_callback) + def revert_task( + self, + task, + task_uuid, + arguments, + result, + failures, + progress_callback=None, + ): + fut = self._executor.submit( + _revert_task, + task, + arguments, + result, + failures, + progress_callback=progress_callback, + ) fut.atom = task return fut @@ -188,18 +206,33 @@ class ParallelTaskExecutor(TaskExecutor): return fut def execute_task(self, task, task_uuid, arguments, progress_callback=None): - return self._submit_task(_execute_task, task, arguments, - progress_callback=progress_callback) + return self._submit_task( + _execute_task, task, arguments, progress_callback=progress_callback + ) - def revert_task(self, task, task_uuid, arguments, result, failures, - progress_callback=None): - return self._submit_task(_revert_task, task, arguments, result, - failures, progress_callback=progress_callback) + def revert_task( + self, + task, + task_uuid, + arguments, + result, + failures, + progress_callback=None, + ): + return self._submit_task( + _revert_task, + task, + arguments, + result, + failures, + progress_callback=progress_callback, + ) def start(self): if self._own_executor: self._executor = self._create_executor( - max_workers=self._max_workers) + max_workers=self._max_workers + ) def stop(self): if self._own_executor: diff --git a/taskflow/engines/action_engine/runtime.py b/taskflow/engines/action_engine/runtime.py index c70c4162a..8b72f8dc8 100644 --- a/taskflow/engines/action_engine/runtime.py +++ b/taskflow/engines/action_engine/runtime.py @@ -32,11 +32,12 @@ from taskflow import logging from taskflow import states as st from taskflow.utils import misc -from taskflow.flow import (LINK_DECIDER, LINK_DECIDER_DEPTH) # noqa +from taskflow.flow import LINK_DECIDER, LINK_DECIDER_DEPTH # noqa # Small helper to make the edge decider tuples more easily useable... -_EdgeDecider = collections.namedtuple('_EdgeDecider', - 'from_node,kind,decider,depth') +_EdgeDecider = collections.namedtuple( + '_EdgeDecider', 'from_node,kind,decider,depth' +) LOG = logging.getLogger(__name__) @@ -49,9 +50,15 @@ class Runtime: action engine to run to completion. """ - def __init__(self, compilation, storage, atom_notifier, - task_executor, retry_executor, - options=None): + def __init__( + self, + compilation, + storage, + atom_notifier, + task_executor, + retry_executor, + options=None, + ): self._atom_notifier = atom_notifier self._task_executor = task_executor self._retry_executor = retry_executor @@ -65,8 +72,9 @@ class Runtime: # This is basically a reverse breadth first exploration, with # special logic to further traverse down flow nodes as needed... predecessors_iter = graph.predecessors - nodes = collections.deque((u_node, atom) - for u_node in predecessors_iter(atom)) + nodes = collections.deque( + (u_node, atom) for u_node in predecessors_iter(atom) + ) visited = set() while nodes: u_node, v_node = nodes.popleft() @@ -77,8 +85,7 @@ class Runtime: decider_depth = u_v_data.get(LINK_DECIDER_DEPTH) if decider_depth is None: decider_depth = de.Depth.ALL - yield _EdgeDecider(u_node, u_node_kind, - decider, decider_depth) + yield _EdgeDecider(u_node, u_node_kind, decider, decider_depth) except KeyError: pass if u_node_kind == com.FLOW and u_node not in visited: @@ -89,8 +96,10 @@ class Runtime: # sure that any prior decider that was directed at this flow # node also gets used during future decisions about this # atom node. - nodes.extend((u_u_node, u_node) - for u_u_node in predecessors_iter(u_node)) + nodes.extend( + (u_u_node, u_node) + for u_u_node in predecessors_iter(u_node) + ) def compile(self): """Compiles & caches frequently used execution helper objects. @@ -102,8 +111,9 @@ class Runtime: specific scheduler and so-on). """ change_state_handlers = { - com.TASK: functools.partial(self.task_action.change_state, - progress=0.0), + com.TASK: functools.partial( + self.task_action.change_state, progress=0.0 + ), com.RETRY: self.retry_action.change_state, } schedulers = { @@ -129,8 +139,9 @@ class Runtime: scheduler = schedulers[node_kind] action = actions[node_kind] else: - raise exc.CompilationFailure("Unknown node kind '%s'" - " encountered" % node_kind) + raise exc.CompilationFailure( + "Unknown node kind '%s' encountered" % node_kind + ) metadata = {} deciders_it = self._walk_edge_deciders(graph, node) walker = sc.ScopeWalker(self.compilation, node, names_only=True) @@ -140,8 +151,12 @@ class Runtime: metadata['scheduler'] = scheduler metadata['edge_deciders'] = tuple(deciders_it) metadata['action'] = action - LOG.trace("Compiled %s metadata for node %s (%s)", - metadata, node.name, node_kind) + LOG.trace( + "Compiled %s metadata for node %s (%s)", + metadata, + node.name, + node_kind, + ) self._atom_cache[node.name] = metadata # TODO(harlowja): optimize the different decider depths to avoid # repeated full successor searching; this can be done by searching @@ -186,15 +201,15 @@ class Runtime: @misc.cachedproperty def retry_action(self): - return ra.RetryAction(self._storage, - self._atom_notifier, - self._retry_executor) + return ra.RetryAction( + self._storage, self._atom_notifier, self._retry_executor + ) @misc.cachedproperty def task_action(self): - return ta.TaskAction(self._storage, - self._atom_notifier, - self._task_executor) + return ta.TaskAction( + self._storage, self._atom_notifier, self._task_executor + ) def _fetch_atom_metadata_entry(self, atom_name, metadata_key): return self._atom_cache[atom_name][metadata_key] @@ -205,7 +220,8 @@ class Runtime: # internally to the engine, and is not exposed to atoms that will # not exist and therefore doesn't need to handle that case). check_transition_handler = self._fetch_atom_metadata_entry( - atom.name, 'check_transition_handler') + atom.name, 'check_transition_handler' + ) return check_transition_handler(current_state, target_state) def fetch_edge_deciders(self, atom): @@ -250,8 +266,9 @@ class Runtime: """ if state: atoms = list(self.iterate_nodes((com.RETRY,))) - atom_states = self._storage.get_atoms_states(atom.name - for atom in atoms) + atom_states = self._storage.get_atoms_states( + atom.name for atom in atoms + ) for atom in atoms: atom_state, _atom_intention = atom_states[atom.name] if atom_state == state: @@ -270,8 +287,9 @@ class Runtime: def is_success(self): """Checks if all atoms in the execution graph are in 'happy' state.""" atoms = list(self.iterate_nodes(com.ATOMS)) - atom_states = self._storage.get_atoms_states(atom.name - for atom in atoms) + atom_states = self._storage.get_atoms_states( + atom.name for atom in atoms + ) for atom in atoms: atom_state, _atom_intention = atom_states[atom.name] if atom_state == st.IGNORE: @@ -305,7 +323,8 @@ class Runtime: tweaked.append((atom, state, intention)) if state: change_state_handler = self._fetch_atom_metadata_entry( - atom.name, 'change_state_handler') + atom.name, 'change_state_handler' + ) change_state_handler(atom, state) if intention: self.storage.set_atom_intention(atom.name, intention) @@ -313,8 +332,9 @@ class Runtime: def reset_all(self, state=st.PENDING, intention=st.EXECUTE): """Resets all atoms to the given state and intention.""" - return self.reset_atoms(self.iterate_nodes(com.ATOMS), - state=state, intention=intention) + return self.reset_atoms( + self.iterate_nodes(com.ATOMS), state=state, intention=intention + ) def reset_subgraph(self, atom, state=st.PENDING, intention=st.EXECUTE): """Resets a atoms subgraph to the given state and intention. @@ -322,8 +342,9 @@ class Runtime: The subgraph is contained of **all** of the atoms successors. """ execution_graph = self._compilation.execution_graph - atoms_it = tr.depth_first_iterate(execution_graph, atom, - tr.Direction.FORWARD) + atoms_it = tr.depth_first_iterate( + execution_graph, atom, tr.Direction.FORWARD + ) return self.reset_atoms(atoms_it, state=state, intention=intention) def retry_subflow(self, retry): diff --git a/taskflow/engines/action_engine/scheduler.py b/taskflow/engines/action_engine/scheduler.py index 2e0661dac..c953202f3 100644 --- a/taskflow/engines/action_engine/scheduler.py +++ b/taskflow/engines/action_engine/scheduler.py @@ -46,8 +46,9 @@ class RetryScheduler: self._runtime.retry_subflow(retry) return self._retry_action.schedule_execution(retry) else: - raise excp.ExecutionFailure("Unknown how to schedule retry with" - " intention: %s" % intention) + raise excp.ExecutionFailure( + "Unknown how to schedule retry with intention: %s" % intention + ) class TaskScheduler: @@ -69,8 +70,9 @@ class TaskScheduler: elif intention == st.REVERT: return self._task_action.schedule_reversion(task) else: - raise excp.ExecutionFailure("Unknown how to schedule task with" - " intention: %s" % intention) + raise excp.ExecutionFailure( + "Unknown how to schedule task with intention: %s" % intention + ) class Scheduler: diff --git a/taskflow/engines/action_engine/scopes.py b/taskflow/engines/action_engine/scopes.py index 707266a9e..f43f58142 100644 --- a/taskflow/engines/action_engine/scopes.py +++ b/taskflow/engines/action_engine/scopes.py @@ -32,8 +32,9 @@ class ScopeWalker: def __init__(self, compilation, atom, names_only=False): self._node = compilation.hierarchy.find(atom) if self._node is None: - raise ValueError("Unable to find atom '%s' in compilation" - " hierarchy" % atom) + raise ValueError( + "Unable to find atom '%s' in compilation hierarchy" % atom + ) self._level_cache = {} self._atom = atom self._execution_graph = compilation.execution_graph @@ -78,8 +79,10 @@ class ScopeWalker: graph = self._execution_graph if self._predecessors is None: predecessors = { - node for node in graph.bfs_predecessors_iter(self._atom) - if graph.nodes[node]['kind'] in co.ATOMS} + node + for node in graph.bfs_predecessors_iter(self._atom) + if graph.nodes[node]['kind'] in co.ATOMS + } self._predecessors = predecessors.copy() else: predecessors = self._predecessors.copy() @@ -95,7 +98,8 @@ class ScopeWalker: visible = [] removals = set() atom_it = tr.depth_first_reverse_iterate( - parent, start_from_idx=last_idx) + parent, start_from_idx=last_idx + ) for atom in atom_it: if atom in predecessors: predecessors.remove(atom) @@ -106,9 +110,14 @@ class ScopeWalker: self._level_cache[lvl] = (visible, removals) if LOG.isEnabledFor(logging.TRACE): visible_names = [a.name for a in visible] - LOG.trace("Scope visible to '%s' (limited by parent '%s'" - " index < %s) is: %s", self._atom, - parent.item.name, last_idx, visible_names) + LOG.trace( + "Scope visible to '%s' (limited by parent '%s'" + " index < %s) is: %s", + self._atom, + parent.item.name, + last_idx, + visible_names, + ) if self._names_only: yield [a.name for a in visible] else: diff --git a/taskflow/engines/action_engine/selector.py b/taskflow/engines/action_engine/selector.py index 5da9e5af3..e33f5aac2 100644 --- a/taskflow/engines/action_engine/selector.py +++ b/taskflow/engines/action_engine/selector.py @@ -43,16 +43,22 @@ class Selector: def iter_next_atoms(self, atom=None): """Iterate next atoms to run (originating from atom or all atoms).""" if atom is None: - return iter_utils.unique_seen((self._browse_atoms_for_execute(), - self._browse_atoms_for_revert()), - seen_selector=operator.itemgetter(0)) + return iter_utils.unique_seen( + ( + self._browse_atoms_for_execute(), + self._browse_atoms_for_revert(), + ), + seen_selector=operator.itemgetter(0), + ) state = self._storage.get_atom_state(atom.name) intention = self._storage.get_atom_intention(atom.name) if state == st.SUCCESS: if intention == st.REVERT: - return iter([ - (atom, deciders.NoOpDecider()), - ]) + return iter( + [ + (atom, deciders.NoOpDecider()), + ] + ) elif intention == st.EXECUTE: return self._browse_atoms_for_execute(atom=atom) else: @@ -82,7 +88,8 @@ class Selector: # problematic to determine as top levels can have their deciders # applied **after** going deeper). atom_it = traversal.breadth_first_iterate( - self._execution_graph, atom, traversal.Direction.FORWARD) + self._execution_graph, atom, traversal.Direction.FORWARD + ) for atom in atom_it: is_ready, late_decider = self._get_maybe_ready_for_execute(atom) if is_ready: @@ -100,22 +107,32 @@ class Selector: atom_it = self._runtime.iterate_nodes(co.ATOMS) else: atom_it = traversal.breadth_first_iterate( - self._execution_graph, atom, traversal.Direction.BACKWARD, + self._execution_graph, + atom, + traversal.Direction.BACKWARD, # Stop at the retry boundary (as retries 'control' there # surronding atoms, and we don't want to back track over # them so that they can correctly affect there associated # atoms); we do though need to jump through all tasks since # if a predecessor Y was ignored and a predecessor Z before Y # was not it should be eligible to now revert... - through_retries=False) + through_retries=False, + ) for atom in atom_it: is_ready, late_decider = self._get_maybe_ready_for_revert(atom) if is_ready: yield (atom, late_decider) - def _get_maybe_ready(self, atom, transition_to, allowed_intentions, - connected_fetcher, ready_checker, - decider_fetcher, for_what="?"): + def _get_maybe_ready( + self, + atom, + transition_to, + allowed_intentions, + connected_fetcher, + ready_checker, + decider_fetcher, + for_what="?", + ): def iter_connected_states(): # Lazily iterate over connected states so that ready checkers # can stop early (vs having to consume and check all the @@ -126,6 +143,7 @@ class Selector: # to avoid two calls into storage). atom_states = self._storage.get_atoms_states([atom.name]) yield (atom, atom_states[atom.name]) + # NOTE(harlowja): How this works is the following... # # 1. First check if the current atom can even transition to the @@ -144,18 +162,29 @@ class Selector: # which can (if it desires) affect this ready result (but does # so right before the atom is about to be scheduled). state = self._storage.get_atom_state(atom.name) - ok_to_transition = self._runtime.check_atom_transition(atom, state, - transition_to) + ok_to_transition = self._runtime.check_atom_transition( + atom, state, transition_to + ) if not ok_to_transition: - LOG.trace("Atom '%s' is not ready to %s since it can not" - " transition to %s from its current state %s", - atom, for_what, transition_to, state) + LOG.trace( + "Atom '%s' is not ready to %s since it can not" + " transition to %s from its current state %s", + atom, + for_what, + transition_to, + state, + ) return (False, None) intention = self._storage.get_atom_intention(atom.name) if intention not in allowed_intentions: - LOG.trace("Atom '%s' is not ready to %s since its current" - " intention %s is not in allowed intentions %s", - atom, for_what, intention, allowed_intentions) + LOG.trace( + "Atom '%s' is not ready to %s since its current" + " intention %s is not in allowed intentions %s", + atom, + for_what, + intention, + allowed_intentions, + ) return (False, None) ok_to_run = ready_checker(iter_connected_states()) if not ok_to_run: @@ -165,62 +194,91 @@ class Selector: def _get_maybe_ready_for_execute(self, atom): """Returns if an atom is *likely* ready to be executed.""" + def ready_checker(pred_connected_it): for pred in pred_connected_it: pred_atom, (pred_atom_state, pred_atom_intention) = pred - if (pred_atom_state in (st.SUCCESS, st.IGNORE) and - pred_atom_intention in (st.EXECUTE, st.IGNORE)): + if pred_atom_state in ( + st.SUCCESS, + st.IGNORE, + ) and pred_atom_intention in (st.EXECUTE, st.IGNORE): continue - LOG.trace("Unable to begin to execute since predecessor" - " atom '%s' is in state %s with intention %s", - pred_atom, pred_atom_state, pred_atom_intention) + LOG.trace( + "Unable to begin to execute since predecessor" + " atom '%s' is in state %s with intention %s", + pred_atom, + pred_atom_state, + pred_atom_intention, + ) return False LOG.trace("Able to let '%s' execute", atom) return True - decider_fetcher = lambda: \ - deciders.IgnoreDecider( - atom, self._runtime.fetch_edge_deciders(atom)) - connected_fetcher = lambda: \ - traversal.depth_first_iterate(self._execution_graph, atom, - # Whether the desired atom - # can execute is dependent on its - # predecessors outcomes (thus why - # we look backwards). - traversal.Direction.BACKWARD) + + decider_fetcher = lambda: deciders.IgnoreDecider( + atom, self._runtime.fetch_edge_deciders(atom) + ) + connected_fetcher = lambda: traversal.depth_first_iterate( + self._execution_graph, + atom, + # Whether the desired atom + # can execute is dependent on its + # predecessors outcomes (thus why + # we look backwards). + traversal.Direction.BACKWARD, + ) # If this atoms current state is able to be transitioned to RUNNING # and its intention is to EXECUTE and all of its predecessors executed # successfully or were ignored then this atom is ready to execute. LOG.trace("Checking if '%s' is ready to execute", atom) - return self._get_maybe_ready(atom, st.RUNNING, [st.EXECUTE], - connected_fetcher, ready_checker, - decider_fetcher, for_what='execute') + return self._get_maybe_ready( + atom, + st.RUNNING, + [st.EXECUTE], + connected_fetcher, + ready_checker, + decider_fetcher, + for_what='execute', + ) def _get_maybe_ready_for_revert(self, atom): """Returns if an atom is *likely* ready to be reverted.""" + def ready_checker(succ_connected_it): for succ in succ_connected_it: succ_atom, (succ_atom_state, _succ_atom_intention) = succ if succ_atom_state not in (st.PENDING, st.REVERTED, st.IGNORE): - LOG.trace("Unable to begin to revert since successor" - " atom '%s' is in state %s", succ_atom, - succ_atom_state) + LOG.trace( + "Unable to begin to revert since successor" + " atom '%s' is in state %s", + succ_atom, + succ_atom_state, + ) return False LOG.trace("Able to let '%s' revert", atom) return True + noop_decider = deciders.NoOpDecider() - connected_fetcher = lambda: \ - traversal.depth_first_iterate(self._execution_graph, atom, - # Whether the desired atom - # can revert is dependent on its - # successors states (thus why we - # look forwards). - traversal.Direction.FORWARD) + connected_fetcher = lambda: traversal.depth_first_iterate( + self._execution_graph, + atom, + # Whether the desired atom + # can revert is dependent on its + # successors states (thus why we + # look forwards). + traversal.Direction.FORWARD, + ) decider_fetcher = lambda: noop_decider # If this atoms current state is able to be transitioned to REVERTING # and its intention is either REVERT or RETRY and all of its # successors are either PENDING or REVERTED then this atom is ready # to revert. LOG.trace("Checking if '%s' is ready to revert", atom) - return self._get_maybe_ready(atom, st.REVERTING, [st.REVERT, st.RETRY], - connected_fetcher, ready_checker, - decider_fetcher, for_what='revert') + return self._get_maybe_ready( + atom, + st.REVERTING, + [st.REVERT, st.RETRY], + connected_fetcher, + ready_checker, + decider_fetcher, + for_what='revert', + ) diff --git a/taskflow/engines/action_engine/traversal.py b/taskflow/engines/action_engine/traversal.py index 7885d0d36..e88af0457 100644 --- a/taskflow/engines/action_engine/traversal.py +++ b/taskflow/engines/action_engine/traversal.py @@ -28,9 +28,14 @@ class Direction(enum.Enum): BACKWARD = 2 -def _extract_connectors(execution_graph, starting_node, direction, - through_flows=True, through_retries=True, - through_tasks=True): +def _extract_connectors( + execution_graph, + starting_node, + direction, + through_flows=True, + through_retries=True, + through_tasks=True, +): if direction == Direction.FORWARD: connected_iter = execution_graph.successors else: @@ -46,9 +51,14 @@ def _extract_connectors(execution_graph, starting_node, direction, return connected_iter(starting_node), connected_to_functors -def breadth_first_iterate(execution_graph, starting_node, direction, - through_flows=True, through_retries=True, - through_tasks=True): +def breadth_first_iterate( + execution_graph, + starting_node, + direction, + through_flows=True, + through_retries=True, + through_tasks=True, +): """Iterates connected nodes in execution graph (from starting node). Does so in a breadth first manner. @@ -56,9 +66,13 @@ def breadth_first_iterate(execution_graph, starting_node, direction, Jumps over nodes with ``noop`` attribute (does not yield them back). """ initial_nodes_iter, connected_to_functors = _extract_connectors( - execution_graph, starting_node, direction, - through_flows=through_flows, through_retries=through_retries, - through_tasks=through_tasks) + execution_graph, + starting_node, + direction, + through_flows=through_flows, + through_retries=through_retries, + through_tasks=through_tasks, + ) q = collections.deque(initial_nodes_iter) visited_nodes = set() while q: @@ -79,9 +93,14 @@ def breadth_first_iterate(execution_graph, starting_node, direction, q.extend(connected_to_functor(node)) -def depth_first_iterate(execution_graph, starting_node, direction, - through_flows=True, through_retries=True, - through_tasks=True): +def depth_first_iterate( + execution_graph, + starting_node, + direction, + through_flows=True, + through_retries=True, + through_tasks=True, +): """Iterates connected nodes in execution graph (from starting node). Does so in a depth first manner. @@ -89,9 +108,13 @@ def depth_first_iterate(execution_graph, starting_node, direction, Jumps over nodes with ``noop`` attribute (does not yield them back). """ initial_nodes_iter, connected_to_functors = _extract_connectors( - execution_graph, starting_node, direction, - through_flows=through_flows, through_retries=through_retries, - through_tasks=through_tasks) + execution_graph, + starting_node, + direction, + through_flows=through_flows, + through_retries=through_retries, + through_tasks=through_tasks, + ) stack = list(initial_nodes_iter) visited_nodes = set() while stack: diff --git a/taskflow/engines/helpers.py b/taskflow/engines/helpers.py index 417192a30..907d5af57 100644 --- a/taskflow/engines/helpers.py +++ b/taskflow/engines/helpers.py @@ -60,8 +60,9 @@ def _fetch_factory(factory_name): try: return importutils.import_class(factory_name) except (ImportError, ValueError) as e: - raise ImportError("Could not import factory %r: %s" - % (factory_name, e)) + raise ImportError( + "Could not import factory %r: %s" % (factory_name, e) + ) def _fetch_validate_factory(flow_factory): @@ -73,16 +74,25 @@ def _fetch_validate_factory(flow_factory): factory_name = reflection.get_callable_name(flow_factory) try: reimported = _fetch_factory(factory_name) - assert reimported == factory_fun + assert reimported == factory_fun # noqa: S101 except (ImportError, AssertionError): - raise ValueError('Flow factory %r is not reimportable by name %s' - % (factory_fun, factory_name)) + raise ValueError( + 'Flow factory %r is not reimportable by name %s' + % (factory_fun, factory_name) + ) return (factory_name, factory_fun) -def load(flow, store=None, flow_detail=None, book=None, - backend=None, namespace=ENGINES_NAMESPACE, - engine=ENGINE_DEFAULT, **kwargs): +def load( + flow, + store=None, + flow_detail=None, + book=None, + backend=None, + namespace=ENGINES_NAMESPACE, + engine=ENGINE_DEFAULT, + **kwargs, +): """Load a flow into an engine. This function creates and prepares an engine to run the provided flow. All @@ -122,15 +132,18 @@ def load(flow, store=None, flow_detail=None, book=None, backend = p_backends.fetch(backend) if flow_detail is None: - flow_detail = p_utils.create_flow_detail(flow, book=book, - backend=backend) + flow_detail = p_utils.create_flow_detail( + flow, book=book, backend=backend + ) LOG.debug('Looking for %r engine driver in %r', kind, namespace) try: mgr = stevedore.driver.DriverManager( - namespace, kind, + namespace, + kind, invoke_on_load=True, - invoke_args=(flow, flow_detail, backend, options)) + invoke_args=(flow, flow_detail, backend, options), + ) engine = mgr.driver except RuntimeError as e: raise exc.NotFound("Could not find engine '%s'" % (kind), e) @@ -140,9 +153,16 @@ def load(flow, store=None, flow_detail=None, book=None, return engine -def run(flow, store=None, flow_detail=None, book=None, - backend=None, namespace=ENGINES_NAMESPACE, - engine=ENGINE_DEFAULT, **kwargs): +def run( + flow, + store=None, + flow_detail=None, + book=None, + backend=None, + namespace=ENGINES_NAMESPACE, + engine=ENGINE_DEFAULT, + **kwargs, +): """Run the flow. This function loads the flow into an engine (with the :func:`load() ` @@ -153,16 +173,23 @@ def run(flow, store=None, flow_detail=None, book=None, :returns: dictionary of all named results (see :py:meth:`~.taskflow.storage.Storage.fetch_all`) """ - engine = load(flow, store=store, flow_detail=flow_detail, book=book, - backend=backend, namespace=namespace, - engine=engine, **kwargs) + engine = load( + flow, + store=store, + flow_detail=flow_detail, + book=book, + backend=backend, + namespace=namespace, + engine=engine, + **kwargs, + ) engine.run() return engine.storage.fetch_all() -def save_factory_details(flow_detail, - flow_factory, factory_args, factory_kwargs, - backend=None): +def save_factory_details( + flow_detail, flow_factory, factory_args, factory_kwargs, backend=None +): """Saves the given factories reimportable attributes into the flow detail. This function saves the factory name, arguments, and keyword arguments @@ -198,10 +225,17 @@ def save_factory_details(flow_detail, conn.update_flow_details(flow_detail) -def load_from_factory(flow_factory, factory_args=None, factory_kwargs=None, - store=None, book=None, backend=None, - namespace=ENGINES_NAMESPACE, engine=ENGINE_DEFAULT, - **kwargs): +def load_from_factory( + flow_factory, + factory_args=None, + factory_kwargs=None, + store=None, + book=None, + backend=None, + namespace=ENGINES_NAMESPACE, + engine=ENGINE_DEFAULT, + **kwargs, +): """Loads a flow from a factory function into an engine. Gets flow factory function (or name of it) and creates flow with @@ -227,12 +261,23 @@ def load_from_factory(flow_factory, factory_args=None, factory_kwargs=None, if isinstance(backend, dict): backend = p_backends.fetch(backend) flow_detail = p_utils.create_flow_detail(flow, book=book, backend=backend) - save_factory_details(flow_detail, - flow_factory, factory_args, factory_kwargs, - backend=backend) - return load(flow=flow, store=store, flow_detail=flow_detail, book=book, - backend=backend, namespace=namespace, - engine=engine, **kwargs) + save_factory_details( + flow_detail, + flow_factory, + factory_args, + factory_kwargs, + backend=backend, + ) + return load( + flow=flow, + store=store, + flow_detail=flow_detail, + book=book, + backend=backend, + namespace=namespace, + engine=engine, + **kwargs, + ) def flow_from_detail(flow_detail): @@ -247,24 +292,33 @@ def flow_from_detail(flow_detail): try: factory_data = flow_detail.meta['factory'] except (KeyError, AttributeError, TypeError): - raise ValueError('Cannot reconstruct flow %s %s: ' - 'no factory information saved.' - % (flow_detail.name, flow_detail.uuid)) + raise ValueError( + 'Cannot reconstruct flow %s %s: ' + 'no factory information saved.' + % (flow_detail.name, flow_detail.uuid) + ) try: factory_fun = _fetch_factory(factory_data['name']) except (KeyError, ImportError): - raise ImportError('Could not import factory for flow %s %s' - % (flow_detail.name, flow_detail.uuid)) + raise ImportError( + 'Could not import factory for flow %s %s' + % (flow_detail.name, flow_detail.uuid) + ) args = factory_data.get('args', ()) kwargs = factory_data.get('kwargs', {}) return factory_fun(*args, **kwargs) -def load_from_detail(flow_detail, store=None, backend=None, - namespace=ENGINES_NAMESPACE, engine=ENGINE_DEFAULT, - **kwargs): +def load_from_detail( + flow_detail, + store=None, + backend=None, + namespace=ENGINES_NAMESPACE, + engine=ENGINE_DEFAULT, + **kwargs, +): """Reloads an engine previously saved. This reloads the flow using the @@ -278,6 +332,12 @@ def load_from_detail(flow_detail, store=None, backend=None, :returns: engine """ flow = flow_from_detail(flow_detail) - return load(flow, flow_detail=flow_detail, - store=store, backend=backend, - namespace=namespace, engine=engine, **kwargs) + return load( + flow, + flow_detail=flow_detail, + store=store, + backend=backend, + namespace=namespace, + engine=engine, + **kwargs, + ) diff --git a/taskflow/engines/worker_based/dispatcher.py b/taskflow/engines/worker_based/dispatcher.py index 8f47cf53b..9b5d8209a 100644 --- a/taskflow/engines/worker_based/dispatcher.py +++ b/taskflow/engines/worker_based/dispatcher.py @@ -100,9 +100,13 @@ class TypeDispatcher: if cb(data, message): requeue_votes += 1 except Exception: - LOG.exception("Failed calling requeue filter %s '%s' to" - " determine if message %r should be requeued.", - i + 1, cb, message.delivery_tag) + LOG.exception( + "Failed calling requeue filter %s '%s' to" + " determine if message %r should be requeued.", + i + 1, + cb, + message.delivery_tag, + ) return requeue_votes def _requeue_log_error(self, message, errors): @@ -114,52 +118,72 @@ class TypeDispatcher: # This was taken from how kombu is formatting its messages # when its reject_log_error or ack_log_error functions are # used so that we have a similar error format for requeuing. - LOG.critical("Couldn't requeue %r, reason:%r", - message.delivery_tag, exc, exc_info=True) + LOG.critical( + "Couldn't requeue %r, reason:%r", + message.delivery_tag, + exc, + exc_info=True, + ) else: LOG.debug("Message '%s' was requeued.", ku.DelayedPretty(message)) def _process_message(self, data, message, message_type): handler = self._type_handlers.get(message_type) if handler is None: - message.reject_log_error(logger=LOG, - errors=(kombu_exc.MessageStateError,)) - LOG.warning("Unexpected message type: '%s' in message" - " '%s'", message_type, ku.DelayedPretty(message)) + message.reject_log_error( + logger=LOG, errors=(kombu_exc.MessageStateError,) + ) + LOG.warning( + "Unexpected message type: '%s' in message '%s'", + message_type, + ku.DelayedPretty(message), + ) else: if handler.validator is not None: try: handler.validator(data) except excp.InvalidFormat as e: message.reject_log_error( - logger=LOG, errors=(kombu_exc.MessageStateError,)) - LOG.warning("Message '%s' (%s) was rejected due to it" - " being in an invalid format: %s", - ku.DelayedPretty(message), message_type, e) + logger=LOG, errors=(kombu_exc.MessageStateError,) + ) + LOG.warning( + "Message '%s' (%s) was rejected due to it" + " being in an invalid format: %s", + ku.DelayedPretty(message), + message_type, + e, + ) return - message.ack_log_error(logger=LOG, - errors=(kombu_exc.MessageStateError,)) + message.ack_log_error( + logger=LOG, errors=(kombu_exc.MessageStateError,) + ) if message.acknowledged: - LOG.debug("Message '%s' was acknowledged.", - ku.DelayedPretty(message)) + LOG.debug( + "Message '%s' was acknowledged.", ku.DelayedPretty(message) + ) handler.process_message(data, message) else: - message.reject_log_error(logger=LOG, - errors=(kombu_exc.MessageStateError,)) + message.reject_log_error( + logger=LOG, errors=(kombu_exc.MessageStateError,) + ) def on_message(self, data, message): """This method is called on incoming messages.""" LOG.debug("Received message '%s'", ku.DelayedPretty(message)) if self._collect_requeue_votes(data, message): - self._requeue_log_error(message, - errors=(kombu_exc.MessageStateError,)) + self._requeue_log_error( + message, errors=(kombu_exc.MessageStateError,) + ) else: try: message_type = message.properties['type'] except KeyError: message.reject_log_error( - logger=LOG, errors=(kombu_exc.MessageStateError,)) - LOG.warning("The 'type' message property is missing" - " in message '%s'", ku.DelayedPretty(message)) + logger=LOG, errors=(kombu_exc.MessageStateError,) + ) + LOG.warning( + "The 'type' message property is missing in message '%s'", + ku.DelayedPretty(message), + ) else: self._process_message(data, message, message_type) diff --git a/taskflow/engines/worker_based/engine.py b/taskflow/engines/worker_based/engine.py index 743a9fc7b..579dea28b 100644 --- a/taskflow/engines/worker_based/engine.py +++ b/taskflow/engines/worker_based/engine.py @@ -54,17 +54,20 @@ class WorkerBasedActionEngine(engine.ActionEngine): super().__init__(flow, flow_detail, backend, options) # This ensures that any provided executor will be validated before # we get to far in the compilation/execution pipeline... - self._task_executor = self._fetch_task_executor(self._options, - self._flow_detail) + self._task_executor = self._fetch_task_executor( + self._options, self._flow_detail + ) @classmethod def _fetch_task_executor(cls, options, flow_detail): try: e = options['executor'] if not isinstance(e, executor.WorkerTaskExecutor): - raise TypeError("Expected an instance of type '%s' instead of" - " type '%s' for 'executor' option" - % (executor.WorkerTaskExecutor, type(e))) + raise TypeError( + "Expected an instance of type '%s' instead of" + " type '%s' for 'executor' option" + % (executor.WorkerTaskExecutor, type(e)) + ) return e except KeyError: return executor.WorkerTaskExecutor( @@ -75,8 +78,8 @@ class WorkerBasedActionEngine(engine.ActionEngine): topics=options.get('topics', []), transport=options.get('transport'), transport_options=options.get('transport_options'), - transition_timeout=options.get('transition_timeout', - pr.REQUEST_TIMEOUT), - worker_expiry=options.get('worker_expiry', - pr.EXPIRES_AFTER), + transition_timeout=options.get( + 'transition_timeout', pr.REQUEST_TIMEOUT + ), + worker_expiry=options.get('worker_expiry', pr.EXPIRES_AFTER), ) diff --git a/taskflow/engines/worker_based/executor.py b/taskflow/engines/worker_based/executor.py index c16625848..f31457706 100644 --- a/taskflow/engines/worker_based/executor.py +++ b/taskflow/engines/worker_based/executor.py @@ -35,35 +35,53 @@ LOG = logging.getLogger(__name__) class WorkerTaskExecutor(executor.TaskExecutor): """Executes tasks on remote workers.""" - def __init__(self, uuid, exchange, topics, - transition_timeout=pr.REQUEST_TIMEOUT, - url=None, transport=None, transport_options=None, - retry_options=None, worker_expiry=pr.EXPIRES_AFTER): + def __init__( + self, + uuid, + exchange, + topics, + transition_timeout=pr.REQUEST_TIMEOUT, + url=None, + transport=None, + transport_options=None, + retry_options=None, + worker_expiry=pr.EXPIRES_AFTER, + ): self._uuid = uuid self._ongoing_requests = {} self._ongoing_requests_lock = threading.RLock() self._transition_timeout = transition_timeout - self._proxy = proxy.Proxy(uuid, exchange, - on_wait=self._on_wait, url=url, - transport=transport, - transport_options=transport_options, - retry_options=retry_options) + self._proxy = proxy.Proxy( + uuid, + exchange, + on_wait=self._on_wait, + url=url, + transport=transport, + transport_options=transport_options, + retry_options=retry_options, + ) # NOTE(harlowja): This is the most simplest finder impl. that # doesn't have external dependencies (outside of what this engine # already requires); it though does create periodic 'polling' traffic # to workers to 'learn' of the tasks they can perform (and requires # pre-existing knowledge of the topics those workers are on to gather # and update this information). - self._finder = wt.ProxyWorkerFinder(uuid, self._proxy, topics, - worker_expiry=worker_expiry) - self._proxy.dispatcher.type_handlers.update({ - pr.RESPONSE: dispatcher.Handler(self._process_response, - validator=pr.Response.validate), - pr.NOTIFY: dispatcher.Handler( - self._finder.process_response, - validator=functools.partial(pr.Notify.validate, - response=True)), - }) + self._finder = wt.ProxyWorkerFinder( + uuid, self._proxy, topics, worker_expiry=worker_expiry + ) + self._proxy.dispatcher.type_handlers.update( + { + pr.RESPONSE: dispatcher.Handler( + self._process_response, validator=pr.Response.validate + ), + pr.NOTIFY: dispatcher.Handler( + self._finder.process_response, + validator=functools.partial( + pr.Notify.validate, response=True + ), + ), + } + ) # Thread that will run the message dispatching (and periodically # call the on_wait callback to do various things) loop... self._helper = None @@ -73,20 +91,27 @@ class WorkerTaskExecutor(executor.TaskExecutor): def _process_response(self, response, message): """Process response from remote side.""" - LOG.debug("Started processing response message '%s'", - ku.DelayedPretty(message)) + LOG.debug( + "Started processing response message '%s'", + ku.DelayedPretty(message), + ) try: request_uuid = message.properties['correlation_id'] except KeyError: - LOG.warning("The 'correlation_id' message property is" - " missing in message '%s'", - ku.DelayedPretty(message)) + LOG.warning( + "The 'correlation_id' message property is" + " missing in message '%s'", + ku.DelayedPretty(message), + ) else: request = self._ongoing_requests.get(request_uuid) if request is not None: response = pr.Response.from_dict(response) - LOG.debug("Extracted response '%s' and matched it to" - " request '%s'", response, request) + LOG.debug( + "Extracted response '%s' and matched it to request '%s'", + response, + request, + ) if response.state == pr.RUNNING: request.transition_and_log_error(pr.RUNNING, logger=LOG) elif response.state == pr.EVENT: @@ -98,14 +123,16 @@ class WorkerTaskExecutor(executor.TaskExecutor): details = response.data['details'] request.task.notifier.notify(event_type, details) elif response.state in (pr.FAILURE, pr.SUCCESS): - if request.transition_and_log_error(response.state, - logger=LOG): + if request.transition_and_log_error( + response.state, logger=LOG + ): with self._ongoing_requests_lock: del self._ongoing_requests[request.uuid] request.set_result(result=response.data['result']) else: - LOG.warning("Unexpected response status '%s'", - response.state) + LOG.warning( + "Unexpected response status '%s'", response.state + ) else: LOG.debug("Request with id='%s' not found", request_uuid) @@ -126,7 +153,8 @@ class WorkerTaskExecutor(executor.TaskExecutor): raise exc.RequestTimeout( "Request '%s' has expired after waiting for %0.2f" " seconds for it to transition out of (%s) states" - % (request, request_age, ", ".join(pr.WAITING_STATES))) + % (request, request_age, ", ".join(pr.WAITING_STATES)) + ) except exc.RequestTimeout: with misc.capture_failure() as failure: LOG.debug(failure.exception_str) @@ -169,9 +197,9 @@ class WorkerTaskExecutor(executor.TaskExecutor): while waiting_requests: _request_uuid, request = waiting_requests.popitem() worker = finder.get_worker_for_task(request.task) - if (worker is not None and - request.transition_and_log_error(pr.PENDING, - logger=LOG)): + if worker is not None and request.transition_and_log_error( + pr.PENDING, logger=LOG + ): self._publish_request(request, worker) self._messages_processed['finder'] = new_messages_processed @@ -187,20 +215,36 @@ class WorkerTaskExecutor(executor.TaskExecutor): # a worker located). self._clean() - def _submit_task(self, task, task_uuid, action, arguments, - progress_callback=None, result=pr.NO_RESULT, - failures=None): + def _submit_task( + self, + task, + task_uuid, + action, + arguments, + progress_callback=None, + result=pr.NO_RESULT, + failures=None, + ): """Submit task request to a worker.""" - request = pr.Request(task, task_uuid, action, arguments, - timeout=self._transition_timeout, - result=result, failures=failures) + request = pr.Request( + task, + task_uuid, + action, + arguments, + timeout=self._transition_timeout, + result=result, + failures=failures, + ) # Register the callback, so that we can proxy the progress correctly. - if (progress_callback is not None and - task.notifier.can_be_registered(EVENT_UPDATE_PROGRESS)): + if progress_callback is not None and task.notifier.can_be_registered( + EVENT_UPDATE_PROGRESS + ): task.notifier.register(EVENT_UPDATE_PROGRESS, progress_callback) request.future.add_done_callback( - lambda _fut: task.notifier.deregister(EVENT_UPDATE_PROGRESS, - progress_callback)) + lambda _fut: task.notifier.deregister( + EVENT_UPDATE_PROGRESS, progress_callback + ) + ) # Get task's worker and publish request if worker was found. worker = self._finder.get_worker_for_task(task) if worker is not None: @@ -209,42 +253,75 @@ class WorkerTaskExecutor(executor.TaskExecutor): self._ongoing_requests[request.uuid] = request self._publish_request(request, worker) else: - LOG.debug("Delaying submission of '%s', no currently known" - " worker/s available to process it", request) + LOG.debug( + "Delaying submission of '%s', no currently known" + " worker/s available to process it", + request, + ) with self._ongoing_requests_lock: self._ongoing_requests[request.uuid] = request return request.future def _publish_request(self, request, worker): """Publish request to a given topic.""" - LOG.debug("Submitting execution of '%s' to worker '%s' (expecting" - " response identified by reply_to=%s and" - " correlation_id=%s) - waited %0.3f seconds to" - " get published", request, worker, self._uuid, - request.uuid, timeutils.now() - request.created_on) + LOG.debug( + "Submitting execution of '%s' to worker '%s' (expecting" + " response identified by reply_to=%s and" + " correlation_id=%s) - waited %0.3f seconds to" + " get published", + request, + worker, + self._uuid, + request.uuid, + timeutils.now() - request.created_on, + ) try: - self._proxy.publish(request, worker.topic, - reply_to=self._uuid, - correlation_id=request.uuid) + self._proxy.publish( + request, + worker.topic, + reply_to=self._uuid, + correlation_id=request.uuid, + ) except Exception: with misc.capture_failure() as failure: - LOG.critical("Failed to submit '%s' (transitioning it to" - " %s)", request, pr.FAILURE, exc_info=True) + LOG.critical( + "Failed to submit '%s' (transitioning it to %s)", + request, + pr.FAILURE, + exc_info=True, + ) if request.transition_and_log_error(pr.FAILURE, logger=LOG): with self._ongoing_requests_lock: del self._ongoing_requests[request.uuid] request.set_result(failure) - def execute_task(self, task, task_uuid, arguments, - progress_callback=None): - return self._submit_task(task, task_uuid, pr.EXECUTE, arguments, - progress_callback=progress_callback) + def execute_task(self, task, task_uuid, arguments, progress_callback=None): + return self._submit_task( + task, + task_uuid, + pr.EXECUTE, + arguments, + progress_callback=progress_callback, + ) - def revert_task(self, task, task_uuid, arguments, result, failures, - progress_callback=None): - return self._submit_task(task, task_uuid, pr.REVERT, arguments, - result=result, failures=failures, - progress_callback=progress_callback) + def revert_task( + self, + task, + task_uuid, + arguments, + result, + failures, + progress_callback=None, + ): + return self._submit_task( + task, + task_uuid, + pr.REVERT, + arguments, + result=result, + failures=failures, + progress_callback=progress_callback, + ) def wait_for_workers(self, workers=1, timeout=None): """Waits for geq workers to notify they are ready to do work. @@ -255,14 +332,14 @@ class WorkerTaskExecutor(executor.TaskExecutor): return how many workers are still needed, otherwise it will return zero. """ - return self._finder.wait_for_workers(workers=workers, - timeout=timeout) + return self._finder.wait_for_workers(workers=workers, timeout=timeout) def start(self): """Starts message processing thread.""" if self._helper is not None: - raise RuntimeError("Worker executor must be stopped before" - " it can be started") + raise RuntimeError( + "Worker executor must be stopped before it can be started" + ) self._helper = tu.daemon_thread(self._proxy.start) self._helper.start() self._proxy.wait() diff --git a/taskflow/engines/worker_based/protocol.py b/taskflow/engines/worker_based/protocol.py index fa85462d2..12b94708e 100644 --- a/taskflow/engines/worker_based/protocol.py +++ b/taskflow/engines/worker_based/protocol.py @@ -53,10 +53,7 @@ EXECUTE = 'execute' REVERT = 'revert' # Remote task action to event map. -ACTION_TO_EVENT = { - EXECUTE: executor.EXECUTED, - REVERT: executor.REVERTED -} +ACTION_TO_EVENT = {EXECUTE: executor.EXECUTED, REVERT: executor.REVERTED} # NOTE(skudriashev): A timeout which specifies request expiration period. REQUEST_TIMEOUT = 60 @@ -149,9 +146,11 @@ class Message(metaclass=abc.ABCMeta): """Base class for all message types.""" def __repr__(self): - return ("<%s object at 0x%x with contents %s>" - % (reflection.get_class_name(self, fully_qualified=False), - id(self), self.to_dict())) + return "<%s object at 0x%x with contents %s>" % ( + reflection.get_class_name(self, fully_qualified=False), + id(self), + self.to_dict(), + ) @abc.abstractmethod def to_dict(self): @@ -180,7 +179,7 @@ class Notify(Message): "items": { "type": "string", }, - } + }, }, "required": ["topic", 'tasks'], "additionalProperties": False, @@ -217,21 +216,24 @@ class Notify(Message): except su.ValidationError as e: cls_name = reflection.get_class_name(cls, fully_qualified=False) if response: - excp.raise_with_cause(excp.InvalidFormat, - "%s message response data not of the" - " expected format: %s" % (cls_name, - e.message), - cause=e) + excp.raise_with_cause( + excp.InvalidFormat, + "%s message response data not of the" + " expected format: %s" % (cls_name, e.message), + cause=e, + ) else: - excp.raise_with_cause(excp.InvalidFormat, - "%s message sender data not of the" - " expected format: %s" % (cls_name, - e.message), - cause=e) + excp.raise_with_cause( + excp.InvalidFormat, + "%s message sender data not of the" + " expected format: %s" % (cls_name, e.message), + cause=e, + ) -_WorkUnit = collections.namedtuple('_WorkUnit', ['task_cls', 'task_name', - 'action', 'arguments']) +_WorkUnit = collections.namedtuple( + '_WorkUnit', ['task_cls', 'task_name', 'action', 'arguments'] +) class Request(Message): @@ -299,9 +301,16 @@ class Request(Message): 'required': ['task_cls', 'task_name', 'task_version', 'action'], } - def __init__(self, task, uuid, action, - arguments, timeout=REQUEST_TIMEOUT, result=NO_RESULT, - failures=None): + def __init__( + self, + task, + uuid, + action, + arguments, + timeout=REQUEST_TIMEOUT, + result=NO_RESULT, + failures=None, + ): self._action = action self._event = ACTION_TO_EVENT[action] self._arguments = arguments @@ -383,8 +392,12 @@ class Request(Message): try: moved = self.transition(new_state) except excp.InvalidState: - logger.warnng("Failed to transition '%s' to %s state.", self, - new_state, exc_info=True) + logger.warnng( + "Failed to transition '%s' to %s state.", + self, + new_state, + exc_info=True, + ) return moved @fasteners.locked @@ -402,14 +415,19 @@ class Request(Message): try: self._machine.process_event(make_an_event(new_state)) except (machine_excp.NotFound, machine_excp.InvalidState) as e: - raise excp.InvalidState("Request transition from %s to %s is" - " not allowed: %s" % (old_state, - new_state, e)) + raise excp.InvalidState( + "Request transition from %s to %s is" + " not allowed: %s" % (old_state, new_state, e) + ) else: if new_state in STOP_TIMER_STATES: self._watch.stop() - LOG.debug("Transitioned '%s' from %s state to %s state", self, - old_state, new_state) + LOG.debug( + "Transitioned '%s' from %s state to %s state", + self, + old_state, + new_state, + ) return True @classmethod @@ -418,11 +436,12 @@ class Request(Message): su.schema_validate(data, cls.SCHEMA) except su.ValidationError as e: cls_name = reflection.get_class_name(cls, fully_qualified=False) - excp.raise_with_cause(excp.InvalidFormat, - "%s message response data not of the" - " expected format: %s" % (cls_name, - e.message), - cause=e) + excp.raise_with_cause( + excp.InvalidFormat, + "%s message response data not of the" + " expected format: %s" % (cls_name, e.message), + cause=e, + ) else: # Validate all failure dictionaries that *may* be present... failures = [] @@ -556,11 +575,12 @@ class Response(Message): su.schema_validate(data, cls.SCHEMA) except su.ValidationError as e: cls_name = reflection.get_class_name(cls, fully_qualified=False) - excp.raise_with_cause(excp.InvalidFormat, - "%s message response data not of the" - " expected format: %s" % (cls_name, - e.message), - cause=e) + excp.raise_with_cause( + excp.InvalidFormat, + "%s message response data not of the" + " expected format: %s" % (cls_name, e.message), + cause=e, + ) else: state = data['state'] if state == FAILURE and 'result' in data: diff --git a/taskflow/engines/worker_based/proxy.py b/taskflow/engines/worker_based/proxy.py index 311e4065d..c393a2e89 100644 --- a/taskflow/engines/worker_based/proxy.py +++ b/taskflow/engines/worker_based/proxy.py @@ -31,11 +31,13 @@ DRAIN_EVENTS_PERIOD = 1 # instead of returning the raw results from the kombu connection objects # themselves so that a person can not mutate those objects (which would be # bad). -_ConnectionDetails = collections.namedtuple('_ConnectionDetails', - ['uri', 'transport']) -_TransportDetails = collections.namedtuple('_TransportDetails', - ['options', 'driver_type', - 'driver_name', 'driver_version']) +_ConnectionDetails = collections.namedtuple( + '_ConnectionDetails', ['uri', 'transport'] +) +_TransportDetails = collections.namedtuple( + '_TransportDetails', + ['options', 'driver_type', 'driver_name', 'driver_version'], +) class Proxy: @@ -65,10 +67,17 @@ class Proxy: # value is valid... _RETRY_INT_OPTS = frozenset(['max_retries']) - def __init__(self, topic, exchange, - type_handlers=None, on_wait=None, url=None, - transport=None, transport_options=None, - retry_options=None): + def __init__( + self, + topic, + exchange, + type_handlers=None, + on_wait=None, + url=None, + transport=None, + transport_options=None, + retry_options=None, + ): self._topic = topic self._exchange_name = exchange self._on_wait = on_wait @@ -77,7 +86,8 @@ class Proxy: # NOTE(skudriashev): Process all incoming messages only if proxy is # running, otherwise requeue them. requeue_filters=[lambda data, message: not self.is_running], - type_handlers=type_handlers) + type_handlers=type_handlers, + ) ensure_options = self.DEFAULT_RETRY_OPTIONS.copy() if retry_options is not None: @@ -91,9 +101,11 @@ class Proxy: else: tmp_val = float(val) if tmp_val < 0: - raise ValueError("Expected value greater or equal to" - " zero for 'retry_options' %s; got" - " %s instead" % (k, val)) + raise ValueError( + "Expected value greater or equal to" + " zero for 'retry_options' %s; got" + " %s instead" % (k, val) + ) ensure_options[k] = tmp_val self._ensure_options = ensure_options @@ -104,12 +116,14 @@ class Proxy: self._drain_events_timeout = polling_interval # create connection - self._conn = kombu.Connection(url, transport=transport, - transport_options=transport_options) + self._conn = kombu.Connection( + url, transport=transport, transport_options=transport_options + ) # create exchange - self._exchange = kombu.Exchange(name=self._exchange_name, - durable=False, auto_delete=True) + self._exchange = kombu.Exchange( + name=self._exchange_name, durable=False, auto_delete=True + ) @property def dispatcher(self): @@ -131,10 +145,11 @@ class Proxy: options=transport_options, driver_type=self._conn.transport.driver_type, driver_name=self._conn.transport.driver_name, - driver_version=driver_version) + driver_version=driver_version, + ) return _ConnectionDetails( - uri=self._conn.as_uri(include_password=False), - transport=transport) + uri=self._conn.as_uri(include_password=False), transport=transport + ) @property def is_running(self): @@ -144,10 +159,14 @@ class Proxy: def _make_queue(self, routing_key, exchange, channel=None): """Make a named queue for the given exchange.""" queue_name = f"{self._exchange_name}_{routing_key}" - return kombu.Queue(name=queue_name, - routing_key=routing_key, durable=False, - exchange=exchange, auto_delete=True, - channel=channel) + return kombu.Queue( + name=queue_name, + routing_key=routing_key, + durable=False, + exchange=exchange, + auto_delete=True, + channel=channel, + ) def publish(self, msg, routing_key, reply_to=None, correlation_id=None): """Publish message to the named exchange with given routing key.""" @@ -159,27 +178,33 @@ class Proxy: # Filter out any empty keys... routing_keys = [r_k for r_k in routing_keys if r_k] if not routing_keys: - LOG.warning("No routing key/s specified; unable to send '%s'" - " to any target queue on exchange '%s'", msg, - self._exchange_name) + LOG.warning( + "No routing key/s specified; unable to send '%s'" + " to any target queue on exchange '%s'", + msg, + self._exchange_name, + ) return def _publish(producer, routing_key): queue = self._make_queue(routing_key, self._exchange) - producer.publish(body=msg.to_dict(), - routing_key=routing_key, - exchange=self._exchange, - declare=[queue], - type=msg.TYPE, - reply_to=reply_to, - correlation_id=correlation_id) + producer.publish( + body=msg.to_dict(), + routing_key=routing_key, + exchange=self._exchange, + declare=[queue], + type=msg.TYPE, + reply_to=reply_to, + correlation_id=correlation_id, + ) def _publish_errback(exc, interval): LOG.exception('Publishing error: %s', exc) LOG.info('Retry triggering in %s seconds', interval) - LOG.debug("Sending '%s' message using routing keys %s", - msg, routing_keys) + LOG.debug( + "Sending '%s' message using routing keys %s", msg, routing_keys + ) with kombu.connections[self._conn].acquire(block=True) as conn: with conn.Producer() as producer: ensure_kwargs = self._ensure_options.copy() @@ -201,8 +226,9 @@ class Proxy: LOG.exception('Draining error: %s', exc) LOG.info('Retry triggering in %s seconds', interval) - LOG.info("Starting to consume from the '%s' exchange.", - self._exchange_name) + LOG.info( + "Starting to consume from the '%s' exchange.", self._exchange_name + ) with kombu.connections[self._conn].acquire(block=True) as conn: queue = self._make_queue(self._topic, self._exchange, channel=conn) callbacks = [self._dispatcher.on_message] diff --git a/taskflow/engines/worker_based/server.py b/taskflow/engines/worker_based/server.py index ed69b0c35..1afdd60f9 100644 --- a/taskflow/engines/worker_based/server.py +++ b/taskflow/engines/worker_based/server.py @@ -32,27 +32,41 @@ LOG = logging.getLogger(__name__) class Server: """Server implementation that waits for incoming tasks requests.""" - def __init__(self, topic, exchange, executor, endpoints, - url=None, transport=None, transport_options=None, - retry_options=None): + def __init__( + self, + topic, + exchange, + executor, + endpoints, + url=None, + transport=None, + transport_options=None, + retry_options=None, + ): type_handlers = { pr.NOTIFY: dispatcher.Handler( self._delayed_process(self._process_notify), - validator=functools.partial(pr.Notify.validate, - response=False)), + validator=functools.partial( + pr.Notify.validate, response=False + ), + ), pr.REQUEST: dispatcher.Handler( self._delayed_process(self._process_request), - validator=pr.Request.validate), + validator=pr.Request.validate, + ), } self._executor = executor - self._proxy = proxy.Proxy(topic, exchange, - type_handlers=type_handlers, - url=url, transport=transport, - transport_options=transport_options, - retry_options=retry_options) + self._proxy = proxy.Proxy( + topic, + exchange, + type_handlers=type_handlers, + url=url, + transport=transport, + transport_options=transport_options, + retry_options=retry_options, + ) self._topic = topic - self._endpoints = {endpoint.name: endpoint - for endpoint in endpoints} + self._endpoints = {endpoint.name: endpoint for endpoint in endpoints} def _delayed_process(self, func): """Runs the function using the instances executor (eventually). @@ -65,25 +79,34 @@ class Server: func_name = reflection.get_callable_name(func) def _on_run(watch, content, message): - LOG.trace("It took %s seconds to get around to running" - " function/method '%s' with" - " message '%s'", watch.elapsed(), func_name, - ku.DelayedPretty(message)) + LOG.trace( + "It took %s seconds to get around to running" + " function/method '%s' with" + " message '%s'", + watch.elapsed(), + func_name, + ku.DelayedPretty(message), + ) return func(content, message) def _on_receive(content, message): - LOG.debug("Submitting message '%s' for execution in the" - " future to '%s'", ku.DelayedPretty(message), func_name) + LOG.debug( + "Submitting message '%s' for execution in the future to '%s'", + ku.DelayedPretty(message), + func_name, + ) watch = timeutils.StopWatch() watch.start() try: self._executor.submit(_on_run, watch, content, message) except RuntimeError: - LOG.error("Unable to continue processing message '%s'," - " submission to instance executor (with later" - " execution by '%s') was unsuccessful", - ku.DelayedPretty(message), func_name, - exc_info=True) + LOG.exception( + "Unable to continue processing message '%s'," + " submission to instance executor (with later" + " execution by '%s') was unsuccessful", + ku.DelayedPretty(message), + func_name, + ) return _on_receive @@ -103,8 +126,7 @@ class Server: try: properties.append(message.properties[prop]) except KeyError: - raise ValueError("The '%s' message property is missing" % - prop) + raise ValueError("The '%s' message property is missing" % prop) return properties def _reply(self, capture, reply_to, task_uuid, state=pr.FAILURE, **kwargs): @@ -122,9 +144,13 @@ class Server: except Exception: if not capture: raise - LOG.critical("Failed to send reply to '%s' for task '%s' with" - " response %s", reply_to, task_uuid, response, - exc_info=True) + LOG.critical( + "Failed to send reply to '%s' for task '%s' with response %s", + reply_to, + task_uuid, + response, + exc_info=True, + ) return published def _on_event(self, reply_to, task_uuid, event_type, details): @@ -132,26 +158,39 @@ class Server: # NOTE(harlowja): the executor that will trigger this using the # task notification/listener mechanism will handle logging if this # fails, so thats why capture is 'False' is used here. - self._reply(False, reply_to, task_uuid, pr.EVENT, - event_type=event_type, details=details) + self._reply( + False, + reply_to, + task_uuid, + pr.EVENT, + event_type=event_type, + details=details, + ) def _process_notify(self, notify, message): """Process notify message and reply back.""" try: reply_to = message.properties['reply_to'] except KeyError: - LOG.warning("The 'reply_to' message property is missing" - " in received notify message '%s'", - ku.DelayedPretty(message), exc_info=True) + LOG.warning( + "The 'reply_to' message property is missing" + " in received notify message '%s'", + ku.DelayedPretty(message), + exc_info=True, + ) else: - response = pr.Notify(topic=self._topic, - tasks=list(self._endpoints.keys())) + response = pr.Notify( + topic=self._topic, tasks=list(self._endpoints.keys()) + ) try: self._proxy.publish(response, routing_key=reply_to) except Exception: - LOG.critical("Failed to send reply to '%s' with notify" - " response '%s'", reply_to, response, - exc_info=True) + LOG.critical( + "Failed to send reply to '%s' with notify response '%s'", + reply_to, + response, + exc_info=True, + ) def _process_request(self, request, message): """Process request message and reply back.""" @@ -162,22 +201,28 @@ class Server: # in the first place...). reply_to, task_uuid = self._parse_message(message) except ValueError: - LOG.warning("Failed to parse request attributes from message '%s'", - ku.DelayedPretty(message), exc_info=True) + LOG.warning( + "Failed to parse request attributes from message '%s'", + ku.DelayedPretty(message), + exc_info=True, + ) return else: # prepare reply callback - reply_callback = functools.partial(self._reply, True, reply_to, - task_uuid) + reply_callback = functools.partial( + self._reply, True, reply_to, task_uuid + ) # Parse the request to get the activity/work to perform. try: work = pr.Request.from_dict(request, task_uuid=task_uuid) except ValueError: with misc.capture_failure() as failure: - LOG.warning("Failed to parse request contents" - " from message '%s'", - ku.DelayedPretty(message), exc_info=True) + LOG.warning( + "Failed to parse request contents from message '%s'", + ku.DelayedPretty(message), + exc_info=True, + ) reply_callback(result=pr.failure_to_dict(failure)) return @@ -186,10 +231,13 @@ class Server: endpoint = self._endpoints[work.task_cls] except KeyError: with misc.capture_failure() as failure: - LOG.warning("The '%s' task endpoint does not exist, unable" - " to continue processing request message '%s'", - work.task_cls, ku.DelayedPretty(message), - exc_info=True) + LOG.warning( + "The '%s' task endpoint does not exist, unable" + " to continue processing request message '%s'", + work.task_cls, + ku.DelayedPretty(message), + exc_info=True, + ) reply_callback(result=pr.failure_to_dict(failure)) return else: @@ -197,10 +245,15 @@ class Server: handler = getattr(endpoint, work.action) except AttributeError: with misc.capture_failure() as failure: - LOG.warning("The '%s' handler does not exist on task" - " endpoint '%s', unable to continue processing" - " request message '%s'", work.action, endpoint, - ku.DelayedPretty(message), exc_info=True) + LOG.warning( + "The '%s' handler does not exist on task" + " endpoint '%s', unable to continue processing" + " request message '%s'", + work.action, + endpoint, + ku.DelayedPretty(message), + exc_info=True, + ) reply_callback(result=pr.failure_to_dict(failure)) return else: @@ -208,10 +261,14 @@ class Server: task = endpoint.generate(name=work.task_name) except Exception: with misc.capture_failure() as failure: - LOG.warning("The '%s' task '%s' generation for request" - " message '%s' failed", endpoint, - work.action, ku.DelayedPretty(message), - exc_info=True) + LOG.warning( + "The '%s' task '%s' generation for request" + " message '%s' failed", + endpoint, + work.action, + ku.DelayedPretty(message), + exc_info=True, + ) reply_callback(result=pr.failure_to_dict(failure)) return else: @@ -222,24 +279,31 @@ class Server: # emit them back to the engine... for handling at the engine side # of things... if task.notifier.can_be_registered(nt.Notifier.ANY): - task.notifier.register(nt.Notifier.ANY, - functools.partial(self._on_event, - reply_to, task_uuid)) + task.notifier.register( + nt.Notifier.ANY, + functools.partial(self._on_event, reply_to, task_uuid), + ) elif isinstance(task.notifier, nt.RestrictedNotifier): # Only proxy the allowable events then... for event_type in task.notifier.events_iter(): - task.notifier.register(event_type, - functools.partial(self._on_event, - reply_to, task_uuid)) + task.notifier.register( + event_type, + functools.partial(self._on_event, reply_to, task_uuid), + ) # Perform the task action. try: result = handler(task, **work.arguments) except Exception: with misc.capture_failure() as failure: - LOG.warning("The '%s' endpoint '%s' execution for request" - " message '%s' failed", endpoint, work.action, - ku.DelayedPretty(message), exc_info=True) + LOG.warning( + "The '%s' endpoint '%s' execution for request" + " message '%s' failed", + endpoint, + work.action, + ku.DelayedPretty(message), + exc_info=True, + ) reply_callback(result=pr.failure_to_dict(failure)) else: # And be done with it! diff --git a/taskflow/engines/worker_based/types.py b/taskflow/engines/worker_based/types.py index 1017f4105..5f5791784 100644 --- a/taskflow/engines/worker_based/types.py +++ b/taskflow/engines/worker_based/types.py @@ -71,19 +71,26 @@ class TopicWorker: r = reflection.get_class_name(self, fully_qualified=False) if self.identity is not self._NO_IDENTITY: r += "(identity={}, tasks={}, topic={})".format( - self.identity, self.tasks, self.topic) + self.identity, self.tasks, self.topic + ) else: r += "(identity=*, tasks={}, topic={})".format( - self.tasks, self.topic) + self.tasks, self.topic + ) return r class ProxyWorkerFinder: """Requests and receives responses about workers topic+task details.""" - def __init__(self, uuid, proxy, topics, - beat_periodicity=pr.NOTIFY_PERIOD, - worker_expiry=pr.EXPIRES_AFTER): + def __init__( + self, + uuid, + proxy, + topics, + beat_periodicity=pr.NOTIFY_PERIOD, + worker_expiry=pr.EXPIRES_AFTER, + ): self._cond = threading.Condition() self._proxy = proxy self._topics = topics @@ -134,7 +141,7 @@ class ProxyWorkerFinder: if len(available_workers) == 1: return available_workers[0] else: - return random.choice(available_workers) + return random.choice(available_workers) # noqa: S311 @property def messages_processed(self): @@ -157,14 +164,14 @@ class ProxyWorkerFinder: match workers to tasks to run). """ if self._messages_published == 0: - self._proxy.publish(pr.Notify(), - self._topics, reply_to=self._uuid) + self._proxy.publish(pr.Notify(), self._topics, reply_to=self._uuid) self._messages_published += 1 self._watch.restart() else: if self._watch.expired(): - self._proxy.publish(pr.Notify(), - self._topics, reply_to=self._uuid) + self._proxy.publish( + pr.Notify(), self._topics, reply_to=self._uuid + ) self._messages_published += 1 self._watch.restart() @@ -188,16 +195,21 @@ class ProxyWorkerFinder: def process_response(self, data, message): """Process notify message sent from remote side.""" - LOG.debug("Started processing notify response message '%s'", - ku.DelayedPretty(message)) + LOG.debug( + "Started processing notify response message '%s'", + ku.DelayedPretty(message), + ) response = pr.Notify(**data) LOG.debug("Extracted notify response '%s'", response) with self._cond: - worker, new_or_updated = self._add(response.topic, - response.tasks) + worker, new_or_updated = self._add(response.topic, response.tasks) if new_or_updated: - LOG.debug("Updated worker '%s' (%s total workers are" - " currently known)", worker, self.total_workers) + LOG.debug( + "Updated worker '%s' (%s total workers are" + " currently known)", + worker, + self.total_workers, + ) self._cond.notify_all() worker.last_seen = timeutils.now() self._messages_processed += 1 @@ -207,8 +219,9 @@ class ProxyWorkerFinder: Returns how many workers were removed. """ - if (not self._workers or - (self._worker_expiry is None or self._worker_expiry <= 0)): + if not self._workers or ( + self._worker_expiry is None or self._worker_expiry <= 0 + ): return 0 dead_workers = {} with self._cond: @@ -225,9 +238,12 @@ class ProxyWorkerFinder: self._cond.notify_all() if dead_workers and LOG.isEnabledFor(logging.INFO): for worker, secs_since_last_seen in dead_workers.values(): - LOG.info("Removed worker '%s' as it has not responded to" - " notification requests in %0.3f seconds", - worker, secs_since_last_seen) + LOG.info( + "Removed worker '%s' as it has not responded to" + " notification requests in %0.3f seconds", + worker, + secs_since_last_seen, + ) return len(dead_workers) def reset(self): diff --git a/taskflow/engines/worker_based/worker.py b/taskflow/engines/worker_based/worker.py index cb7cf8d7d..b2038d522 100644 --- a/taskflow/engines/worker_based/worker.py +++ b/taskflow/engines/worker_based/worker.py @@ -55,24 +55,38 @@ class Worker: (see: :py:attr:`~.proxy.Proxy.DEFAULT_RETRY_OPTIONS`) """ - def __init__(self, exchange, topic, tasks, - executor=None, threads_count=None, url=None, - transport=None, transport_options=None, - retry_options=None): + def __init__( + self, + exchange, + topic, + tasks, + executor=None, + threads_count=None, + url=None, + transport=None, + transport_options=None, + retry_options=None, + ): self._topic = topic self._executor = executor self._owns_executor = False if self._executor is None: self._executor = futurist.ThreadPoolExecutor( - max_workers=threads_count) + max_workers=threads_count + ) self._owns_executor = True self._endpoints = self._derive_endpoints(tasks) self._exchange = exchange - self._server = server.Server(topic, exchange, self._executor, - self._endpoints, url=url, - transport=transport, - transport_options=transport_options, - retry_options=retry_options) + self._server = server.Server( + topic, + exchange, + self._executor, + self._endpoints, + url=url, + transport=transport, + transport_options=transport_options, + retry_options=retry_options, + ) @staticmethod def _derive_endpoints(tasks): @@ -86,8 +100,9 @@ class Worker: connection_details = self._server.connection_details transport = connection_details.transport if transport.driver_version: - transport_driver = "{} v{}".format(transport.driver_name, - transport.driver_version) + transport_driver = "{} v{}".format( + transport.driver_name, transport.driver_version + ) else: transport_driver = transport.driver_name try: @@ -145,12 +160,12 @@ class Worker: if __name__ == '__main__': import argparse import logging as log + parser = argparse.ArgumentParser() parser.add_argument("--exchange", required=True) parser.add_argument("--connection-url", required=True) parser.add_argument("--topic", required=True) - parser.add_argument("--task", action='append', - metavar="TASK", default=[]) + parser.add_argument("--task", action='append', metavar="TASK", default=[]) parser.add_argument("-v", "--verbose", action='store_true') args = parser.parse_args() if args.verbose: diff --git a/taskflow/examples/99_bottles.py b/taskflow/examples/99_bottles.py index f73e01524..f47349547 100644 --- a/taskflow/examples/99_bottles.py +++ b/taskflow/examples/99_bottles.py @@ -22,9 +22,9 @@ import traceback from kazoo import client -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from taskflow.conductors import backends as conductor_backends @@ -90,16 +90,19 @@ def make_bottles(count): s = lf.Flow("bottle-song") - take_bottle = TakeABottleDown("take-bottle-%s" % count, - inject={'bottles_left': count}, - provides='bottles_left') + take_bottle = TakeABottleDown( + "take-bottle-%s" % count, + inject={'bottles_left': count}, + provides='bottles_left', + ) pass_it = PassItAround("pass-%s-around" % count) next_bottles = Conclusion("next-bottles-%s" % (count - 1)) s.add(take_bottle, pass_it, next_bottles) for bottle in reversed(list(range(1, count))): - take_bottle = TakeABottleDown("take-bottle-%s" % bottle, - provides='bottles_left') + take_bottle = TakeABottleDown( + "take-bottle-%s" % bottle, provides='bottles_left' + ) pass_it = PassItAround("pass-%s-around" % bottle) next_bottles = Conclusion("next-bottles-%s" % (bottle - 1)) s.add(take_bottle, pass_it, next_bottles) @@ -122,15 +125,17 @@ def run_conductor(only_run_once=False): if event.endswith("_start"): w = timeutils.StopWatch() w.start() - base_event = event[0:-len("_start")] + base_event = event[0 : -len("_start")] event_watches[base_event] = w if event.endswith("_end"): - base_event = event[0:-len("_end")] + base_event = event[0 : -len("_end")] try: w = event_watches.pop(base_event) w.stop() - print("It took %0.3f seconds for event '%s' to finish" - % (w.elapsed(), base_event)) + print( + "It took %0.3f seconds for event '%s' to finish" + % (w.elapsed(), base_event) + ) except KeyError: pass if event == 'running_end' and only_run_once: @@ -142,12 +147,14 @@ def run_conductor(only_run_once=False): with contextlib.closing(persist_backend): with contextlib.closing(persist_backend.get_connection()) as conn: conn.upgrade() - job_backend = job_backends.fetch(my_name, JB_CONF, - persistence=persist_backend) + job_backend = job_backends.fetch( + my_name, JB_CONF, persistence=persist_backend + ) job_backend.connect() with contextlib.closing(job_backend): - cond = conductor_backends.fetch('blocking', my_name, job_backend, - persistence=persist_backend) + cond = conductor_backends.fetch( + 'blocking', my_name, job_backend, persistence=persist_backend + ) on_conductor_event = functools.partial(on_conductor_event, cond) cond.notifier.register(cond.notifier.ANY, on_conductor_event) # Run forever, and kill -9 or ctrl-c me... @@ -166,8 +173,9 @@ def run_poster(): with contextlib.closing(persist_backend): with contextlib.closing(persist_backend.get_connection()) as conn: conn.upgrade() - job_backend = job_backends.fetch(my_name, JB_CONF, - persistence=persist_backend) + job_backend = job_backends.fetch( + my_name, JB_CONF, persistence=persist_backend + ) job_backend.connect() with contextlib.closing(job_backend): # Create information in the persistence backend about the @@ -175,14 +183,19 @@ def run_poster(): # can be called to create the tasks that the work unit needs # to be done. lb = models.LogBook("post-from-%s" % my_name) - fd = models.FlowDetail("song-from-%s" % my_name, - uuidutils.generate_uuid()) + fd = models.FlowDetail( + "song-from-%s" % my_name, uuidutils.generate_uuid() + ) lb.add(fd) with contextlib.closing(persist_backend.get_connection()) as conn: conn.save_logbook(lb) - engines.save_factory_details(fd, make_bottles, - [HOW_MANY_BOTTLES], {}, - backend=persist_backend) + engines.save_factory_details( + fd, + make_bottles, + [HOW_MANY_BOTTLES], + {}, + backend=persist_backend, + ) # Post, and be done with it! jb = job_backend.post("song-from-%s" % my_name, book=lb) print("Posted: %s" % jb) diff --git a/taskflow/examples/alphabet_soup.py b/taskflow/examples/alphabet_soup.py index cc6d3fd6d..56d3f723f 100644 --- a/taskflow/examples/alphabet_soup.py +++ b/taskflow/examples/alphabet_soup.py @@ -23,9 +23,9 @@ import time logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) sys.path.insert(0, self_dir) @@ -74,8 +74,9 @@ print("Constructing...") soup = linear_flow.Flow("alphabet-soup") for letter in string.ascii_lowercase: abc = AlphabetTask(letter) - abc.notifier.register(task.EVENT_UPDATE_PROGRESS, - functools.partial(progress_printer, abc)) + abc.notifier.register( + task.EVENT_UPDATE_PROGRESS, functools.partial(progress_printer, abc) + ) soup.add(abc) try: print("Loading...") diff --git a/taskflow/examples/build_a_car.py b/taskflow/examples/build_a_car.py index e849784a2..352383a15 100644 --- a/taskflow/examples/build_a_car.py +++ b/taskflow/examples/build_a_car.py @@ -19,9 +19,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) @@ -64,6 +64,7 @@ def build_wheels(): # These just return true to indiciate success, they would in the real work # do more than just that. + def install_engine(frame, engine): return True @@ -130,15 +131,22 @@ flow = lf.Flow("make-auto").add( task.FunctorTask(install_engine, provides='engine_installed'), task.FunctorTask(install_doors, provides='doors_installed'), task.FunctorTask(install_windows, provides='windows_installed'), - task.FunctorTask(install_wheels, provides='wheels_installed')), - task.FunctorTask(verify, requires=['frame', - 'engine', - 'doors', - 'wheels', - 'engine_installed', - 'doors_installed', - 'windows_installed', - 'wheels_installed'])) + task.FunctorTask(install_wheels, provides='wheels_installed'), + ), + task.FunctorTask( + verify, + requires=[ + 'frame', + 'engine', + 'doors', + 'wheels', + 'engine_installed', + 'doors_installed', + 'windows_installed', + 'wheels_installed', + ], + ), +) # This dictionary will be provided to the tasks as a specification for what # the tasks should produce, in this example this specification will influence diff --git a/taskflow/examples/buildsystem.py b/taskflow/examples/buildsystem.py index a0ffe149a..8b050085c 100644 --- a/taskflow/examples/buildsystem.py +++ b/taskflow/examples/buildsystem.py @@ -18,9 +18,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) import taskflow.engines @@ -40,17 +40,18 @@ import example_utils as eu # noqa class CompileTask(task.Task): """Pretends to take a source and make object file.""" + default_provides = 'object_filename' def execute(self, source_filename): object_filename = '%s.o' % os.path.splitext(source_filename)[0] - print('Compiling %s into %s' - % (source_filename, object_filename)) + print('Compiling %s into %s' % (source_filename, object_filename)) return object_filename class LinkTask(task.Task): """Pretends to link executable form several object files.""" + default_provides = 'executable' def __init__(self, executable_path, *args, **kwargs): @@ -59,14 +60,16 @@ class LinkTask(task.Task): def execute(self, **kwargs): object_filenames = list(kwargs.values()) - print('Linking executable %s from files %s' - % (self._executable_path, - ', '.join(object_filenames))) + print( + 'Linking executable %s from files %s' + % (self._executable_path, ', '.join(object_filenames)) + ) return self._executable_path class BuildDocsTask(task.Task): """Pretends to build docs from sources.""" + default_provides = 'docs' def execute(self, **kwargs): @@ -84,9 +87,13 @@ def make_flow_and_store(source_files, executable_only=False): object_stored = '%s-object' % source store[source_stored] = source object_targets.append(object_stored) - flow.add(CompileTask(name='compile-%s' % source, - rebind={'source_filename': source_stored}, - provides=object_stored)) + flow.add( + CompileTask( + name='compile-%s' % source, + rebind={'source_filename': source_stored}, + provides=object_stored, + ) + ) flow.add(BuildDocsTask(requires=list(store.keys()))) # Try this to see executable_only switch broken: diff --git a/taskflow/examples/calculate_in_parallel.py b/taskflow/examples/calculate_in_parallel.py index efdd89107..8eb62059c 100644 --- a/taskflow/examples/calculate_in_parallel.py +++ b/taskflow/examples/calculate_in_parallel.py @@ -18,9 +18,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) import taskflow.engines @@ -65,8 +65,7 @@ flow = lf.Flow('root').add( # Provide the initial values for other tasks to depend on. # # x1 = 2, y1 = 3, x2 = 5, x3 = 8 - Provider("provide-adder", 2, 3, 5, 8, - provides=('x1', 'y1', 'x2', 'y2')), + Provider("provide-adder", 2, 3, 5, 8, provides=('x1', 'y1', 'x2', 'y2')), # Note here that we define the flow that contains the 2 adders to be an # unordered flow since the order in which these execute does not matter, # another way to solve this would be to use a graph_flow pattern, which @@ -85,7 +84,8 @@ flow = lf.Flow('root').add( Adder(name="add-2", provides='z2', rebind=['x2', 'y2']), ), # r = z1+z2 = 18 - Adder(name="sum-1", provides='r', rebind=['z1', 'z2'])) + Adder(name="sum-1", provides='r', rebind=['z1', 'z2']), +) # The result here will be all results (from all tasks) which is stored in an diff --git a/taskflow/examples/calculate_linear.py b/taskflow/examples/calculate_linear.py index d68cdc63f..9adee4b23 100644 --- a/taskflow/examples/calculate_linear.py +++ b/taskflow/examples/calculate_linear.py @@ -18,9 +18,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) import taskflow.engines @@ -50,7 +50,6 @@ from taskflow import task # storage backend before your tasks are ran (which accomplishes a similar goal # in a more uniform manner). class Provider(task.Task): - def __init__(self, name, *args, **kwargs): super().__init__(name=name, **kwargs) self._provide = args @@ -77,8 +76,7 @@ class Adder(task.Task): # this function needs to undo if some later operation fails. class Multiplier(task.Task): def __init__(self, name, multiplier, provides=None, rebind=None): - super().__init__(name=name, provides=provides, - rebind=rebind) + super().__init__(name=name, provides=provides, rebind=rebind) self._multiplier = multiplier def execute(self, z): @@ -104,7 +102,7 @@ flow = lf.Flow('root').add( # bound to the 'z' variable provided from the above 'provider' object but # instead the 'z' argument will be taken from the 'a' variable provided # by the second add-2 listed above. - Multiplier("multi", 3, provides='r', rebind={'z': 'a'}) + Multiplier("multi", 3, provides='r', rebind={'z': 'a'}), ) # The result here will be all results (from all tasks) which is stored in an diff --git a/taskflow/examples/create_parallel_volume.py b/taskflow/examples/create_parallel_volume.py index add1baaf0..6df762965 100644 --- a/taskflow/examples/create_parallel_volume.py +++ b/taskflow/examples/create_parallel_volume.py @@ -21,9 +21,9 @@ import time logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from oslo_utils import reflection diff --git a/taskflow/examples/delayed_return.py b/taskflow/examples/delayed_return.py index 04d840cb1..52a776cfa 100644 --- a/taskflow/examples/delayed_return.py +++ b/taskflow/examples/delayed_return.py @@ -21,9 +21,9 @@ from concurrent import futures logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) sys.path.insert(0, self_dir) @@ -45,7 +45,8 @@ class PokeFutureListener(base.Listener): super().__init__( engine, task_listen_for=(notifier.Notifier.ANY,), - flow_listen_for=[]) + flow_listen_for=[], + ) self._future = future self._task_name = task_name diff --git a/taskflow/examples/distance_calculator.py b/taskflow/examples/distance_calculator.py index f820c1071..8e951bd0e 100644 --- a/taskflow/examples/distance_calculator.py +++ b/taskflow/examples/distance_calculator.py @@ -17,9 +17,9 @@ import math import os import sys -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from taskflow import engines @@ -60,24 +60,39 @@ if __name__ == '__main__': any_distance = linear_flow.Flow("origin").add(DistanceTask()) results = engines.run(any_distance) print(results) - print("{} is near-enough to {}: {}".format( - results['distance'], 0.0, is_near(results['distance'], 0.0))) + print( + "{} is near-enough to {}: {}".format( + results['distance'], 0.0, is_near(results['distance'], 0.0) + ) + ) results = engines.run(any_distance, store={'a': Point(1, 1)}) print(results) - print("{} is near-enough to {}: {}".format( - results['distance'], 1.4142, is_near(results['distance'], 1.4142))) + print( + "{} is near-enough to {}: {}".format( + results['distance'], 1.4142, is_near(results['distance'], 1.4142) + ) + ) results = engines.run(any_distance, store={'a': Point(10, 10)}) print(results) - print("{} is near-enough to {}: {}".format( - results['distance'], 14.14199, is_near(results['distance'], 14.14199))) + print( + "{} is near-enough to {}: {}".format( + results['distance'], + 14.14199, + is_near(results['distance'], 14.14199), + ) + ) - results = engines.run(any_distance, - store={'a': Point(5, 5), 'b': Point(10, 10)}) + results = engines.run( + any_distance, store={'a': Point(5, 5), 'b': Point(10, 10)} + ) print(results) - print("{} is near-enough to {}: {}".format( - results['distance'], 7.07106, is_near(results['distance'], 7.07106))) + print( + "{} is near-enough to {}: {}".format( + results['distance'], 7.07106, is_near(results['distance'], 7.07106) + ) + ) # For this we use the ability to override at task creation time the # optional arguments so that we don't need to continue to send them @@ -88,10 +103,18 @@ if __name__ == '__main__': ten_distance.add(DistanceTask(inject={'a': Point(10, 10)})) results = engines.run(ten_distance, store={'b': Point(10, 10)}) print(results) - print("{} is near-enough to {}: {}".format( - results['distance'], 0.0, is_near(results['distance'], 0.0))) + print( + "{} is near-enough to {}: {}".format( + results['distance'], 0.0, is_near(results['distance'], 0.0) + ) + ) results = engines.run(ten_distance) print(results) - print("{} is near-enough to {}: {}".format( - results['distance'], 14.14199, is_near(results['distance'], 14.14199))) + print( + "{} is near-enough to {}: {}".format( + results['distance'], + 14.14199, + is_near(results['distance'], 14.14199), + ) + ) diff --git a/taskflow/examples/dump_memory_backend.py b/taskflow/examples/dump_memory_backend.py index 798b15c62..4daf8ec52 100644 --- a/taskflow/examples/dump_memory_backend.py +++ b/taskflow/examples/dump_memory_backend.py @@ -19,9 +19,9 @@ import sys logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) sys.path.insert(0, self_dir) @@ -39,6 +39,7 @@ class PrintTask(task.Task): def execute(self): print("Running '%s'" % self.name) + # Make a little flow and run it... f = lf.Flow('root') for alpha in ['a', 'b', 'c']: diff --git a/taskflow/examples/echo_listener.py b/taskflow/examples/echo_listener.py index 8e62a7726..8bbb33ceb 100644 --- a/taskflow/examples/echo_listener.py +++ b/taskflow/examples/echo_listener.py @@ -18,9 +18,9 @@ import sys logging.basicConfig(level=logging.DEBUG) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from taskflow import engines diff --git a/taskflow/examples/example_utils.py b/taskflow/examples/example_utils.py index 45fc3da7c..a958f0ed4 100644 --- a/taskflow/examples/example_utils.py +++ b/taskflow/examples/example_utils.py @@ -28,6 +28,7 @@ LOG = logging.getLogger(__name__) try: import sqlalchemy as _sa # noqa + SQLALCHEMY_AVAILABLE = True except ImportError: SQLALCHEMY_AVAILABLE = False @@ -93,8 +94,11 @@ def get_backend(backend_uri=None): if not tmp_dir: tmp_dir = tempfile.mkdtemp() backend_uri = "file:///%s" % tmp_dir - LOG.exception("Falling back to file backend using temporary" - " directory located at: %s", tmp_dir) + LOG.exception( + "Falling back to file backend using temporary" + " directory located at: %s", + tmp_dir, + ) backend = backends.fetch(_make_conf(backend_uri)) else: raise e diff --git a/taskflow/examples/fake_billing.py b/taskflow/examples/fake_billing.py index 435e3a395..4abed7e17 100644 --- a/taskflow/examples/fake_billing.py +++ b/taskflow/examples/fake_billing.py @@ -20,9 +20,9 @@ import time logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from oslo_utils import uuidutils @@ -130,8 +130,11 @@ class ActivateDriver(task.Task): # that the url sending helper class uses. This allows the task progress # to be tied to the url sending progress, which is very useful for # downstream systems to be aware of what a task is doing at any time. - url_sender.send(self._url, json.dumps(parsed_request), - status_cb=self.update_progress) + url_sender.send( + self._url, + json.dumps(parsed_request), + status_cb=self.update_progress, + ) return self._url def update_progress(self, progress, **kwargs): diff --git a/taskflow/examples/graph_flow.py b/taskflow/examples/graph_flow.py index e90d8158c..6ccebf475 100644 --- a/taskflow/examples/graph_flow.py +++ b/taskflow/examples/graph_flow.py @@ -18,9 +18,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) import taskflow.engines @@ -46,7 +46,6 @@ from taskflow import task class Adder(task.Task): - def execute(self, x, y): return x + y @@ -56,7 +55,7 @@ flow = gf.Flow('root').add( # x2 = y3+y4 = 12 Adder("add2", provides='x2', rebind=['y3', 'y4']), # x1 = y1+y2 = 4 - Adder("add1", provides='x1', rebind=['y1', 'y2']) + Adder("add1", provides='x1', rebind=['y1', 'y2']), ), # x5 = x1+x3 = 20 Adder("add5", provides='x5', rebind=['x1', 'x3']), @@ -67,7 +66,8 @@ flow = gf.Flow('root').add( # x6 = x5+x4 = 41 Adder("add6", provides='x6', rebind=['x5', 'x4']), # x7 = x6+x6 = 82 - Adder("add7", provides='x7', rebind=['x6', 'x6'])) + Adder("add7", provides='x7', rebind=['x6', 'x6']), +) # Provide the initial variable inputs using a storage dictionary. store = { @@ -90,21 +90,19 @@ expected = [ ('x7', 82), ] -result = taskflow.engines.run( - flow, engine='serial', store=store) +result = taskflow.engines.run(flow, engine='serial', store=store) print("Single threaded engine result %s" % result) -for (name, value) in expected: +for name, value in expected: actual = result.get(name) if actual != value: sys.stderr.write(f"{actual} != {value}\n") unexpected += 1 -result = taskflow.engines.run( - flow, engine='parallel', store=store) +result = taskflow.engines.run(flow, engine='parallel', store=store) print("Multi threaded engine result %s" % result) -for (name, value) in expected: +for name, value in expected: actual = result.get(name) if actual != value: sys.stderr.write(f"{actual} != {value}\n") diff --git a/taskflow/examples/hello_world.py b/taskflow/examples/hello_world.py index e048d8f6a..675eb904d 100644 --- a/taskflow/examples/hello_world.py +++ b/taskflow/examples/hello_world.py @@ -18,9 +18,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from taskflow import engines @@ -34,6 +34,7 @@ from taskflow import task # engines using different styles of execution (all can be used to run in # parallel if a workflow is provided that is parallelizable). + class PrinterTask(task.Task): def __init__(self, name, show_name=True, inject=None): super().__init__(name, inject=inject) @@ -55,26 +56,33 @@ song = lf.Flow("beats") # singing at once of course! hi_chorus = uf.Flow('hello') world_chorus = uf.Flow('world') -for (name, hello, world) in [('bob', 'hello', 'world'), - ('joe', 'hellooo', 'worllllld'), - ('sue', "helloooooo!", 'wooorllld!')]: - hi_chorus.add(PrinterTask("%s@hello" % name, - # This will show up to the execute() method of - # the task as the argument named 'output' (which - # will allow us to print the character we want). - inject={'output': hello})) - world_chorus.add(PrinterTask("%s@world" % name, - inject={'output': world})) +for name, hello, world in [ + ('bob', 'hello', 'world'), + ('joe', 'hellooo', 'worllllld'), + ('sue', "helloooooo!", 'wooorllld!'), +]: + hi_chorus.add( + PrinterTask( + "%s@hello" % name, + # This will show up to the execute() method of + # the task as the argument named 'output' (which + # will allow us to print the character we want). + inject={'output': hello}, + ) + ) + world_chorus.add(PrinterTask("%s@world" % name, inject={'output': world})) # The composition starts with the conductor and then runs in sequence with # the chorus running in parallel, but no matter what the 'hello' chorus must # always run before the 'world' chorus (otherwise the world will fall apart). -song.add(PrinterTask("conductor@begin", - show_name=False, inject={'output': "*ding*"}), - hi_chorus, - world_chorus, - PrinterTask("conductor@end", - show_name=False, inject={'output': "*dong*"})) +song.add( + PrinterTask( + "conductor@begin", show_name=False, inject={'output': "*ding*"} + ), + hi_chorus, + world_chorus, + PrinterTask("conductor@end", show_name=False, inject={'output': "*dong*"}), +) # Run in parallel using eventlet green threads... try: @@ -84,22 +92,21 @@ except ImportError: pass else: print("-- Running in parallel using eventlet --") - e = engines.load(song, executor='greenthreaded', engine='parallel', - max_workers=1) + e = engines.load( + song, executor='greenthreaded', engine='parallel', max_workers=1 + ) e.run() # Run in parallel using real threads... print("-- Running in parallel using threads --") -e = engines.load(song, executor='threaded', engine='parallel', - max_workers=1) +e = engines.load(song, executor='threaded', engine='parallel', max_workers=1) e.run() # Run in parallel using external processes... print("-- Running in parallel using processes --") -e = engines.load(song, executor='processes', engine='parallel', - max_workers=1) +e = engines.load(song, executor='processes', engine='parallel', max_workers=1) e.run() diff --git a/taskflow/examples/jobboard_produce_consume_colors.py b/taskflow/examples/jobboard_produce_consume_colors.py index b7f7b374e..36c0db783 100644 --- a/taskflow/examples/jobboard_produce_consume_colors.py +++ b/taskflow/examples/jobboard_produce_consume_colors.py @@ -23,9 +23,9 @@ import time logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from taskflow import exceptions as excp @@ -121,10 +121,12 @@ def worker(ident, client, consumed): abandoned_jobs += 1 safe_print(name, "'%s' [abandoned]" % (job)) time.sleep(WORKER_DELAY) - safe_print(name, - "finished (claimed %s jobs, consumed %s jobs," - " abandoned %s jobs)" % (claimed_jobs, consumed_jobs, - abandoned_jobs), prefix=">>>") + safe_print( + name, + "finished (claimed %s jobs, consumed %s jobs," + " abandoned %s jobs)" % (claimed_jobs, consumed_jobs, abandoned_jobs), + prefix=">>>", + ) def producer(ident, client): @@ -149,6 +151,7 @@ def main(): # TODO(harlowja): Hack to make eventlet work right, remove when the # following is fixed: https://github.com/eventlet/eventlet/issues/230 from taskflow.utils import eventlet_utils as _eu # noqa + try: import eventlet as _eventlet # noqa except ImportError: diff --git a/taskflow/examples/parallel_table_multiply.py b/taskflow/examples/parallel_table_multiply.py index 61da4566b..7bd87fe27 100644 --- a/taskflow/examples/parallel_table_multiply.py +++ b/taskflow/examples/parallel_table_multiply.py @@ -20,9 +20,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) import futurist diff --git a/taskflow/examples/persistence_example.py b/taskflow/examples/persistence_example.py index f7045e058..34b4e17e8 100644 --- a/taskflow/examples/persistence_example.py +++ b/taskflow/examples/persistence_example.py @@ -21,9 +21,9 @@ import traceback logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) sys.path.insert(0, self_dir) @@ -95,8 +95,7 @@ with eu.get_backend(backend_uri) as backend: flow = make_flow(blowup=blowup) eu.print_wrapped("Running") try: - eng = engines.load(flow, engine='serial', - backend=backend, book=book) + eng = engines.load(flow, engine='serial', backend=backend, book=book) eng.run() if not blowup: eu.rm_path(persist_path) diff --git a/taskflow/examples/pseudo_scoping.py b/taskflow/examples/pseudo_scoping.py index 736d050ab..3cd052324 100644 --- a/taskflow/examples/pseudo_scoping.py +++ b/taskflow/examples/pseudo_scoping.py @@ -18,9 +18,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) import taskflow.engines @@ -43,12 +43,7 @@ from taskflow import task # his or her phone number from phone book and call. -PHONE_BOOK = { - 'jim': '444', - 'joe': '555', - 'iv_m': '666', - 'josh': '777' -} +PHONE_BOOK = {'jim': '444', 'joe': '555', 'iv_m': '666', 'josh': '777'} class FetchNumberTask(task.Task): @@ -67,11 +62,10 @@ class CallTask(task.Task): def execute(self, person, number): print(f'Calling {person} {number}.') + # This is how it works for one person: -simple_flow = lf.Flow('simple one').add( - FetchNumberTask(), - CallTask()) +simple_flow = lf.Flow('simple one').add(FetchNumberTask(), CallTask()) print('Running simple flow:') taskflow.engines.run(simple_flow, store={'person': 'Josh'}) @@ -85,11 +79,10 @@ def subflow_factory(prefix): return f'{prefix}-{what}' return lf.Flow(pr('flow')).add( - FetchNumberTask(pr('fetch'), - provides=pr('number'), - rebind=[pr('person')]), - CallTask(pr('call'), - rebind=[pr('person'), pr('number')]) + FetchNumberTask( + pr('fetch'), provides=pr('number'), rebind=[pr('person')] + ), + CallTask(pr('call'), rebind=[pr('person'), pr('number')]), ) @@ -107,5 +100,6 @@ def call_them_all(): flow.add(subflow_factory(prefix)) taskflow.engines.run(flow, store=persons) + print('\nCalling many people using prefixed factory:') call_them_all() diff --git a/taskflow/examples/resume_from_backend.py b/taskflow/examples/resume_from_backend.py index 47fcdbb7d..538133eb0 100644 --- a/taskflow/examples/resume_from_backend.py +++ b/taskflow/examples/resume_from_backend.py @@ -20,9 +20,9 @@ import sys logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) sys.path.insert(0, self_dir) @@ -63,8 +63,9 @@ def print_task_states(flowdetail, msg): print(f"Flow '{flowdetail.name}' state: {flowdetail.state}") # Sort by these so that our test validation doesn't get confused by the # order in which the items in the flow detail can be in. - items = sorted((td.name, td.version, td.state, td.results) - for td in flowdetail) + items = sorted( + (td.name, td.version, td.state, td.results) for td in flowdetail + ) for item in items: print(" %s==%s: %s, result=%s" % item) @@ -94,17 +95,18 @@ def flow_factory(): return lf.Flow('resume from backend example').add( TestTask(name='first'), InterruptTask(name='boom'), - TestTask(name='second')) + TestTask(name='second'), + ) # INITIALIZE PERSISTENCE #################################### with eu.get_backend() as backend: - # Create a place where the persistence information will be stored. book = models.LogBook("example") - flow_detail = models.FlowDetail("resume from backend example", - uuid=uuidutils.generate_uuid()) + flow_detail = models.FlowDetail( + "resume from backend example", uuid=uuidutils.generate_uuid() + ) book.add(flow_detail) with contextlib.closing(backend.get_connection()) as conn: conn.save_logbook(book) @@ -112,8 +114,9 @@ with eu.get_backend() as backend: # CREATE AND RUN THE FLOW: FIRST ATTEMPT #################### flow = flow_factory() - engine = taskflow.engines.load(flow, flow_detail=flow_detail, - book=book, backend=backend) + engine = taskflow.engines.load( + flow, flow_detail=flow_detail, book=book, backend=backend + ) print_task_states(flow_detail, "At the beginning, there is no state") eu.print_wrapped("Running") @@ -135,8 +138,8 @@ with eu.get_backend() as backend: # running the above flow crashes). flow2 = flow_factory() flow_detail_2 = find_flow_detail(backend, book.uuid, flow_detail.uuid) - engine2 = taskflow.engines.load(flow2, - flow_detail=flow_detail_2, - backend=backend, book=book) + engine2 = taskflow.engines.load( + flow2, flow_detail=flow_detail_2, backend=backend, book=book + ) engine2.run() print_task_states(flow_detail_2, "At the end") diff --git a/taskflow/examples/resume_many_flows.py b/taskflow/examples/resume_many_flows.py index 0f1fce045..07f47fa18 100644 --- a/taskflow/examples/resume_many_flows.py +++ b/taskflow/examples/resume_many_flows.py @@ -42,9 +42,9 @@ def _exec(cmd, add_env=None): env = os.environ.copy() env.update(add_env) - proc = subprocess.Popen(cmd, env=env, stdin=None, - stdout=subprocess.PIPE, - stderr=sys.stderr) + proc = subprocess.Popen( + cmd, env=env, stdin=None, stdout=subprocess.PIPE, stderr=sys.stderr + ) stdout, _stderr = proc.communicate() rc = proc.returncode @@ -54,8 +54,9 @@ def _exec(cmd, add_env=None): def _path_to(name): - return os.path.abspath(os.path.join(os.path.dirname(__file__), - 'resume_many_flows', name)) + return os.path.abspath( + os.path.join(os.path.dirname(__file__), 'resume_many_flows', name) + ) def main(): @@ -87,5 +88,6 @@ def main(): if tmp_path: example_utils.rm_path(tmp_path) + if __name__ == '__main__': main() diff --git a/taskflow/examples/resume_many_flows/my_flows.py b/taskflow/examples/resume_many_flows/my_flows.py index 284fb7dc7..ecea82d36 100644 --- a/taskflow/examples/resume_many_flows/my_flows.py +++ b/taskflow/examples/resume_many_flows/my_flows.py @@ -38,4 +38,5 @@ def flow_factory(): return lf.Flow('example').add( TestTask(name='first'), UnfortunateTask(name='boom'), - TestTask(name='second')) + TestTask(name='second'), + ) diff --git a/taskflow/examples/resume_many_flows/resume_all.py b/taskflow/examples/resume_many_flows/resume_all.py index cc8098b67..37c340854 100644 --- a/taskflow/examples/resume_many_flows/resume_all.py +++ b/taskflow/examples/resume_many_flows/resume_all.py @@ -20,7 +20,8 @@ logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) top_dir = os.path.abspath( - os.path.join(self_dir, os.pardir, os.pardir, os.pardir)) + os.path.join(self_dir, os.pardir, os.pardir, os.pardir) +) example_dir = os.path.abspath(os.path.join(self_dir, os.pardir)) sys.path.insert(0, top_dir) @@ -38,8 +39,9 @@ FINISHED_STATES = (states.SUCCESS, states.FAILURE, states.REVERTED) def resume(flowdetail, backend): print(f'Resuming flow {flowdetail.name} {flowdetail.uuid}') - engine = taskflow.engines.load_from_detail(flow_detail=flowdetail, - backend=backend) + engine = taskflow.engines.load_from_detail( + flow_detail=flowdetail, backend=backend + ) engine.run() diff --git a/taskflow/examples/resume_many_flows/run_flow.py b/taskflow/examples/resume_many_flows/run_flow.py index 5cbf9731f..2b21fa44c 100644 --- a/taskflow/examples/resume_many_flows/run_flow.py +++ b/taskflow/examples/resume_many_flows/run_flow.py @@ -20,7 +20,8 @@ logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) top_dir = os.path.abspath( - os.path.join(self_dir, os.pardir, os.pardir, os.pardir)) + os.path.join(self_dir, os.pardir, os.pardir, os.pardir) +) example_dir = os.path.abspath(os.path.join(self_dir, os.pardir)) sys.path.insert(0, top_dir) @@ -34,8 +35,12 @@ import my_flows # noqa with example_utils.get_backend() as backend: - engine = taskflow.engines.load_from_factory(my_flows.flow_factory, - backend=backend) - print('Running flow {} {}'.format(engine.storage.flow_name, - engine.storage.flow_uuid)) + engine = taskflow.engines.load_from_factory( + my_flows.flow_factory, backend=backend + ) + print( + 'Running flow {} {}'.format( + engine.storage.flow_name, engine.storage.flow_uuid + ) + ) engine.run() diff --git a/taskflow/examples/resume_vm_boot.py b/taskflow/examples/resume_vm_boot.py index 8ba8ac187..dc71ee078 100644 --- a/taskflow/examples/resume_vm_boot.py +++ b/taskflow/examples/resume_vm_boot.py @@ -23,9 +23,9 @@ import time logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) sys.path.insert(0, self_dir) @@ -59,6 +59,7 @@ def slow_down(how_long=0.5): class PrintText(task.Task): """Just inserts some text print outs in a workflow.""" + def __init__(self, print_what, no_slow=False): content_hash = hashlib.md5(print_what.encode('utf-8')).hexdigest()[0:8] super().__init__(name="Print: %s" % (content_hash)) @@ -75,6 +76,7 @@ class PrintText(task.Task): class DefineVMSpec(task.Task): """Defines a vm specification to be.""" + def __init__(self, name): super().__init__(provides='vm_spec', name=name) @@ -90,6 +92,7 @@ class DefineVMSpec(task.Task): class LocateImages(task.Task): """Locates where the vm images are.""" + def __init__(self, name): super().__init__(provides='image_locations', name=name) @@ -103,9 +106,9 @@ class LocateImages(task.Task): class DownloadImages(task.Task): """Downloads all the vm images.""" + def __init__(self, name): - super().__init__(provides='download_paths', - name=name) + super().__init__(provides='download_paths', name=name) def execute(self, image_locations): for src, loc in image_locations.items(): @@ -116,14 +119,14 @@ class DownloadImages(task.Task): class CreateNetworkTpl(task.Task): """Generates the network settings file to be placed in the images.""" + SYSCONFIG_CONTENTS = """DEVICE=eth%s BOOTPROTO=static IPADDR=%s ONBOOT=yes""" def __init__(self, name): - super().__init__(provides='network_settings', - name=name) + super().__init__(provides='network_settings', name=name) def execute(self, ips): settings = [] @@ -134,6 +137,7 @@ ONBOOT=yes""" class AllocateIP(task.Task): """Allocates the ips for the given vm.""" + def __init__(self, name): super().__init__(provides='ips', name=name) @@ -146,13 +150,15 @@ class AllocateIP(task.Task): class WriteNetworkSettings(task.Task): """Writes all the network settings into the downloaded images.""" + def execute(self, download_paths, network_settings): for j, path in enumerate(download_paths): with slow_down(1): print(f"Mounting {path} to /tmp/{j}") for i, setting in enumerate(network_settings): - filename = ("/tmp/etc/sysconfig/network-scripts/" - "ifcfg-eth%s" % (i)) + filename = "/tmp/etc/sysconfig/network-scripts/ifcfg-eth%s" % ( + i + ) with slow_down(1): print("Writing to %s" % (filename)) print(setting) @@ -160,6 +166,7 @@ class WriteNetworkSettings(task.Task): class BootVM(task.Task): """Fires off the vm boot operation.""" + def execute(self, vm_spec): print("Starting vm!") with slow_down(1): @@ -168,6 +175,7 @@ class BootVM(task.Task): class AllocateVolumes(task.Task): """Allocates the volumes for the vm.""" + def execute(self, vm_spec): volumes = [] for i in range(0, vm_spec['volumes']): @@ -179,6 +187,7 @@ class AllocateVolumes(task.Task): class FormatVolumes(task.Task): """Formats the volumes for the vm.""" + def execute(self, volumes): for v in volumes: print("Formatting volume %s" % v) @@ -215,14 +224,15 @@ def create_flow(): ), # Ya it worked! PrintText("Finished vm create.", no_slow=True), - PrintText("Instance is running!", no_slow=True)) + PrintText("Instance is running!", no_slow=True), + ) return flow + eu.print_wrapped("Initializing") # Setup the persistence & resumption layer. with eu.get_backend() as backend: - # Try to find a previously passed in tracking id... try: book_id, flow_id = sys.argv[2].split("+", 1) @@ -256,17 +266,24 @@ with eu.get_backend() as backend: book = models.LogBook("vm-boot") with contextlib.closing(backend.get_connection()) as conn: conn.save_logbook(book) - engine = engines.load_from_factory(create_flow, - backend=backend, book=book, - engine='parallel', - executor=executor) - print("!! Your tracking id is: '{}+{}'".format( - book.uuid, engine.storage.flow_uuid)) + engine = engines.load_from_factory( + create_flow, + backend=backend, + book=book, + engine='parallel', + executor=executor, + ) + print( + "!! Your tracking id is: '{}+{}'".format( + book.uuid, engine.storage.flow_uuid + ) + ) print("!! Please submit this on later runs for tracking purposes") else: # Attempt to load from a previously partially completed flow. - engine = engines.load_from_detail(flow_detail, backend=backend, - engine='parallel', executor=executor) + engine = engines.load_from_detail( + flow_detail, backend=backend, engine='parallel', executor=executor + ) # Make me my vm please! eu.print_wrapped('Running') diff --git a/taskflow/examples/resume_volume_create.py b/taskflow/examples/resume_volume_create.py index 6ea46259c..9521feacf 100644 --- a/taskflow/examples/resume_volume_create.py +++ b/taskflow/examples/resume_volume_create.py @@ -23,9 +23,9 @@ import time logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) sys.path.insert(0, self_dir) @@ -90,10 +90,12 @@ class CreateSpecForVolumes(task.Task): def execute(self): volumes = [] for i in range(0, random.randint(1, 10)): - volumes.append({ - 'type': 'disk', - 'location': "/dev/vda%s" % (i + 1), - }) + volumes.append( + { + 'type': 'disk', + 'location': "/dev/vda%s" % (i + 1), + } + ) return volumes @@ -115,7 +117,8 @@ flow = lf.Flow("root").add( PrintText("I need a nap, it took me a while to build those specs."), PrepareVolumes(), ), - PrintText("Finished volume create", no_slow=True)) + PrintText("Finished volume create", no_slow=True), +) # Setup the persistence & resumption layer. with example_utils.get_backend() as backend: @@ -139,16 +142,19 @@ with example_utils.get_backend() as backend: book.add(flow_detail) with contextlib.closing(backend.get_connection()) as conn: conn.save_logbook(book) - print("!! Your tracking id is: '{}+{}'".format(book.uuid, - flow_detail.uuid)) + print( + "!! Your tracking id is: '{}+{}'".format( + book.uuid, flow_detail.uuid + ) + ) print("!! Please submit this on later runs for tracking purposes") else: flow_detail = find_flow_detail(backend, book_id, flow_id) # Load and run. - engine = engines.load(flow, - flow_detail=flow_detail, - backend=backend, engine='serial') + engine = engines.load( + flow, flow_detail=flow_detail, backend=backend, engine='serial' + ) engine.run() # How to use. diff --git a/taskflow/examples/retry_flow.py b/taskflow/examples/retry_flow.py index bbdecf105..24d7daf93 100644 --- a/taskflow/examples/retry_flow.py +++ b/taskflow/examples/retry_flow.py @@ -18,9 +18,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) import taskflow.engines @@ -52,10 +52,12 @@ class CallJim(task.Task): # Create your flow and associated tasks (the work to be done). -flow = lf.Flow('retrying-linear', - retry=retry.ParameterizedForEach( - rebind=['phone_directory'], - provides='jim_number')).add(CallJim()) +flow = lf.Flow( + 'retrying-linear', + retry=retry.ParameterizedForEach( + rebind=['phone_directory'], provides='jim_number' + ), +).add(CallJim()) # Now run that flow using the provided initial data (store below). taskflow.engines.run(flow, store={'phone_directory': [333, 444, 555, 666]}) diff --git a/taskflow/examples/reverting_linear.py b/taskflow/examples/reverting_linear.py index cb3e965ff..fe609b42d 100644 --- a/taskflow/examples/reverting_linear.py +++ b/taskflow/examples/reverting_linear.py @@ -18,9 +18,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) import taskflow.engines @@ -61,17 +61,13 @@ class CallSuzzie(task.Task): # Create your flow and associated tasks (the work to be done). -flow = lf.Flow('simple-linear').add( - CallJim(), - CallJoe(), - CallSuzzie() -) +flow = lf.Flow('simple-linear').add(CallJim(), CallJoe(), CallSuzzie()) try: # Now run that flow using the provided initial data (store below). - taskflow.engines.run(flow, store=dict(joe_number=444, - jim_number=555, - suzzie_number=666)) + taskflow.engines.run( + flow, store=dict(joe_number=444, jim_number=555, suzzie_number=666) + ) except Exception as e: # NOTE(harlowja): This exception will be the exception that came out of the # 'CallSuzzie' task instead of a different exception, this is useful since diff --git a/taskflow/examples/run_by_iter.py b/taskflow/examples/run_by_iter.py index 70956fcf9..4f4cd26e0 100644 --- a/taskflow/examples/run_by_iter.py +++ b/taskflow/examples/run_by_iter.py @@ -19,9 +19,9 @@ import sys logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) sys.path.insert(0, self_dir) @@ -51,12 +51,19 @@ def make_alphabet_flow(i): while ord(curr_value) <= ord(end_value): next_value = chr(ord(curr_value) + 1) if curr_value != end_value: - f.add(EchoTask(name="echoer_%s" % curr_value, - rebind={'value': curr_value}, - provides=next_value)) + f.add( + EchoTask( + name="echoer_%s" % curr_value, + rebind={'value': curr_value}, + provides=next_value, + ) + ) else: - f.add(EchoTask(name="echoer_%s" % curr_value, - rebind={'value': curr_value})) + f.add( + EchoTask( + name="echoer_%s" % curr_value, rebind={'value': curr_value} + ) + ) curr_value = next_value return f diff --git a/taskflow/examples/run_by_iter_enumerate.py b/taskflow/examples/run_by_iter_enumerate.py index 91c613019..211ea4dbc 100644 --- a/taskflow/examples/run_by_iter_enumerate.py +++ b/taskflow/examples/run_by_iter_enumerate.py @@ -19,9 +19,9 @@ import sys logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) sys.path.insert(0, self_dir) diff --git a/taskflow/examples/share_engine_thread.py b/taskflow/examples/share_engine_thread.py index 7984e6fff..611c8fa01 100644 --- a/taskflow/examples/share_engine_thread.py +++ b/taskflow/examples/share_engine_thread.py @@ -20,9 +20,9 @@ import time logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) import futurist diff --git a/taskflow/examples/simple_linear.py b/taskflow/examples/simple_linear.py index 989a62227..26c2cf079 100644 --- a/taskflow/examples/simple_linear.py +++ b/taskflow/examples/simple_linear.py @@ -18,9 +18,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) import taskflow.engines @@ -54,11 +54,7 @@ class CallJoe(task.Task): # Create your flow and associated tasks (the work to be done). -flow = lf.Flow('simple-linear').add( - CallJim(), - CallJoe() -) +flow = lf.Flow('simple-linear').add(CallJim(), CallJoe()) # Now run that flow using the provided initial data (store below). -taskflow.engines.run(flow, store=dict(joe_number=444, - jim_number=555)) +taskflow.engines.run(flow, store=dict(joe_number=444, jim_number=555)) diff --git a/taskflow/examples/simple_linear_listening.py b/taskflow/examples/simple_linear_listening.py index 3a93569b9..bd3ede1b0 100644 --- a/taskflow/examples/simple_linear_listening.py +++ b/taskflow/examples/simple_linear_listening.py @@ -18,9 +18,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) import taskflow.engines @@ -82,12 +82,15 @@ flow.add(task.FunctorTask(execute=call_jim)) flow.add(task.FunctorTask(execute=call_joe)) # Now load (but do not run) the flow using the provided initial data. -engine = taskflow.engines.load(flow, store={ - 'context': { - "joe_number": 444, - "jim_number": 555, - } -}) +engine = taskflow.engines.load( + flow, + store={ + 'context': { + "joe_number": 444, + "jim_number": 555, + } + }, +) # This is where we attach our callback functions to the 2 different # notification objects that an engine exposes. The usage of a ANY (kleene star) diff --git a/taskflow/examples/simple_linear_pass.py b/taskflow/examples/simple_linear_pass.py index 68aa9ecde..91b9375f2 100644 --- a/taskflow/examples/simple_linear_pass.py +++ b/taskflow/examples/simple_linear_pass.py @@ -19,9 +19,9 @@ import sys logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) sys.path.insert(0, self_dir) diff --git a/taskflow/examples/simple_map_reduce.py b/taskflow/examples/simple_map_reduce.py index 80a1e8c9e..6fd73c00a 100644 --- a/taskflow/examples/simple_map_reduce.py +++ b/taskflow/examples/simple_map_reduce.py @@ -19,9 +19,9 @@ import sys logging.basicConfig(level=logging.ERROR) self_dir = os.path.abspath(os.path.dirname(__file__)) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) sys.path.insert(0, self_dir) @@ -47,7 +47,7 @@ class TotalReducer(task.Task): def execute(self, *args, **kwargs): # Reduces all mapped summed outputs into a single value. total = 0 - for (k, v) in kwargs.items(): + for k, v in kwargs.items(): # If any other kwargs was passed in, we don't want to use those # in the calculation of the total... if k.startswith('reduction_'): @@ -88,9 +88,13 @@ for i, chunk in enumerate(chunk_iter(CHUNK_SIZE, UPPER_BOUND)): # The reducer uses all of the outputs of the mappers, so it needs # to be recorded that it needs access to them (under a specific name). provided.append("reduction_%s" % i) - mappers.add(SumMapper(name=mapper_name, - rebind={'inputs': mapper_name}, - provides=provided[-1])) + mappers.add( + SumMapper( + name=mapper_name, + rebind={'inputs': mapper_name}, + provides=provided[-1], + ) + ) w.add(mappers) # The reducer will run last (after all the mappers). diff --git a/taskflow/examples/switch_graph_flow.py b/taskflow/examples/switch_graph_flow.py index fb4c19313..5a6c7a99c 100644 --- a/taskflow/examples/switch_graph_flow.py +++ b/taskflow/examples/switch_graph_flow.py @@ -18,9 +18,9 @@ import sys logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from taskflow import engines @@ -56,8 +56,10 @@ print("---------") print("After run") print("---------") backend = e.storage.backend -entries = [os.path.join(backend.memory.root_path, child) - for child in backend.memory.ls(backend.memory.root_path)] +entries = [ + os.path.join(backend.memory.root_path, child) + for child in backend.memory.ls(backend.memory.root_path) +] while entries: path = entries.pop() value = backend.memory[path] @@ -65,5 +67,6 @@ while entries: print(f"{path} -> {value}") else: print("%s" % (path)) - entries.extend(os.path.join(path, child) - for child in backend.memory.ls(path)) + entries.extend( + os.path.join(path, child) for child in backend.memory.ls(path) + ) diff --git a/taskflow/examples/timing_listener.py b/taskflow/examples/timing_listener.py index 902a4a1fc..d9e67133c 100644 --- a/taskflow/examples/timing_listener.py +++ b/taskflow/examples/timing_listener.py @@ -20,9 +20,9 @@ import time logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from taskflow import engines diff --git a/taskflow/examples/tox_conductor.py b/taskflow/examples/tox_conductor.py index 159ee8b12..6b9d6fc92 100644 --- a/taskflow/examples/tox_conductor.py +++ b/taskflow/examples/tox_conductor.py @@ -25,9 +25,9 @@ import time logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from oslo_utils import timeutils @@ -128,7 +128,7 @@ def create_review_workflow(): f.add( MakeTempDir(name="maker"), RunReview(name="runner"), - CleanResources(name="cleaner") + CleanResources(name="cleaner"), ) return f @@ -137,14 +137,16 @@ def generate_reviewer(client, saver, name=NAME): """Creates a review producer thread with the given name prefix.""" real_name = "%s_reviewer" % name no_more = threading.Event() - jb = boards.fetch(real_name, JOBBOARD_CONF, - client=client, persistence=saver) + jb = boards.fetch( + real_name, JOBBOARD_CONF, client=client, persistence=saver + ) def make_save_book(saver, review_id): # Record what we want to happen (sometime in the future). book = models.LogBook("book_%s" % review_id) - detail = models.FlowDetail("flow_%s" % review_id, - uuidutils.generate_uuid()) + detail = models.FlowDetail( + "flow_%s" % review_id, uuidutils.generate_uuid() + ) book.add(detail) # Associate the factory method we want to be called (in the future) # with the book, so that the conductor will be able to call into @@ -157,8 +159,9 @@ def generate_reviewer(client, saver, name=NAME): # workflow that represents this review). factory_args = () factory_kwargs = {} - engines.save_factory_details(detail, create_review_workflow, - factory_args, factory_kwargs) + engines.save_factory_details( + detail, create_review_workflow, factory_args, factory_kwargs + ) with contextlib.closing(saver.get_connection()) as conn: conn.save_logbook(book) return book @@ -177,9 +180,11 @@ def generate_reviewer(client, saver, name=NAME): } job_name = "{}_{}".format(real_name, review['id']) print("Posting review '%s'" % review['id']) - jb.post(job_name, - book=make_save_book(saver, review['id']), - details=details) + jb.post( + job_name, + book=make_save_book(saver, review['id']), + details=details, + ) time.sleep(REVIEW_CREATION_DELAY) # Return the unstarted thread, and a callback that can be used @@ -190,10 +195,10 @@ def generate_reviewer(client, saver, name=NAME): def generate_conductor(client, saver, name=NAME): """Creates a conductor thread with the given name prefix.""" real_name = "%s_conductor" % name - jb = boards.fetch(name, JOBBOARD_CONF, - client=client, persistence=saver) - conductor = conductors.fetch("blocking", real_name, jb, - engine='parallel', wait_timeout=SCAN_DELAY) + jb = boards.fetch(name, JOBBOARD_CONF, client=client, persistence=saver) + conductor = conductors.fetch( + "blocking", real_name, jb, engine='parallel', wait_timeout=SCAN_DELAY + ) def run(): jb.connect() diff --git a/taskflow/examples/wbe_event_sender.py b/taskflow/examples/wbe_event_sender.py index 35ebf7138..3b311fab3 100644 --- a/taskflow/examples/wbe_event_sender.py +++ b/taskflow/examples/wbe_event_sender.py @@ -18,9 +18,9 @@ import string import sys import time -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from taskflow import engines diff --git a/taskflow/examples/wbe_mandelbrot.py b/taskflow/examples/wbe_mandelbrot.py index 205997ce2..eccc80f58 100644 --- a/taskflow/examples/wbe_mandelbrot.py +++ b/taskflow/examples/wbe_mandelbrot.py @@ -17,9 +17,9 @@ import math import os import sys -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from taskflow import engines @@ -118,7 +118,7 @@ def calculate(engine_conf): 'mandelbrot_config': [-2.0, 1.0, -1.0, 1.0, MAX_ITERATIONS], 'image_config': { 'size': IMAGE_SIZE, - } + }, } # We need the task names to be in the right order so that we can extract @@ -135,13 +135,16 @@ def calculate(engine_conf): # Break the calculation up into chunk size pieces. rows = [i * chunk_size, i * chunk_size + chunk_size] flow.add( - MandelCalculator(task_name, - # This ensures the storage symbol with name - # 'chunk_name' is sent into the tasks local - # symbol 'chunk'. This is how we give each - # calculator its own correct sequence of rows - # to work on. - rebind={'chunk': chunk_name})) + MandelCalculator( + task_name, + # This ensures the storage symbol with name + # 'chunk_name' is sent into the tasks local + # symbol 'chunk'. This is how we give each + # calculator its own correct sequence of rows + # to work on. + rebind={'chunk': chunk_name}, + ) + ) store[chunk_name] = rows task_names.append(task_name) @@ -161,9 +164,11 @@ def calculate(engine_conf): def write_image(results, output_filename=None): - print("Gathered %s results that represents a mandelbrot" - " image (using %s chunks that are computed jointly" - " by %s workers)." % (len(results), CHUNK_COUNT, WORKERS)) + print( + "Gathered %s results that represents a mandelbrot" + " image (using %s chunks that are computed jointly" + " by %s workers)." % (len(results), CHUNK_COUNT, WORKERS) + ) if not output_filename: return @@ -198,12 +203,14 @@ def create_fractal(): # Setup our transport configuration and merge it into the worker and # engine configuration so that both of those use it correctly. shared_conf = dict(BASE_SHARED_CONF) - shared_conf.update({ - 'transport': 'memory', - 'transport_options': { - 'polling_interval': 0.1, - }, - }) + shared_conf.update( + { + 'transport': 'memory', + 'transport_options': { + 'polling_interval': 0.1, + }, + } + ) if len(sys.argv) >= 2: output_filename = sys.argv[1] diff --git a/taskflow/examples/wbe_simple_linear.py b/taskflow/examples/wbe_simple_linear.py index 3a0b7b098..f3a109aec 100644 --- a/taskflow/examples/wbe_simple_linear.py +++ b/taskflow/examples/wbe_simple_linear.py @@ -18,9 +18,9 @@ import os import sys import tempfile -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) from taskflow import engines @@ -64,7 +64,7 @@ WORKER_CONF = { # not want to allow all python code to be executed). 'tasks': [ 'taskflow.tests.utils:TaskOneArgOneReturn', - 'taskflow.tests.utils:TaskMultiArgOneReturn' + 'taskflow.tests.utils:TaskMultiArgOneReturn', ], } @@ -72,11 +72,14 @@ WORKER_CONF = { def run(engine_options): flow = lf.Flow('simple-linear').add( utils.TaskOneArgOneReturn(provides='result1'), - utils.TaskMultiArgOneReturn(provides='result2') + utils.TaskMultiArgOneReturn(provides='result2'), + ) + eng = engines.load( + flow, + store=dict(x=111, y=222, z=333), + engine='worker-based', + **engine_options, ) - eng = engines.load(flow, - store=dict(x=111, y=222, z=333), - engine='worker-based', **engine_options) eng.run() return eng.storage.fetch_all() @@ -92,22 +95,26 @@ if __name__ == "__main__": if USE_FILESYSTEM: worker_count = FILE_WORKERS tmp_path = tempfile.mkdtemp(prefix='wbe-example-') - shared_conf.update({ - 'transport': 'filesystem', - 'transport_options': { - 'data_folder_in': tmp_path, - 'data_folder_out': tmp_path, - 'polling_interval': 0.1, - }, - }) + shared_conf.update( + { + 'transport': 'filesystem', + 'transport_options': { + 'data_folder_in': tmp_path, + 'data_folder_out': tmp_path, + 'polling_interval': 0.1, + }, + } + ) else: worker_count = MEMORY_WORKERS - shared_conf.update({ - 'transport': 'memory', - 'transport_options': { - 'polling_interval': 0.1, - }, - }) + shared_conf.update( + { + 'transport': 'memory', + 'transport_options': { + 'polling_interval': 0.1, + }, + } + ) worker_conf = dict(WORKER_CONF) worker_conf.update(shared_conf) engine_options = dict(shared_conf) diff --git a/taskflow/examples/wrapped_exception.py b/taskflow/examples/wrapped_exception.py index 472d0d6d8..e51d8ddb5 100644 --- a/taskflow/examples/wrapped_exception.py +++ b/taskflow/examples/wrapped_exception.py @@ -20,9 +20,9 @@ import time logging.basicConfig(level=logging.ERROR) -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir, - os.pardir)) +top_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) +) sys.path.insert(0, top_dir) @@ -84,14 +84,10 @@ def run(**store): # here and based on those kwargs it will behave in a different manner # while executing; this allows for the calling code (see below) to show # different usages of the failure catching and handling mechanism. - flow = uf.Flow('flow').add( - FirstTask(), - SecondTask() - ) + flow = uf.Flow('flow').add(FirstTask(), SecondTask()) try: with utils.wrap_all_failures(): - taskflow.engines.run(flow, store=store, - engine='parallel') + taskflow.engines.run(flow, store=store, engine='parallel') except exceptions.WrappedFailure as ex: unknown_failures = [] for a_failure in ex: @@ -106,20 +102,17 @@ def run(**store): eu.print_wrapped("Raise and catch first exception only") -run(sleep1=0.0, raise1=True, - sleep2=0.0, raise2=False) +run(sleep1=0.0, raise1=True, sleep2=0.0, raise2=False) # NOTE(imelnikov): in general, sleeping does not guarantee that we'll have both # task running before one of them fails, but with current implementation this # works most of times, which is enough for our purposes here (as an example). eu.print_wrapped("Raise and catch both exceptions") -run(sleep1=1.0, raise1=True, - sleep2=1.0, raise2=True) +run(sleep1=1.0, raise1=True, sleep2=1.0, raise2=True) eu.print_wrapped("Handle one exception, and re-raise another") try: - run(sleep1=1.0, raise1=True, - sleep2=1.0, raise2='boom') + run(sleep1=1.0, raise1=True, sleep2=1.0, raise2='boom') except TypeError as ex: print("As expected, TypeError is here: %s" % ex) else: diff --git a/taskflow/exceptions.py b/taskflow/exceptions.py index 29b01d5d2..ba9f7ed43 100644 --- a/taskflow/exceptions.py +++ b/taskflow/exceptions.py @@ -64,6 +64,7 @@ class TaskFlowException(Exception): creating a chain of exceptions for versions of python where this is not yet implemented/supported natively. """ + def __init__(self, message, cause=None): super().__init__(message) self._cause = cause @@ -84,8 +85,10 @@ class TaskFlowException(Exception): def pformat(self, indent=2, indent_text=" ", show_root_class=False): """Pretty formats a taskflow exception + any connected causes.""" if indent < 0: - raise ValueError("Provided 'indent' must be greater than" - " or equal to zero instead of %s" % indent) + raise ValueError( + "Provided 'indent' must be greater than" + " or equal to zero instead of %s" % indent + ) buf = io.StringIO() if show_root_class: buf.write(reflection.get_class_name(self, fully_qualified=False)) @@ -99,8 +102,9 @@ class TaskFlowException(Exception): buf.write(os.linesep) if isinstance(next_up, TaskFlowException): buf.write(indent_text * active_indent) - buf.write(reflection.get_class_name(next_up, - fully_qualified=False)) + buf.write( + reflection.get_class_name(next_up, fully_qualified=False) + ) buf.write(": ") buf.write(next_up._get_message()) else: @@ -125,18 +129,21 @@ class TaskFlowException(Exception): # Errors related to storage or operations on storage units. + class StorageFailure(TaskFlowException): """Raised when storage backends can not be read/saved/deleted.""" # Conductor related errors. + class ConductorFailure(TaskFlowException): """Errors related to conducting activities.""" # Job related errors. + class JobFailure(TaskFlowException): """Errors related to jobs or operations on jobs.""" @@ -147,6 +154,7 @@ class UnclaimableJob(JobFailure): # Engine/ during execution related errors. + class ExecutionFailure(TaskFlowException): """Errors related to engine execution.""" @@ -181,8 +189,10 @@ class MissingDependencies(DependencyFailure): """ #: Exception message template used when creating an actual message. - MESSAGE_TPL = ("'%(who)s' requires %(requirements)s but no other entity" - " produces said requirements") + MESSAGE_TPL = ( + "'%(who)s' requires %(requirements)s but no other entity" + " produces said requirements" + ) METHOD_TPL = "'%(method)s' method on " @@ -232,6 +242,7 @@ class DisallowedAccess(TaskFlowException): # Others. + class NotImplementedError(NotImplementedError): """Exception for when some functionality really isn't implemented. diff --git a/taskflow/formatters.py b/taskflow/formatters.py index b9add1cbf..ff2d101e6 100644 --- a/taskflow/formatters.py +++ b/taskflow/formatters.py @@ -63,8 +63,13 @@ class FailureFormatter: states.EXECUTE: (_fetch_predecessor_tree, 'predecessors'), } - def __init__(self, engine, hide_inputs_outputs_of=(), - mask_inputs_keys=(), mask_outputs_keys=()): + def __init__( + self, + engine, + hide_inputs_outputs_of=(), + mask_inputs_keys=(), + mask_outputs_keys=(), + ): self._hide_inputs_outputs_of = hide_inputs_outputs_of self._mask_inputs_keys = mask_inputs_keys self._mask_outputs_keys = mask_outputs_keys @@ -95,13 +100,17 @@ class FailureFormatter: atom_name = atom.name atom_attrs = {} intention, intention_found = _cached_get( - cache, 'intentions', atom_name, storage.get_atom_intention, - atom_name) + cache, + 'intentions', + atom_name, + storage.get_atom_intention, + atom_name, + ) if intention_found: atom_attrs['intention'] = intention - state, state_found = _cached_get(cache, 'states', atom_name, - storage.get_atom_state, - atom_name) + state, state_found = _cached_get( + cache, 'states', atom_name, storage.get_atom_state, atom_name + ) if state_found: atom_attrs['state'] = state if atom_name not in self._hide_inputs_outputs_of: @@ -109,27 +118,38 @@ class FailureFormatter: # will be called with the rest of these arguments # used to populate the cache. fetch_mapped_args = functools.partial( - storage.fetch_mapped_args, atom.rebind, - atom_name=atom_name, optional_args=atom.optional) - requires, requires_found = _cached_get(cache, 'requires', - atom_name, - fetch_mapped_args) + storage.fetch_mapped_args, + atom.rebind, + atom_name=atom_name, + optional_args=atom.optional, + ) + requires, requires_found = _cached_get( + cache, 'requires', atom_name, fetch_mapped_args + ) if requires_found: atom_attrs['requires'] = self._mask_keys( - requires, self._mask_inputs_keys) + requires, self._mask_inputs_keys + ) provides, provides_found = _cached_get( - cache, 'provides', atom_name, - storage.get_execute_result, atom_name) + cache, + 'provides', + atom_name, + storage.get_execute_result, + atom_name, + ) if provides_found: atom_attrs['provides'] = self._mask_keys( - provides, self._mask_outputs_keys) + provides, self._mask_outputs_keys + ) if atom_attrs: return f"Atom '{atom_name}' {atom_attrs}" else: return "Atom '%s'" % (atom_name) else: - raise TypeError("Unable to format node, unknown node" - " kind '%s' encountered" % node.metadata['kind']) + raise TypeError( + "Unable to format node, unknown node" + " kind '%s' encountered" % node.metadata['kind'] + ) def format(self, fail, atom_matcher): """Returns a (exc_info, details) tuple about the failure. @@ -173,15 +193,20 @@ class FailureFormatter: builder, kind = self._BUILDERS[atom_intention] rooted_tree = builder(graph, atom) child_count = rooted_tree.child_count(only_direct=False) - buff.write_nl( - f'{child_count} {kind} (most recent first):') + buff.write_nl(f'{child_count} {kind} (most recent first):') formatter = functools.partial(self._format_node, storage, cache) direct_child_count = rooted_tree.child_count(only_direct=True) for i, child in enumerate(rooted_tree, 1): if i == direct_child_count: - buff.write(child.pformat(stringify_node=formatter, - starting_prefix=" ")) + buff.write( + child.pformat( + stringify_node=formatter, starting_prefix=" " + ) + ) else: - buff.write_nl(child.pformat(stringify_node=formatter, - starting_prefix=" ")) + buff.write_nl( + child.pformat( + stringify_node=formatter, starting_prefix=" " + ) + ) return (fail.exc_info, buff.getvalue()) diff --git a/taskflow/jobs/backends/__init__.py b/taskflow/jobs/backends/__init__.py index 5afede756..ba7fec322 100644 --- a/taskflow/jobs/backends/__init__.py +++ b/taskflow/jobs/backends/__init__.py @@ -51,10 +51,13 @@ def fetch(name, conf, namespace=BACKEND_NAMESPACE, **kwargs): board, conf = misc.extract_driver_and_conf(conf, 'board') LOG.debug('Looking for %r jobboard driver in %r', board, namespace) try: - mgr = driver.DriverManager(namespace, board, - invoke_on_load=True, - invoke_args=(name, conf), - invoke_kwds=kwargs) + mgr = driver.DriverManager( + namespace, + board, + invoke_on_load=True, + invoke_args=(name, conf), + invoke_kwds=kwargs, + ) return mgr.driver except RuntimeError as e: raise exc.NotFound("Could not find jobboard %s" % (board), e) diff --git a/taskflow/jobs/backends/impl_etcd.py b/taskflow/jobs/backends/impl_etcd.py index 2a0e500d4..dd367fe73 100644 --- a/taskflow/jobs/backends/impl_etcd.py +++ b/taskflow/jobs/backends/impl_etcd.py @@ -26,6 +26,7 @@ from taskflow.jobs import base from taskflow import logging from taskflow import states from taskflow.utils import misc + if typing.TYPE_CHECKING: from taskflow.types import entity @@ -37,13 +38,30 @@ class EtcdJob(base.Job): board: 'EtcdJobBoard' - def __init__(self, board: 'EtcdJobBoard', name, client, key, - uuid=None, details=None, backend=None, - book=None, book_data=None, - priority=base.JobPriority.NORMAL, - sequence=None, created_on=None): - super().__init__(board, name, uuid=uuid, details=details, - backend=backend, book=book, book_data=book_data) + def __init__( + self, + board: 'EtcdJobBoard', + name, + client, + key, + uuid=None, + details=None, + backend=None, + book=None, + book_data=None, + priority=base.JobPriority.NORMAL, + sequence=None, + created_on=None, + ): + super().__init__( + board, + name, + uuid=uuid, + details=details, + backend=backend, + book=book, + book_data=book_data, + ) self._client = client self._key = key @@ -79,8 +97,11 @@ class EtcdJob(base.Job): owner, data = self.board.get_owner_and_data(self) if not data: if owner is not None: - LOG.info(f"Owner key was found for job {self.uuid}, " - f"but the key {self.key} is missing") + LOG.info( + "Owner key was found for job %s but the key %s is missing", + self.uuid, + self.key, + ) return states.COMPLETE if not owner: return states.UNCLAIMED @@ -101,8 +122,7 @@ class EtcdJob(base.Job): if 'lease_id' not in owner_data: return None lease_id = owner_data['lease_id'] - self._lease = etcd3gw.Lease(id=lease_id, - client=self._client) + self._lease = etcd3gw.Lease(id=lease_id, client=self._client) return self._lease def expires_in(self): @@ -120,7 +140,7 @@ class EtcdJob(base.Job): if self.lease is None: return False ret = self.lease.refresh() - return (ret > 0) + return ret > 0 @property def root(self): @@ -134,7 +154,8 @@ class EtcdJob(base.Job): return self.sequence < other.sequence else: ordered = base.JobPriority.reorder( - (self.priority, self), (other.priority, other)) + (self.priority, self), (other.priority, other) + ) if ordered[0] is self: return False return True @@ -145,8 +166,11 @@ class EtcdJob(base.Job): def __eq__(self, other): if not isinstance(other, EtcdJob): return NotImplemented - return ((self.root, self.sequence, self.priority) == - (other.root, other.sequence, other.priority)) + return (self.root, self.sequence, self.priority) == ( + other.root, + other.sequence, + other.priority, + ) def __ne__(self, other): return not self.__eq__(other) @@ -184,6 +208,7 @@ class EtcdJobBoard(base.JobBoard): .. _etcd: https://etcd.io/ """ + ROOT_PATH = "/taskflow/jobs" TRASH_PATH = "/taskflow/.trash" @@ -224,8 +249,10 @@ class EtcdJobBoard(base.JobBoard): self._persistence = persistence self._state = self.INIT_STATE - path_elems = [self.ROOT_PATH, - self._conf.get("path", self.DEFAULT_PATH)] + path_elems = [ + self.ROOT_PATH, + self._conf.get("path", self.DEFAULT_PATH), + ] self._root_path = self._create_path(*path_elems) self._job_cache = {} @@ -301,8 +328,7 @@ class EtcdJobBoard(base.JobBoard): try: job_data = jsonutils.loads(data) except jsonutils.json.JSONDecodeError: - msg = ("Incorrectly formatted job data found at " - f"key: {key}") + msg = f"Incorrectly formatted job data found at key: {key}" LOG.warning(msg, exc_info=True) LOG.info("Deleting invalid job data at key: %s", key) self._client.delete(key) @@ -311,16 +337,18 @@ class EtcdJobBoard(base.JobBoard): with self._job_cond: if key not in self._job_cache: job_priority = base.JobPriority.convert(job_data["priority"]) - new_job = EtcdJob(self, - job_data["name"], - self._client, - key, - uuid=job_data["uuid"], - details=job_data.get("details", {}), - backend=self._persistence, - book_data=job_data.get("book"), - priority=job_priority, - sequence=job_data["sequence"]) + new_job = EtcdJob( + self, + job_data["name"], + self._client, + key, + uuid=job_data["uuid"], + details=job_data.get("details", {}), + backend=self._persistence, + book_data=job_data.get("book"), + priority=job_priority, + sequence=job_data["sequence"], + ) self._job_cache[key] = new_job self._job_cond.notify_all() @@ -335,15 +363,18 @@ class EtcdJobBoard(base.JobBoard): self._remove_job_from_cache(job.key) self._client.delete_prefix(job.key) except Exception: - LOG.exception(f"Failed to delete prefix {job.key}") + LOG.exception("Failed to delete prefix %s", job.key) def iterjobs(self, only_unclaimed=False, ensure_fresh=False): """Returns an iterator of jobs that are currently on this board.""" return base.JobBoardIterator( - self, LOG, only_unclaimed=only_unclaimed, + self, + LOG, + only_unclaimed=only_unclaimed, ensure_fresh=ensure_fresh, board_fetch_func=self._fetch_jobs, - board_removal_func=self._board_removal_func) + board_removal_func=self._board_removal_func, + ) def wait(self, timeout=None): """Waits a given amount of time for **any** jobs to be posted.""" @@ -354,9 +385,10 @@ class EtcdJobBoard(base.JobBoard): while True: if not self._job_cache: if watch.expired(): - raise exc.NotFound("Expired waiting for jobs to" - " arrive; waited %s seconds" - % watch.elapsed()) + raise exc.NotFound( + "Expired waiting for jobs to" + " arrive; waited %s seconds" % watch.elapsed() + ) # This is done since the given timeout can not be provided # to the condition variable, since we can not ensure that # when we acquire the condition that there will actually @@ -367,10 +399,14 @@ class EtcdJobBoard(base.JobBoard): curr_jobs = self._fetch_jobs() fetch_func = lambda ensure_fresh: curr_jobs removal_func = lambda a_job: self._remove_job_from_cache( - a_job.key) + a_job.key + ) return base.JobBoardIterator( - self, LOG, board_fetch_func=fetch_func, - board_removal_func=removal_func) + self, + LOG, + board_fetch_func=fetch_func, + board_removal_func=removal_func, + ) @property def job_count(self): @@ -395,11 +431,11 @@ class EtcdJobBoard(base.JobBoard): key = job.key + self.DATA_POSTFIX return self.get_one(key) - def get_owner_and_data(self, job: EtcdJob) -> tuple[ - str | None, bytes | None]: + def get_owner_and_data( + self, job: EtcdJob + ) -> tuple[str | None, bytes | None]: if self._client is None: - raise exc.JobFailure("Cannot retrieve information, " - "not connected") + raise exc.JobFailure("Cannot retrieve information, not connected") job_data = None job_owner = None @@ -426,15 +462,20 @@ class EtcdJobBoard(base.JobBoard): return self.get_one(key) - def post(self, name, book=None, details=None, - priority=base.JobPriority.NORMAL) -> EtcdJob: + def post( + self, name, book=None, details=None, priority=base.JobPriority.NORMAL + ) -> EtcdJob: """Atomically creates and posts a job to the jobboard.""" job_priority = base.JobPriority.convert(priority) job_uuid = uuidutils.generate_uuid() - job_posting = base.format_posting(job_uuid, name, - created_on=timeutils.utcnow(), - book=book, details=details, - priority=job_priority) + job_posting = base.format_posting( + job_uuid, + name, + created_on=timeutils.utcnow(), + book=book, + details=details, + priority=job_priority, + ) seq = self.incr(self._create_path(self._root_path, self.SEQUENCE_KEY)) key = self._create_path(self._root_path, f"{self.JOB_PREFIX}{seq}") @@ -444,14 +485,19 @@ class EtcdJobBoard(base.JobBoard): data_key = key + self.DATA_POSTFIX self._client.create(data_key, raw_job_posting) - job = EtcdJob(self, name, self._client, key, - uuid=job_uuid, - details=details, - backend=self._persistence, - book=book, - book_data=job_posting.get('book'), - priority=job_priority, - sequence=seq) + job = EtcdJob( + self, + name, + self._client, + key, + uuid=job_uuid, + details=details, + backend=self._persistence, + book=book, + book_data=job_posting.get('book'), + priority=job_priority, + sequence=seq, + ) with self._job_cond: self._job_cache[key] = job self._job_cond.notify_all() @@ -511,8 +557,9 @@ class EtcdJobBoard(base.JobBoard): if data is None or owner is None: raise exc.NotFound(f"Cannot find job {job.uuid}") if owner != who: - raise exc.JobFailure(f"Cannot consume a job {job.uuid}" - f" which is not owned by {who}") + raise exc.JobFailure( + f"Cannot consume a job {job.uuid} which is not owned by {who}" + ) self._client.delete_prefix(job.key + ".") self._remove_job_from_cache(job.key) @@ -524,8 +571,9 @@ class EtcdJobBoard(base.JobBoard): if data is None or owner is None: raise exc.NotFound(f"Cannot find job {job.uuid}") if owner != who: - raise exc.JobFailure(f"Cannot abandon a job {job.uuid}" - f" which is not owned by {who}") + raise exc.JobFailure( + f"Cannot abandon a job {job.uuid} which is not owned by {who}" + ) owner_key = job.key + self.LOCK_POSTFIX self._client.delete(owner_key) @@ -537,8 +585,9 @@ class EtcdJobBoard(base.JobBoard): if data is None or owner is None: raise exc.NotFound(f"Cannot find job {job.uuid}") if owner != who: - raise exc.JobFailure(f"Cannot trash a job {job.uuid} " - f"which is not owned by {who}") + raise exc.JobFailure( + f"Cannot trash a job {job.uuid} which is not owned by {who}" + ) trash_key = job.key.replace(self.ROOT_PATH, self.TRASH_PATH) self._client.create(trash_key, data) @@ -570,11 +619,13 @@ class EtcdJobBoard(base.JobBoard): watch_url = self._create_path(self._root_path, self.JOB_PREFIX) self._thread_cancel = threading.Event() try: - (self._watcher, - self._watcher_cancel) = self._client.watch_prefix(watch_url) + (self._watcher, self._watcher_cancel) = ( + self._client.watch_prefix(watch_url) + ) except etcd3gw.exceptions.ConnectionFailedError: - exc.raise_with_cause(exc.JobFailure, - "Failed to connect to Etcd") + exc.raise_with_cause( + exc.JobFailure, "Failed to connect to Etcd" + ) self._watcher_thd = threading.Thread(target=self._watcher_thread) self._watcher_thd.start() diff --git a/taskflow/jobs/backends/impl_redis.py b/taskflow/jobs/backends/impl_redis.py index 617cad5b4..cbf00e2de 100644 --- a/taskflow/jobs/backends/impl_redis.py +++ b/taskflow/jobs/backends/impl_redis.py @@ -48,28 +48,43 @@ def _translate_failures(): except redis_exceptions.ConnectionError: exc.raise_with_cause(exc.JobFailure, "Failed to connect to redis") except redis_exceptions.TimeoutError: - exc.raise_with_cause(exc.JobFailure, - "Failed to communicate with redis, connection" - " timed out") + exc.raise_with_cause( + exc.JobFailure, + "Failed to communicate with redis, connection timed out", + ) except redis_exceptions.RedisError: - exc.raise_with_cause(exc.JobFailure, - "Failed to communicate with redis," - " internal error") + exc.raise_with_cause( + exc.JobFailure, "Failed to communicate with redis, internal error" + ) @functools.total_ordering class RedisJob(base.Job): """A redis job.""" - def __init__(self, board, name, sequence, key, - uuid=None, details=None, - created_on=None, backend=None, - book=None, book_data=None, - priority=base.JobPriority.NORMAL): - super().__init__(board, name, - uuid=uuid, details=details, - backend=backend, - book=book, book_data=book_data) + def __init__( + self, + board, + name, + sequence, + key, + uuid=None, + details=None, + created_on=None, + backend=None, + book=None, + book_data=None, + priority=base.JobPriority.NORMAL, + ): + super().__init__( + board, + name, + uuid=uuid, + details=details, + backend=backend, + book=book, + book_data=book_data, + ) self._created_on = created_on self._client = board._client self._redis_version = board._redis_version @@ -113,8 +128,11 @@ class RedisJob(base.Job): :attr:`.owner_key` expired at/before time of inquiry?). """ with _translate_failures(): - return ru.get_expiry(self._client, self._owner_key, - prior_version=self._redis_version) + return ru.get_expiry( + self._client, + self._owner_key, + prior_version=self._redis_version, + ) def extend_expiry(self, expiry): """Extends the owner key (aka the claim) expiry for this job. @@ -128,8 +146,12 @@ class RedisJob(base.Job): otherwise ``False``. """ with _translate_failures(): - return ru.apply_expiry(self._client, self._owner_key, expiry, - prior_version=self._redis_version) + return ru.apply_expiry( + self._client, + self._owner_key, + expiry, + prior_version=self._redis_version, + ) def __lt__(self, other): if not isinstance(other, RedisJob): @@ -139,7 +161,8 @@ class RedisJob(base.Job): return self.sequence < other.sequence else: ordered = base.JobPriority.reorder( - (self.priority, self), (other.priority, other)) + (self.priority, self), (other.priority, other) + ) if ordered[0] is self: return False return True @@ -150,8 +173,11 @@ class RedisJob(base.Job): def __eq__(self, other): if not isinstance(other, RedisJob): return NotImplemented - return ((self.board.listings_key, self.priority, self.sequence) == - (other.board.listings_key, other.priority, other.sequence)) + return (self.board.listings_key, self.priority, self.sequence) == ( + other.board.listings_key, + other.priority, + other.sequence, + ) def __ne__(self, other): return not self.__eq__(other) @@ -170,7 +196,8 @@ class RedisJob(base.Job): last_modified = None if raw_last_modified: last_modified = self._board._loads( - raw_last_modified, root_types=(datetime.datetime,)) + raw_last_modified, root_types=(datetime.datetime,) + ) # NOTE(harlowja): just incase this is somehow busted (due to time # sync issues/other), give back the most recent one (since redis # does not maintain clock information; we could have this happen @@ -199,9 +226,13 @@ class RedisJob(base.Job): # This should **not** be possible due to lua code ordering # but let's log an INFO statement if it does happen (so # that it can be investigated)... - LOG.info("Unexpected owner key found at '%s' when job" - " key '%s[%s]' was not found", owner_key, - listings_key, listings_sub_key) + LOG.info( + "Unexpected owner key found at '%s' when job" + " key '%s[%s]' was not found", + owner_key, + listings_key, + listings_sub_key, + ) return states.COMPLETE else: if owner_exists: @@ -210,9 +241,9 @@ class RedisJob(base.Job): return states.UNCLAIMED with _translate_failures(): - return self._client.transaction(_do_fetch, - listings_key, owner_key, - value_from_callable=True) + return self._client.transaction( + _do_fetch, listings_key, owner_key, value_from_callable=True + ) class RedisJobBoard(base.JobBoard): @@ -255,37 +286,33 @@ class RedisJobBoard(base.JobBoard): .. _hash: https://redis.io/topics/data-types#hashes """ - CLIENT_CONF_TRANSFERS = tuple([ - # Host config... - ('host', str), - ('port', int), - - # See: http://redis.io/commands/auth - ('username', str), - ('password', str), - - # Data encoding/decoding + error handling - ('encoding', str), - ('encoding_errors', str), - - # Connection settings. - ('socket_timeout', float), - ('socket_connect_timeout', float), - - # This one negates the usage of host, port, socket connection - # settings as it doesn't use the same kind of underlying socket... - ('unix_socket_path', str), - - # Do u want ssl??? - ('ssl', strutils.bool_from_string), - ('ssl_keyfile', str), - ('ssl_certfile', str), - ('ssl_cert_reqs', str), - ('ssl_ca_certs', str), - - # See: http://www.rediscookbook.org/multiple_databases.html - ('db', int), - ]) + CLIENT_CONF_TRANSFERS = tuple( + [ + # Host config... + ('host', str), + ('port', int), + # See: http://redis.io/commands/auth + ('username', str), + ('password', str), + # Data encoding/decoding + error handling + ('encoding', str), + ('encoding_errors', str), + # Connection settings. + ('socket_timeout', float), + ('socket_connect_timeout', float), + # This one negates the usage of host, port, socket connection + # settings as it doesn't use the same kind of underlying socket... + ('unix_socket_path', str), + # Do u want ssl??? + ('ssl', strutils.bool_from_string), + ('ssl_keyfile', str), + ('ssl_certfile', str), + ('ssl_cert_reqs', str), + ('ssl_ca_certs', str), + # See: http://www.rediscookbook.org/multiple_databases.html + ('db', int), + ] + ) """ Keys (and value type converters) that we allow to proxy from the jobboard configuration into the redis client (used to configure the redis client @@ -566,8 +593,9 @@ return cmsgpack.pack(result) @classmethod def _filter_ssl_options(cls, opts): if not opts.get('ssl', False): - return {k: v for (k, v) in opts.items() - if not k.startswith('ssl_')} + return { + k: v for (k, v) in opts.items() if not k.startswith('ssl_') + } return opts @classmethod @@ -587,15 +615,14 @@ return cmsgpack.pack(result) sentinel_kwargs = conf.get('sentinel_kwargs') if sentinel_kwargs is not None: sentinel_kwargs = cls._filter_ssl_options(sentinel_kwargs) - s = sentinel.Sentinel(sentinels, - sentinel_kwargs=sentinel_kwargs, - **client_conf) + s = sentinel.Sentinel( + sentinels, sentinel_kwargs=sentinel_kwargs, **client_conf + ) return s.master_for(conf['sentinel']) else: return ru.RedisClient(**client_conf) - def __init__(self, name, conf, - client=None, persistence=None): + def __init__(self, name, conf, client=None, persistence=None): super().__init__(name, conf) self._closed = True if client is not None: @@ -682,25 +709,29 @@ return cmsgpack.pack(result) # op occurs). self._client.ping() is_new_enough, redis_version = ru.is_server_new_enough( - self._client, self.MIN_REDIS_VERSION) + self._client, self.MIN_REDIS_VERSION + ) if not is_new_enough: - wanted_version = ".".join([str(p) - for p in self.MIN_REDIS_VERSION]) + wanted_version = ".".join( + [str(p) for p in self.MIN_REDIS_VERSION] + ) if redis_version: - raise exc.JobFailure("Redis version %s or greater is" - " required (version %s is to" - " old)" % (wanted_version, - redis_version)) + raise exc.JobFailure( + "Redis version %s or greater is" + " required (version %s is to" + " old)" % (wanted_version, redis_version) + ) else: - raise exc.JobFailure("Redis version %s or greater is" - " required" % (wanted_version)) + raise exc.JobFailure( + "Redis version %s or greater is" + " required" % (wanted_version) + ) else: self._redis_version = redis_version script_params = { # Status field values. 'ok': self.SCRIPT_STATUS_OK, 'error': self.SCRIPT_STATUS_ERROR, - # Known error reasons (when status field is error). 'not_expected_owner': self.SCRIPT_NOT_EXPECTED_OWNER, 'unknown_owner': self.SCRIPT_UNKNOWN_OWNER, @@ -729,18 +760,20 @@ return cmsgpack.pack(result) try: return msgpackutils.dumps(obj) except Exception: - exc.raise_with_cause(exc.JobFailure, - "Failed to serialize object to" - " msgpack blob") + exc.raise_with_cause( + exc.JobFailure, "Failed to serialize object to msgpack blob" + ) @staticmethod def _loads(blob, root_types=(dict,)): try: return misc.decode_msgpack(blob, root_types=root_types) except ValueError: - exc.raise_with_cause(exc.JobFailure, - "Failed to deserialize object from" - " msgpack blob (of length %s)" % len(blob)) + exc.raise_with_cause( + exc.JobFailure, + "Failed to deserialize object from" + " msgpack blob (of length %s)" % len(blob), + ) _decode_owner = staticmethod(misc.binary_decode) @@ -752,42 +785,66 @@ return cmsgpack.pack(result) raw_owner = self._client.get(owner_key) return self._decode_owner(raw_owner) - def post(self, name, book=None, details=None, - priority=base.JobPriority.NORMAL): + def post( + self, name, book=None, details=None, priority=base.JobPriority.NORMAL + ): job_uuid = uuidutils.generate_uuid() job_priority = base.JobPriority.convert(priority) - posting = base.format_posting(job_uuid, name, - created_on=timeutils.utcnow(), - book=book, details=details, - priority=job_priority) + posting = base.format_posting( + job_uuid, + name, + created_on=timeutils.utcnow(), + book=book, + details=details, + priority=job_priority, + ) with _translate_failures(): sequence = self._client.incr(self.sequence_key) - posting.update({ - 'sequence': sequence, - }) + posting.update( + { + 'sequence': sequence, + } + ) with _translate_failures(): raw_posting = self._dumps(posting) raw_job_uuid = job_uuid.encode('latin-1') - was_posted = bool(self._client.hsetnx(self.listings_key, - raw_job_uuid, raw_posting)) + was_posted = bool( + self._client.hsetnx( + self.listings_key, raw_job_uuid, raw_posting + ) + ) if not was_posted: - raise exc.JobFailure("New job located at '%s[%s]' could not" - " be posted" % (self.listings_key, - raw_job_uuid)) + raise exc.JobFailure( + "New job located at '%s[%s]' could not" + " be posted" % (self.listings_key, raw_job_uuid) + ) else: - return RedisJob(self, name, sequence, raw_job_uuid, - uuid=job_uuid, details=details, - created_on=posting['created_on'], - book=book, book_data=posting.get('book'), - backend=self._persistence, - priority=job_priority) + return RedisJob( + self, + name, + sequence, + raw_job_uuid, + uuid=job_uuid, + details=details, + created_on=posting['created_on'], + book=book, + book_data=posting.get('book'), + backend=self._persistence, + priority=job_priority, + ) - def wait(self, timeout=None, initial_delay=0.005, - max_delay=1.0, sleep_func=time.sleep): + def wait( + self, + timeout=None, + initial_delay=0.005, + max_delay=1.0, + sleep_func=time.sleep, + ): if initial_delay > max_delay: - raise ValueError("Initial delay %s must be less than or equal" - " to the provided max delay %s" - % (initial_delay, max_delay)) + raise ValueError( + "Initial delay %s must be less than or equal" + " to the provided max delay %s" % (initial_delay, max_delay) + ) # This does a spin-loop that backs off by doubling the delay # up to the provided max-delay. In the future we could try having # a secondary client connected into redis pubsub and use that @@ -801,12 +858,15 @@ return cmsgpack.pack(result) curr_jobs = self._fetch_jobs() if curr_jobs: return base.JobBoardIterator( - self, LOG, - board_fetch_func=lambda ensure_fresh: curr_jobs) + self, + LOG, + board_fetch_func=lambda ensure_fresh: curr_jobs, + ) if w.expired(): - raise exc.NotFound("Expired waiting for jobs to" - " arrive; waited %s seconds" - % w.elapsed()) + raise exc.NotFound( + "Expired waiting for jobs to" + " arrive; waited %s seconds" % w.elapsed() + ) else: remaining = w.leftover(return_none=True) if remaining is not None: @@ -834,27 +894,43 @@ return cmsgpack.pack(result) job_details = job_data.get('details', {}) except (ValueError, TypeError, KeyError, exc.JobFailure): with excutils.save_and_reraise_exception(): - LOG.warning("Incorrectly formatted job data found at" - " key: %s[%s]", self.listings_key, - raw_job_key, exc_info=True) - LOG.info("Deleting invalid job data at key: %s[%s]", - self.listings_key, raw_job_key) + LOG.warning( + "Incorrectly formatted job data found at key: %s[%s]", + self.listings_key, + raw_job_key, + exc_info=True, + ) + LOG.info( + "Deleting invalid job data at key: %s[%s]", + self.listings_key, + raw_job_key, + ) self._client.hdel(self.listings_key, raw_job_key) else: - postings.append(RedisJob(self, job_name, job_sequence_id, - raw_job_key, uuid=job_uuid, - details=job_details, - created_on=job_created_on, - book_data=job_data.get('book'), - backend=self._persistence, - priority=job_priority)) + postings.append( + RedisJob( + self, + job_name, + job_sequence_id, + raw_job_key, + uuid=job_uuid, + details=job_details, + created_on=job_created_on, + book_data=job_data.get('book'), + backend=self._persistence, + priority=job_priority, + ) + ) return sorted(postings, reverse=True) def iterjobs(self, only_unclaimed=False, ensure_fresh=False): return base.JobBoardIterator( - self, LOG, only_unclaimed=only_unclaimed, + self, + LOG, + only_unclaimed=only_unclaimed, ensure_fresh=ensure_fresh, - board_fetch_func=lambda ensure_fresh: self._fetch_jobs()) + board_fetch_func=lambda ensure_fresh: self._fetch_jobs(), + ) def register_entity(self, entity): # Will implement a redis jobboard conductor register later @@ -865,36 +941,43 @@ return cmsgpack.pack(result) script = self._get_script('consume') with _translate_failures(): raw_who = self._encode_owner(who) - raw_result = script(keys=[job.owner_key, self.listings_key, - job.last_modified_key], - args=[raw_who, job.key]) + raw_result = script( + keys=[job.owner_key, self.listings_key, job.last_modified_key], + args=[raw_who, job.key], + ) result = self._loads(raw_result) status = result['status'] if status != self.SCRIPT_STATUS_OK: reason = result.get('reason') if reason == self.SCRIPT_UNKNOWN_JOB: - raise exc.NotFound("Job %s not found to be" - " consumed" % (job.uuid)) + raise exc.NotFound( + "Job %s not found to be consumed" % (job.uuid) + ) elif reason == self.SCRIPT_UNKNOWN_OWNER: - raise exc.NotFound("Can not consume job %s" - " which we can not determine" - " the owner of" % (job.uuid)) + raise exc.NotFound( + "Can not consume job %s" + " which we can not determine" + " the owner of" % (job.uuid) + ) elif reason == self.SCRIPT_NOT_EXPECTED_OWNER: raw_owner = result.get('owner') if raw_owner: owner = self._decode_owner(raw_owner) - raise exc.JobFailure("Can not consume job %s" - " which is not owned by %s (it is" - " actively owned by %s)" - % (job.uuid, who, owner)) + raise exc.JobFailure( + "Can not consume job %s" + " which is not owned by %s (it is" + " actively owned by %s)" % (job.uuid, who, owner) + ) else: - raise exc.JobFailure("Can not consume job %s" - " which is not owned by %s" - % (job.uuid, who)) + raise exc.JobFailure( + "Can not consume job %s" + " which is not owned by %s" % (job.uuid, who) + ) else: - raise exc.JobFailure("Failure to consume job %s," - " unknown internal error (reason=%s)" - % (job.uuid, reason)) + raise exc.JobFailure( + "Failure to consume job %s," + " unknown internal error (reason=%s)" % (job.uuid, reason) + ) @base.check_who def claim(self, job, who, expiry=None): @@ -906,122 +989,151 @@ return cmsgpack.pack(result) else: ms_expiry = int(expiry * 1000.0) if ms_expiry <= 0: - raise ValueError("Provided expiry (when converted to" - " milliseconds) must be greater" - " than zero instead of %s" % (expiry)) + raise ValueError( + "Provided expiry (when converted to" + " milliseconds) must be greater" + " than zero instead of %s" % (expiry) + ) script = self._get_script('claim') with _translate_failures(): raw_who = self._encode_owner(who) - raw_result = script(keys=[job.owner_key, self.listings_key, - job.last_modified_key], - args=[raw_who, job.key, - # NOTE(harlowja): we need to send this - # in as a blob (even if it's not - # set/used), since the format can not - # currently be created in lua... - self._dumps(timeutils.utcnow()), - ms_expiry]) + raw_result = script( + keys=[job.owner_key, self.listings_key, job.last_modified_key], + args=[ + raw_who, + job.key, + # NOTE(harlowja): we need to send this + # in as a blob (even if it's not + # set/used), since the format can not + # currently be created in lua... + self._dumps(timeutils.utcnow()), + ms_expiry, + ], + ) result = self._loads(raw_result) status = result['status'] if status != self.SCRIPT_STATUS_OK: reason = result.get('reason') if reason == self.SCRIPT_UNKNOWN_JOB: - raise exc.NotFound("Job %s not found to be" - " claimed" % (job.uuid)) + raise exc.NotFound( + "Job %s not found to be claimed" % (job.uuid) + ) elif reason == self.SCRIPT_ALREADY_CLAIMED: raw_owner = result.get('owner') if raw_owner: owner = self._decode_owner(raw_owner) - raise exc.UnclaimableJob("Job %s already" - " claimed by %s" - % (job.uuid, owner)) + raise exc.UnclaimableJob( + "Job %s already claimed by %s" % (job.uuid, owner) + ) else: - raise exc.UnclaimableJob("Job %s already" - " claimed" % (job.uuid)) + raise exc.UnclaimableJob( + "Job %s already claimed" % (job.uuid) + ) else: - raise exc.JobFailure("Failure to claim job %s," - " unknown internal error (reason=%s)" - % (job.uuid, reason)) + raise exc.JobFailure( + "Failure to claim job %s," + " unknown internal error (reason=%s)" % (job.uuid, reason) + ) @base.check_who def abandon(self, job, who): script = self._get_script('abandon') with _translate_failures(): raw_who = self._encode_owner(who) - raw_result = script(keys=[job.owner_key, self.listings_key, - job.last_modified_key], - args=[raw_who, job.key, - self._dumps(timeutils.utcnow())]) + raw_result = script( + keys=[job.owner_key, self.listings_key, job.last_modified_key], + args=[raw_who, job.key, self._dumps(timeutils.utcnow())], + ) result = self._loads(raw_result) status = result.get('status') if status != self.SCRIPT_STATUS_OK: reason = result.get('reason') if reason == self.SCRIPT_UNKNOWN_JOB: - raise exc.NotFound("Job %s not found to be" - " abandoned" % (job.uuid)) + raise exc.NotFound( + "Job %s not found to be abandoned" % (job.uuid) + ) elif reason == self.SCRIPT_UNKNOWN_OWNER: - raise exc.NotFound("Can not abandon job %s" - " which we can not determine" - " the owner of" % (job.uuid)) + raise exc.NotFound( + "Can not abandon job %s" + " which we can not determine" + " the owner of" % (job.uuid) + ) elif reason == self.SCRIPT_NOT_EXPECTED_OWNER: raw_owner = result.get('owner') if raw_owner: owner = self._decode_owner(raw_owner) - raise exc.JobFailure("Can not abandon job %s" - " which is not owned by %s (it is" - " actively owned by %s)" - % (job.uuid, who, owner)) + raise exc.JobFailure( + "Can not abandon job %s" + " which is not owned by %s (it is" + " actively owned by %s)" % (job.uuid, who, owner) + ) else: - raise exc.JobFailure("Can not abandon job %s" - " which is not owned by %s" - % (job.uuid, who)) + raise exc.JobFailure( + "Can not abandon job %s" + " which is not owned by %s" % (job.uuid, who) + ) else: - raise exc.JobFailure("Failure to abandon job %s," - " unknown internal" - " error (status=%s, reason=%s)" - % (job.uuid, status, reason)) + raise exc.JobFailure( + "Failure to abandon job %s," + " unknown internal" + " error (status=%s, reason=%s)" + % (job.uuid, status, reason) + ) def _get_script(self, name): try: return self._scripts[name] except KeyError: - exc.raise_with_cause(exc.NotFound, - "Can not access %s script (has this" - " board been connected?)" % name) + exc.raise_with_cause( + exc.NotFound, + "Can not access %s script (has this" + " board been connected?)" % name, + ) @base.check_who def trash(self, job, who): script = self._get_script('trash') with _translate_failures(): raw_who = self._encode_owner(who) - raw_result = script(keys=[job.owner_key, self.listings_key, - job.last_modified_key, self.trash_key], - args=[raw_who, job.key, - self._dumps(timeutils.utcnow())]) + raw_result = script( + keys=[ + job.owner_key, + self.listings_key, + job.last_modified_key, + self.trash_key, + ], + args=[raw_who, job.key, self._dumps(timeutils.utcnow())], + ) result = self._loads(raw_result) status = result['status'] if status != self.SCRIPT_STATUS_OK: reason = result.get('reason') if reason == self.SCRIPT_UNKNOWN_JOB: - raise exc.NotFound("Job %s not found to be" - " trashed" % (job.uuid)) + raise exc.NotFound( + "Job %s not found to be trashed" % (job.uuid) + ) elif reason == self.SCRIPT_UNKNOWN_OWNER: - raise exc.NotFound("Can not trash job %s" - " which we can not determine" - " the owner of" % (job.uuid)) + raise exc.NotFound( + "Can not trash job %s" + " which we can not determine" + " the owner of" % (job.uuid) + ) elif reason == self.SCRIPT_NOT_EXPECTED_OWNER: raw_owner = result.get('owner') if raw_owner: owner = self._decode_owner(raw_owner) - raise exc.JobFailure("Can not trash job %s" - " which is not owned by %s (it is" - " actively owned by %s)" - % (job.uuid, who, owner)) + raise exc.JobFailure( + "Can not trash job %s" + " which is not owned by %s (it is" + " actively owned by %s)" % (job.uuid, who, owner) + ) else: - raise exc.JobFailure("Can not trash job %s" - " which is not owned by %s" - % (job.uuid, who)) + raise exc.JobFailure( + "Can not trash job %s" + " which is not owned by %s" % (job.uuid, who) + ) else: - raise exc.JobFailure("Failure to trash job %s," - " unknown internal error (reason=%s)" - % (job.uuid, reason)) + raise exc.JobFailure( + "Failure to trash job %s," + " unknown internal error (reason=%s)" % (job.uuid, reason) + ) diff --git a/taskflow/jobs/backends/impl_zookeeper.py b/taskflow/jobs/backends/impl_zookeeper.py index 582e186ab..7d5da1c5a 100644 --- a/taskflow/jobs/backends/impl_zookeeper.py +++ b/taskflow/jobs/backends/impl_zookeeper.py @@ -45,22 +45,37 @@ LOG = logging.getLogger(__name__) class ZookeeperJob(base.Job): """A zookeeper job.""" - def __init__(self, board, name, client, path, - uuid=None, details=None, book=None, book_data=None, - created_on=None, backend=None, - priority=base.JobPriority.NORMAL): - super().__init__(board, name, - uuid=uuid, details=details, - backend=backend, - book=book, book_data=book_data) + def __init__( + self, + board, + name, + client, + path, + uuid=None, + details=None, + book=None, + book_data=None, + created_on=None, + backend=None, + priority=base.JobPriority.NORMAL, + ): + super().__init__( + board, + name, + uuid=uuid, + details=details, + backend=backend, + book=book, + book_data=book_data, + ) self._client = client self._path = k_paths.normpath(path) self._lock_path = self._path + board.LOCK_POSTFIX self._created_on = created_on self._node_not_found = False basename = k_paths.basename(self._path) - self._root = self._path[0:-len(basename)] - self._sequence = int(basename[len(board.JOB_PREFIX):]) + self._root = self._path[0 : -len(basename)] + self._sequence = int(basename[len(board.JOB_PREFIX) :]) self._priority = priority @property @@ -99,23 +114,26 @@ class ZookeeperJob(base.Job): excp.raise_with_cause( excp.NotFound, "Can not fetch the %r attribute of job %s (%s)," - " path %s not found" % (attr_name, self.uuid, - self.path, path)) + " path %s not found" % (attr_name, self.uuid, self.path, path), + ) except self._client.handler.timeout_exception: excp.raise_with_cause( excp.JobFailure, "Can not fetch the %r attribute of job %s (%s)," - " operation timed out" % (attr_name, self.uuid, self.path)) + " operation timed out" % (attr_name, self.uuid, self.path), + ) except k_exceptions.SessionExpiredError: excp.raise_with_cause( excp.JobFailure, "Can not fetch the %r attribute of job %s (%s)," - " session expired" % (attr_name, self.uuid, self.path)) + " session expired" % (attr_name, self.uuid, self.path), + ) except (AttributeError, k_exceptions.KazooException): excp.raise_with_cause( excp.JobFailure, "Can not fetch the %r attribute of job %s (%s)," - " internal error" % (attr_name, self.uuid, self.path)) + " internal error" % (attr_name, self.uuid, self.path), + ) @property def last_modified(self): @@ -123,8 +141,8 @@ class ZookeeperJob(base.Job): try: if not self._node_not_found: modified_on = self._get_node_attr( - self.path, 'mtime', - trans_func=misc.millis_to_datetime) + self.path, 'mtime', trans_func=misc.millis_to_datetime + ) except excp.NotFound: self._node_not_found = True return modified_on @@ -137,8 +155,8 @@ class ZookeeperJob(base.Job): if self._created_on is None: try: self._created_on = self._get_node_attr( - self.path, 'ctime', - trans_func=misc.millis_to_datetime) + self.path, 'ctime', trans_func=misc.millis_to_datetime + ) except excp.NotFound: self._node_not_found = True return self._created_on @@ -155,18 +173,19 @@ class ZookeeperJob(base.Job): except k_exceptions.SessionExpiredError: excp.raise_with_cause( excp.JobFailure, - "Can not fetch the state of %s," - " session expired" % (self.uuid)) + "Can not fetch the state of %s, session expired" % (self.uuid), + ) except self._client.handler.timeout_exception: excp.raise_with_cause( excp.JobFailure, "Can not fetch the state of %s," - " operation timed out" % (self.uuid)) + " operation timed out" % (self.uuid), + ) except k_exceptions.KazooException: excp.raise_with_cause( excp.JobFailure, - "Can not fetch the state of %s," - " internal error" % (self.uuid)) + "Can not fetch the state of %s, internal error" % (self.uuid), + ) if not job_data: # No data this job has been completed (the owner that we might have # fetched will not be able to be fetched again, since the job node @@ -185,7 +204,8 @@ class ZookeeperJob(base.Job): return self.sequence < other.sequence else: ordered = base.JobPriority.reorder( - (self.priority, self), (other.priority, other)) + (self.priority, self), (other.priority, other) + ) if ordered[0] is self: return False return True @@ -196,8 +216,11 @@ class ZookeeperJob(base.Job): def __eq__(self, other): if not isinstance(other, ZookeeperJob): return NotImplemented - return ((self.root, self.sequence, self.priority) == - (other.root, other.sequence, other.priority)) + return (self.root, self.sequence, self.priority) == ( + other.root, + other.sequence, + other.priority, + ) def __ne__(self, other): return not self.__eq__(other) @@ -277,8 +300,14 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): or may be recovered (aka, it has not full disconnected). """ - def __init__(self, name, conf, - client=None, persistence=None, emit_notifications=True): + def __init__( + self, + name, + conf, + client=None, + persistence=None, + emit_notifications=True, + ): super().__init__(name, conf) if client is not None: self._client = client @@ -292,11 +321,12 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): if not k_paths.isabs(path): raise ValueError("Zookeeper path must be absolute") self._path = path - self._trash_path = self._path.replace(k_paths.basename(self._path), - self.TRASH_FOLDER) + self._trash_path = self._path.replace( + k_paths.basename(self._path), self.TRASH_FOLDER + ) self._entity_path = self._path.replace( - k_paths.basename(self._path), - self.ENTITY_FOLDER) + k_paths.basename(self._path), self.ENTITY_FOLDER + ) # The backend to load the full logbooks from, since what is sent over # the data connection is only the logbook uuid and name, and not the # full logbook. @@ -378,23 +408,30 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): maybe_children = self._client.get_children(self.path) self._on_job_posting(maybe_children, delayed=False) except self._client.handler.timeout_exception: - excp.raise_with_cause(excp.JobFailure, - "Refreshing failure, operation timed out") + excp.raise_with_cause( + excp.JobFailure, "Refreshing failure, operation timed out" + ) except k_exceptions.SessionExpiredError: - excp.raise_with_cause(excp.JobFailure, - "Refreshing failure, session expired") + excp.raise_with_cause( + excp.JobFailure, "Refreshing failure, session expired" + ) except k_exceptions.NoNodeError: pass except k_exceptions.KazooException: - excp.raise_with_cause(excp.JobFailure, - "Refreshing failure, internal error") + excp.raise_with_cause( + excp.JobFailure, "Refreshing failure, internal error" + ) def iterjobs(self, only_unclaimed=False, ensure_fresh=False): board_removal_func = lambda job: self._remove_job(job.path) return base.JobBoardIterator( - self, LOG, only_unclaimed=only_unclaimed, - ensure_fresh=ensure_fresh, board_fetch_func=self._fetch_jobs, - board_removal_func=board_removal_func) + self, + LOG, + only_unclaimed=only_unclaimed, + ensure_fresh=ensure_fresh, + board_fetch_func=self._fetch_jobs, + board_removal_func=board_removal_func, + ) def _remove_job(self, path): if path not in self._known_jobs: @@ -424,38 +461,56 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): job_name = job_data['name'] except (ValueError, TypeError, KeyError): with excutils.save_and_reraise_exception(reraise=not quiet): - LOG.warning("Incorrectly formatted job data found at path: %s", - path, exc_info=True) + LOG.warning( + "Incorrectly formatted job data found at path: %s", + path, + exc_info=True, + ) except self._client.handler.timeout_exception: with excutils.save_and_reraise_exception(reraise=not quiet): - LOG.warning("Operation timed out fetching job data from" - " from path: %s", - path, exc_info=True) + LOG.warning( + "Operation timed out fetching job data from from path: %s", + path, + exc_info=True, + ) except k_exceptions.SessionExpiredError: with excutils.save_and_reraise_exception(reraise=not quiet): - LOG.warning("Session expired fetching job data from path: %s", - path, exc_info=True) + LOG.warning( + "Session expired fetching job data from path: %s", + path, + exc_info=True, + ) except k_exceptions.NoNodeError: - LOG.debug("No job node found at path: %s, it must have" - " disappeared or was removed", path) + LOG.debug( + "No job node found at path: %s, it must have" + " disappeared or was removed", + path, + ) except k_exceptions.KazooException: with excutils.save_and_reraise_exception(reraise=not quiet): - LOG.warning("Internal error fetching job data from path: %s", - path, exc_info=True) + LOG.warning( + "Internal error fetching job data from path: %s", + path, + exc_info=True, + ) else: with self._job_cond: # Now we can officially check if someone already placed this # jobs information into the known job set (if it's already # existing then just leave it alone). if path not in self._known_jobs: - job = ZookeeperJob(self, job_name, - self._client, path, - backend=self._persistence, - uuid=job_uuid, - book_data=job_data.get("book"), - details=job_data.get("details", {}), - created_on=job_created_on, - priority=job_priority) + job = ZookeeperJob( + self, + job_name, + self._client, + path, + backend=self._persistence, + uuid=job_uuid, + book_data=job_data.get("book"), + details=job_data.get("details", {}), + created_on=job_created_on, + priority=job_priority, + ) self._known_jobs[path] = job self._job_cond.notify_all() if job is not None: @@ -465,8 +520,9 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): LOG.debug("Got children %s under path %s", children, self.path) child_paths = [] for c in children: - if (c.endswith(self.LOCK_POSTFIX) or - not c.startswith(self.JOB_PREFIX)): + if c.endswith(self.LOCK_POSTFIX) or not c.startswith( + self.JOB_PREFIX + ): # Skip lock paths or non-job-paths (these are not valid jobs) continue child_paths.append(k_paths.join(self.path, c)) @@ -513,29 +569,42 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): else: self._process_child(path, request, quiet=False) - def post(self, name, book=None, details=None, - priority=base.JobPriority.NORMAL): + def post( + self, name, book=None, details=None, priority=base.JobPriority.NORMAL + ): # NOTE(harlowja): Jobs are not ephemeral, they will persist until they # are consumed (this may change later, but seems safer to do this until # further notice). job_priority = base.JobPriority.convert(priority) job_uuid = uuidutils.generate_uuid() - job_posting = base.format_posting(job_uuid, name, - book=book, details=details, - priority=job_priority) + job_posting = base.format_posting( + job_uuid, name, book=book, details=details, priority=job_priority + ) raw_job_posting = misc.binary_encode(jsonutils.dumps(job_posting)) - with self._wrap(job_uuid, None, - fail_msg_tpl="Posting failure: %s", - ensure_known=False): - job_path = self._client.create(self._job_base, - value=raw_job_posting, - sequence=True, - ephemeral=False) - job = ZookeeperJob(self, name, self._client, job_path, - backend=self._persistence, - book=book, details=details, uuid=job_uuid, - book_data=job_posting.get('book'), - priority=job_priority) + with self._wrap( + job_uuid, + None, + fail_msg_tpl="Posting failure: %s", + ensure_known=False, + ): + job_path = self._client.create( + self._job_base, + value=raw_job_posting, + sequence=True, + ephemeral=False, + ) + job = ZookeeperJob( + self, + name, + self._client, + job_path, + backend=self._persistence, + book=book, + details=details, + uuid=job_uuid, + book_data=job_posting.get('book'), + priority=job_priority, + ) with self._job_cond: self._known_jobs[job_path] = job self._job_cond.notify_all() @@ -551,19 +620,22 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): owner = None if owner: message = "Job {} already claimed by '{}'".format( - job.uuid, owner) + job.uuid, owner + ) else: message = "Job %s already claimed" % (job.uuid) - excp.raise_with_cause(excp.UnclaimableJob, - message, cause=cause) + excp.raise_with_cause(excp.UnclaimableJob, message, cause=cause) - with self._wrap(job.uuid, job.path, - fail_msg_tpl="Claiming failure: %s"): + with self._wrap( + job.uuid, job.path, fail_msg_tpl="Claiming failure: %s" + ): # NOTE(harlowja): post as json which will allow for future changes # more easily than a raw string/text. - value = jsonutils.dumps({ - 'owner': who, - }) + value = jsonutils.dumps( + { + 'owner': who, + } + ) # Ensure the target job is still existent (at the right version). job_data, job_stat = self._client.get(job.path) txn = self._client.transaction() @@ -571,8 +643,9 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): # removed (somehow...) or updated by someone else to a different # version... txn.check(job.path, version=job_stat.version) - txn.create(job.lock_path, value=misc.binary_encode(value), - ephemeral=True) + txn.create( + job.lock_path, value=misc.binary_encode(value), ephemeral=True + ) try: kazoo_utils.checked_commit(txn) except k_exceptions.NodeExistsError as e: @@ -585,24 +658,29 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): excp.raise_with_cause( excp.NotFound, "Job %s not found to be claimed" % job.uuid, - cause=e.failures[0]) + cause=e.failures[0], + ) if isinstance(e.failures[1], k_exceptions.NodeExistsError): _unclaimable_try_find_owner(e.failures[1]) else: excp.raise_with_cause( excp.UnclaimableJob, "Job %s claim failed due to transaction" - " not succeeding" % (job.uuid), cause=e) + " not succeeding" % (job.uuid), + cause=e, + ) @contextlib.contextmanager - def _wrap(self, job_uuid, job_path, - fail_msg_tpl="Failure: %s", ensure_known=True): + def _wrap( + self, job_uuid, job_path, fail_msg_tpl="Failure: %s", ensure_known=True + ): if job_path: fail_msg_tpl += " (%s)" % (job_path) if ensure_known: if not job_path: - raise ValueError("Unable to check if %r is a known path" - % (job_path)) + raise ValueError( + "Unable to check if %r is a known path" % (job_path) + ) if job_path not in self._known_jobs: fail_msg_tpl += ", unknown job" raise excp.NotFound(fail_msg_tpl % (job_uuid)) @@ -622,9 +700,12 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): excp.raise_with_cause(excp.JobFailure, fail_msg_tpl % (job_uuid)) def find_owner(self, job): - with self._wrap(job.uuid, job.path, - fail_msg_tpl="Owner query failure: %s", - ensure_known=False): + with self._wrap( + job.uuid, + job.path, + fail_msg_tpl="Owner query failure: %s", + ensure_known=False, + ): try: self._client.sync(job.lock_path) raw_data, _lock_stat = self._client.get(job.lock_path) @@ -637,8 +718,12 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): def _get_owner_and_data(self, job): lock_data, lock_stat = self._client.get(job.lock_path) job_data, job_stat = self._client.get(job.path) - return (misc.decode_json(lock_data), lock_stat, - misc.decode_json(job_data), job_stat) + return ( + misc.decode_json(lock_data), + lock_stat, + misc.decode_json(job_data), + job_stat, + ) def register_entity(self, entity): entity_type = entity.kind @@ -646,47 +731,58 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): entity_path = k_paths.join(self.entity_path, entity_type) try: self._client.ensure_path(entity_path) - self._client.create(k_paths.join(entity_path, entity.name), - value=misc.binary_encode( - jsonutils.dumps(entity.to_dict())), - ephemeral=True) + self._client.create( + k_paths.join(entity_path, entity.name), + value=misc.binary_encode( + jsonutils.dumps(entity.to_dict()) + ), + ephemeral=True, + ) except k_exceptions.NodeExistsError: pass except self._client.handler.timeout_exception: excp.raise_with_cause( excp.JobFailure, "Can not register entity %s under %s, operation" - " timed out" % (entity.name, entity_path)) + " timed out" % (entity.name, entity_path), + ) except k_exceptions.SessionExpiredError: excp.raise_with_cause( excp.JobFailure, "Can not register entity %s under %s, session" - " expired" % (entity.name, entity_path)) + " expired" % (entity.name, entity_path), + ) except k_exceptions.KazooException: excp.raise_with_cause( excp.JobFailure, "Can not register entity %s under %s, internal" - " error" % (entity.name, entity_path)) + " error" % (entity.name, entity_path), + ) else: raise excp.NotImplementedError( - "Not implemented for other entity type '%s'" % entity_type) + "Not implemented for other entity type '%s'" % entity_type + ) @base.check_who def consume(self, job, who): - with self._wrap(job.uuid, job.path, - fail_msg_tpl="Consumption failure: %s"): + with self._wrap( + job.uuid, job.path, fail_msg_tpl="Consumption failure: %s" + ): try: owner_data = self._get_owner_and_data(job) lock_data, lock_stat, data, data_stat = owner_data except k_exceptions.NoNodeError: - excp.raise_with_cause(excp.NotFound, - "Can not consume a job %s" - " which we can not determine" - " the owner of" % (job.uuid)) + excp.raise_with_cause( + excp.NotFound, + "Can not consume a job %s" + " which we can not determine" + " the owner of" % (job.uuid), + ) if lock_data.get("owner") != who: - raise excp.JobFailure("Can not consume a job %s" - " which is not owned by %s" - % (job.uuid, who)) + raise excp.JobFailure( + "Can not consume a job %s" + " which is not owned by %s" % (job.uuid, who) + ) txn = self._client.transaction() txn.delete(job.lock_path, version=lock_stat.version) txn.delete(job.path, version=data_stat.version) @@ -695,40 +791,46 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): @base.check_who def abandon(self, job, who): - with self._wrap(job.uuid, job.path, - fail_msg_tpl="Abandonment failure: %s"): + with self._wrap( + job.uuid, job.path, fail_msg_tpl="Abandonment failure: %s" + ): try: owner_data = self._get_owner_and_data(job) lock_data, lock_stat, data, data_stat = owner_data except k_exceptions.NoNodeError: - excp.raise_with_cause(excp.NotFound, - "Can not abandon a job %s" - " which we can not determine" - " the owner of" % (job.uuid)) + excp.raise_with_cause( + excp.NotFound, + "Can not abandon a job %s" + " which we can not determine" + " the owner of" % (job.uuid), + ) if lock_data.get("owner") != who: - raise excp.JobFailure("Can not abandon a job %s" - " which is not owned by %s" - % (job.uuid, who)) + raise excp.JobFailure( + "Can not abandon a job %s" + " which is not owned by %s" % (job.uuid, who) + ) txn = self._client.transaction() txn.delete(job.lock_path, version=lock_stat.version) kazoo_utils.checked_commit(txn) @base.check_who def trash(self, job, who): - with self._wrap(job.uuid, job.path, - fail_msg_tpl="Trash failure: %s"): + with self._wrap(job.uuid, job.path, fail_msg_tpl="Trash failure: %s"): try: owner_data = self._get_owner_and_data(job) lock_data, lock_stat, data, data_stat = owner_data except k_exceptions.NoNodeError: - excp.raise_with_cause(excp.NotFound, - "Can not trash a job %s" - " which we can not determine" - " the owner of" % (job.uuid)) + excp.raise_with_cause( + excp.NotFound, + "Can not trash a job %s" + " which we can not determine" + " the owner of" % (job.uuid), + ) if lock_data.get("owner") != who: - raise excp.JobFailure("Can not trash a job %s" - " which is not owned by %s" - % (job.uuid, who)) + raise excp.JobFailure( + "Can not trash a job %s" + " which is not owned by %s" % (job.uuid, who) + ) trash_path = job.path.replace(self.path, self.trash_path) value = misc.binary_encode(jsonutils.dumps(data)) txn = self._client.transaction() @@ -739,12 +841,18 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): def _state_change_listener(self, state): if self._last_states: - LOG.debug("Kazoo client has changed to" - " state '%s' from prior states '%s'", state, - self._last_states) + LOG.debug( + "Kazoo client has changed to" + " state '%s' from prior states '%s'", + state, + self._last_states, + ) else: - LOG.debug("Kazoo client has changed to state '%s' (from" - " its initial/uninitialized state)", state) + LOG.debug( + "Kazoo client has changed to state '%s' (from" + " its initial/uninitialized state)", + state, + ) self._last_states.appendleft(state) if state == k_states.KazooState.LOST: self._connected = False @@ -769,9 +877,10 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): while True: if not self._known_jobs: if watch.expired(): - raise excp.NotFound("Expired waiting for jobs to" - " arrive; waited %s seconds" - % watch.elapsed()) + raise excp.NotFound( + "Expired waiting for jobs to" + " arrive; waited %s seconds" % watch.elapsed() + ) # This is done since the given timeout can not be provided # to the condition variable, since we can not ensure that # when we acquire the condition that there will actually @@ -783,8 +892,11 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): fetch_func = lambda ensure_fresh: curr_jobs removal_func = lambda a_job: self._remove_job(a_job.path) return base.JobBoardIterator( - self, LOG, board_fetch_func=fetch_func, - board_removal_func=removal_func) + self, + LOG, + board_fetch_func=fetch_func, + board_removal_func=removal_func, + ) @property def connected(self): @@ -816,21 +928,27 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): try: self.close() except k_exceptions.KazooException: - LOG.exception("Failed cleaning-up after post-connection" - " initialization failed") + LOG.exception( + "Failed cleaning-up after post-connection" + " initialization failed" + ) try: if timeout is not None: timeout = float(timeout) self._client.start(timeout=timeout) self._closing = False - except (self._client.handler.timeout_exception, - k_exceptions.KazooException): - excp.raise_with_cause(excp.JobFailure, - "Failed to connect to zookeeper") + except ( + self._client.handler.timeout_exception, + k_exceptions.KazooException, + ): + excp.raise_with_cause( + excp.JobFailure, "Failed to connect to zookeeper" + ) try: if strutils.bool_from_string( - self._conf.get('check_compatible'), default=True): + self._conf.get('check_compatible'), default=True + ): kazoo_utils.check_compatible(self._client, self.MIN_ZK_VERSION) if self._worker is None and self._emit_notifications: self._worker = futurist.ThreadPoolExecutor(max_workers=1) @@ -841,18 +959,23 @@ class ZookeeperJobBoard(base.NotifyingJobBoard): self._client, self.path, func=self._on_job_posting, - allow_session_lost=True) + allow_session_lost=True, + ) self._connected = True except excp.IncompatibleVersion: with excutils.save_and_reraise_exception(): try_clean() - except (self._client.handler.timeout_exception, - k_exceptions.KazooException): + except ( + self._client.handler.timeout_exception, + k_exceptions.KazooException, + ): exc_type, exc, exc_tb = sys.exc_info() try: try_clean() - excp.raise_with_cause(excp.JobFailure, - "Failed to do post-connection" - " initialization", cause=exc) + excp.raise_with_cause( + excp.JobFailure, + "Failed to do post-connection initialization", + cause=exc, + ) finally: del (exc_type, exc, exc_tb) diff --git a/taskflow/jobs/base.py b/taskflow/jobs/base.py index 160730d85..e73f69fa5 100644 --- a/taskflow/jobs/base.py +++ b/taskflow/jobs/base.py @@ -59,18 +59,24 @@ class JobPriority(enum.Enum): try: return cls(value.upper()) except (ValueError, AttributeError): - valids = [cls.VERY_HIGH, cls.HIGH, cls.NORMAL, - cls.LOW, cls.VERY_LOW] + valids = [ + cls.VERY_HIGH, + cls.HIGH, + cls.NORMAL, + cls.LOW, + cls.VERY_LOW, + ] valids = [p.value for p in valids] - raise ValueError("'%s' is not a valid priority, valid" - " priorities are %s" % (value, valids)) + raise ValueError( + "'%s' is not a valid priority, valid" + " priorities are %s" % (value, valids) + ) @classmethod def reorder(cls, *values): """Reorders (priority, value) tuples -> priority ordered values.""" if len(values) == 0: - raise ValueError("At least one (priority, value) pair is" - " required") + raise ValueError("At least one (priority, value) pair is required") elif len(values) == 1: v1 = values[0] # Even though this isn't used, we do the conversion because @@ -81,8 +87,13 @@ class JobPriority(enum.Enum): return v1[1] else: # Order very very much matters in this tuple... - priority_ordering = (cls.VERY_HIGH, cls.HIGH, - cls.NORMAL, cls.LOW, cls.VERY_LOW) + priority_ordering = ( + cls.VERY_HIGH, + cls.HIGH, + cls.NORMAL, + cls.LOW, + cls.VERY_LOW, + ) if len(values) == 2: # It's common to use this in a 2 tuple situation, so # make it avoid all the needed complexity that is done @@ -99,7 +110,7 @@ class JobPriority(enum.Enum): return v2[1], v1[1] else: buckets = collections.defaultdict(list) - for (p, v) in values: + for p, v in values: p = cls.convert(p) buckets[p].append(v) values = [] @@ -127,9 +138,16 @@ class Job(metaclass=abc.ABCMeta): reverting... """ - def __init__(self, board, name, - uuid=None, details=None, backend=None, - book=None, book_data=None): + def __init__( + self, + board, + name, + uuid=None, + details=None, + backend=None, + book=None, + book_data=None, + ): if uuid: self._uuid = uuid else: @@ -170,9 +188,14 @@ class Job(metaclass=abc.ABCMeta): def priority(self): """The :py:class:`~.JobPriority` of this job.""" - def wait(self, timeout=None, - delay=0.01, delay_multiplier=2.0, max_delay=60.0, - sleep_func=time.sleep): + def wait( + self, + timeout=None, + delay=0.01, + delay_multiplier=2.0, + max_delay=60.0, + sleep_func=time.sleep, + ): """Wait for job to enter completion state. If the job has not completed in the given timeout, then return false, @@ -194,8 +217,9 @@ class Job(metaclass=abc.ABCMeta): w.start() else: w = None - delay_gen = iter_utils.generate_delays(delay, max_delay, - multiplier=delay_multiplier) + delay_gen = iter_utils.generate_delays( + delay, max_delay, multiplier=delay_multiplier + ) while True: if w is not None and w.expired(): return False @@ -254,11 +278,14 @@ class Job(metaclass=abc.ABCMeta): """The non-uniquely identifying name of this job.""" return self._name - @tenacity.retry(retry=tenacity.retry_if_exception_type( - exception_types=excp.StorageFailure), - stop=tenacity.stop_after_attempt(RETRY_ATTEMPTS), - wait=tenacity.wait_fixed(RETRY_WAIT_TIMEOUT), - reraise=True) + @tenacity.retry( + retry=tenacity.retry_if_exception_type( + exception_types=excp.StorageFailure + ), + stop=tenacity.stop_after_attempt(RETRY_ATTEMPTS), + wait=tenacity.wait_fixed(RETRY_WAIT_TIMEOUT), + reraise=True, + ) def _load_book(self): book_uuid = self.book_uuid if self._backend is not None and book_uuid is not None: @@ -276,8 +303,8 @@ class Job(metaclass=abc.ABCMeta): """Pretty formats the job into something *more* meaningful.""" cls_name = type(self).__name__ return "{}: {} (priority={}, uuid={}, details={})".format( - cls_name, self.name, self.priority, - self.uuid, self.details) + cls_name, self.name, self.priority, self.uuid, self.details + ) class JobBoardIterator: @@ -296,9 +323,15 @@ class JobBoardIterator: _UNCLAIMED_JOB_STATES = (states.UNCLAIMED,) _JOB_STATES = (states.UNCLAIMED, states.COMPLETE, states.CLAIMED) - def __init__(self, board, logger, - board_fetch_func=None, board_removal_func=None, - only_unclaimed=False, ensure_fresh=False): + def __init__( + self, + board, + logger, + board_fetch_func=None, + board_removal_func=None, + only_unclaimed=False, + ensure_fresh=False, + ): self._board = board self._logger = logger self._board_removal_func = board_removal_func @@ -328,8 +361,11 @@ class JobBoardIterator: if maybe_job.state in allowed_states: job = maybe_job except excp.JobFailure: - self._logger.warning("Failed determining the state of" - " job '%s'", maybe_job, exc_info=True) + self._logger.warning( + "Failed determining the state of job '%s'", + maybe_job, + exc_info=True, + ) except excp.NotFound: # Attempt to clean this off the board now that we found # it wasn't really there (this **must** gracefully handle @@ -343,8 +379,8 @@ class JobBoardIterator: if not self._fetched: if self._board_fetch_func is not None: self._jobs.extend( - self._board_fetch_func( - ensure_fresh=self.ensure_fresh)) + self._board_fetch_func(ensure_fresh=self.ensure_fresh) + ) self._fetched = True job = self._next_job() if job is None: @@ -562,6 +598,7 @@ class NotifyingJobBoard(JobBoard): separate dedicated thread when they occur, so ensure that all callbacks registered are thread safe (and block for as little time as possible). """ + def __init__(self, name, conf): super().__init__(name, conf) self.notifier = notifier.Notifier() @@ -569,6 +606,7 @@ class NotifyingJobBoard(JobBoard): # Internal helpers for usage by board implementations... + def check_who(meth): @functools.wraps(meth) @@ -582,8 +620,15 @@ def check_who(meth): return wrapper -def format_posting(uuid, name, created_on=None, last_modified=None, - details=None, book=None, priority=JobPriority.NORMAL): +def format_posting( + uuid, + name, + created_on=None, + last_modified=None, + details=None, + book=None, + priority=JobPriority.NORMAL, +): posting = { 'uuid': uuid, 'name': name, diff --git a/taskflow/listeners/base.py b/taskflow/listeners/base.py index 089ec7cf1..dec77a5ce 100644 --- a/taskflow/listeners/base.py +++ b/taskflow/listeners/base.py @@ -24,8 +24,12 @@ from taskflow.types import notifier LOG = logging.getLogger(__name__) #: These states will results be usable, other states do not produce results. -FINISH_STATES = (states.FAILURE, states.SUCCESS, - states.REVERTED, states.REVERT_FAILURE) +FINISH_STATES = ( + states.FAILURE, + states.SUCCESS, + states.REVERTED, + states.REVERT_FAILURE, +) #: What is listened for by default... DEFAULT_LISTEN_FOR = (notifier.Notifier.ANY,) @@ -53,8 +57,7 @@ def _bulk_deregister(notifier, registered, details_filter=None): """Bulk deregisters callbacks associated with many states.""" while registered: state, cb = registered.pop() - notifier.deregister(state, cb, - details_filter=details_filter) + notifier.deregister(state, cb, details_filter=details_filter) def _bulk_register(watch_states, notifier, cb, details_filter=None): @@ -62,15 +65,16 @@ def _bulk_register(watch_states, notifier, cb, details_filter=None): registered = [] try: for state in watch_states: - if not notifier.is_registered(state, cb, - details_filter=details_filter): - notifier.register(state, cb, - details_filter=details_filter) + if not notifier.is_registered( + state, cb, details_filter=details_filter + ): + notifier.register(state, cb, details_filter=details_filter) registered.append((state, cb)) except ValueError: with excutils.save_and_reraise_exception(): - _bulk_deregister(notifier, registered, - details_filter=details_filter) + _bulk_deregister( + notifier, registered, details_filter=details_filter + ) else: return registered @@ -88,10 +92,13 @@ class Listener: methods (in this class, they do nothing). """ - def __init__(self, engine, - task_listen_for=DEFAULT_LISTEN_FOR, - flow_listen_for=DEFAULT_LISTEN_FOR, - retry_listen_for=DEFAULT_LISTEN_FOR): + def __init__( + self, + engine, + task_listen_for=DEFAULT_LISTEN_FOR, + flow_listen_for=DEFAULT_LISTEN_FOR, + retry_listen_for=DEFAULT_LISTEN_FOR, + ): if not task_listen_for: task_listen_for = [] if not retry_listen_for: @@ -117,33 +124,44 @@ class Listener: def deregister(self): if 'task' in self._registered: - _bulk_deregister(self._engine.atom_notifier, - self._registered['task'], - details_filter=_task_matcher) + _bulk_deregister( + self._engine.atom_notifier, + self._registered['task'], + details_filter=_task_matcher, + ) del self._registered['task'] if 'retry' in self._registered: - _bulk_deregister(self._engine.atom_notifier, - self._registered['retry'], - details_filter=_retry_matcher) + _bulk_deregister( + self._engine.atom_notifier, + self._registered['retry'], + details_filter=_retry_matcher, + ) del self._registered['retry'] if 'flow' in self._registered: - _bulk_deregister(self._engine.notifier, - self._registered['flow']) + _bulk_deregister(self._engine.notifier, self._registered['flow']) del self._registered['flow'] def register(self): if 'task' not in self._registered: self._registered['task'] = _bulk_register( - self._listen_for['task'], self._engine.atom_notifier, - self._task_receiver, details_filter=_task_matcher) + self._listen_for['task'], + self._engine.atom_notifier, + self._task_receiver, + details_filter=_task_matcher, + ) if 'retry' not in self._registered: self._registered['retry'] = _bulk_register( - self._listen_for['retry'], self._engine.atom_notifier, - self._retry_receiver, details_filter=_retry_matcher) + self._listen_for['retry'], + self._engine.atom_notifier, + self._retry_receiver, + details_filter=_retry_matcher, + ) if 'flow' not in self._registered: self._registered['flow'] = _bulk_register( - self._listen_for['flow'], self._engine.notifier, - self._flow_receiver) + self._listen_for['flow'], + self._engine.notifier, + self._flow_receiver, + ) def __enter__(self): self.register() @@ -154,8 +172,11 @@ class Listener: self.deregister() except Exception: # Don't let deregistering throw exceptions - LOG.warning("Failed deregistering listeners from engine %s", - self._engine, exc_info=True) + LOG.warning( + "Failed deregistering listeners from engine %s", + self._engine, + exc_info=True, + ) class DumpingListener(Listener, metaclass=abc.ABCMeta): @@ -174,9 +195,14 @@ class DumpingListener(Listener, metaclass=abc.ABCMeta): """Dumps the provided *templated* message to some output.""" def _flow_receiver(self, state, details): - self._dump("%s has moved flow '%s' (%s) into state '%s'" - " from state '%s'", self._engine, details['flow_name'], - details['flow_uuid'], state, details['old_state']) + self._dump( + "%s has moved flow '%s' (%s) into state '%s' from state '%s'", + self._engine, + details['flow_name'], + details['flow_uuid'], + state, + details['old_state'], + ) def _task_receiver(self, state, details): if state in FINISH_STATES: @@ -187,12 +213,24 @@ class DumpingListener(Listener, metaclass=abc.ABCMeta): if result.exc_info: exc_info = tuple(result.exc_info) was_failure = True - self._dump("%s has moved task '%s' (%s) into state '%s'" - " from state '%s' with result '%s' (failure=%s)", - self._engine, details['task_name'], - details['task_uuid'], state, details['old_state'], - result, was_failure, exc_info=exc_info) + self._dump( + "%s has moved task '%s' (%s) into state '%s'" + " from state '%s' with result '%s' (failure=%s)", + self._engine, + details['task_name'], + details['task_uuid'], + state, + details['old_state'], + result, + was_failure, + exc_info=exc_info, + ) else: - self._dump("%s has moved task '%s' (%s) into state '%s'" - " from state '%s'", self._engine, details['task_name'], - details['task_uuid'], state, details['old_state']) + self._dump( + "%s has moved task '%s' (%s) into state '%s' from state '%s'", + self._engine, + details['task_name'], + details['task_uuid'], + state, + details['old_state'], + ) diff --git a/taskflow/listeners/capturing.py b/taskflow/listeners/capturing.py index 9c33c5052..535d76d04 100644 --- a/taskflow/listeners/capturing.py +++ b/taskflow/listeners/capturing.py @@ -51,23 +51,31 @@ class CaptureListener(base.Listener): #: Kind that denotes a 'retry' capture. RETRY = 'retry' - def __init__(self, engine, - task_listen_for=base.DEFAULT_LISTEN_FOR, - flow_listen_for=base.DEFAULT_LISTEN_FOR, - retry_listen_for=base.DEFAULT_LISTEN_FOR, - # Easily override what you want captured and where it - # should save into and what should be skipped... - capture_flow=True, capture_task=True, capture_retry=True, - # Skip capturing *all* tasks, all retries, all flows... - skip_tasks=None, skip_retries=None, skip_flows=None, - # Provide your own list (or previous list) to accumulate - # into... - values=None): + def __init__( + self, + engine, + task_listen_for=base.DEFAULT_LISTEN_FOR, + flow_listen_for=base.DEFAULT_LISTEN_FOR, + retry_listen_for=base.DEFAULT_LISTEN_FOR, + # Easily override what you want captured and where it + # should save into and what should be skipped... + capture_flow=True, + capture_task=True, + capture_retry=True, + # Skip capturing *all* tasks, all retries, all flows... + skip_tasks=None, + skip_retries=None, + skip_flows=None, + # Provide your own list (or previous list) to accumulate + # into... + values=None, + ): super().__init__( engine, task_listen_for=task_listen_for, flow_listen_for=flow_listen_for, - retry_listen_for=retry_listen_for) + retry_listen_for=retry_listen_for, + ) self._capture_flow = capture_flow self._capture_task = capture_task self._capture_retry = capture_retry @@ -87,17 +95,20 @@ class CaptureListener(base.Listener): def _task_receiver(self, state, details): if self._capture_task: if details['task_name'] not in self._skip_tasks: - self.values.append(self._format_capture(self.TASK, - state, details)) + self.values.append( + self._format_capture(self.TASK, state, details) + ) def _retry_receiver(self, state, details): if self._capture_retry: if details['retry_name'] not in self._skip_retries: - self.values.append(self._format_capture(self.RETRY, - state, details)) + self.values.append( + self._format_capture(self.RETRY, state, details) + ) def _flow_receiver(self, state, details): if self._capture_flow: if details['flow_name'] not in self._skip_flows: - self.values.append(self._format_capture(self.FLOW, - state, details)) + self.values.append( + self._format_capture(self.FLOW, state, details) + ) diff --git a/taskflow/listeners/claims.py b/taskflow/listeners/claims.py index de03bac2c..afe6f5609 100644 --- a/taskflow/listeners/claims.py +++ b/taskflow/listeners/claims.py @@ -55,8 +55,9 @@ class CheckingClaimListener(base.Listener): self._on_job_loss = self._suspend_engine_on_loss else: if not callable(on_job_loss): - raise ValueError("Custom 'on_job_loss' handler must be" - " callable") + raise ValueError( + "Custom 'on_job_loss' handler must be callable" + ) self._on_job_loss = on_job_loss def _suspend_engine_on_loss(self, engine, state, details): @@ -64,9 +65,14 @@ class CheckingClaimListener(base.Listener): try: engine.suspend() except exceptions.TaskFlowException as e: - LOG.warning("Failed suspending engine '%s', (previously owned by" - " '%s'):%s%s", engine, self._owner, os.linesep, - e.pformat()) + LOG.warning( + "Failed suspending engine '%s', (previously owned by" + " '%s'):%s%s", + engine, + self._owner, + os.linesep, + e.pformat(), + ) def _flow_receiver(self, state, details): self._claim_checker(state, details) @@ -88,10 +94,15 @@ class CheckingClaimListener(base.Listener): def _claim_checker(self, state, details): if not self._has_been_lost(): - LOG.debug("Job '%s' is still claimed (actively owned by '%s')", - self._job, self._owner) + LOG.debug( + "Job '%s' is still claimed (actively owned by '%s')", + self._job, + self._owner, + ) else: - LOG.warning("Job '%s' has lost its claim" - " (previously owned by '%s')", - self._job, self._owner) + LOG.warning( + "Job '%s' has lost its claim (previously owned by '%s')", + self._job, + self._owner, + ) self._on_job_loss(self._engine, state, details) diff --git a/taskflow/listeners/logging.py b/taskflow/listeners/logging.py index 615a5a309..165f645d5 100644 --- a/taskflow/listeners/logging.py +++ b/taskflow/listeners/logging.py @@ -38,15 +38,21 @@ class LoggingListener(base.DumpingListener): #: Default logger to use if one is not provided on construction. _LOGGER = None - def __init__(self, engine, - task_listen_for=base.DEFAULT_LISTEN_FOR, - flow_listen_for=base.DEFAULT_LISTEN_FOR, - retry_listen_for=base.DEFAULT_LISTEN_FOR, - log=None, - level=logging.DEBUG): + def __init__( + self, + engine, + task_listen_for=base.DEFAULT_LISTEN_FOR, + flow_listen_for=base.DEFAULT_LISTEN_FOR, + retry_listen_for=base.DEFAULT_LISTEN_FOR, + log=None, + level=logging.DEBUG, + ): super().__init__( - engine, task_listen_for=task_listen_for, - flow_listen_for=flow_listen_for, retry_listen_for=retry_listen_for) + engine, + task_listen_for=task_listen_for, + flow_listen_for=flow_listen_for, + retry_listen_for=retry_listen_for, + ) self._logger = misc.pick_first_not_none(log, self._LOGGER, LOG) self._level = level @@ -102,18 +108,26 @@ class DynamicLoggingListener(base.Listener): #: States which are triggered under some type of failure. _FAILURE_STATES = (states.FAILURE, states.REVERT_FAILURE) - def __init__(self, engine, - task_listen_for=base.DEFAULT_LISTEN_FOR, - flow_listen_for=base.DEFAULT_LISTEN_FOR, - retry_listen_for=base.DEFAULT_LISTEN_FOR, - log=None, failure_level=logging.WARNING, - level=logging.DEBUG, hide_inputs_outputs_of=(), - fail_formatter=None, - mask_inputs_keys=(), - mask_outputs_keys=()): + def __init__( + self, + engine, + task_listen_for=base.DEFAULT_LISTEN_FOR, + flow_listen_for=base.DEFAULT_LISTEN_FOR, + retry_listen_for=base.DEFAULT_LISTEN_FOR, + log=None, + failure_level=logging.WARNING, + level=logging.DEBUG, + hide_inputs_outputs_of=(), + fail_formatter=None, + mask_inputs_keys=(), + mask_outputs_keys=(), + ): super().__init__( - engine, task_listen_for=task_listen_for, - flow_listen_for=flow_listen_for, retry_listen_for=retry_listen_for) + engine, + task_listen_for=task_listen_for, + flow_listen_for=flow_listen_for, + retry_listen_for=retry_listen_for, + ) self._failure_level = failure_level self._level = level self._task_log_levels = { @@ -135,16 +149,22 @@ class DynamicLoggingListener(base.Listener): self._engine, hide_inputs_outputs_of=self._hide_inputs_outputs_of, mask_inputs_keys=self._mask_inputs_keys, - mask_outputs_keys=self._mask_outputs_keys) + mask_outputs_keys=self._mask_outputs_keys, + ) else: self._fail_formatter = fail_formatter def _flow_receiver(self, state, details): """Gets called on flow state changes.""" level = self._flow_log_levels.get(state, self._level) - self._logger.log(level, "Flow '%s' (%s) transitioned into state '%s'" - " from state '%s'", details['flow_name'], - details['flow_uuid'], state, details.get('old_state')) + self._logger.log( + level, + "Flow '%s' (%s) transitioned into state '%s' from state '%s'", + details['flow_name'], + details['flow_uuid'], + state, + details.get('old_state'), + ) def _task_receiver(self, state, details): """Gets called on task state changes.""" @@ -156,41 +176,74 @@ class DynamicLoggingListener(base.Listener): result = details.get('result') if isinstance(result, failure.Failure): exc_info, fail_details = self._fail_formatter.format( - result, _make_matcher(task_name)) + result, _make_matcher(task_name) + ) if fail_details: - self._logger.log(self._failure_level, - "Task '%s' (%s) transitioned into state" - " '%s' from state '%s'%s%s", - task_name, task_uuid, state, - details['old_state'], os.linesep, - fail_details, exc_info=exc_info) + self._logger.log( + self._failure_level, + "Task '%s' (%s) transitioned into state" + " '%s' from state '%s'%s%s", + task_name, + task_uuid, + state, + details['old_state'], + os.linesep, + fail_details, + exc_info=exc_info, + ) else: - self._logger.log(self._failure_level, - "Task '%s' (%s) transitioned into state" - " '%s' from state '%s'", task_name, - task_uuid, state, details['old_state'], - exc_info=exc_info) + self._logger.log( + self._failure_level, + "Task '%s' (%s) transitioned into state" + " '%s' from state '%s'", + task_name, + task_uuid, + state, + details['old_state'], + exc_info=exc_info, + ) else: # Otherwise, depending on the enabled logging level/state we # will show or hide results that the task may have produced # during execution. level = self._task_log_levels.get(state, self._level) - show_result = (self._logger.isEnabledFor(self._level) - or state == states.FAILURE) - if show_result and \ - task_name not in self._hide_inputs_outputs_of: - self._logger.log(level, "Task '%s' (%s) transitioned into" - " state '%s' from state '%s' with" - " result '%s'", task_name, task_uuid, - state, details['old_state'], result) + show_result = ( + self._logger.isEnabledFor(self._level) + or state == states.FAILURE + ) + if ( + show_result + and task_name not in self._hide_inputs_outputs_of + ): + self._logger.log( + level, + "Task '%s' (%s) transitioned into" + " state '%s' from state '%s' with" + " result '%s'", + task_name, + task_uuid, + state, + details['old_state'], + result, + ) else: - self._logger.log(level, "Task '%s' (%s) transitioned into" - " state '%s' from state '%s'", - task_name, task_uuid, state, - details['old_state']) + self._logger.log( + level, + "Task '%s' (%s) transitioned into" + " state '%s' from state '%s'", + task_name, + task_uuid, + state, + details['old_state'], + ) else: # Just a intermediary state, carry on! level = self._task_log_levels.get(state, self._level) - self._logger.log(level, "Task '%s' (%s) transitioned into state" - " '%s' from state '%s'", task_name, task_uuid, - state, details['old_state']) + self._logger.log( + level, + "Task '%s' (%s) transitioned into state '%s' from state '%s'", + task_name, + task_uuid, + state, + details['old_state'], + ) diff --git a/taskflow/listeners/printing.py b/taskflow/listeners/printing.py index 03d652777..dfe133d4d 100644 --- a/taskflow/listeners/printing.py +++ b/taskflow/listeners/printing.py @@ -20,14 +20,21 @@ from taskflow.listeners import base class PrintingListener(base.DumpingListener): """Writes the task and flow notifications messages to stdout or stderr.""" - def __init__(self, engine, - task_listen_for=base.DEFAULT_LISTEN_FOR, - flow_listen_for=base.DEFAULT_LISTEN_FOR, - retry_listen_for=base.DEFAULT_LISTEN_FOR, - stderr=False): + + def __init__( + self, + engine, + task_listen_for=base.DEFAULT_LISTEN_FOR, + flow_listen_for=base.DEFAULT_LISTEN_FOR, + retry_listen_for=base.DEFAULT_LISTEN_FOR, + stderr=False, + ): super().__init__( - engine, task_listen_for=task_listen_for, - flow_listen_for=flow_listen_for, retry_listen_for=retry_listen_for) + engine, + task_listen_for=task_listen_for, + flow_listen_for=flow_listen_for, + retry_listen_for=retry_listen_for, + ) if stderr: self._file = sys.stderr else: @@ -37,5 +44,6 @@ class PrintingListener(base.DumpingListener): print(message % args, file=self._file) exc_info = kwargs.get('exc_info') if exc_info is not None: - traceback.print_exception(exc_info[0], exc_info[1], exc_info[2], - file=self._file) + traceback.print_exception( + exc_info[0], exc_info[1], exc_info[2], file=self._file + ) diff --git a/taskflow/listeners/timing.py b/taskflow/listeners/timing.py index 284b4600a..1f18255b2 100644 --- a/taskflow/listeners/timing.py +++ b/taskflow/listeners/timing.py @@ -25,8 +25,9 @@ from taskflow import states STARTING_STATES = frozenset((states.RUNNING, states.REVERTING)) FINISHED_STATES = frozenset(base.FINISH_STATES + (states.REVERTED,)) -WATCH_STATES = frozenset(itertools.chain(FINISHED_STATES, STARTING_STATES, - [states.PENDING])) +WATCH_STATES = frozenset( + itertools.chain(FINISHED_STATES, STARTING_STATES, [states.PENDING]) +) LOG = logging.getLogger(__name__) @@ -45,10 +46,11 @@ class DurationListener(base.Listener): to storage. It saves the duration in seconds as float value to task metadata with key ``'duration'``. """ + def __init__(self, engine): - super().__init__(engine, - task_listen_for=WATCH_STATES, - flow_listen_for=WATCH_STATES) + super().__init__( + engine, task_listen_for=WATCH_STATES, flow_listen_for=WATCH_STATES + ) self._timers = {co.TASK: {}, co.FLOW: {}} def deregister(self): @@ -58,9 +60,12 @@ class DurationListener(base.Listener): for item_type, timers in self._timers.items(): leftover_timers = len(timers) if leftover_timers: - LOG.warning("%s %s(s) did not enter %s states", - leftover_timers, - item_type, FINISHED_STATES) + LOG.warning( + "%s %s(s) did not enter %s states", + leftover_timers, + item_type, + FINISHED_STATES, + ) timers.clear() def _record_ending(self, timer, item_type, item_name, state): @@ -76,8 +81,13 @@ class DurationListener(base.Listener): else: storage.update_atom_metadata(item_name, meta_update) except exc.StorageFailure: - LOG.warning("Failure to store duration update %s for %s %s", - meta_update, item_type, item_name, exc_info=True) + LOG.warning( + "Failure to store duration update %s for %s %s", + meta_update, + item_type, + item_name, + exc_info=True, + ) def _task_receiver(self, state, details): task_name = details['task_name'] @@ -110,10 +120,11 @@ class PrintingDurationListener(DurationListener): self._printer = printer def _record_ending(self, timer, item_type, item_name, state): - super()._record_ending( - timer, item_type, item_name, state) - self._printer("It took %s '%s' %0.2f seconds to" - " finish." % (item_type, item_name, timer.elapsed())) + super()._record_ending(timer, item_type, item_name, state) + self._printer( + "It took %s '%s' %0.2f seconds to" + " finish." % (item_type, item_name, timer.elapsed()) + ) def _receiver(self, item_type, item_name, state): super()._receiver(item_type, item_name, state) @@ -132,13 +143,19 @@ class EventTimeListener(base.Listener): This information can be later extracted/examined to derive durations... """ - def __init__(self, engine, - task_listen_for=base.DEFAULT_LISTEN_FOR, - flow_listen_for=base.DEFAULT_LISTEN_FOR, - retry_listen_for=base.DEFAULT_LISTEN_FOR): + def __init__( + self, + engine, + task_listen_for=base.DEFAULT_LISTEN_FOR, + flow_listen_for=base.DEFAULT_LISTEN_FOR, + retry_listen_for=base.DEFAULT_LISTEN_FOR, + ): super().__init__( - engine, task_listen_for=task_listen_for, - flow_listen_for=flow_listen_for, retry_listen_for=retry_listen_for) + engine, + task_listen_for=task_listen_for, + flow_listen_for=flow_listen_for, + retry_listen_for=retry_listen_for, + ) def _record_atom_event(self, state, atom_name): meta_update = {'%s-timestamp' % state: time.time()} @@ -146,8 +163,12 @@ class EventTimeListener(base.Listener): # Don't let storage failures throw exceptions in a listener method. self._engine.storage.update_atom_metadata(atom_name, meta_update) except exc.StorageFailure: - LOG.warning("Failure to store timestamp %s for atom %s", - meta_update, atom_name, exc_info=True) + LOG.warning( + "Failure to store timestamp %s for atom %s", + meta_update, + atom_name, + exc_info=True, + ) def _flow_receiver(self, state, details): meta_update = {'%s-timestamp' % state: time.time()} @@ -155,8 +176,12 @@ class EventTimeListener(base.Listener): # Don't let storage failures throw exceptions in a listener method. self._engine.storage.update_flow_metadata(meta_update) except exc.StorageFailure: - LOG.warning("Failure to store timestamp %s for flow %s", - meta_update, details['flow_name'], exc_info=True) + LOG.warning( + "Failure to store timestamp %s for flow %s", + meta_update, + details['flow_name'], + exc_info=True, + ) def _task_receiver(self, state, details): self._record_atom_event(state, details['task_name']) diff --git a/taskflow/logging.py b/taskflow/logging.py index f191df706..2ad7ee067 100644 --- a/taskflow/logging.py +++ b/taskflow/logging.py @@ -32,12 +32,11 @@ ERROR = logging.ERROR FATAL = logging.FATAL INFO = logging.INFO NOTSET = logging.NOTSET -WARN = logging.WARN +WARN = logging.WARNING WARNING = logging.WARNING class _TraceLoggerAdapter(logging.LoggerAdapter): - def trace(self, msg, *args, **kwargs): """Delegate a trace call to the underlying logger.""" self.log(TRACE, msg, *args, **kwargs) diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index 292d5b826..2378ce2ed 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -108,13 +108,23 @@ class Flow(flow.Flow): if decider is not None: if not callable(decider): raise ValueError("Decider boolean callback must be callable") - self._swap(self._link(u, v, manual=True, - decider=decider, decider_depth=decider_depth)) + self._swap( + self._link( + u, v, manual=True, decider=decider, decider_depth=decider_depth + ) + ) return self - def _link(self, u, v, graph=None, - reason=None, manual=False, decider=None, - decider_depth=None): + def _link( + self, + u, + v, + graph=None, + reason=None, + manual=False, + decider=None, + decider_depth=None, + ): mutable_graph = True if graph is None: graph = self._graph @@ -133,8 +143,10 @@ class Flow(flow.Flow): pass if decider_depth is not None: if decider is None: - raise ValueError("Decider depth requires a decider to be" - " provided along with it") + raise ValueError( + "Decider depth requires a decider to be" + " provided along with it" + ) else: decider_depth = de.Depth.translate(decider_depth) attrs[flow.LINK_DECIDER_DEPTH] = decider_depth @@ -158,10 +170,12 @@ class Flow(flow.Flow): direct access to the underlying graph). """ if not graph.is_directed_acyclic(): - raise exc.DependencyFailure("No path through the node(s) in the" - " graph produces an ordering that" - " will allow for logical" - " edge traversal") + raise exc.DependencyFailure( + "No path through the node(s) in the" + " graph produces an ordering that" + " will allow for logical" + " edge traversal" + ) self._graph = graph.freeze() def add(self, *nodes, **kwargs): @@ -222,8 +236,9 @@ class Flow(flow.Flow): provided[value].append(self._retry) for node in self._graph.nodes: - for value in self._unsatisfied_requires(node, self._graph, - retry_provides): + for value in self._unsatisfied_requires( + node, self._graph, retry_provides + ): required[value].append(node) for value in node.provides: provided[value].append(node) @@ -237,8 +252,9 @@ class Flow(flow.Flow): # Try to find a valid provider. if resolve_requires: - for value in self._unsatisfied_requires(node, tmp_graph, - retry_provides): + for value in self._unsatisfied_requires( + node, tmp_graph, retry_provides + ): if value in provided: providers = provided[value] if len(providers) > 1: @@ -248,12 +264,19 @@ class Flow(flow.Flow): " adding '%(node)s', multiple" " providers %(providers)s found for" " required symbol '%(value)s'" - % dict(node=node.name, - providers=sorted(provider_names), - value=value)) + % dict( + node=node.name, + providers=sorted(provider_names), + value=value, + ) + ) else: - self._link(providers[0], node, - graph=tmp_graph, reason=value) + self._link( + providers[0], + node, + graph=tmp_graph, + reason=value, + ) else: required[value].append(node) @@ -266,8 +289,12 @@ class Flow(flow.Flow): if value in required: for requiree in list(required[value]): if requiree is not node: - self._link(node, requiree, - graph=tmp_graph, reason=value) + self._link( + node, + requiree, + graph=tmp_graph, + reason=value, + ) required[value].remove(requiree) self._swap(tmp_graph) @@ -305,8 +332,9 @@ class Flow(flow.Flow): retry_provides.update(self._retry.provides) g = self._get_subgraph() for node in g.nodes: - requires.update(self._unsatisfied_requires(node, g, - retry_provides)) + requires.update( + self._unsatisfied_requires(node, g, retry_provides) + ) return frozenset(requires) @@ -365,6 +393,7 @@ class TargetedFlow(Flow): nodes = [self._target] nodes.extend(self._graph.bfs_predecessors_iter(self._target)) self._subgraph = gr.DiGraph( - incoming_graph_data=self._graph.subgraph(nodes)) + incoming_graph_data=self._graph.subgraph(nodes) + ) self._subgraph.freeze() return self._subgraph diff --git a/taskflow/patterns/linear_flow.py b/taskflow/patterns/linear_flow.py index 38a162987..fef0c8695 100644 --- a/taskflow/patterns/linear_flow.py +++ b/taskflow/patterns/linear_flow.py @@ -44,8 +44,11 @@ class Flow(flow.Flow): if not self._graph.has_node(item): self._graph.add_node(item) if self._last_item is not self._no_last_item: - self._graph.add_edge(self._last_item, item, - attr_dict={flow.LINK_INVARIANT: True}) + self._graph.add_edge( + self._last_item, + item, + attr_dict={flow.LINK_INVARIANT: True}, + ) self._last_item = item return self diff --git a/taskflow/persistence/backends/__init__.py b/taskflow/persistence/backends/__init__.py index a4a1a8303..4dc8bf504 100644 --- a/taskflow/persistence/backends/__init__.py +++ b/taskflow/persistence/backends/__init__.py @@ -56,10 +56,13 @@ def fetch(conf, namespace=BACKEND_NAMESPACE, **kwargs): backend = backend.split("+", 1)[0] LOG.debug('Looking for %r backend driver in %r', backend, namespace) try: - mgr = driver.DriverManager(namespace, backend, - invoke_on_load=True, - invoke_args=(conf,), - invoke_kwds=kwargs) + mgr = driver.DriverManager( + namespace, + backend, + invoke_on_load=True, + invoke_args=(conf,), + invoke_kwds=kwargs, + ) return mgr.driver except RuntimeError as e: raise exc.NotFound(f"Could not find backend {backend}: {e}") diff --git a/taskflow/persistence/backends/impl_dir.py b/taskflow/persistence/backends/impl_dir.py index 716a841a5..dc181406c 100644 --- a/taskflow/persistence/backends/impl_dir.py +++ b/taskflow/persistence/backends/impl_dir.py @@ -36,12 +36,13 @@ def _storagefailure_wrapper(): raise except Exception as e: if isinstance(e, (IOError, OSError)) and e.errno == errno.ENOENT: - exc.raise_with_cause(exc.NotFound, - 'Item not found: %s' % e.filename, - cause=e) + exc.raise_with_cause( + exc.NotFound, 'Item not found: %s' % e.filename, cause=e + ) else: - exc.raise_with_cause(exc.StorageFailure, - "Storage backend internal error", cause=e) + exc.raise_with_cause( + exc.StorageFailure, "Storage backend internal error", cause=e + ) class DirBackend(path_based.PathBasedBackend): @@ -71,8 +72,9 @@ class DirBackend(path_based.PathBasedBackend): if max_cache_size is not None: max_cache_size = int(max_cache_size) if max_cache_size < 1: - raise ValueError("Maximum cache size must be greater than" - " or equal to one") + raise ValueError( + "Maximum cache size must be greater than or equal to one" + ) self.file_cache = cachetools.LRUCache(max_cache_size) else: self.file_cache = {} @@ -103,8 +105,7 @@ class Connection(path_based.PathBasedConnection): return cache_info['data'] def _write_to(self, filename, contents): - contents = misc.binary_encode(contents, - encoding=self.backend.encoding) + contents = misc.binary_encode(contents, encoding=self.backend.encoding) with open(filename, 'wb') as fp: fp.write(contents) self.backend.file_cache.pop(filename, None) @@ -139,8 +140,11 @@ class Connection(path_based.PathBasedConnection): else: filter_func = os.path.islink with _storagefailure_wrapper(): - return [child for child in os.listdir(path) - if filter_func(self._join_path(path, child))] + return [ + child + for child in os.listdir(path) + if filter_func(self._join_path(path, child)) + ] def _ensure_path(self, path): with _storagefailure_wrapper(): diff --git a/taskflow/persistence/backends/impl_memory.py b/taskflow/persistence/backends/impl_memory.py index c25452872..779ef669f 100644 --- a/taskflow/persistence/backends/impl_memory.py +++ b/taskflow/persistence/backends/impl_memory.py @@ -73,12 +73,15 @@ class FakeFilesystem: def normpath(cls, path): """Return a normalized absolutized version of the pathname path.""" if not path: - raise ValueError("This filesystem can only normalize paths" - " that are not empty") + raise ValueError( + "This filesystem can only normalize paths that are not empty" + ) if not path.startswith(cls.root_path): - raise ValueError("This filesystem can only normalize" - " paths that start with %s: '%s' is not" - " valid" % (cls.root_path, path)) + raise ValueError( + "This filesystem can only normalize" + " paths that start with %s: '%s' is not" + " valid" % (cls.root_path, path) + ) return pp.normpath(path) #: Split a pathname into a tuple of ``(head, tail)``. @@ -108,8 +111,7 @@ class FakeFilesystem: return node = self._root for piece in self._iter_pieces(path): - child_node = node.find(piece, only_direct=True, - include_self=False) + child_node = node.find(piece, only_direct=True, include_self=False) if child_node is None: child_node = self._insert_child(node, piece) node = child_node @@ -154,9 +156,10 @@ class FakeFilesystem: if links is None: links = [] if path in links: - raise ValueError("Recursive link following not" - " allowed (loop %s detected)" - % (links + [path])) + raise ValueError( + "Recursive link following not" + " allowed (loop %s detected)" % (links + [path]) + ) else: links.append(path) return self._get_item(path, links=links) @@ -186,8 +189,9 @@ class FakeFilesystem: selector_func = self._metadata_path_selector else: selector_func = self._up_to_root_selector - return [selector_func(node, child_node) - for child_node in node.bfs_iter()] + return [ + selector_func(node, child_node) for child_node in node.bfs_iter() + ] def ls(self, path, absolute=False): """Return list of all children of the given path (not recursive).""" @@ -197,8 +201,9 @@ class FakeFilesystem: else: selector_func = self._up_to_root_selector child_node_it = iter(node) - return [selector_func(node, child_node) - for child_node in child_node_it] + return [ + selector_func(node, child_node) for child_node in child_node_it + ] def clear(self): """Remove all nodes (except the root) from this filesystem.""" @@ -219,8 +224,10 @@ class FakeFilesystem: else: node_child_count = node.child_count() if node_child_count: - raise ValueError("Can not delete '%s', it has %s children" - % (path, node_child_count)) + raise ValueError( + "Can not delete '%s', it has %s children" + % (path, node_child_count) + ) child_paths = [] if node is self._root: # Don't drop/pop the root... @@ -307,8 +314,9 @@ class MemoryBackend(path_based.PathBasedBackend): def __init__(self, conf=None): super().__init__(conf) - self.memory = FakeFilesystem(deep_copy=self._conf.get('deep_copy', - True)) + self.memory = FakeFilesystem( + deep_copy=self._conf.get('deep_copy', True) + ) self.lock = fasteners.ReaderWriterLock() def get_connection(self): @@ -335,8 +343,9 @@ class Connection(path_based.PathBasedConnection): except exc.TaskFlowException: raise except Exception: - exc.raise_with_cause(exc.StorageFailure, - "Storage backend internal error") + exc.raise_with_cause( + exc.StorageFailure, "Storage backend internal error" + ) def _join_path(self, *parts): return pp.join(*parts) diff --git a/taskflow/persistence/backends/impl_sqlalchemy.py b/taskflow/persistence/backends/impl_sqlalchemy.py index 627494e01..78e532fbd 100644 --- a/taskflow/persistence/backends/impl_sqlalchemy.py +++ b/taskflow/persistence/backends/impl_sqlalchemy.py @@ -109,8 +109,12 @@ DEFAULT_TXN_ISOLATION_LEVELS = { def _log_statements(log_level, conn, cursor, statement, parameters, *args): if LOG.isEnabledFor(log_level): - LOG.log(log_level, "Running statement '%s' with parameters %s", - statement, parameters) + LOG.log( + log_level, + "Running statement '%s' with parameters %s", + statement, + parameters, + ) def _in_any(reason, err_haystack): @@ -188,6 +192,7 @@ class _Alchemist: NOTE(harlowja): for internal usage only. """ + def __init__(self, tables): self._tables = tables @@ -206,15 +211,17 @@ class _Alchemist: return atom_cls.from_dict(row) def atom_query_iter(self, conn, parent_uuid): - q = (sql.select(self._tables.atomdetails). - where(self._tables.atomdetails.c.parent_uuid == parent_uuid)) + q = sql.select(self._tables.atomdetails).where( + self._tables.atomdetails.c.parent_uuid == parent_uuid + ) for row in conn.execute(q): row = row._mapping yield self.convert_atom_detail(row) def flow_query_iter(self, conn, parent_uuid): - q = (sql.select(self._tables.flowdetails). - where(self._tables.flowdetails.c.parent_uuid == parent_uuid)) + q = sql.select(self._tables.flowdetails).where( + self._tables.flowdetails.c.parent_uuid == parent_uuid + ) for row in conn.execute(q): row = row._mapping yield self.convert_flow_detail(row) @@ -238,6 +245,7 @@ class SQLAlchemyBackend(base.Backend): "connection": "sqlite:////tmp/test.db", } """ + def __init__(self, conf, engine=None): super().__init__(conf) if engine is not None: @@ -275,24 +283,27 @@ class SQLAlchemyBackend(base.Backend): engine_args["poolclass"] = sa_pool.StaticPool engine_args["connect_args"] = {'check_same_thread': False} else: - for (k, lookup_key) in [('pool_size', 'max_pool_size'), - ('max_overflow', 'max_overflow'), - ('pool_timeout', 'pool_timeout')]: + for k, lookup_key in [ + ('pool_size', 'max_pool_size'), + ('max_overflow', 'max_overflow'), + ('pool_timeout', 'pool_timeout'), + ]: if lookup_key in conf: engine_args[k] = misc.as_int(conf.pop(lookup_key)) if 'isolation_level' not in conf: # Check driver name exact matches first, then try driver name # partial matches... - txn_isolation_levels = conf.pop('isolation_levels', - DEFAULT_TXN_ISOLATION_LEVELS) + txn_isolation_levels = conf.pop( + 'isolation_levels', DEFAULT_TXN_ISOLATION_LEVELS + ) level_applied = False - for (driver, level) in txn_isolation_levels.items(): + for driver, level in txn_isolation_levels.items(): if driver == e_url.drivername: engine_args['isolation_level'] = level level_applied = True break if not level_applied: - for (driver, level) in txn_isolation_levels.items(): + for driver, level in txn_isolation_levels.items(): if e_url.drivername.find(driver) != -1: engine_args['isolation_level'] = level break @@ -304,13 +315,17 @@ class SQLAlchemyBackend(base.Backend): engine = sa.create_engine(sql_connection, **engine_args) log_statements = conf.pop('log_statements', False) if _as_bool(log_statements): - log_statements_level = conf.pop("log_statements_level", - logging.TRACE) - sa.event.listen(engine, "before_cursor_execute", - functools.partial(_log_statements, - log_statements_level)) - checkin_yield = conf.pop('checkin_yield', - eventlet_utils.EVENTLET_AVAILABLE) + log_statements_level = conf.pop( + "log_statements_level", logging.TRACE + ) + sa.event.listen( + engine, + "before_cursor_execute", + functools.partial(_log_statements, log_statements_level), + ) + checkin_yield = conf.pop( + 'checkin_yield', eventlet_utils.EVENTLET_AVAILABLE + ) if _as_bool(checkin_yield): sa.event.listen(engine, 'checkin', _thread_yield) if 'mysql' in e_url.drivername: @@ -320,8 +335,9 @@ class SQLAlchemyBackend(base.Backend): if 'mysql_sql_mode' in conf: mode = conf.pop('mysql_sql_mode') if mode is not None: - sa.event.listen(engine, 'connect', - functools.partial(_set_sql_mode, mode)) + sa.event.listen( + engine, 'connect', functools.partial(_set_sql_mode, mode) + ) return engine @property @@ -362,13 +378,19 @@ class Connection(base.Connection): def _retry_on_exception(exc): LOG.warning("Engine connection (validate) failed due to '%s'", exc) - if isinstance(exc, sa_exc.OperationalError) and \ - _is_db_connection_error(str(exc.args[0])): + if isinstance( + exc, sa_exc.OperationalError + ) and _is_db_connection_error(str(exc.args[0])): # We may be able to fix this by retrying... return True - if isinstance(exc, (sa_exc.TimeoutError, - sa_exc.ResourceClosedError, - sa_exc.DisconnectionError)): + if isinstance( + exc, + ( + sa_exc.TimeoutError, + sa_exc.ResourceClosedError, + sa_exc.DisconnectionError, + ), + ): # We may be able to fix this by retrying... return True # Other failures we likely can't fix by retrying... @@ -378,7 +400,7 @@ class Connection(base.Connection): stop=tenacity.stop_after_attempt(max(0, int(max_retries))), wait=tenacity.wait_exponential(), reraise=True, - retry=tenacity.retry_if_exception(_retry_on_exception) + retry=tenacity.retry_if_exception(_retry_on_exception), ) def _try_connect(engine): # See if we can make a connection happen. @@ -408,8 +430,9 @@ class Connection(base.Connection): else: migration.db_sync(conn) except sa_exc.SQLAlchemyError: - exc.raise_with_cause(exc.StorageFailure, - "Failed upgrading database version") + exc.raise_with_cause( + exc.StorageFailure, "Failed upgrading database version" + ) def clear_all(self): try: @@ -417,27 +440,33 @@ class Connection(base.Connection): with self._engine.begin() as conn: conn.execute(logbooks.delete()) except sa_exc.DBAPIError: - exc.raise_with_cause(exc.StorageFailure, - "Failed clearing all entries") + exc.raise_with_cause( + exc.StorageFailure, "Failed clearing all entries" + ) def update_atom_details(self, atom_detail): try: atomdetails = self._tables.atomdetails with self._engine.begin() as conn: - q = (sql.select(atomdetails). - where(atomdetails.c.uuid == atom_detail.uuid)) + q = sql.select(atomdetails).where( + atomdetails.c.uuid == atom_detail.uuid + ) row = conn.execute(q).first() if not row: - raise exc.NotFound("No atom details found with uuid" - " '%s'" % atom_detail.uuid) + raise exc.NotFound( + "No atom details found with uuid" + " '%s'" % atom_detail.uuid + ) row = row._mapping e_ad = self._converter.convert_atom_detail(row) self._update_atom_details(conn, atom_detail, e_ad) return e_ad except sa_exc.SQLAlchemyError: - exc.raise_with_cause(exc.StorageFailure, - "Failed updating atom details" - " with uuid '%s'" % atom_detail.uuid) + exc.raise_with_cause( + exc.StorageFailure, + "Failed updating atom details" + " with uuid '%s'" % atom_detail.uuid, + ) def _insert_flow_details(self, conn, fd, parent_uuid): value = fd.to_dict() @@ -454,15 +483,19 @@ class Connection(base.Connection): def _update_atom_details(self, conn, ad, e_ad): e_ad.merge(ad) - conn.execute(sql.update(self._tables.atomdetails) - .where(self._tables.atomdetails.c.uuid == e_ad.uuid) - .values(e_ad.to_dict())) + conn.execute( + sql.update(self._tables.atomdetails) + .where(self._tables.atomdetails.c.uuid == e_ad.uuid) + .values(e_ad.to_dict()) + ) def _update_flow_details(self, conn, fd, e_fd): e_fd.merge(fd) - conn.execute(sql.update(self._tables.flowdetails) - .where(self._tables.flowdetails.c.uuid == e_fd.uuid) - .values(e_fd.to_dict())) + conn.execute( + sql.update(self._tables.flowdetails) + .where(self._tables.flowdetails.c.uuid == e_fd.uuid) + .values(e_fd.to_dict()) + ) for ad in fd: e_ad = e_fd.find(ad.uuid) if e_ad is None: @@ -475,21 +508,26 @@ class Connection(base.Connection): try: flowdetails = self._tables.flowdetails with self._engine.begin() as conn: - q = (sql.select(flowdetails). - where(flowdetails.c.uuid == flow_detail.uuid)) + q = sql.select(flowdetails).where( + flowdetails.c.uuid == flow_detail.uuid + ) row = conn.execute(q).first() if not row: - raise exc.NotFound("No flow details found with" - " uuid '%s'" % flow_detail.uuid) + raise exc.NotFound( + "No flow details found with" + " uuid '%s'" % flow_detail.uuid + ) row = row._mapping e_fd = self._converter.convert_flow_detail(row) self._converter.populate_flow_detail(conn, e_fd) self._update_flow_details(conn, flow_detail, e_fd) return e_fd except sa_exc.SQLAlchemyError: - exc.raise_with_cause(exc.StorageFailure, - "Failed updating flow details with" - " uuid '%s'" % flow_detail.uuid) + exc.raise_with_cause( + exc.StorageFailure, + "Failed updating flow details with" + " uuid '%s'" % flow_detail.uuid, + ) def destroy_logbook(self, book_uuid): try: @@ -498,27 +536,31 @@ class Connection(base.Connection): q = logbooks.delete().where(logbooks.c.uuid == book_uuid) r = conn.execute(q) if r.rowcount == 0: - raise exc.NotFound("No logbook found with" - " uuid '%s'" % book_uuid) + raise exc.NotFound( + "No logbook found with uuid '%s'" % book_uuid + ) except sa_exc.DBAPIError: - exc.raise_with_cause(exc.StorageFailure, - "Failed destroying logbook '%s'" % book_uuid) + exc.raise_with_cause( + exc.StorageFailure, + "Failed destroying logbook '%s'" % book_uuid, + ) def save_logbook(self, book): try: logbooks = self._tables.logbooks with self._engine.begin() as conn: - q = (sql.select(logbooks). - where(logbooks.c.uuid == book.uuid)) + q = sql.select(logbooks).where(logbooks.c.uuid == book.uuid) row = conn.execute(q).first() if row: row = row._mapping e_lb = self._converter.convert_book(row) self._converter.populate_book(conn, e_lb) e_lb.merge(book) - conn.execute(sql.update(logbooks) - .where(logbooks.c.uuid == e_lb.uuid) - .values(e_lb.to_dict())) + conn.execute( + sql.update(logbooks) + .where(logbooks.c.uuid == e_lb.uuid) + .values(e_lb.to_dict()) + ) for fd in book: e_fd = e_lb.find(fd.uuid) if e_fd is None: @@ -534,27 +576,28 @@ class Connection(base.Connection): return book except sa_exc.DBAPIError: exc.raise_with_cause( - exc.StorageFailure, - "Failed saving logbook '%s'" % book.uuid) + exc.StorageFailure, "Failed saving logbook '%s'" % book.uuid + ) def get_logbook(self, book_uuid, lazy=False): try: logbooks = self._tables.logbooks with self._engine.connect() as conn: - q = (sql.select(logbooks). - where(logbooks.c.uuid == book_uuid)) + q = sql.select(logbooks).where(logbooks.c.uuid == book_uuid) row = conn.execute(q).first() if not row: - raise exc.NotFound("No logbook found with" - " uuid '%s'" % book_uuid) + raise exc.NotFound( + "No logbook found with uuid '%s'" % book_uuid + ) row = row._mapping book = self._converter.convert_book(row) if not lazy: self._converter.populate_book(conn, book) return book except sa_exc.DBAPIError: - exc.raise_with_cause(exc.StorageFailure, - "Failed getting logbook '%s'" % book_uuid) + exc.raise_with_cause( + exc.StorageFailure, "Failed getting logbook '%s'" % book_uuid + ) def get_logbooks(self, lazy=False): gathered = [] @@ -568,8 +611,7 @@ class Connection(base.Connection): self._converter.populate_book(conn, book) gathered.append(book) except sa_exc.DBAPIError: - exc.raise_with_cause(exc.StorageFailure, - "Failed getting logbooks") + exc.raise_with_cause(exc.StorageFailure, "Failed getting logbooks") for book in gathered: yield book @@ -582,47 +624,54 @@ class Connection(base.Connection): self._converter.populate_flow_detail(conn, fd) gathered.append(fd) except sa_exc.DBAPIError: - exc.raise_with_cause(exc.StorageFailure, - "Failed getting flow details in" - " logbook '%s'" % book_uuid) + exc.raise_with_cause( + exc.StorageFailure, + "Failed getting flow details in logbook '%s'" % book_uuid, + ) yield from gathered def get_flow_details(self, fd_uuid, lazy=False): try: flowdetails = self._tables.flowdetails with self._engine.begin() as conn: - q = (sql.select(flowdetails). - where(flowdetails.c.uuid == fd_uuid)) + q = sql.select(flowdetails).where( + flowdetails.c.uuid == fd_uuid + ) row = conn.execute(q).first() if not row: - raise exc.NotFound("No flow details found with uuid" - " '%s'" % fd_uuid) + raise exc.NotFound( + "No flow details found with uuid '%s'" % fd_uuid + ) row = row._mapping fd = self._converter.convert_flow_detail(row) if not lazy: self._converter.populate_flow_detail(conn, fd) return fd except sa_exc.SQLAlchemyError: - exc.raise_with_cause(exc.StorageFailure, - "Failed getting flow details with" - " uuid '%s'" % fd_uuid) + exc.raise_with_cause( + exc.StorageFailure, + "Failed getting flow details with uuid '%s'" % fd_uuid, + ) def get_atom_details(self, ad_uuid): try: atomdetails = self._tables.atomdetails with self._engine.begin() as conn: - q = (sql.select(atomdetails). - where(atomdetails.c.uuid == ad_uuid)) + q = sql.select(atomdetails).where( + atomdetails.c.uuid == ad_uuid + ) row = conn.execute(q).first() if not row: - raise exc.NotFound("No atom details found with uuid" - " '%s'" % ad_uuid) + raise exc.NotFound( + "No atom details found with uuid '%s'" % ad_uuid + ) row = row._mapping return self._converter.convert_atom_detail(row) except sa_exc.SQLAlchemyError: - exc.raise_with_cause(exc.StorageFailure, - "Failed getting atom details with" - " uuid '%s'" % ad_uuid) + exc.raise_with_cause( + exc.StorageFailure, + "Failed getting atom details with uuid '%s'" % ad_uuid, + ) def get_atoms_for_flow(self, fd_uuid): gathered = [] @@ -631,9 +680,10 @@ class Connection(base.Connection): for ad in self._converter.atom_query_iter(conn, fd_uuid): gathered.append(ad) except sa_exc.DBAPIError: - exc.raise_with_cause(exc.StorageFailure, - "Failed getting atom details in flow" - " detail '%s'" % fd_uuid) + exc.raise_with_cause( + exc.StorageFailure, + "Failed getting atom details in flow detail '%s'" % fd_uuid, + ) yield from gathered def close(self): diff --git a/taskflow/persistence/backends/impl_zookeeper.py b/taskflow/persistence/backends/impl_zookeeper.py index 36c1e2870..31c0e9650 100644 --- a/taskflow/persistence/backends/impl_zookeeper.py +++ b/taskflow/persistence/backends/impl_zookeeper.py @@ -79,8 +79,9 @@ class ZkBackend(path_based.PathBasedBackend): try: k_utils.finalize_client(self._client) except (k_exc.KazooException, k_exc.ZookeeperError): - exc.raise_with_cause(exc.StorageFailure, - "Unable to finalize client") + exc.raise_with_cause( + exc.StorageFailure, "Unable to finalize client" + ) class ZkConnection(path_based.PathBasedConnection): @@ -103,20 +104,23 @@ class ZkConnection(path_based.PathBasedConnection): try: yield except self._client.handler.timeout_exception: - exc.raise_with_cause(exc.StorageFailure, - "Storage backend timeout") + exc.raise_with_cause(exc.StorageFailure, "Storage backend timeout") except k_exc.SessionExpiredError: - exc.raise_with_cause(exc.StorageFailure, - "Storage backend session has expired") + exc.raise_with_cause( + exc.StorageFailure, "Storage backend session has expired" + ) except k_exc.NoNodeError: - exc.raise_with_cause(exc.NotFound, - "Storage backend node not found") + exc.raise_with_cause( + exc.NotFound, "Storage backend node not found" + ) except k_exc.NodeExistsError: - exc.raise_with_cause(exc.Duplicate, - "Storage backend duplicate node") + exc.raise_with_cause( + exc.Duplicate, "Storage backend duplicate node" + ) except (k_exc.KazooException, k_exc.ZookeeperError): - exc.raise_with_cause(exc.StorageFailure, - "Storage backend internal error") + exc.raise_with_cause( + exc.StorageFailure, "Storage backend internal error" + ) def _join_path(self, *parts): return paths.join(*parts) @@ -161,8 +165,11 @@ class ZkConnection(path_based.PathBasedConnection): with self._exc_wrapper(): try: if strutils.bool_from_string( - self._conf.get('check_compatible'), default=True): + self._conf.get('check_compatible'), default=True + ): k_utils.check_compatible(self._client, MIN_ZK_VERSION) except exc.IncompatibleVersion: - exc.raise_with_cause(exc.StorageFailure, "Backend storage is" - " not a compatible version") + exc.raise_with_cause( + exc.StorageFailure, + "Backend storage is not a compatible version", + ) diff --git a/taskflow/persistence/backends/sqlalchemy/alembic/env.py b/taskflow/persistence/backends/sqlalchemy/alembic/env.py index 44079d02b..82d2ee675 100644 --- a/taskflow/persistence/backends/sqlalchemy/alembic/env.py +++ b/taskflow/persistence/backends/sqlalchemy/alembic/env.py @@ -60,19 +60,23 @@ def run_migrations_online(): if connectable is None: connectable = engine_from_config( config.get_section(config.config_ini_section), - prefix='sqlalchemy.', poolclass=pool.NullPool) + prefix='sqlalchemy.', + poolclass=pool.NullPool, + ) with connectable.connect() as connection: - context.configure(connection=connection, - target_metadata=target_metadata) + context.configure( + connection=connection, target_metadata=target_metadata + ) with context.begin_transaction(): context.run_migrations() else: context.configure( - connection=connectable, - target_metadata=target_metadata) + connection=connectable, target_metadata=target_metadata + ) with context.begin_transaction(): context.run_migrations() + if context.is_offline_mode(): run_migrations_offline() else: diff --git a/taskflow/persistence/backends/sqlalchemy/alembic/versions/00af93df9d77_add_unique_into_all_indexes.py b/taskflow/persistence/backends/sqlalchemy/alembic/versions/00af93df9d77_add_unique_into_all_indexes.py index 69b7e7f86..cee19e2de 100644 --- a/taskflow/persistence/backends/sqlalchemy/alembic/versions/00af93df9d77_add_unique_into_all_indexes.py +++ b/taskflow/persistence/backends/sqlalchemy/alembic/versions/00af93df9d77_add_unique_into_all_indexes.py @@ -32,23 +32,24 @@ def upgrade(): with op.batch_alter_table("logbooks") as batch_op: batch_op.drop_index("logbook_uuid_idx") batch_op.create_index( - index_name="logbook_uuid_idx", - columns=['uuid'], - unique=True) + index_name="logbook_uuid_idx", columns=['uuid'], unique=True + ) with op.batch_alter_table("flowdetails") as batch_op: batch_op.drop_index("flowdetails_uuid_idx") batch_op.create_index( index_name="flowdetails_uuid_idx", columns=['uuid'], - unique=True) + unique=True, + ) with op.batch_alter_table("atomdetails") as batch_op: batch_op.drop_index("taskdetails_uuid_idx") batch_op.create_index( index_name="taskdetails_uuid_idx", columns=['uuid'], - unique=True) + unique=True, + ) def downgrade(): @@ -58,17 +59,17 @@ def downgrade(): with op.batch_alter_table("logbooks") as batch_op: batch_op.drop_index("logbook_uuid_idx") batch_op.create_index( - index_name="logbook_uuid_idx", - columns=['uuid']) + index_name="logbook_uuid_idx", columns=['uuid'] + ) with op.batch_alter_table("flowdetails") as batch_op: batch_op.drop_index("flowdetails_uuid_idx") batch_op.create_index( - index_name="flowdetails_uuid_idx", - columns=['uuid']) + index_name="flowdetails_uuid_idx", columns=['uuid'] + ) with op.batch_alter_table("atomdetails") as batch_op: batch_op.drop_index("taskdetails_uuid_idx") batch_op.create_index( - index_name="taskdetails_uuid_idx", - columns=['uuid']) + index_name="taskdetails_uuid_idx", columns=['uuid'] + ) diff --git a/taskflow/persistence/backends/sqlalchemy/alembic/versions/0bc3e1a3c135_set_result_meduimtext_type.py b/taskflow/persistence/backends/sqlalchemy/alembic/versions/0bc3e1a3c135_set_result_meduimtext_type.py index d1a5e6886..2816d4407 100644 --- a/taskflow/persistence/backends/sqlalchemy/alembic/versions/0bc3e1a3c135_set_result_meduimtext_type.py +++ b/taskflow/persistence/backends/sqlalchemy/alembic/versions/0bc3e1a3c135_set_result_meduimtext_type.py @@ -31,13 +31,18 @@ def upgrade(): bind = op.get_bind() engine = bind.engine if engine.name == 'mysql': - op.alter_column('atomdetails', 'results', type_=mysql.LONGTEXT, - existing_nullable=True) + op.alter_column( + 'atomdetails', + 'results', + type_=mysql.LONGTEXT, + existing_nullable=True, + ) def downgrade(): bind = op.get_bind() engine = bind.engine if engine.name == 'mysql': - op.alter_column('atomdetails', 'results', type_=sa.Text(), - existing_nullable=True) + op.alter_column( + 'atomdetails', 'results', type_=sa.Text(), existing_nullable=True + ) diff --git a/taskflow/persistence/backends/sqlalchemy/alembic/versions/14b227d79a87_add_intention_column.py b/taskflow/persistence/backends/sqlalchemy/alembic/versions/14b227d79a87_add_intention_column.py index 2469ae69c..dbf6ad30c 100644 --- a/taskflow/persistence/backends/sqlalchemy/alembic/versions/14b227d79a87_add_intention_column.py +++ b/taskflow/persistence/backends/sqlalchemy/alembic/versions/14b227d79a87_add_intention_column.py @@ -25,8 +25,9 @@ from taskflow import states def upgrade(): bind = op.get_bind() intention_type = sa.Enum(*states.INTENTIONS, name='intention_type') - column = sa.Column('intention', intention_type, - server_default=states.EXECUTE) + column = sa.Column( + 'intention', intention_type, server_default=states.EXECUTE + ) impl = intention_type.dialect_impl(bind.dialect) impl.create(bind, checkfirst=True) op.add_column('taskdetails', column) diff --git a/taskflow/persistence/backends/sqlalchemy/alembic/versions/1c783c0c2875_replace_exception_an.py b/taskflow/persistence/backends/sqlalchemy/alembic/versions/1c783c0c2875_replace_exception_an.py index 58ab67e8a..76fb5574c 100644 --- a/taskflow/persistence/backends/sqlalchemy/alembic/versions/1c783c0c2875_replace_exception_an.py +++ b/taskflow/persistence/backends/sqlalchemy/alembic/versions/1c783c0c2875_replace_exception_an.py @@ -29,15 +29,18 @@ import sqlalchemy as sa def upgrade(): - op.add_column('taskdetails', - sa.Column('failure', sa.Text(), nullable=True)) + op.add_column( + 'taskdetails', sa.Column('failure', sa.Text(), nullable=True) + ) op.drop_column('taskdetails', 'exception') op.drop_column('taskdetails', 'stacktrace') def downgrade(): op.drop_column('taskdetails', 'failure') - op.add_column('taskdetails', - sa.Column('stacktrace', sa.Text(), nullable=True)) - op.add_column('taskdetails', - sa.Column('exception', sa.Text(), nullable=True)) + op.add_column( + 'taskdetails', sa.Column('stacktrace', sa.Text(), nullable=True) + ) + op.add_column( + 'taskdetails', sa.Column('exception', sa.Text(), nullable=True) + ) diff --git a/taskflow/persistence/backends/sqlalchemy/alembic/versions/1cea328f0f65_initial_logbook_deta.py b/taskflow/persistence/backends/sqlalchemy/alembic/versions/1cea328f0f65_initial_logbook_deta.py index 7086ab117..df39363a6 100644 --- a/taskflow/persistence/backends/sqlalchemy/alembic/versions/1cea328f0f65_initial_logbook_deta.py +++ b/taskflow/persistence/backends/sqlalchemy/alembic/versions/1cea328f0f65_initial_logbook_deta.py @@ -82,50 +82,65 @@ def _get_foreign_keys(): def upgrade(): - op.create_table('logbooks', - sa.Column('created_at', sa.DateTime), - sa.Column('updated_at', sa.DateTime), - sa.Column('meta', sa.Text(), nullable=True), - sa.Column('name', sa.String(length=tables.NAME_LENGTH), - nullable=True), - sa.Column('uuid', sa.String(length=tables.UUID_LENGTH), - primary_key=True, nullable=False), - mysql_engine='InnoDB', - mysql_charset='utf8') - op.create_table('flowdetails', - sa.Column('created_at', sa.DateTime), - sa.Column('updated_at', sa.DateTime), - sa.Column('parent_uuid', - sa.String(length=tables.UUID_LENGTH)), - sa.Column('meta', sa.Text(), nullable=True), - sa.Column('state', sa.String(length=tables.STATE_LENGTH), - nullable=True), - sa.Column('name', sa.String(length=tables.NAME_LENGTH), - nullable=True), - sa.Column('uuid', sa.String(length=tables.UUID_LENGTH), - primary_key=True, nullable=False), - mysql_engine='InnoDB', - mysql_charset='utf8') - op.create_table('taskdetails', - sa.Column('created_at', sa.DateTime), - sa.Column('updated_at', sa.DateTime), - sa.Column('parent_uuid', - sa.String(length=tables.UUID_LENGTH)), - sa.Column('meta', sa.Text(), nullable=True), - sa.Column('name', sa.String(length=tables.NAME_LENGTH), - nullable=True), - sa.Column('results', sa.Text(), nullable=True), - sa.Column('version', - sa.String(length=tables.VERSION_LENGTH), - nullable=True), - sa.Column('stacktrace', sa.Text(), nullable=True), - sa.Column('exception', sa.Text(), nullable=True), - sa.Column('state', sa.String(length=tables.STATE_LENGTH), - nullable=True), - sa.Column('uuid', sa.String(length=tables.UUID_LENGTH), - primary_key=True, nullable=False), - mysql_engine='InnoDB', - mysql_charset='utf8') + op.create_table( + 'logbooks', + sa.Column('created_at', sa.DateTime), + sa.Column('updated_at', sa.DateTime), + sa.Column('meta', sa.Text(), nullable=True), + sa.Column('name', sa.String(length=tables.NAME_LENGTH), nullable=True), + sa.Column( + 'uuid', + sa.String(length=tables.UUID_LENGTH), + primary_key=True, + nullable=False, + ), + mysql_engine='InnoDB', + mysql_charset='utf8', + ) + op.create_table( + 'flowdetails', + sa.Column('created_at', sa.DateTime), + sa.Column('updated_at', sa.DateTime), + sa.Column('parent_uuid', sa.String(length=tables.UUID_LENGTH)), + sa.Column('meta', sa.Text(), nullable=True), + sa.Column( + 'state', sa.String(length=tables.STATE_LENGTH), nullable=True + ), + sa.Column('name', sa.String(length=tables.NAME_LENGTH), nullable=True), + sa.Column( + 'uuid', + sa.String(length=tables.UUID_LENGTH), + primary_key=True, + nullable=False, + ), + mysql_engine='InnoDB', + mysql_charset='utf8', + ) + op.create_table( + 'taskdetails', + sa.Column('created_at', sa.DateTime), + sa.Column('updated_at', sa.DateTime), + sa.Column('parent_uuid', sa.String(length=tables.UUID_LENGTH)), + sa.Column('meta', sa.Text(), nullable=True), + sa.Column('name', sa.String(length=tables.NAME_LENGTH), nullable=True), + sa.Column('results', sa.Text(), nullable=True), + sa.Column( + 'version', sa.String(length=tables.VERSION_LENGTH), nullable=True + ), + sa.Column('stacktrace', sa.Text(), nullable=True), + sa.Column('exception', sa.Text(), nullable=True), + sa.Column( + 'state', sa.String(length=tables.STATE_LENGTH), nullable=True + ), + sa.Column( + 'uuid', + sa.String(length=tables.UUID_LENGTH), + primary_key=True, + nullable=False, + ), + mysql_engine='InnoDB', + mysql_charset='utf8', + ) try: for fkey_descriptor in _get_foreign_keys(): op.create_foreign_key(**fkey_descriptor) diff --git a/taskflow/persistence/backends/sqlalchemy/alembic/versions/2ad4984f2864_switch_postgres_to_json_native.py b/taskflow/persistence/backends/sqlalchemy/alembic/versions/2ad4984f2864_switch_postgres_to_json_native.py index aa923ea9a..70fa5af5b 100644 --- a/taskflow/persistence/backends/sqlalchemy/alembic/versions/2ad4984f2864_switch_postgres_to_json_native.py +++ b/taskflow/persistence/backends/sqlalchemy/alembic/versions/2ad4984f2864_switch_postgres_to_json_native.py @@ -28,22 +28,24 @@ from alembic import op _ALTER_TO_JSON_TPL = 'ALTER TABLE %s ALTER COLUMN %s TYPE JSON USING %s::JSON' -_TABLES_COLS = tuple([ - ('logbooks', 'meta'), - ('flowdetails', 'meta'), - ('atomdetails', 'meta'), - ('atomdetails', 'failure'), - ('atomdetails', 'revert_failure'), - ('atomdetails', 'results'), - ('atomdetails', 'revert_results'), -]) +_TABLES_COLS = tuple( + [ + ('logbooks', 'meta'), + ('flowdetails', 'meta'), + ('atomdetails', 'meta'), + ('atomdetails', 'failure'), + ('atomdetails', 'revert_failure'), + ('atomdetails', 'results'), + ('atomdetails', 'revert_results'), + ] +) _ALTER_TO_TEXT_TPL = 'ALTER TABLE %s ALTER COLUMN %s TYPE TEXT' def upgrade(): b = op.get_bind() if b.dialect.name.startswith('postgresql'): - for (table_name, col_name) in _TABLES_COLS: + for table_name, col_name in _TABLES_COLS: q = _ALTER_TO_JSON_TPL % (table_name, col_name, col_name) op.execute(q) @@ -51,6 +53,6 @@ def upgrade(): def downgrade(): b = op.get_bind() if b.dialect.name.startswith('postgresql'): - for (table_name, col_name) in _TABLES_COLS: + for table_name, col_name in _TABLES_COLS: q = _ALTER_TO_TEXT_TPL % (table_name, col_name) op.execute(q) diff --git a/taskflow/persistence/backends/sqlalchemy/alembic/versions/3162c0f3f8e4_add_revert_results_and_revert_failure_.py b/taskflow/persistence/backends/sqlalchemy/alembic/versions/3162c0f3f8e4_add_revert_results_and_revert_failure_.py index 14cbc5374..14db30a4e 100644 --- a/taskflow/persistence/backends/sqlalchemy/alembic/versions/3162c0f3f8e4_add_revert_results_and_revert_failure_.py +++ b/taskflow/persistence/backends/sqlalchemy/alembic/versions/3162c0f3f8e4_add_revert_results_and_revert_failure_.py @@ -29,10 +29,12 @@ import sqlalchemy as sa def upgrade(): - op.add_column('atomdetails', - sa.Column('revert_results', sa.Text(), nullable=True)) - op.add_column('atomdetails', - sa.Column('revert_failure', sa.Text(), nullable=True)) + op.add_column( + 'atomdetails', sa.Column('revert_results', sa.Text(), nullable=True) + ) + op.add_column( + 'atomdetails', sa.Column('revert_failure', sa.Text(), nullable=True) + ) def downgrade(): diff --git a/taskflow/persistence/backends/sqlalchemy/alembic/versions/40fc8c914bd2_fix_atomdetails_failure_size.py b/taskflow/persistence/backends/sqlalchemy/alembic/versions/40fc8c914bd2_fix_atomdetails_failure_size.py index ec50dd6a2..628a733b8 100644 --- a/taskflow/persistence/backends/sqlalchemy/alembic/versions/40fc8c914bd2_fix_atomdetails_failure_size.py +++ b/taskflow/persistence/backends/sqlalchemy/alembic/versions/40fc8c914bd2_fix_atomdetails_failure_size.py @@ -30,7 +30,15 @@ def upgrade(): bind = op.get_bind() engine = bind.engine if engine.name == 'mysql': - op.alter_column('atomdetails', 'failure', type_=mysql.LONGTEXT, - existing_nullable=True) - op.alter_column('atomdetails', 'revert_failure', type_=mysql.LONGTEXT, - existing_nullable=True) + op.alter_column( + 'atomdetails', + 'failure', + type_=mysql.LONGTEXT, + existing_nullable=True, + ) + op.alter_column( + 'atomdetails', + 'revert_failure', + type_=mysql.LONGTEXT, + existing_nullable=True, + ) diff --git a/taskflow/persistence/backends/sqlalchemy/alembic/versions/6df9422fcb43_fix_flowdetails_meta_size.py b/taskflow/persistence/backends/sqlalchemy/alembic/versions/6df9422fcb43_fix_flowdetails_meta_size.py index 4adb945c3..82f77b96c 100644 --- a/taskflow/persistence/backends/sqlalchemy/alembic/versions/6df9422fcb43_fix_flowdetails_meta_size.py +++ b/taskflow/persistence/backends/sqlalchemy/alembic/versions/6df9422fcb43_fix_flowdetails_meta_size.py @@ -30,5 +30,6 @@ def upgrade(): bind = op.get_bind() engine = bind.engine if engine.name == 'mysql': - op.alter_column('flowdetails', 'meta', type_=mysql.LONGTEXT, - existing_nullable=True) + op.alter_column( + 'flowdetails', 'meta', type_=mysql.LONGTEXT, existing_nullable=True + ) diff --git a/taskflow/persistence/backends/sqlalchemy/tables.py b/taskflow/persistence/backends/sqlalchemy/tables.py index c3aa4c607..dcd51d6c4 100644 --- a/taskflow/persistence/backends/sqlalchemy/tables.py +++ b/taskflow/persistence/backends/sqlalchemy/tables.py @@ -23,8 +23,9 @@ from sqlalchemy_utils.types import json as json_type from taskflow.persistence import models from taskflow import states -Tables = collections.namedtuple('Tables', - ['logbooks', 'flowdetails', 'atomdetails']) +Tables = collections.namedtuple( + 'Tables', ['logbooks', 'flowdetails', 'atomdetails'] +) # Column length limits... NAME_LENGTH = 255 @@ -53,51 +54,71 @@ class JSONType(json_type.JSONType): def fetch(metadata): """Returns the master set of table objects (which is also there schema).""" - logbooks = Table('logbooks', metadata, - Column('created_at', DateTime, - default=timeutils.utcnow), - Column('updated_at', DateTime, - onupdate=timeutils.utcnow), - Column('meta', JSONType), - Column('name', String(length=NAME_LENGTH)), - Column('uuid', String(length=UUID_LENGTH), - primary_key=True, nullable=False, unique=True, - default=uuidutils.generate_uuid)) - flowdetails = Table('flowdetails', metadata, - Column('created_at', DateTime, - default=timeutils.utcnow), - Column('updated_at', DateTime, - onupdate=timeutils.utcnow), - Column('parent_uuid', String(length=UUID_LENGTH), - ForeignKey('logbooks.uuid', - ondelete='CASCADE')), - Column('meta', JSONType), - Column('name', String(length=NAME_LENGTH)), - Column('state', String(length=STATE_LENGTH)), - Column('uuid', String(length=UUID_LENGTH), - primary_key=True, nullable=False, unique=True, - default=uuidutils.generate_uuid)) - atomdetails = Table('atomdetails', metadata, - Column('created_at', DateTime, - default=timeutils.utcnow), - Column('updated_at', DateTime, - onupdate=timeutils.utcnow), - Column('meta', JSONType), - Column('parent_uuid', String(length=UUID_LENGTH), - ForeignKey('flowdetails.uuid', - ondelete='CASCADE')), - Column('name', String(length=NAME_LENGTH)), - Column('version', String(length=VERSION_LENGTH)), - Column('state', String(length=STATE_LENGTH)), - Column('uuid', String(length=UUID_LENGTH), - primary_key=True, nullable=False, unique=True, - default=uuidutils.generate_uuid), - Column('failure', JSONType), - Column('results', JSONType), - Column('revert_results', JSONType), - Column('revert_failure', JSONType), - Column('atom_type', Enum(*models.ATOM_TYPES, - name='atom_types')), - Column('intention', Enum(*states.INTENTIONS, - name='intentions'))) + logbooks = Table( + 'logbooks', + metadata, + Column('created_at', DateTime, default=timeutils.utcnow), + Column('updated_at', DateTime, onupdate=timeutils.utcnow), + Column('meta', JSONType), + Column('name', String(length=NAME_LENGTH)), + Column( + 'uuid', + String(length=UUID_LENGTH), + primary_key=True, + nullable=False, + unique=True, + default=uuidutils.generate_uuid, + ), + ) + flowdetails = Table( + 'flowdetails', + metadata, + Column('created_at', DateTime, default=timeutils.utcnow), + Column('updated_at', DateTime, onupdate=timeutils.utcnow), + Column( + 'parent_uuid', + String(length=UUID_LENGTH), + ForeignKey('logbooks.uuid', ondelete='CASCADE'), + ), + Column('meta', JSONType), + Column('name', String(length=NAME_LENGTH)), + Column('state', String(length=STATE_LENGTH)), + Column( + 'uuid', + String(length=UUID_LENGTH), + primary_key=True, + nullable=False, + unique=True, + default=uuidutils.generate_uuid, + ), + ) + atomdetails = Table( + 'atomdetails', + metadata, + Column('created_at', DateTime, default=timeutils.utcnow), + Column('updated_at', DateTime, onupdate=timeutils.utcnow), + Column('meta', JSONType), + Column( + 'parent_uuid', + String(length=UUID_LENGTH), + ForeignKey('flowdetails.uuid', ondelete='CASCADE'), + ), + Column('name', String(length=NAME_LENGTH)), + Column('version', String(length=VERSION_LENGTH)), + Column('state', String(length=STATE_LENGTH)), + Column( + 'uuid', + String(length=UUID_LENGTH), + primary_key=True, + nullable=False, + unique=True, + default=uuidutils.generate_uuid, + ), + Column('failure', JSONType), + Column('results', JSONType), + Column('revert_results', JSONType), + Column('revert_failure', JSONType), + Column('atom_type', Enum(*models.ATOM_TYPES, name='atom_types')), + Column('intention', Enum(*states.INTENTIONS, name='intentions')), + ) return Tables(logbooks, flowdetails, atomdetails) diff --git a/taskflow/persistence/base.py b/taskflow/persistence/base.py index 136882a1f..f969d9a2b 100644 --- a/taskflow/persistence/base.py +++ b/taskflow/persistence/base.py @@ -24,8 +24,10 @@ class Backend(metaclass=abc.ABCMeta): if not conf: conf = {} if not isinstance(conf, dict): - raise TypeError("Configuration dictionary expected not '%s' (%s)" - % (conf, type(conf))) + raise TypeError( + "Configuration dictionary expected not '%s' (%s)" + % (conf, type(conf)) + ) self._conf = conf @abc.abstractmethod diff --git a/taskflow/persistence/models.py b/taskflow/persistence/models.py index 8edcd4d18..1ed9fa73b 100644 --- a/taskflow/persistence/models.py +++ b/taskflow/persistence/models.py @@ -36,7 +36,7 @@ def _format_meta(metadata, indent): lines = [ '%s- metadata:' % (" " * indent), ] - for (k, v) in metadata.items(): + for k, v in metadata.items(): # Progress for now is a special snowflake and will be formatted # in percent format. if k == 'progress' and isinstance(v, misc.NUMERIC_TYPES): @@ -53,8 +53,11 @@ def _format_shared(obj, indent): for attr_name in ("uuid", "state"): if not hasattr(obj, attr_name): continue - lines.append("{}- {} = {}".format(" " * indent, attr_name, - getattr(obj, attr_name))) + lines.append( + "{}- {} = {}".format( + " " * indent, attr_name, getattr(obj, attr_name) + ) + ) return lines @@ -119,6 +122,7 @@ class LogBook: was last updated at. :ivar meta: A dictionary of meta-data associated with this logbook. """ + def __init__(self, name, uuid=None): if uuid: self._uuid = uuid @@ -145,16 +149,19 @@ class LogBook: lines.extend(_format_shared(self, indent=indent + 1)) lines.extend(_format_meta(self.meta, indent=indent + 1)) if self.created_at is not None: - lines.append("%s- created_at = %s" - % (" " * (indent + 1), - self.created_at.isoformat())) + lines.append( + "%s- created_at = %s" + % (" " * (indent + 1), self.created_at.isoformat()) + ) if self.updated_at is not None: - lines.append("%s- updated_at = %s" - % (" " * (indent + 1), - self.updated_at.isoformat())) + lines.append( + "%s- updated_at = %s" + % (" " * (indent + 1), self.updated_at.isoformat()) + ) for flow_detail in self: - lines.append(flow_detail.pformat(indent=indent + 1, - linesep=linesep)) + lines.append( + flow_detail.pformat(indent=indent + 1, linesep=linesep) + ) return linesep.join(lines) def add(self, fd): @@ -299,6 +306,7 @@ class FlowDetail: :ivar meta: A dictionary of meta-data associated with this flow detail. """ + def __init__(self, name, uuid): self._uuid = uuid self._name = name @@ -334,8 +342,9 @@ class FlowDetail: >>> from oslo_utils import uuidutils >>> from taskflow.persistence import models - >>> flow_detail = models.FlowDetail("example", - ... uuid=uuidutils.generate_uuid()) + >>> flow_detail = models.FlowDetail( + ... "example", uuid=uuidutils.generate_uuid() + ... ) >>> print(flow_detail.pformat()) FlowDetail: 'example' - uuid = ... @@ -346,8 +355,9 @@ class FlowDetail: lines.extend(_format_shared(self, indent=indent + 1)) lines.extend(_format_meta(self.meta, indent=indent + 1)) for atom_detail in self: - lines.append(atom_detail.pformat(indent=indent + 1, - linesep=linesep)) + lines.append( + atom_detail.pformat(indent=indent + 1, linesep=linesep) + ) return linesep.join(lines) def merge(self, fd, deep_copy=False): @@ -686,12 +696,14 @@ class AtomDetail(metaclass=abc.ABCMeta): cls_name = self.__class__.__name__ lines = ["{}{}: '{}'".format(" " * (indent), cls_name, self.name)] lines.extend(_format_shared(self, indent=indent + 1)) - lines.append("%s- version = %s" - % (" " * (indent + 1), misc.get_version_string(self))) - lines.append("%s- results = %s" - % (" " * (indent + 1), self.results)) - lines.append("{}- failure = {}".format(" " * (indent + 1), - bool(self.failure))) + lines.append( + "%s- version = %s" + % (" " * (indent + 1), misc.get_version_string(self)) + ) + lines.append("%s- results = %s" % (" " * (indent + 1), self.results)) + lines.append( + "{}- failure = {}".format(" " * (indent + 1), bool(self.failure)) + ) lines.extend(_format_meta(self.meta, indent=indent + 1)) return linesep.join(lines) @@ -738,15 +750,17 @@ class TaskDetail(AtomDetail): if self.failure != result: self.failure = result was_altered = True - if not _is_all_none(self.results, self.revert_results, - self.revert_failure): + if not _is_all_none( + self.results, self.revert_results, self.revert_failure + ): self.results = None self.revert_results = None self.revert_failure = None was_altered = True elif state == states.SUCCESS: - if not _is_all_none(self.revert_results, self.revert_failure, - self.failure): + if not _is_all_none( + self.revert_results, self.revert_failure, self.failure + ): self.revert_results = None self.revert_failure = None self.failure = None @@ -785,8 +799,9 @@ class TaskDetail(AtomDetail): :rtype: :py:class:`.TaskDetail` """ if not isinstance(other, TaskDetail): - raise exc.NotImplementedError("Can only merge with other" - " task details") + raise exc.NotImplementedError( + "Can only merge with other task details" + ) if other is self: return self super().merge(other, deep_copy=deep_copy) @@ -879,9 +894,9 @@ class RetryDetail(AtomDetail): results = [] # NOTE(imelnikov): we can't just deep copy Failures, as they # contain tracebacks, which are not copyable. - for (data, failures) in self.results: + for data, failures in self.results: copied_failures = {} - for (key, failure) in failures.items(): + for key, failure in failures.items(): copied_failures[key] = failure results.append((data, copied_failures)) clone.results = results @@ -942,8 +957,9 @@ class RetryDetail(AtomDetail): self.revert_failure = None was_altered = True elif state == states.SUCCESS: - if not _is_all_none(self.failure, self.revert_failure, - self.revert_results): + if not _is_all_none( + self.failure, self.revert_failure, self.revert_results + ): self.failure = None self.revert_failure = None self.revert_results = None @@ -972,9 +988,9 @@ class RetryDetail(AtomDetail): if not results: return [] new_results = [] - for (data, failures) in results: + for data, failures in results: new_failures = {} - for (key, data) in failures.items(): + for key, data in failures.items(): new_failures[key] = ft.Failure.from_dict(data) new_results.append((data, new_failures)) return new_results @@ -990,9 +1006,9 @@ class RetryDetail(AtomDetail): if not results: return [] new_results = [] - for (data, failures) in results: + for data, failures in results: new_failures = {} - for (key, failure) in failures.items(): + for key, failure in failures.items(): new_failures[key] = failure.to_dict() new_results.append((data, new_failures)) return new_results @@ -1025,17 +1041,18 @@ class RetryDetail(AtomDetail): :rtype: :py:class:`.RetryDetail` """ if not isinstance(other, RetryDetail): - raise exc.NotImplementedError("Can only merge with other" - " retry details") + raise exc.NotImplementedError( + "Can only merge with other retry details" + ) if other is self: return self super().merge(other, deep_copy=deep_copy) results = [] # NOTE(imelnikov): we can't just deep copy Failures, as they # contain tracebacks, which are not copyable. - for (data, failures) in other.results: + for data, failures in other.results: copied_failures = {} - for (key, failure) in failures.items(): + for key, failure in failures.items(): if deep_copy: copied_failures[key] = failure.copy() else: @@ -1064,5 +1081,6 @@ def atom_detail_type(atom_detail): try: return _DETAIL_TO_NAME[type(atom_detail)] except KeyError: - raise TypeError("Unknown atom '%s' (%s)" - % (atom_detail, type(atom_detail))) + raise TypeError( + "Unknown atom '%s' (%s)" % (atom_detail, type(atom_detail)) + ) diff --git a/taskflow/persistence/path_based.py b/taskflow/persistence/path_based.py index 44fa874e2..e308c9bb6 100644 --- a/taskflow/persistence/path_based.py +++ b/taskflow/persistence/path_based.py @@ -167,8 +167,9 @@ class PathBasedConnection(base.Connection, metaclass=abc.ABCMeta): for flow_details in book: flow_path = self._get_obj_path(flow_details) link_path = self._join_path(book_path, flow_details.uuid) - self._do_update_flow_details(flow_details, transaction, - ignore_missing=True) + self._do_update_flow_details( + flow_details, transaction, ignore_missing=True + ) self._create_link(flow_path, link_path, transaction) return book @@ -186,11 +187,13 @@ class PathBasedConnection(base.Connection, metaclass=abc.ABCMeta): flow_details.add(atom_details) return flow_details - def _do_update_flow_details(self, flow_detail, transaction, - ignore_missing=False): + def _do_update_flow_details( + self, flow_detail, transaction, ignore_missing=False + ): flow_path = self._get_obj_path(flow_detail) - self._update_object(flow_detail, transaction, - ignore_missing=ignore_missing) + self._update_object( + flow_detail, transaction, ignore_missing=ignore_missing + ) for atom_details in flow_detail: atom_path = self._get_obj_path(atom_details) link_path = self._join_path(flow_path, atom_details.uuid) @@ -200,8 +203,9 @@ class PathBasedConnection(base.Connection, metaclass=abc.ABCMeta): def update_flow_details(self, flow_detail, ignore_missing=False): with self._transaction() as transaction: - return self._do_update_flow_details(flow_detail, transaction, - ignore_missing=ignore_missing) + return self._do_update_flow_details( + flow_detail, transaction, ignore_missing=ignore_missing + ) def get_atoms_for_flow(self, flow_uuid): flow_path = self._join_path(self.flow_path, flow_uuid) @@ -215,8 +219,9 @@ class PathBasedConnection(base.Connection, metaclass=abc.ABCMeta): def update_atom_details(self, atom_detail, ignore_missing=False): with self._transaction() as transaction: - return self._update_object(atom_detail, transaction, - ignore_missing=ignore_missing) + return self._update_object( + atom_detail, transaction, ignore_missing=ignore_missing + ) def _do_destroy_logbook(self, book_uuid, transaction): book_path = self._join_path(self.book_path, book_uuid) diff --git a/taskflow/retry.py b/taskflow/retry.py index c423ef860..b39c475ad 100644 --- a/taskflow/retry.py +++ b/taskflow/retry.py @@ -54,6 +54,7 @@ class Decision(misc.StrEnum): #: Retries the surrounding/associated subflow again. RETRY = "RETRY" + # Retain these aliases for a number of releases... REVERT = Decision.REVERT REVERT_ALL = Decision.REVERT_ALL @@ -96,7 +97,7 @@ class History: contents = [ self._contents[index], ] - for (provided, outcomes) in contents: + for provided, outcomes in contents: yield from outcomes.items() def __len__(self): @@ -104,7 +105,7 @@ class History: def provided_iter(self): """Iterates over all the values the retry has attempted (in order).""" - for (provided, outcomes) in self._contents: + for provided, outcomes in self._contents: yield provided def __getitem__(self, index): @@ -119,7 +120,7 @@ class History: to false) will the potential retries own failure be checked against as well. """ - for (name, failure) in self.outcomes_iter(index=index): + for name, failure in self.outcomes_iter(index=index): if failure.check(exception_cls): return True if include_retry and self._failure is not None: @@ -149,12 +150,22 @@ class Retry(atom.Atom, metaclass=abc.ABCMeta): decisions and outcomes that have occurred (if available). """ - def __init__(self, name=None, provides=None, requires=None, - auto_extract=True, rebind=None): - super().__init__(name=name, provides=provides, - requires=requires, rebind=rebind, - auto_extract=auto_extract, - ignore_list=[EXECUTE_REVERT_HISTORY]) + def __init__( + self, + name=None, + provides=None, + requires=None, + auto_extract=True, + rebind=None, + ): + super().__init__( + name=name, + provides=provides, + requires=requires, + rebind=rebind, + auto_extract=auto_extract, + ignore_list=[EXECUTE_REVERT_HISTORY], + ) @property def name(self): @@ -257,8 +268,16 @@ class Times(Retry): :py:class:`~taskflow.atom.Atom` constructor. """ - def __init__(self, attempts=1, name=None, provides=None, requires=None, - auto_extract=True, rebind=None, revert_all=False): + def __init__( + self, + attempts=1, + name=None, + provides=None, + requires=None, + auto_extract=True, + rebind=None, + revert_all=False, + ): super().__init__(name, provides, requires, auto_extract, rebind) self._attempts = attempts @@ -279,8 +298,15 @@ class Times(Retry): class ForEachBase(Retry): """Base class for retries that iterate over a given collection.""" - def __init__(self, name=None, provides=None, requires=None, - auto_extract=True, rebind=None, revert_all=False): + def __init__( + self, + name=None, + provides=None, + requires=None, + auto_extract=True, + rebind=None, + revert_all=False, + ): super().__init__(name, provides, requires, auto_extract, rebind) if revert_all: @@ -294,8 +320,10 @@ class ForEachBase(Retry): # resolution strategy remaining. remaining = misc.sequence_minus(values, history.provided_iter()) if not remaining: - raise exc.NotFound("No elements left in collection of iterable " - "retry controller %s" % self.name) + raise exc.NotFound( + "No elements left in collection of iterable " + "retry controller %s" % self.name + ) return remaining[0] def _on_failure(self, values, history): @@ -329,10 +357,19 @@ class ForEach(ForEachBase): :py:class:`~taskflow.atom.Atom` constructor. """ - def __init__(self, values, name=None, provides=None, requires=None, - auto_extract=True, rebind=None, revert_all=False): - super().__init__(name, provides, requires, auto_extract, rebind, - revert_all) + def __init__( + self, + values, + name=None, + provides=None, + requires=None, + auto_extract=True, + rebind=None, + revert_all=False, + ): + super().__init__( + name, provides, requires, auto_extract, rebind, revert_all + ) self._values = values def on_failure(self, history, *args, **kwargs): @@ -361,10 +398,18 @@ class ParameterizedForEach(ForEachBase): :py:class:`~taskflow.atom.Atom` constructor. """ - def __init__(self, name=None, provides=None, requires=None, - auto_extract=True, rebind=None, revert_all=False): - super().__init__(name, provides, requires, auto_extract, rebind, - revert_all) + def __init__( + self, + name=None, + provides=None, + requires=None, + auto_extract=True, + rebind=None, + revert_all=False, + ): + super().__init__( + name, provides, requires, auto_extract, rebind, revert_all + ) def on_failure(self, values, history, *args, **kwargs): return self._on_failure(values, history) diff --git a/taskflow/states.py b/taskflow/states.py index 486baa8be..3c442ea83 100644 --- a/taskflow/states.py +++ b/taskflow/states.py @@ -56,16 +56,16 @@ ANALYZING = 'ANALYZING' # Job state transitions # See: https://docs.openstack.org/taskflow/latest/user/states.html -_ALLOWED_JOB_TRANSITIONS = frozenset(( - # Job is being claimed. - (UNCLAIMED, CLAIMED), - - # Job has been lost (or manually unclaimed/abandoned). - (CLAIMED, UNCLAIMED), - - # Job has been finished. - (CLAIMED, COMPLETE), -)) +_ALLOWED_JOB_TRANSITIONS = frozenset( + ( + # Job is being claimed. + (UNCLAIMED, CLAIMED), + # Job has been lost (or manually unclaimed/abandoned). + (CLAIMED, UNCLAIMED), + # Job has been finished. + (CLAIMED, COMPLETE), + ) +) def check_job_transition(old_state, new_state): @@ -80,39 +80,38 @@ def check_job_transition(old_state, new_state): pair = (old_state, new_state) if pair in _ALLOWED_JOB_TRANSITIONS: return True - raise exc.InvalidState("Job transition from '%s' to '%s' is not allowed" - % pair) + raise exc.InvalidState( + "Job transition from '%s' to '%s' is not allowed" % pair + ) # Flow state transitions # See: https://docs.openstack.org/taskflow/latest/user/states.html#flow -_ALLOWED_FLOW_TRANSITIONS = frozenset(( - (PENDING, RUNNING), # run it! - - (RUNNING, SUCCESS), # all tasks finished successfully - (RUNNING, FAILURE), # some of task failed - (RUNNING, REVERTED), # some of task failed and flow has been reverted - (RUNNING, SUSPENDING), # engine.suspend was called - (RUNNING, RESUMING), # resuming from a previous running - - (SUCCESS, RUNNING), # see note below - - (FAILURE, RUNNING), # see note below - - (REVERTED, PENDING), # try again - (SUCCESS, PENDING), # run it again - - (SUSPENDING, SUSPENDED), # suspend finished - (SUSPENDING, SUCCESS), # all tasks finished while we were waiting - (SUSPENDING, FAILURE), # some tasks failed while we were waiting - (SUSPENDING, REVERTED), # all tasks were reverted while we were waiting - (SUSPENDING, RESUMING), # resuming from a previous suspending - - (SUSPENDED, RUNNING), # restart from suspended - - (RESUMING, SUSPENDED), # after flow resumed, it is suspended -)) +_ALLOWED_FLOW_TRANSITIONS = frozenset( + ( + (PENDING, RUNNING), # run it! + (RUNNING, SUCCESS), # all tasks finished successfully + (RUNNING, FAILURE), # some of task failed + (RUNNING, REVERTED), # some of task failed and flow has been reverted + (RUNNING, SUSPENDING), # engine.suspend was called + (RUNNING, RESUMING), # resuming from a previous running + (SUCCESS, RUNNING), # see note below + (FAILURE, RUNNING), # see note below + (REVERTED, PENDING), # try again + (SUCCESS, PENDING), # run it again + (SUSPENDING, SUSPENDED), # suspend finished + (SUSPENDING, SUCCESS), # all tasks finished while we were waiting + (SUSPENDING, FAILURE), # some tasks failed while we were waiting + ( + SUSPENDING, + REVERTED, + ), # all tasks were reverted while we were waiting + (SUSPENDING, RESUMING), # resuming from a previous suspending + (SUSPENDED, RUNNING), # restart from suspended + (RESUMING, SUSPENDED), # after flow resumed, it is suspended + ) +) # NOTE(imelnikov) SUCCESS->RUNNING and FAILURE->RUNNING transitions are @@ -152,29 +151,28 @@ def check_flow_transition(old_state, new_state): return True if pair in _IGNORED_FLOW_TRANSITIONS: return False - raise exc.InvalidState("Flow transition from '%s' to '%s' is not allowed" - % pair) + raise exc.InvalidState( + "Flow transition from '%s' to '%s' is not allowed" % pair + ) # Task state transitions # See: https://docs.openstack.org/taskflow/latest/user/states.html#task -_ALLOWED_TASK_TRANSITIONS = frozenset(( - (PENDING, RUNNING), # run it! - (PENDING, IGNORE), # skip it! - - (RUNNING, SUCCESS), # the task executed successfully - (RUNNING, FAILURE), # the task execution failed - - (FAILURE, REVERTING), # task execution failed, try reverting... - (SUCCESS, REVERTING), # some other task failed, try reverting... - - (REVERTING, REVERTED), # the task reverted successfully - (REVERTING, REVERT_FAILURE), # the task failed reverting (terminal!) - - (REVERTED, PENDING), # try again - (IGNORE, PENDING), # try again -)) +_ALLOWED_TASK_TRANSITIONS = frozenset( + ( + (PENDING, RUNNING), # run it! + (PENDING, IGNORE), # skip it! + (RUNNING, SUCCESS), # the task executed successfully + (RUNNING, FAILURE), # the task execution failed + (FAILURE, REVERTING), # task execution failed, try reverting... + (SUCCESS, REVERTING), # some other task failed, try reverting... + (REVERTING, REVERTED), # the task reverted successfully + (REVERTING, REVERT_FAILURE), # the task failed reverting (terminal!) + (REVERTED, PENDING), # try again + (IGNORE, PENDING), # try again + ) +) def check_task_transition(old_state, new_state): @@ -192,10 +190,12 @@ def check_task_transition(old_state, new_state): # See: https://docs.openstack.org/taskflow/latest/user/states.html#retry _ALLOWED_RETRY_TRANSITIONS = list(_ALLOWED_TASK_TRANSITIONS) -_ALLOWED_RETRY_TRANSITIONS.extend([ - (SUCCESS, RETRYING), # retrying retry controller - (RETRYING, RUNNING), # run retry controller that has been retrying -]) +_ALLOWED_RETRY_TRANSITIONS.extend( + [ + (SUCCESS, RETRYING), # retrying retry controller + (RETRYING, RUNNING), # run retry controller that has been retrying + ] +) _ALLOWED_RETRY_TRANSITIONS = frozenset(_ALLOWED_RETRY_TRANSITIONS) diff --git a/taskflow/storage.py b/taskflow/storage.py index ddc47f647..7871baecb 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -93,14 +93,18 @@ class _ProviderLocator: follow... """ - def __init__(self, transient_results, - providers_fetcher, result_fetcher): + def __init__(self, transient_results, providers_fetcher, result_fetcher): self.result_fetcher = result_fetcher self.providers_fetcher = providers_fetcher self.transient_results = transient_results - def _try_get_results(self, looking_for, provider, - look_into_results=True, find_potentials=False): + def _try_get_results( + self, + looking_for, + provider, + look_into_results=True, + find_potentials=False, + ): if provider.name is _TRANSIENT_PROVIDER: # TODO(harlowja): This 'is' check still sucks, do this # better in the future... @@ -120,8 +124,13 @@ class _ProviderLocator: _item_from_single(provider, results, looking_for) return results - def _find(self, looking_for, scope_walker=None, - short_circuit=True, find_potentials=False): + def _find( + self, + looking_for, + scope_walker=None, + short_circuit=True, + find_potentials=False, + ): if scope_walker is None: scope_walker = [] default_providers, atom_providers = self.providers_fetcher(looking_for) @@ -132,12 +141,15 @@ class _ProviderLocator: searched_providers.add(p) try: provider_results = self._try_get_results( - looking_for, p, find_potentials=find_potentials, + looking_for, + p, + find_potentials=find_potentials, # For default providers always look into there # results as default providers are statically setup # and therefore looking into there provided results # should fail early. - look_into_results=True) + look_into_results=True, + ) except exceptions.NotFound: if not find_potentials: raise @@ -153,9 +165,11 @@ class _ProviderLocator: # happen); instead of retaining the possible provider match # order (which isn't that important and may be different from # the scope requested ordering). - maybe_atom_providers = [atom_providers_by_name[atom_name] - for atom_name in accessible_atom_names - if atom_name in atom_providers_by_name] + maybe_atom_providers = [ + atom_providers_by_name[atom_name] + for atom_name in accessible_atom_names + if atom_name in atom_providers_by_name + ] tmp_providers_and_results = [] if find_potentials: for p in maybe_atom_providers: @@ -170,20 +184,28 @@ class _ProviderLocator: # get the result from the *first* provider that # actually provided it (or die). provider_results = self._try_get_results( - looking_for, p, find_potentials=find_potentials, - look_into_results=False) + looking_for, + p, + find_potentials=find_potentials, + look_into_results=False, + ) except exceptions.DisallowedAccess as e: if e.state != states.IGNORE: exceptions.raise_with_cause( exceptions.NotFound, "Expected to be able to find output %r" " produced by %s but was unable to get at" - " that providers results" % (looking_for, p)) + " that providers results" % (looking_for, p), + ) else: - LOG.trace("Avoiding using the results of" - " %r (from %s) for name %r because" - " it was ignored", p.name, p, - looking_for) + LOG.trace( + "Avoiding using the results of" + " %r (from %s) for name %r because" + " it was ignored", + p.name, + p, + looking_for, + ) else: tmp_providers_and_results.append((p, provider_results)) if tmp_providers_and_results and short_circuit: @@ -195,15 +217,21 @@ class _ProviderLocator: def find_potentials(self, looking_for, scope_walker=None): """Returns the accessible **potential** providers.""" _searched_providers, providers_and_results = self._find( - looking_for, scope_walker=scope_walker, - short_circuit=False, find_potentials=True) + looking_for, + scope_walker=scope_walker, + short_circuit=False, + find_potentials=True, + ) return {p for (p, _provider_results) in providers_and_results} def find(self, looking_for, scope_walker=None, short_circuit=True): """Returns the accessible providers.""" - return self._find(looking_for, scope_walker=scope_walker, - short_circuit=short_circuit, - find_potentials=False) + return self._find( + looking_for, + scope_walker=scope_walker, + short_circuit=short_circuit, + find_potentials=False, + ) class _Provider: @@ -251,12 +279,13 @@ def _item_from_single(provider, container, looking_for): exceptions.NotFound, "Unable to find result %r, expected to be able to find it" " created by %s but was unable to perform successful" - " extraction" % (looking_for, provider)) + " extraction" % (looking_for, provider), + ) def _item_from_first_of(providers, looking_for): """Returns item from the *first* successful container extraction.""" - for (provider, container) in providers: + for provider, container in providers: try: return (provider, _item_from(container, provider.index)) except _EXTRACTION_EXCEPTIONS: @@ -265,7 +294,8 @@ def _item_from_first_of(providers, looking_for): raise exceptions.NotFound( "Unable to find result %r, expected to be able to find it" " created by one of %s but was unable to perform successful" - " extraction" % (looking_for, providers)) + " extraction" % (looking_for, providers) + ) class Storage: @@ -324,17 +354,18 @@ class Storage: fail_cache[states.REVERT] = ad.revert_failure self._failures[ad.name] = fail_cache - self._atom_name_to_uuid = {ad.name: ad.uuid - for ad in self._flowdetail} + self._atom_name_to_uuid = {ad.name: ad.uuid for ad in self._flowdetail} try: source, _clone = self._atomdetail_by_name( - self.injector_name, expected_type=models.TaskDetail) + self.injector_name, expected_type=models.TaskDetail + ) except exceptions.NotFound: pass else: names_iter = source.results.keys() - self._set_result_mapping(source.name, - {name: name for name in names_iter}) + self._set_result_mapping( + source.name, {name: name for name in names_iter} + ) def _with_connection(self, functor, *args, **kwargs): # Run the given functor with a backend connection as its first @@ -344,8 +375,12 @@ class Storage: return functor(conn, *args, **kwargs) @staticmethod - def _create_atom_detail(atom_name, atom_detail_cls, - atom_version=None, atom_state=states.PENDING): + def _create_atom_detail( + atom_name, + atom_detail_cls, + atom_version=None, + atom_state=states.PENDING, + ): ad = atom_detail_cls(atom_name, uuidutils.generate_uuid()) ad.state = atom_state if atom_version is not None: @@ -363,8 +398,10 @@ class Storage: for i, atom in enumerate(atoms): match = misc.match_type(atom, self._ensure_matchers) if not match: - raise TypeError("Unknown atom '%s' (%s) requested to ensure" - % (atom, type(atom))) + raise TypeError( + "Unknown atom '%s' (%s) requested to ensure" + % (atom, type(atom)) + ) atom_detail_cls, kind = match atom_name = atom.name if not atom_name: @@ -380,26 +417,29 @@ class Storage: if not isinstance(ad, atom_detail_cls): raise exceptions.Duplicate( "Atom detail '%s' already exists in flow" - " detail '%s'" % (atom_name, self._flowdetail.name)) + " detail '%s'" % (atom_name, self._flowdetail.name) + ) else: atom_ids.append(ad.uuid) self._set_result_mapping(atom_name, atom.save_as) if missing_ads: needs_to_be_created_ads = [] - for (i, atom, atom_detail_cls) in missing_ads: + for i, atom, atom_detail_cls in missing_ads: ad = self._create_atom_detail( - atom.name, atom_detail_cls, - atom_version=misc.get_version_string(atom)) + atom.name, + atom_detail_cls, + atom_version=misc.get_version_string(atom), + ) needs_to_be_created_ads.append((i, atom, ad)) # Add the atom detail(s) to a clone, which upon success will be # updated into the contained flow detail; if it does not get saved # then no update will happen. source, clone = self._fetch_flowdetail(clone=True) - for (_i, _atom, ad) in needs_to_be_created_ads: + for _i, _atom, ad in needs_to_be_created_ads: clone.add(ad) self._with_connection(self._save_flow_detail, source, clone) # Insert the needed data, and get outta here... - for (i, atom, ad) in needs_to_be_created_ads: + for i, atom, ad in needs_to_be_created_ads: atom_name = atom.name atom_ids[i] = ad.uuid self._atom_name_to_uuid[atom_name] = ad.uuid @@ -449,11 +489,14 @@ class Storage: # This never changes (so no read locking needed). return self._backend - @tenacity.retry(retry=tenacity.retry_if_exception_type( - exception_types=exceptions.StorageFailure), - stop=tenacity.stop_after_attempt(RETRY_ATTEMPTS), - wait=tenacity.wait_fixed(RETRY_WAIT_TIMEOUT), - reraise=True) + @tenacity.retry( + retry=tenacity.retry_if_exception_type( + exception_types=exceptions.StorageFailure + ), + stop=tenacity.stop_after_attempt(RETRY_ATTEMPTS), + wait=tenacity.wait_fixed(RETRY_WAIT_TIMEOUT), + reraise=True, + ) def _save_flow_detail(self, conn, original_flow_detail, flow_detail): # NOTE(harlowja): we need to update our contained flow detail if # the result of the update actually added more (aka another process @@ -472,26 +515,31 @@ class Storage: try: ad = self._flowdetail.find(self._atom_name_to_uuid[atom_name]) except KeyError: - exceptions.raise_with_cause(exceptions.NotFound, - "Unknown atom name '%s'" % atom_name) + exceptions.raise_with_cause( + exceptions.NotFound, "Unknown atom name '%s'" % atom_name + ) else: # TODO(harlowja): we need to figure out how to get away from doing # these kinds of type checks in general (since they likely mean # we aren't doing something right). if expected_type and not isinstance(ad, expected_type): - raise TypeError("Atom '%s' is not of the expected type: %s" - % (atom_name, - reflection.get_class_name(expected_type))) + raise TypeError( + "Atom '%s' is not of the expected type: %s" + % (atom_name, reflection.get_class_name(expected_type)) + ) if clone: return (ad, ad.copy()) else: return (ad, ad) - @tenacity.retry(retry=tenacity.retry_if_exception_type( - exception_types=exceptions.StorageFailure), - stop=tenacity.stop_after_attempt(RETRY_ATTEMPTS), - wait=tenacity.wait_fixed(RETRY_WAIT_TIMEOUT), - reraise=True) + @tenacity.retry( + retry=tenacity.retry_if_exception_type( + exception_types=exceptions.StorageFailure + ), + stop=tenacity.stop_after_attempt(RETRY_ATTEMPTS), + wait=tenacity.wait_fixed(RETRY_WAIT_TIMEOUT), + reraise=True, + ) def _save_atom_detail(self, conn, original_atom_detail, atom_detail): # NOTE(harlowja): we need to update our contained atom detail if # the result of the update actually added more (aka another process @@ -545,11 +593,12 @@ class Storage: return details @fasteners.write_locked - def _update_atom_metadata(self, atom_name, update_with, - expected_type=None): - source, clone = self._atomdetail_by_name(atom_name, - expected_type=expected_type, - clone=True) + def _update_atom_metadata( + self, atom_name, update_with, expected_type=None + ): + source, clone = self._atomdetail_by_name( + atom_name, expected_type=expected_type, clone=True + ) if update_with: clone.meta.update(update_with) self._with_connection(self._save_atom_detail, source, clone) @@ -585,8 +634,9 @@ class Storage: } else: update_with[META_PROGRESS_DETAILS] = None - self._update_atom_metadata(task_name, update_with, - expected_type=models.TaskDetail) + self._update_atom_metadata( + task_name, update_with, expected_type=models.TaskDetail + ) @fasteners.read_locked def get_task_progress(self, task_name): @@ -596,7 +646,8 @@ class Storage: :returns: current task progress value """ source, _clone = self._atomdetail_by_name( - task_name, expected_type=models.TaskDetail) + task_name, expected_type=models.TaskDetail + ) try: return source.meta[META_PROGRESS] except KeyError: @@ -611,7 +662,8 @@ class Storage: dict """ source, _clone = self._atomdetail_by_name( - task_name, expected_type=models.TaskDetail) + task_name, expected_type=models.TaskDetail + ) try: return source.meta[META_PROGRESS_DETAILS] except KeyError: @@ -631,9 +683,13 @@ class Storage: try: _item_from(container, index) except _EXTRACTION_EXCEPTIONS: - LOG.warning("Atom '%s' did not supply result " - "with index %r (name '%s')", atom_name, index, - name) + LOG.warning( + "Atom '%s' did not supply result " + "with index %r (name '%s')", + atom_name, + index, + name, + ) @fasteners.write_locked def save(self, atom_name, result, state=states.SUCCESS): @@ -658,14 +714,17 @@ class Storage: def save_retry_failure(self, retry_name, failed_atom_name, failure): """Save subflow failure to retry controller history.""" source, clone = self._atomdetail_by_name( - retry_name, expected_type=models.RetryDetail, clone=True) + retry_name, expected_type=models.RetryDetail, clone=True + ) try: failures = clone.last_failures except exceptions.NotFound: - exceptions.raise_with_cause(exceptions.StorageFailure, - "Unable to fetch most recent retry" - " failures so new retry failure can" - " be inserted") + exceptions.raise_with_cause( + exceptions.StorageFailure, + "Unable to fetch most recent retry" + " failures so new retry failure can" + " be inserted", + ) else: if failed_atom_name not in failures: failures[failed_atom_name] = failure @@ -675,15 +734,21 @@ class Storage: def cleanup_retry_history(self, retry_name, state): """Cleanup history of retry atom with given name.""" source, clone = self._atomdetail_by_name( - retry_name, expected_type=models.RetryDetail, clone=True) + retry_name, expected_type=models.RetryDetail, clone=True + ) clone.state = state clone.results = [] self._with_connection(self._save_atom_detail, source, clone) @fasteners.read_locked - def _get(self, atom_name, - results_attr_name, fail_attr_name, - allowed_states, fail_cache_key): + def _get( + self, + atom_name, + results_attr_name, + fail_attr_name, + allowed_states, + fail_cache_key, + ): source, _clone = self._atomdetail_by_name(atom_name) failure = getattr(source, fail_attr_name) if failure is not None: @@ -703,27 +768,35 @@ class Storage: raise exceptions.DisallowedAccess( "Result for atom '%s' is not known/accessible" " due to it being in %s state when result access" - " is restricted to %s states" % (atom_name, - source.state, - allowed_states), - state=source.state) + " is restricted to %s states" + % (atom_name, source.state, allowed_states), + state=source.state, + ) return getattr(source, results_attr_name) def get_execute_result(self, atom_name): """Gets the ``execute`` results for an atom from storage.""" try: - results = self._get(atom_name, 'results', 'failure', - _EXECUTE_STATES_WITH_RESULTS, states.EXECUTE) + results = self._get( + atom_name, + 'results', + 'failure', + _EXECUTE_STATES_WITH_RESULTS, + states.EXECUTE, + ) except exceptions.DisallowedAccess as e: if e.state == states.IGNORE: - exceptions.raise_with_cause(exceptions.NotFound, - "Result for atom '%s' execution" - " is not known (as it was" - " ignored)" % atom_name) + exceptions.raise_with_cause( + exceptions.NotFound, + "Result for atom '%s' execution" + " is not known (as it was" + " ignored)" % atom_name, + ) else: - exceptions.raise_with_cause(exceptions.NotFound, - "Result for atom '%s' execution" - " is not known" % atom_name) + exceptions.raise_with_cause( + exceptions.NotFound, + "Result for atom '%s' execution is not known" % atom_name, + ) else: return results @@ -748,18 +821,26 @@ class Storage: def get_revert_result(self, atom_name): """Gets the ``revert`` results for an atom from storage.""" try: - results = self._get(atom_name, 'revert_results', 'revert_failure', - _REVERT_STATES_WITH_RESULTS, states.REVERT) + results = self._get( + atom_name, + 'revert_results', + 'revert_failure', + _REVERT_STATES_WITH_RESULTS, + states.REVERT, + ) except exceptions.DisallowedAccess as e: if e.state == states.IGNORE: - exceptions.raise_with_cause(exceptions.NotFound, - "Result for atom '%s' revert is" - " not known (as it was" - " ignored)" % atom_name) + exceptions.raise_with_cause( + exceptions.NotFound, + "Result for atom '%s' revert is" + " not known (as it was" + " ignored)" % atom_name, + ) else: - exceptions.raise_with_cause(exceptions.NotFound, - "Result for atom '%s' revert is" - " not known" % atom_name) + exceptions.raise_with_cause( + exceptions.NotFound, + "Result for atom '%s' revert is not known" % atom_name, + ) else: return results @@ -876,27 +957,30 @@ class Storage: source, clone = self._atomdetail_by_name( self.injector_name, expected_type=models.TaskDetail, - clone=True) + clone=True, + ) except exceptions.NotFound: # Ensure we have our special task detail... # # TODO(harlowja): get this removed when # https://review.openstack.org/#/c/165645/ merges. - source = self._create_atom_detail(self.injector_name, - models.TaskDetail, - atom_state=None) + source = self._create_atom_detail( + self.injector_name, models.TaskDetail, atom_state=None + ) fd_source, fd_clone = self._fetch_flowdetail(clone=True) fd_clone.add(source) - self._with_connection(self._save_flow_detail, fd_source, - fd_clone) + self._with_connection( + self._save_flow_detail, fd_source, fd_clone + ) self._atom_name_to_uuid[source.name] = source.uuid clone = source clone.results = dict(pairs) clone.state = states.SUCCESS else: clone.results.update(pairs) - result = self._with_connection(self._save_atom_detail, - source, clone) + result = self._with_connection( + self._save_atom_detail, source, clone + ) return (self.injector_name, result.results.keys()) def save_transient(): @@ -908,8 +992,7 @@ class Storage: else: provider_name, names = save_persistent() - self._set_result_mapping(provider_name, - {name: name for name in names}) + self._set_result_mapping(provider_name, {name: name for name in names}) def _fetch_providers(self, looking_for, providers=None): """Return pair of (default providers, atom providers).""" @@ -945,44 +1028,58 @@ class Storage: @fasteners.read_locked def fetch(self, name, many_handler=None): """Fetch a named ``execute`` result.""" + def _many_handler(values): # By default we just return the first of many (unless provided # a different callback that can translate many results into # something more meaningful). return values[0] + if many_handler is None: many_handler = _many_handler try: maybe_providers = self._reverse_mapping[name] except KeyError: - raise exceptions.NotFound("Name %r is not mapped as a produced" - " output by any providers" % name) + raise exceptions.NotFound( + "Name %r is not mapped as a produced" + " output by any providers" % name + ) locator = _ProviderLocator( self._transients, - functools.partial(self._fetch_providers, - providers=maybe_providers), - lambda atom_name: - self._get(atom_name, 'last_results', 'failure', - _EXECUTE_STATES_WITH_RESULTS, states.EXECUTE)) + functools.partial( + self._fetch_providers, providers=maybe_providers + ), + lambda atom_name: self._get( + atom_name, + 'last_results', + 'failure', + _EXECUTE_STATES_WITH_RESULTS, + states.EXECUTE, + ), + ) values = [] searched_providers, providers = locator.find( - name, short_circuit=False, + name, + short_circuit=False, # NOTE(harlowja): There are no scopes used here (as of now), so # we just return all known providers as if it was one large # scope. - scope_walker=[[p.name for p in maybe_providers]]) + scope_walker=[[p.name for p in maybe_providers]], + ) for provider, results in providers: values.append(_item_from_single(provider, results, name)) if not values: raise exceptions.NotFound( "Unable to find result %r, searched %s providers" - % (name, len(searched_providers))) + % (name, len(searched_providers)) + ) else: return many_handler(values) @fasteners.read_locked - def fetch_unsatisfied_args(self, atom_name, args_mapping, - scope_walker=None, optional_args=None): + def fetch_unsatisfied_args( + self, atom_name, args_mapping, scope_walker=None, optional_args=None + ): """Fetch unsatisfied ``execute`` arguments using an atoms args mapping. NOTE(harlowja): this takes into account the provided scope walker @@ -1003,14 +1100,24 @@ class Storage: ] missing = set(args_mapping.keys()) locator = _ProviderLocator( - self._transients, self._fetch_providers, - lambda atom_name: - self._get(atom_name, 'last_results', 'failure', - _EXECUTE_STATES_WITH_RESULTS, states.EXECUTE)) - for (bound_name, name) in args_mapping.items(): + self._transients, + self._fetch_providers, + lambda atom_name: self._get( + atom_name, + 'last_results', + 'failure', + _EXECUTE_STATES_WITH_RESULTS, + states.EXECUTE, + ), + ) + for bound_name, name in args_mapping.items(): if LOG.isEnabledFor(logging.TRACE): - LOG.trace("Looking for %r <= %r for atom '%s'", - bound_name, name, atom_name) + LOG.trace( + "Looking for %r <= %r for atom '%s'", + bound_name, + name, + atom_name, + ) if bound_name in optional_args: LOG.trace("Argument %r is optional, skipping", bound_name) missing.discard(bound_name) @@ -1022,21 +1129,28 @@ class Storage: if name in source: maybe_providers += 1 maybe_providers += len( - locator.find_potentials(name, scope_walker=scope_walker)) + locator.find_potentials(name, scope_walker=scope_walker) + ) if maybe_providers: - LOG.trace("Atom '%s' will have %s potential providers" - " of %r <= %r", atom_name, maybe_providers, - bound_name, name) + LOG.trace( + "Atom '%s' will have %s potential providers of %r <= %r", + atom_name, + maybe_providers, + bound_name, + name, + ) missing.discard(bound_name) return missing @fasteners.read_locked def fetch_all(self, many_handler=None): """Fetch all named ``execute`` results known so far.""" + def _many_handler(values): if len(values) > 1: return values return values[0] + if many_handler is None: many_handler = _many_handler results = {} @@ -1048,10 +1162,15 @@ class Storage: return results @fasteners.read_locked - def fetch_mapped_args(self, args_mapping, - atom_name=None, scope_walker=None, - optional_args=None): + def fetch_mapped_args( + self, + args_mapping, + atom_name=None, + scope_walker=None, + optional_args=None, + ): """Fetch ``execute`` arguments for an atom using its args mapping.""" + def _extract_first_from(name, sources): """Extracts/returns first occurrence of key in list of dicts.""" for i, source in enumerate(sources): @@ -1060,6 +1179,7 @@ class Storage: if name in source: return (i, source[name]) raise KeyError(name) + if optional_args is None: optional_args = [] if atom_name: @@ -1074,57 +1194,89 @@ class Storage: injected_sources = [] if not args_mapping: return {} - get_results = lambda atom_name: \ - self._get(atom_name, 'last_results', 'failure', - _EXECUTE_STATES_WITH_RESULTS, states.EXECUTE) + get_results = lambda atom_name: self._get( + atom_name, + 'last_results', + 'failure', + _EXECUTE_STATES_WITH_RESULTS, + states.EXECUTE, + ) mapped_args = {} - for (bound_name, name) in args_mapping.items(): + for bound_name, name in args_mapping.items(): if LOG.isEnabledFor(logging.TRACE): if atom_name: - LOG.trace("Looking for %r <= %r for atom '%s'", - bound_name, name, atom_name) + LOG.trace( + "Looking for %r <= %r for atom '%s'", + bound_name, + name, + atom_name, + ) else: LOG.trace("Looking for %r <= %r", bound_name, name) try: source_index, value = _extract_first_from( - name, injected_sources) + name, injected_sources + ) mapped_args[bound_name] = value if LOG.isEnabledFor(logging.TRACE): if source_index == 0: - LOG.trace("Matched %r <= %r to %r (from injected" - " atom-specific transient" - " values)", bound_name, name, value) + LOG.trace( + "Matched %r <= %r to %r (from injected" + " atom-specific transient" + " values)", + bound_name, + name, + value, + ) else: - LOG.trace("Matched %r <= %r to %r (from injected" - " atom-specific persistent" - " values)", bound_name, name, value) + LOG.trace( + "Matched %r <= %r to %r (from injected" + " atom-specific persistent" + " values)", + bound_name, + name, + value, + ) except KeyError: try: maybe_providers = self._reverse_mapping[name] except KeyError: if bound_name in optional_args: - LOG.trace("Argument %r is optional, skipping", - bound_name) + LOG.trace( + "Argument %r is optional, skipping", bound_name + ) continue - raise exceptions.NotFound("Name %r is not mapped as a" - " produced output by any" - " providers" % name) + raise exceptions.NotFound( + "Name %r is not mapped as a" + " produced output by any" + " providers" % name + ) locator = _ProviderLocator( self._transients, - functools.partial(self._fetch_providers, - providers=maybe_providers), get_results) + functools.partial( + self._fetch_providers, providers=maybe_providers + ), + get_results, + ) searched_providers, providers = locator.find( - name, scope_walker=scope_walker) + name, scope_walker=scope_walker + ) if not providers: raise exceptions.NotFound( "Mapped argument %r <= %r was not produced" " by any accessible provider (%s possible" " providers were scanned)" - % (bound_name, name, len(searched_providers))) + % (bound_name, name, len(searched_providers)) + ) provider, value = _item_from_first_of(providers, name) mapped_args[bound_name] = value - LOG.trace("Matched %r <= %r to %r (from %s)", - bound_name, name, value, provider) + LOG.trace( + "Matched %r <= %r to %r (from %s)", + bound_name, + name, + value, + provider, + ) return mapped_args @fasteners.write_locked @@ -1186,7 +1338,8 @@ class Storage: def get_retry_history(self, retry_name): """Fetch a single retrys history.""" source, _clone = self._atomdetail_by_name( - retry_name, expected_type=models.RetryDetail) + retry_name, expected_type=models.RetryDetail + ) return self._translate_into_history(source) @fasteners.read_locked @@ -1195,6 +1348,5 @@ class Storage: histories = [] for ad in self._flowdetail: if isinstance(ad, models.RetryDetail): - histories.append((ad.name, - self._translate_into_history(ad))) + histories.append((ad.name, self._translate_into_history(ad))) return histories diff --git a/taskflow/task.py b/taskflow/task.py index af27903c1..3c8d0c80e 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -54,15 +54,30 @@ class Task(atom.Atom, metaclass=abc.ABCMeta): # or existing internal events... TASK_EVENTS = (EVENT_UPDATE_PROGRESS,) - def __init__(self, name=None, provides=None, requires=None, - auto_extract=True, rebind=None, inject=None, - ignore_list=None, revert_rebind=None, revert_requires=None): + def __init__( + self, + name=None, + provides=None, + requires=None, + auto_extract=True, + rebind=None, + inject=None, + ignore_list=None, + revert_rebind=None, + revert_requires=None, + ): if name is None: name = reflection.get_class_name(self) - super().__init__(name, provides=provides, requires=requires, - auto_extract=auto_extract, rebind=rebind, - inject=inject, revert_rebind=revert_rebind, - revert_requires=revert_requires) + super().__init__( + name, + provides=provides, + requires=requires, + auto_extract=auto_extract, + rebind=rebind, + inject=inject, + revert_rebind=revert_rebind, + revert_requires=revert_requires, + ) self._notifier = notifier.RestrictedNotifier(self.TASK_EVENTS) @property @@ -97,14 +112,20 @@ class Task(atom.Atom, metaclass=abc.ABCMeta): :param progress: task progress float value between 0.0 and 1.0 """ + def on_clamped(): - LOG.warning("Progress value must be greater or equal to 0.0 or" - " less than or equal to 1.0 instead of being '%s'", - progress) - cleaned_progress = misc.clamp(progress, 0.0, 1.0, - on_clamped=on_clamped) - self._notifier.notify(EVENT_UPDATE_PROGRESS, - {'progress': cleaned_progress}) + LOG.warning( + "Progress value must be greater or equal to 0.0 or" + " less than or equal to 1.0 instead of being '%s'", + progress, + ) + + cleaned_progress = misc.clamp( + progress, 0.0, 1.0, on_clamped=on_clamped + ) + self._notifier.notify( + EVENT_UPDATE_PROGRESS, {'progress': cleaned_progress} + ) class FunctorTask(Task): @@ -117,16 +138,25 @@ class FunctorTask(Task): the ``revert`` callable is not used). """ - def __init__(self, execute, name=None, provides=None, - requires=None, auto_extract=True, rebind=None, revert=None, - version=None, inject=None): + def __init__( + self, + execute, + name=None, + provides=None, + requires=None, + auto_extract=True, + rebind=None, + revert=None, + version=None, + inject=None, + ): if not callable(execute): - raise ValueError("Function to use for executing must be" - " callable") + raise ValueError("Function to use for executing must be callable") if revert is not None: if not callable(revert): - raise ValueError("Function to use for reverting must" - " be callable") + raise ValueError( + "Function to use for reverting must be callable" + ) if name is None: name = reflection.get_callable_name(execute) super().__init__(name, provides=provides, inject=inject) @@ -134,17 +164,20 @@ class FunctorTask(Task): self._revert = revert if version is not None: self.version = version - mapping = self._build_arg_mapping(execute, requires, rebind, - auto_extract) + mapping = self._build_arg_mapping( + execute, requires, rebind, auto_extract + ) self.rebind, exec_requires, self.optional = mapping if revert: - revert_mapping = self._build_arg_mapping(revert, requires, rebind, - auto_extract) + revert_mapping = self._build_arg_mapping( + revert, requires, rebind, auto_extract + ) else: revert_mapping = (self.rebind, exec_requires, self.optional) - (self.revert_rebind, revert_requires, - self.revert_optional) = revert_mapping + (self.revert_rebind, revert_requires, self.revert_optional) = ( + revert_mapping + ) self.requires = exec_requires.union(revert_requires) def execute(self, *args, **kwargs): @@ -166,33 +199,50 @@ class ReduceFunctorTask(Task): task calls ``reduce`` with the functor and list as arguments. The resulting value from the call to ``reduce`` is then returned after execution. """ - def __init__(self, functor, requires, name=None, provides=None, - auto_extract=True, rebind=None, inject=None): + + def __init__( + self, + functor, + requires, + name=None, + provides=None, + auto_extract=True, + rebind=None, + inject=None, + ): if not callable(functor): raise ValueError("Function to use for reduce must be callable") f_args = reflection.get_callable_args(functor) if len(f_args) != 2: - raise ValueError("%s arguments were provided. Reduce functor " - "must take exactly 2 arguments." % len(f_args)) + raise ValueError( + "%s arguments were provided. Reduce functor " + "must take exactly 2 arguments." % len(f_args) + ) if not misc.is_iterable(requires): - raise TypeError("%s type was provided for requires. Requires " - "must be an iterable." % type(requires)) + raise TypeError( + "%s type was provided for requires. Requires " + "must be an iterable." % type(requires) + ) if len(requires) < 2: - raise ValueError("%s elements were provided. Requires must have " - "at least 2 elements." % len(requires)) + raise ValueError( + "%s elements were provided. Requires must have " + "at least 2 elements." % len(requires) + ) if name is None: name = reflection.get_callable_name(functor) - super().__init__(name=name, - provides=provides, - inject=inject, - requires=requires, - rebind=rebind, - auto_extract=auto_extract) + super().__init__( + name=name, + provides=provides, + inject=inject, + requires=requires, + rebind=rebind, + auto_extract=auto_extract, + ) self._functor = functor @@ -215,27 +265,43 @@ class MapFunctorTask(Task): preserved in the returned list. """ - def __init__(self, functor, requires, name=None, provides=None, - auto_extract=True, rebind=None, inject=None): + def __init__( + self, + functor, + requires, + name=None, + provides=None, + auto_extract=True, + rebind=None, + inject=None, + ): if not callable(functor): raise ValueError("Function to use for map must be callable") f_args = reflection.get_callable_args(functor) if len(f_args) != 1: - raise ValueError("%s arguments were provided. Map functor must " - "take exactly 1 argument." % len(f_args)) + raise ValueError( + "%s arguments were provided. Map functor must " + "take exactly 1 argument." % len(f_args) + ) if not misc.is_iterable(requires): - raise TypeError("%s type was provided for requires. Requires " - "must be an iterable." % type(requires)) + raise TypeError( + "%s type was provided for requires. Requires " + "must be an iterable." % type(requires) + ) if name is None: name = reflection.get_callable_name(functor) - super().__init__(name=name, provides=provides, - inject=inject, requires=requires, - rebind=rebind, - auto_extract=auto_extract) + super().__init__( + name=name, + provides=provides, + inject=inject, + requires=requires, + rebind=rebind, + auto_extract=auto_extract, + ) self._functor = functor diff --git a/taskflow/test.py b/taskflow/test.py index e620aa396..662847887 100644 --- a/taskflow/test.py +++ b/taskflow/test.py @@ -41,10 +41,12 @@ class FailureRegexpMatcher: def match(self, failure): for cause in failure: if cause.check(self.exc_class) is not None: - return matchers.MatchesRegex( - self.pattern).match(cause.exception_str) - return matchers.Mismatch("The `%s` wasn't caused by the `%s`" % - (failure, self.exc_class)) + return matchers.MatchesRegex(self.pattern).match( + cause.exception_str + ) + return matchers.Mismatch( + "The `%s` wasn't caused by the `%s`" % (failure, self.exc_class) + ) class TestCase(base.BaseTestCase): @@ -75,15 +77,18 @@ class TestCase(base.BaseTestCase): except ValueError: # element not found if msg is None: - msg = ("%r is not subsequence of %r: " - "element %r not found in tail %r" - % (sub_seq, super_seq, sub_elem, current_tail)) + msg = ( + "%r is not subsequence of %r: " + "element %r not found in tail %r" + % (sub_seq, super_seq, sub_elem, current_tail) + ) self.fail(msg) else: - current_tail = current_tail[super_index + 1:] + current_tail = current_tail[super_index + 1 :] - def assertFailuresRegexp(self, exc_class, pattern, callable_obj, *args, - **kwargs): + def assertFailuresRegexp( + self, exc_class, pattern, callable_obj, *args, **kwargs + ): """Asserts the callable failed with the given exception and message.""" try: with utils.wrap_all_failures(): @@ -93,15 +98,15 @@ class TestCase(base.BaseTestCase): class MockTestCase(TestCase): - def setUp(self): super().setUp() self.master_mock = mock.Mock(name='master_mock') def patch(self, target, autospec=True, **kwargs): """Patch target and attach it to the master mock.""" - f = self.useFixture(fixtures.MockPatch(target, - autospec=autospec, **kwargs)) + f = self.useFixture( + fixtures.MockPatch(target, autospec=autospec, **kwargs) + ) mocked = f.mock attach_as = kwargs.pop('attach_as', None) if attach_as is not None: @@ -120,8 +125,9 @@ class MockTestCase(TestCase): else: instance_mock = mock.Mock() - f = self.useFixture(fixtures.MockPatchObject(module, name, - autospec=autospec)) + f = self.useFixture( + fixtures.MockPatchObject(module, name, autospec=autospec) + ) class_mock = f.mock class_mock.return_value = instance_mock diff --git a/taskflow/tests/fixtures.py b/taskflow/tests/fixtures.py index f57d7c008..686772ca3 100644 --- a/taskflow/tests/fixtures.py +++ b/taskflow/tests/fixtures.py @@ -18,6 +18,7 @@ from sqlalchemy import exc as sqla_exc class WarningsFixture(fixtures.Fixture): """Filters out warnings during test runs.""" + def setUp(self): super().setUp() diff --git a/taskflow/tests/test_examples.py b/taskflow/tests/test_examples.py index 75f07816d..d6c20eb79 100644 --- a/taskflow/tests/test_examples.py +++ b/taskflow/tests/test_examples.py @@ -25,7 +25,6 @@ generated. Please note that this will break tests as output for most examples is indeterministic (due to hash randomization for example). """ - import keyword import os import re @@ -39,18 +38,19 @@ from taskflow import test from taskflow.tests import utils as test_utils ROOT_DIR = os.path.abspath( - os.path.dirname( - os.path.dirname( - os.path.dirname(__file__)))) + os.path.dirname(os.path.dirname(os.path.dirname(__file__))) +) # This is used so that any uuid like data being output is removed (since it # will change per test run and will invalidate the deterministic output that # we expect to be able to check). -UUID_RE = re.compile('XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX' - .replace('X', '[0-9a-f]')) +UUID_RE = re.compile( + 'XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX'.replace('X', '[0-9a-f]') +) ZOOKEEPER_AVAILABLE = test_utils.zookeeper_available( - impl_zookeeper.ZookeeperJobBoard.MIN_ZK_VERSION) + impl_zookeeper.ZookeeperJobBoard.MIN_ZK_VERSION +) def safe_filename(filename): @@ -68,20 +68,22 @@ def root_path(*args): def run_example(name): path = root_path('taskflow', 'examples', '%s.py' % name) - obj = subprocess.Popen([sys.executable, path], - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + obj = subprocess.Popen( + [sys.executable, path], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) output = obj.communicate() stdout = output[0].decode() stderr = output[1].decode() rc = obj.wait() if rc != 0: - raise RuntimeError('Example %s failed, return code=%s\n' - '<<>>\n%s' - '<<>>\n' - '<<>>\n%s' - '<<>>' - % (name, rc, stdout, stderr)) + raise RuntimeError( + 'Example %s failed, return code=%s\n' + '<<>>\n%s' + '<<>>\n' + '<<>>\n%s' + '<<>>' % (name, rc, stdout, stderr) + ) return stdout @@ -112,6 +114,7 @@ class ExampleAdderMeta(type): def generate_test(example_name): def test_example(self): self._check_example(example_name) + return test_example for example_name, safe_name in iter_examples(): diff --git a/taskflow/tests/unit/action_engine/test_builder.py b/taskflow/tests/unit/action_engine/test_builder.py index 2ea5f1e11..c8a480c6a 100644 --- a/taskflow/tests/unit/action_engine/test_builder.py +++ b/taskflow/tests/unit/action_engine/test_builder.py @@ -29,7 +29,6 @@ from taskflow.utils import persistence_utils as pu class BuildersTest(test.TestCase): - def _make_runtime(self, flow, initial_state=None): compilation = compiler.PatternCompiler(flow).compile() flow_detail = pu.create_flow_detail(flow) @@ -45,9 +44,9 @@ class BuildersTest(test.TestCase): retry_executor = executor.SerialRetryExecutor() task_executor.start() self.addCleanup(task_executor.stop) - r = runtime.Runtime(compilation, store, - atom_notifier, task_executor, - retry_executor) + r = runtime.Runtime( + compilation, store, atom_notifier, task_executor, retry_executor + ) r.compile() return r @@ -60,11 +59,13 @@ class BuildersTest(test.TestCase): def test_run_iterations(self): flow = lf.Flow("root") tasks = test_utils.make_many( - 1, task_cls=test_utils.TaskNoRequiresNoReturns) + 1, task_cls=test_utils.TaskNoRequiresNoReturns + ) flow.add(*tasks) runtime, machine, memory, machine_runner = self._make_machine( - flow, initial_state=st.RUNNING) + flow, initial_state=st.RUNNING + ) it = machine_runner.run_iter(builder.START) prior_state, new_state = next(it) @@ -94,28 +95,29 @@ class BuildersTest(test.TestCase): def test_run_iterations_reverted(self): flow = lf.Flow("root") - tasks = test_utils.make_many( - 1, task_cls=test_utils.TaskWithFailure) + tasks = test_utils.make_many(1, task_cls=test_utils.TaskWithFailure) flow.add(*tasks) runtime, machine, memory, machine_runner = self._make_machine( - flow, initial_state=st.RUNNING) + flow, initial_state=st.RUNNING + ) transitions = list(machine_runner.run_iter(builder.START)) prior_state, new_state = transitions[-1] self.assertEqual(st.REVERTED, new_state) self.assertEqual([], memory.failures) - self.assertEqual(st.REVERTED, - runtime.storage.get_atom_state(tasks[0].name)) + self.assertEqual( + st.REVERTED, runtime.storage.get_atom_state(tasks[0].name) + ) def test_run_iterations_failure(self): flow = lf.Flow("root") - tasks = test_utils.make_many( - 1, task_cls=test_utils.NastyFailingTask) + tasks = test_utils.make_many(1, task_cls=test_utils.NastyFailingTask) flow.add(*tasks) runtime, machine, memory, machine_runner = self._make_machine( - flow, initial_state=st.RUNNING) + flow, initial_state=st.RUNNING + ) transitions = list(machine_runner.run_iter(builder.START)) prior_state, new_state = transitions[-1] @@ -123,17 +125,20 @@ class BuildersTest(test.TestCase): self.assertEqual(1, len(memory.failures)) failure = memory.failures[0] self.assertTrue(failure.check(RuntimeError)) - self.assertEqual(st.REVERT_FAILURE, - runtime.storage.get_atom_state(tasks[0].name)) + self.assertEqual( + st.REVERT_FAILURE, runtime.storage.get_atom_state(tasks[0].name) + ) def test_run_iterations_suspended(self): flow = lf.Flow("root") tasks = test_utils.make_many( - 2, task_cls=test_utils.TaskNoRequiresNoReturns) + 2, task_cls=test_utils.TaskNoRequiresNoReturns + ) flow.add(*tasks) runtime, machine, memory, machine_runner = self._make_machine( - flow, initial_state=st.RUNNING) + flow, initial_state=st.RUNNING + ) transitions = [] for prior_state, new_state in machine_runner.run_iter(builder.START): @@ -144,22 +149,27 @@ class BuildersTest(test.TestCase): self.assertEqual(st.SUSPENDED, state) self.assertEqual([], failures) - self.assertEqual(st.SUCCESS, - runtime.storage.get_atom_state(tasks[0].name)) - self.assertEqual(st.PENDING, - runtime.storage.get_atom_state(tasks[1].name)) + self.assertEqual( + st.SUCCESS, runtime.storage.get_atom_state(tasks[0].name) + ) + self.assertEqual( + st.PENDING, runtime.storage.get_atom_state(tasks[1].name) + ) def test_run_iterations_suspended_failure(self): flow = lf.Flow("root") sad_tasks = test_utils.make_many( - 1, task_cls=test_utils.NastyFailingTask) + 1, task_cls=test_utils.NastyFailingTask + ) flow.add(*sad_tasks) happy_tasks = test_utils.make_many( - 1, task_cls=test_utils.TaskNoRequiresNoReturns, offset=1) + 1, task_cls=test_utils.TaskNoRequiresNoReturns, offset=1 + ) flow.add(*happy_tasks) runtime, machine, memory, machine_runner = self._make_machine( - flow, initial_state=st.RUNNING) + flow, initial_state=st.RUNNING + ) transitions = [] for prior_state, new_state in machine_runner.run_iter(builder.START): @@ -170,24 +180,29 @@ class BuildersTest(test.TestCase): self.assertEqual(st.SUSPENDED, state) self.assertEqual([], failures) - self.assertEqual(st.PENDING, - runtime.storage.get_atom_state(happy_tasks[0].name)) - self.assertEqual(st.FAILURE, - runtime.storage.get_atom_state(sad_tasks[0].name)) + self.assertEqual( + st.PENDING, runtime.storage.get_atom_state(happy_tasks[0].name) + ) + self.assertEqual( + st.FAILURE, runtime.storage.get_atom_state(sad_tasks[0].name) + ) def test_builder_manual_process(self): flow = lf.Flow("root") tasks = test_utils.make_many( - 1, task_cls=test_utils.TaskNoRequiresNoReturns) + 1, task_cls=test_utils.TaskNoRequiresNoReturns + ) flow.add(*tasks) runtime, machine, memory, machine_runner = self._make_machine( - flow, initial_state=st.RUNNING) + flow, initial_state=st.RUNNING + ) self.assertRaises(excp.NotInitialized, machine.process_event, 'poke') # Should now be pending... - self.assertEqual(st.PENDING, - runtime.storage.get_atom_state(tasks[0].name)) + self.assertEqual( + st.PENDING, runtime.storage.get_atom_state(tasks[0].name) + ) machine.initialize() self.assertEqual(builder.UNDEFINED, machine.current_state) @@ -203,8 +218,9 @@ class BuildersTest(test.TestCase): last_state = machine.current_state cb, args, kwargs = reaction - next_event = cb(last_state, machine.current_state, - builder.START, *args, **kwargs) + next_event = cb( + last_state, machine.current_state, builder.START, *args, **kwargs + ) reaction, terminal = machine.process_event(next_event) self.assertFalse(terminal) self.assertIsNotNone(reaction) @@ -213,21 +229,24 @@ class BuildersTest(test.TestCase): last_state = machine.current_state cb, args, kwargs = reaction - next_event = cb(last_state, machine.current_state, - next_event, *args, **kwargs) + next_event = cb( + last_state, machine.current_state, next_event, *args, **kwargs + ) reaction, terminal = machine.process_event(next_event) self.assertFalse(terminal) self.assertEqual(st.WAITING, machine.current_state) self.assertRaises(excp.NotFound, machine.process_event, 'poke') # Should now be running... - self.assertEqual(st.RUNNING, - runtime.storage.get_atom_state(tasks[0].name)) + self.assertEqual( + st.RUNNING, runtime.storage.get_atom_state(tasks[0].name) + ) last_state = machine.current_state cb, args, kwargs = reaction - next_event = cb(last_state, machine.current_state, - next_event, *args, **kwargs) + next_event = cb( + last_state, machine.current_state, next_event, *args, **kwargs + ) reaction, terminal = machine.process_event(next_event) self.assertFalse(terminal) self.assertIsNotNone(reaction) @@ -236,30 +255,35 @@ class BuildersTest(test.TestCase): last_state = machine.current_state cb, args, kwargs = reaction - next_event = cb(last_state, machine.current_state, - next_event, *args, **kwargs) + next_event = cb( + last_state, machine.current_state, next_event, *args, **kwargs + ) reaction, terminal = machine.process_event(next_event) self.assertFalse(terminal) self.assertEqual(builder.GAME_OVER, machine.current_state) # Should now be done... - self.assertEqual(st.SUCCESS, - runtime.storage.get_atom_state(tasks[0].name)) + self.assertEqual( + st.SUCCESS, runtime.storage.get_atom_state(tasks[0].name) + ) def test_builder_automatic_process(self): flow = lf.Flow("root") tasks = test_utils.make_many( - 1, task_cls=test_utils.TaskNoRequiresNoReturns) + 1, task_cls=test_utils.TaskNoRequiresNoReturns + ) flow.add(*tasks) runtime, machine, memory, machine_runner = self._make_machine( - flow, initial_state=st.RUNNING) + flow, initial_state=st.RUNNING + ) transitions = list(machine_runner.run_iter(builder.START)) self.assertEqual((builder.UNDEFINED, st.RESUMING), transitions[0]) self.assertEqual((builder.GAME_OVER, st.SUCCESS), transitions[-1]) - self.assertEqual(st.SUCCESS, - runtime.storage.get_atom_state(tasks[0].name)) + self.assertEqual( + st.SUCCESS, runtime.storage.get_atom_state(tasks[0].name) + ) def test_builder_automatic_process_failure(self): flow = lf.Flow("root") @@ -267,7 +291,8 @@ class BuildersTest(test.TestCase): flow.add(*tasks) runtime, machine, memory, machine_runner = self._make_machine( - flow, initial_state=st.RUNNING) + flow, initial_state=st.RUNNING + ) transitions = list(machine_runner.run_iter(builder.START)) self.assertEqual((builder.GAME_OVER, st.FAILURE), transitions[-1]) @@ -279,21 +304,25 @@ class BuildersTest(test.TestCase): flow.add(*tasks) runtime, machine, memory, machine_runner = self._make_machine( - flow, initial_state=st.RUNNING) + flow, initial_state=st.RUNNING + ) transitions = list(machine_runner.run_iter(builder.START)) self.assertEqual((builder.GAME_OVER, st.REVERTED), transitions[-1]) - self.assertEqual(st.REVERTED, - runtime.storage.get_atom_state(tasks[0].name)) + self.assertEqual( + st.REVERTED, runtime.storage.get_atom_state(tasks[0].name) + ) def test_builder_expected_transition_occurrences(self): flow = lf.Flow("root") tasks = test_utils.make_many( - 10, task_cls=test_utils.TaskNoRequiresNoReturns) + 10, task_cls=test_utils.TaskNoRequiresNoReturns + ) flow.add(*tasks) runtime, machine, memory, machine_runner = self._make_machine( - flow, initial_state=st.RUNNING) + flow, initial_state=st.RUNNING + ) transitions = list(machine_runner.run_iter(builder.START)) occurrences = {t: transitions.count(t) for t in transitions} diff --git a/taskflow/tests/unit/action_engine/test_compile.py b/taskflow/tests/unit/action_engine/test_compile.py index de37b1934..81bd6e285 100644 --- a/taskflow/tests/unit/action_engine/test_compile.py +++ b/taskflow/tests/unit/action_engine/test_compile.py @@ -42,7 +42,8 @@ class PatternCompileTest(test.TestCase): def test_task(self): task = test_utils.DummyTask(name='a') g = _replicate_graph_with_names( - compiler.PatternCompiler(task).compile()) + compiler.PatternCompiler(task).compile() + ) self.assertEqual(['a'], list(g.nodes())) self.assertEqual([], list(g.edges())) @@ -52,8 +53,9 @@ class PatternCompileTest(test.TestCase): def test_wrong_object(self): msg_regex = '^Unknown object .* requested to compile' - self.assertRaisesRegex(TypeError, msg_regex, - compiler.PatternCompiler(42).compile) + self.assertRaisesRegex( + TypeError, msg_regex, compiler.PatternCompiler(42).compile + ) def test_empty(self): flo = lf.Flow("test") @@ -68,17 +70,18 @@ class PatternCompileTest(test.TestCase): flo.add(inner_flo) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(8, len(g)) order = list(g.topological_sort()) - self.assertEqual(['test', 'a', 'b', 'c', - "sub-test", 'd', "sub-test[$]", - 'test[$]'], order) + self.assertEqual( + ['test', 'a', 'b', 'c', "sub-test", 'd', "sub-test[$]", 'test[$]'], + order, + ) self.assertTrue(g.has_edge('c', "sub-test")) self.assertTrue(g.has_edge("sub-test", 'd')) - self.assertEqual({'invariant': True}, - g.get_edge_data("sub-test", 'd')) + self.assertEqual({'invariant': True}, g.get_edge_data("sub-test", 'd')) self.assertEqual(['test[$]'], list(g.no_successors_iter())) self.assertEqual(['test'], list(g.no_predecessors_iter())) @@ -87,8 +90,7 @@ class PatternCompileTest(test.TestCase): flo = lf.Flow("test") flo.add(a, b, c) flo.add(flo) - self.assertRaises(ValueError, - compiler.PatternCompiler(flo).compile) + self.assertRaises(ValueError, compiler.PatternCompiler(flo).compile) def test_unordered(self): a, b, c, d = test_utils.make_many(4) @@ -96,18 +98,22 @@ class PatternCompileTest(test.TestCase): flo.add(a, b, c, d) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(6, len(g)) - self.assertCountEqual(g.edges(), [ - ('test', 'a'), - ('test', 'b'), - ('test', 'c'), - ('test', 'd'), - ('a', 'test[$]'), - ('b', 'test[$]'), - ('c', 'test[$]'), - ('d', 'test[$]'), - ]) + self.assertCountEqual( + g.edges(), + [ + ('test', 'a'), + ('test', 'b'), + ('test', 'c'), + ('test', 'd'), + ('a', 'test[$]'), + ('b', 'test[$]'), + ('c', 'test[$]'), + ('d', 'test[$]'), + ], + ) self.assertEqual({'test'}, set(g.no_predecessors_iter())) def test_linear_nested(self): @@ -119,7 +125,8 @@ class PatternCompileTest(test.TestCase): flo.add(inner_flo) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(8, len(g)) sub_g = g.subgraph(['a', 'b']) @@ -144,19 +151,23 @@ class PatternCompileTest(test.TestCase): flo.add(flo2) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(8, len(g)) - self.assertCountEqual(g.edges(), [ - ('test', 'a'), - ('test', 'b'), - ('test', 'test2'), - ('test2', 'c'), - ('c', 'd'), - ('d', 'test2[$]'), - ('test2[$]', 'test[$]'), - ('a', 'test[$]'), - ('b', 'test[$]'), - ]) + self.assertCountEqual( + g.edges(), + [ + ('test', 'a'), + ('test', 'b'), + ('test', 'test2'), + ('test2', 'c'), + ('c', 'd'), + ('d', 'test2[$]'), + ('test2[$]', 'test[$]'), + ('a', 'test[$]'), + ('b', 'test[$]'), + ], + ) def test_unordered_nested_in_linear(self): a, b, c, d = test_utils.make_many(4) @@ -164,18 +175,22 @@ class PatternCompileTest(test.TestCase): flo = lf.Flow('lt').add(a, inner_flo, d) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(8, len(g)) - self.assertCountEqual(g.edges(), [ - ('lt', 'a'), - ('a', 'ut'), - ('ut', 'b'), - ('ut', 'c'), - ('b', 'ut[$]'), - ('c', 'ut[$]'), - ('ut[$]', 'd'), - ('d', 'lt[$]'), - ]) + self.assertCountEqual( + g.edges(), + [ + ('lt', 'a'), + ('a', 'ut'), + ('ut', 'b'), + ('ut', 'c'), + ('b', 'ut[$]'), + ('c', 'ut[$]'), + ('ut[$]', 'd'), + ('d', 'lt[$]'), + ], + ) def test_graph(self): a, b, c, d = test_utils.make_many(4) @@ -195,26 +210,28 @@ class PatternCompileTest(test.TestCase): flo.add(flo2) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(11, len(g)) - self.assertCountEqual(g.edges(), [ - ('test', 'a'), - ('test', 'b'), - ('test', 'c'), - ('test', 'd'), - ('a', 'test[$]'), - ('b', 'test[$]'), - ('c', 'test[$]'), - ('d', 'test[$]'), - - ('test', 'test2'), - ('test2', 'e'), - ('e', 'f'), - ('f', 'g'), - - ('g', 'test2[$]'), - ('test2[$]', 'test[$]'), - ]) + self.assertCountEqual( + g.edges(), + [ + ('test', 'a'), + ('test', 'b'), + ('test', 'c'), + ('test', 'd'), + ('a', 'test[$]'), + ('b', 'test[$]'), + ('c', 'test[$]'), + ('d', 'test[$]'), + ('test', 'test2'), + ('test2', 'e'), + ('e', 'f'), + ('f', 'g'), + ('g', 'test2[$]'), + ('test2[$]', 'test[$]'), + ], + ) def test_graph_nested_graph(self): a, b, c, d, e, f, g = test_utils.make_many(7) @@ -226,29 +243,30 @@ class PatternCompileTest(test.TestCase): flo.add(flo2) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(11, len(g)) - self.assertCountEqual(g.edges(), [ - ('test', 'a'), - ('test', 'b'), - ('test', 'c'), - ('test', 'd'), - ('test', 'test2'), - - ('test2', 'e'), - ('test2', 'f'), - ('test2', 'g'), - - ('e', 'test2[$]'), - ('f', 'test2[$]'), - ('g', 'test2[$]'), - - ('test2[$]', 'test[$]'), - ('a', 'test[$]'), - ('b', 'test[$]'), - ('c', 'test[$]'), - ('d', 'test[$]'), - ]) + self.assertCountEqual( + g.edges(), + [ + ('test', 'a'), + ('test', 'b'), + ('test', 'c'), + ('test', 'd'), + ('test', 'test2'), + ('test2', 'e'), + ('test2', 'f'), + ('test2', 'g'), + ('e', 'test2[$]'), + ('f', 'test2[$]'), + ('g', 'test2[$]'), + ('test2[$]', 'test[$]'), + ('a', 'test[$]'), + ('b', 'test[$]'), + ('c', 'test[$]'), + ('d', 'test[$]'), + ], + ) def test_graph_links(self): a, b, c, d = test_utils.make_many(4) @@ -259,15 +277,19 @@ class PatternCompileTest(test.TestCase): flo.link(c, d) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(6, len(g)) - self.assertCountEqual(g.edges(data=True), [ - ('test', 'a', {'invariant': True}), - ('a', 'b', {'manual': True}), - ('b', 'c', {'manual': True}), - ('c', 'd', {'manual': True}), - ('d', 'test[$]', {'invariant': True}), - ]) + self.assertCountEqual( + g.edges(data=True), + [ + ('test', 'a', {'invariant': True}), + ('a', 'b', {'manual': True}), + ('b', 'c', {'manual': True}), + ('c', 'd', {'manual': True}), + ('d', 'test[$]', {'invariant': True}), + ], + ) self.assertCountEqual(['test'], g.no_predecessors_iter()) self.assertCountEqual(['test[$]'], g.no_successors_iter()) @@ -277,13 +299,17 @@ class PatternCompileTest(test.TestCase): flo = gf.Flow("test").add(a, b) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(4, len(g)) - self.assertCountEqual(g.edges(data=True), [ - ('test', 'a', {'invariant': True}), - ('a', 'b', {'reasons': {'x'}}), - ('b', 'test[$]', {'invariant': True}), - ]) + self.assertCountEqual( + g.edges(data=True), + [ + ('test', 'a', {'invariant': True}), + ('a', 'b', {'reasons': {'x'}}), + ('b', 'test[$]', {'invariant': True}), + ], + ) self.assertCountEqual(['test'], g.no_predecessors_iter()) self.assertCountEqual(['test[$]'], g.no_successors_iter()) @@ -295,16 +321,20 @@ class PatternCompileTest(test.TestCase): flo = gf.Flow("test").add(a, inner_flo) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(7, len(g)) - self.assertCountEqual(g.edges(data=True), [ - ('test', 'a', {'invariant': True}), - ('test2', 'b', {'invariant': True}), - ('a', 'test2', {'reasons': {'x'}}), - ('b', 'c', {'invariant': True}), - ('c', 'test2[$]', {'invariant': True}), - ('test2[$]', 'test[$]', {'invariant': True}), - ]) + self.assertCountEqual( + g.edges(data=True), + [ + ('test', 'a', {'invariant': True}), + ('test2', 'b', {'invariant': True}), + ('a', 'test2', {'reasons': {'x'}}), + ('b', 'c', {'invariant': True}), + ('c', 'test2[$]', {'invariant': True}), + ('test2[$]', 'test[$]', {'invariant': True}), + ], + ) self.assertCountEqual(['test'], list(g.no_predecessors_iter())) self.assertCountEqual(['test[$]'], list(g.no_successors_iter())) @@ -316,19 +346,21 @@ class PatternCompileTest(test.TestCase): flo = gf.Flow("test").add(a, inner_flo) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(7, len(g)) - self.assertCountEqual(g.edges(data=True), [ - ('test', 'test2', {'invariant': True}), - ('a', 'test[$]', {'invariant': True}), - - # The 'x' requirement is produced out of test2... - ('test2[$]', 'a', {'reasons': {'x'}}), - - ('test2', 'b', {'invariant': True}), - ('b', 'c', {'invariant': True}), - ('c', 'test2[$]', {'invariant': True}), - ]) + self.assertCountEqual( + g.edges(data=True), + [ + ('test', 'test2', {'invariant': True}), + ('a', 'test[$]', {'invariant': True}), + # The 'x' requirement is produced out of test2... + ('test2[$]', 'a', {'reasons': {'x'}}), + ('test2', 'b', {'invariant': True}), + ('b', 'c', {'invariant': True}), + ('c', 'test2[$]', {'invariant': True}), + ], + ) self.assertCountEqual(['test'], g.no_predecessors_iter()) self.assertCountEqual(['test[$]'], g.no_successors_iter()) @@ -340,14 +372,18 @@ class PatternCompileTest(test.TestCase): flo.add(a, empty_flo, b) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) - self.assertCountEqual(g.edges(), [ - ("lf", "a"), - ("a", "empty"), - ("empty", "empty[$]"), - ("empty[$]", "b"), - ("b", "lf[$]"), - ]) + compiler.PatternCompiler(flo).compile() + ) + self.assertCountEqual( + g.edges(), + [ + ("lf", "a"), + ("a", "empty"), + ("empty", "empty[$]"), + ("empty[$]", "b"), + ("b", "lf[$]"), + ], + ) def test_many_empty_in_graph_flow(self): flo = gf.Flow('root') @@ -378,7 +414,8 @@ class PatternCompileTest(test.TestCase): flo.link(c, d) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertTrue(g.has_edge('root', 'a')) self.assertTrue(g.has_edge('root', 'b')) @@ -409,11 +446,18 @@ class PatternCompileTest(test.TestCase): flow.add(a, flow2, b) g = _replicate_graph_with_names( - compiler.PatternCompiler(flow).compile()) - for u, v in [('lf', 'a'), ('a', 'lf-2'), - ('lf-2', 'c'), ('c', 'empty'), - ('empty[$]', 'd'), ('d', 'lf-2[$]'), - ('lf-2[$]', 'b'), ('b', 'lf[$]')]: + compiler.PatternCompiler(flow).compile() + ) + for u, v in [ + ('lf', 'a'), + ('a', 'lf-2'), + ('lf-2', 'c'), + ('c', 'empty'), + ('empty[$]', 'd'), + ('d', 'lf-2[$]'), + ('lf-2[$]', 'b'), + ('b', 'lf[$]'), + ]: self.assertTrue(g.has_edge(u, v)) def test_empty_flow_in_graph_flow(self): @@ -432,8 +476,9 @@ class PatternCompileTest(test.TestCase): self.assertEqual(1, len(empty_flow_successors)) empty_flow_terminal = empty_flow_successors[0] self.assertIs(empty_flow, empty_flow_terminal.flow) - self.assertEqual(compiler.FLOW_END, - g.nodes[empty_flow_terminal]['kind']) + self.assertEqual( + compiler.FLOW_END, g.nodes[empty_flow_terminal]['kind'] + ) self.assertTrue(g.has_edge(empty_flow_terminal, b)) def test_empty_flow_in_graph_flow_linkage(self): @@ -452,22 +497,22 @@ class PatternCompileTest(test.TestCase): def test_checks_for_dups(self): flo = gf.Flow("test").add( - test_utils.DummyTask(name="a"), - test_utils.DummyTask(name="a") + test_utils.DummyTask(name="a"), test_utils.DummyTask(name="a") ) e = engines.load(flo) - self.assertRaisesRegex(exc.Duplicate, - '^Atoms with duplicate names', - e.compile) + self.assertRaisesRegex( + exc.Duplicate, '^Atoms with duplicate names', e.compile + ) def test_checks_for_dups_globally(self): flo = gf.Flow("test").add( gf.Flow("int1").add(test_utils.DummyTask(name="a")), - gf.Flow("int2").add(test_utils.DummyTask(name="a"))) + gf.Flow("int2").add(test_utils.DummyTask(name="a")), + ) e = engines.load(flo) - self.assertRaisesRegex(exc.Duplicate, - '^Atoms with duplicate names', - e.compile) + self.assertRaisesRegex( + exc.Duplicate, '^Atoms with duplicate names', e.compile + ) def test_retry_in_linear_flow(self): flo = lf.Flow("test", retry.AlwaysRevert("c")) @@ -495,15 +540,19 @@ class PatternCompileTest(test.TestCase): flo = lf.Flow("test", c1).add(inner_flo) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(6, len(g)) - self.assertCountEqual(g.edges(data=True), [ - ('test', 'c1', {'invariant': True}), - ('c1', 'test2', {'invariant': True, 'retry': True}), - ('test2', 'c2', {'invariant': True}), - ('c2', 'test2[$]', {'invariant': True}), - ('test2[$]', 'test[$]', {'invariant': True}), - ]) + self.assertCountEqual( + g.edges(data=True), + [ + ('test', 'c1', {'invariant': True}), + ('c1', 'test2', {'invariant': True, 'retry': True}), + ('test2', 'c2', {'invariant': True}), + ('c2', 'test2[$]', {'invariant': True}), + ('test2[$]', 'test[$]', {'invariant': True}), + ], + ) self.assertIs(c1, g.nodes['c2']['retry']) self.assertCountEqual(['test'], list(g.no_predecessors_iter())) self.assertCountEqual(['test[$]'], list(g.no_successors_iter())) @@ -514,14 +563,18 @@ class PatternCompileTest(test.TestCase): flo = lf.Flow("test", c).add(a, b) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(5, len(g)) - self.assertCountEqual(g.edges(data=True), [ - ('test', 'c', {'invariant': True}), - ('a', 'b', {'invariant': True}), - ('c', 'a', {'invariant': True, 'retry': True}), - ('b', 'test[$]', {'invariant': True}), - ]) + self.assertCountEqual( + g.edges(data=True), + [ + ('test', 'c', {'invariant': True}), + ('a', 'b', {'invariant': True}), + ('c', 'a', {'invariant': True, 'retry': True}), + ('b', 'test[$]', {'invariant': True}), + ], + ) self.assertCountEqual(['test'], g.no_predecessors_iter()) self.assertCountEqual(['test[$]'], g.no_successors_iter()) @@ -534,15 +587,19 @@ class PatternCompileTest(test.TestCase): flo = uf.Flow("test", c).add(a, b) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(5, len(g)) - self.assertCountEqual(g.edges(data=True), [ - ('test', 'c', {'invariant': True}), - ('c', 'a', {'invariant': True, 'retry': True}), - ('c', 'b', {'invariant': True, 'retry': True}), - ('b', 'test[$]', {'invariant': True}), - ('a', 'test[$]', {'invariant': True}), - ]) + self.assertCountEqual( + g.edges(data=True), + [ + ('test', 'c', {'invariant': True}), + ('c', 'a', {'invariant': True, 'retry': True}), + ('c', 'b', {'invariant': True, 'retry': True}), + ('b', 'test[$]', {'invariant': True}), + ('a', 'test[$]', {'invariant': True}), + ], + ) self.assertCountEqual(['test'], list(g.no_predecessors_iter())) self.assertCountEqual(['test[$]'], list(g.no_successors_iter())) @@ -555,15 +612,19 @@ class PatternCompileTest(test.TestCase): flo = gf.Flow("test", r).add(a, b, c).link(b, c) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) - self.assertCountEqual(g.edges(data=True), [ - ('test', 'r', {'invariant': True}), - ('r', 'a', {'invariant': True, 'retry': True}), - ('r', 'b', {'invariant': True, 'retry': True}), - ('b', 'c', {'manual': True}), - ('a', 'test[$]', {'invariant': True}), - ('c', 'test[$]', {'invariant': True}), - ]) + compiler.PatternCompiler(flo).compile() + ) + self.assertCountEqual( + g.edges(data=True), + [ + ('test', 'r', {'invariant': True}), + ('r', 'a', {'invariant': True, 'retry': True}), + ('r', 'b', {'invariant': True, 'retry': True}), + ('b', 'c', {'manual': True}), + ('a', 'test[$]', {'invariant': True}), + ('c', 'test[$]', {'invariant': True}), + ], + ) self.assertCountEqual(['test'], g.no_predecessors_iter()) self.assertCountEqual(['test[$]'], g.no_successors_iter()) @@ -579,19 +640,23 @@ class PatternCompileTest(test.TestCase): flo = lf.Flow("test", c1).add(a, inner_flo, d) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(10, len(g)) - self.assertCountEqual(g.edges(data=True), [ - ('test', 'c1', {'invariant': True}), - ('c1', 'a', {'invariant': True, 'retry': True}), - ('a', 'test2', {'invariant': True}), - ('test2', 'c2', {'invariant': True}), - ('c2', 'b', {'invariant': True, 'retry': True}), - ('b', 'c', {'invariant': True}), - ('c', 'test2[$]', {'invariant': True}), - ('test2[$]', 'd', {'invariant': True}), - ('d', 'test[$]', {'invariant': True}), - ]) + self.assertCountEqual( + g.edges(data=True), + [ + ('test', 'c1', {'invariant': True}), + ('c1', 'a', {'invariant': True, 'retry': True}), + ('a', 'test2', {'invariant': True}), + ('test2', 'c2', {'invariant': True}), + ('c2', 'b', {'invariant': True, 'retry': True}), + ('b', 'c', {'invariant': True}), + ('c', 'test2[$]', {'invariant': True}), + ('test2[$]', 'd', {'invariant': True}), + ('d', 'test[$]', {'invariant': True}), + ], + ) self.assertIs(c1, g.nodes['a']['retry']) self.assertIs(c1, g.nodes['d']['retry']) self.assertIs(c2, g.nodes['b']['retry']) @@ -606,18 +671,22 @@ class PatternCompileTest(test.TestCase): flo = lf.Flow("test", c1).add(a, inner_flo, d) g = _replicate_graph_with_names( - compiler.PatternCompiler(flo).compile()) + compiler.PatternCompiler(flo).compile() + ) self.assertEqual(9, len(g)) - self.assertCountEqual(g.edges(data=True), [ - ('test', 'c1', {'invariant': True}), - ('c1', 'a', {'invariant': True, 'retry': True}), - ('a', 'test2', {'invariant': True}), - ('test2', 'b', {'invariant': True}), - ('b', 'c', {'invariant': True}), - ('c', 'test2[$]', {'invariant': True}), - ('test2[$]', 'd', {'invariant': True}), - ('d', 'test[$]', {'invariant': True}), - ]) + self.assertCountEqual( + g.edges(data=True), + [ + ('test', 'c1', {'invariant': True}), + ('c1', 'a', {'invariant': True, 'retry': True}), + ('a', 'test2', {'invariant': True}), + ('test2', 'b', {'invariant': True}), + ('b', 'c', {'invariant': True}), + ('c', 'test2[$]', {'invariant': True}), + ('test2[$]', 'd', {'invariant': True}), + ('d', 'test[$]', {'invariant': True}), + ], + ) self.assertIs(c1, g.nodes['a']['retry']) self.assertIs(c1, g.nodes['d']['retry']) self.assertIs(c1, g.nodes['b']['retry']) diff --git a/taskflow/tests/unit/action_engine/test_creation.py b/taskflow/tests/unit/action_engine/test_creation.py index b9aba0ab8..d91b00881 100644 --- a/taskflow/tests/unit/action_engine/test_creation.py +++ b/taskflow/tests/unit/action_engine/test_creation.py @@ -32,33 +32,36 @@ class ParallelCreationTest(test.TestCase): backend = backends.fetch({'connection': 'memory'}) flow_detail = pu.create_flow_detail(flow, backend=backend) options = kwargs.copy() - return engine.ParallelActionEngine(flow, flow_detail, - backend, options) + return engine.ParallelActionEngine(flow, flow_detail, backend, options) def test_thread_string_creation(self): for s in ['threads', 'threaded', 'thread']: eng = self._create_engine(executor=s) - self.assertIsInstance(eng._task_executor, - executor.ParallelThreadTaskExecutor) + self.assertIsInstance( + eng._task_executor, executor.ParallelThreadTaskExecutor + ) def test_thread_executor_creation(self): with futurist.ThreadPoolExecutor(1) as e: eng = self._create_engine(executor=e) - self.assertIsInstance(eng._task_executor, - executor.ParallelThreadTaskExecutor) + self.assertIsInstance( + eng._task_executor, executor.ParallelThreadTaskExecutor + ) @testtools.skipIf(not eu.EVENTLET_AVAILABLE, 'eventlet is not available') def test_green_executor_creation(self): with futurist.GreenThreadPoolExecutor(1) as e: eng = self._create_engine(executor=e) - self.assertIsInstance(eng._task_executor, - executor.ParallelThreadTaskExecutor) + self.assertIsInstance( + eng._task_executor, executor.ParallelThreadTaskExecutor + ) def test_sync_executor_creation(self): with futurist.SynchronousExecutor() as e: eng = self._create_engine(executor=e) - self.assertIsInstance(eng._task_executor, - executor.ParallelThreadTaskExecutor) + self.assertIsInstance( + eng._task_executor, executor.ParallelThreadTaskExecutor + ) def test_invalid_creation(self): self.assertRaises(ValueError, self._create_engine, executor='crap') diff --git a/taskflow/tests/unit/action_engine/test_scoping.py b/taskflow/tests/unit/action_engine/test_scoping.py index 76a8b1748..5030f90a2 100644 --- a/taskflow/tests/unit/action_engine/test_scoping.py +++ b/taskflow/tests/unit/action_engine/test_scoping.py @@ -65,8 +65,10 @@ class LinearScopingTest(test.TestCase): def test_nested_prior_linear(self): r = lf.Flow("root") - r.add(test_utils.TaskOneReturn("root.1"), - test_utils.TaskOneReturn("root.2")) + r.add( + test_utils.TaskOneReturn("root.1"), + test_utils.TaskOneReturn("root.2"), + ) sub_r = lf.Flow("subroot") sub_r_1 = test_utils.TaskOneReturn("subroot.1") sub_r.add(sub_r_1) @@ -82,8 +84,10 @@ class LinearScopingTest(test.TestCase): middle_r = test_utils.TaskOneReturn("root.3") r.add(middle_r) sub_r = lf.Flow("subroot") - sub_r.add(test_utils.TaskOneReturn("subroot.1"), - test_utils.TaskOneReturn("subroot.2")) + sub_r.add( + test_utils.TaskOneReturn("subroot.1"), + test_utils.TaskOneReturn("subroot.2"), + ) r.add(sub_r) end_r = test_utils.TaskOneReturn("root.4") r.add(end_r) @@ -92,28 +96,31 @@ class LinearScopingTest(test.TestCase): self.assertEqual([], _get_scopes(c, begin_r)) self.assertEqual([['root.2', 'root.1']], _get_scopes(c, middle_r)) - self.assertEqual([['subroot.2', 'subroot.1', 'root.3', 'root.2', - 'root.1']], _get_scopes(c, end_r)) + self.assertEqual( + [['subroot.2', 'subroot.1', 'root.3', 'root.2', 'root.1']], + _get_scopes(c, end_r), + ) class GraphScopingTest(test.TestCase): def test_dependent(self): r = gf.Flow("root") - customer = test_utils.ProvidesRequiresTask("customer", - provides=['dog'], - requires=[]) - washer = test_utils.ProvidesRequiresTask("washer", - requires=['dog'], - provides=['wash']) - dryer = test_utils.ProvidesRequiresTask("dryer", - requires=['dog', 'wash'], - provides=['dry_dog']) - shaved = test_utils.ProvidesRequiresTask("shaver", - requires=['dry_dog'], - provides=['shaved_dog']) + customer = test_utils.ProvidesRequiresTask( + "customer", provides=['dog'], requires=[] + ) + washer = test_utils.ProvidesRequiresTask( + "washer", requires=['dog'], provides=['wash'] + ) + dryer = test_utils.ProvidesRequiresTask( + "dryer", requires=['dog', 'wash'], provides=['dry_dog'] + ) + shaved = test_utils.ProvidesRequiresTask( + "shaver", requires=['dry_dog'], provides=['shaved_dog'] + ) happy_customer = test_utils.ProvidesRequiresTask( - "happy_customer", requires=['shaved_dog'], provides=['happiness']) + "happy_customer", requires=['shaved_dog'], provides=['happiness'] + ) r.add(customer, washer, dryer, shaved, happy_customer) @@ -121,8 +128,10 @@ class GraphScopingTest(test.TestCase): self.assertEqual([], _get_scopes(c, customer)) self.assertEqual([['washer', 'customer']], _get_scopes(c, dryer)) - self.assertEqual([['shaver', 'dryer', 'washer', 'customer']], - _get_scopes(c, happy_customer)) + self.assertEqual( + [['shaver', 'dryer', 'washer', 'customer']], + _get_scopes(c, happy_customer), + ) def test_no_visible(self): r = gf.Flow("root") @@ -202,10 +211,10 @@ class MixedPatternScopingTest(test.TestCase): self.assertEqual([['root.1']], _get_scopes(c, r_2)) self.assertEqual([], _get_scopes(c, s_1)) self.assertEqual([['subroot.1']], _get_scopes(c, s_2)) - self.assertEqual([[], ['subroot.2', 'subroot.1']], - _get_scopes(c, t_1)) - self.assertEqual([["subroot2.1"], ['subroot.2', 'subroot.1']], - _get_scopes(c, t_2)) + self.assertEqual([[], ['subroot.2', 'subroot.1']], _get_scopes(c, t_1)) + self.assertEqual( + [["subroot2.1"], ['subroot.2', 'subroot.1']], _get_scopes(c, t_2) + ) def test_linear_unordered_scope(self): r = lf.Flow("root") @@ -247,15 +256,15 @@ class MixedPatternScopingTest(test.TestCase): def test_shadow_graph(self): r = gf.Flow("root") - customer = test_utils.ProvidesRequiresTask("customer", - provides=['dog'], - requires=[]) - customer2 = test_utils.ProvidesRequiresTask("customer2", - provides=['dog'], - requires=[]) - washer = test_utils.ProvidesRequiresTask("washer", - requires=['dog'], - provides=['wash']) + customer = test_utils.ProvidesRequiresTask( + "customer", provides=['dog'], requires=[] + ) + customer2 = test_utils.ProvidesRequiresTask( + "customer2", provides=['dog'], requires=[] + ) + washer = test_utils.ProvidesRequiresTask( + "washer", requires=['dog'], provides=['wash'] + ) r.add(customer, washer) r.add(customer2, resolve_requires=False) r.link(customer2, washer) @@ -270,23 +279,24 @@ class MixedPatternScopingTest(test.TestCase): # This may be different after/if the following is resolved: # # https://github.com/networkx/networkx/issues/1181 (and a few others) - self.assertEqual({'customer', 'customer2'}, - set(_get_scopes(c, washer)[0])) + self.assertEqual( + {'customer', 'customer2'}, set(_get_scopes(c, washer)[0]) + ) self.assertEqual([], _get_scopes(c, customer2)) self.assertEqual([], _get_scopes(c, customer)) def test_shadow_linear(self): r = lf.Flow("root") - customer = test_utils.ProvidesRequiresTask("customer", - provides=['dog'], - requires=[]) - customer2 = test_utils.ProvidesRequiresTask("customer2", - provides=['dog'], - requires=[]) - washer = test_utils.ProvidesRequiresTask("washer", - requires=['dog'], - provides=['wash']) + customer = test_utils.ProvidesRequiresTask( + "customer", provides=['dog'], requires=[] + ) + customer2 = test_utils.ProvidesRequiresTask( + "customer2", provides=['dog'], requires=[] + ) + washer = test_utils.ProvidesRequiresTask( + "washer", requires=['dog'], provides=['wash'] + ) r.add(customer, customer2, washer) c = compiler.PatternCompiler(r).compile() diff --git a/taskflow/tests/unit/jobs/base.py b/taskflow/tests/unit/jobs/base.py index 93115d54f..2bf947ce6 100644 --- a/taskflow/tests/unit/jobs/base.py +++ b/taskflow/tests/unit/jobs/base.py @@ -36,7 +36,6 @@ def connect_close(*args): class BoardTestMixin: - @contextlib.contextmanager def flush(self, client): yield @@ -71,8 +70,10 @@ class BoardTestMixin: def poster(wait_post=0.2): if not ev.wait(test_utils.WAIT_TIMEOUT): - raise RuntimeError("Waiter did not appear ready" - " in %s seconds" % test_utils.WAIT_TIMEOUT) + raise RuntimeError( + "Waiter did not appear ready" + " in %s seconds" % test_utils.WAIT_TIMEOUT + ) time.sleep(wait_post) self.board.post('test', p_utils.temporary_log_book()) @@ -133,8 +134,9 @@ class BoardTestMixin: self.board.consume(j, self.board.name) self.assertEqual(0, len(list(self.board.iterjobs()))) - self.assertRaises(excp.NotFound, - self.board.consume, j, self.board.name) + self.assertRaises( + excp.NotFound, self.board.consume, j, self.board.name + ) def test_posting_claim_abandon(self): @@ -169,8 +171,12 @@ class BoardTestMixin: possible_jobs = list(self.board.iterjobs()) self.assertEqual(1, len(possible_jobs)) - self.assertRaises(excp.UnclaimableJob, self.board.claim, - possible_jobs[0], self.board.name + "-1") + self.assertRaises( + excp.UnclaimableJob, + self.board.claim, + possible_jobs[0], + self.board.name + "-1", + ) possible_jobs = list(self.board.iterjobs(only_unclaimed=True)) self.assertEqual(0, len(possible_jobs)) @@ -188,9 +194,11 @@ class BoardTestMixin: self.assertFalse(jb.wait(0.1)) def test_posting_with_book(self): - backend = impl_dir.DirBackend(conf={ - 'path': self.makeTmpDir(), - }) + backend = impl_dir.DirBackend( + conf={ + 'path': self.makeTmpDir(), + } + ) backend.get_connection().upgrade() book, flow_detail = p_utils.temporary_flow_detail(backend) self.assertEqual(1, len(book)) @@ -223,5 +231,4 @@ class BoardTestMixin: possible_jobs = list(self.board.iterjobs(only_unclaimed=True)) self.assertEqual(1, len(possible_jobs)) j = possible_jobs[0] - self.assertRaises(excp.NotFound, self.board.abandon, - j, j.name) + self.assertRaises(excp.NotFound, self.board.abandon, j, j.name) diff --git a/taskflow/tests/unit/jobs/test_etcd_job.py b/taskflow/tests/unit/jobs/test_etcd_job.py index f43647dc5..6f4619ea7 100644 --- a/taskflow/tests/unit/jobs/test_etcd_job.py +++ b/taskflow/tests/unit/jobs/test_etcd_job.py @@ -38,13 +38,13 @@ class EtcdJobBoardMixin: } if conf: board_conf.update(conf) - board = impl_etcd.EtcdJobBoard("etcd", board_conf, - persistence=persistence) + board = impl_etcd.EtcdJobBoard( + "etcd", board_conf, persistence=persistence + ) return board._client, board class MockedEtcdJobBoard(test.TestCase, EtcdJobBoardMixin): - def test_create_board(self): _, jobboard = self.create_board() self.assertEqual(f"/taskflow/jobs/{self.path}", jobboard._root_path) @@ -56,11 +56,13 @@ class MockedEtcdJobBoard(test.TestCase, EtcdJobBoardMixin): @mock.patch("threading.Condition") @mock.patch("oslo_utils.uuidutils.generate_uuid") @mock.patch("oslo_utils.timeutils.utcnow") - def test_post(self, - mock_utcnow: mock.Mock, - mock_generated_uuid: mock.Mock, - mock_cond: mock.Mock, - mock_incr: mock.Mock): + def test_post( + self, + mock_utcnow: mock.Mock, + mock_generated_uuid: mock.Mock, + mock_cond: mock.Mock, + mock_incr: mock.Mock, + ): mock_incr.return_value = 12 mock_generated_uuid.return_value = "uuid1" mock_utcnow.return_value = "utcnow1" @@ -72,17 +74,16 @@ class MockedEtcdJobBoard(test.TestCase, EtcdJobBoardMixin): _, jobboard = self.create_board() jobboard._client = mock.Mock() - job = jobboard.post("post1", book=mock_book, - details=mock_details, - priority=jobs_base.JobPriority.NORMAL) + job = jobboard.post( + "post1", + book=mock_book, + details=mock_details, + priority=jobs_base.JobPriority.NORMAL, + ) - expected_key = ( - f"/taskflow/jobs/{self.path}/job12") + expected_key = f"/taskflow/jobs/{self.path}/job12" expected_data_key = expected_key + jobboard.DATA_POSTFIX - expected_book_data = { - "name": "book1_name", - "uuid": "book1_uuid" - } + expected_book_data = {"name": "book1_name", "uuid": "book1_uuid"} expected_job_posting = { "uuid": "uuid1", "name": "post1", @@ -96,7 +97,8 @@ class MockedEtcdJobBoard(test.TestCase, EtcdJobBoardMixin): mock_incr.assert_called_with(f"/taskflow/jobs/{self.path}/sequence") jobboard._client.create.assert_called_with( - expected_data_key, jsonutils.dumps(expected_job_posting)) + expected_data_key, jsonutils.dumps(expected_job_posting) + ) self.assertEqual("post1", job.name) self.assertEqual(expected_key, job.key) @@ -109,8 +111,9 @@ class MockedEtcdJobBoard(test.TestCase, EtcdJobBoardMixin): self.assertEqual(1, len(jobboard._job_cache)) self.assertEqual(job, jobboard._job_cache[expected_key]) - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "set_last_modified") + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard.set_last_modified" + ) def test_claim(self, mock_set_last_modified): who = "owner1" lease_id = uuidutils.generate_uuid() @@ -123,18 +126,20 @@ class MockedEtcdJobBoard(test.TestCase, EtcdJobBoardMixin): jobboard._client.create.return_value = True jobboard._client.get.return_value = [mock.Mock()] - job = impl_etcd.EtcdJob(jobboard, - "job7", - jobboard._client, - f"/taskflow/jobs/{self.path}/job7", - uuid=uuidutils.generate_uuid(), - details=mock.Mock(), - backend="etcd", - book=mock.Mock(), - book_data=mock.Mock(), - priority=jobs_base.JobPriority.NORMAL, - sequence=7, - created_on="date") + job = impl_etcd.EtcdJob( + jobboard, + "job7", + jobboard._client, + f"/taskflow/jobs/{self.path}/job7", + uuid=uuidutils.generate_uuid(), + details=mock.Mock(), + backend="etcd", + book=mock.Mock(), + book_data=mock.Mock(), + priority=jobs_base.JobPriority.NORMAL, + sequence=7, + created_on="date", + ) jobboard.claim(job, who) @@ -142,22 +147,24 @@ class MockedEtcdJobBoard(test.TestCase, EtcdJobBoardMixin): jobboard._client.create.assert_called_once_with( f"{job.key}{jobboard.LOCK_POSTFIX}", - jsonutils.dumps({"owner": who, - "lease_id": lease_id}), - lease=mock_lease) + jsonutils.dumps({"owner": who, "lease_id": lease_id}), + lease=mock_lease, + ) jobboard._client.get.assert_called_once_with( - job.key + jobboard.DATA_POSTFIX) + job.key + jobboard.DATA_POSTFIX + ) mock_lease.revoke.assert_not_called() mock_set_last_modified.assert_called_once_with(job) - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "set_last_modified") - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "find_owner") - def test_claim_already_claimed(self, mock_find_owner, - mock_set_last_modified): + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard.set_last_modified" + ) + @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard.find_owner") + def test_claim_already_claimed( + self, mock_find_owner, mock_set_last_modified + ): who = "owner1" lease_id = uuidutils.generate_uuid() @@ -171,36 +178,40 @@ class MockedEtcdJobBoard(test.TestCase, EtcdJobBoardMixin): jobboard._client.create.return_value = False jobboard._client.get.return_value = [] - job = impl_etcd.EtcdJob(jobboard, - "job7", - jobboard._client, - f"/taskflow/jobs/{self.path}/job7", - uuid=uuidutils.generate_uuid(), - details=mock.Mock(), - backend="etcd", - book=mock.Mock(), - book_data=mock.Mock(), - priority=jobs_base.JobPriority.NORMAL, - sequence=7, - created_on="date") + job = impl_etcd.EtcdJob( + jobboard, + "job7", + jobboard._client, + f"/taskflow/jobs/{self.path}/job7", + uuid=uuidutils.generate_uuid(), + details=mock.Mock(), + backend="etcd", + book=mock.Mock(), + book_data=mock.Mock(), + priority=jobs_base.JobPriority.NORMAL, + sequence=7, + created_on="date", + ) - self.assertRaisesRegex(exc.UnclaimableJob, "already claimed by", - jobboard.claim, job, who) + self.assertRaisesRegex( + exc.UnclaimableJob, "already claimed by", jobboard.claim, job, who + ) jobboard._client.lease.assert_called_once_with(ttl=37) jobboard._client.create.assert_called_once_with( f"{job.key}{jobboard.LOCK_POSTFIX}", - jsonutils.dumps({"owner": who, - "lease_id": lease_id}), - lease=mock_lease) + jsonutils.dumps({"owner": who, "lease_id": lease_id}), + lease=mock_lease, + ) mock_lease.revoke.assert_called_once() mock_set_last_modified.assert_not_called() - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "set_last_modified") + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard.set_last_modified" + ) def test_claim_deleted(self, mock_set_last_modified): who = "owner1" lease_id = uuidutils.generate_uuid() @@ -213,168 +224,213 @@ class MockedEtcdJobBoard(test.TestCase, EtcdJobBoardMixin): jobboard._client.create.return_value = True jobboard._client.get.return_value = [] - job = impl_etcd.EtcdJob(jobboard, - "job7", - jobboard._client, - f"/taskflow/jobs/{self.path}/job7", - uuid=uuidutils.generate_uuid(), - details=mock.Mock(), - backend="etcd", - book=mock.Mock(), - book_data=mock.Mock(), - priority=jobs_base.JobPriority.NORMAL, - sequence=7, - created_on="date") + job = impl_etcd.EtcdJob( + jobboard, + "job7", + jobboard._client, + f"/taskflow/jobs/{self.path}/job7", + uuid=uuidutils.generate_uuid(), + details=mock.Mock(), + backend="etcd", + book=mock.Mock(), + book_data=mock.Mock(), + priority=jobs_base.JobPriority.NORMAL, + sequence=7, + created_on="date", + ) - self.assertRaisesRegex(exc.UnclaimableJob, "already deleted", - jobboard.claim, job, who) + self.assertRaisesRegex( + exc.UnclaimableJob, "already deleted", jobboard.claim, job, who + ) jobboard._client.lease.assert_called_once_with(ttl=37) jobboard._client.create.assert_called_once_with( f"{job.key}{jobboard.LOCK_POSTFIX}", - jsonutils.dumps({"owner": who, - "lease_id": lease_id}), - lease=mock_lease) + jsonutils.dumps({"owner": who, "lease_id": lease_id}), + lease=mock_lease, + ) jobboard._client.get.assert_called_once_with( - job.key + jobboard.DATA_POSTFIX) + job.key + jobboard.DATA_POSTFIX + ) mock_lease.revoke.assert_called_once() mock_set_last_modified.assert_not_called() - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "get_owner_and_data") - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "_remove_job_from_cache") - def test_consume(self, mock__remove_job_from_cache, - mock_get_owner_and_data): + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard.get_owner_and_data" + ) + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard._remove_job_from_cache" + ) + def test_consume( + self, mock__remove_job_from_cache, mock_get_owner_and_data + ): mock_get_owner_and_data.return_value = ["owner1", mock.Mock()] _, jobboard = self.create_board() jobboard._client = mock.Mock() - job = impl_etcd.EtcdJob(jobboard, - "job7", - jobboard._client, - f"/taskflow/jobs/{self.path}/job7") + job = impl_etcd.EtcdJob( + jobboard, + "job7", + jobboard._client, + f"/taskflow/jobs/{self.path}/job7", + ) jobboard.consume(job, "owner1") jobboard._client.delete_prefix.assert_called_once_with(job.key + ".") mock__remove_job_from_cache.assert_called_once_with(job.key) - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "get_owner_and_data") + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard.get_owner_and_data" + ) def test_consume_bad_owner(self, mock_get_owner_and_data): mock_get_owner_and_data.return_value = ["owner2", mock.Mock()] _, jobboard = self.create_board() jobboard._client = mock.Mock() - job = impl_etcd.EtcdJob(jobboard, - "job7", - jobboard._client, - f"/taskflow/jobs/{self.path}/job7") - self.assertRaisesRegex(exc.JobFailure, "which is not owned", - jobboard.consume, job, "owner1") + job = impl_etcd.EtcdJob( + jobboard, + "job7", + jobboard._client, + f"/taskflow/jobs/{self.path}/job7", + ) + self.assertRaisesRegex( + exc.JobFailure, + "which is not owned", + jobboard.consume, + job, + "owner1", + ) jobboard._client.delete_prefix.assert_not_called() - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "get_owner_and_data") + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard.get_owner_and_data" + ) def test_abandon(self, mock_get_owner_and_data): mock_get_owner_and_data.return_value = ["owner1", mock.Mock()] _, jobboard = self.create_board() jobboard._client = mock.Mock() - job = impl_etcd.EtcdJob(jobboard, - "job7", - jobboard._client, - f"/taskflow/jobs/{self.path}/job7") + job = impl_etcd.EtcdJob( + jobboard, + "job7", + jobboard._client, + f"/taskflow/jobs/{self.path}/job7", + ) jobboard.abandon(job, "owner1") jobboard._client.delete.assert_called_once_with( - f"{job.key}{jobboard.LOCK_POSTFIX}") + f"{job.key}{jobboard.LOCK_POSTFIX}" + ) - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "get_owner_and_data") + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard.get_owner_and_data" + ) def test_abandon_bad_owner(self, mock_get_owner_and_data): mock_get_owner_and_data.return_value = ["owner2", mock.Mock()] _, jobboard = self.create_board() jobboard._client = mock.Mock() - job = impl_etcd.EtcdJob(jobboard, - "job7", - jobboard._client, - f"/taskflow/jobs/{self.path}/job7") - self.assertRaisesRegex(exc.JobFailure, "which is not owned", - jobboard.abandon, job, "owner1") + job = impl_etcd.EtcdJob( + jobboard, + "job7", + jobboard._client, + f"/taskflow/jobs/{self.path}/job7", + ) + self.assertRaisesRegex( + exc.JobFailure, + "which is not owned", + jobboard.abandon, + job, + "owner1", + ) jobboard._client.delete.assert_not_called() - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "get_owner_and_data") - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "_remove_job_from_cache") - def test_trash(self, mock__remove_job_from_cache, - mock_get_owner_and_data): + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard.get_owner_and_data" + ) + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard._remove_job_from_cache" + ) + def test_trash(self, mock__remove_job_from_cache, mock_get_owner_and_data): mock_get_owner_and_data.return_value = ["owner1", mock.Mock()] _, jobboard = self.create_board() jobboard._client = mock.Mock() - job = impl_etcd.EtcdJob(jobboard, - "job7", - jobboard._client, - f"/taskflow/jobs/{self.path}/job7") + job = impl_etcd.EtcdJob( + jobboard, + "job7", + jobboard._client, + f"/taskflow/jobs/{self.path}/job7", + ) jobboard.trash(job, "owner1") jobboard._client.create.assert_called_once_with( - f"/taskflow/.trash/{self.path}/job7", mock.ANY) + f"/taskflow/.trash/{self.path}/job7", mock.ANY + ) jobboard._client.delete_prefix.assert_called_once_with(job.key + ".") mock__remove_job_from_cache.assert_called_once_with(job.key) - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "get_owner_and_data") - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "_remove_job_from_cache") - def test_trash_bad_owner(self, mock__remove_job_from_cache, - mock_get_owner_and_data): + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard.get_owner_and_data" + ) + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard._remove_job_from_cache" + ) + def test_trash_bad_owner( + self, mock__remove_job_from_cache, mock_get_owner_and_data + ): mock_get_owner_and_data.return_value = ["owner2", mock.Mock()] _, jobboard = self.create_board() jobboard._client = mock.Mock() - job = impl_etcd.EtcdJob(jobboard, - "job7", - jobboard._client, - f"/taskflow/jobs/{self.path}/job7") - self.assertRaisesRegex(exc.JobFailure, "which is not owned", - jobboard.trash, job, "owner1") + job = impl_etcd.EtcdJob( + jobboard, + "job7", + jobboard._client, + f"/taskflow/jobs/{self.path}/job7", + ) + self.assertRaisesRegex( + exc.JobFailure, "which is not owned", jobboard.trash, job, "owner1" + ) jobboard._client.create.assert_not_called() jobboard._client.delete_prefix.assert_not_called() mock__remove_job_from_cache.assert_not_called() - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "get_owner_and_data") - @mock.patch("taskflow.jobs.backends.impl_etcd.EtcdJobBoard." - "_remove_job_from_cache") - def test_trash_deleted_job(self, mock__remove_job_from_cache, - mock_get_owner_and_data): + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard.get_owner_and_data" + ) + @mock.patch( + "taskflow.jobs.backends.impl_etcd.EtcdJobBoard._remove_job_from_cache" + ) + def test_trash_deleted_job( + self, mock__remove_job_from_cache, mock_get_owner_and_data + ): mock_get_owner_and_data.return_value = ["owner1", None] _, jobboard = self.create_board() jobboard._client = mock.Mock() - job = impl_etcd.EtcdJob(jobboard, - "job7", - jobboard._client, - f"/taskflow/jobs/{self.path}/job7") - self.assertRaisesRegex(exc.NotFound, "Cannot find job", - jobboard.trash, job, "owner1") + job = impl_etcd.EtcdJob( + jobboard, + "job7", + jobboard._client, + f"/taskflow/jobs/{self.path}/job7", + ) + self.assertRaisesRegex( + exc.NotFound, "Cannot find job", jobboard.trash, job, "owner1" + ) jobboard._client.create.assert_not_called() jobboard._client.delete_prefix.assert_not_called() diff --git a/taskflow/tests/unit/jobs/test_redis_job.py b/taskflow/tests/unit/jobs/test_redis_job.py index e0e19c721..8fe6304da 100644 --- a/taskflow/tests/unit/jobs/test_redis_job.py +++ b/taskflow/tests/unit/jobs/test_redis_job.py @@ -29,7 +29,8 @@ from taskflow.utils import redis_utils as ru REDIS_AVAILABLE = test_utils.redis_available( - impl_redis.RedisJobBoard.MIN_REDIS_VERSION) + impl_redis.RedisJobBoard.MIN_REDIS_VERSION +) REDIS_PORT = test_utils.REDIS_PORT @@ -93,8 +94,12 @@ class RedisJobboardTest(test.TestCase, base.BoardTestMixin): possible_jobs = list(self.board.iterjobs()) self.assertEqual(1, len(possible_jobs)) with self.flush(self.client): - self.assertRaises(excp.UnclaimableJob, self.board.claim, - possible_jobs[0], self.board.name) + self.assertRaises( + excp.UnclaimableJob, + self.board.claim, + possible_jobs[0], + self.board.name, + ) possible_jobs = list(self.board.iterjobs(only_unclaimed=True)) self.assertEqual(0, len(possible_jobs)) @@ -103,12 +108,13 @@ class RedisJobboardTest(test.TestCase, base.BoardTestMixin): self.client, self.board = self.create_board() def test__make_client(self): - conf = {'host': '127.0.0.1', - 'port': 6379, - 'username': 'default', - 'password': 'secret', - 'namespace': 'test' - } + conf = { + 'host': '127.0.0.1', + 'port': 6379, + 'username': 'default', + 'password': 'secret', + 'namespace': 'test', + } test_conf = { 'host': '127.0.0.1', 'port': 6379, @@ -120,16 +126,18 @@ class RedisJobboardTest(test.TestCase, base.BoardTestMixin): mock_ru.assert_called_once_with(**test_conf) def test__make_client_sentinel(self): - conf = {'host': '127.0.0.1', - 'port': 26379, + conf = { + 'host': '127.0.0.1', + 'port': 26379, + 'username': 'default', + 'password': 'secret', + 'namespace': 'test', + 'sentinel': 'mymaster', + 'sentinel_kwargs': { 'username': 'default', - 'password': 'secret', - 'namespace': 'test', - 'sentinel': 'mymaster', - 'sentinel_kwargs': { - 'username': 'default', - 'password': 'senitelsecret' - }} + 'password': 'senitelsecret', + }, + } with mock.patch('redis.sentinel.Sentinel') as mock_sentinel: impl_redis.RedisJobBoard('test-board', conf) test_conf = { @@ -140,21 +148,26 @@ class RedisJobboardTest(test.TestCase, base.BoardTestMixin): [('127.0.0.1', 26379)], sentinel_kwargs={ 'username': 'default', - 'password': 'senitelsecret' + 'password': 'senitelsecret', }, - **test_conf) + **test_conf, + ) mock_sentinel().master_for.assert_called_once_with('mymaster') def test__make_client_sentinel_fallbacks(self): - conf = {'host': '127.0.0.1', - 'port': 26379, - 'username': 'default', - 'password': 'secret', - 'namespace': 'test', - 'sentinel': 'mymaster', - 'sentinel_fallbacks': [ - '[::1]:26379', '127.0.0.2:26379', 'localhost:26379' - ]} + conf = { + 'host': '127.0.0.1', + 'port': 26379, + 'username': 'default', + 'password': 'secret', + 'namespace': 'test', + 'sentinel': 'mymaster', + 'sentinel_fallbacks': [ + '[::1]:26379', + '127.0.0.2:26379', + 'localhost:26379', + ], + } with mock.patch('redis.sentinel.Sentinel') as mock_sentinel: impl_redis.RedisJobBoard('test-board', conf) test_conf = { @@ -163,21 +176,28 @@ class RedisJobboardTest(test.TestCase, base.BoardTestMixin): 'sentinel_kwargs': None, } mock_sentinel.assert_called_once_with( - [('127.0.0.1', 26379), ('::1', 26379), - ('127.0.0.2', 26379), ('localhost', 26379)], - **test_conf) + [ + ('127.0.0.1', 26379), + ('::1', 26379), + ('127.0.0.2', 26379), + ('localhost', 26379), + ], + **test_conf, + ) mock_sentinel().master_for.assert_called_once_with('mymaster') def test__make_client_sentinel_ssl(self): - conf = {'host': '127.0.0.1', - 'port': 26379, - 'username': 'default', - 'password': 'secret', - 'namespace': 'test', - 'sentinel': 'mymaster', - 'sentinel_kwargs': None, - 'ssl': True, - 'ssl_ca_certs': '/etc/ssl/certs'} + conf = { + 'host': '127.0.0.1', + 'port': 26379, + 'username': 'default', + 'password': 'secret', + 'namespace': 'test', + 'sentinel': 'mymaster', + 'sentinel_kwargs': None, + 'ssl': True, + 'ssl_ca_certs': '/etc/ssl/certs', + } with mock.patch('redis.sentinel.Sentinel') as mock_sentinel: impl_redis.RedisJobBoard('test-board', conf) test_conf = { @@ -187,7 +207,6 @@ class RedisJobboardTest(test.TestCase, base.BoardTestMixin): 'ssl_ca_certs': '/etc/ssl/certs', } mock_sentinel.assert_called_once_with( - [('127.0.0.1', 26379)], - sentinel_kwargs=None, - **test_conf) + [('127.0.0.1', 26379)], sentinel_kwargs=None, **test_conf + ) mock_sentinel().master_for.assert_called_once_with('mymaster') diff --git a/taskflow/tests/unit/jobs/test_zk_job.py b/taskflow/tests/unit/jobs/test_zk_job.py index 19455b64e..0276144b1 100644 --- a/taskflow/tests/unit/jobs/test_zk_job.py +++ b/taskflow/tests/unit/jobs/test_zk_job.py @@ -36,7 +36,8 @@ from taskflow.utils import persistence_utils as p_utils FLUSH_PATH_TPL = '/taskflow/flush-test/%s' TEST_PATH_TPL = '/taskflow/board-test/%s' ZOOKEEPER_AVAILABLE = test_utils.zookeeper_available( - impl_zookeeper.ZookeeperJobBoard.MIN_ZK_VERSION) + impl_zookeeper.ZookeeperJobBoard.MIN_ZK_VERSION +) LOCK_POSTFIX = impl_zookeeper.ZookeeperJobBoard.LOCK_POSTFIX @@ -70,25 +71,33 @@ class ZookeeperBoardTestMixin(base.BoardTestMixin): watchers.DataWatch(client, path, func=on_created) client.create(path, makepath=True) if not created.wait(test_utils.WAIT_TIMEOUT): - raise RuntimeError("Could not receive creation of %s in" - " the alloted timeout of %s seconds" - % (path, test_utils.WAIT_TIMEOUT)) + raise RuntimeError( + "Could not receive creation of %s in" + " the alloted timeout of %s seconds" + % (path, test_utils.WAIT_TIMEOUT) + ) try: yield finally: watchers.DataWatch(client, path, func=on_deleted) client.delete(path, recursive=True) if not deleted.wait(test_utils.WAIT_TIMEOUT): - raise RuntimeError("Could not receive deletion of %s in" - " the alloted timeout of %s seconds" - % (path, test_utils.WAIT_TIMEOUT)) + raise RuntimeError( + "Could not receive deletion of %s in" + " the alloted timeout of %s seconds" + % (path, test_utils.WAIT_TIMEOUT) + ) def test_posting_no_post(self): with base.connect_close(self.board): with mock.patch.object(self.client, 'create') as create_func: create_func.side_effect = IOError("Unable to post") - self.assertRaises(IOError, self.board.post, - 'test', p_utils.temporary_log_book()) + self.assertRaises( + IOError, + self.board.post, + 'test', + p_utils.temporary_log_book(), + ) self.assertEqual(0, self.board.job_count) def test_board_iter(self): @@ -98,8 +107,9 @@ class ZookeeperBoardTestMixin(base.BoardTestMixin): self.assertFalse(it.only_unclaimed) self.assertFalse(it.ensure_fresh) - @mock.patch("taskflow.jobs.backends.impl_zookeeper.misc." - "millis_to_datetime") + @mock.patch( + "taskflow.jobs.backends.impl_zookeeper.misc.millis_to_datetime" + ) def test_posting_dates(self, mock_dt): epoch = misc.millis_to_datetime(0) mock_dt.return_value = epoch @@ -123,9 +133,11 @@ class ZookeeperJobboardTest(test.TestCase, ZookeeperBoardTestMixin): client = kazoo_utils.make_client(test_utils.ZK_TEST_CONFIG.copy()) self.path = TEST_PATH_TPL % uuidutils.generate_uuid() board = impl_zookeeper.ZookeeperJobBoard( - 'test-board', {'path': self.path}, + 'test-board', + {'path': self.path}, client=client, - persistence=persistence) + persistence=persistence, + ) self.addCleanup(cleanup_path, client, self.path) self.addCleanup(board.close) self.addCleanup(self.close_client, client) @@ -151,8 +163,10 @@ class ZookeeperJobboardTest(test.TestCase, ZookeeperBoardTestMixin): children = self.client.get_children(self.path) for p in children: if p.endswith(LOCK_POSTFIX): - self.client.set(k_paths.join(self.path, p), - misc.binary_encode(jsonutils.dumps({}))) + self.client.set( + k_paths.join(self.path, p), + misc.binary_encode(jsonutils.dumps({})), + ) self.assertEqual(states.UNCLAIMED, j.state) def test_posting_state_lock_lost(self): @@ -203,8 +217,9 @@ class ZookeeperJobboardTest(test.TestCase, ZookeeperBoardTestMixin): self.assertEqual(self.board, posted_job.board) self.assertEqual(1, self.board.job_count) - self.assertIn(posted_job.uuid, [j.uuid - for j in self.board.iterjobs()]) + self.assertIn( + posted_job.uuid, [j.uuid for j in self.board.iterjobs()] + ) # Remove paths that got created due to the running process that we are # not interested in... @@ -213,42 +228,46 @@ class ZookeeperJobboardTest(test.TestCase, ZookeeperBoardTestMixin): self.assertEqual(1, len(children)) child = self.client.get(k_paths.join(self.path, children[0])) self.assertGreater(len(child[0]), 0) - self.assertEqual({ - 'uuid': posted_job.uuid, - 'name': posted_job.name, - 'book': { - 'name': book.name, - 'uuid': book.uuid, + self.assertEqual( + { + 'uuid': posted_job.uuid, + 'name': posted_job.name, + 'book': { + 'name': book.name, + 'uuid': book.uuid, + }, + 'priority': 'NORMAL', + 'details': {}, }, - 'priority': 'NORMAL', - 'details': {}, - }, jsonutils.loads(misc.binary_decode(child[0]))) + jsonutils.loads(misc.binary_decode(child[0])), + ) def test_register_entity(self): conductor_name = "conductor-abc@localhost:4123" - entity_instance = entity.Entity("conductor", - conductor_name, - {}) + entity_instance = entity.Entity("conductor", conductor_name, {}) with base.connect_close(self.board): self.board.register_entity(entity_instance) # Check '.entity' node has been created self.client.get_children(self.board.entity_path) - conductor_entity_path = k_paths.join(self.board.entity_path, - 'conductor', - conductor_name) + conductor_entity_path = k_paths.join( + self.board.entity_path, 'conductor', conductor_name + ) conductor_data = self.client.get(conductor_entity_path)[0] self.assertGreater(len(conductor_data), 0) - self.assertEqual({ - 'name': conductor_name, - 'kind': 'conductor', - 'metadata': {}, - }, jsonutils.loads(misc.binary_decode(conductor_data))) + self.assertEqual( + { + 'name': conductor_name, + 'kind': 'conductor', + 'metadata': {}, + }, + jsonutils.loads(misc.binary_decode(conductor_data)), + ) - entity_instance_2 = entity.Entity("non-sense", - "other_name", - {}) + entity_instance_2 = entity.Entity("non-sense", "other_name", {}) with base.connect_close(self.board): - self.assertRaises(excp.NotImplementedError, - self.board.register_entity, - entity_instance_2) + self.assertRaises( + excp.NotImplementedError, + self.board.register_entity, + entity_instance_2, + ) diff --git a/taskflow/tests/unit/patterns/test_graph_flow.py b/taskflow/tests/unit/patterns/test_graph_flow.py index ff3372954..b52877573 100644 --- a/taskflow/tests/unit/patterns/test_graph_flow.py +++ b/taskflow/tests/unit/patterns/test_graph_flow.py @@ -30,10 +30,14 @@ class GraphFlowTest(test.TestCase): for not_a_depth in ['not-a-depth', object(), 2, 3.4, False]: flow = gf.Flow('g') flow.add(g_1, g_2) - self.assertRaises((ValueError, TypeError), - flow.link, g_1, g_2, - decider=lambda history: False, - decider_depth=not_a_depth) + self.assertRaises( + (ValueError, TypeError), + flow.link, + g_1, + g_2, + decider=lambda history: False, + decider_depth=not_a_depth, + ) def test_graph_flow_stringy(self): f = gf.Flow('test') @@ -93,8 +97,9 @@ class GraphFlowTest(test.TestCase): self.assertEqual(2, len(f)) self.assertCountEqual(f, [task1, task2]) - self.assertEqual([(task1, task2, {'reasons': {'a'}})], - list(f.iter_links())) + self.assertEqual( + [(task1, task2, {'reasons': {'a'}})], list(f.iter_links()) + ) self.assertEqual(set(), f.requires) self.assertEqual({'a'}, f.provides) @@ -106,8 +111,9 @@ class GraphFlowTest(test.TestCase): self.assertEqual(2, len(f)) self.assertCountEqual(f, [task1, task2]) - self.assertEqual([(task1, task2, {'reasons': {'a'}})], - list(f.iter_links())) + self.assertEqual( + [(task1, task2, {'reasons': {'a'}})], list(f.iter_links()) + ) def test_graph_flow_two_task_same_provide(self): task1 = _task(name='task1', provides=['a', 'b']) @@ -165,10 +171,13 @@ class GraphFlowTest(test.TestCase): self.assertEqual(3, len(f)) - self.assertCountEqual(list(f.iter_links()), [ - (task1, task2, {'reasons': {'a', 'b'}}), - (task2, task3, {'reasons': {'c'}}) - ]) + self.assertCountEqual( + list(f.iter_links()), + [ + (task1, task2, {'reasons': {'a', 'b'}}), + (task2, task3, {'reasons': {'c'}}), + ], + ) def test_graph_flow_links(self): task1 = _task('task1') @@ -176,9 +185,9 @@ class GraphFlowTest(test.TestCase): f = gf.Flow('test').add(task1, task2) linked = f.link(task1, task2) self.assertIs(linked, f) - self.assertCountEqual(list(f.iter_links()), [ - (task1, task2, {'manual': True}) - ]) + self.assertCountEqual( + list(f.iter_links()), [(task1, task2, {'manual': True})] + ) def test_graph_flow_links_and_dependencies(self): task1 = _task('task1', provides=['a']) @@ -186,27 +195,26 @@ class GraphFlowTest(test.TestCase): f = gf.Flow('test').add(task1, task2) linked = f.link(task1, task2) self.assertIs(linked, f) - expected_meta = { - 'manual': True, - 'reasons': {'a'} - } - self.assertCountEqual(list(f.iter_links()), [ - (task1, task2, expected_meta) - ]) + expected_meta = {'manual': True, 'reasons': {'a'}} + self.assertCountEqual( + list(f.iter_links()), [(task1, task2, expected_meta)] + ) def test_graph_flow_link_from_unknown_node(self): task1 = _task('task1') task2 = _task('task2') f = gf.Flow('test').add(task2) - self.assertRaisesRegex(ValueError, 'Node .* not found to link from', - f.link, task1, task2) + self.assertRaisesRegex( + ValueError, 'Node .* not found to link from', f.link, task1, task2 + ) def test_graph_flow_link_to_unknown_node(self): task1 = _task('task1') task2 = _task('task2') f = gf.Flow('test').add(task1) - self.assertRaisesRegex(ValueError, 'Node .* not found to link to', - f.link, task1, task2) + self.assertRaisesRegex( + ValueError, 'Node .* not found to link to', f.link, task1, task2 + ) def test_graph_flow_link_raises_on_cycle(self): task1 = _task('task1', provides=['a']) @@ -236,7 +244,7 @@ class GraphFlowTest(test.TestCase): f1.add(task3) tasks = {task1, task2, f1} f = gf.Flow('test').add(task1, task2, f1) - for (n, data) in f.iter_nodes(): + for n, data in f.iter_nodes(): self.assertIn(n, tasks) self.assertEqual({}, data) @@ -248,14 +256,13 @@ class GraphFlowTest(test.TestCase): f1.add(task3) tasks = {task1, task2, f1} f = gf.Flow('test').add(task1, task2, f1) - for (u, v, data) in f.iter_links(): + for u, v, data in f.iter_links(): self.assertIn(u, tasks) self.assertIn(v, tasks) self.assertEqual({}, data) class TargetedGraphFlowTest(test.TestCase): - def test_targeted_flow_restricts(self): f = gf.TargetedFlow("test") task1 = _task('task1', provides=['a'], requires=[]) @@ -286,8 +293,9 @@ class TargetedGraphFlowTest(test.TestCase): task1 = _task('task1', provides=['a'], requires=[]) task2 = _task('task2', provides=['b'], requires=['a']) f.add(task1) - self.assertRaisesRegex(ValueError, '^Node .* not found', - f.set_target, task2) + self.assertRaisesRegex( + ValueError, '^Node .* not found', f.set_target, task2 + ) def test_targeted_flow_one_node(self): f = gf.TargetedFlow("test") @@ -327,5 +335,7 @@ class TargetedGraphFlowTest(test.TestCase): f.link(task2, task1) self.assertEqual(2, len(f)) - self.assertEqual([(task2, task1, {'manual': True})], - list(f.iter_links()), ) + self.assertEqual( + [(task2, task1, {'manual': True})], + list(f.iter_links()), + ) diff --git a/taskflow/tests/unit/patterns/test_linear_flow.py b/taskflow/tests/unit/patterns/test_linear_flow.py index e88da69a9..0aaea467b 100644 --- a/taskflow/tests/unit/patterns/test_linear_flow.py +++ b/taskflow/tests/unit/patterns/test_linear_flow.py @@ -23,7 +23,6 @@ def _task(name, provides=None, requires=None): class LinearFlowTest(test.TestCase): - def test_linear_flow_stringy(self): f = lf.Flow('test') expected = '"linear_flow.Flow: test(len=0)"' @@ -73,8 +72,9 @@ class LinearFlowTest(test.TestCase): self.assertEqual(2, len(f)) self.assertEqual([task1, task2], list(f)) - self.assertEqual([(task1, task2, {'invariant': True})], - list(f.iter_links())) + self.assertEqual( + [(task1, task2, {'invariant': True})], list(f.iter_links()) + ) def test_linear_flow_two_dependent_tasks(self): task1 = _task(name='task1', provides=['a']) @@ -83,8 +83,9 @@ class LinearFlowTest(test.TestCase): self.assertEqual(2, len(f)) self.assertEqual([task1, task2], list(f)) - self.assertEqual([(task1, task2, {'invariant': True})], - list(f.iter_links())) + self.assertEqual( + [(task1, task2, {'invariant': True})], list(f.iter_links()) + ) self.assertEqual(set(), f.requires) self.assertEqual({'a'}, f.provides) @@ -96,8 +97,10 @@ class LinearFlowTest(test.TestCase): self.assertEqual(2, len(f)) self.assertEqual([task1, task2], list(f)) - self.assertEqual([(task1, task2, {'invariant': True})], - list(f.iter_links()), ) + self.assertEqual( + [(task1, task2, {'invariant': True})], + list(f.iter_links()), + ) def test_linear_flow_three_tasks(self): task1 = _task(name='task1') @@ -107,10 +110,13 @@ class LinearFlowTest(test.TestCase): self.assertEqual(3, len(f)) self.assertEqual([task1, task2, task3], list(f)) - self.assertEqual([ - (task1, task2, {'invariant': True}), - (task2, task3, {'invariant': True}) - ], list(f.iter_links())) + self.assertEqual( + [ + (task1, task2, {'invariant': True}), + (task2, task3, {'invariant': True}), + ], + list(f.iter_links()), + ) def test_linear_flow_with_retry(self): ret = retry.AlwaysRevert(requires=['a'], provides=['b']) @@ -127,7 +133,7 @@ class LinearFlowTest(test.TestCase): task3 = _task(name='task3') f = lf.Flow('test').add(task1, task2, task3) tasks = {task1, task2, task3} - for (node, data) in f.iter_nodes(): + for node, data in f.iter_nodes(): self.assertIn(node, tasks) self.assertEqual({}, data) @@ -137,7 +143,7 @@ class LinearFlowTest(test.TestCase): task3 = _task(name='task3') f = lf.Flow('test').add(task1, task2, task3) tasks = {task1, task2, task3} - for (u, v, data) in f.iter_links(): + for u, v, data in f.iter_links(): self.assertIn(u, tasks) self.assertIn(v, tasks) self.assertEqual({'invariant': True}, data) diff --git a/taskflow/tests/unit/patterns/test_unordered_flow.py b/taskflow/tests/unit/patterns/test_unordered_flow.py index 033159b98..f76577fa5 100644 --- a/taskflow/tests/unit/patterns/test_unordered_flow.py +++ b/taskflow/tests/unit/patterns/test_unordered_flow.py @@ -23,7 +23,6 @@ def _task(name, provides=None, requires=None): class UnorderedFlowTest(test.TestCase): - def test_unordered_flow_stringy(self): f = uf.Flow('test') expected = '"unordered_flow.Flow: test(len=0)"' @@ -123,7 +122,7 @@ class UnorderedFlowTest(test.TestCase): tasks = {task1, task2} f = uf.Flow('test') f.add(task2, task1) - for (node, data) in f.iter_nodes(): + for node, data in f.iter_nodes(): self.assertIn(node, tasks) self.assertEqual({}, data) @@ -132,5 +131,5 @@ class UnorderedFlowTest(test.TestCase): task2 = _task(name='task2', provides=['a', 'c']) f = uf.Flow('test') f.add(task2, task1) - for (u, v, data) in f.iter_links(): + for u, v, data in f.iter_links(): raise AssertionError('links iterator should be empty') diff --git a/taskflow/tests/unit/persistence/test_dir_persistence.py b/taskflow/tests/unit/persistence/test_dir_persistence.py index 0aef5fa19..33c46c2f8 100644 --- a/taskflow/tests/unit/persistence/test_dir_persistence.py +++ b/taskflow/tests/unit/persistence/test_dir_persistence.py @@ -28,9 +28,9 @@ from taskflow import test from taskflow.tests.unit.persistence import base -class DirPersistenceTest(testscenarios.TestWithScenarios, - test.TestCase, base.PersistenceTestMixin): - +class DirPersistenceTest( + testscenarios.TestWithScenarios, test.TestCase, base.PersistenceTestMixin +): scenarios = [ ('no_cache', {'max_cache_size': None}), ('one', {'max_cache_size': 1}), @@ -45,10 +45,12 @@ class DirPersistenceTest(testscenarios.TestWithScenarios, def setUp(self): super().setUp() self.path = tempfile.mkdtemp() - self.backend = impl_dir.DirBackend({ - 'path': self.path, - 'max_cache_size': self.max_cache_size, - }) + self.backend = impl_dir.DirBackend( + { + 'path': self.path, + 'max_cache_size': self.max_cache_size, + } + ) with contextlib.closing(self._get_connection()) as conn: conn.upgrade() @@ -83,8 +85,9 @@ class DirPersistenceTest(testscenarios.TestWithScenarios, self.assertRaises(exc.NotFound, conn.get_logbook, lb_id) conn.save_logbook(lb) books_ids_made.append(lb_id) - self.assertLessEqual(self.backend.file_cache.currsize, - self.max_cache_size) + self.assertLessEqual( + self.backend.file_cache.currsize, self.max_cache_size + ) # Also ensure that we can still read all created books... with contextlib.closing(self._get_connection()) as conn: for lb_id in books_ids_made: @@ -95,8 +98,12 @@ class DirPersistenceTest(testscenarios.TestWithScenarios, self._check_backend(dict(connection='dir:', path=self.path)) def test_dir_backend_name(self): - self._check_backend(dict(connection='dir', # no colon - path=self.path)) + self._check_backend( + dict( + connection='dir', # no colon + path=self.path, + ) + ) def test_file_backend_entry_point(self): self._check_backend(dict(connection='file:', path=self.path)) diff --git a/taskflow/tests/unit/persistence/test_memory_persistence.py b/taskflow/tests/unit/persistence/test_memory_persistence.py index a16fde4e8..7bc48a492 100644 --- a/taskflow/tests/unit/persistence/test_memory_persistence.py +++ b/taskflow/tests/unit/persistence/test_memory_persistence.py @@ -47,7 +47,6 @@ class MemoryPersistenceTest(test.TestCase, base.PersistenceTestMixin): class MemoryFilesystemTest(test.TestCase): - @staticmethod def _get_item_path(fs, path): # TODO(harlowja): is there a better way to do this?? @@ -76,18 +75,21 @@ class MemoryFilesystemTest(test.TestCase): fs.ensure_path("/b/c/d") fs.ensure_path("/a/b/c/d") contents = fs.ls_r("/", absolute=False) - self.assertEqual([ - 'a', - 'b', - 'c', - 'd', - 'a/b', - 'b/c', - 'c/d', - 'a/b/c', - 'b/c/d', - 'a/b/c/d', - ], contents) + self.assertEqual( + [ + 'a', + 'b', + 'c', + 'd', + 'a/b', + 'b/c', + 'c/d', + 'a/b/c', + 'b/c/d', + 'a/b/c/d', + ], + contents, + ) def test_ls_recursive_absolute(self): fs = impl_memory.FakeFilesystem() @@ -96,18 +98,21 @@ class MemoryFilesystemTest(test.TestCase): fs.ensure_path("/b/c/d") fs.ensure_path("/a/b/c/d") contents = fs.ls_r("/", absolute=True) - self.assertEqual([ - '/a', - '/b', - '/c', - '/d', - '/a/b', - '/b/c', - '/c/d', - '/a/b/c', - '/b/c/d', - '/a/b/c/d', - ], contents) + self.assertEqual( + [ + '/a', + '/b', + '/c', + '/d', + '/a/b', + '/b/c', + '/c/d', + '/a/b/c', + '/b/c/d', + '/a/b/c/d', + ], + contents, + ) def test_ls_recursive_targeted(self): fs = impl_memory.FakeFilesystem() diff --git a/taskflow/tests/unit/persistence/test_sql_persistence.py b/taskflow/tests/unit/persistence/test_sql_persistence.py index b292683ad..74933f61d 100644 --- a/taskflow/tests/unit/persistence/test_sql_persistence.py +++ b/taskflow/tests/unit/persistence/test_sql_persistence.py @@ -34,8 +34,7 @@ import testtools USER = "openstack_citest" PASSWD = "openstack_citest" -DATABASE = "tftest_" + ''.join(random.choice('0123456789') - for _ in range(12)) +DATABASE = "tftest_" + ''.join(random.choice('0123456789') for _ in range(12)) import sqlalchemy as sa @@ -99,6 +98,7 @@ def _postgres_exists(): class SqlitePersistenceTest(test.TestCase, base.PersistenceTestMixin): """Inherits from the base test and sets up a sqlite temporary db.""" + def _get_connection(self): conf = { 'connection': self.db_uri, @@ -120,8 +120,9 @@ class SqlitePersistenceTest(test.TestCase, base.PersistenceTestMixin): self.db_location = None -class BackendPersistenceTestMixin(base.PersistenceTestMixin, - metaclass=abc.ABCMeta): +class BackendPersistenceTestMixin( + base.PersistenceTestMixin, metaclass=abc.ABCMeta +): """Specifies a backend type and does required setup and teardown.""" def _get_connection(self): @@ -148,15 +149,15 @@ class BackendPersistenceTestMixin(base.PersistenceTestMixin, self.backend = None try: self.db_uri = self._init_db() - self.db_conf = { - 'connection': self.db_uri - } + self.db_conf = {'connection': self.db_uri} # Since we are using random database names, we need to make sure # and remove our random database when we are done testing. self.addCleanup(self._remove_db) except Exception as e: - self.skipTest("Failed to create temporary database;" - " testing being skipped due to: %s" % (e)) + self.skipTest( + "Failed to create temporary database;" + " testing being skipped due to: %s" % (e) + ) else: self.backend = impl_sqlalchemy.SQLAlchemyBackend(self.db_conf) self.addCleanup(self.backend.close) @@ -166,7 +167,6 @@ class BackendPersistenceTestMixin(base.PersistenceTestMixin, @testtools.skipIf(not _mysql_exists(), 'mysql is not available') class MysqlPersistenceTest(BackendPersistenceTestMixin, test.TestCase): - def _init_db(self): engine = None try: @@ -182,8 +182,7 @@ class MysqlPersistenceTest(BackendPersistenceTestMixin, test.TestCase): engine.dispose() except Exception: pass - return _get_connect_string('mysql', USER, PASSWD, - database=DATABASE) + return _get_connect_string('mysql', USER, PASSWD, database=DATABASE) def _remove_db(self): engine = None @@ -203,15 +202,15 @@ class MysqlPersistenceTest(BackendPersistenceTestMixin, test.TestCase): @testtools.skipIf(not _postgres_exists(), 'postgres is not available') class PostgresPersistenceTest(BackendPersistenceTestMixin, test.TestCase): - def _init_db(self): engine = None try: # Postgres can't operate on the database it's connected to, that's # why we connect to the database 'postgres' and then create the # desired database. - db_uri = _get_connect_string('postgres', USER, PASSWD, - database='postgres') + db_uri = _get_connect_string( + 'postgres', USER, PASSWD, database='postgres' + ) engine = sa.create_engine(db_uri) with contextlib.closing(engine.connect()) as conn: conn.connection.set_isolation_level(0) @@ -225,8 +224,7 @@ class PostgresPersistenceTest(BackendPersistenceTestMixin, test.TestCase): engine.dispose() except Exception: pass - return _get_connect_string('postgres', USER, PASSWD, - database=DATABASE) + return _get_connect_string('postgres', USER, PASSWD, database=DATABASE) def _remove_db(self): engine = None @@ -234,8 +232,9 @@ class PostgresPersistenceTest(BackendPersistenceTestMixin, test.TestCase): # Postgres can't operate on the database it's connected to, that's # why we connect to the 'postgres' database and then drop the # database. - db_uri = _get_connect_string('postgres', USER, PASSWD, - database='postgres') + db_uri = _get_connect_string( + 'postgres', USER, PASSWD, database='postgres' + ) engine = sa.create_engine(db_uri) with contextlib.closing(engine.connect()) as conn: conn.connection.set_isolation_level(0) @@ -252,7 +251,6 @@ class PostgresPersistenceTest(BackendPersistenceTestMixin, test.TestCase): class SQLBackendFetchingTest(test.TestCase): - def test_sqlite_persistence_entry_point(self): conf = {'connection': 'sqlite:///'} with contextlib.closing(backends.fetch(conf)) as be: diff --git a/taskflow/tests/unit/persistence/test_zk_persistence.py b/taskflow/tests/unit/persistence/test_zk_persistence.py index 2be1c3343..c2bd030a2 100644 --- a/taskflow/tests/unit/persistence/test_zk_persistence.py +++ b/taskflow/tests/unit/persistence/test_zk_persistence.py @@ -28,7 +28,8 @@ from taskflow.utils import kazoo_utils TEST_PATH_TPL = '/taskflow/persistence-test/%s' _ZOOKEEPER_AVAILABLE = test_utils.zookeeper_available( - impl_zookeeper.MIN_ZK_VERSION) + impl_zookeeper.MIN_ZK_VERSION +) def clean_backend(backend, conf): diff --git a/taskflow/tests/unit/test_arguments_passing.py b/taskflow/tests/unit/test_arguments_passing.py index cf113b82d..6330df4cb 100644 --- a/taskflow/tests/unit/test_arguments_passing.py +++ b/taskflow/tests/unit/test_arguments_passing.py @@ -23,7 +23,6 @@ from taskflow.utils import eventlet_utils as eu class ArgumentsPassingTest(utils.EngineTestBase): - def test_save_as(self): flow = utils.TaskOneReturn(name='task1', provides='first_data') engine = self._make_engine(flow) @@ -34,45 +33,50 @@ class ArgumentsPassingTest(utils.EngineTestBase): flow = utils.TaskMultiReturn(provides='all_data') engine = self._make_engine(flow) engine.run() - self.assertEqual({'all_data': (1, 3, 5)}, - engine.storage.fetch_all()) + self.assertEqual({'all_data': (1, 3, 5)}, engine.storage.fetch_all()) def test_save_several_values(self): flow = utils.TaskMultiReturn(provides=('badger', 'mushroom', 'snake')) engine = self._make_engine(flow) engine.run() - self.assertEqual({ - 'badger': 1, - 'mushroom': 3, - 'snake': 5 - }, engine.storage.fetch_all()) + self.assertEqual( + {'badger': 1, 'mushroom': 3, 'snake': 5}, + engine.storage.fetch_all(), + ) def test_save_dict(self): - flow = utils.TaskMultiDict(provides={'badger', - 'mushroom', - 'snake'}) + flow = utils.TaskMultiDict(provides={'badger', 'mushroom', 'snake'}) engine = self._make_engine(flow) engine.run() - self.assertEqual({ - 'badger': 0, - 'mushroom': 1, - 'snake': 2, - }, engine.storage.fetch_all()) + self.assertEqual( + { + 'badger': 0, + 'mushroom': 1, + 'snake': 2, + }, + engine.storage.fetch_all(), + ) def test_bad_save_as_value(self): - self.assertRaises(TypeError, - utils.TaskOneReturn, - name='task1', provides=object()) + self.assertRaises( + TypeError, utils.TaskOneReturn, name='task1', provides=object() + ) def test_arguments_passing(self): flow = utils.TaskMultiArgOneReturn(provides='result') engine = self._make_engine(flow) engine.storage.inject({'x': 1, 'y': 4, 'z': 9, 'a': 17}) engine.run() - self.assertEqual({ - 'x': 1, 'y': 4, 'z': 9, 'a': 17, - 'result': 14, - }, engine.storage.fetch_all()) + self.assertEqual( + { + 'x': 1, + 'y': 4, + 'z': 9, + 'a': 17, + 'result': 14, + }, + engine.storage.fetch_all(), + ) def test_arguments_missing(self): flow = utils.TaskMultiArg() @@ -81,58 +85,85 @@ class ArgumentsPassingTest(utils.EngineTestBase): self.assertRaises(exc.MissingDependencies, engine.run) def test_partial_arguments_mapping(self): - flow = utils.TaskMultiArgOneReturn(provides='result', - rebind={'x': 'a'}) + flow = utils.TaskMultiArgOneReturn( + provides='result', rebind={'x': 'a'} + ) engine = self._make_engine(flow) engine.storage.inject({'x': 1, 'y': 4, 'z': 9, 'a': 17}) engine.run() - self.assertEqual({ - 'x': 1, 'y': 4, 'z': 9, 'a': 17, - 'result': 30, - }, engine.storage.fetch_all()) + self.assertEqual( + { + 'x': 1, + 'y': 4, + 'z': 9, + 'a': 17, + 'result': 30, + }, + engine.storage.fetch_all(), + ) def test_argument_injection(self): - flow = utils.TaskMultiArgOneReturn(provides='result', - inject={'x': 1, 'y': 4, 'z': 9}) + flow = utils.TaskMultiArgOneReturn( + provides='result', inject={'x': 1, 'y': 4, 'z': 9} + ) engine = self._make_engine(flow) engine.run() - self.assertEqual({ - 'result': 14, - }, engine.storage.fetch_all()) + self.assertEqual( + { + 'result': 14, + }, + engine.storage.fetch_all(), + ) def test_argument_injection_rebind(self): - flow = utils.TaskMultiArgOneReturn(provides='result', - rebind=['a', 'b', 'c'], - inject={'a': 1, 'b': 4, 'c': 9}) + flow = utils.TaskMultiArgOneReturn( + provides='result', + rebind=['a', 'b', 'c'], + inject={'a': 1, 'b': 4, 'c': 9}, + ) engine = self._make_engine(flow) engine.run() - self.assertEqual({ - 'result': 14, - }, engine.storage.fetch_all()) + self.assertEqual( + { + 'result': 14, + }, + engine.storage.fetch_all(), + ) def test_argument_injection_required(self): - flow = utils.TaskMultiArgOneReturn(provides='result', - requires=['a', 'b', 'c'], - inject={'x': 1, 'y': 4, 'z': 9, - 'a': 0, 'b': 0, 'c': 0}) + flow = utils.TaskMultiArgOneReturn( + provides='result', + requires=['a', 'b', 'c'], + inject={'x': 1, 'y': 4, 'z': 9, 'a': 0, 'b': 0, 'c': 0}, + ) engine = self._make_engine(flow) engine.run() - self.assertEqual({ - 'result': 14, - }, engine.storage.fetch_all()) + self.assertEqual( + { + 'result': 14, + }, + engine.storage.fetch_all(), + ) def test_all_arguments_mapping(self): - flow = utils.TaskMultiArgOneReturn(provides='result', - rebind=['a', 'b', 'c']) + flow = utils.TaskMultiArgOneReturn( + provides='result', rebind=['a', 'b', 'c'] + ) engine = self._make_engine(flow) - engine.storage.inject({ - 'a': 1, 'b': 2, 'c': 3, 'x': 4, 'y': 5, 'z': 6 - }) + engine.storage.inject({'a': 1, 'b': 2, 'c': 3, 'x': 4, 'y': 5, 'z': 6}) engine.run() - self.assertEqual({ - 'a': 1, 'b': 2, 'c': 3, 'x': 4, 'y': 5, 'z': 6, - 'result': 6, - }, engine.storage.fetch_all()) + self.assertEqual( + { + 'a': 1, + 'b': 2, + 'c': 3, + 'x': 4, + 'y': 5, + 'z': 6, + 'result': 6, + }, + engine.storage.fetch_all(), + ) def test_invalid_argument_name_map(self): flow = utils.TaskMultiArg(rebind={'z': 'b'}) @@ -147,19 +178,18 @@ class ArgumentsPassingTest(utils.EngineTestBase): self.assertRaises(exc.MissingDependencies, engine.run) def test_bad_rebind_args_value(self): - self.assertRaises(TypeError, - utils.TaskOneArg, - rebind=object()) + self.assertRaises(TypeError, utils.TaskOneArg, rebind=object()) def test_long_arg_name(self): - flow = utils.LongArgNameTask(requires='long_arg_name', - provides='result') + flow = utils.LongArgNameTask( + requires='long_arg_name', provides='result' + ) engine = self._make_engine(flow) engine.storage.inject({'long_arg_name': 1}) engine.run() - self.assertEqual({ - 'long_arg_name': 1, 'result': 1 - }, engine.storage.fetch_all()) + self.assertEqual( + {'long_arg_name': 1, 'result': 1}, engine.storage.fetch_all() + ) def test_revert_rebound_args_required(self): flow = utils.TaskMultiArg(revert_rebind={'z': 'b'}) @@ -183,12 +213,13 @@ class ArgumentsPassingTest(utils.EngineTestBase): class SerialEngineTest(ArgumentsPassingTest, test.TestCase): - def _make_engine(self, flow, flow_detail=None): - return taskflow.engines.load(flow, - flow_detail=flow_detail, - engine='serial', - backend=self.backend) + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + engine='serial', + backend=self.backend, + ) class ParallelEngineWithThreadsTest(ArgumentsPassingTest, test.TestCase): @@ -197,23 +228,26 @@ class ParallelEngineWithThreadsTest(ArgumentsPassingTest, test.TestCase): def _make_engine(self, flow, flow_detail=None, executor=None): if executor is None: executor = 'threads' - return taskflow.engines.load(flow, - flow_detail=flow_detail, - engine='parallel', - backend=self.backend, - executor=executor, - max_workers=self._EXECUTOR_WORKERS) + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + engine='parallel', + backend=self.backend, + executor=executor, + max_workers=self._EXECUTOR_WORKERS, + ) @testtools.skipIf(not eu.EVENTLET_AVAILABLE, 'eventlet is not available') class ParallelEngineWithEventletTest(ArgumentsPassingTest, test.TestCase): - def _make_engine(self, flow, flow_detail=None, executor=None): if executor is None: executor = futurist.GreenThreadPoolExecutor() self.addCleanup(executor.shutdown) - return taskflow.engines.load(flow, - flow_detail=flow_detail, - backend=self.backend, - engine='parallel', - executor=executor) + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + backend=self.backend, + engine='parallel', + executor=executor, + ) diff --git a/taskflow/tests/unit/test_check_transition.py b/taskflow/tests/unit/test_check_transition.py index 12da5af41..8d5f65bf2 100644 --- a/taskflow/tests/unit/test_check_transition.py +++ b/taskflow/tests/unit/test_check_transition.py @@ -18,7 +18,6 @@ from taskflow import test class TransitionTest(test.TestCase): - _DISALLOWED_TPL = "Transition from '%s' to '%s' was found to be disallowed" _NOT_IGNORED_TPL = "Transition from '%s' to '%s' was not ignored" @@ -31,12 +30,17 @@ class TransitionTest(test.TestCase): self.assertFalse(self.check_transition(from_state, to_state), msg=msg) def assertTransitionForbidden(self, from_state, to_state): - self.assertRaisesRegex(exc.InvalidState, - self.transition_exc_regexp, - self.check_transition, from_state, to_state) + self.assertRaisesRegex( + exc.InvalidState, + self.transition_exc_regexp, + self.check_transition, + from_state, + to_state, + ) - def assertTransitions(self, from_state, allowed=None, ignored=None, - forbidden=None): + def assertTransitions( + self, from_state, allowed=None, ignored=None, forbidden=None + ): for a in allowed or []: self.assertTransitionAllowed(from_state, a) for i in ignored or []: @@ -46,7 +50,6 @@ class TransitionTest(test.TestCase): class CheckFlowTransitionTest(TransitionTest): - def setUp(self): super().setUp() self.check_transition = states.check_flow_transition @@ -69,71 +72,119 @@ class CheckFlowTransitionTest(TransitionTest): class CheckTaskTransitionTest(TransitionTest): - def setUp(self): super().setUp() self.check_transition = states.check_task_transition self.transition_exc_regexp = '^Task transition.*not allowed' def test_from_pending_state(self): - self.assertTransitions(from_state=states.PENDING, - allowed=(states.RUNNING,), - ignored=(states.PENDING, states.REVERTING, - states.SUCCESS, states.FAILURE, - states.REVERTED)) + self.assertTransitions( + from_state=states.PENDING, + allowed=(states.RUNNING,), + ignored=( + states.PENDING, + states.REVERTING, + states.SUCCESS, + states.FAILURE, + states.REVERTED, + ), + ) def test_from_running_state(self): - self.assertTransitions(from_state=states.RUNNING, - allowed=(states.SUCCESS, states.FAILURE,), - ignored=(states.REVERTING, states.RUNNING, - states.PENDING, states.REVERTED)) + self.assertTransitions( + from_state=states.RUNNING, + allowed=( + states.SUCCESS, + states.FAILURE, + ), + ignored=( + states.REVERTING, + states.RUNNING, + states.PENDING, + states.REVERTED, + ), + ) def test_from_success_state(self): - self.assertTransitions(from_state=states.SUCCESS, - allowed=(states.REVERTING,), - ignored=(states.RUNNING, states.SUCCESS, - states.PENDING, states.FAILURE, - states.REVERTED)) + self.assertTransitions( + from_state=states.SUCCESS, + allowed=(states.REVERTING,), + ignored=( + states.RUNNING, + states.SUCCESS, + states.PENDING, + states.FAILURE, + states.REVERTED, + ), + ) def test_from_failure_state(self): - self.assertTransitions(from_state=states.FAILURE, - allowed=(states.REVERTING,), - ignored=(states.FAILURE, states.RUNNING, - states.PENDING, - states.SUCCESS, states.REVERTED)) + self.assertTransitions( + from_state=states.FAILURE, + allowed=(states.REVERTING,), + ignored=( + states.FAILURE, + states.RUNNING, + states.PENDING, + states.SUCCESS, + states.REVERTED, + ), + ) def test_from_reverting_state(self): - self.assertTransitions(from_state=states.REVERTING, - allowed=(states.REVERT_FAILURE, - states.REVERTED), - ignored=(states.RUNNING, states.REVERTING, - states.PENDING, states.SUCCESS)) + self.assertTransitions( + from_state=states.REVERTING, + allowed=(states.REVERT_FAILURE, states.REVERTED), + ignored=( + states.RUNNING, + states.REVERTING, + states.PENDING, + states.SUCCESS, + ), + ) def test_from_reverted_state(self): - self.assertTransitions(from_state=states.REVERTED, - allowed=(states.PENDING,), - ignored=(states.REVERTING, states.REVERTED, - states.RUNNING, - states.SUCCESS, states.FAILURE)) + self.assertTransitions( + from_state=states.REVERTED, + allowed=(states.PENDING,), + ignored=( + states.REVERTING, + states.REVERTED, + states.RUNNING, + states.SUCCESS, + states.FAILURE, + ), + ) class CheckRetryTransitionTest(CheckTaskTransitionTest): - def setUp(self): super().setUp() self.check_transition = states.check_retry_transition self.transition_exc_regexp = '^Retry transition.*not allowed' def test_from_success_state(self): - self.assertTransitions(from_state=states.SUCCESS, - allowed=(states.REVERTING, states.RETRYING), - ignored=(states.RUNNING, states.SUCCESS, - states.PENDING, states.FAILURE, - states.REVERTED)) + self.assertTransitions( + from_state=states.SUCCESS, + allowed=(states.REVERTING, states.RETRYING), + ignored=( + states.RUNNING, + states.SUCCESS, + states.PENDING, + states.FAILURE, + states.REVERTED, + ), + ) def test_from_retrying_state(self): - self.assertTransitions(from_state=states.RETRYING, - allowed=(states.RUNNING,), - ignored=(states.RETRYING, states.SUCCESS, - states.PENDING, states.FAILURE, - states.REVERTED)) + self.assertTransitions( + from_state=states.RETRYING, + allowed=(states.RUNNING,), + ignored=( + states.RETRYING, + states.SUCCESS, + states.PENDING, + states.FAILURE, + states.REVERTED, + ), + ) diff --git a/taskflow/tests/unit/test_conductors.py b/taskflow/tests/unit/test_conductors.py index 23acdd209..1c6d41cb3 100644 --- a/taskflow/tests/unit/test_conductors.py +++ b/taskflow/tests/unit/test_conductors.py @@ -37,7 +37,8 @@ from taskflow.utils import threading_utils TEST_PATH_TPL = '/taskflow/conductor-test/%s' ZOOKEEPER_AVAILABLE = test_utils.zookeeper_available( - impl_zookeeper.ZookeeperJobBoard.MIN_ZK_VERSION) + impl_zookeeper.ZookeeperJobBoard.MIN_ZK_VERSION +) def test_factory(blowup): @@ -66,23 +67,34 @@ def single_factory(): return futurist.ThreadPoolExecutor(max_workers=1) -ComponentBundle = collections.namedtuple('ComponentBundle', - ['board', 'persistence', 'conductor']) +ComponentBundle = collections.namedtuple( + 'ComponentBundle', ['board', 'persistence', 'conductor'] +) @testtools.skipIf(not ZOOKEEPER_AVAILABLE, 'zookeeper is not available') -class ManyConductorTest(testscenarios.TestWithScenarios, - test_utils.EngineTestBase, test.TestCase): +class ManyConductorTest( + testscenarios.TestWithScenarios, test_utils.EngineTestBase, test.TestCase +): scenarios = [ - ('blocking', {'kind': 'blocking', - 'conductor_kwargs': {'wait_timeout': 0.1}}), - ('nonblocking_many_thread', - {'kind': 'nonblocking', 'conductor_kwargs': {'wait_timeout': 0.1}}), - ('nonblocking_one_thread', {'kind': 'nonblocking', - 'conductor_kwargs': { - 'executor_factory': single_factory, - 'wait_timeout': 0.1, - }}) + ( + 'blocking', + {'kind': 'blocking', 'conductor_kwargs': {'wait_timeout': 0.1}}, + ), + ( + 'nonblocking_many_thread', + {'kind': 'nonblocking', 'conductor_kwargs': {'wait_timeout': 0.1}}, + ), + ( + 'nonblocking_one_thread', + { + 'kind': 'nonblocking', + 'conductor_kwargs': { + 'executor_factory': single_factory, + 'wait_timeout': 0.1, + }, + }, + ), ] def make_components(self): @@ -95,18 +107,17 @@ class ManyConductorTest(testscenarios.TestWithScenarios, path = TEST_PATH_TPL % uuidutils.generate_uuid() persistence = impl_memory.MemoryBackend() board = impl_zookeeper.ZookeeperJobBoard( - 'testing', - {'path': path}, - client=client, - persistence=persistence) + 'testing', {'path': path}, client=client, persistence=persistence + ) self.addCleanup(cleanup_path, client, path) self.addCleanup(kazoo_utils.finalize_client, client) conductor_kwargs = self.conductor_kwargs.copy() conductor_kwargs['persistence'] = persistence - conductor = backends.fetch(self.kind, 'testing', board, - **conductor_kwargs) + conductor = backends.fetch( + self.kind, 'testing', board, **conductor_kwargs + ) return ComponentBundle(board, persistence, conductor) def test_connection(self): @@ -123,8 +134,7 @@ class ManyConductorTest(testscenarios.TestWithScenarios, t = threading_utils.daemon_thread(components.conductor.run) t.start() components.conductor.stop() - self.assertTrue( - components.conductor.wait(test_utils.WAIT_TIMEOUT)) + self.assertTrue(components.conductor.wait(test_utils.WAIT_TIMEOUT)) self.assertFalse(components.conductor.dispatching) t.join() @@ -147,19 +157,18 @@ class ManyConductorTest(testscenarios.TestWithScenarios, job_abandoned_event.set() components.board.notifier.register(base.REMOVAL, on_consume) - components.conductor.notifier.register("job_consumed", - on_job_consumed) - components.conductor.notifier.register("job_abandoned", - on_job_abandoned) + components.conductor.notifier.register("job_consumed", on_job_consumed) + components.conductor.notifier.register( + "job_abandoned", on_job_abandoned + ) with contextlib.closing(components.conductor): t = threading_utils.daemon_thread(components.conductor.run) t.start() lb, fd = pu.temporary_flow_detail(components.persistence) - engines.save_factory_details(fd, test_factory, - [False], {}, - backend=components.persistence) - components.board.post('poke', lb, - details={'flow_uuid': fd.uuid}) + engines.save_factory_details( + fd, test_factory, [False], {}, backend=components.persistence + ) + components.board.post('poke', lb, details={'flow_uuid': fd.uuid}) self.assertTrue(consumed_event.wait(test_utils.WAIT_TIMEOUT)) self.assertTrue(job_consumed_event.wait(test_utils.WAIT_TIMEOUT)) self.assertFalse(job_abandoned_event.wait(1)) @@ -185,19 +194,19 @@ class ManyConductorTest(testscenarios.TestWithScenarios, components.board.notifier.register(base.REMOVAL, on_consume) with contextlib.closing(components.conductor): t = threading_utils.daemon_thread( - lambda: components.conductor.run(max_dispatches=5)) + lambda: components.conductor.run(max_dispatches=5) + ) t.start() lb, fd = pu.temporary_flow_detail(components.persistence) - engines.save_factory_details(fd, test_factory, - [False], {}, - backend=components.persistence) + engines.save_factory_details( + fd, test_factory, [False], {}, backend=components.persistence + ) for _ in range(5): - components.board.post('poke', lb, - details={'flow_uuid': fd.uuid}) - self.assertTrue(consumed_event.wait( - test_utils.WAIT_TIMEOUT)) - components.board.post('poke', lb, - details={'flow_uuid': fd.uuid}) + components.board.post( + 'poke', lb, details={'flow_uuid': fd.uuid} + ) + self.assertTrue(consumed_event.wait(test_utils.WAIT_TIMEOUT)) + components.board.post('poke', lb, details={'flow_uuid': fd.uuid}) components.conductor.stop() self.assertTrue(components.conductor.wait(test_utils.WAIT_TIMEOUT)) self.assertFalse(components.conductor.dispatching) @@ -221,19 +230,18 @@ class ManyConductorTest(testscenarios.TestWithScenarios, job_abandoned_event.set() components.board.notifier.register(base.REMOVAL, on_consume) - components.conductor.notifier.register("job_consumed", - on_job_consumed) - components.conductor.notifier.register("job_abandoned", - on_job_abandoned) + components.conductor.notifier.register("job_consumed", on_job_consumed) + components.conductor.notifier.register( + "job_abandoned", on_job_abandoned + ) with contextlib.closing(components.conductor): t = threading_utils.daemon_thread(components.conductor.run) t.start() lb, fd = pu.temporary_flow_detail(components.persistence) - engines.save_factory_details(fd, test_factory, - [True], {}, - backend=components.persistence) - components.board.post('poke', lb, - details={'flow_uuid': fd.uuid}) + engines.save_factory_details( + fd, test_factory, [True], {}, backend=components.persistence + ) + components.board.post('poke', lb, details={'flow_uuid': fd.uuid}) self.assertTrue(consumed_event.wait(test_utils.WAIT_TIMEOUT)) self.assertTrue(job_consumed_event.wait(test_utils.WAIT_TIMEOUT)) self.assertFalse(job_abandoned_event.wait(1)) @@ -261,11 +269,10 @@ class ManyConductorTest(testscenarios.TestWithScenarios, t = threading_utils.daemon_thread(components.conductor.run) t.start() lb, fd = pu.temporary_flow_detail(components.persistence) - engines.save_factory_details(fd, test_store_factory, - [], {}, - backend=components.persistence) - components.board.post('poke', lb, - details={'flow_uuid': fd.uuid}) + engines.save_factory_details( + fd, test_store_factory, [], {}, backend=components.persistence + ) + components.board.post('poke', lb, details={'flow_uuid': fd.uuid}) self.assertTrue(consumed_event.wait(test_utils.WAIT_TIMEOUT)) components.conductor.stop() self.assertTrue(components.conductor.wait(test_utils.WAIT_TIMEOUT)) @@ -293,12 +300,12 @@ class ManyConductorTest(testscenarios.TestWithScenarios, t = threading_utils.daemon_thread(components.conductor.run) t.start() lb, fd = pu.temporary_flow_detail(components.persistence) - engines.save_factory_details(fd, test_store_factory, - [], {}, - backend=components.persistence) - components.board.post('poke', lb, - details={'flow_uuid': fd.uuid, - 'store': store}) + engines.save_factory_details( + fd, test_store_factory, [], {}, backend=components.persistence + ) + components.board.post( + 'poke', lb, details={'flow_uuid': fd.uuid, 'store': store} + ) self.assertTrue(consumed_event.wait(test_utils.WAIT_TIMEOUT)) components.conductor.stop() self.assertTrue(components.conductor.wait(test_utils.WAIT_TIMEOUT)) @@ -325,13 +332,13 @@ class ManyConductorTest(testscenarios.TestWithScenarios, with contextlib.closing(components.conductor): t = threading_utils.daemon_thread(components.conductor.run) t.start() - lb, fd = pu.temporary_flow_detail(components.persistence, - meta={'store': store}) - engines.save_factory_details(fd, test_store_factory, - [], {}, - backend=components.persistence) - components.board.post('poke', lb, - details={'flow_uuid': fd.uuid}) + lb, fd = pu.temporary_flow_detail( + components.persistence, meta={'store': store} + ) + engines.save_factory_details( + fd, test_store_factory, [], {}, backend=components.persistence + ) + components.board.post('poke', lb, details={'flow_uuid': fd.uuid}) self.assertTrue(consumed_event.wait(test_utils.WAIT_TIMEOUT)) components.conductor.stop() self.assertTrue(components.conductor.wait(test_utils.WAIT_TIMEOUT)) @@ -359,14 +366,15 @@ class ManyConductorTest(testscenarios.TestWithScenarios, with contextlib.closing(components.conductor): t = threading_utils.daemon_thread(components.conductor.run) t.start() - lb, fd = pu.temporary_flow_detail(components.persistence, - meta={'store': flow_store}) - engines.save_factory_details(fd, test_store_factory, - [], {}, - backend=components.persistence) - components.board.post('poke', lb, - details={'flow_uuid': fd.uuid, - 'store': job_store}) + lb, fd = pu.temporary_flow_detail( + components.persistence, meta={'store': flow_store} + ) + engines.save_factory_details( + fd, test_store_factory, [], {}, backend=components.persistence + ) + components.board.post( + 'poke', lb, details={'flow_uuid': fd.uuid, 'store': job_store} + ) self.assertTrue(consumed_event.wait(test_utils.WAIT_TIMEOUT)) components.conductor.stop() self.assertTrue(components.conductor.wait(test_utils.WAIT_TIMEOUT)) @@ -402,22 +410,25 @@ class ManyConductorTest(testscenarios.TestWithScenarios, job_abandoned_event.set() components.board.notifier.register(base.REMOVAL, on_consume) - components.conductor.notifier.register("job_consumed", - on_job_consumed) - components.conductor.notifier.register("job_abandoned", - on_job_abandoned) - components.conductor.notifier.register("running_start", - on_running_start) + components.conductor.notifier.register("job_consumed", on_job_consumed) + components.conductor.notifier.register( + "job_abandoned", on_job_abandoned + ) + components.conductor.notifier.register( + "running_start", on_running_start + ) with contextlib.closing(components.conductor): t = threading_utils.daemon_thread(components.conductor.run) t.start() lb, fd = pu.temporary_flow_detail(components.persistence) - engines.save_factory_details(fd, sleep_factory, - [], {}, - backend=components.persistence) - components.board.post('poke', lb, - details={'flow_uuid': fd.uuid, - 'store': {'duration': 2}}) + engines.save_factory_details( + fd, sleep_factory, [], {}, backend=components.persistence + ) + components.board.post( + 'poke', + lb, + details={'flow_uuid': fd.uuid, 'store': {'duration': 2}}, + ) running_start_event.wait(test_utils.WAIT_TIMEOUT) components.conductor.stop() job_abandoned_event.wait(test_utils.WAIT_TIMEOUT) @@ -433,21 +444,31 @@ class NonBlockingExecutorTest(test.TestCase): board = impl_zookeeper.ZookeeperJobBoard( 'testing', test_utils.ZK_TEST_CONFIG.copy(), - persistence=persistence) - self.assertRaises(ValueError, - backends.fetch, - 'nonblocking', 'testing', board, - persistence=persistence, - wait_timeout='testing') + persistence=persistence, + ) + self.assertRaises( + ValueError, + backends.fetch, + 'nonblocking', + 'testing', + board, + persistence=persistence, + wait_timeout='testing', + ) def test_bad_factory(self): persistence = impl_memory.MemoryBackend() board = impl_zookeeper.ZookeeperJobBoard( 'testing', test_utils.ZK_TEST_CONFIG.copy(), - persistence=persistence) - self.assertRaises(ValueError, - backends.fetch, - 'nonblocking', 'testing', board, - persistence=persistence, - executor_factory='testing') + persistence=persistence, + ) + self.assertRaises( + ValueError, + backends.fetch, + 'nonblocking', + 'testing', + board, + persistence=persistence, + executor_factory='testing', + ) diff --git a/taskflow/tests/unit/test_deciders.py b/taskflow/tests/unit/test_deciders.py index 0ac2775bb..617be0379 100644 --- a/taskflow/tests/unit/test_deciders.py +++ b/taskflow/tests/unit/test_deciders.py @@ -19,18 +19,24 @@ from taskflow import test class TestDeciders(test.TestCase): def test_translate(self): for val in ['all', 'ALL', 'aLL', deciders.Depth.ALL]: - self.assertEqual(deciders.Depth.ALL, - deciders.Depth.translate(val)) + self.assertEqual(deciders.Depth.ALL, deciders.Depth.translate(val)) for val in ['atom', 'ATOM', 'atOM', deciders.Depth.ATOM]: - self.assertEqual(deciders.Depth.ATOM, - deciders.Depth.translate(val)) - for val in ['neighbors', 'Neighbors', - 'NEIGHBORS', deciders.Depth.NEIGHBORS]: - self.assertEqual(deciders.Depth.NEIGHBORS, - deciders.Depth.translate(val)) + self.assertEqual( + deciders.Depth.ATOM, deciders.Depth.translate(val) + ) + for val in [ + 'neighbors', + 'Neighbors', + 'NEIGHBORS', + deciders.Depth.NEIGHBORS, + ]: + self.assertEqual( + deciders.Depth.NEIGHBORS, deciders.Depth.translate(val) + ) for val in ['flow', 'FLOW', 'flOW', deciders.Depth.FLOW]: - self.assertEqual(deciders.Depth.FLOW, - deciders.Depth.translate(val)) + self.assertEqual( + deciders.Depth.FLOW, deciders.Depth.translate(val) + ) def test_bad_translate(self): self.assertRaises(TypeError, deciders.Depth.translate, 3) @@ -40,15 +46,23 @@ class TestDeciders(test.TestCase): def test_pick_widest(self): choices = [deciders.Depth.ATOM, deciders.Depth.FLOW] self.assertEqual(deciders.Depth.FLOW, deciders.pick_widest(choices)) - choices = [deciders.Depth.ATOM, deciders.Depth.FLOW, - deciders.Depth.ALL] + choices = [ + deciders.Depth.ATOM, + deciders.Depth.FLOW, + deciders.Depth.ALL, + ] self.assertEqual(deciders.Depth.ALL, deciders.pick_widest(choices)) - choices = [deciders.Depth.ATOM, deciders.Depth.FLOW, - deciders.Depth.ALL, deciders.Depth.NEIGHBORS] + choices = [ + deciders.Depth.ATOM, + deciders.Depth.FLOW, + deciders.Depth.ALL, + deciders.Depth.NEIGHBORS, + ] self.assertEqual(deciders.Depth.ALL, deciders.pick_widest(choices)) choices = [deciders.Depth.ATOM, deciders.Depth.NEIGHBORS] - self.assertEqual(deciders.Depth.NEIGHBORS, - deciders.pick_widest(choices)) + self.assertEqual( + deciders.Depth.NEIGHBORS, deciders.pick_widest(choices) + ) def test_bad_pick_widest(self): self.assertRaises(ValueError, deciders.pick_widest, []) diff --git a/taskflow/tests/unit/test_engine_helpers.py b/taskflow/tests/unit/test_engine_helpers.py index 7997f75f8..2c8e58a57 100644 --- a/taskflow/tests/unit/test_engine_helpers.py +++ b/taskflow/tests/unit/test_engine_helpers.py @@ -34,8 +34,12 @@ class EngineLoadingTestCase(test.TestCase): def test_unknown_load(self): f = self._make_dummy_flow() - self.assertRaises(exc.NotFound, taskflow.engines.load, f, - engine='not_really_any_engine') + self.assertRaises( + exc.NotFound, + taskflow.engines.load, + f, + engine='not_really_any_engine', + ) def test_options_empty(self): f = self._make_dummy_flow() @@ -52,35 +56,43 @@ class FlowFromDetailTestCase(test.TestCase): def test_no_meta(self): _lb, flow_detail = p_utils.temporary_flow_detail() self.assertEqual({}, flow_detail.meta) - self.assertRaisesRegex(ValueError, - '^Cannot .* no factory information saved.$', - taskflow.engines.flow_from_detail, - flow_detail) + self.assertRaisesRegex( + ValueError, + '^Cannot .* no factory information saved.$', + taskflow.engines.flow_from_detail, + flow_detail, + ) def test_no_factory_in_meta(self): _lb, flow_detail = p_utils.temporary_flow_detail() - self.assertRaisesRegex(ValueError, - '^Cannot .* no factory information saved.$', - taskflow.engines.flow_from_detail, - flow_detail) + self.assertRaisesRegex( + ValueError, + '^Cannot .* no factory information saved.$', + taskflow.engines.flow_from_detail, + flow_detail, + ) def test_no_importable_function(self): _lb, flow_detail = p_utils.temporary_flow_detail() - flow_detail.meta = dict(factory=dict( - name='you can not import me, i contain spaces' - )) - self.assertRaisesRegex(ImportError, - '^Could not import factory', - taskflow.engines.flow_from_detail, - flow_detail) + flow_detail.meta = dict( + factory=dict(name='you can not import me, i contain spaces') + ) + self.assertRaisesRegex( + ImportError, + '^Could not import factory', + taskflow.engines.flow_from_detail, + flow_detail, + ) def test_no_arg_factory(self): name = 'some.test.factory' _lb, flow_detail = p_utils.temporary_flow_detail() flow_detail.meta = dict(factory=dict(name=name)) - with mock.patch('oslo_utils.importutils.import_class', - return_value=lambda: 'RESULT') as mock_import: + with mock.patch( + 'oslo_utils.importutils.import_class', + return_value=lambda: 'RESULT', + ) as mock_import: result = taskflow.engines.flow_from_detail(flow_detail) mock_import.assert_called_once_with(name) self.assertEqual('RESULT', result) @@ -90,8 +102,10 @@ class FlowFromDetailTestCase(test.TestCase): _lb, flow_detail = p_utils.temporary_flow_detail() flow_detail.meta = dict(factory=dict(name=name, args=['foo'])) - with mock.patch('oslo_utils.importutils.import_class', - return_value=lambda x: 'RESULT %s' % x) as mock_import: + with mock.patch( + 'oslo_utils.importutils.import_class', + return_value=lambda x: 'RESULT %s' % x, + ) as mock_import: result = taskflow.engines.flow_from_detail(flow_detail) mock_import.assert_called_once_with(name) self.assertEqual('RESULT foo', result) @@ -102,40 +116,49 @@ def my_flow_factory(task_name): class LoadFromFactoryTestCase(test.TestCase): - def test_non_reimportable(self): def factory(): pass - self.assertRaisesRegex(ValueError, - 'Flow factory .* is not reimportable', - taskflow.engines.load_from_factory, - factory) + self.assertRaisesRegex( + ValueError, + 'Flow factory .* is not reimportable', + taskflow.engines.load_from_factory, + factory, + ) def test_it_works(self): engine = taskflow.engines.load_from_factory( - my_flow_factory, factory_kwargs={'task_name': 'test1'}) + my_flow_factory, factory_kwargs={'task_name': 'test1'} + ) self.assertIsInstance(engine._flow, test_utils.DummyTask) fd = engine.storage._flowdetail self.assertEqual('test1', fd.name) - self.assertEqual({ - 'name': '%s.my_flow_factory' % __name__, - 'args': [], - 'kwargs': {'task_name': 'test1'}, - }, fd.meta.get('factory')) + self.assertEqual( + { + 'name': '%s.my_flow_factory' % __name__, + 'args': [], + 'kwargs': {'task_name': 'test1'}, + }, + fd.meta.get('factory'), + ) def test_it_works_by_name(self): factory_name = '%s.my_flow_factory' % __name__ engine = taskflow.engines.load_from_factory( - factory_name, factory_kwargs={'task_name': 'test1'}) + factory_name, factory_kwargs={'task_name': 'test1'} + ) self.assertIsInstance(engine._flow, test_utils.DummyTask) fd = engine.storage._flowdetail self.assertEqual('test1', fd.name) - self.assertEqual({ - 'name': factory_name, - 'args': [], - 'kwargs': {'task_name': 'test1'}, - }, fd.meta.get('factory')) + self.assertEqual( + { + 'name': factory_name, + 'args': [], + 'kwargs': {'task_name': 'test1'}, + }, + fd.meta.get('factory'), + ) diff --git a/taskflow/tests/unit/test_engines.py b/taskflow/tests/unit/test_engines.py index 1644ea435..c5b462baa 100644 --- a/taskflow/tests/unit/test_engines.py +++ b/taskflow/tests/unit/test_engines.py @@ -42,8 +42,11 @@ from taskflow.utils import threading_utils as tu # Expected engine transitions when empty workflows are ran... _EMPTY_TRANSITIONS = [ - states.RESUMING, states.SCHEDULING, states.WAITING, - states.ANALYZING, states.SUCCESS, + states.RESUMING, + states.SCHEDULING, + states.WAITING, + states.ANALYZING, + states.SUCCESS, ] @@ -57,18 +60,21 @@ class EngineTaskNotificationsTest: flow = lf.Flow("flow") work_1 = utils.MultiProgressingTask('work-1') - work_1.notifier.register(task.EVENT_UPDATE_PROGRESS, - functools.partial(do_capture, 'work-1')) + work_1.notifier.register( + task.EVENT_UPDATE_PROGRESS, functools.partial(do_capture, 'work-1') + ) work_2 = utils.MultiProgressingTask('work-2') - work_2.notifier.register(task.EVENT_UPDATE_PROGRESS, - functools.partial(do_capture, 'work-2')) + work_2.notifier.register( + task.EVENT_UPDATE_PROGRESS, functools.partial(do_capture, 'work-2') + ) flow.add(work_1, work_2) # NOTE(harlowja): These were selected so that float comparison will # work vs not work... progress_chunks = tuple([0.2, 0.5, 0.8]) engine = self._make_engine( - flow, store={'progress_chunks': progress_chunks}) + flow, store={'progress_chunks': progress_chunks} + ) engine.run() expected = [ @@ -83,7 +89,6 @@ class EngineTaskNotificationsTest: class EngineTaskTest: - def test_run_task_as_flow(self): flow = utils.ProgressingTask(name='task1') engine = self._make_engine(flow) @@ -97,18 +102,26 @@ class EngineTaskTest: engine = self._make_engine(flow) with utils.CaptureListener(engine) as capturer: engine.run() - expected = ['task1.f RUNNING', 'task1.t RUNNING', - 'task1.t SUCCESS(5)', 'task1.f SUCCESS'] + expected = [ + 'task1.f RUNNING', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_failing_task_with_flow_notifications(self): values = [] flow = utils.FailingTask('fail') engine = self._make_engine(flow) - expected = ['fail.f RUNNING', 'fail.t RUNNING', - 'fail.t FAILURE(Failure: RuntimeError: Woot!)', - 'fail.t REVERTING', 'fail.t REVERTED(None)', - 'fail.f REVERTED'] + expected = [ + 'fail.f RUNNING', + 'fail.t RUNNING', + 'fail.t FAILURE(Failure: RuntimeError: Woot!)', + 'fail.t REVERTING', + 'fail.t REVERTED(None)', + 'fail.f REVERTED', + ] with utils.CaptureListener(engine, values=values) as capturer: self.assertFailuresRegexp(RuntimeError, '^Woot', engine.run) self.assertEqual(expected, capturer.values) @@ -153,20 +166,21 @@ class EngineOptionalRequirementsTest(utils.EngineTestBase): flow_no_inject.add(utils.OptionalTask(provides='result')) flow_inject_a = lf.Flow("flow") - flow_inject_a.add(utils.OptionalTask(provides='result', - inject={'a': 10})) + flow_inject_a.add( + utils.OptionalTask(provides='result', inject={'a': 10}) + ) flow_inject_b = lf.Flow("flow") - flow_inject_b.add(utils.OptionalTask(provides='result', - inject={'b': 1000})) + flow_inject_b.add( + utils.OptionalTask(provides='result', inject={'b': 1000}) + ) engine = self._make_engine(flow_no_inject, store={'a': 3}) engine.run() result = engine.storage.fetch_all() self.assertEqual({'a': 3, 'result': 15}, result) - engine = self._make_engine(flow_no_inject, - store={'a': 3, 'b': 7}) + engine = self._make_engine(flow_no_inject, store={'a': 3, 'b': 7}) engine.run() result = engine.storage.fetch_all() self.assertEqual({'a': 3, 'b': 7, 'result': 21}, result) @@ -204,8 +218,9 @@ class EngineMultipleResultsTest(utils.EngineTestBase): def test_many_results_visible_to(self): flow = lf.Flow("flow") - flow.add(utils.AddOneSameProvidesRequires( - 'a', rebind={'value': 'source'})) + flow.add( + utils.AddOneSameProvidesRequires('a', rebind={'value': 'source'}) + ) flow.add(utils.AddOneSameProvidesRequires('b')) flow.add(utils.AddOneSameProvidesRequires('c')) engine = self._make_engine(flow, store={'source': 0}) @@ -214,18 +229,15 @@ class EngineMultipleResultsTest(utils.EngineTestBase): # Check what each task in the prior should be seeing... atoms = list(flow) a = atoms[0] - a_kwargs = engine.storage.fetch_mapped_args(a.rebind, - atom_name='a') + a_kwargs = engine.storage.fetch_mapped_args(a.rebind, atom_name='a') self.assertEqual({'value': 0}, a_kwargs) b = atoms[1] - b_kwargs = engine.storage.fetch_mapped_args(b.rebind, - atom_name='b') + b_kwargs = engine.storage.fetch_mapped_args(b.rebind, atom_name='b') self.assertEqual({'value': 1}, b_kwargs) c = atoms[2] - c_kwargs = engine.storage.fetch_mapped_args(c.rebind, - atom_name='c') + c_kwargs = engine.storage.fetch_mapped_args(c.rebind, atom_name='c') self.assertEqual({'value': 2}, c_kwargs) def test_many_results_storage_provided_visible_to(self): @@ -242,18 +254,15 @@ class EngineMultipleResultsTest(utils.EngineTestBase): # Check what each task in the prior should be seeing... atoms = list(flow) a = atoms[0] - a_kwargs = engine.storage.fetch_mapped_args(a.rebind, - atom_name='a') + a_kwargs = engine.storage.fetch_mapped_args(a.rebind, atom_name='a') self.assertEqual({'value': 0}, a_kwargs) b = atoms[1] - b_kwargs = engine.storage.fetch_mapped_args(b.rebind, - atom_name='b') + b_kwargs = engine.storage.fetch_mapped_args(b.rebind, atom_name='b') self.assertEqual({'value': 0}, b_kwargs) c = atoms[2] - c_kwargs = engine.storage.fetch_mapped_args(c.rebind, - atom_name='c') + c_kwargs = engine.storage.fetch_mapped_args(c.rebind, atom_name='c') self.assertEqual({'value': 0}, c_kwargs) def test_fetch_with_two_results(self): @@ -294,7 +303,6 @@ class EngineMultipleResultsTest(utils.EngineTestBase): class EngineLinearFlowTest(utils.EngineTestBase): - def test_run_empty_linear_flow(self): flow = lf.Flow('flow-1') engine = self._make_engine(flow) @@ -365,9 +373,7 @@ class EngineLinearFlowTest(utils.EngineTestBase): self.assertEqual(states.FAILURE, engine.storage.get_flow_state()) def test_sequential_flow_one_task(self): - flow = lf.Flow('flow-1').add( - utils.ProgressingTask(name='task1') - ) + flow = lf.Flow('flow-1').add(utils.ProgressingTask(name='task1')) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: engine.run() @@ -377,34 +383,42 @@ class EngineLinearFlowTest(utils.EngineTestBase): def test_sequential_flow_two_tasks(self): flow = lf.Flow('flow-2').add( utils.ProgressingTask(name='task1'), - utils.ProgressingTask(name='task2') + utils.ProgressingTask(name='task2'), ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: engine.run() - expected = ['task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)'] + expected = [ + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t SUCCESS(5)', + ] self.assertEqual(expected, capturer.values) self.assertEqual(2, len(flow)) def test_sequential_flow_two_tasks_iter(self): flow = lf.Flow('flow-2').add( utils.ProgressingTask(name='task1'), - utils.ProgressingTask(name='task2') + utils.ProgressingTask(name='task2'), ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: gathered_states = list(engine.run_iter()) self.assertGreater(len(gathered_states), 0) - expected = ['task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)'] + expected = [ + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t SUCCESS(5)', + ] self.assertEqual(expected, capturer.values) self.assertEqual(2, len(flow)) def test_sequential_flow_iter_suspend_resume(self): flow = lf.Flow('flow-2').add( utils.ProgressingTask(name='task1'), - utils.ProgressingTask(name='task2') + utils.ProgressingTask(name='task2'), ) lb, fd = p_utils.temporary_flow_detail(self.backend) @@ -439,7 +453,7 @@ class EngineLinearFlowTest(utils.EngineTestBase): flow = lf.Flow('revert-removes').add( utils.TaskOneReturn(provides='one'), utils.TaskMultiReturn(provides=('a', 'b', 'c')), - utils.FailingTask(name='fail') + utils.FailingTask(name='fail'), ) engine = self._make_engine(flow) self.assertFailuresRegexp(RuntimeError, '^Woot', engine.run) @@ -447,8 +461,7 @@ class EngineLinearFlowTest(utils.EngineTestBase): def test_revert_provided(self): flow = lf.Flow('revert').add( - utils.GiveBackRevert('giver'), - utils.FailingTask(name='fail') + utils.GiveBackRevert('giver'), utils.FailingTask(name='fail') ) engine = self._make_engine(flow, store={'value': 0}) self.assertFailuresRegexp(RuntimeError, '^Woot', engine.run) @@ -456,8 +469,7 @@ class EngineLinearFlowTest(utils.EngineTestBase): def test_nasty_revert(self): flow = lf.Flow('revert').add( - utils.NastyTask('nasty'), - utils.FailingTask(name='fail') + utils.NastyTask('nasty'), utils.FailingTask(name='fail') ) engine = self._make_engine(flow) self.assertFailuresRegexp(RuntimeError, '^Gotcha', engine.run) @@ -471,21 +483,22 @@ class EngineLinearFlowTest(utils.EngineTestBase): def test_sequential_flow_nested_blocks(self): flow = lf.Flow('nested-1').add( utils.ProgressingTask('task1'), - lf.Flow('inner-1').add( - utils.ProgressingTask('task2') - ) + lf.Flow('inner-1').add(utils.ProgressingTask('task2')), ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: engine.run() - expected = ['task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)'] + expected = [ + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t SUCCESS(5)', + ] self.assertEqual(expected, capturer.values) def test_revert_exception_is_reraised(self): flow = lf.Flow('revert-1').add( - utils.NastyTask(), - utils.FailingTask(name='fail') + utils.NastyTask(), utils.FailingTask(name='fail') ) engine = self._make_engine(flow) self.assertFailuresRegexp(RuntimeError, '^Gotcha', engine.run) @@ -498,34 +511,42 @@ class EngineLinearFlowTest(utils.EngineTestBase): engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: self.assertFailuresRegexp(RuntimeError, '^Woot', engine.run) - expected = ['fail.t RUNNING', - 'fail.t FAILURE(Failure: RuntimeError: Woot!)', - 'fail.t REVERTING', 'fail.t REVERTED(None)'] + expected = [ + 'fail.t RUNNING', + 'fail.t FAILURE(Failure: RuntimeError: Woot!)', + 'fail.t REVERTING', + 'fail.t REVERTED(None)', + ] self.assertEqual(expected, capturer.values) def test_correctly_reverts_children(self): flow = lf.Flow('root-1').add( utils.ProgressingTask('task1'), lf.Flow('child-1').add( - utils.ProgressingTask('task2'), - utils.FailingTask('fail') - ) + utils.ProgressingTask('task2'), utils.FailingTask('fail') + ), ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: self.assertFailuresRegexp(RuntimeError, '^Woot', engine.run) - expected = ['task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', - 'fail.t RUNNING', - 'fail.t FAILURE(Failure: RuntimeError: Woot!)', - 'fail.t REVERTING', 'fail.t REVERTED(None)', - 'task2.t REVERTING', 'task2.t REVERTED(None)', - 'task1.t REVERTING', 'task1.t REVERTED(None)'] + expected = [ + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t SUCCESS(5)', + 'fail.t RUNNING', + 'fail.t FAILURE(Failure: RuntimeError: Woot!)', + 'fail.t REVERTING', + 'fail.t REVERTED(None)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + ] self.assertEqual(expected, capturer.values) class EngineParallelFlowTest(utils.EngineTestBase): - def test_run_empty_unordered_flow(self): flow = uf.Flow('p-1') engine = self._make_engine(flow) @@ -571,26 +592,31 @@ class EngineParallelFlowTest(utils.EngineTestBase): def test_parallel_flow_two_tasks(self): flow = uf.Flow('p-2').add( utils.ProgressingTask(name='task1'), - utils.ProgressingTask(name='task2') + utils.ProgressingTask(name='task2'), ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: engine.run() - expected = {'task2.t SUCCESS(5)', 'task2.t RUNNING', - 'task1.t RUNNING', 'task1.t SUCCESS(5)'} + expected = { + 'task2.t SUCCESS(5)', + 'task2.t RUNNING', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + } self.assertEqual(expected, set(capturer.values)) def test_parallel_revert(self): flow = uf.Flow('p-r-3').add( utils.TaskNoRequiresNoReturns(name='task1'), utils.FailingTask(name='fail'), - utils.TaskNoRequiresNoReturns(name='task2') + utils.TaskNoRequiresNoReturns(name='task2'), ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: self.assertFailuresRegexp(RuntimeError, '^Woot', engine.run) - self.assertIn('fail.t FAILURE(Failure: RuntimeError: Woot!)', - capturer.values) + self.assertIn( + 'fail.t FAILURE(Failure: RuntimeError: Woot!)', capturer.values + ) def test_parallel_revert_exception_is_reraised(self): # NOTE(imelnikov): if we put NastyTask and FailingTask @@ -599,10 +625,9 @@ class EngineParallelFlowTest(utils.EngineTestBase): # FailingTask fails. flow = lf.Flow('p-r-r-l').add( uf.Flow('p-r-r').add( - utils.TaskNoRequiresNoReturns(name='task1'), - utils.NastyTask() + utils.TaskNoRequiresNoReturns(name='task1'), utils.NastyTask() ), - utils.FailingTask() + utils.FailingTask(), ) engine = self._make_engine(flow) self.assertFailuresRegexp(RuntimeError, '^Gotcha', engine.run) @@ -610,7 +635,7 @@ class EngineParallelFlowTest(utils.EngineTestBase): def test_sequential_flow_two_tasks_with_resumption(self): flow = lf.Flow('lf-2-r').add( utils.ProgressingTask(name='task1', provides='x1'), - utils.ProgressingTask(name='task2', provides='x2') + utils.ProgressingTask(name='task2', provides='x2'), ) # Create FlowDetail as if we already run task1 @@ -629,8 +654,7 @@ class EngineParallelFlowTest(utils.EngineTestBase): engine.run() expected = ['task2.t RUNNING', 'task2.t SUCCESS(5)'] self.assertEqual(expected, capturer.values) - self.assertEqual({'x1': 17, 'x2': 5}, - engine.storage.fetch_all()) + self.assertEqual({'x1': 17, 'x2': 5}, engine.storage.fetch_all()) # Reproducer for #2139228 and #2086453 def test_many_unordered_flows_in_linear_flow(self): @@ -648,15 +672,13 @@ class EngineParallelFlowTest(utils.EngineTestBase): class EngineLinearAndUnorderedExceptionsTest(utils.EngineTestBase): - def test_revert_ok_for_unordered_in_linear(self): flow = lf.Flow('p-root').add( utils.ProgressingTask(name='task1'), utils.ProgressingTask(name='task2'), uf.Flow('p-inner').add( - utils.ProgressingTask(name='task3'), - utils.FailingTask('fail') - ) + utils.ProgressingTask(name='task3'), utils.FailingTask('fail') + ), ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: @@ -665,20 +687,27 @@ class EngineLinearAndUnorderedExceptionsTest(utils.EngineTestBase): # NOTE(imelnikov): we don't know if task 3 was run, but if it was, # it should have been REVERTED(None) in correct order. possible_values_no_task3 = [ - 'task1.t RUNNING', 'task2.t RUNNING', + 'task1.t RUNNING', + 'task2.t RUNNING', 'fail.t FAILURE(Failure: RuntimeError: Woot!)', - 'task2.t REVERTED(None)', 'task1.t REVERTED(None)' + 'task2.t REVERTED(None)', + 'task1.t REVERTED(None)', ] - self.assertIsSuperAndSubsequence(capturer.values, - possible_values_no_task3) + self.assertIsSuperAndSubsequence( + capturer.values, possible_values_no_task3 + ) if 'task3' in capturer.values: possible_values_task3 = [ - 'task1.t RUNNING', 'task2.t RUNNING', 'task3.t RUNNING', - 'task3.t REVERTED(None)', 'task2.t REVERTED(None)', - 'task1.t REVERTED(None)' + 'task1.t RUNNING', + 'task2.t RUNNING', + 'task3.t RUNNING', + 'task3.t REVERTED(None)', + 'task2.t REVERTED(None)', + 'task1.t REVERTED(None)', ] - self.assertIsSuperAndSubsequence(capturer.values, - possible_values_task3) + self.assertIsSuperAndSubsequence( + capturer.values, possible_values_task3 + ) def test_revert_raises_for_unordered_in_linear(self): flow = lf.Flow('p-root').add( @@ -686,59 +715,70 @@ class EngineLinearAndUnorderedExceptionsTest(utils.EngineTestBase): utils.ProgressingTask(name='task2'), uf.Flow('p-inner').add( utils.ProgressingTask(name='task3'), - utils.NastyFailingTask(name='nasty') - ) + utils.NastyFailingTask(name='nasty'), + ), ) engine = self._make_engine(flow) - with utils.CaptureListener(engine, - capture_flow=False, - skip_tasks=['nasty']) as capturer: + with utils.CaptureListener( + engine, capture_flow=False, skip_tasks=['nasty'] + ) as capturer: self.assertFailuresRegexp(RuntimeError, '^Gotcha', engine.run) # NOTE(imelnikov): we don't know if task 3 was run, but if it was, # it should have been REVERTED(None) in correct order. - possible_values = ['task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', - 'task3.t RUNNING', 'task3.t SUCCESS(5)', - 'task3.t REVERTING', - 'task3.t REVERTED(None)'] + possible_values = [ + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t SUCCESS(5)', + 'task3.t RUNNING', + 'task3.t SUCCESS(5)', + 'task3.t REVERTING', + 'task3.t REVERTED(None)', + ] self.assertIsSuperAndSubsequence(possible_values, capturer.values) possible_values_no_task3 = ['task1.t RUNNING', 'task2.t RUNNING'] - self.assertIsSuperAndSubsequence(capturer.values, - possible_values_no_task3) + self.assertIsSuperAndSubsequence( + capturer.values, possible_values_no_task3 + ) def test_revert_ok_for_linear_in_unordered(self): flow = uf.Flow('p-root').add( utils.ProgressingTask(name='task1'), lf.Flow('p-inner').add( - utils.ProgressingTask(name='task2'), - utils.FailingTask('fail') - ) + utils.ProgressingTask(name='task2'), utils.FailingTask('fail') + ), ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: self.assertFailuresRegexp(RuntimeError, '^Woot', engine.run) - self.assertIn('fail.t FAILURE(Failure: RuntimeError: Woot!)', - capturer.values) + self.assertIn( + 'fail.t FAILURE(Failure: RuntimeError: Woot!)', capturer.values + ) # NOTE(imelnikov): if task1 was run, it should have been reverted. if 'task1' in capturer.values: - task1_story = ['task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task1.t REVERTED(None)'] + task1_story = [ + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task1.t REVERTED(None)', + ] self.assertIsSuperAndSubsequence(capturer.values, task1_story) # NOTE(imelnikov): task2 should have been run and reverted - task2_story = ['task2.t RUNNING', 'task2.t SUCCESS(5)', - 'task2.t REVERTED(None)'] + task2_story = [ + 'task2.t RUNNING', + 'task2.t SUCCESS(5)', + 'task2.t REVERTED(None)', + ] self.assertIsSuperAndSubsequence(capturer.values, task2_story) def test_revert_raises_for_linear_in_unordered(self): flow = uf.Flow('p-root').add( utils.ProgressingTask(name='task1'), lf.Flow('p-inner').add( - utils.ProgressingTask(name='task2'), - utils.NastyFailingTask() - ) + utils.ProgressingTask(name='task2'), utils.NastyFailingTask() + ), ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: @@ -747,7 +787,6 @@ class EngineLinearAndUnorderedExceptionsTest(utils.EngineTestBase): class EngineDeciderDepthTest(utils.EngineTestBase): - def test_run_graph_flow_decider_various_depths(self): sub_flow_1 = gf.Flow('g_1') g_1_1 = utils.ProgressingTask(name='g_1-1') @@ -756,21 +795,26 @@ class EngineDeciderDepthTest(utils.EngineTestBase): g_2 = utils.ProgressingTask(name='g-2') g_3 = utils.ProgressingTask(name='g-3') g_4 = utils.ProgressingTask(name='g-4') - for a_depth, ran_how_many in [('all', 1), - ('atom', 4), - ('flow', 2), - ('neighbors', 3)]: + for a_depth, ran_how_many in [ + ('all', 1), + ('atom', 4), + ('flow', 2), + ('neighbors', 3), + ]: flow = gf.Flow('g') flow.add(g_1, g_2, sub_flow_1, g_3, g_4) - flow.link(g_1, g_2, - decider=lambda history: False, - decider_depth=a_depth) + flow.link( + g_1, g_2, decider=lambda history: False, decider_depth=a_depth + ) flow.link(g_2, sub_flow_1) flow.link(g_2, g_3) flow.link(g_3, g_4) - flow.link(g_1, sub_flow_1, - decider=lambda history: True, - decider_depth=a_depth) + flow.link( + g_1, + sub_flow_1, + decider=lambda history: True, + decider_depth=a_depth, + ) e = self._make_engine(flow) with utils.CaptureListener(e, capture_flow=False) as capturer: e.run() @@ -786,8 +830,7 @@ class EngineDeciderDepthTest(utils.EngineTestBase): b = utils.AddOneSameProvidesRequires("b") c = utils.AddOneSameProvidesRequires("c") flow.add(a, b, c, resolve_requires=False) - flow.link(a, b, decider=lambda history: False, - decider_depth='atom') + flow.link(a, b, decider=lambda history: False, decider_depth='atom') flow.link(b, c) e = self._make_engine(flow) e.run() @@ -800,8 +843,7 @@ class EngineDeciderDepthTest(utils.EngineTestBase): b = utils.FailingTask("b") c = utils.NoopTask("c") flow.add(a, b, c) - flow.link(a, b, decider=lambda history: False, - decider_depth='atom') + flow.link(a, b, decider=lambda history: False, decider_depth='atom') flow.link(b, c) e = self._make_engine(flow) e.run() @@ -812,8 +854,7 @@ class EngineDeciderDepthTest(utils.EngineTestBase): b = utils.NoopTask("b") c = utils.FailingTask("c") flow.add(a, b, c) - flow.link(a, b, decider=lambda history: False, - decider_depth='atom') + flow.link(a, b, decider=lambda history: False, decider_depth='atom') flow.link(b, c) e = self._make_engine(flow) with utils.CaptureListener(e, capture_flow=False) as capturer: @@ -835,22 +876,18 @@ class EngineDeciderDepthTest(utils.EngineTestBase): class EngineGraphFlowTest(utils.EngineTestBase): - def test_run_empty_graph_flow(self): flow = gf.Flow('g-1') engine = self._make_engine(flow) self.assertEqual(_EMPTY_TRANSITIONS, list(engine.run_iter())) def test_run_empty_nested_graph_flows(self): - flow = gf.Flow('g-1').add(lf.Flow('l-1'), - gf.Flow('g-2')) + flow = gf.Flow('g-1').add(lf.Flow('l-1'), gf.Flow('g-2')) engine = self._make_engine(flow) self.assertEqual(_EMPTY_TRANSITIONS, list(engine.run_iter())) def test_graph_flow_one_task(self): - flow = gf.Flow('g-1').add( - utils.ProgressingTask(name='task1') - ) + flow = gf.Flow('g-1').add(utils.ProgressingTask(name='task1')) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: engine.run() @@ -860,71 +897,96 @@ class EngineGraphFlowTest(utils.EngineTestBase): def test_graph_flow_two_independent_tasks(self): flow = gf.Flow('g-2').add( utils.ProgressingTask(name='task1'), - utils.ProgressingTask(name='task2') + utils.ProgressingTask(name='task2'), ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: engine.run() - expected = {'task2.t SUCCESS(5)', 'task2.t RUNNING', - 'task1.t RUNNING', 'task1.t SUCCESS(5)'} + expected = { + 'task2.t SUCCESS(5)', + 'task2.t RUNNING', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + } self.assertEqual(expected, set(capturer.values)) self.assertEqual(2, len(flow)) def test_graph_flow_two_tasks(self): flow = gf.Flow('g-1-1').add( utils.ProgressingTask(name='task2', requires=['a']), - utils.ProgressingTask(name='task1', provides='a') + utils.ProgressingTask(name='task1', provides='a'), ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: engine.run() - expected = ['task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)'] + expected = [ + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t SUCCESS(5)', + ] self.assertEqual(expected, capturer.values) def test_graph_flow_four_tasks_added_separately(self): - flow = (gf.Flow('g-4') - .add(utils.ProgressingTask(name='task4', - provides='d', requires=['c'])) - .add(utils.ProgressingTask(name='task2', - provides='b', requires=['a'])) - .add(utils.ProgressingTask(name='task3', - provides='c', requires=['b'])) - .add(utils.ProgressingTask(name='task1', - provides='a')) + flow = ( + gf.Flow('g-4') + .add( + utils.ProgressingTask( + name='task4', provides='d', requires=['c'] ) + ) + .add( + utils.ProgressingTask( + name='task2', provides='b', requires=['a'] + ) + ) + .add( + utils.ProgressingTask( + name='task3', provides='c', requires=['b'] + ) + ) + .add(utils.ProgressingTask(name='task1', provides='a')) + ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: engine.run() - expected = ['task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', - 'task3.t RUNNING', 'task3.t SUCCESS(5)', - 'task4.t RUNNING', 'task4.t SUCCESS(5)'] + expected = [ + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t SUCCESS(5)', + 'task3.t RUNNING', + 'task3.t SUCCESS(5)', + 'task4.t RUNNING', + 'task4.t SUCCESS(5)', + ] self.assertEqual(expected, capturer.values) def test_graph_flow_four_tasks_revert(self): flow = gf.Flow('g-4-failing').add( - utils.ProgressingTask(name='task4', - provides='d', requires=['c']), - utils.ProgressingTask(name='task2', - provides='b', requires=['a']), - utils.FailingTask(name='task3', - provides='c', requires=['b']), - utils.ProgressingTask(name='task1', provides='a')) + utils.ProgressingTask(name='task4', provides='d', requires=['c']), + utils.ProgressingTask(name='task2', provides='b', requires=['a']), + utils.FailingTask(name='task3', provides='c', requires=['b']), + utils.ProgressingTask(name='task1', provides='a'), + ) engine = self._make_engine(flow) with utils.CaptureListener(engine, capture_flow=False) as capturer: self.assertFailuresRegexp(RuntimeError, '^Woot', engine.run) - expected = ['task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', - 'task3.t RUNNING', - 'task3.t FAILURE(Failure: RuntimeError: Woot!)', - 'task3.t REVERTING', - 'task3.t REVERTED(None)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'task1.t REVERTING', - 'task1.t REVERTED(None)'] + expected = [ + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t SUCCESS(5)', + 'task3.t RUNNING', + 'task3.t FAILURE(Failure: RuntimeError: Woot!)', + 'task3.t REVERTING', + 'task3.t REVERTED(None)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + ] self.assertEqual(expected, capturer.values) self.assertEqual(states.REVERTED, engine.storage.get_flow_state()) @@ -932,7 +994,8 @@ class EngineGraphFlowTest(utils.EngineTestBase): flow = gf.Flow('g-3-nasty').add( utils.NastyTask(name='task2', provides='b', requires=['a']), utils.FailingTask(name='task3', requires=['b']), - utils.ProgressingTask(name='task1', provides='a')) + utils.ProgressingTask(name='task1', provides='a'), + ) engine = self._make_engine(flow) self.assertFailuresRegexp(RuntimeError, '^Gotcha', engine.run) @@ -940,28 +1003,28 @@ class EngineGraphFlowTest(utils.EngineTestBase): def test_graph_flow_with_multireturn_and_multiargs_tasks(self): flow = gf.Flow('g-3-multi').add( - utils.TaskMultiArgOneReturn(name='task1', - rebind=['a', 'b', 'y'], provides='z'), + utils.TaskMultiArgOneReturn( + name='task1', rebind=['a', 'b', 'y'], provides='z' + ), utils.TaskMultiReturn(name='task2', provides=['a', 'b', 'c']), - utils.TaskMultiArgOneReturn(name='task3', - rebind=['c', 'b', 'x'], provides='y')) + utils.TaskMultiArgOneReturn( + name='task3', rebind=['c', 'b', 'x'], provides='y' + ), + ) engine = self._make_engine(flow) engine.storage.inject({'x': 30}) engine.run() - self.assertEqual({ - 'a': 1, - 'b': 3, - 'c': 5, - 'x': 30, - 'y': 38, - 'z': 42 - }, engine.storage.fetch_all()) + self.assertEqual( + {'a': 1, 'b': 3, 'c': 5, 'x': 30, 'y': 38, 'z': 42}, + engine.storage.fetch_all(), + ) def test_task_graph_property(self): flow = gf.Flow('test').add( utils.TaskNoRequiresNoReturns(name='task1'), - utils.TaskNoRequiresNoReturns(name='task2')) + utils.TaskNoRequiresNoReturns(name='task2'), + ) engine = self._make_engine(flow) engine.compile() @@ -980,10 +1043,9 @@ class EngineGraphFlowTest(utils.EngineTestBase): class EngineMissingDepsTest(utils.EngineTestBase): def test_missing_deps_deep(self): flow = gf.Flow('missing-many').add( - utils.TaskOneReturn(name='task1', - requires=['a', 'b', 'c']), - utils.TaskMultiArgOneReturn(name='task2', - rebind=['e', 'f', 'g'])) + utils.TaskOneReturn(name='task1', requires=['a', 'b', 'c']), + utils.TaskMultiArgOneReturn(name='task2', rebind=['e', 'f', 'g']), + ) engine = self._make_engine(flow) engine.compile() engine.prepare() @@ -1012,10 +1074,8 @@ class EngineResetTests(utils.EngineTestBase): expected = [ 'task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', - 'task3.t RUNNING', 'task3.t SUCCESS(5)', ] @@ -1051,9 +1111,7 @@ class EngineResetTests(utils.EngineTestBase): 'task2.t RUNNING', 'task2.t SUCCESS(5)', 'task3.t RUNNING', - 'task3.t FAILURE(Failure: RuntimeError: Woot!)', - 'task3.t REVERTING', 'task3.t REVERTED(None)', 'task2.t REVERTING', @@ -1099,7 +1157,6 @@ class EngineResetTests(utils.EngineTestBase): expected = [ 'task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', ] @@ -1119,10 +1176,8 @@ class EngineResetTests(utils.EngineTestBase): expected = [ 'task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', - 'task3.t RUNNING', 'task3.t SUCCESS(5)', ] @@ -1130,7 +1185,6 @@ class EngineResetTests(utils.EngineTestBase): class EngineGraphConditionalFlowTest(utils.EngineTestBase): - def test_graph_flow_conditional_jumps_across_2(self): histories = [] @@ -1161,10 +1215,8 @@ class EngineGraphConditionalFlowTest(utils.EngineTestBase): expected = [ 'task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', - 'task3.t IGNORE', 'task4.t IGNORE', ] @@ -1200,10 +1252,8 @@ class EngineGraphConditionalFlowTest(utils.EngineTestBase): expected = [ 'task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', - 'task3.t IGNORE', 'task4.t IGNORE', ] @@ -1233,10 +1283,8 @@ class EngineGraphConditionalFlowTest(utils.EngineTestBase): expected = { 'task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t IGNORE', 'task2_2.t IGNORE', - 'task3.t RUNNING', 'task3.t SUCCESS(5)', } @@ -1261,17 +1309,15 @@ class EngineGraphConditionalFlowTest(utils.EngineTestBase): expected = { 'task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', - 'task3.t IGNORE', } self.assertEqual(expected, set(capturer.values)) - self.assertEqual(states.IGNORE, - engine.storage.get_atom_state('task3')) - self.assertEqual(states.IGNORE, - engine.storage.get_atom_intention('task3')) + self.assertEqual(states.IGNORE, engine.storage.get_atom_state('task3')) + self.assertEqual( + states.IGNORE, engine.storage.get_atom_intention('task3') + ) engine.reset() allow_execute.set() @@ -1281,10 +1327,8 @@ class EngineGraphConditionalFlowTest(utils.EngineTestBase): expected = { 'task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', - 'task3.t RUNNING', 'task3.t SUCCESS(5)', } @@ -1311,20 +1355,17 @@ class EngineGraphConditionalFlowTest(utils.EngineTestBase): expected = { 'task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', - 'task3.t RUNNING', 'task3.t SUCCESS(5)', - 'task4.t IGNORE', } self.assertEqual(expected, set(capturer.values)) - self.assertEqual(states.IGNORE, - engine.storage.get_atom_state('task4')) - self.assertEqual(states.IGNORE, - engine.storage.get_atom_intention('task4')) + self.assertEqual(states.IGNORE, engine.storage.get_atom_state('task4')) + self.assertEqual( + states.IGNORE, engine.storage.get_atom_intention('task4') + ) def test_graph_flow_conditional_history(self): @@ -1343,12 +1384,18 @@ class EngineGraphConditionalFlowTest(utils.EngineTestBase): task3_3 = utils.ProgressingTask(name='task3_3') flow.add(task1, task2, task2_2, task3, task3_3) - flow.link(task1, task2, - decider=functools.partial(even_odd_decider, allowed=2)) + flow.link( + task1, + task2, + decider=functools.partial(even_odd_decider, allowed=2), + ) flow.link(task2, task2_2) - flow.link(task1, task3, - decider=functools.partial(even_odd_decider, allowed=1)) + flow.link( + task1, + task3, + decider=functools.partial(even_odd_decider, allowed=1), + ) flow.link(task3, task3_3) engine = self._make_engine(flow) @@ -1358,10 +1405,14 @@ class EngineGraphConditionalFlowTest(utils.EngineTestBase): engine.run() expected = { - 'task1.t RUNNING', 'task1.t SUCCESS(2)', - 'task3.t IGNORE', 'task3_3.t IGNORE', - 'task2.t RUNNING', 'task2.t SUCCESS(5)', - 'task2_2.t RUNNING', 'task2_2.t SUCCESS(5)', + 'task1.t RUNNING', + 'task1.t SUCCESS(2)', + 'task3.t IGNORE', + 'task3_3.t IGNORE', + 'task2.t RUNNING', + 'task2.t SUCCESS(5)', + 'task2_2.t RUNNING', + 'task2_2.t SUCCESS(5)', } self.assertEqual(expected, set(capturer.values)) @@ -1371,10 +1422,14 @@ class EngineGraphConditionalFlowTest(utils.EngineTestBase): engine.run() expected = { - 'task1.t RUNNING', 'task1.t SUCCESS(1)', - 'task2.t IGNORE', 'task2_2.t IGNORE', - 'task3.t RUNNING', 'task3.t SUCCESS(5)', - 'task3_3.t RUNNING', 'task3_3.t SUCCESS(5)', + 'task1.t RUNNING', + 'task1.t SUCCESS(1)', + 'task2.t IGNORE', + 'task2_2.t IGNORE', + 'task3.t RUNNING', + 'task3.t SUCCESS(5)', + 'task3_3.t RUNNING', + 'task3_3.t SUCCESS(5)', } self.assertEqual(expected, set(capturer.values)) @@ -1396,35 +1451,36 @@ class EngineCheckingTaskTest(utils.EngineTestBase): self.assertIsInstance(fail, failure.Failure) self.assertEqual('Failure: RuntimeError: Woot!', str(fail)) - flow = lf.Flow('test').add( - CheckingTask(), - utils.FailingTask('fail1') - ) + flow = lf.Flow('test').add(CheckingTask(), utils.FailingTask('fail1')) engine = self._make_engine(flow) self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) -class SerialEngineTest(EngineTaskTest, - EngineMultipleResultsTest, - EngineLinearFlowTest, - EngineParallelFlowTest, - EngineLinearAndUnorderedExceptionsTest, - EngineOptionalRequirementsTest, - EngineGraphFlowTest, - EngineMissingDepsTest, - EngineResetTests, - EngineGraphConditionalFlowTest, - EngineCheckingTaskTest, - EngineDeciderDepthTest, - EngineTaskNotificationsTest, - test.TestCase): - def _make_engine(self, flow, - flow_detail=None, store=None, **kwargs): - return taskflow.engines.load(flow, - flow_detail=flow_detail, - engine='serial', - backend=self.backend, - store=store, **kwargs) +class SerialEngineTest( + EngineTaskTest, + EngineMultipleResultsTest, + EngineLinearFlowTest, + EngineParallelFlowTest, + EngineLinearAndUnorderedExceptionsTest, + EngineOptionalRequirementsTest, + EngineGraphFlowTest, + EngineMissingDepsTest, + EngineResetTests, + EngineGraphConditionalFlowTest, + EngineCheckingTaskTest, + EngineDeciderDepthTest, + EngineTaskNotificationsTest, + test.TestCase, +): + def _make_engine(self, flow, flow_detail=None, store=None, **kwargs): + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + engine='serial', + backend=self.backend, + store=store, + **kwargs, + ) def test_correct_load(self): engine = self._make_engine(utils.TaskNoRequiresNoReturns) @@ -1435,34 +1491,39 @@ class SerialEngineTest(EngineTaskTest, self.assertIsInstance(engine, eng.SerialActionEngine) -class ParallelEngineWithThreadsTest(EngineTaskTest, - EngineMultipleResultsTest, - EngineLinearFlowTest, - EngineParallelFlowTest, - EngineLinearAndUnorderedExceptionsTest, - EngineOptionalRequirementsTest, - EngineGraphFlowTest, - EngineResetTests, - EngineMissingDepsTest, - EngineGraphConditionalFlowTest, - EngineCheckingTaskTest, - EngineDeciderDepthTest, - EngineTaskNotificationsTest, - test.TestCase): +class ParallelEngineWithThreadsTest( + EngineTaskTest, + EngineMultipleResultsTest, + EngineLinearFlowTest, + EngineParallelFlowTest, + EngineLinearAndUnorderedExceptionsTest, + EngineOptionalRequirementsTest, + EngineGraphFlowTest, + EngineResetTests, + EngineMissingDepsTest, + EngineGraphConditionalFlowTest, + EngineCheckingTaskTest, + EngineDeciderDepthTest, + EngineTaskNotificationsTest, + test.TestCase, +): _EXECUTOR_WORKERS = 2 - def _make_engine(self, flow, - flow_detail=None, executor=None, store=None, - **kwargs): + def _make_engine( + self, flow, flow_detail=None, executor=None, store=None, **kwargs + ): if executor is None: executor = 'threads' - return taskflow.engines.load(flow, flow_detail=flow_detail, - backend=self.backend, - executor=executor, - engine='parallel', - store=store, - max_workers=self._EXECUTOR_WORKERS, - **kwargs) + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + backend=self.backend, + executor=executor, + engine='parallel', + store=store, + max_workers=self._EXECUTOR_WORKERS, + **kwargs, + ) def test_correct_load(self): engine = self._make_engine(utils.TaskNoRequiresNoReturns) @@ -1480,45 +1541,53 @@ class ParallelEngineWithThreadsTest(EngineTaskTest, @testtools.skipIf(not eu.EVENTLET_AVAILABLE, 'eventlet is not available') -class ParallelEngineWithEventletTest(EngineTaskTest, - EngineMultipleResultsTest, - EngineLinearFlowTest, - EngineParallelFlowTest, - EngineLinearAndUnorderedExceptionsTest, - EngineOptionalRequirementsTest, - EngineGraphFlowTest, - EngineResetTests, - EngineMissingDepsTest, - EngineGraphConditionalFlowTest, - EngineCheckingTaskTest, - EngineDeciderDepthTest, - EngineTaskNotificationsTest, - test.TestCase): - - def _make_engine(self, flow, - flow_detail=None, executor=None, store=None, - **kwargs): +class ParallelEngineWithEventletTest( + EngineTaskTest, + EngineMultipleResultsTest, + EngineLinearFlowTest, + EngineParallelFlowTest, + EngineLinearAndUnorderedExceptionsTest, + EngineOptionalRequirementsTest, + EngineGraphFlowTest, + EngineResetTests, + EngineMissingDepsTest, + EngineGraphConditionalFlowTest, + EngineCheckingTaskTest, + EngineDeciderDepthTest, + EngineTaskNotificationsTest, + test.TestCase, +): + def _make_engine( + self, flow, flow_detail=None, executor=None, store=None, **kwargs + ): if executor is None: executor = 'greenthreads' - return taskflow.engines.load(flow, flow_detail=flow_detail, - backend=self.backend, engine='parallel', - executor=executor, - store=store, **kwargs) + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + backend=self.backend, + engine='parallel', + executor=executor, + store=store, + **kwargs, + ) -class WorkerBasedEngineTest(EngineTaskTest, - EngineMultipleResultsTest, - EngineLinearFlowTest, - EngineParallelFlowTest, - EngineLinearAndUnorderedExceptionsTest, - EngineOptionalRequirementsTest, - EngineGraphFlowTest, - EngineResetTests, - EngineMissingDepsTest, - EngineGraphConditionalFlowTest, - EngineDeciderDepthTest, - EngineTaskNotificationsTest, - test.TestCase): +class WorkerBasedEngineTest( + EngineTaskTest, + EngineMultipleResultsTest, + EngineLinearFlowTest, + EngineParallelFlowTest, + EngineLinearAndUnorderedExceptionsTest, + EngineOptionalRequirementsTest, + EngineGraphFlowTest, + EngineResetTests, + EngineMissingDepsTest, + EngineGraphConditionalFlowTest, + EngineDeciderDepthTest, + EngineTaskNotificationsTest, + test.TestCase, +): def setUp(self): super().setUp() shared_conf = { @@ -1535,20 +1604,24 @@ class WorkerBasedEngineTest(EngineTaskTest, }, } worker_conf = shared_conf.copy() - worker_conf.update({ - 'topic': 'my-topic', - 'tasks': [ - # This makes it possible for the worker to run/find any atoms - # that are defined in the test.utils module (which are all - # the task/atom types that this test uses)... - utils.__name__, - ], - }) + worker_conf.update( + { + 'topic': 'my-topic', + 'tasks': [ + # This makes it possible for the worker to run/find any + # atoms that are defined in the test.utils module (which + # are all the task/atom types that this test uses)... + utils.__name__, + ], + } + ) self.engine_conf = shared_conf.copy() - self.engine_conf.update({ - 'engine': 'worker-based', - 'topics': tuple([worker_conf['topic']]), - }) + self.engine_conf.update( + { + 'engine': 'worker-based', + 'topics': tuple([worker_conf['topic']]), + } + ) self.worker = wkr.Worker(**worker_conf) self.worker_thread = tu.daemon_thread(self.worker.run) self.worker_thread.start() @@ -1562,12 +1635,15 @@ class WorkerBasedEngineTest(EngineTaskTest, # Make sure the worker is started before we can continue... self.worker.wait() - def _make_engine(self, flow, - flow_detail=None, store=None, **kwargs): + def _make_engine(self, flow, flow_detail=None, store=None, **kwargs): kwargs.update(self.engine_conf) - return taskflow.engines.load(flow, flow_detail=flow_detail, - backend=self.backend, - store=store, **kwargs) + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + backend=self.backend, + store=store, + **kwargs, + ) def test_correct_load(self): engine = self._make_engine(utils.TaskNoRequiresNoReturns) diff --git a/taskflow/tests/unit/test_exceptions.py b/taskflow/tests/unit/test_exceptions.py index ba837ff19..b9b6cc886 100644 --- a/taskflow/tests/unit/test_exceptions.py +++ b/taskflow/tests/unit/test_exceptions.py @@ -73,8 +73,9 @@ class TestExceptions(test.TestCase): try: raise OSError("Didn't work") except OSError: - exc.raise_with_cause(exc.TaskFlowException, - "It didn't go so well") + exc.raise_with_cause( + exc.TaskFlowException, "It didn't go so well" + ) except exc.TaskFlowException: exc.raise_with_cause(exc.TaskFlowException, "I Failed") except exc.TaskFlowException as e: @@ -93,12 +94,11 @@ class TestExceptions(test.TestCase): def test_pformat_root_class(self): ex = exc.TaskFlowException("Broken") - self.assertIn("TaskFlowException", - ex.pformat(show_root_class=True)) - self.assertNotIn("TaskFlowException", - ex.pformat(show_root_class=False)) - self.assertIn("Broken", - ex.pformat(show_root_class=True)) + self.assertIn("TaskFlowException", ex.pformat(show_root_class=True)) + self.assertNotIn( + "TaskFlowException", ex.pformat(show_root_class=False) + ) + self.assertIn("Broken", ex.pformat(show_root_class=True)) def test_invalid_pformat_indent(self): ex = exc.TaskFlowException("Broken") diff --git a/taskflow/tests/unit/test_failure.py b/taskflow/tests/unit/test_failure.py index e8eac8de2..914b99a8b 100644 --- a/taskflow/tests/unit/test_failure.py +++ b/taskflow/tests/unit/test_failure.py @@ -36,17 +36,16 @@ def _make_exc_info(msg): class GeneralFailureObjTestsMixin: - def test_captures_message(self): self.assertEqual('Woot!', self.fail_obj.exception_str) def test_str(self): - self.assertEqual('Failure: RuntimeError: Woot!', - str(self.fail_obj)) + self.assertEqual('Failure: RuntimeError: Woot!', str(self.fail_obj)) def test_exception_types(self): - self.assertEqual(test_utils.RUNTIME_ERROR_CLASSES[:-2], - list(self.fail_obj)) + self.assertEqual( + test_utils.RUNTIME_ERROR_CLASSES[:-2], list(self.fail_obj) + ) def test_pformat_no_traceback(self): text = self.fail_obj.pformat() @@ -68,7 +67,6 @@ class GeneralFailureObjTestsMixin: class CaptureFailureTestCase(test.TestCase, GeneralFailureObjTestsMixin): - def setUp(self): super().setUp() self.fail_obj = _captured_failure('Woot!') @@ -87,13 +85,14 @@ class CaptureFailureTestCase(test.TestCase, GeneralFailureObjTestsMixin): class ReCreatedFailureTestCase(test.TestCase, GeneralFailureObjTestsMixin): - def setUp(self): super().setUp() fail_obj = _captured_failure('Woot!') - self.fail_obj = failure.Failure(exception_str=fail_obj.exception_str, - traceback_str=fail_obj.traceback_str, - exc_type_names=list(fail_obj)) + self.fail_obj = failure.Failure( + exception_str=fail_obj.exception_str, + traceback_str=fail_obj.traceback_str, + exc_type_names=list(fail_obj), + ) def test_value_lost(self): self.assertIsNone(self.fail_obj.exception) @@ -106,21 +105,23 @@ class ReCreatedFailureTestCase(test.TestCase, GeneralFailureObjTestsMixin): self.assertIn("Traceback (most recent call last):", text) def test_reraises(self): - exc = self.assertRaises(exceptions.WrappedFailure, - self.fail_obj.reraise) + exc = self.assertRaises( + exceptions.WrappedFailure, self.fail_obj.reraise + ) self.assertIs(exc.check(RuntimeError), RuntimeError) def test_no_type_names(self): fail_obj = _captured_failure('Woot!') - fail_obj = failure.Failure(exception_str=fail_obj.exception_str, - traceback_str=fail_obj.traceback_str, - exc_type_names=[]) + fail_obj = failure.Failure( + exception_str=fail_obj.exception_str, + traceback_str=fail_obj.traceback_str, + exc_type_names=[], + ) self.assertEqual([], list(fail_obj)) self.assertEqual("Failure: Woot!", fail_obj.pformat()) class FromExceptionTestCase(test.TestCase, GeneralFailureObjTestsMixin): - def setUp(self): super().setUp() self.fail_obj = failure.Failure.from_exception(RuntimeError('Woot!')) @@ -131,29 +132,31 @@ class FromExceptionTestCase(test.TestCase, GeneralFailureObjTestsMixin): class FailureObjectTestCase(test.TestCase): - def test_invalids(self): f = { 'exception_str': 'blah', 'traceback_str': 'blah', 'exc_type_names': [], } - self.assertRaises(exceptions.InvalidFormat, - failure.Failure.validate, f) + self.assertRaises( + exceptions.InvalidFormat, failure.Failure.validate, f + ) f = { 'exception_str': 'blah', 'exc_type_names': ['Exception'], } - self.assertRaises(exceptions.InvalidFormat, - failure.Failure.validate, f) + self.assertRaises( + exceptions.InvalidFormat, failure.Failure.validate, f + ) f = { 'exception_str': 'blah', 'traceback_str': 'blah', 'exc_type_names': ['Exception'], 'version': -1, } - self.assertRaises(exceptions.InvalidFormat, - failure.Failure.validate, f) + self.assertRaises( + exceptions.InvalidFormat, failure.Failure.validate, f + ) def test_valid_from_dict_to_dict(self): f = _captured_failure('Woot!') @@ -166,8 +169,9 @@ class FailureObjectTestCase(test.TestCase): f = _captured_failure('Woot!') d_f = f.to_dict() d_f['exc_type_names'] = ['Junk'] - self.assertRaises(exceptions.InvalidFormat, - failure.Failure.validate, d_f) + self.assertRaises( + exceptions.InvalidFormat, failure.Failure.validate, d_f + ) def test_valid_from_dict_to_dict_2(self): f = _captured_failure('Woot!') @@ -190,11 +194,14 @@ class FailureObjectTestCase(test.TestCase): self.assertRaises(TypeError, failure.Failure) def test_unknown_argument(self): - exc = self.assertRaises(TypeError, failure.Failure, - exception_str='Woot!', - traceback_str=None, - exc_type_names=['Exception'], - hi='hi there') + exc = self.assertRaises( + TypeError, + failure.Failure, + exception_str='Woot!', + traceback_str=None, + exc_type_names=['Exception'], + hi='hi there', + ) expected = "Failure.__init__ got unexpected keyword argument(s): hi" self.assertEqual(expected, str(exc)) @@ -203,16 +210,15 @@ class FailureObjectTestCase(test.TestCase): def test_reraises_one(self): fls = [_captured_failure('Woot!')] - self.assertRaisesRegex(RuntimeError, '^Woot!$', - failure.Failure.reraise_if_any, fls) + self.assertRaisesRegex( + RuntimeError, '^Woot!$', failure.Failure.reraise_if_any, fls + ) def test_reraises_several(self): - fls = [ - _captured_failure('Woot!'), - _captured_failure('Oh, not again!') - ] - exc = self.assertRaises(exceptions.WrappedFailure, - failure.Failure.reraise_if_any, fls) + fls = [_captured_failure('Woot!'), _captured_failure('Oh, not again!')] + exc = self.assertRaises( + exceptions.WrappedFailure, failure.Failure.reraise_if_any, fls + ) self.assertEqual(fls, list(exc)) def test_failure_copy(self): @@ -225,9 +231,11 @@ class FailureObjectTestCase(test.TestCase): def test_failure_copy_recaptured(self): captured = _captured_failure('Woot!') - fail_obj = failure.Failure(exception_str=captured.exception_str, - traceback_str=captured.traceback_str, - exc_type_names=list(captured)) + fail_obj = failure.Failure( + exception_str=captured.exception_str, + traceback_str=captured.traceback_str, + exc_type_names=list(captured), + ) copied = fail_obj.copy() self.assertIsNot(fail_obj, copied) self.assertEqual(fail_obj, copied) @@ -235,10 +243,12 @@ class FailureObjectTestCase(test.TestCase): def test_recaptured_not_eq(self): captured = _captured_failure('Woot!') - fail_obj = failure.Failure(exception_str=captured.exception_str, - traceback_str=captured.traceback_str, - exc_type_names=list(captured), - exc_args=list(captured.exception_args)) + fail_obj = failure.Failure( + exception_str=captured.exception_str, + traceback_str=captured.traceback_str, + exc_type_names=list(captured), + exc_args=list(captured.exception_args), + ) self.assertNotEqual(fail_obj, captured) self.assertTrue(fail_obj.matches(captured)) @@ -249,13 +259,17 @@ class FailureObjectTestCase(test.TestCase): def test_two_recaptured_neq(self): captured = _captured_failure('Woot!') - fail_obj = failure.Failure(exception_str=captured.exception_str, - traceback_str=captured.traceback_str, - exc_type_names=list(captured)) + fail_obj = failure.Failure( + exception_str=captured.exception_str, + traceback_str=captured.traceback_str, + exc_type_names=list(captured), + ) new_exc_str = captured.exception_str.replace('Woot', 'w00t') - fail_obj2 = failure.Failure(exception_str=new_exc_str, - traceback_str=captured.traceback_str, - exc_type_names=list(captured)) + fail_obj2 = failure.Failure( + exception_str=new_exc_str, + traceback_str=captured.traceback_str, + exc_type_names=list(captured), + ) self.assertNotEqual(fail_obj, fail_obj2) self.assertFalse(fail_obj2.matches(fail_obj)) @@ -277,17 +291,18 @@ class FailureObjectTestCase(test.TestCase): def test_no_capture_exc_args(self): captured = _captured_failure(Exception("I am not valid JSON")) - fail_obj = failure.Failure(exception_str=captured.exception_str, - traceback_str=captured.traceback_str, - exc_type_names=list(captured), - exc_args=list(captured.exception_args)) + fail_obj = failure.Failure( + exception_str=captured.exception_str, + traceback_str=captured.traceback_str, + exc_type_names=list(captured), + exc_args=list(captured.exception_args), + ) fail_json = fail_obj.to_dict(include_args=False) self.assertNotEqual(fail_obj.exception_args, fail_json['exc_args']) self.assertEqual(fail_json['exc_args'], tuple()) class WrappedFailureTestCase(test.TestCase): - def test_simple_iter(self): fail_obj = _captured_failure('Woot!') wf = exceptions.WrappedFailure([fail_obj]) @@ -301,10 +316,7 @@ class WrappedFailureTestCase(test.TestCase): self.assertIsNone(wf.check(ValueError)) def test_two_failures(self): - fls = [ - _captured_failure('Woot!'), - _captured_failure('Oh, not again!') - ] + fls = [_captured_failure('Woot!'), _captured_failure('Oh, not again!')] wf = exceptions.WrappedFailure(fls) self.assertEqual(2, len(wf)) self.assertEqual(fls, list(wf)) @@ -323,7 +335,6 @@ class WrappedFailureTestCase(test.TestCase): class NonAsciiExceptionsTestCase(test.TestCase): - def test_exception_with_non_ascii_str(self): bad_string = chr(200) excp = ValueError(bad_string) @@ -337,8 +348,7 @@ class NonAsciiExceptionsTestCase(test.TestCase): fail = failure.Failure.from_exception(ValueError(hi_ru)) self.assertEqual(hi_ru, fail.exception_str) self.assertIsInstance(fail.exception_str, str) - self.assertEqual('Failure: ValueError: %s' % hi_ru, - str(fail)) + self.assertEqual('Failure: ValueError: %s' % hi_ru, str(fail)) def test_wrapped_failure_non_ascii_unicode(self): hi_cn = '嗨' @@ -346,8 +356,7 @@ class NonAsciiExceptionsTestCase(test.TestCase): self.assertEqual(hi_cn, str(fail)) fail = failure.Failure.from_exception(fail) wrapped_fail = exceptions.WrappedFailure([fail]) - expected_result = ("WrappedFailure: " - "[Failure: ValueError: %s]" % (hi_cn)) + expected_result = "WrappedFailure: [Failure: ValueError: %s]" % (hi_cn) self.assertEqual(expected_result, str(wrapped_fail)) def test_failure_equality_with_non_ascii_str(self): @@ -364,7 +373,6 @@ class NonAsciiExceptionsTestCase(test.TestCase): class FailureCausesTest(test.TestCase): - @classmethod def _raise_many(cls, messages): if not messages: @@ -380,8 +388,9 @@ class FailureCausesTest(test.TestCase): def test_causes(self): f = None try: - self._raise_many(["Still still not working", - "Still not working", "Not working"]) + self._raise_many( + ["Still still not working", "Still not working", "Not working"] + ) except RuntimeError: f = failure.Failure() @@ -400,8 +409,9 @@ class FailureCausesTest(test.TestCase): def test_causes_to_from_dict(self): f = None try: - self._raise_many(["Still still not working", - "Still not working", "Not working"]) + self._raise_many( + ["Still still not working", "Still not working", "Not working"] + ) except RuntimeError: f = failure.Failure() @@ -423,8 +433,9 @@ class FailureCausesTest(test.TestCase): def test_causes_pickle(self): f = None try: - self._raise_many(["Still still not working", - "Still not working", "Not working"]) + self._raise_many( + ["Still still not working", "Still not working", "Not working"] + ) except RuntimeError: f = failure.Failure() @@ -447,8 +458,13 @@ class FailureCausesTest(test.TestCase): f = None try: try: - self._raise_many(["Still still not working", - "Still not working", "Not working"]) + self._raise_many( + [ + "Still still not working", + "Still not working", + "Not working", + ] + ) except RuntimeError as e: raise e from None except RuntimeError: diff --git a/taskflow/tests/unit/test_flow_dependencies.py b/taskflow/tests/unit/test_flow_dependencies.py index 583fd8002..f4494762a 100644 --- a/taskflow/tests/unit/test_flow_dependencies.py +++ b/taskflow/tests/unit/test_flow_dependencies.py @@ -22,7 +22,6 @@ from taskflow.tests import utils class FlowDependenciesTest(test.TestCase): - def test_task_without_dependencies(self): flow = utils.TaskNoRequiresNoReturns() self.assertEqual(set(), flow.requires) @@ -31,7 +30,10 @@ class FlowDependenciesTest(test.TestCase): def test_task_requires_default_values(self): flow = utils.TaskMultiArg() self.assertEqual({'x', 'y', 'z'}, flow.requires) - self.assertEqual(set(), flow.provides, ) + self.assertEqual( + set(), + flow.provides, + ) def test_task_requires_rebinded_mapped(self): flow = utils.TaskMultiArg(rebind={'x': 'a', 'y': 'b', 'z': 'c'}) @@ -56,82 +58,93 @@ class FlowDependenciesTest(test.TestCase): def test_linear_flow_without_dependencies(self): flow = lf.Flow('lf').add( utils.TaskNoRequiresNoReturns('task1'), - utils.TaskNoRequiresNoReturns('task2')) + utils.TaskNoRequiresNoReturns('task2'), + ) self.assertEqual(set(), flow.requires) self.assertEqual(set(), flow.provides) def test_linear_flow_requires_values(self): flow = lf.Flow('lf').add( - utils.TaskOneArg('task1'), - utils.TaskMultiArg('task2')) + utils.TaskOneArg('task1'), utils.TaskMultiArg('task2') + ) self.assertEqual({'x', 'y', 'z'}, flow.requires) self.assertEqual(set(), flow.provides) def test_linear_flow_requires_rebind_values(self): flow = lf.Flow('lf').add( utils.TaskOneArg('task1', rebind=['q']), - utils.TaskMultiArg('task2')) + utils.TaskMultiArg('task2'), + ) self.assertEqual({'x', 'y', 'z', 'q'}, flow.requires) self.assertEqual(set(), flow.provides) def test_linear_flow_provides_values(self): flow = lf.Flow('lf').add( utils.TaskOneReturn('task1', provides='x'), - utils.TaskMultiReturn('task2', provides=['a', 'b', 'c'])) + utils.TaskMultiReturn('task2', provides=['a', 'b', 'c']), + ) self.assertEqual(set(), flow.requires) self.assertEqual({'x', 'a', 'b', 'c'}, flow.provides) def test_linear_flow_provides_required_values(self): flow = lf.Flow('lf').add( utils.TaskOneReturn('task1', provides='x'), - utils.TaskOneArg('task2')) + utils.TaskOneArg('task2'), + ) self.assertEqual(set(), flow.requires) self.assertEqual({'x'}, flow.provides) def test_linear_flow_multi_provides_and_requires_values(self): flow = lf.Flow('lf').add( - utils.TaskMultiArgMultiReturn('task1', - rebind=['a', 'b', 'c'], - provides=['x', 'y', 'q']), - utils.TaskMultiArgMultiReturn('task2', - provides=['i', 'j', 'k'])) + utils.TaskMultiArgMultiReturn( + 'task1', rebind=['a', 'b', 'c'], provides=['x', 'y', 'q'] + ), + utils.TaskMultiArgMultiReturn('task2', provides=['i', 'j', 'k']), + ) self.assertEqual({'a', 'b', 'c', 'z'}, flow.requires) self.assertEqual({'x', 'y', 'q', 'i', 'j', 'k'}, flow.provides) def test_unordered_flow_without_dependencies(self): flow = uf.Flow('uf').add( utils.TaskNoRequiresNoReturns('task1'), - utils.TaskNoRequiresNoReturns('task2')) + utils.TaskNoRequiresNoReturns('task2'), + ) self.assertEqual(set(), flow.requires) self.assertEqual(set(), flow.provides) def test_unordered_flow_requires_values(self): flow = uf.Flow('uf').add( - utils.TaskOneArg('task1'), - utils.TaskMultiArg('task2')) + utils.TaskOneArg('task1'), utils.TaskMultiArg('task2') + ) self.assertEqual({'x', 'y', 'z'}, flow.requires) self.assertEqual(set(), flow.provides) def test_unordered_flow_requires_rebind_values(self): flow = uf.Flow('uf').add( utils.TaskOneArg('task1', rebind=['q']), - utils.TaskMultiArg('task2')) + utils.TaskMultiArg('task2'), + ) self.assertEqual({'x', 'y', 'z', 'q'}, flow.requires) self.assertEqual(set(), flow.provides) def test_unordered_flow_provides_values(self): flow = uf.Flow('uf').add( utils.TaskOneReturn('task1', provides='x'), - utils.TaskMultiReturn('task2', provides=['a', 'b', 'c'])) + utils.TaskMultiReturn('task2', provides=['a', 'b', 'c']), + ) self.assertEqual(set(), flow.requires) self.assertEqual({'x', 'a', 'b', 'c'}, flow.provides) def test_unordered_flow_provides_required_values(self): flow = uf.Flow('uf') - flow.add(utils.TaskOneReturn('task1', provides='x'), - utils.TaskOneArg('task2')) - flow.add(utils.TaskOneReturn('task1', provides='x'), - utils.TaskOneArg('task2')) + flow.add( + utils.TaskOneReturn('task1', provides='x'), + utils.TaskOneArg('task2'), + ) + flow.add( + utils.TaskOneReturn('task1', provides='x'), + utils.TaskOneArg('task2'), + ) self.assertEqual({'x'}, flow.provides) self.assertEqual({'x'}, flow.requires) @@ -152,11 +165,11 @@ class FlowDependenciesTest(test.TestCase): def test_unordered_flow_multi_provides_and_requires_values(self): flow = uf.Flow('uf').add( - utils.TaskMultiArgMultiReturn('task1', - rebind=['a', 'b', 'c'], - provides=['d', 'e', 'f']), - utils.TaskMultiArgMultiReturn('task2', - provides=['i', 'j', 'k'])) + utils.TaskMultiArgMultiReturn( + 'task1', rebind=['a', 'b', 'c'], provides=['d', 'e', 'f'] + ), + utils.TaskMultiArgMultiReturn('task2', provides=['i', 'j', 'k']), + ) self.assertEqual({'a', 'b', 'c', 'x', 'y', 'z'}, flow.requires) self.assertEqual({'d', 'e', 'f', 'i', 'j', 'k'}, flow.provides) @@ -167,49 +180,60 @@ class FlowDependenciesTest(test.TestCase): def test_unordered_flow_provides_same_values_one_add(self): flow = uf.Flow('uf') - flow.add(utils.TaskOneReturn(provides='x'), - utils.TaskOneReturn(provides='x')) + flow.add( + utils.TaskOneReturn(provides='x'), + utils.TaskOneReturn(provides='x'), + ) self.assertEqual({'x'}, flow.provides) def test_nested_flows_requirements(self): flow = uf.Flow('uf').add( lf.Flow('lf').add( - utils.TaskOneArgOneReturn('task1', - rebind=['a'], provides=['x']), - utils.TaskOneArgOneReturn('task2', provides=['y'])), + utils.TaskOneArgOneReturn( + 'task1', rebind=['a'], provides=['x'] + ), + utils.TaskOneArgOneReturn('task2', provides=['y']), + ), uf.Flow('uf').add( - utils.TaskOneArgOneReturn('task3', - rebind=['b'], provides=['z']), - utils.TaskOneArgOneReturn('task4', rebind=['c'], - provides=['q']))) + utils.TaskOneArgOneReturn( + 'task3', rebind=['b'], provides=['z'] + ), + utils.TaskOneArgOneReturn( + 'task4', rebind=['c'], provides=['q'] + ), + ), + ) self.assertEqual({'a', 'b', 'c'}, flow.requires) self.assertEqual({'x', 'y', 'z', 'q'}, flow.provides) def test_graph_flow_requires_values(self): flow = gf.Flow('gf').add( - utils.TaskOneArg('task1'), - utils.TaskMultiArg('task2')) + utils.TaskOneArg('task1'), utils.TaskMultiArg('task2') + ) self.assertEqual({'x', 'y', 'z'}, flow.requires) self.assertEqual(set(), flow.provides) def test_graph_flow_requires_rebind_values(self): flow = gf.Flow('gf').add( utils.TaskOneArg('task1', rebind=['q']), - utils.TaskMultiArg('task2')) + utils.TaskMultiArg('task2'), + ) self.assertEqual({'x', 'y', 'z', 'q'}, flow.requires) self.assertEqual(set(), flow.provides) def test_graph_flow_provides_values(self): flow = gf.Flow('gf').add( utils.TaskOneReturn('task1', provides='x'), - utils.TaskMultiReturn('task2', provides=['a', 'b', 'c'])) + utils.TaskMultiReturn('task2', provides=['a', 'b', 'c']), + ) self.assertEqual(set(), flow.requires) self.assertEqual({'x', 'a', 'b', 'c'}, flow.provides) def test_graph_flow_provides_required_values(self): flow = gf.Flow('gf').add( utils.TaskOneReturn('task1', provides='x'), - utils.TaskOneArg('task2')) + utils.TaskOneArg('task2'), + ) self.assertEqual(set(), flow.requires) self.assertEqual({'x'}, flow.provides) @@ -221,28 +245,29 @@ class FlowDependenciesTest(test.TestCase): def test_graph_flow_multi_provides_and_requires_values(self): flow = gf.Flow('gf').add( - utils.TaskMultiArgMultiReturn('task1', - rebind=['a', 'b', 'c'], - provides=['d', 'e', 'f']), - utils.TaskMultiArgMultiReturn('task2', - provides=['i', 'j', 'k'])) + utils.TaskMultiArgMultiReturn( + 'task1', rebind=['a', 'b', 'c'], provides=['d', 'e', 'f'] + ), + utils.TaskMultiArgMultiReturn('task2', provides=['i', 'j', 'k']), + ) self.assertEqual({'a', 'b', 'c', 'x', 'y', 'z'}, flow.requires) self.assertEqual({'d', 'e', 'f', 'i', 'j', 'k'}, flow.provides) def test_graph_cyclic_dependency(self): flow = gf.Flow('g-3-cyclic') - self.assertRaisesRegex(exceptions.DependencyFailure, '^No path', - flow.add, - utils.TaskOneArgOneReturn(provides='a', - requires=['b']), - utils.TaskOneArgOneReturn(provides='b', - requires=['c']), - utils.TaskOneArgOneReturn(provides='c', - requires=['a'])) + self.assertRaisesRegex( + exceptions.DependencyFailure, + '^No path', + flow.add, + utils.TaskOneArgOneReturn(provides='a', requires=['b']), + utils.TaskOneArgOneReturn(provides='b', requires=['c']), + utils.TaskOneArgOneReturn(provides='c', requires=['a']), + ) def test_task_requires_and_provides_same_values(self): - flow = lf.Flow('lf', utils.TaskOneArgOneReturn('rt', requires='x', - provides='x')) + flow = lf.Flow( + 'lf', utils.TaskOneArgOneReturn('rt', requires='x', provides='x') + ) self.assertEqual(set('x'), flow.requires) self.assertEqual(set('x'), flow.provides) @@ -262,16 +287,18 @@ class FlowDependenciesTest(test.TestCase): self.assertEqual({'x', 'y'}, flow.provides) def test_retry_in_linear_flow_requires_and_provides(self): - flow = lf.Flow('lf', retry.AlwaysRevert('rt', - requires=['x', 'y'], - provides=['a', 'b'])) + flow = lf.Flow( + 'lf', + retry.AlwaysRevert('rt', requires=['x', 'y'], provides=['a', 'b']), + ) self.assertEqual({'x', 'y'}, flow.requires) self.assertEqual({'a', 'b'}, flow.provides) def test_retry_requires_and_provides_same_value(self): - flow = lf.Flow('lf', retry.AlwaysRevert('rt', - requires=['x', 'y'], - provides=['x', 'y'])) + flow = lf.Flow( + 'lf', + retry.AlwaysRevert('rt', requires=['x', 'y'], provides=['x', 'y']), + ) self.assertEqual({'x', 'y'}, flow.requires) self.assertEqual({'x', 'y'}, flow.provides) @@ -291,9 +318,10 @@ class FlowDependenciesTest(test.TestCase): self.assertEqual({'x', 'y'}, flow.provides) def test_retry_in_unordered_flow_requires_and_provides(self): - flow = uf.Flow('uf', retry.AlwaysRevert('rt', - requires=['x', 'y'], - provides=['a', 'b'])) + flow = uf.Flow( + 'uf', + retry.AlwaysRevert('rt', requires=['x', 'y'], provides=['a', 'b']), + ) self.assertEqual({'x', 'y'}, flow.requires) self.assertEqual({'a', 'b'}, flow.provides) @@ -313,28 +341,33 @@ class FlowDependenciesTest(test.TestCase): self.assertEqual({'x', 'y'}, flow.provides) def test_retry_in_graph_flow_requires_and_provides(self): - flow = gf.Flow('gf', retry.AlwaysRevert('rt', - requires=['x', 'y'], - provides=['a', 'b'])) + flow = gf.Flow( + 'gf', + retry.AlwaysRevert('rt', requires=['x', 'y'], provides=['a', 'b']), + ) self.assertEqual({'x', 'y'}, flow.requires) self.assertEqual({'a', 'b'}, flow.provides) def test_linear_flow_retry_and_task(self): - flow = lf.Flow('lf', retry.AlwaysRevert('rt', - requires=['x', 'y'], - provides=['a', 'b'])) - flow.add(utils.TaskMultiArgOneReturn(rebind=['a', 'x', 'c'], - provides=['z'])) + flow = lf.Flow( + 'lf', + retry.AlwaysRevert('rt', requires=['x', 'y'], provides=['a', 'b']), + ) + flow.add( + utils.TaskMultiArgOneReturn(rebind=['a', 'x', 'c'], provides=['z']) + ) self.assertEqual({'x', 'y', 'c'}, flow.requires) self.assertEqual({'a', 'b', 'z'}, flow.provides) def test_unordered_flow_retry_and_task(self): - flow = uf.Flow('uf', retry.AlwaysRevert('rt', - requires=['x', 'y'], - provides=['a', 'b'])) - flow.add(utils.TaskMultiArgOneReturn(rebind=['a', 'x', 'c'], - provides=['z'])) + flow = uf.Flow( + 'uf', + retry.AlwaysRevert('rt', requires=['x', 'y'], provides=['a', 'b']), + ) + flow.add( + utils.TaskMultiArgOneReturn(rebind=['a', 'x', 'c'], provides=['z']) + ) self.assertEqual({'x', 'y', 'c'}, flow.requires) self.assertEqual({'a', 'b', 'z'}, flow.provides) @@ -352,16 +385,20 @@ class FlowDependenciesTest(test.TestCase): def test_unordered_flow_retry_two_tasks_provide_same_value(self): flow = uf.Flow('uf', retry.AlwaysRevert('rt', provides=['y'])) - flow.add(utils.TaskOneReturn('t1', provides=['x']), - utils.TaskOneReturn('t2', provides=['x'])) + flow.add( + utils.TaskOneReturn('t1', provides=['x']), + utils.TaskOneReturn('t2', provides=['x']), + ) self.assertEqual({'x', 'y'}, flow.provides) def test_graph_flow_retry_and_task(self): - flow = gf.Flow('gf', retry.AlwaysRevert('rt', - requires=['x', 'y'], - provides=['a', 'b'])) - flow.add(utils.TaskMultiArgOneReturn(rebind=['a', 'x', 'c'], - provides=['z'])) + flow = gf.Flow( + 'gf', + retry.AlwaysRevert('rt', requires=['x', 'y'], provides=['a', 'b']), + ) + flow.add( + utils.TaskMultiArgOneReturn(rebind=['a', 'x', 'c'], provides=['z']) + ) self.assertEqual({'x', 'y', 'c'}, flow.requires) self.assertEqual({'a', 'b', 'z'}, flow.provides) diff --git a/taskflow/tests/unit/test_formatters.py b/taskflow/tests/unit/test_formatters.py index 0f1285536..0e540aaff 100644 --- a/taskflow/tests/unit/test_formatters.py +++ b/taskflow/tests/unit/test_formatters.py @@ -23,7 +23,6 @@ from taskflow.test import utils as test_utils class FormattersTest(test.TestCase): - @staticmethod def _broken_atom_matcher(node): return node.item.name == 'Broken' @@ -84,7 +83,8 @@ class FormattersTest(test.TestCase): e.storage.set_atom_intention("Broken", states.EXECUTE) hide_inputs_outputs_of = ['Broken', "Happy-1", "Happy-2"] f = formatters.FailureFormatter( - e, hide_inputs_outputs_of=hide_inputs_outputs_of) + e, hide_inputs_outputs_of=hide_inputs_outputs_of + ) (exc_info, details) = f.format(fail, self._broken_atom_matcher) self.assertEqual(3, len(exc_info)) self.assertFalse(mock_get_execute.called) @@ -111,19 +111,19 @@ class FormattersTest(test.TestCase): self.assertEqual(f._mask_keys([], FILTER_KEYS), []) self.assertEqual( f._mask_keys({'a': 1, 'b': 'hello'}, FILTER_KEYS), - {'a': 1, 'b': 'hello'} + {'a': 1, 'b': 'hello'}, ) self.assertEqual( f._mask_keys({'certificate': 'secret'}, FILTER_KEYS), - {'certificate': '***'} + {'certificate': '***'}, ) self.assertEqual( f._mask_keys([{'certificate': 'secret'}, {'a': 'b'}], FILTER_KEYS), - [{'certificate': '***'}, {'a': 'b'}] + [{'certificate': '***'}, {'a': 'b'}], ) self.assertEqual( f._mask_keys({'certificate': None}, FILTER_KEYS), - {'certificate': '***'} + {'certificate': '***'}, ) data = { 'listeners': [ @@ -143,10 +143,11 @@ class FormattersTest(test.TestCase): m = f._mask_keys(data, FILTER_KEYS) self.assertEqual( m['listeners'][0]['default_tls_container_data']['certificate'], - '***' + '***', + ) + self.assertEqual( + "some string", f._mask_keys("some string", FILTER_KEYS) ) - self.assertEqual("some string", - f._mask_keys("some string", FILTER_KEYS)) self.assertEqual(12345, f._mask_keys(12345, FILTER_KEYS)) self.assertIsNone(f._mask_keys(None, FILTER_KEYS)) self.assertIs(False, f._mask_keys(False, FILTER_KEYS)) diff --git a/taskflow/tests/unit/test_functor_task.py b/taskflow/tests/unit/test_functor_task.py index 430aacb86..d09440bb5 100644 --- a/taskflow/tests/unit/test_functor_task.py +++ b/taskflow/tests/unit/test_functor_task.py @@ -23,7 +23,6 @@ def add(a, b): class BunchOfFunctions: - def __init__(self, values): self.values = values @@ -44,7 +43,6 @@ multiply = lambda x, y: x * y class FunctorTaskTest(test.TestCase): - def test_simple(self): task = base.FunctorTask(add) self.assertEqual(__name__ + '.add', task.name) @@ -59,12 +57,10 @@ class FunctorTaskTest(test.TestCase): t = base.FunctorTask flow = linear_flow.Flow('test') - flow.add( - t(bof.run_one, revert=bof.revert_one), - t(bof.run_fail) + flow.add(t(bof.run_one, revert=bof.revert_one), t(bof.run_fail)) + self.assertRaisesRegex( + RuntimeError, '^Woot', taskflow.engines.run, flow ) - self.assertRaisesRegex(RuntimeError, '^Woot', - taskflow.engines.run, flow) self.assertEqual(['one', 'fail', 'revert one'], values) def test_lambda_functors(self): @@ -73,20 +69,14 @@ class FunctorTaskTest(test.TestCase): flow = linear_flow.Flow('test') flow.add( t(five, provides='five', name='five'), - t(multiply, provides='product', name='product') + t(multiply, provides='product', name='product'), ) - flow_store = { - 'x': 2, - 'y': 3 - } + flow_store = {'x': 2, 'y': 3} result = taskflow.engines.run(flow, store=flow_store) expected = flow_store.copy() - expected.update({ - 'five': 5, - 'product': 6 - }) + expected.update({'five': 5, 'product': 6}) self.assertEqual(expected, result) diff --git a/taskflow/tests/unit/test_listeners.py b/taskflow/tests/unit/test_listeners.py index ac11910af..38f527f0d 100644 --- a/taskflow/tests/unit/test_listeners.py +++ b/taskflow/tests/unit/test_listeners.py @@ -41,17 +41,20 @@ from taskflow.utils import persistence_utils ZOOKEEPER_AVAILABLE = test_utils.zookeeper_available( - impl_zookeeper.ZookeeperJobBoard.MIN_ZK_VERSION) + impl_zookeeper.ZookeeperJobBoard.MIN_ZK_VERSION +) -_LOG_LEVELS = frozenset([ - logging.CRITICAL, - logging.DEBUG, - logging.ERROR, - logging.INFO, - logging.NOTSET, - logging.WARNING, -]) +_LOG_LEVELS = frozenset( + [ + logging.CRITICAL, + logging.DEBUG, + logging.ERROR, + logging.INFO, + logging.NOTSET, + logging.WARNING, + ] +) class SleepyTask(task.Task): @@ -68,9 +71,9 @@ class SleepyTask(task.Task): class EngineMakerMixin: def _make_engine(self, flow, flow_detail=None, backend=None): - e = taskflow.engines.load(flow, - flow_detail=flow_detail, - backend=backend) + e = taskflow.engines.load( + flow, flow_detail=flow_detail, backend=backend + ) e.compile() e.prepare() return e @@ -125,8 +128,10 @@ class TestClaimListener(test.TestCase, EngineMakerMixin): altered = 0 for p in children: if p.endswith(".lock"): - self.client.set("/taskflow/jobs/" + p, misc.binary_encode( - jsonutils.dumps({'owner': new_owner}))) + self.client.set( + "/taskflow/jobs/" + p, + misc.binary_encode(jsonutils.dumps({'owner': new_owner})), + ) altered += 1 return altered @@ -134,9 +139,15 @@ class TestClaimListener(test.TestCase, EngineMakerMixin): job = self._post_claim_job('test') f = self._make_dummy_flow(10) e = self._make_engine(f) - self.assertRaises(ValueError, claims.CheckingClaimListener, - e, job, self.board, self.board.name, - on_job_loss=1) + self.assertRaises( + ValueError, + claims.CheckingClaimListener, + e, + job, + self.board, + self.board.name, + on_job_loss=1, + ) def test_claim_lost_suspended(self): job = self._post_claim_job('test') @@ -145,8 +156,7 @@ class TestClaimListener(test.TestCase, EngineMakerMixin): try_destroy = True ran_states = [] - with claims.CheckingClaimListener(e, job, - self.board, self.board.name): + with claims.CheckingClaimListener(e, job, self.board, self.board.name): for state in e.run_iter(): ran_states.append(state) if state == states.SCHEDULING and try_destroy: @@ -166,9 +176,9 @@ class TestClaimListener(test.TestCase, EngineMakerMixin): ran_states = [] try_destroy = True destroyed_at = -1 - with claims.CheckingClaimListener(e, job, self.board, - self.board.name, - on_job_loss=handler): + with claims.CheckingClaimListener( + e, job, self.board, self.board.name, on_job_loss=handler + ): for i, state in enumerate(e.run_iter()): ran_states.append(state) if state == states.SCHEDULING and try_destroy: @@ -191,8 +201,7 @@ class TestClaimListener(test.TestCase, EngineMakerMixin): change_owner = True ran_states = [] - with claims.CheckingClaimListener(e, job, - self.board, self.board.name): + with claims.CheckingClaimListener(e, job, self.board, self.board.name): for state in e.run_iter(): ran_states.append(state) if state == states.SCHEDULING and change_owner: @@ -252,13 +261,15 @@ class TestDurationListener(test.TestCase, EngineMakerMixin): (lb, fd) = persistence_utils.temporary_flow_detail(be) e = self._make_engine(flow, fd, be) duration_listener = timing.DurationListener(e) - with mock.patch.object(duration_listener._engine.storage, - 'update_atom_metadata') as mocked_uam: + with mock.patch.object( + duration_listener._engine.storage, 'update_atom_metadata' + ) as mocked_uam: mocked_uam.side_effect = exc.StorageFailure('Woot!') with duration_listener: e.run() - mocked_warning.assert_called_once_with(mock.ANY, mock.ANY, 'task', - 'test-1', exc_info=True) + mocked_warning.assert_called_once_with( + mock.ANY, mock.ANY, 'task', 'test-1', exc_info=True + ) class TestEventTimeListener(test.TestCase, EngineMakerMixin): @@ -291,15 +302,15 @@ class TestCapturingListeners(test.TestCase, EngineMakerMixin): e = self._make_engine(flow) with test_utils.CaptureListener(e, capture_task=False) as capturer: e.run() - expected = ['test.f RUNNING', - 'test.f SUCCESS'] + expected = ['test.f RUNNING', 'test.f SUCCESS'] self.assertEqual(expected, capturer.values) class TestLoggingListeners(test.TestCase, EngineMakerMixin): def _make_logger(self, level=logging.DEBUG): log = logging.getLogger( - reflection.get_callable_name(self._get_test_method())) + reflection.get_callable_name(self._get_test_method()) + ) log.propagate = False for handler in reversed(log.handlers): log.removeHandler(handler) @@ -328,7 +339,8 @@ class TestLoggingListeners(test.TestCase, EngineMakerMixin): e = self._make_engine(flow) log, handler = self._make_logger() listener = logging_listeners.LoggingListener( - e, log=log, level=logging.INFO) + e, log=log, level=logging.INFO + ) with listener: e.run() self.assertGreater(handler.counts[logging.INFO], 0) @@ -379,7 +391,8 @@ class TestLoggingListeners(test.TestCase, EngineMakerMixin): e = self._make_engine(flow) log, handler = self._make_logger() listener = logging_listeners.DynamicLoggingListener( - e, log=log, failure_level=logging.ERROR) + e, log=log, failure_level=logging.ERROR + ) with listener: self.assertRaises(RuntimeError, e.run) self.assertGreater(handler.counts[logging.ERROR], 0) diff --git a/taskflow/tests/unit/test_mapfunctor_task.py b/taskflow/tests/unit/test_mapfunctor_task.py index 495b1502a..a610639b3 100644 --- a/taskflow/tests/unit/test_mapfunctor_task.py +++ b/taskflow/tests/unit/test_mapfunctor_task.py @@ -21,11 +21,11 @@ from taskflow import test def double(x): return x * 2 + square = lambda x: x * x class MapFunctorTaskTest(test.TestCase): - def setUp(self): super().setUp() @@ -39,40 +39,46 @@ class MapFunctorTaskTest(test.TestCase): def test_double_array(self): expected = self.flow_store.copy() - expected.update({ - 'double_a': 2, - 'double_b': 4, - 'double_c': 6, - 'double_d': 8, - 'double_e': 10, - }) + expected.update( + { + 'double_a': 2, + 'double_b': 4, + 'double_c': 6, + 'double_d': 8, + 'double_e': 10, + } + ) requires = self.flow_store.keys() provides = ["double_%s" % k for k in requires] flow = linear_flow.Flow("double array flow") - flow.add(base.MapFunctorTask(double, requires=requires, - provides=provides)) + flow.add( + base.MapFunctorTask(double, requires=requires, provides=provides) + ) result = engines.run(flow, store=self.flow_store) self.assertEqual(expected, result) def test_square_array(self): expected = self.flow_store.copy() - expected.update({ - 'square_a': 1, - 'square_b': 4, - 'square_c': 9, - 'square_d': 16, - 'square_e': 25, - }) + expected.update( + { + 'square_a': 1, + 'square_b': 4, + 'square_c': 9, + 'square_d': 16, + 'square_e': 25, + } + ) requires = self.flow_store.keys() provides = ["square_%s" % k for k in requires] flow = linear_flow.Flow("square array flow") - flow.add(base.MapFunctorTask(square, requires=requires, - provides=provides)) + flow.add( + base.MapFunctorTask(square, requires=requires, provides=provides) + ) result = engines.run(flow, store=self.flow_store) self.assertEqual(expected, result) diff --git a/taskflow/tests/unit/test_notifier.py b/taskflow/tests/unit/test_notifier.py index b266461a8..5ea7a7ae3 100644 --- a/taskflow/tests/unit/test_notifier.py +++ b/taskflow/tests/unit/test_notifier.py @@ -21,7 +21,6 @@ from taskflow.types import notifier as nt class NotifierTest(test.TestCase): - def test_notify_called(self): call_collector = [] @@ -87,37 +86,46 @@ class NotifierTest(test.TestCase): pass notifier = nt.Notifier() - self.assertRaises(KeyError, notifier.register, - nt.Notifier.ANY, call_me, - kwargs={'details': 5}) + self.assertRaises( + KeyError, + notifier.register, + nt.Notifier.ANY, + call_me, + kwargs={'details': 5}, + ) def test_not_callable(self): notifier = nt.Notifier() - self.assertRaises(ValueError, notifier.register, - nt.Notifier.ANY, 2) + self.assertRaises(ValueError, notifier.register, nt.Notifier.ANY, 2) def test_restricted_notifier(self): notifier = nt.RestrictedNotifier(['a', 'b']) - self.assertRaises(ValueError, notifier.register, - 'c', lambda *args, **kargs: None) + self.assertRaises( + ValueError, notifier.register, 'c', lambda *args, **kargs: None + ) notifier.register('b', lambda *args, **kargs: None) self.assertEqual(1, len(notifier)) def test_restricted_notifier_any(self): notifier = nt.RestrictedNotifier(['a', 'b']) - self.assertRaises(ValueError, notifier.register, - 'c', lambda *args, **kargs: None) + self.assertRaises( + ValueError, notifier.register, 'c', lambda *args, **kargs: None + ) notifier.register('b', lambda *args, **kargs: None) self.assertEqual(1, len(notifier)) - notifier.register(nt.RestrictedNotifier.ANY, - lambda *args, **kargs: None) + notifier.register( + nt.RestrictedNotifier.ANY, lambda *args, **kargs: None + ) self.assertEqual(2, len(notifier)) def test_restricted_notifier_no_any(self): notifier = nt.RestrictedNotifier(['a', 'b'], allow_any=False) - self.assertRaises(ValueError, notifier.register, - nt.RestrictedNotifier.ANY, - lambda *args, **kargs: None) + self.assertRaises( + ValueError, + notifier.register, + nt.RestrictedNotifier.ANY, + lambda *args, **kargs: None, + ) notifier.register('b', lambda *args, **kargs: None) self.assertEqual(1, len(notifier)) @@ -131,13 +139,15 @@ class NotifierTest(test.TestCase): call_me_on_success = functools.partial(call_me_on, states.SUCCESS) notifier.register(states.SUCCESS, call_me_on_success) - self.assertTrue(notifier.is_registered(states.SUCCESS, - call_me_on_success)) + self.assertTrue( + notifier.is_registered(states.SUCCESS, call_me_on_success) + ) call_me_on_any = functools.partial(call_me_on, nt.Notifier.ANY) notifier.register(nt.Notifier.ANY, call_me_on_any) - self.assertTrue(notifier.is_registered(nt.Notifier.ANY, - call_me_on_any)) + self.assertTrue( + notifier.is_registered(nt.Notifier.ANY, call_me_on_any) + ) self.assertEqual(2, len(notifier)) notifier.notify(states.SUCCESS, {}) @@ -162,11 +172,15 @@ class NotifierTest(test.TestCase): notifier = nt.Notifier() call_me_on_success = functools.partial(call_me_on, states.SUCCESS) - notifier.register(states.SUCCESS, call_me_on_success, - details_filter=when_red) + notifier.register( + states.SUCCESS, call_me_on_success, details_filter=when_red + ) self.assertEqual(1, len(notifier)) - self.assertTrue(notifier.is_registered( - states.SUCCESS, call_me_on_success, details_filter=when_red)) + self.assertTrue( + notifier.is_registered( + states.SUCCESS, call_me_on_success, details_filter=when_red + ) + ) notifier.notify(states.SUCCESS, {}) self.assertEqual(0, len(call_counts[states.SUCCESS])) @@ -190,15 +204,23 @@ class NotifierTest(test.TestCase): notifier = nt.Notifier() call_me_on_success = functools.partial(call_me_on, states.SUCCESS) - notifier.register(states.SUCCESS, call_me_on_success, - details_filter=when_red) - notifier.register(states.SUCCESS, call_me_on_success, - details_filter=when_blue) + notifier.register( + states.SUCCESS, call_me_on_success, details_filter=when_red + ) + notifier.register( + states.SUCCESS, call_me_on_success, details_filter=when_blue + ) self.assertEqual(2, len(notifier)) - self.assertTrue(notifier.is_registered( - states.SUCCESS, call_me_on_success, details_filter=when_blue)) - self.assertTrue(notifier.is_registered( - states.SUCCESS, call_me_on_success, details_filter=when_red)) + self.assertTrue( + notifier.is_registered( + states.SUCCESS, call_me_on_success, details_filter=when_blue + ) + ) + self.assertTrue( + notifier.is_registered( + states.SUCCESS, call_me_on_success, details_filter=when_red + ) + ) notifier.notify(states.SUCCESS, {}) self.assertEqual(0, len(call_counts[states.SUCCESS])) diff --git a/taskflow/tests/unit/test_progress.py b/taskflow/tests/unit/test_progress.py index bc2043e5e..611b8dcdb 100644 --- a/taskflow/tests/unit/test_progress.py +++ b/taskflow/tests/unit/test_progress.py @@ -47,9 +47,9 @@ class ProgressTaskWithDetails(task.Task): class TestProgress(test.TestCase): def _make_engine(self, flow, flow_detail=None, backend=None): - e = taskflow.engines.load(flow, - flow_detail=flow_detail, - backend=backend) + e = taskflow.engines.load( + flow, flow_detail=flow_detail, backend=backend + ) e.compile() e.prepare() return e @@ -116,10 +116,9 @@ class TestProgress(test.TestCase): self.assertEqual(1.0, end_progress) end_details = e.storage.get_task_progress_details("test") self.assertEqual(0.5, end_details.get('at_progress')) - self.assertEqual({ - 'test': 'test data', - 'foo': 'bar' - }, end_details.get('details')) + self.assertEqual( + {'test': 'test data', 'foo': 'bar'}, end_details.get('details') + ) def test_dual_storage_progress(self): fired_events = [] diff --git a/taskflow/tests/unit/test_reducefunctor_task.py b/taskflow/tests/unit/test_reducefunctor_task.py index 47cbac5bd..208b2ede3 100644 --- a/taskflow/tests/unit/test_reducefunctor_task.py +++ b/taskflow/tests/unit/test_reducefunctor_task.py @@ -21,11 +21,11 @@ from taskflow import test def sum(x, y): return x + y + multiply = lambda x, y: x * y class ReduceFunctorTaskTest(test.TestCase): - def setUp(self): super().setUp() @@ -39,32 +39,32 @@ class ReduceFunctorTaskTest(test.TestCase): def test_sum_array(self): expected = self.flow_store.copy() - expected.update({ - 'sum': 15 - }) + expected.update({'sum': 15}) requires = self.flow_store.keys() provides = 'sum' flow = linear_flow.Flow("sum array flow") - flow.add(base.ReduceFunctorTask(sum, requires=requires, - provides=provides)) + flow.add( + base.ReduceFunctorTask(sum, requires=requires, provides=provides) + ) result = engines.run(flow, store=self.flow_store) self.assertEqual(expected, result) def test_multiply_array(self): expected = self.flow_store.copy() - expected.update({ - 'product': 120 - }) + expected.update({'product': 120}) requires = self.flow_store.keys() provides = 'product' flow = linear_flow.Flow("square array flow") - flow.add(base.ReduceFunctorTask(multiply, requires=requires, - provides=provides)) + flow.add( + base.ReduceFunctorTask( + multiply, requires=requires, provides=provides + ) + ) result = engines.run(flow, store=self.flow_store) self.assertEqual(expected, result) diff --git a/taskflow/tests/unit/test_retries.py b/taskflow/tests/unit/test_retries.py index c680ed357..829788df7 100644 --- a/taskflow/tests/unit/test_retries.py +++ b/taskflow/tests/unit/test_retries.py @@ -30,7 +30,6 @@ from taskflow.utils import eventlet_utils as eu class FailingRetry(retry.Retry): - def execute(self, **kwargs): raise ValueError('OMG I FAILED') @@ -47,7 +46,6 @@ class NastyFailingRetry(FailingRetry): class RetryTest(utils.EngineTestBase): - def test_run_empty_linear_flow(self): flow = lf.Flow('flow-1', utils.OneReturnRetry(provides='x')) engine = self._make_engine(flow) @@ -68,94 +66,101 @@ class RetryTest(utils.EngineTestBase): def test_states_retry_success_linear_flow(self): flow = lf.Flow('flow-1', retry.Times(4, 'r1', provides='x')).add( - utils.ProgressingTask("task1"), - utils.ConditionalTask("task2") + utils.ProgressingTask("task1"), utils.ConditionalTask("task2") ) engine = self._make_engine(flow) engine.storage.inject({'y': 2}) with utils.CaptureListener(engine) as capturer: engine.run() self.assertEqual({'y': 2, 'x': 2}, engine.storage.fetch_all()) - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', 'r1.r SUCCESS(1)', - 'task1.t RUNNING', 'task1.t SUCCESS(5)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot!)', - 'task2.t REVERTING', 'task2.t REVERTED(None)', - 'task1.t REVERTING', 'task1.t REVERTED(None)', - 'r1.r RETRYING', - 'task1.t PENDING', - 'task2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'task2.t RUNNING', - 'task2.t SUCCESS(None)', - 'flow-1.f SUCCESS'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot!)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'r1.r RETRYING', + 'task1.t PENDING', + 'task2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t SUCCESS(None)', + 'flow-1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_states_retry_reverted_linear_flow(self): flow = lf.Flow('flow-1', retry.Times(2, 'r1', provides='x')).add( - utils.ProgressingTask("task1"), - utils.ConditionalTask("task2") + utils.ProgressingTask("task1"), utils.ConditionalTask("task2") ) engine = self._make_engine(flow) engine.storage.inject({'y': 4}) with utils.CaptureListener(engine) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) self.assertEqual({'y': 4}, engine.storage.fetch_all()) - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot!)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'task1.t REVERTING', - 'task1.t REVERTED(None)', - 'r1.r RETRYING', - 'task1.t PENDING', - 'task2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot!)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'task1.t REVERTING', - 'task1.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot!)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'r1.r RETRYING', + 'task1.t PENDING', + 'task2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot!)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) def test_states_retry_failure_linear_flow(self): flow = lf.Flow('flow-1', retry.Times(2, 'r1', provides='x')).add( - utils.NastyTask("task1"), - utils.ConditionalTask("task2") + utils.NastyTask("task1"), utils.ConditionalTask("task2") ) engine = self._make_engine(flow) engine.storage.inject({'y': 4}) with utils.CaptureListener(engine) as capturer: self.assertRaisesRegex(RuntimeError, '^Gotcha', engine.run) self.assertEqual({'y': 4, 'x': 1}, engine.storage.fetch_all()) - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 'task1.t RUNNING', - 'task1.t SUCCESS(None)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot!)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'task1.t REVERTING', - 'task1.t REVERT_FAILURE(Failure: RuntimeError: Gotcha!)', - 'flow-1.f FAILURE'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 'task1.t RUNNING', + 'task1.t SUCCESS(None)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot!)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'task1.t REVERTING', + 'task1.t REVERT_FAILURE(Failure: RuntimeError: Gotcha!)', + 'flow-1.f FAILURE', + ] self.assertEqual(expected, capturer.values) def test_states_retry_failure_nested_flow_fails(self): @@ -163,42 +168,44 @@ class RetryTest(utils.EngineTestBase): utils.TaskNoRequiresNoReturns("task1"), lf.Flow('flow-2', retry.Times(3, 'r2', provides='x')).add( utils.TaskNoRequiresNoReturns("task2"), - utils.ConditionalTask("task3") + utils.ConditionalTask("task3"), ), - utils.TaskNoRequiresNoReturns("task4") + utils.TaskNoRequiresNoReturns("task4"), ) engine = self._make_engine(flow) engine.storage.inject({'y': 2}) with utils.CaptureListener(engine) as capturer: engine.run() self.assertEqual({'y': 2, 'x': 2}, engine.storage.fetch_all()) - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', - 'r1.r SUCCESS(None)', - 'task1.t RUNNING', - 'task1.t SUCCESS(None)', - 'r2.r RUNNING', - 'r2.r SUCCESS(1)', - 'task2.t RUNNING', - 'task2.t SUCCESS(None)', - 'task3.t RUNNING', - 'task3.t FAILURE(Failure: RuntimeError: Woot!)', - 'task3.t REVERTING', - 'task3.t REVERTED(None)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r2.r RETRYING', - 'task2.t PENDING', - 'task3.t PENDING', - 'r2.r RUNNING', - 'r2.r SUCCESS(2)', - 'task2.t RUNNING', - 'task2.t SUCCESS(None)', - 'task3.t RUNNING', - 'task3.t SUCCESS(None)', - 'task4.t RUNNING', - 'task4.t SUCCESS(None)', - 'flow-1.f SUCCESS'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(None)', + 'task1.t RUNNING', + 'task1.t SUCCESS(None)', + 'r2.r RUNNING', + 'r2.r SUCCESS(1)', + 'task2.t RUNNING', + 'task2.t SUCCESS(None)', + 'task3.t RUNNING', + 'task3.t FAILURE(Failure: RuntimeError: Woot!)', + 'task3.t REVERTING', + 'task3.t REVERTED(None)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r2.r RETRYING', + 'task2.t PENDING', + 'task3.t PENDING', + 'r2.r RUNNING', + 'r2.r SUCCESS(2)', + 'task2.t RUNNING', + 'task2.t SUCCESS(None)', + 'task3.t RUNNING', + 'task3.t SUCCESS(None)', + 'task4.t RUNNING', + 'task4.t SUCCESS(None)', + 'flow-1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_new_revert_vs_old(self): @@ -206,9 +213,9 @@ class RetryTest(utils.EngineTestBase): utils.TaskNoRequiresNoReturns("task1"), lf.Flow('flow-2', retry.Times(1, 'r1', provides='x')).add( utils.TaskNoRequiresNoReturns("task2"), - utils.ConditionalTask("task3") + utils.ConditionalTask("task3"), ), - utils.TaskNoRequiresNoReturns("task4") + utils.TaskNoRequiresNoReturns("task4"), ) engine = self._make_engine(flow) engine.storage.inject({'y': 2}) @@ -218,22 +225,24 @@ class RetryTest(utils.EngineTestBase): except Exception: pass - expected = ['flow-1.f RUNNING', - 'task1.t RUNNING', - 'task1.t SUCCESS(None)', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 'task2.t RUNNING', - 'task2.t SUCCESS(None)', - 'task3.t RUNNING', - 'task3.t FAILURE(Failure: RuntimeError: Woot!)', - 'task3.t REVERTING', - 'task3.t REVERTED(None)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'task1.t RUNNING', + 'task1.t SUCCESS(None)', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 'task2.t RUNNING', + 'task2.t SUCCESS(None)', + 'task3.t RUNNING', + 'task3.t FAILURE(Failure: RuntimeError: Woot!)', + 'task3.t REVERTING', + 'task3.t REVERTED(None)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) engine = self._make_engine(flow, defer_reverts=True) @@ -244,24 +253,26 @@ class RetryTest(utils.EngineTestBase): except Exception: pass - expected = ['flow-1.f RUNNING', - 'task1.t RUNNING', - 'task1.t SUCCESS(None)', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 'task2.t RUNNING', - 'task2.t SUCCESS(None)', - 'task3.t RUNNING', - 'task3.t FAILURE(Failure: RuntimeError: Woot!)', - 'task3.t REVERTING', - 'task3.t REVERTED(None)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'task1.t REVERTING', - 'task1.t REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'task1.t RUNNING', + 'task1.t SUCCESS(None)', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 'task2.t RUNNING', + 'task2.t SUCCESS(None)', + 'task3.t RUNNING', + 'task3.t FAILURE(Failure: RuntimeError: Woot!)', + 'task3.t REVERTING', + 'task3.t REVERTED(None)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) def test_states_retry_failure_parent_flow_fails(self): @@ -269,92 +280,95 @@ class RetryTest(utils.EngineTestBase): utils.TaskNoRequiresNoReturns("task1"), lf.Flow('flow-2', retry.Times(3, 'r2', provides='x2')).add( utils.TaskNoRequiresNoReturns("task2"), - utils.TaskNoRequiresNoReturns("task3") + utils.TaskNoRequiresNoReturns("task3"), ), - utils.ConditionalTask("task4", rebind={'x': 'x1'}) + utils.ConditionalTask("task4", rebind={'x': 'x1'}), ) engine = self._make_engine(flow) engine.storage.inject({'y': 2}) with utils.CaptureListener(engine) as capturer: engine.run() - self.assertEqual({'y': 2, 'x1': 2, - 'x2': 1}, - engine.storage.fetch_all()) - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 'task1.t RUNNING', - 'task1.t SUCCESS(None)', - 'r2.r RUNNING', - 'r2.r SUCCESS(1)', - 'task2.t RUNNING', - 'task2.t SUCCESS(None)', - 'task3.t RUNNING', - 'task3.t SUCCESS(None)', - 'task4.t RUNNING', - 'task4.t FAILURE(Failure: RuntimeError: Woot!)', - 'task4.t REVERTING', - 'task4.t REVERTED(None)', - 'task3.t REVERTING', - 'task3.t REVERTED(None)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r2.r REVERTING', - 'r2.r REVERTED(None)', - 'task1.t REVERTING', - 'task1.t REVERTED(None)', - 'r1.r RETRYING', - 'task1.t PENDING', - 'r2.r PENDING', - 'task2.t PENDING', - 'task3.t PENDING', - 'task4.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 'task1.t RUNNING', - 'task1.t SUCCESS(None)', - 'r2.r RUNNING', - 'r2.r SUCCESS(1)', - 'task2.t RUNNING', - 'task2.t SUCCESS(None)', - 'task3.t RUNNING', - 'task3.t SUCCESS(None)', - 'task4.t RUNNING', - 'task4.t SUCCESS(None)', - 'flow-1.f SUCCESS'] + self.assertEqual( + {'y': 2, 'x1': 2, 'x2': 1}, engine.storage.fetch_all() + ) + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 'task1.t RUNNING', + 'task1.t SUCCESS(None)', + 'r2.r RUNNING', + 'r2.r SUCCESS(1)', + 'task2.t RUNNING', + 'task2.t SUCCESS(None)', + 'task3.t RUNNING', + 'task3.t SUCCESS(None)', + 'task4.t RUNNING', + 'task4.t FAILURE(Failure: RuntimeError: Woot!)', + 'task4.t REVERTING', + 'task4.t REVERTED(None)', + 'task3.t REVERTING', + 'task3.t REVERTED(None)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r2.r REVERTING', + 'r2.r REVERTED(None)', + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'r1.r RETRYING', + 'task1.t PENDING', + 'r2.r PENDING', + 'task2.t PENDING', + 'task3.t PENDING', + 'task4.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 'task1.t RUNNING', + 'task1.t SUCCESS(None)', + 'r2.r RUNNING', + 'r2.r SUCCESS(1)', + 'task2.t RUNNING', + 'task2.t SUCCESS(None)', + 'task3.t RUNNING', + 'task3.t SUCCESS(None)', + 'task4.t RUNNING', + 'task4.t SUCCESS(None)', + 'flow-1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_unordered_flow_task_fails_parallel_tasks_should_be_reverted(self): flow = uf.Flow('flow-1', retry.Times(3, 'r', provides='x')).add( - utils.ProgressingTask("task1"), - utils.ConditionalTask("task2") + utils.ProgressingTask("task1"), utils.ConditionalTask("task2") ) engine = self._make_engine(flow) engine.storage.inject({'y': 2}) with utils.CaptureListener(engine) as capturer: engine.run() self.assertEqual({'y': 2, 'x': 2}, engine.storage.fetch_all()) - expected = ['flow-1.f RUNNING', - 'r.r RUNNING', - 'r.r SUCCESS(1)', - 'task1.t RUNNING', - 'task2.t RUNNING', - 'task1.t SUCCESS(5)', - 'task2.t FAILURE(Failure: RuntimeError: Woot!)', - 'task2.t REVERTING', - 'task1.t REVERTING', - 'task2.t REVERTED(None)', - 'task1.t REVERTED(None)', - 'r.r RETRYING', - 'task1.t PENDING', - 'task2.t PENDING', - 'r.r RUNNING', - 'r.r SUCCESS(2)', - 'task1.t RUNNING', - 'task2.t RUNNING', - 'task1.t SUCCESS(5)', - 'task2.t SUCCESS(None)', - 'flow-1.f SUCCESS'] + expected = [ + 'flow-1.f RUNNING', + 'r.r RUNNING', + 'r.r SUCCESS(1)', + 'task1.t RUNNING', + 'task2.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t FAILURE(Failure: RuntimeError: Woot!)', + 'task2.t REVERTING', + 'task1.t REVERTING', + 'task2.t REVERTED(None)', + 'task1.t REVERTED(None)', + 'r.r RETRYING', + 'task1.t PENDING', + 'task2.t PENDING', + 'r.r RUNNING', + 'r.r SUCCESS(2)', + 'task1.t RUNNING', + 'task2.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t SUCCESS(None)', + 'flow-1.f SUCCESS', + ] self.assertCountEqual(capturer.values, expected) def test_nested_flow_reverts_parent_retries(self): @@ -362,41 +376,43 @@ class RetryTest(utils.EngineTestBase): retry2 = retry.Times(0, 'r2', provides='x2') flow = lf.Flow('flow-1', retry1).add( utils.ProgressingTask("task1"), - lf.Flow('flow-2', retry2).add(utils.ConditionalTask("task2")) + lf.Flow('flow-2', retry2).add(utils.ConditionalTask("task2")), ) engine = self._make_engine(flow) engine.storage.inject({'y': 2}) with utils.CaptureListener(engine) as capturer: engine.run() self.assertEqual({'y': 2, 'x': 2, 'x2': 1}, engine.storage.fetch_all()) - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'r2.r RUNNING', - 'r2.r SUCCESS(1)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot!)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r2.r REVERTING', - 'r2.r REVERTED(None)', - 'task1.t REVERTING', - 'task1.t REVERTED(None)', - 'r1.r RETRYING', - 'task1.t PENDING', - 'r2.r PENDING', - 'task2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'r2.r RUNNING', - 'r2.r SUCCESS(1)', - 'task2.t RUNNING', - 'task2.t SUCCESS(None)', - 'flow-1.f SUCCESS'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'r2.r RUNNING', + 'r2.r SUCCESS(1)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot!)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r2.r REVERTING', + 'r2.r REVERTED(None)', + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'r1.r RETRYING', + 'task1.t PENDING', + 'r2.r PENDING', + 'task2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'r2.r RUNNING', + 'r2.r SUCCESS(1)', + 'task2.t RUNNING', + 'task2.t SUCCESS(None)', + 'flow-1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_nested_flow_with_retry_revert(self): @@ -404,7 +420,8 @@ class RetryTest(utils.EngineTestBase): flow = lf.Flow('flow-1').add( utils.ProgressingTask("task1"), lf.Flow('flow-2', retry1).add( - utils.ConditionalTask("task2", inject={'x': 1})) + utils.ConditionalTask("task2", inject={'x': 1}) + ), ) engine = self._make_engine(flow) engine.storage.inject({'y': 2}) @@ -414,18 +431,20 @@ class RetryTest(utils.EngineTestBase): except Exception: pass self.assertEqual({'y': 2}, engine.storage.fetch_all()) - expected = ['flow-1.f RUNNING', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot!)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot!)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) def test_nested_flow_with_retry_revert_all(self): @@ -433,7 +452,8 @@ class RetryTest(utils.EngineTestBase): flow = lf.Flow('flow-1').add( utils.ProgressingTask("task1"), lf.Flow('flow-2', retry1).add( - utils.ConditionalTask("task2", inject={'x': 1})) + utils.ConditionalTask("task2", inject={'x': 1}) + ), ) engine = self._make_engine(flow) engine.storage.inject({'y': 2}) @@ -443,56 +463,62 @@ class RetryTest(utils.EngineTestBase): except Exception: pass self.assertEqual({'y': 2}, engine.storage.fetch_all()) - expected = ['flow-1.f RUNNING', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot!)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'task1.t REVERTING', - 'task1.t REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot!)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) def test_revert_all_retry(self): flow = lf.Flow('flow-1', retry.Times(3, 'r1', provides='x')).add( utils.ProgressingTask("task1"), lf.Flow('flow-2', retry.AlwaysRevertAll('r2')).add( - utils.ConditionalTask("task2")) + utils.ConditionalTask("task2") + ), ) engine = self._make_engine(flow) engine.storage.inject({'y': 2}) with utils.CaptureListener(engine) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) self.assertEqual({'y': 2}, engine.storage.fetch_all()) - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'r2.r RUNNING', - 'r2.r SUCCESS(None)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot!)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r2.r REVERTING', - 'r2.r REVERTED(None)', - 'task1.t REVERTING', - 'task1.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'r2.r RUNNING', + 'r2.r SUCCESS(None)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot!)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r2.r REVERTING', + 'r2.r REVERTED(None)', + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) def test_restart_reverted_flow_with_retry(self): flow = lf.Flow('test', retry=utils.OneReturnRetry(provides='x')).add( - utils.FailingTask('fail')) + utils.FailingTask('fail') + ) engine = self._make_engine(flow) self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) @@ -504,15 +530,20 @@ class RetryTest(utils.EngineTestBase): subflow1 = lf.Flow('subflow1') # * a task that completes in 3 sec with a few retries - subsubflow1 = lf.Flow('subflow1.subsubflow1', - retry=utils.RetryFiveTimes()) - subsubflow1.add(utils.SuccessAfter3Sec('subflow1.fail1', - inject={'start_time': now})) + subsubflow1 = lf.Flow( + 'subflow1.subsubflow1', retry=utils.RetryFiveTimes() + ) + subsubflow1.add( + utils.SuccessAfter3Sec( + 'subflow1.fail1', inject={'start_time': now} + ) + ) subflow1.add(subsubflow1) # * a task that fails and triggers a revert after 5 retries - subsubflow2 = lf.Flow('subflow1.subsubflow2', - retry=utils.RetryFiveTimes()) + subsubflow2 = lf.Flow( + 'subflow1.subsubflow2', retry=utils.RetryFiveTimes() + ) subsubflow2.add(utils.FailingTask('subflow1.fail2')) subflow1.add(subsubflow2) @@ -520,8 +551,9 @@ class RetryTest(utils.EngineTestBase): subflow2 = lf.Flow('subflow2') # * a task that always fails and retries - subsubflow1 = lf.Flow('subflow2.subsubflow1', - retry=utils.AlwaysRetry()) + subsubflow1 = lf.Flow( + 'subflow2.subsubflow1', retry=utils.AlwaysRetry() + ) subsubflow1.add(utils.FailingTask('subflow2.fail1')) subflow2.add(subsubflow1) @@ -536,14 +568,15 @@ class RetryTest(utils.EngineTestBase): engine = self._make_engine(flow) # This test fails when using Green threads, skipping it for now - if isinstance(engine._task_executor, - executor.ParallelGreenThreadTaskExecutor): + if isinstance( + engine._task_executor, executor.ParallelGreenThreadTaskExecutor + ): self.skipTest("Skipping this test when using green threads.") with utils.CaptureListener(engine) as capturer: - self.assertRaisesRegex(exc.WrappedFailure, - '.*RuntimeError: Woot!', - engine.run) + self.assertRaisesRegex( + exc.WrappedFailure, '.*RuntimeError: Woot!', engine.run + ) # task1 should have been reverted self.assertIn('task1.t REVERTED(None)', capturer.values) @@ -561,7 +594,7 @@ class RetryTest(utils.EngineTestBase): flow = lf.Flow('flow-1', retry.Times(3, 'r1')).add( utils.ProgressingTask('t1'), utils.ProgressingTask('t2'), - utils.ProgressingTask('t3') + utils.ProgressingTask('t3'), ) engine = self._make_engine(flow) engine.compile() @@ -572,24 +605,25 @@ class RetryTest(utils.EngineTestBase): engine.storage.set_atom_state('t2', st.REVERTED) engine.storage.set_atom_state('t3', st.REVERTED) engine.run() - expected = ['flow-1.f RUNNING', - 't2.t PENDING', - 't3.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 't1.t RUNNING', - 't1.t SUCCESS(5)', - 't2.t RUNNING', - 't2.t SUCCESS(5)', - 't3.t RUNNING', - 't3.t SUCCESS(5)', - 'flow-1.f SUCCESS'] + expected = [ + 'flow-1.f RUNNING', + 't2.t PENDING', + 't3.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 't1.t RUNNING', + 't1.t SUCCESS(5)', + 't2.t RUNNING', + 't2.t SUCCESS(5)', + 't3.t RUNNING', + 't3.t SUCCESS(5)', + 'flow-1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_resume_flow_that_should_be_retried(self): flow = lf.Flow('flow-1', retry.Times(3, 'r1')).add( - utils.ProgressingTask('t1'), - utils.ProgressingTask('t2') + utils.ProgressingTask('t1'), utils.ProgressingTask('t2') ) engine = self._make_engine(flow) engine.compile() @@ -600,93 +634,98 @@ class RetryTest(utils.EngineTestBase): engine.storage.set_atom_state('t1', st.REVERTED) engine.storage.set_atom_state('t2', st.REVERTED) engine.run() - expected = ['flow-1.f RUNNING', - 'r1.r RETRYING', - 't1.t PENDING', - 't2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 't1.t RUNNING', - 't1.t SUCCESS(5)', - 't2.t RUNNING', - 't2.t SUCCESS(5)', - 'flow-1.f SUCCESS'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RETRYING', + 't1.t PENDING', + 't2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 't1.t RUNNING', + 't1.t SUCCESS(5)', + 't2.t RUNNING', + 't2.t SUCCESS(5)', + 'flow-1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_retry_tasks_that_has_not_been_reverted(self): flow = lf.Flow('flow-1', retry.Times(3, 'r1', provides='x')).add( - utils.ConditionalTask('c'), - utils.ProgressingTask('t1') + utils.ConditionalTask('c'), utils.ProgressingTask('t1') ) engine = self._make_engine(flow) engine.storage.inject({'y': 2}) with utils.CaptureListener(engine) as capturer: engine.run() - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 'c.t RUNNING', - 'c.t FAILURE(Failure: RuntimeError: Woot!)', - 'c.t REVERTING', - 'c.t REVERTED(None)', - 'r1.r RETRYING', - 'c.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 'c.t RUNNING', - 'c.t SUCCESS(None)', - 't1.t RUNNING', - 't1.t SUCCESS(5)', - 'flow-1.f SUCCESS'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 'c.t RUNNING', + 'c.t FAILURE(Failure: RuntimeError: Woot!)', + 'c.t REVERTING', + 'c.t REVERTED(None)', + 'r1.r RETRYING', + 'c.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 'c.t RUNNING', + 'c.t SUCCESS(None)', + 't1.t RUNNING', + 't1.t SUCCESS(5)', + 'flow-1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_default_times_retry(self): flow = lf.Flow('flow-1', retry.Times(3, 'r1')).add( - utils.ProgressingTask('t1'), - utils.FailingTask('t2')) + utils.ProgressingTask('t1'), utils.FailingTask('t2') + ) engine = self._make_engine(flow) with utils.CaptureListener(engine) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', - 'r1.r SUCCESS(1)', - 't1.t RUNNING', - 't1.t SUCCESS(5)', - 't2.t RUNNING', - 't2.t FAILURE(Failure: RuntimeError: Woot!)', - 't2.t REVERTING', - 't2.t REVERTED(None)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r RETRYING', - 't1.t PENDING', - 't2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 't1.t RUNNING', - 't1.t SUCCESS(5)', - 't2.t RUNNING', - 't2.t FAILURE(Failure: RuntimeError: Woot!)', - 't2.t REVERTING', - 't2.t REVERTED(None)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r RETRYING', - 't1.t PENDING', - 't2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(3)', - 't1.t RUNNING', - 't1.t SUCCESS(5)', - 't2.t RUNNING', - 't2.t FAILURE(Failure: RuntimeError: Woot!)', - 't2.t REVERTING', - 't2.t REVERTED(None)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(1)', + 't1.t RUNNING', + 't1.t SUCCESS(5)', + 't2.t RUNNING', + 't2.t FAILURE(Failure: RuntimeError: Woot!)', + 't2.t REVERTING', + 't2.t REVERTED(None)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r RETRYING', + 't1.t PENDING', + 't2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 't1.t RUNNING', + 't1.t SUCCESS(5)', + 't2.t RUNNING', + 't2.t FAILURE(Failure: RuntimeError: Woot!)', + 't2.t REVERTING', + 't2.t REVERTED(None)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r RETRYING', + 't1.t PENDING', + 't2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(3)', + 't1.t RUNNING', + 't1.t SUCCESS(5)', + 't2.t RUNNING', + 't2.t FAILURE(Failure: RuntimeError: Woot!)', + 't2.t REVERTING', + 't2.t REVERTED(None)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) def test_for_each_with_list(self): @@ -696,40 +735,42 @@ class RetryTest(utils.EngineTestBase): engine = self._make_engine(flow) with utils.CaptureListener(engine) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', - 'r1.r SUCCESS(3)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 3)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r RETRYING', - 't1.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 2)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r RETRYING', - 't1.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(3)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 3)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r RETRYING', - 't1.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(5)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 5)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(3)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 3)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r RETRYING', + 't1.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 2)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r RETRYING', + 't1.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(3)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 3)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r RETRYING', + 't1.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(5)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 5)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) def test_for_each_with_set(self): @@ -739,32 +780,34 @@ class RetryTest(utils.EngineTestBase): engine = self._make_engine(flow) with utils.CaptureListener(engine) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 2)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r RETRYING', - 't1.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(3)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 3)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r RETRYING', - 't1.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(5)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 5)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 2)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r RETRYING', + 't1.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(3)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 3)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r RETRYING', + 't1.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(5)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 5)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertCountEqual(capturer.values, expected) def test_nested_for_each_revert(self): @@ -774,47 +817,49 @@ class RetryTest(utils.EngineTestBase): utils.ProgressingTask("task1"), lf.Flow('flow-2', retry1).add( utils.FailingTaskWithOneArg('task2') - ) + ), ) engine = self._make_engine(flow) with utils.CaptureListener(engine) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) - expected = ['flow-1.f RUNNING', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'r1.r RUNNING', - 'r1.r SUCCESS(3)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot with 3)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r1.r RETRYING', - 'task2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot with 2)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r1.r RETRYING', - 'task2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(3)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot with 3)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r1.r RETRYING', - 'task2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(5)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot with 5)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'r1.r RUNNING', + 'r1.r SUCCESS(3)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot with 3)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r1.r RETRYING', + 'task2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot with 2)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r1.r RETRYING', + 'task2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(3)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot with 3)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r1.r RETRYING', + 'task2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot with 5)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) def test_nested_for_each_revert_all(self): @@ -824,49 +869,51 @@ class RetryTest(utils.EngineTestBase): utils.ProgressingTask("task1"), lf.Flow('flow-2', retry1).add( utils.FailingTaskWithOneArg('task2') - ) + ), ) engine = self._make_engine(flow) with utils.CaptureListener(engine) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) - expected = ['flow-1.f RUNNING', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'r1.r RUNNING', - 'r1.r SUCCESS(3)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot with 3)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r1.r RETRYING', - 'task2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot with 2)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r1.r RETRYING', - 'task2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(3)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot with 3)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r1.r RETRYING', - 'task2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(5)', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot with 5)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'task1.t REVERTING', - 'task1.t REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'r1.r RUNNING', + 'r1.r SUCCESS(3)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot with 3)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r1.r RETRYING', + 'task2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot with 2)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r1.r RETRYING', + 'task2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(3)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot with 3)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r1.r RETRYING', + 'task2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(5)', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot with 5)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) def test_for_each_empty_collection(self): @@ -885,68 +932,72 @@ class RetryTest(utils.EngineTestBase): engine.storage.inject({'values': values, 'y': 1}) with utils.CaptureListener(engine) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', - 'r1.r SUCCESS(3)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 3)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r RETRYING', - 't1.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 2)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r RETRYING', - 't1.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(5)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 5)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(3)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 3)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r RETRYING', + 't1.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 2)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r RETRYING', + 't1.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(5)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 5)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) def test_parameterized_for_each_with_set(self): - values = ([3, 2, 5]) + values = [3, 2, 5] retry1 = retry.ParameterizedForEach('r1', provides='x') flow = lf.Flow('flow-1', retry1).add(utils.FailingTaskWithOneArg('t1')) engine = self._make_engine(flow) engine.storage.inject({'values': values, 'y': 1}) with utils.CaptureListener(engine) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) - expected = ['flow-1.f RUNNING', - 'r1.r RUNNING', - 'r1.r SUCCESS(3)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 3)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r RETRYING', - 't1.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 2)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r RETRYING', - 't1.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(5)', - 't1.t RUNNING', - 't1.t FAILURE(Failure: RuntimeError: Woot with 5)', - 't1.t REVERTING', - 't1.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'r1.r RUNNING', + 'r1.r SUCCESS(3)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 3)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r RETRYING', + 't1.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 2)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r RETRYING', + 't1.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(5)', + 't1.t RUNNING', + 't1.t FAILURE(Failure: RuntimeError: Woot with 5)', + 't1.t REVERTING', + 't1.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertCountEqual(capturer.values, expected) def test_nested_parameterized_for_each_revert(self): @@ -956,86 +1007,91 @@ class RetryTest(utils.EngineTestBase): utils.ProgressingTask('task-1'), lf.Flow('flow-2', retry1).add( utils.FailingTaskWithOneArg('task-2') - ) + ), ) engine = self._make_engine(flow) engine.storage.inject({'values': values, 'y': 1}) with utils.CaptureListener(engine) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) - expected = ['flow-1.f RUNNING', - 'task-1.t RUNNING', - 'task-1.t SUCCESS(5)', - 'r1.r RUNNING', - 'r1.r SUCCESS(3)', - 'task-2.t RUNNING', - 'task-2.t FAILURE(Failure: RuntimeError: Woot with 3)', - 'task-2.t REVERTING', - 'task-2.t REVERTED(None)', - 'r1.r RETRYING', - 'task-2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 'task-2.t RUNNING', - 'task-2.t FAILURE(Failure: RuntimeError: Woot with 2)', - 'task-2.t REVERTING', - 'task-2.t REVERTED(None)', - 'r1.r RETRYING', - 'task-2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(5)', - 'task-2.t RUNNING', - 'task-2.t FAILURE(Failure: RuntimeError: Woot with 5)', - 'task-2.t REVERTING', - 'task-2.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'task-1.t RUNNING', + 'task-1.t SUCCESS(5)', + 'r1.r RUNNING', + 'r1.r SUCCESS(3)', + 'task-2.t RUNNING', + 'task-2.t FAILURE(Failure: RuntimeError: Woot with 3)', + 'task-2.t REVERTING', + 'task-2.t REVERTED(None)', + 'r1.r RETRYING', + 'task-2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 'task-2.t RUNNING', + 'task-2.t FAILURE(Failure: RuntimeError: Woot with 2)', + 'task-2.t REVERTING', + 'task-2.t REVERTED(None)', + 'r1.r RETRYING', + 'task-2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(5)', + 'task-2.t RUNNING', + 'task-2.t FAILURE(Failure: RuntimeError: Woot with 5)', + 'task-2.t REVERTING', + 'task-2.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) def test_nested_parameterized_for_each_revert_all(self): values = [3, 2, 5] - retry1 = retry.ParameterizedForEach('r1', provides='x', - revert_all=True) + retry1 = retry.ParameterizedForEach( + 'r1', provides='x', revert_all=True + ) flow = lf.Flow('flow-1').add( utils.ProgressingTask('task-1'), lf.Flow('flow-2', retry1).add( utils.FailingTaskWithOneArg('task-2') - ) + ), ) engine = self._make_engine(flow) engine.storage.inject({'values': values, 'y': 1}) with utils.CaptureListener(engine) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) - expected = ['flow-1.f RUNNING', - 'task-1.t RUNNING', - 'task-1.t SUCCESS(5)', - 'r1.r RUNNING', - 'r1.r SUCCESS(3)', - 'task-2.t RUNNING', - 'task-2.t FAILURE(Failure: RuntimeError: Woot with 3)', - 'task-2.t REVERTING', - 'task-2.t REVERTED(None)', - 'r1.r RETRYING', - 'task-2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(2)', - 'task-2.t RUNNING', - 'task-2.t FAILURE(Failure: RuntimeError: Woot with 2)', - 'task-2.t REVERTING', - 'task-2.t REVERTED(None)', - 'r1.r RETRYING', - 'task-2.t PENDING', - 'r1.r RUNNING', - 'r1.r SUCCESS(5)', - 'task-2.t RUNNING', - 'task-2.t FAILURE(Failure: RuntimeError: Woot with 5)', - 'task-2.t REVERTING', - 'task-2.t REVERTED(None)', - 'r1.r REVERTING', - 'r1.r REVERTED(None)', - 'task-1.t REVERTING', - 'task-1.t REVERTED(None)', - 'flow-1.f REVERTED'] + expected = [ + 'flow-1.f RUNNING', + 'task-1.t RUNNING', + 'task-1.t SUCCESS(5)', + 'r1.r RUNNING', + 'r1.r SUCCESS(3)', + 'task-2.t RUNNING', + 'task-2.t FAILURE(Failure: RuntimeError: Woot with 3)', + 'task-2.t REVERTING', + 'task-2.t REVERTED(None)', + 'r1.r RETRYING', + 'task-2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(2)', + 'task-2.t RUNNING', + 'task-2.t FAILURE(Failure: RuntimeError: Woot with 2)', + 'task-2.t REVERTING', + 'task-2.t REVERTED(None)', + 'r1.r RETRYING', + 'task-2.t PENDING', + 'r1.r RUNNING', + 'r1.r SUCCESS(5)', + 'task-2.t RUNNING', + 'task-2.t FAILURE(Failure: RuntimeError: Woot with 5)', + 'task-2.t REVERTING', + 'task-2.t REVERTED(None)', + 'r1.r REVERTING', + 'r1.r REVERTED(None)', + 'task-1.t REVERTING', + 'task-1.t REVERTED(None)', + 'flow-1.f REVERTED', + ] self.assertEqual(expected, capturer.values) def test_parameterized_for_each_empty_collection(self): @@ -1048,7 +1104,8 @@ class RetryTest(utils.EngineTestBase): def _pretend_to_run_a_flow_and_crash(self, when): flow = uf.Flow('flow-1', retry.Times(3, provides='x')).add( - utils.ProgressingTask('task1')) + utils.ProgressingTask('task1') + ) engine = self._make_engine(flow) engine.compile() engine.prepare() @@ -1085,74 +1142,84 @@ class RetryTest(utils.EngineTestBase): engine = self._pretend_to_run_a_flow_and_crash('task fails') with utils.CaptureListener(engine) as capturer: engine.run() - expected = ['task1.t REVERTING', - 'task1.t REVERTED(None)', - 'flow-1_retry.r RETRYING', - 'task1.t PENDING', - 'flow-1_retry.r RUNNING', - 'flow-1_retry.r SUCCESS(2)', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'flow-1.f SUCCESS'] + expected = [ + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'flow-1_retry.r RETRYING', + 'task1.t PENDING', + 'flow-1_retry.r RUNNING', + 'flow-1_retry.r SUCCESS(2)', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'flow-1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_resumption_on_crash_after_retry_queried(self): engine = self._pretend_to_run_a_flow_and_crash('retry queried') with utils.CaptureListener(engine) as capturer: engine.run() - expected = ['task1.t REVERTING', - 'task1.t REVERTED(None)', - 'flow-1_retry.r RETRYING', - 'task1.t PENDING', - 'flow-1_retry.r RUNNING', - 'flow-1_retry.r SUCCESS(2)', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'flow-1.f SUCCESS'] + expected = [ + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'flow-1_retry.r RETRYING', + 'task1.t PENDING', + 'flow-1_retry.r RUNNING', + 'flow-1_retry.r SUCCESS(2)', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'flow-1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_resumption_on_crash_after_retry_updated(self): engine = self._pretend_to_run_a_flow_and_crash('retry updated') with utils.CaptureListener(engine) as capturer: engine.run() - expected = ['task1.t REVERTING', - 'task1.t REVERTED(None)', - 'flow-1_retry.r RETRYING', - 'task1.t PENDING', - 'flow-1_retry.r RUNNING', - 'flow-1_retry.r SUCCESS(2)', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'flow-1.f SUCCESS'] + expected = [ + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'flow-1_retry.r RETRYING', + 'task1.t PENDING', + 'flow-1_retry.r RUNNING', + 'flow-1_retry.r SUCCESS(2)', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'flow-1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_resumption_on_crash_after_task_updated(self): engine = self._pretend_to_run_a_flow_and_crash('task updated') with utils.CaptureListener(engine) as capturer: engine.run() - expected = ['task1.t REVERTING', - 'task1.t REVERTED(None)', - 'flow-1_retry.r RETRYING', - 'task1.t PENDING', - 'flow-1_retry.r RUNNING', - 'flow-1_retry.r SUCCESS(2)', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'flow-1.f SUCCESS'] + expected = [ + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'flow-1_retry.r RETRYING', + 'task1.t PENDING', + 'flow-1_retry.r RUNNING', + 'flow-1_retry.r SUCCESS(2)', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'flow-1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_resumption_on_crash_after_revert_scheduled(self): engine = self._pretend_to_run_a_flow_and_crash('revert scheduled') with utils.CaptureListener(engine) as capturer: engine.run() - expected = ['task1.t REVERTED(None)', - 'flow-1_retry.r RETRYING', - 'task1.t PENDING', - 'flow-1_retry.r RUNNING', - 'flow-1_retry.r SUCCESS(2)', - 'task1.t RUNNING', - 'task1.t SUCCESS(5)', - 'flow-1.f SUCCESS'] + expected = [ + 'task1.t REVERTED(None)', + 'flow-1_retry.r RETRYING', + 'task1.t PENDING', + 'flow-1_retry.r RUNNING', + 'flow-1_retry.r SUCCESS(2)', + 'task1.t RUNNING', + 'task1.t SUCCESS(5)', + 'flow-1.f SUCCESS', + ] self.assertEqual(expected, capturer.values) def test_retry_fails(self): @@ -1177,7 +1244,9 @@ class RetryTest(utils.EngineTestBase): utils.ProgressingTask('a', requires=['x']), lf.Flow("test2", retry=retry.Times(2)).add( utils.ProgressingTask('b', provides='x'), - utils.FailingTask('c'))) + utils.FailingTask('c'), + ), + ) engine = self._make_engine(flow) engine.compile() engine.prepare() @@ -1186,21 +1255,26 @@ class RetryTest(utils.EngineTestBase): engine.storage.save('a', 10) with utils.CaptureListener(engine, capture_flow=False) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) - expected = ['c.t RUNNING', - 'c.t FAILURE(Failure: RuntimeError: Woot!)', - 'a.t REVERTING', - 'c.t REVERTING', - 'a.t REVERTED(None)', - 'c.t REVERTED(None)', - 'b.t REVERTING', - 'b.t REVERTED(None)'] + expected = [ + 'c.t RUNNING', + 'c.t FAILURE(Failure: RuntimeError: Woot!)', + 'a.t REVERTING', + 'c.t REVERTING', + 'a.t REVERTED(None)', + 'c.t REVERTED(None)', + 'b.t REVERTING', + 'b.t REVERTED(None)', + ] self.assertCountEqual(capturer.values[:8], expected) # Task 'a' was or was not executed again, both cases are ok. - self.assertIsSuperAndSubsequence(capturer.values[8:], [ - 'b.t RUNNING', - 'c.t FAILURE(Failure: RuntimeError: Woot!)', - 'b.t REVERTED(None)', - ]) + self.assertIsSuperAndSubsequence( + capturer.values[8:], + [ + 'b.t RUNNING', + 'c.t FAILURE(Failure: RuntimeError: Woot!)', + 'b.t REVERTED(None)', + ], + ) self.assertEqual(st.REVERTED, engine.storage.get_flow_state()) def test_nested_provides_graph_retried_correctly(self): @@ -1208,7 +1282,9 @@ class RetryTest(utils.EngineTestBase): utils.ProgressingTask('a', requires=['x']), lf.Flow("test2", retry=retry.Times(2)).add( utils.ProgressingTask('b', provides='x'), - utils.ProgressingTask('c'))) + utils.ProgressingTask('c'), + ), + ) engine = self._make_engine(flow) engine.compile() engine.prepare() @@ -1219,22 +1295,26 @@ class RetryTest(utils.EngineTestBase): engine.storage.save('c', fail, st.FAILURE) with utils.CaptureListener(engine, capture_flow=False) as capturer: engine.run() - expected = ['c.t REVERTING', - 'c.t REVERTED(None)', - 'b.t REVERTING', - 'b.t REVERTED(None)'] + expected = [ + 'c.t REVERTING', + 'c.t REVERTED(None)', + 'b.t REVERTING', + 'b.t REVERTED(None)', + ] self.assertCountEqual(capturer.values[:4], expected) - expected = ['test2_retry.r RETRYING', - 'b.t PENDING', - 'c.t PENDING', - 'test2_retry.r RUNNING', - 'test2_retry.r SUCCESS(2)', - 'b.t RUNNING', - 'b.t SUCCESS(5)', - 'a.t RUNNING', - 'c.t RUNNING', - 'a.t SUCCESS(5)', - 'c.t SUCCESS(5)'] + expected = [ + 'test2_retry.r RETRYING', + 'b.t PENDING', + 'c.t PENDING', + 'test2_retry.r RUNNING', + 'test2_retry.r SUCCESS(2)', + 'b.t RUNNING', + 'b.t SUCCESS(5)', + 'a.t RUNNING', + 'c.t RUNNING', + 'a.t SUCCESS(5)', + 'c.t SUCCESS(5)', + ] self.assertCountEqual(expected, capturer.values[4:]) self.assertEqual(st.SUCCESS, engine.storage.get_flow_state()) @@ -1244,11 +1324,11 @@ class RetryParallelExecutionTest(utils.EngineTestBase): # them in a way that works with more executors... def test_when_subflow_fails_revert_running_tasks(self): - waiting_task = utils.WaitForOneFromTask('task1', 'task2', - [st.SUCCESS, st.FAILURE]) + waiting_task = utils.WaitForOneFromTask( + 'task1', 'task2', [st.SUCCESS, st.FAILURE] + ) flow = uf.Flow('flow-1', retry.Times(3, 'r', provides='x')).add( - waiting_task, - utils.ConditionalTask('task2') + waiting_task, utils.ConditionalTask('task2') ) engine = self._make_engine(flow) engine.atom_notifier.register('*', waiting_task.callback) @@ -1256,35 +1336,38 @@ class RetryParallelExecutionTest(utils.EngineTestBase): with utils.CaptureListener(engine, capture_flow=False) as capturer: engine.run() self.assertEqual({'y': 2, 'x': 2}, engine.storage.fetch_all()) - expected = ['r.r RUNNING', - 'r.r SUCCESS(1)', - 'task1.t RUNNING', - 'task2.t RUNNING', - 'task2.t FAILURE(Failure: RuntimeError: Woot!)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'task1.t SUCCESS(5)', - 'task1.t REVERTING', - 'task1.t REVERTED(None)', - 'r.r RETRYING', - 'task1.t PENDING', - 'task2.t PENDING', - 'r.r RUNNING', - 'r.r SUCCESS(2)', - 'task1.t RUNNING', - 'task2.t RUNNING', - 'task2.t SUCCESS(None)', - 'task1.t SUCCESS(5)'] + expected = [ + 'r.r RUNNING', + 'r.r SUCCESS(1)', + 'task1.t RUNNING', + 'task2.t RUNNING', + 'task2.t FAILURE(Failure: RuntimeError: Woot!)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'task1.t SUCCESS(5)', + 'task1.t REVERTING', + 'task1.t REVERTED(None)', + 'r.r RETRYING', + 'task1.t PENDING', + 'task2.t PENDING', + 'r.r RUNNING', + 'r.r SUCCESS(2)', + 'task1.t RUNNING', + 'task2.t RUNNING', + 'task2.t SUCCESS(None)', + 'task1.t SUCCESS(5)', + ] self.assertCountEqual(capturer.values, expected) def test_when_subflow_fails_revert_success_tasks(self): - waiting_task = utils.WaitForOneFromTask('task2', 'task1', - [st.SUCCESS, st.FAILURE]) + waiting_task = utils.WaitForOneFromTask( + 'task2', 'task1', [st.SUCCESS, st.FAILURE] + ) flow = uf.Flow('flow-1', retry.Times(3, 'r', provides='x')).add( utils.ProgressingTask('task1'), lf.Flow('flow-2').add( - waiting_task, - utils.ConditionalTask('task3')) + waiting_task, utils.ConditionalTask('task3') + ), ) engine = self._make_engine(flow) engine.atom_notifier.register('*', waiting_task.callback) @@ -1292,72 +1375,81 @@ class RetryParallelExecutionTest(utils.EngineTestBase): with utils.CaptureListener(engine, capture_flow=False) as capturer: engine.run() self.assertEqual({'y': 2, 'x': 2}, engine.storage.fetch_all()) - expected = ['r.r RUNNING', - 'r.r SUCCESS(1)', - 'task1.t RUNNING', - 'task2.t RUNNING', - 'task1.t SUCCESS(5)', - 'task2.t SUCCESS(5)', - 'task3.t RUNNING', - 'task3.t FAILURE(Failure: RuntimeError: Woot!)', - 'task3.t REVERTING', - 'task1.t REVERTING', - 'task3.t REVERTED(None)', - 'task1.t REVERTED(None)', - 'task2.t REVERTING', - 'task2.t REVERTED(None)', - 'r.r RETRYING', - 'task1.t PENDING', - 'task2.t PENDING', - 'task3.t PENDING', - 'r.r RUNNING', - 'r.r SUCCESS(2)', - 'task1.t RUNNING', - 'task2.t RUNNING', - 'task1.t SUCCESS(5)', - 'task2.t SUCCESS(5)', - 'task3.t RUNNING', - 'task3.t SUCCESS(None)'] + expected = [ + 'r.r RUNNING', + 'r.r SUCCESS(1)', + 'task1.t RUNNING', + 'task2.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t SUCCESS(5)', + 'task3.t RUNNING', + 'task3.t FAILURE(Failure: RuntimeError: Woot!)', + 'task3.t REVERTING', + 'task1.t REVERTING', + 'task3.t REVERTED(None)', + 'task1.t REVERTED(None)', + 'task2.t REVERTING', + 'task2.t REVERTED(None)', + 'r.r RETRYING', + 'task1.t PENDING', + 'task2.t PENDING', + 'task3.t PENDING', + 'r.r RUNNING', + 'r.r SUCCESS(2)', + 'task1.t RUNNING', + 'task2.t RUNNING', + 'task1.t SUCCESS(5)', + 'task2.t SUCCESS(5)', + 'task3.t RUNNING', + 'task3.t SUCCESS(None)', + ] self.assertCountEqual(capturer.values, expected) class SerialEngineTest(RetryTest, test.TestCase): def _make_engine(self, flow, defer_reverts=None, flow_detail=None): - return taskflow.engines.load(flow, - flow_detail=flow_detail, - engine='serial', - backend=self.backend, - defer_reverts=defer_reverts) + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + engine='serial', + backend=self.backend, + defer_reverts=defer_reverts, + ) -class ParallelEngineWithThreadsTest(RetryTest, - RetryParallelExecutionTest, - test.TestCase): +class ParallelEngineWithThreadsTest( + RetryTest, RetryParallelExecutionTest, test.TestCase +): _EXECUTOR_WORKERS = 2 - def _make_engine(self, flow, defer_reverts=None, flow_detail=None, - executor=None): + def _make_engine( + self, flow, defer_reverts=None, flow_detail=None, executor=None + ): if executor is None: executor = 'threads' - return taskflow.engines.load(flow, - flow_detail=flow_detail, - engine='parallel', - backend=self.backend, - executor=executor, - max_workers=self._EXECUTOR_WORKERS, - defer_reverts=defer_reverts) + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + engine='parallel', + backend=self.backend, + executor=executor, + max_workers=self._EXECUTOR_WORKERS, + defer_reverts=defer_reverts, + ) @testtools.skipIf(not eu.EVENTLET_AVAILABLE, 'eventlet is not available') class ParallelEngineWithEventletTest(RetryTest, test.TestCase): - - def _make_engine(self, flow, defer_reverts=None, flow_detail=None, - executor=None): + def _make_engine( + self, flow, defer_reverts=None, flow_detail=None, executor=None + ): if executor is None: executor = 'greenthreads' - return taskflow.engines.load(flow, - flow_detail=flow_detail, - backend=self.backend, - engine='parallel', - executor=executor, - defer_reverts=defer_reverts) + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + backend=self.backend, + engine='parallel', + executor=executor, + defer_reverts=defer_reverts, + ) diff --git a/taskflow/tests/unit/test_states.py b/taskflow/tests/unit/test_states.py index 21e8cec1a..bb0cc1fda 100644 --- a/taskflow/tests/unit/test_states.py +++ b/taskflow/tests/unit/test_states.py @@ -20,13 +20,15 @@ from taskflow import test class TestStates(test.TestCase): def test_valid_flow_states(self): for start_state, end_state in states._ALLOWED_FLOW_TRANSITIONS: - self.assertTrue(states.check_flow_transition(start_state, - end_state)) + self.assertTrue( + states.check_flow_transition(start_state, end_state) + ) def test_ignored_flow_states(self): for start_state, end_state in states._IGNORED_FLOW_TRANSITIONS: - self.assertFalse(states.check_flow_transition(start_state, - end_state)) + self.assertFalse( + states.check_flow_transition(start_state, end_state) + ) def test_invalid_flow_states(self): invalids = [ @@ -36,14 +38,18 @@ class TestStates(test.TestCase): (states.RESUMING, states.RUNNING), ] for start_state, end_state in invalids: - self.assertRaises(excp.InvalidState, - states.check_flow_transition, - start_state, end_state) + self.assertRaises( + excp.InvalidState, + states.check_flow_transition, + start_state, + end_state, + ) def test_valid_job_states(self): for start_state, end_state in states._ALLOWED_JOB_TRANSITIONS: - self.assertTrue(states.check_job_transition(start_state, - end_state)) + self.assertTrue( + states.check_job_transition(start_state, end_state) + ) def test_ignored_job_states(self): ignored = [] @@ -51,8 +57,9 @@ class TestStates(test.TestCase): ignored.append((start_state, start_state)) ignored.append((end_state, end_state)) for start_state, end_state in ignored: - self.assertFalse(states.check_job_transition(start_state, - end_state)) + self.assertFalse( + states.check_job_transition(start_state, end_state) + ) def test_invalid_job_states(self): invalids = [ @@ -60,14 +67,18 @@ class TestStates(test.TestCase): (states.UNCLAIMED, states.COMPLETE), ] for start_state, end_state in invalids: - self.assertRaises(excp.InvalidState, - states.check_job_transition, - start_state, end_state) + self.assertRaises( + excp.InvalidState, + states.check_job_transition, + start_state, + end_state, + ) def test_valid_task_states(self): for start_state, end_state in states._ALLOWED_TASK_TRANSITIONS: - self.assertTrue(states.check_task_transition(start_state, - end_state)) + self.assertTrue( + states.check_task_transition(start_state, end_state) + ) def test_invalid_task_states(self): invalids = [ @@ -82,4 +93,5 @@ class TestStates(test.TestCase): # TODO(harlowja): fix this so that it raises instead of # returning false... self.assertFalse( - states.check_task_transition(start_state, end_state)) + states.check_task_transition(start_state, end_state) + ) diff --git a/taskflow/tests/unit/test_storage.py b/taskflow/tests/unit/test_storage.py index dd5890f6a..d3eb03cb8 100644 --- a/taskflow/tests/unit/test_storage.py +++ b/taskflow/tests/unit/test_storage.py @@ -189,9 +189,9 @@ class StorageTestMixin: def test_fetch_unknown_name(self): s = self._get_storage() - self.assertRaisesRegex(exceptions.NotFound, - "^Name 'xxx' is not mapped", - s.fetch, 'xxx') + self.assertRaisesRegex( + exceptions.NotFound, "^Name 'xxx' is not mapped", s.fetch, 'xxx' + ) def test_flow_metadata_update(self): s = self._get_storage() @@ -221,24 +221,24 @@ class StorageTestMixin: s.set_task_progress('my task', 0.5, {'test_data': 11}) self.assertEqual(0.5, s.get_task_progress('my task')) - self.assertEqual({ - 'at_progress': 0.5, - 'details': {'test_data': 11} - }, s.get_task_progress_details('my task')) + self.assertEqual( + {'at_progress': 0.5, 'details': {'test_data': 11}}, + s.get_task_progress_details('my task'), + ) s.set_task_progress('my task', 0.7, {'test_data': 17}) self.assertEqual(0.7, s.get_task_progress('my task')) - self.assertEqual({ - 'at_progress': 0.7, - 'details': {'test_data': 17} - }, s.get_task_progress_details('my task')) + self.assertEqual( + {'at_progress': 0.7, 'details': {'test_data': 17}}, + s.get_task_progress_details('my task'), + ) s.set_task_progress('my task', 0.99) self.assertEqual(0.99, s.get_task_progress('my task')) - self.assertEqual({ - 'at_progress': 0.7, - 'details': {'test_data': 17} - }, s.get_task_progress_details('my task')) + self.assertEqual( + {'at_progress': 0.7, 'details': {'test_data': 17}}, + s.get_task_progress_details('my task'), + ) def test_task_progress_erase(self): s = self._get_storage() @@ -259,10 +259,13 @@ class StorageTestMixin: s = self._get_storage() s.ensure_atom(test_utils.NoopTask('my task', provides=['foo', 'bar'])) s.save('my task', ('spam', 'eggs')) - self.assertEqual({ - 'foo': 'spam', - 'bar': 'eggs', - }, s.fetch_all()) + self.assertEqual( + { + 'foo': 'spam', + 'bar': 'eggs', + }, + s.fetch_all(), + ) def test_mapping_none(self): s = self._get_storage() @@ -274,37 +277,49 @@ class StorageTestMixin: s = self._get_storage() s.inject({'foo': 'bar', 'spam': 'eggs'}) self.assertEqual('eggs', s.fetch('spam')) - self.assertEqual({ - 'foo': 'bar', - 'spam': 'eggs', - }, s.fetch_all()) + self.assertEqual( + { + 'foo': 'bar', + 'spam': 'eggs', + }, + s.fetch_all(), + ) def test_inject_twice(self): s = self._get_storage() s.inject({'foo': 'bar'}) self.assertEqual({'foo': 'bar'}, s.fetch_all()) s.inject({'spam': 'eggs'}) - self.assertEqual({ - 'foo': 'bar', - 'spam': 'eggs', - }, s.fetch_all()) + self.assertEqual( + { + 'foo': 'bar', + 'spam': 'eggs', + }, + s.fetch_all(), + ) def test_inject_resumed(self): s = self._get_storage() s.inject({'foo': 'bar', 'spam': 'eggs'}) # verify it's there - self.assertEqual({ - 'foo': 'bar', - 'spam': 'eggs', - }, s.fetch_all()) + self.assertEqual( + { + 'foo': 'bar', + 'spam': 'eggs', + }, + s.fetch_all(), + ) # imagine we are resuming, so we need to make new # storage from same flow details s2 = self._get_storage(s._flowdetail) # injected data should still be there: - self.assertEqual({ - 'foo': 'bar', - 'spam': 'eggs', - }, s2.fetch_all()) + self.assertEqual( + { + 'foo': 'bar', + 'spam': 'eggs', + }, + s2.fetch_all(), + ) def test_many_thread_ensure_same_task(self): s = self._get_storage() @@ -331,8 +346,9 @@ class StorageTestMixin: values = { str(i): str(i), } - threads.append(threading.Thread(target=inject_values, - args=[values])) + threads.append( + threading.Thread(target=inject_values, args=[values]) + ) self._run_many_threads(threads) self.assertEqual(self.thread_count, len(s.fetch_all())) @@ -341,28 +357,34 @@ class StorageTestMixin: def test_fetch_mapped_args(self): s = self._get_storage() s.inject({'foo': 'bar', 'spam': 'eggs'}) - self.assertEqual({'viking': 'eggs'}, - s.fetch_mapped_args({'viking': 'spam'})) + self.assertEqual( + {'viking': 'eggs'}, s.fetch_mapped_args({'viking': 'spam'}) + ) def test_fetch_not_found_args(self): s = self._get_storage() s.inject({'foo': 'bar', 'spam': 'eggs'}) - self.assertRaises(exceptions.NotFound, - s.fetch_mapped_args, {'viking': 'helmet'}) + self.assertRaises( + exceptions.NotFound, s.fetch_mapped_args, {'viking': 'helmet'} + ) def test_fetch_optional_args_found(self): s = self._get_storage() s.inject({'foo': 'bar', 'spam': 'eggs'}) - self.assertEqual({'viking': 'eggs'}, - s.fetch_mapped_args({'viking': 'spam'}, - optional_args={'viking'})) + self.assertEqual( + {'viking': 'eggs'}, + s.fetch_mapped_args({'viking': 'spam'}, optional_args={'viking'}), + ) def test_fetch_optional_args_not_found(self): s = self._get_storage() s.inject({'foo': 'bar', 'spam': 'eggs'}) - self.assertEqual({}, - s.fetch_mapped_args({'viking': 'helmet'}, - optional_args={'viking'})) + self.assertEqual( + {}, + s.fetch_mapped_args( + {'viking': 'helmet'}, optional_args={'viking'} + ), + ) def test_set_and_get_task_state(self): s = self._get_storage() @@ -373,8 +395,9 @@ class StorageTestMixin: def test_get_state_of_unknown_task(self): s = self._get_storage() - self.assertRaisesRegex(exceptions.NotFound, '^Unknown', - s.get_atom_state, 'my task') + self.assertRaisesRegex( + exceptions.NotFound, '^Unknown', s.get_atom_state, 'my task' + ) def test_task_by_name(self): s = self._get_storage() @@ -412,9 +435,9 @@ class StorageTestMixin: def test_unknown_task_by_name(self): s = self._get_storage() - self.assertRaisesRegex(exceptions.NotFound, - '^Unknown atom', - s.get_atom_uuid, '42') + self.assertRaisesRegex( + exceptions.NotFound, '^Unknown atom', s.get_atom_uuid, '42' + ) def test_initial_flow_state(self): s = self._get_storage() @@ -437,23 +460,26 @@ class StorageTestMixin: s = self._get_storage() s.ensure_atom(test_utils.NoopTask('my task', provides={'result'})) s.save('my task', {}) - self.assertRaisesRegex(exceptions.NotFound, - '^Unable to find result', s.fetch, 'result') + self.assertRaisesRegex( + exceptions.NotFound, '^Unable to find result', s.fetch, 'result' + ) def test_empty_result_is_checked(self): s = self._get_storage() s.ensure_atom(test_utils.NoopTask('my task', provides=['a'])) s.save('my task', ()) - self.assertRaisesRegex(exceptions.NotFound, - '^Unable to find result', s.fetch, 'a') + self.assertRaisesRegex( + exceptions.NotFound, '^Unable to find result', s.fetch, 'a' + ) def test_short_result_is_checked(self): s = self._get_storage() s.ensure_atom(test_utils.NoopTask('my task', provides=['a', 'b'])) s.save('my task', ['result']) self.assertEqual('result', s.fetch('a')) - self.assertRaisesRegex(exceptions.NotFound, - '^Unable to find result', s.fetch, 'b') + self.assertRaisesRegex( + exceptions.NotFound, '^Unable to find result', s.fetch, 'b' + ) def test_ensure_retry(self): s = self._get_storage() @@ -464,9 +490,12 @@ class StorageTestMixin: def test_ensure_retry_and_task_with_same_name(self): s = self._get_storage() s.ensure_atom(test_utils.NoopTask('my retry')) - self.assertRaisesRegex(exceptions.Duplicate, - '^Atom detail', s.ensure_atom, - test_utils.NoopRetry('my retry')) + self.assertRaisesRegex( + exceptions.Duplicate, + '^Atom detail', + s.ensure_atom, + test_utils.NoopRetry('my retry'), + ) def test_save_retry_results(self): s = self._get_storage() @@ -514,9 +543,9 @@ class StorageTestMixin: self.assertEqual({'my retry': a_failure}, s.get_failures()) def test_logbook_get_unknown_atom_type(self): - self.assertRaisesRegex(TypeError, - 'Unknown atom', - models.atom_detail_class, 'some_detail') + self.assertRaisesRegex( + TypeError, 'Unknown atom', models.atom_detail_class, 'some_detail' + ) def test_save_task_intention(self): s = self._get_storage() @@ -563,8 +592,7 @@ class StorageTestMixin: s.ensure_atom(t) s.save('my task', 2) self.assertEqual(2, s.get('my task')) - self.assertRaises(exceptions.NotFound, - s.get_revert_result, 'my task') + self.assertRaises(exceptions.NotFound, s.get_revert_result, 'my task') def test_save_fetch_revert(self): t = test_utils.GiveBackRevert('my task') diff --git a/taskflow/tests/unit/test_suspend.py b/taskflow/tests/unit/test_suspend.py index 3f266455c..22817c8db 100644 --- a/taskflow/tests/unit/test_suspend.py +++ b/taskflow/tests/unit/test_suspend.py @@ -25,12 +25,8 @@ from taskflow.utils import eventlet_utils as eu class SuspendingListener(utils.CaptureListener): - - def __init__(self, engine, - task_name, task_state, capture_flow=False): - super().__init__( - engine, - capture_flow=capture_flow) + def __init__(self, engine, task_name, task_state, capture_flow=False): + super().__init__(engine, capture_flow=capture_flow) self._revert_match = (task_name, task_state) def _task_receiver(self, state, details): @@ -40,18 +36,19 @@ class SuspendingListener(utils.CaptureListener): class SuspendTest(utils.EngineTestBase): - def test_suspend_one_task(self): flow = utils.ProgressingTask('a') engine = self._make_engine(flow) - with SuspendingListener(engine, task_name='b', - task_state=states.SUCCESS) as capturer: + with SuspendingListener( + engine, task_name='b', task_state=states.SUCCESS + ) as capturer: engine.run() self.assertEqual(states.SUCCESS, engine.storage.get_flow_state()) expected = ['a.t RUNNING', 'a.t SUCCESS(5)'] self.assertEqual(expected, capturer.values) - with SuspendingListener(engine, task_name='b', - task_state=states.SUCCESS) as capturer: + with SuspendingListener( + engine, task_name='b', task_state=states.SUCCESS + ) as capturer: engine.run() self.assertEqual(states.SUCCESS, engine.storage.get_flow_state()) expected = [] @@ -61,15 +58,20 @@ class SuspendTest(utils.EngineTestBase): flow = lf.Flow('linear').add( utils.ProgressingTask('a'), utils.ProgressingTask('b'), - utils.ProgressingTask('c') + utils.ProgressingTask('c'), ) engine = self._make_engine(flow) - with SuspendingListener(engine, task_name='b', - task_state=states.SUCCESS) as capturer: + with SuspendingListener( + engine, task_name='b', task_state=states.SUCCESS + ) as capturer: engine.run() self.assertEqual(states.SUSPENDED, engine.storage.get_flow_state()) - expected = ['a.t RUNNING', 'a.t SUCCESS(5)', - 'b.t RUNNING', 'b.t SUCCESS(5)'] + expected = [ + 'a.t RUNNING', + 'a.t SUCCESS(5)', + 'b.t RUNNING', + 'b.t SUCCESS(5)', + ] self.assertEqual(expected, capturer.values) with utils.CaptureListener(engine, capture_flow=False) as capturer: engine.run() @@ -81,23 +83,26 @@ class SuspendTest(utils.EngineTestBase): flow = lf.Flow('linear').add( utils.ProgressingTask('a'), utils.ProgressingTask('b'), - utils.FailingTask('c') + utils.FailingTask('c'), ) engine = self._make_engine(flow) - with SuspendingListener(engine, task_name='b', - task_state=states.REVERTED) as capturer: + with SuspendingListener( + engine, task_name='b', task_state=states.REVERTED + ) as capturer: engine.run() self.assertEqual(states.SUSPENDED, engine.storage.get_flow_state()) - expected = ['a.t RUNNING', - 'a.t SUCCESS(5)', - 'b.t RUNNING', - 'b.t SUCCESS(5)', - 'c.t RUNNING', - 'c.t FAILURE(Failure: RuntimeError: Woot!)', - 'c.t REVERTING', - 'c.t REVERTED(None)', - 'b.t REVERTING', - 'b.t REVERTED(None)'] + expected = [ + 'a.t RUNNING', + 'a.t SUCCESS(5)', + 'b.t RUNNING', + 'b.t SUCCESS(5)', + 'c.t RUNNING', + 'c.t FAILURE(Failure: RuntimeError: Woot!)', + 'c.t REVERTING', + 'c.t REVERTED(None)', + 'b.t REVERTING', + 'b.t REVERTED(None)', + ] self.assertEqual(expected, capturer.values) with utils.CaptureListener(engine, capture_flow=False) as capturer: self.assertRaisesRegex(RuntimeError, '^Woot', engine.run) @@ -109,22 +114,25 @@ class SuspendTest(utils.EngineTestBase): flow = lf.Flow('linear').add( utils.ProgressingTask('a'), utils.ProgressingTask('b'), - utils.FailingTask('c') + utils.FailingTask('c'), ) engine = self._make_engine(flow) - with SuspendingListener(engine, task_name='b', - task_state=states.REVERTED) as capturer: + with SuspendingListener( + engine, task_name='b', task_state=states.REVERTED + ) as capturer: engine.run() - expected = ['a.t RUNNING', - 'a.t SUCCESS(5)', - 'b.t RUNNING', - 'b.t SUCCESS(5)', - 'c.t RUNNING', - 'c.t FAILURE(Failure: RuntimeError: Woot!)', - 'c.t REVERTING', - 'c.t REVERTED(None)', - 'b.t REVERTING', - 'b.t REVERTED(None)'] + expected = [ + 'a.t RUNNING', + 'a.t SUCCESS(5)', + 'b.t RUNNING', + 'b.t SUCCESS(5)', + 'c.t RUNNING', + 'c.t FAILURE(Failure: RuntimeError: Woot!)', + 'c.t REVERTING', + 'c.t REVERTED(None)', + 'b.t REVERTING', + 'b.t REVERTED(None)', + ] self.assertEqual(expected, capturer.values) # pretend we are resuming @@ -132,32 +140,34 @@ class SuspendTest(utils.EngineTestBase): with utils.CaptureListener(engine2, capture_flow=False) as capturer2: self.assertRaisesRegex(RuntimeError, '^Woot', engine2.run) self.assertEqual(states.REVERTED, engine2.storage.get_flow_state()) - expected = ['a.t REVERTING', - 'a.t REVERTED(None)'] + expected = ['a.t REVERTING', 'a.t REVERTED(None)'] self.assertEqual(expected, capturer2.values) def test_suspend_and_revert_even_if_task_is_gone(self): flow = lf.Flow('linear').add( utils.ProgressingTask('a'), utils.ProgressingTask('b'), - utils.FailingTask('c') + utils.FailingTask('c'), ) engine = self._make_engine(flow) - with SuspendingListener(engine, task_name='b', - task_state=states.REVERTED) as capturer: + with SuspendingListener( + engine, task_name='b', task_state=states.REVERTED + ) as capturer: engine.run() - expected = ['a.t RUNNING', - 'a.t SUCCESS(5)', - 'b.t RUNNING', - 'b.t SUCCESS(5)', - 'c.t RUNNING', - 'c.t FAILURE(Failure: RuntimeError: Woot!)', - 'c.t REVERTING', - 'c.t REVERTED(None)', - 'b.t REVERTING', - 'b.t REVERTED(None)'] + expected = [ + 'a.t RUNNING', + 'a.t SUCCESS(5)', + 'b.t RUNNING', + 'b.t SUCCESS(5)', + 'c.t RUNNING', + 'c.t FAILURE(Failure: RuntimeError: Woot!)', + 'c.t REVERTING', + 'c.t REVERTED(None)', + 'b.t REVERTING', + 'b.t REVERTED(None)', + ] self.assertEqual(expected, capturer.values) # pretend we are resuming, but task 'c' gone when flow got updated @@ -175,26 +185,28 @@ class SuspendTest(utils.EngineTestBase): def test_storage_is_rechecked(self): flow = lf.Flow('linear').add( utils.ProgressingTask('b', requires=['foo']), - utils.ProgressingTask('c') + utils.ProgressingTask('c'), ) engine = self._make_engine(flow) engine.storage.inject({'foo': 'bar'}) - with SuspendingListener(engine, task_name='b', - task_state=states.SUCCESS): + with SuspendingListener( + engine, task_name='b', task_state=states.SUCCESS + ): engine.run() self.assertEqual(states.SUSPENDED, engine.storage.get_flow_state()) # uninject everything: - engine.storage.save(engine.storage.injector_name, - {}, states.SUCCESS) + engine.storage.save(engine.storage.injector_name, {}, states.SUCCESS) self.assertRaises(exc.MissingDependencies, engine.run) class SerialEngineTest(SuspendTest, test.TestCase): def _make_engine(self, flow, flow_detail=None): - return taskflow.engines.load(flow, - flow_detail=flow_detail, - engine='serial', - backend=self.backend) + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + engine='serial', + backend=self.backend, + ) class ParallelEngineWithThreadsTest(SuspendTest, test.TestCase): @@ -203,20 +215,26 @@ class ParallelEngineWithThreadsTest(SuspendTest, test.TestCase): def _make_engine(self, flow, flow_detail=None, executor=None): if executor is None: executor = 'threads' - return taskflow.engines.load(flow, flow_detail=flow_detail, - engine='parallel', - backend=self.backend, - executor=executor, - max_workers=self._EXECUTOR_WORKERS) + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + engine='parallel', + backend=self.backend, + executor=executor, + max_workers=self._EXECUTOR_WORKERS, + ) @testtools.skipIf(not eu.EVENTLET_AVAILABLE, 'eventlet is not available') class ParallelEngineWithEventletTest(SuspendTest, test.TestCase): - def _make_engine(self, flow, flow_detail=None, executor=None): if executor is None: executor = futurist.GreenThreadPoolExecutor() self.addCleanup(executor.shutdown) - return taskflow.engines.load(flow, flow_detail=flow_detail, - backend=self.backend, engine='parallel', - executor=executor) + return taskflow.engines.load( + flow, + flow_detail=flow_detail, + backend=self.backend, + engine='parallel', + executor=executor, + ) diff --git a/taskflow/tests/unit/test_task.py b/taskflow/tests/unit/test_task.py index fde6b8a6c..b6a2909ed 100644 --- a/taskflow/tests/unit/test_task.py +++ b/taskflow/tests/unit/test_task.py @@ -72,15 +72,13 @@ class RevertKwargsTask(task.Task): class TaskTest(test.TestCase): - def test_passed_name(self): my_task = MyTask(name='my name') self.assertEqual('my name', my_task.name) def test_generated_name(self): my_task = MyTask() - self.assertEqual('{}.{}'.format(__name__, 'MyTask'), - my_task.name) + self.assertEqual('{}.{}'.format(__name__, 'MyTask'), my_task.name) def test_task_str(self): my_task = MyTask(name='my') @@ -107,44 +105,36 @@ class TaskTest(test.TestCase): self.assertEqual({'food': 0}, my_task.save_as) def test_bad_provides(self): - self.assertRaisesRegex(TypeError, '^Atom provides', - MyTask, provides=object()) + self.assertRaisesRegex( + TypeError, '^Atom provides', MyTask, provides=object() + ) def test_requires_by_default(self): my_task = MyTask() - expected = { - 'spam': 'spam', - 'eggs': 'eggs', - 'context': 'context' - } - self.assertEqual(expected, - my_task.rebind) - self.assertEqual({'spam', 'eggs', 'context'}, - my_task.requires) + expected = {'spam': 'spam', 'eggs': 'eggs', 'context': 'context'} + self.assertEqual(expected, my_task.rebind) + self.assertEqual({'spam', 'eggs', 'context'}, my_task.requires) def test_requires_amended(self): my_task = MyTask(requires=('spam', 'eggs')) - expected = { - 'spam': 'spam', - 'eggs': 'eggs', - 'context': 'context' - } + expected = {'spam': 'spam', 'eggs': 'eggs', 'context': 'context'} self.assertEqual(expected, my_task.rebind) def test_requires_explicit(self): - my_task = MyTask(auto_extract=False, - requires=('spam', 'eggs', 'context')) - expected = { - 'spam': 'spam', - 'eggs': 'eggs', - 'context': 'context' - } + my_task = MyTask( + auto_extract=False, requires=('spam', 'eggs', 'context') + ) + expected = {'spam': 'spam', 'eggs': 'eggs', 'context': 'context'} self.assertEqual(expected, my_task.rebind) def test_requires_explicit_not_enough(self): - self.assertRaisesRegex(ValueError, '^Missing arguments', - MyTask, - auto_extract=False, requires=('spam', 'eggs')) + self.assertRaisesRegex( + ValueError, + '^Missing arguments', + MyTask, + auto_extract=False, + requires=('spam', 'eggs'), + ) def test_requires_ignores_optional(self): my_task = DefaultArgTask() @@ -166,78 +156,53 @@ class TaskTest(test.TestCase): def test_rebind_all_args(self): my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'}) - expected = { - 'spam': 'a', - 'eggs': 'b', - 'context': 'c' - } + expected = {'spam': 'a', 'eggs': 'b', 'context': 'c'} self.assertEqual(expected, my_task.rebind) - self.assertEqual({'a', 'b', 'c'}, - my_task.requires) + self.assertEqual({'a', 'b', 'c'}, my_task.requires) def test_rebind_partial(self): my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b'}) - expected = { - 'spam': 'a', - 'eggs': 'b', - 'context': 'context' - } + expected = {'spam': 'a', 'eggs': 'b', 'context': 'context'} self.assertEqual(expected, my_task.rebind) - self.assertEqual({'a', 'b', 'context'}, - my_task.requires) + self.assertEqual({'a', 'b', 'context'}, my_task.requires) def test_rebind_unknown(self): - self.assertRaisesRegex(ValueError, '^Extra arguments', - MyTask, rebind={'foo': 'bar'}) + self.assertRaisesRegex( + ValueError, '^Extra arguments', MyTask, rebind={'foo': 'bar'} + ) def test_rebind_unknown_kwargs(self): my_task = KwargsTask(rebind={'foo': 'bar'}) - expected = { - 'foo': 'bar', - 'spam': 'spam' - } + expected = {'foo': 'bar', 'spam': 'spam'} self.assertEqual(expected, my_task.rebind) def test_rebind_list_all(self): my_task = MyTask(rebind=('a', 'b', 'c')) - expected = { - 'context': 'a', - 'spam': 'b', - 'eggs': 'c' - } + expected = {'context': 'a', 'spam': 'b', 'eggs': 'c'} self.assertEqual(expected, my_task.rebind) - self.assertEqual({'a', 'b', 'c'}, - my_task.requires) + self.assertEqual({'a', 'b', 'c'}, my_task.requires) def test_rebind_list_partial(self): my_task = MyTask(rebind=('a', 'b')) - expected = { - 'context': 'a', - 'spam': 'b', - 'eggs': 'eggs' - } + expected = {'context': 'a', 'spam': 'b', 'eggs': 'eggs'} self.assertEqual(expected, my_task.rebind) - self.assertEqual({'a', 'b', 'eggs'}, - my_task.requires) + self.assertEqual({'a', 'b', 'eggs'}, my_task.requires) def test_rebind_list_more(self): - self.assertRaisesRegex(ValueError, '^Extra arguments', - MyTask, rebind=('a', 'b', 'c', 'd')) + self.assertRaisesRegex( + ValueError, '^Extra arguments', MyTask, rebind=('a', 'b', 'c', 'd') + ) def test_rebind_list_more_kwargs(self): my_task = KwargsTask(rebind=('a', 'b', 'c')) - expected = { - 'spam': 'a', - 'b': 'b', - 'c': 'c' - } + expected = {'spam': 'a', 'b': 'b', 'c': 'c'} self.assertEqual(expected, my_task.rebind) - self.assertEqual({'a', 'b', 'c'}, - my_task.requires) + self.assertEqual({'a', 'b', 'c'}, my_task.requires) def test_rebind_list_bad_value(self): - self.assertRaisesRegex(TypeError, '^Invalid rebind value', - MyTask, rebind=object()) + self.assertRaisesRegex( + TypeError, '^Invalid rebind value', MyTask, rebind=object() + ) def test_default_provides(self): my_task = DefaultProvidesTask() @@ -300,15 +265,20 @@ class TaskTest(test.TestCase): def test_register_handler_is_none(self): a_task = MyTask() - self.assertRaises(ValueError, a_task.notifier.register, - task.EVENT_UPDATE_PROGRESS, None) + self.assertRaises( + ValueError, + a_task.notifier.register, + task.EVENT_UPDATE_PROGRESS, + None, + ) self.assertEqual(0, len(a_task.notifier)) def test_deregister_any_handler(self): a_task = MyTask() self.assertEqual(0, len(a_task.notifier)) - a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, - lambda event_type, details: None) + a_task.notifier.register( + task.EVENT_UPDATE_PROGRESS, lambda event_type, details: None + ) self.assertEqual(1, len(a_task.notifier)) a_task.notifier.deregister_event(task.EVENT_UPDATE_PROGRESS) self.assertEqual(0, len(a_task.notifier)) @@ -316,8 +286,9 @@ class TaskTest(test.TestCase): def test_deregister_any_handler_empty_listeners(self): a_task = MyTask() self.assertEqual(0, len(a_task.notifier)) - self.assertFalse(a_task.notifier.deregister_event( - task.EVENT_UPDATE_PROGRESS)) + self.assertFalse( + a_task.notifier.deregister_event(task.EVENT_UPDATE_PROGRESS) + ) self.assertEqual(0, len(a_task.notifier)) def test_deregister_non_existent_listener(self): @@ -333,8 +304,9 @@ class TaskTest(test.TestCase): def test_bind_not_callable(self): a_task = MyTask() - self.assertRaises(ValueError, a_task.notifier.register, - task.EVENT_UPDATE_PROGRESS, 2) + self.assertRaises( + ValueError, a_task.notifier.register, task.EVENT_UPDATE_PROGRESS, 2 + ) def test_copy_no_listeners(self): handler1 = lambda event_type, details: None @@ -351,8 +323,9 @@ class TaskTest(test.TestCase): a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1) b_task = a_task.copy() self.assertEqual(1, len(b_task.notifier)) - self.assertTrue(a_task.notifier.deregister_event( - task.EVENT_UPDATE_PROGRESS)) + self.assertTrue( + a_task.notifier.deregister_event(task.EVENT_UPDATE_PROGRESS) + ) self.assertEqual(0, len(a_task.notifier)) self.assertEqual(1, len(b_task.notifier)) b_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler2) @@ -364,16 +337,15 @@ class TaskTest(test.TestCase): my_task = SeparateRevertTask(rebind=('a',), revert_rebind=('b',)) self.assertEqual({'execute_arg': 'a'}, my_task.rebind) self.assertEqual({'revert_arg': 'b'}, my_task.revert_rebind) - self.assertEqual({'a', 'b'}, - my_task.requires) + self.assertEqual({'a', 'b'}, my_task.requires) - my_task = SeparateRevertTask(requires='execute_arg', - revert_requires='revert_arg') + my_task = SeparateRevertTask( + requires='execute_arg', revert_requires='revert_arg' + ) self.assertEqual({'execute_arg': 'execute_arg'}, my_task.rebind) self.assertEqual({'revert_arg': 'revert_arg'}, my_task.revert_rebind) - self.assertEqual({'execute_arg', 'revert_arg'}, - my_task.requires) + self.assertEqual({'execute_arg', 'revert_arg'}, my_task.requires) def test_separate_revert_optional_args(self): my_task = SeparateRevertOptionalTask() @@ -382,17 +354,17 @@ class TaskTest(test.TestCase): def test_revert_kwargs(self): my_task = RevertKwargsTask() - expected_rebind = {'execute_arg1': 'execute_arg1', - 'execute_arg2': 'execute_arg2'} + expected_rebind = { + 'execute_arg1': 'execute_arg1', + 'execute_arg2': 'execute_arg2', + } self.assertEqual(expected_rebind, my_task.rebind) expected_rebind = {'execute_arg1': 'execute_arg1'} self.assertEqual(expected_rebind, my_task.revert_rebind) - self.assertEqual({'execute_arg1', 'execute_arg2'}, - my_task.requires) + self.assertEqual({'execute_arg1', 'execute_arg2'}, my_task.requires) class FunctorTaskTest(test.TestCase): - def test_creation_with_version(self): version = (2, 0) f_task = task.FunctorTask(lambda: None, version=version) @@ -402,49 +374,53 @@ class FunctorTaskTest(test.TestCase): self.assertRaises(ValueError, task.FunctorTask, 2) def test_revert_not_callable(self): - self.assertRaises(ValueError, task.FunctorTask, lambda: None, - revert=2) + self.assertRaises(ValueError, task.FunctorTask, lambda: None, revert=2) class ReduceFunctorTaskTest(test.TestCase): - def test_invalid_functor(self): # Functor not callable self.assertRaises(ValueError, task.ReduceFunctorTask, 2, requires=5) # Functor takes no arguments - self.assertRaises(ValueError, task.ReduceFunctorTask, lambda: None, - requires=5) + self.assertRaises( + ValueError, task.ReduceFunctorTask, lambda: None, requires=5 + ) # Functor takes too few arguments - self.assertRaises(ValueError, task.ReduceFunctorTask, lambda x: None, - requires=5) + self.assertRaises( + ValueError, task.ReduceFunctorTask, lambda x: None, requires=5 + ) def test_functor_invalid_requires(self): # Invalid type, requires is not iterable - self.assertRaises(TypeError, task.ReduceFunctorTask, - lambda x, y: None, requires=1) + self.assertRaises( + TypeError, task.ReduceFunctorTask, lambda x, y: None, requires=1 + ) # Too few elements in requires - self.assertRaises(ValueError, task.ReduceFunctorTask, - lambda x, y: None, requires=[1]) + self.assertRaises( + ValueError, task.ReduceFunctorTask, lambda x, y: None, requires=[1] + ) class MapFunctorTaskTest(test.TestCase): - def test_invalid_functor(self): # Functor not callable self.assertRaises(ValueError, task.MapFunctorTask, 2, requires=5) # Functor takes no arguments - self.assertRaises(ValueError, task.MapFunctorTask, lambda: None, - requires=5) + self.assertRaises( + ValueError, task.MapFunctorTask, lambda: None, requires=5 + ) # Functor takes too many arguments - self.assertRaises(ValueError, task.MapFunctorTask, lambda x, y: None, - requires=5) + self.assertRaises( + ValueError, task.MapFunctorTask, lambda x, y: None, requires=5 + ) def test_functor_invalid_requires(self): # Invalid type, requires is not iterable - self.assertRaises(TypeError, task.MapFunctorTask, lambda x: None, - requires=1) + self.assertRaises( + TypeError, task.MapFunctorTask, lambda x: None, requires=1 + ) diff --git a/taskflow/tests/unit/test_types.py b/taskflow/tests/unit/test_types.py index a605d223d..be01018fe 100644 --- a/taskflow/tests/unit/test_types.py +++ b/taskflow/tests/unit/test_types.py @@ -25,8 +25,7 @@ from taskflow.types import tree class TimingTest(test.TestCase): def test_convert_fail(self): for baddie in ["abc123", "-1", "", object()]: - self.assertRaises(ValueError, - timing.convert_to_timeout, baddie) + self.assertRaises(ValueError, timing.convert_to_timeout, baddie) def test_convert_noop(self): t = timing.convert_to_timeout(1.0) @@ -47,14 +46,12 @@ class TimingTest(test.TestCase): self.assertFalse(t.is_stopped()) def test_values(self): - for v, e_v in [("1.0", 1.0), (1, 1.0), - ("2.0", 2.0)]: + for v, e_v in [("1.0", 1.0), (1, 1.0), ("2.0", 2.0)]: t = timing.convert_to_timeout(v) self.assertEqual(e_v, t.value) def test_fail(self): - self.assertRaises(ValueError, - timing.Timeout, -1) + self.assertRaises(ValueError, timing.Timeout, -1) class GraphTest(test.TestCase): @@ -64,10 +61,8 @@ class GraphTest(test.TestCase): g.add_node("b") g.add_node("c") g.add_edge("b", "c") - self.assertEqual({'a', 'b'}, - set(g.no_predecessors_iter())) - self.assertEqual({'a', 'c'}, - set(g.no_successors_iter())) + self.assertEqual({'a', 'b'}, set(g.no_predecessors_iter())) + self.assertEqual({'a', 'c'}, set(g.no_successors_iter())) def test_directed(self): g = graph.DiGraph() @@ -100,8 +95,10 @@ class GraphTest(test.TestCase): # NOTE(harlowja): ensure we use the ordered types here, otherwise # the expected output will vary based on randomized hashing and then # the test will fail randomly... - for graph_cls, kind, edge in [(graph.OrderedDiGraph, 'digraph', '->'), - (graph.OrderedGraph, 'graph', '--')]: + for graph_cls, kind, edge in [ + (graph.OrderedDiGraph, 'digraph', '->'), + (graph.OrderedGraph, 'graph', '--'), + ]: g = graph_cls(name='test') g.add_node("a") g.add_node("b") @@ -146,15 +143,18 @@ b %(edge)s c; g2.add_node('d') g2.add_edge('a', 'd') - self.assertRaises(ValueError, - graph.merge_graphs, g, g2) + self.assertRaises(ValueError, graph.merge_graphs, g, g2) def occurrence_detector(to_graph, from_graph): return sum(1 for node in from_graph.nodes if node in to_graph) - self.assertRaises(ValueError, - graph.merge_graphs, g, g2, - overlap_detector=occurrence_detector) + self.assertRaises( + ValueError, + graph.merge_graphs, + g, + g2, + overlap_detector=occurrence_detector, + ) g3 = graph.merge_graphs(g, g2, allow_overlaps=True) self.assertEqual(3, len(g3)) @@ -168,9 +168,9 @@ b %(edge)s c; g2 = graph.DiGraph() g2.add_node('c') - self.assertRaises(ValueError, - graph.merge_graphs, g, g2, - overlap_detector='b') + self.assertRaises( + ValueError, graph.merge_graphs, g, g2, overlap_detector='b' + ) class TreeTest(test.TestCase): @@ -451,12 +451,11 @@ CEO root.add(tree.Node("josh.1")) root.freeze() self.assertTrue( - all(n.frozen for n in root.dfs_iter(include_self=True))) - self.assertRaises(tree.FrozenNode, - root.remove, "josh.1") + all(n.frozen for n in root.dfs_iter(include_self=True)) + ) + self.assertRaises(tree.FrozenNode, root.remove, "josh.1") self.assertRaises(tree.FrozenNode, root.disassociate) - self.assertRaises(tree.FrozenNode, root.add, - tree.Node("josh.2")) + self.assertRaises(tree.FrozenNode, root.add, tree.Node("josh.2")) def test_removal(self): root = self._make_species() @@ -466,8 +465,7 @@ CEO def test_removal_direct(self): root = self._make_species() - self.assertRaises(ValueError, root.remove, 'human', - only_direct=True) + self.assertRaises(ValueError, root.remove, 'human', only_direct=True) def test_removal_self(self): root = self._make_species() @@ -526,46 +524,75 @@ CEO self.assertIsNotNone(root.find('animal', only_direct=True)) self.assertIsNotNone(root.find('reptile', only_direct=True)) self.assertIsNone(root.find('animal', include_self=False)) - self.assertIsNone(root.find('animal', - include_self=False, only_direct=True)) + self.assertIsNone( + root.find('animal', include_self=False, only_direct=True) + ) def test_dfs_itr(self): root = self._make_species() things = list([n.item for n in root.dfs_iter(include_self=True)]) - self.assertEqual({'animal', 'reptile', 'mammal', 'horse', - 'primate', 'monkey', 'human'}, set(things)) + self.assertEqual( + { + 'animal', + 'reptile', + 'mammal', + 'horse', + 'primate', + 'monkey', + 'human', + }, + set(things), + ) def test_dfs_itr_left_to_right(self): root = self._make_species() it = root.dfs_iter(include_self=False, right_to_left=False) things = list([n.item for n in it]) - self.assertEqual(['reptile', 'mammal', 'primate', - 'human', 'monkey', 'horse'], things) + self.assertEqual( + ['reptile', 'mammal', 'primate', 'human', 'monkey', 'horse'], + things, + ) def test_dfs_itr_no_self(self): root = self._make_species() things = list([n.item for n in root.dfs_iter(include_self=False)]) - self.assertEqual(['mammal', 'horse', 'primate', - 'monkey', 'human', 'reptile'], things) + self.assertEqual( + ['mammal', 'horse', 'primate', 'monkey', 'human', 'reptile'], + things, + ) def test_bfs_itr(self): root = self._make_species() things = list([n.item for n in root.bfs_iter(include_self=True)]) - self.assertEqual(['animal', 'reptile', 'mammal', 'primate', - 'horse', 'human', 'monkey'], things) + self.assertEqual( + [ + 'animal', + 'reptile', + 'mammal', + 'primate', + 'horse', + 'human', + 'monkey', + ], + things, + ) def test_bfs_itr_no_self(self): root = self._make_species() things = list([n.item for n in root.bfs_iter(include_self=False)]) - self.assertEqual(['reptile', 'mammal', 'primate', - 'horse', 'human', 'monkey'], things) + self.assertEqual( + ['reptile', 'mammal', 'primate', 'horse', 'human', 'monkey'], + things, + ) def test_bfs_itr_right_to_left(self): root = self._make_species() it = root.bfs_iter(include_self=False, right_to_left=True) things = list([n.item for n in it]) - self.assertEqual(['mammal', 'reptile', 'horse', - 'primate', 'monkey', 'human'], things) + self.assertEqual( + ['mammal', 'reptile', 'horse', 'primate', 'monkey', 'human'], + things, + ) def test_to_diagraph(self): root = self._make_species() @@ -590,7 +617,6 @@ CEO class OrderedSetTest(test.TestCase): - def test_pickleable(self): items = [10, 9, 8, 7] s = sets.OrderedSet(items) diff --git a/taskflow/tests/unit/test_utils.py b/taskflow/tests/unit/test_utils.py index 9befadfe9..fb3b79046 100644 --- a/taskflow/tests/unit/test_utils.py +++ b/taskflow/tests/unit/test_utils.py @@ -173,7 +173,6 @@ class UriParseTest(test.TestCase): class TestSequenceMinus(test.TestCase): - def test_simple_case(self): result = misc.sequence_minus([1, 2, 3, 4], [2, 3]) self.assertEqual([1, 4], result) diff --git a/taskflow/tests/unit/test_utils_async_utils.py b/taskflow/tests/unit/test_utils_async_utils.py index 2a5a5fefb..3cc78780c 100644 --- a/taskflow/tests/unit/test_utils_async_utils.py +++ b/taskflow/tests/unit/test_utils_async_utils.py @@ -17,7 +17,6 @@ from taskflow.utils import async_utils as au class MakeCompletedFutureTest(test.TestCase): - def test_make_completed_future(self): result = object() future = au.make_completed_future(result) diff --git a/taskflow/tests/unit/test_utils_binary.py b/taskflow/tests/unit/test_utils_binary.py index 1e531ee93..a078fee22 100644 --- a/taskflow/tests/unit/test_utils_binary.py +++ b/taskflow/tests/unit/test_utils_binary.py @@ -21,7 +21,6 @@ def _bytes(data): class BinaryEncodeTest(test.TestCase): - def _check(self, data, expected_result): result = misc.binary_encode(data) self.assertIsInstance(result, bytes) @@ -48,7 +47,6 @@ class BinaryEncodeTest(test.TestCase): class BinaryDecodeTest(test.TestCase): - def _check(self, data, expected_result): result = misc.binary_decode(data) self.assertIsInstance(result, str) @@ -76,23 +74,18 @@ class BinaryDecodeTest(test.TestCase): class DecodeJsonTest(test.TestCase): - def test_it_works(self): - self.assertEqual({"foo": 1}, - misc.decode_json(_bytes('{"foo": 1}'))) + self.assertEqual({"foo": 1}, misc.decode_json(_bytes('{"foo": 1}'))) def test_it_works_with_unicode(self): data = _bytes('{"foo": "фуу"}') self.assertEqual({"foo": 'фуу'}, misc.decode_json(data)) def test_handles_invalid_unicode(self): - self.assertRaises(ValueError, misc.decode_json, - b'{"\xf1": 1}') + self.assertRaises(ValueError, misc.decode_json, b'{"\xf1": 1}') def test_handles_bad_json(self): - self.assertRaises(ValueError, misc.decode_json, - _bytes('{"foo":')) + self.assertRaises(ValueError, misc.decode_json, _bytes('{"foo":')) def test_handles_wrong_types(self): - self.assertRaises(ValueError, misc.decode_json, - _bytes('42')) + self.assertRaises(ValueError, misc.decode_json, _bytes('42')) diff --git a/taskflow/tests/unit/test_utils_iter_utils.py b/taskflow/tests/unit/test_utils_iter_utils.py index 012446c49..36dd403aa 100644 --- a/taskflow/tests/unit/test_utils_iter_utils.py +++ b/taskflow/tests/unit/test_utils_iter_utils.py @@ -36,8 +36,7 @@ class IterUtilsTest(test.TestCase): None, object(), ] - self.assertRaises(ValueError, - iter_utils.unique_seen, iters) + self.assertRaises(ValueError, iter_utils.unique_seen, iters) def test_generate_delays(self): it = iter_utils.generate_delays(1, 60) @@ -62,8 +61,9 @@ class IterUtilsTest(test.TestCase): self.assertRaises(ValueError, iter_utils.generate_delays, -1, -1) self.assertRaises(ValueError, iter_utils.generate_delays, -1, 2) self.assertRaises(ValueError, iter_utils.generate_delays, 2, -1) - self.assertRaises(ValueError, iter_utils.generate_delays, 1, 1, - multiplier=0.5) + self.assertRaises( + ValueError, iter_utils.generate_delays, 1, 1, multiplier=0.5 + ) def test_unique_seen(self): iters = [ @@ -72,8 +72,10 @@ class IterUtilsTest(test.TestCase): ['a', 'e', 'f'], ['f', 'm', 'n'], ] - self.assertEqual(['a', 'b', 'c', 'd', 'e', 'f', 'm', 'n'], - list(iter_utils.unique_seen(iters))) + self.assertEqual( + ['a', 'b', 'c', 'd', 'e', 'f', 'm', 'n'], + list(iter_utils.unique_seen(iters)), + ) def test_unique_seen_empty(self): iters = [] @@ -86,8 +88,9 @@ class IterUtilsTest(test.TestCase): [(3, 'c')], [(1, 'a'), (3, 'c')], ] - it = iter_utils.unique_seen(iters, - seen_selector=lambda value: value[0]) + it = iter_utils.unique_seen( + iters, seen_selector=lambda value: value[0] + ) self.assertEqual([(1, 'a'), (2, 'b'), (3, 'c')], list(it)) def test_bad_fill(self): @@ -99,8 +102,9 @@ class IterUtilsTest(test.TestCase): self.assertEqual(50, sum(1 for x in result if x is not None)) def test_fill_custom_filler(self): - self.assertEqual("abcd", - "".join(iter_utils.fill("abc", 4, filler='d'))) + self.assertEqual( + "abcd", "".join(iter_utils.fill("abc", 4, filler='d')) + ) def test_fill_less_needed(self): self.assertEqual("ab", "".join(iter_utils.fill("abc", 2))) @@ -110,18 +114,19 @@ class IterUtilsTest(test.TestCase): self.assertEqual((None, None), tuple(iter_utils.fill([], 2))) def test_bad_find_first_match(self): - self.assertRaises(ValueError, - iter_utils.find_first_match, 2, lambda v: False) + self.assertRaises( + ValueError, iter_utils.find_first_match, 2, lambda v: False + ) def test_find_first_match(self): it = forever_it() - self.assertEqual(100, iter_utils.find_first_match(it, - lambda v: v == 100)) + self.assertEqual( + 100, iter_utils.find_first_match(it, lambda v: v == 100) + ) def test_find_first_match_not_found(self): it = iter(string.ascii_lowercase) - self.assertIsNone(iter_utils.find_first_match(it, - lambda v: v == '')) + self.assertIsNone(iter_utils.find_first_match(it, lambda v: v == '')) def test_bad_count(self): self.assertRaises(ValueError, iter_utils.count, 2) @@ -141,18 +146,22 @@ class IterUtilsTest(test.TestCase): class Dummy: def __init__(self, char): self.char = char - dummy_list = [Dummy(a) - for a in string.ascii_lowercase] + + dummy_list = [Dummy(a) for a in string.ascii_lowercase] it = iter(dummy_list) - self.assertEqual([dummy_list[0]], - list(iter_utils.while_is_not(it, dummy_list[0]))) + self.assertEqual( + [dummy_list[0]], list(iter_utils.while_is_not(it, dummy_list[0])) + ) it = iter(dummy_list) - self.assertEqual(dummy_list[0:2], - list(iter_utils.while_is_not(it, dummy_list[1]))) - self.assertEqual(dummy_list[2:], - list(iter_utils.while_is_not(it, Dummy('zzz')))) + self.assertEqual( + dummy_list[0:2], list(iter_utils.while_is_not(it, dummy_list[1])) + ) + self.assertEqual( + dummy_list[2:], list(iter_utils.while_is_not(it, Dummy('zzz'))) + ) it = iter(dummy_list) - self.assertEqual(dummy_list, - list(iter_utils.while_is_not(it, Dummy('')))) + self.assertEqual( + dummy_list, list(iter_utils.while_is_not(it, Dummy(''))) + ) diff --git a/taskflow/tests/unit/test_utils_kazoo_utils.py b/taskflow/tests/unit/test_utils_kazoo_utils.py index 81bd37146..dcd5c7e26 100644 --- a/taskflow/tests/unit/test_utils_kazoo_utils.py +++ b/taskflow/tests/unit/test_utils_kazoo_utils.py @@ -19,7 +19,6 @@ from taskflow.utils import kazoo_utils class MakeClientTest(test.TestCase): - @mock.patch("kazoo.client.KazooClient") def test_make_client_config(self, mock_kazoo_client): conf = {} @@ -32,7 +31,7 @@ class MakeClientTest(test.TestCase): 'keyfile_password': None, 'certfile': None, 'use_ssl': False, - 'verify_certs': True + 'verify_certs': True, } kazoo_utils.make_client(conf) @@ -42,10 +41,7 @@ class MakeClientTest(test.TestCase): mock_kazoo_client.reset_mock() # With boolean passed as strings - conf = { - 'use_ssl': 'True', - 'verify_certs': 'False' - } + conf = {'use_ssl': 'True', 'verify_certs': 'False'} expected = { 'hosts': 'localhost:2181', 'logger': mock.ANY, @@ -55,7 +51,7 @@ class MakeClientTest(test.TestCase): 'keyfile_password': None, 'certfile': None, 'use_ssl': True, - 'verify_certs': False + 'verify_certs': False, } kazoo_utils.make_client(conf) diff --git a/taskflow/tests/unit/test_utils_threading_utils.py b/taskflow/tests/unit/test_utils_threading_utils.py index 3ccc822c7..ef8d8787f 100644 --- a/taskflow/tests/unit/test_utils_threading_utils.py +++ b/taskflow/tests/unit/test_utils_threading_utils.py @@ -59,14 +59,16 @@ class TestThreadBundle(test.TestCase): def test_bind_invalid(self): self.assertRaises(ValueError, self.bundle.bind, 1) - for k in ['after_start', 'before_start', - 'before_join', 'after_join']: + for k in ['after_start', 'before_start', 'before_join', 'after_join']: kwargs = { k: 1, } - self.assertRaises(ValueError, self.bundle.bind, - lambda: tu.daemon_thread(_spinner, self.death), - **kwargs) + self.assertRaises( + ValueError, + self.bundle.bind, + lambda: tu.daemon_thread(_spinner, self.death), + **kwargs, + ) def test_bundle_length(self): self.assertEqual(0, len(self.bundle)) @@ -96,11 +98,13 @@ class TestThreadBundle(test.TestCase): death_events.append((i, 'aj')) for i in range(0, self.thread_count): - self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death), - before_join=functools.partial(before_join, i), - after_join=functools.partial(after_join, i), - before_start=functools.partial(before_start, i), - after_start=functools.partial(after_start, i)) + self.bundle.bind( + lambda: tu.daemon_thread(_spinner, self.death), + before_join=functools.partial(before_join, i), + after_join=functools.partial(after_join, i), + before_start=functools.partial(before_start, i), + after_start=functools.partial(after_start, i), + ) self.assertEqual(self.thread_count, self.bundle.start()) self.assertEqual(self.thread_count, len(self.bundle)) self.assertEqual(self.thread_count, self.bundle.stop()) @@ -109,17 +113,23 @@ class TestThreadBundle(test.TestCase): expected_start_events = [] for i in range(0, self.thread_count): - expected_start_events.extend([ - (i, 'bs'), (i, 'as'), - ]) + expected_start_events.extend( + [ + (i, 'bs'), + (i, 'as'), + ] + ) self.assertEqual(expected_start_events, list(start_events)) expected_death_events = [] j = self.thread_count - 1 for _i in range(0, self.thread_count): - expected_death_events.extend([ - (j, 'bj'), (j, 'aj'), - ]) + expected_death_events.extend( + [ + (j, 'bj'), + (j, 'aj'), + ] + ) j -= 1 self.assertEqual(expected_death_events, list(death_events)) @@ -140,16 +150,19 @@ class TestThreadBundle(test.TestCase): events.append('aj') for _i in range(0, self.thread_count): - self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death), - before_join=before_join, - after_join=after_join, - before_start=before_start, - after_start=after_start) + self.bundle.bind( + lambda: tu.daemon_thread(_spinner, self.death), + before_join=before_join, + after_join=after_join, + before_start=before_start, + after_start=after_start, + ) self.assertEqual(self.thread_count, self.bundle.start()) self.assertEqual(self.thread_count, len(self.bundle)) self.assertEqual(self.thread_count, self.bundle.stop()) for event in ['as', 'bs', 'bj', 'aj']: - self.assertEqual(self.thread_count, - len([e for e in events if e == event])) + self.assertEqual( + self.thread_count, len([e for e in events if e == event]) + ) self.assertEqual(0, self.bundle.stop()) self.assertTrue(self.death.is_set()) diff --git a/taskflow/tests/unit/worker_based/test_creation.py b/taskflow/tests/unit/worker_based/test_creation.py index f0e4ff64c..d87c2e3c8 100644 --- a/taskflow/tests/unit/worker_based/test_creation.py +++ b/taskflow/tests/unit/worker_based/test_creation.py @@ -29,27 +29,31 @@ class TestWorkerBasedActionEngine(test.MockTestCase): backend = backends.fetch({'connection': 'memory'}) flow_detail = pu.create_flow_detail(flow, backend=backend) options = kwargs.copy() - return engine.WorkerBasedActionEngine(flow, flow_detail, - backend, options) + return engine.WorkerBasedActionEngine( + flow, flow_detail, backend, options + ) def _patch_in_executor(self): executor_mock, executor_inst_mock = self.patchClass( - engine.executor, 'WorkerTaskExecutor', attach_as='executor') + engine.executor, 'WorkerTaskExecutor', attach_as='executor' + ) return executor_mock, executor_inst_mock def test_creation_default(self): executor_mock, executor_inst_mock = self._patch_in_executor() eng = self._create_engine() expected_calls = [ - mock.call.executor_class(uuid=eng.storage.flow_uuid, - url=None, - exchange='default', - topics=[], - transport=None, - transport_options=None, - transition_timeout=mock.ANY, - retry_options=None, - worker_expiry=mock.ANY) + mock.call.executor_class( + uuid=eng.storage.flow_uuid, + url=None, + exchange='default', + topics=[], + transport=None, + transport_options=None, + transition_timeout=mock.ANY, + retry_options=None, + worker_expiry=mock.ANY, + ) ] self.assertEqual(expected_calls, self.master_mock.mock_calls) @@ -66,17 +70,20 @@ class TestWorkerBasedActionEngine(test.MockTestCase): transition_timeout=200, topics=topics, retry_options={}, - worker_expiry=1) + worker_expiry=1, + ) expected_calls = [ - mock.call.executor_class(uuid=eng.storage.flow_uuid, - url=broker_url, - exchange=exchange, - topics=topics, - transport='memory', - transport_options={}, - transition_timeout=200, - retry_options={}, - worker_expiry=1) + mock.call.executor_class( + uuid=eng.storage.flow_uuid, + url=broker_url, + exchange=exchange, + topics=topics, + transport='memory', + transport_options={}, + transition_timeout=200, + retry_options={}, + worker_expiry=1, + ) ] self.assertEqual(expected_calls, self.master_mock.mock_calls) diff --git a/taskflow/tests/unit/worker_based/test_dispatcher.py b/taskflow/tests/unit/worker_based/test_dispatcher.py index c03b42978..9f1d7e3ff 100644 --- a/taskflow/tests/unit/worker_based/test_dispatcher.py +++ b/taskflow/tests/unit/worker_based/test_dispatcher.py @@ -23,8 +23,9 @@ from taskflow.test import mock def mock_acked_message(ack_ok=True, **kwargs): - msg = mock.create_autospec(message.Message, spec_set=True, instance=True, - channel=None, **kwargs) + msg = mock.create_autospec( + message.Message, spec_set=True, instance=True, channel=None, **kwargs + ) def ack_side_effect(*args, **kwargs): msg.acknowledged = True @@ -70,8 +71,7 @@ class TestDispatcher(test.TestCase): on_hello = mock.MagicMock() handlers = {'hello': dispatcher.Handler(on_hello)} d = dispatcher.TypeDispatcher(type_handlers=handlers) - msg = mock_acked_message(ack_ok=False, - properties={'type': 'hello'}) + msg = mock_acked_message(ack_ok=False, properties={'type': 'hello'}) d.on_message("", msg) self.assertTrue(msg.ack_log_error.called) self.assertFalse(msg.acknowledged) diff --git a/taskflow/tests/unit/worker_based/test_endpoint.py b/taskflow/tests/unit/worker_based/test_endpoint.py index 11fc89952..d72f73bed 100644 --- a/taskflow/tests/unit/worker_based/test_endpoint.py +++ b/taskflow/tests/unit/worker_based/test_endpoint.py @@ -21,7 +21,6 @@ from taskflow.tests import utils class Task(task.Task): - def __init__(self, a, *args, **kwargs): super().__init__(*args, **kwargs) @@ -30,7 +29,6 @@ class Task(task.Task): class TestEndpoint(test.TestCase): - def setUp(self): super().setUp() self.task_cls = utils.TaskOneReturn @@ -64,18 +62,22 @@ class TestEndpoint(test.TestCase): def test_execute(self): task = self.task_ep.generate(self.task_cls_name) - result = self.task_ep.execute(task, - task_uuid=self.task_uuid, - arguments=self.task_args, - progress_callback=None) + result = self.task_ep.execute( + task, + task_uuid=self.task_uuid, + arguments=self.task_args, + progress_callback=None, + ) self.assertEqual(self.task_result, result) def test_revert(self): task = self.task_ep.generate(self.task_cls_name) - result = self.task_ep.revert(task, - task_uuid=self.task_uuid, - arguments=self.task_args, - progress_callback=None, - result=self.task_result, - failures={}) + result = self.task_ep.revert( + task, + task_uuid=self.task_uuid, + arguments=self.task_args, + progress_callback=None, + result=self.task_result, + failures={}, + ) self.assertIsNone(result) diff --git a/taskflow/tests/unit/worker_based/test_executor.py b/taskflow/tests/unit/worker_based/test_executor.py index fa53cb652..83cf2cb97 100644 --- a/taskflow/tests/unit/worker_based/test_executor.py +++ b/taskflow/tests/unit/worker_based/test_executor.py @@ -25,7 +25,6 @@ from taskflow.types import failure class TestWorkerTaskExecutor(test.MockTestCase): - def setUp(self): super().setUp() self.task = test_utils.DummyTask() @@ -42,9 +41,11 @@ class TestWorkerTaskExecutor(test.MockTestCase): # patch classes self.proxy_mock, self.proxy_inst_mock = self.patchClass( - executor.proxy, 'Proxy') + executor.proxy, 'Proxy' + ) self.request_mock, self.request_inst_mock = self.patchClass( - executor.pr, 'Request', autospec=False) + executor.pr, 'Request', autospec=False + ) # other mocking self.proxy_inst_mock.start.side_effect = self._fake_proxy_start @@ -54,8 +55,10 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.request_inst_mock.created_on = 0 self.request_inst_mock.task_cls = self.task.name self.message_mock = mock.MagicMock(name='message') - self.message_mock.properties = {'correlation_id': self.task_uuid, - 'type': pr.RESPONSE} + self.message_mock.properties = { + 'correlation_id': self.task_uuid, + 'type': pr.RESPONSE, + } def _fake_proxy_start(self): self.proxy_started_event.set() @@ -66,10 +69,12 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.proxy_started_event.clear() def executor(self, reset_master_mock=True, **kwargs): - executor_kwargs = dict(uuid=self.executor_uuid, - exchange=self.executor_exchange, - topics=[self.executor_topic], - url=self.broker_url) + executor_kwargs = dict( + uuid=self.executor_uuid, + exchange=self.executor_exchange, + topics=[self.executor_topic], + url=self.broker_url, + ) executor_kwargs.update(kwargs) ex = executor.WorkerTaskExecutor(**executor_kwargs) if reset_master_mock: @@ -79,11 +84,15 @@ class TestWorkerTaskExecutor(test.MockTestCase): def test_creation(self): ex = self.executor(reset_master_mock=False) master_mock_calls = [ - mock.call.Proxy(self.executor_uuid, self.executor_exchange, - on_wait=ex._on_wait, - url=self.broker_url, transport=mock.ANY, - transport_options=mock.ANY, - retry_options=mock.ANY), + mock.call.Proxy( + self.executor_uuid, + self.executor_exchange, + on_wait=ex._on_wait, + url=self.broker_url, + transport=mock.ANY, + transport_options=mock.ANY, + retry_options=mock.ANY, + ), mock.call.proxy.dispatcher.type_handlers.update(mock.ANY), ] self.assertEqual(master_mock_calls, self.master_mock.mock_calls) @@ -100,16 +109,19 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.assertEqual(expected_calls, self.request_inst_mock.mock_calls) def test_on_message_response_state_progress(self): - response = pr.Response(pr.EVENT, - event_type=task_atom.EVENT_UPDATE_PROGRESS, - details={'progress': 1.0}) + response = pr.Response( + pr.EVENT, + event_type=task_atom.EVENT_UPDATE_PROGRESS, + details={'progress': 1.0}, + ) ex = self.executor() ex._ongoing_requests[self.task_uuid] = self.request_inst_mock ex._process_response(response.to_dict(), self.message_mock) expected_calls = [ - mock.call.task.notifier.notify(task_atom.EVENT_UPDATE_PROGRESS, - {'progress': 1.0}), + mock.call.task.notifier.notify( + task_atom.EVENT_UPDATE_PROGRESS, {'progress': 1.0} + ), ] self.assertEqual(expected_calls, self.request_inst_mock.mock_calls) @@ -124,20 +136,21 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.assertEqual(0, len(ex._ongoing_requests)) expected_calls = [ mock.call.transition_and_log_error(pr.FAILURE, logger=mock.ANY), - mock.call.set_result(result=test_utils.FailureMatcher(a_failure)) + mock.call.set_result(result=test_utils.FailureMatcher(a_failure)), ] self.assertEqual(expected_calls, self.request_inst_mock.mock_calls) def test_on_message_response_state_success(self): - response = pr.Response(pr.SUCCESS, result=self.task_result, - event='executed') + response = pr.Response( + pr.SUCCESS, result=self.task_result, event='executed' + ) ex = self.executor() ex._ongoing_requests[self.task_uuid] = self.request_inst_mock ex._process_response(response.to_dict(), self.message_mock) expected_calls = [ mock.call.transition_and_log_error(pr.SUCCESS, logger=mock.ANY), - mock.call.set_result(result=self.task_result) + mock.call.set_result(result=self.task_result), ] self.assertEqual(expected_calls, self.request_inst_mock.mock_calls) @@ -195,35 +208,57 @@ class TestWorkerTaskExecutor(test.MockTestCase): ex.execute_task(self.task, self.task_uuid, self.task_args) expected_calls = [ - mock.call.Request(self.task, self.task_uuid, 'execute', - self.task_args, timeout=self.timeout, - result=mock.ANY, failures=mock.ANY), - mock.call.request.transition_and_log_error(pr.PENDING, - logger=mock.ANY), - mock.call.proxy.publish(self.request_inst_mock, - self.executor_topic, - reply_to=self.executor_uuid, - correlation_id=self.task_uuid) + mock.call.Request( + self.task, + self.task_uuid, + 'execute', + self.task_args, + timeout=self.timeout, + result=mock.ANY, + failures=mock.ANY, + ), + mock.call.request.transition_and_log_error( + pr.PENDING, logger=mock.ANY + ), + mock.call.proxy.publish( + self.request_inst_mock, + self.executor_topic, + reply_to=self.executor_uuid, + correlation_id=self.task_uuid, + ), ] self.assertEqual(expected_calls, self.master_mock.mock_calls) def test_revert_task(self): ex = self.executor() ex._finder._add(self.executor_topic, [self.task.name]) - ex.revert_task(self.task, self.task_uuid, self.task_args, - self.task_result, self.task_failures) + ex.revert_task( + self.task, + self.task_uuid, + self.task_args, + self.task_result, + self.task_failures, + ) expected_calls = [ - mock.call.Request(self.task, self.task_uuid, 'revert', - self.task_args, timeout=self.timeout, - failures=self.task_failures, - result=self.task_result), - mock.call.request.transition_and_log_error(pr.PENDING, - logger=mock.ANY), - mock.call.proxy.publish(self.request_inst_mock, - self.executor_topic, - reply_to=self.executor_uuid, - correlation_id=self.task_uuid) + mock.call.Request( + self.task, + self.task_uuid, + 'revert', + self.task_args, + timeout=self.timeout, + failures=self.task_failures, + result=self.task_result, + ), + mock.call.request.transition_and_log_error( + pr.PENDING, logger=mock.ANY + ), + mock.call.proxy.publish( + self.request_inst_mock, + self.executor_topic, + reply_to=self.executor_uuid, + correlation_id=self.task_uuid, + ), ] self.assertEqual(expected_calls, self.master_mock.mock_calls) @@ -232,9 +267,15 @@ class TestWorkerTaskExecutor(test.MockTestCase): ex.execute_task(self.task, self.task_uuid, self.task_args) expected_calls = [ - mock.call.Request(self.task, self.task_uuid, 'execute', - self.task_args, timeout=self.timeout, - result=mock.ANY, failures=mock.ANY), + mock.call.Request( + self.task, + self.task_uuid, + 'execute', + self.task_args, + timeout=self.timeout, + result=mock.ANY, + failures=mock.ANY, + ), ] self.assertEqual(expected_calls, self.master_mock.mock_calls) @@ -245,18 +286,28 @@ class TestWorkerTaskExecutor(test.MockTestCase): ex.execute_task(self.task, self.task_uuid, self.task_args) expected_calls = [ - mock.call.Request(self.task, self.task_uuid, 'execute', - self.task_args, timeout=self.timeout, - result=mock.ANY, failures=mock.ANY), - mock.call.request.transition_and_log_error(pr.PENDING, - logger=mock.ANY), - mock.call.proxy.publish(self.request_inst_mock, - self.executor_topic, - reply_to=self.executor_uuid, - correlation_id=self.task_uuid), - mock.call.request.transition_and_log_error(pr.FAILURE, - logger=mock.ANY), - mock.call.request.set_result(mock.ANY) + mock.call.Request( + self.task, + self.task_uuid, + 'execute', + self.task_args, + timeout=self.timeout, + result=mock.ANY, + failures=mock.ANY, + ), + mock.call.request.transition_and_log_error( + pr.PENDING, logger=mock.ANY + ), + mock.call.proxy.publish( + self.request_inst_mock, + self.executor_topic, + reply_to=self.executor_uuid, + correlation_id=self.task_uuid, + ), + mock.call.request.transition_and_log_error( + pr.FAILURE, logger=mock.ANY + ), + mock.call.request.set_result(mock.ANY), ] self.assertEqual(expected_calls, self.master_mock.mock_calls) @@ -270,11 +321,14 @@ class TestWorkerTaskExecutor(test.MockTestCase): # stop executor ex.stop() - self.master_mock.assert_has_calls([ - mock.call.proxy.start(), - mock.call.proxy.wait(), - mock.call.proxy.stop() - ], any_order=True) + self.master_mock.assert_has_calls( + [ + mock.call.proxy.start(), + mock.call.proxy.wait(), + mock.call.proxy.stop(), + ], + any_order=True, + ) def test_start_already_running(self): ex = self.executor() @@ -289,11 +343,14 @@ class TestWorkerTaskExecutor(test.MockTestCase): # stop executor ex.stop() - self.master_mock.assert_has_calls([ - mock.call.proxy.start(), - mock.call.proxy.wait(), - mock.call.proxy.stop() - ], any_order=True) + self.master_mock.assert_has_calls( + [ + mock.call.proxy.start(), + mock.call.proxy.wait(), + mock.call.proxy.stop(), + ], + any_order=True, + ) def test_stop_not_running(self): self.executor().stop() @@ -311,10 +368,9 @@ class TestWorkerTaskExecutor(test.MockTestCase): ex.stop() # since proxy thread is already done - stop is not called - self.master_mock.assert_has_calls([ - mock.call.proxy.start(), - mock.call.proxy.wait() - ], any_order=True) + self.master_mock.assert_has_calls( + [mock.call.proxy.start(), mock.call.proxy.wait()], any_order=True + ) def test_restart(self): ex = self.executor() @@ -333,11 +389,14 @@ class TestWorkerTaskExecutor(test.MockTestCase): # stop executor ex.stop() - self.master_mock.assert_has_calls([ - mock.call.proxy.start(), - mock.call.proxy.wait(), - mock.call.proxy.stop(), - mock.call.proxy.start(), - mock.call.proxy.wait(), - mock.call.proxy.stop() - ], any_order=True) + self.master_mock.assert_has_calls( + [ + mock.call.proxy.start(), + mock.call.proxy.wait(), + mock.call.proxy.stop(), + mock.call.proxy.start(), + mock.call.proxy.wait(), + mock.call.proxy.stop(), + ], + any_order=True, + ) diff --git a/taskflow/tests/unit/worker_based/test_message_pump.py b/taskflow/tests/unit/worker_based/test_message_pump.py index 12e29fba1..988f7f975 100644 --- a/taskflow/tests/unit/worker_based/test_message_pump.py +++ b/taskflow/tests/unit/worker_based/test_message_pump.py @@ -37,11 +37,15 @@ class TestMessagePump(test.TestCase): on_notify.side_effect = lambda *args, **kwargs: barrier.set() handlers = {pr.NOTIFY: dispatcher.Handler(on_notify)} - p = proxy.Proxy(TEST_TOPIC, TEST_EXCHANGE, handlers, - transport='memory', - transport_options={ - 'polling_interval': POLLING_INTERVAL, - }) + p = proxy.Proxy( + TEST_TOPIC, + TEST_EXCHANGE, + handlers, + transport='memory', + transport_options={ + 'polling_interval': POLLING_INTERVAL, + }, + ) t = threading_utils.daemon_thread(p.start) t.start() @@ -62,11 +66,15 @@ class TestMessagePump(test.TestCase): on_response.side_effect = lambda *args, **kwargs: barrier.set() handlers = {pr.RESPONSE: dispatcher.Handler(on_response)} - p = proxy.Proxy(TEST_TOPIC, TEST_EXCHANGE, handlers, - transport='memory', - transport_options={ - 'polling_interval': POLLING_INTERVAL, - }) + p = proxy.Proxy( + TEST_TOPIC, + TEST_EXCHANGE, + handlers, + transport='memory', + transport_options={ + 'polling_interval': POLLING_INTERVAL, + }, + ) t = threading_utils.daemon_thread(p.start) t.start() @@ -101,11 +109,15 @@ class TestMessagePump(test.TestCase): pr.RESPONSE: dispatcher.Handler(on_response), pr.REQUEST: dispatcher.Handler(on_request), } - p = proxy.Proxy(TEST_TOPIC, TEST_EXCHANGE, handlers, - transport='memory', - transport_options={ - 'polling_interval': POLLING_INTERVAL, - }) + p = proxy.Proxy( + TEST_TOPIC, + TEST_EXCHANGE, + handlers, + transport='memory', + transport_options={ + 'polling_interval': POLLING_INTERVAL, + }, + ) t = threading_utils.daemon_thread(p.start) t.start() @@ -118,9 +130,16 @@ class TestMessagePump(test.TestCase): elif j == 1: p.publish(pr.Response(pr.RUNNING), TEST_TOPIC) else: - p.publish(pr.Request(test_utils.DummyTask("dummy_%s" % i), - uuidutils.generate_uuid(), - pr.EXECUTE, [], None), TEST_TOPIC) + p.publish( + pr.Request( + test_utils.DummyTask("dummy_%s" % i), + uuidutils.generate_uuid(), + pr.EXECUTE, + [], + None, + ), + TEST_TOPIC, + ) self.assertTrue(barrier.wait(test_utils.WAIT_TIMEOUT)) self.assertEqual(0, barrier.needed) @@ -135,9 +154,11 @@ class TestMessagePump(test.TestCase): self.assertEqual(10, on_response.call_count) self.assertEqual(10, on_request.call_count) - call_count = sum([ - on_notify.call_count, - on_response.call_count, - on_request.call_count, - ]) + call_count = sum( + [ + on_notify.call_count, + on_response.call_count, + on_request.call_count, + ] + ) self.assertEqual(message_count, call_count) diff --git a/taskflow/tests/unit/worker_based/test_pipeline.py b/taskflow/tests/unit/worker_based/test_pipeline.py index 167809ee6..f8c34e24a 100644 --- a/taskflow/tests/unit/worker_based/test_pipeline.py +++ b/taskflow/tests/unit/worker_based/test_pipeline.py @@ -37,12 +37,15 @@ class TestPipeline(test.TestCase): for cls in task_classes: endpoints.append(endpoint.Endpoint(cls)) server = worker_server.Server( - TEST_TOPIC, TEST_EXCHANGE, - futurist.ThreadPoolExecutor(max_workers=1), endpoints, + TEST_TOPIC, + TEST_EXCHANGE, + futurist.ThreadPoolExecutor(max_workers=1), + endpoints, transport='memory', transport_options={ 'polling_interval': POLLING_INTERVAL, - }) + }, + ) server_thread = threading_utils.daemon_thread(server.start) return (server, server_thread) @@ -54,7 +57,8 @@ class TestPipeline(test.TestCase): transport='memory', transport_options={ 'polling_interval': POLLING_INTERVAL, - }) + }, + ) return executor def _start_components(self, task_classes): @@ -74,8 +78,12 @@ class TestPipeline(test.TestCase): t = test_utils.TaskOneReturn() progress_callback = lambda *args, **kwargs: None - f = executor.execute_task(t, uuidutils.generate_uuid(), {}, - progress_callback=progress_callback) + f = executor.execute_task( + t, + uuidutils.generate_uuid(), + {}, + progress_callback=progress_callback, + ) waiters.wait_for_any([f]) event, result = f.result() @@ -90,8 +98,12 @@ class TestPipeline(test.TestCase): t = test_utils.TaskWithFailure() progress_callback = lambda *args, **kwargs: None - f = executor.execute_task(t, uuidutils.generate_uuid(), {}, - progress_callback=progress_callback) + f = executor.execute_task( + t, + uuidutils.generate_uuid(), + {}, + progress_callback=progress_callback, + ) waiters.wait_for_any([f]) action, result = f.result() diff --git a/taskflow/tests/unit/worker_based/test_protocol.py b/taskflow/tests/unit/worker_based/test_protocol.py index bec76ab3f..aca8a384c 100644 --- a/taskflow/tests/unit/worker_based/test_protocol.py +++ b/taskflow/tests/unit/worker_based/test_protocol.py @@ -36,8 +36,7 @@ class TestProtocolValidation(test.TestCase): msg = { 'all your base': 'are belong to us', } - self.assertRaises(excp.InvalidFormat, - pr.Notify.validate, msg, False) + self.assertRaises(excp.InvalidFormat, pr.Notify.validate, msg, False) def test_reply_notify(self): msg = pr.Notify(topic="bob", tasks=['a', 'b', 'c']) @@ -48,13 +47,16 @@ class TestProtocolValidation(test.TestCase): 'topic': {}, 'tasks': 'not yours', } - self.assertRaises(excp.InvalidFormat, - pr.Notify.validate, msg, True) + self.assertRaises(excp.InvalidFormat, pr.Notify.validate, msg, True) def test_request(self): - request = pr.Request(utils.DummyTask("hi"), - uuidutils.generate_uuid(), - pr.EXECUTE, {}, 1.0) + request = pr.Request( + utils.DummyTask("hi"), + uuidutils.generate_uuid(), + pr.EXECUTE, + {}, + 1.0, + ) pr.Request.validate(request.to_dict()) def test_request_invalid(self): @@ -66,16 +68,21 @@ class TestProtocolValidation(test.TestCase): self.assertRaises(excp.InvalidFormat, pr.Request.validate, msg) def test_request_invalid_action(self): - request = pr.Request(utils.DummyTask("hi"), - uuidutils.generate_uuid(), - pr.EXECUTE, {}, 1.0) + request = pr.Request( + utils.DummyTask("hi"), + uuidutils.generate_uuid(), + pr.EXECUTE, + {}, + 1.0, + ) request = request.to_dict() request['action'] = 'NOTHING' self.assertRaises(excp.InvalidFormat, pr.Request.validate, request) def test_response_progress(self): - msg = pr.Response(pr.EVENT, details={'progress': 0.5}, - event_type='blah') + msg = pr.Response( + pr.EVENT, details={'progress': 0.5}, event_type='blah' + ) pr.Response.validate(msg.to_dict()) def test_response_completion(self): @@ -83,9 +90,9 @@ class TestProtocolValidation(test.TestCase): pr.Response.validate(msg.to_dict()) def test_response_mixed_invalid(self): - msg = pr.Response(pr.EVENT, - details={'progress': 0.5}, - event_type='blah', result=1) + msg = pr.Response( + pr.EVENT, details={'progress': 0.5}, event_type='blah', result=1 + ) self.assertRaises(excp.InvalidFormat, pr.Response.validate, msg) def test_response_bad_state(self): @@ -94,7 +101,6 @@ class TestProtocolValidation(test.TestCase): class TestProtocol(test.TestCase): - def setUp(self): super().setUp() self.task = utils.DummyTask() @@ -104,20 +110,24 @@ class TestProtocol(test.TestCase): self.timeout = 60 def request(self, **kwargs): - request_kwargs = dict(task=self.task, - uuid=self.task_uuid, - action=self.task_action, - arguments=self.task_args, - timeout=self.timeout) + request_kwargs = dict( + task=self.task, + uuid=self.task_uuid, + action=self.task_action, + arguments=self.task_args, + timeout=self.timeout, + ) request_kwargs.update(kwargs) return pr.Request(**request_kwargs) def request_to_dict(self, **kwargs): - to_dict = dict(task_cls=self.task.name, - task_name=self.task.name, - task_version=self.task.version, - action=self.task_action, - arguments=self.task_args) + to_dict = dict( + task_cls=self.task.name, + task_name=self.task.name, + task_version=self.task.version, + action=self.task_action, + arguments=self.task_args, + ) to_dict.update(kwargs) return to_dict @@ -145,18 +155,21 @@ class TestProtocol(test.TestCase): def test_to_dict_with_result(self): request = self.request(result=333) - self.assertEqual(self.request_to_dict(result=('success', 333)), - request.to_dict()) + self.assertEqual( + self.request_to_dict(result=('success', 333)), request.to_dict() + ) def test_to_dict_with_result_none(self): request = self.request(result=None) - self.assertEqual(self.request_to_dict(result=('success', None)), - request.to_dict()) + self.assertEqual( + self.request_to_dict(result=('success', None)), request.to_dict() + ) def test_to_dict_with_result_failure(self): a_failure = failure.Failure.from_exception(RuntimeError('Woot!')) - expected = self.request_to_dict(result=('failure', - a_failure.to_dict())) + expected = self.request_to_dict( + result=('failure', a_failure.to_dict()) + ) request = self.request(result=a_failure) self.assertEqual(expected, request.to_dict()) @@ -164,7 +177,8 @@ class TestProtocol(test.TestCase): a_failure = failure.Failure.from_exception(RuntimeError('Woot!')) request = self.request(failures={self.task.name: a_failure}) expected = self.request_to_dict( - failures={self.task.name: a_failure.to_dict()}) + failures={self.task.name: a_failure.to_dict()} + ) self.assertEqual(expected, request.to_dict()) def test_to_dict_with_invalid_json_failures(self): @@ -172,7 +186,8 @@ class TestProtocol(test.TestCase): a_failure = failure.Failure.from_exception(exc) request = self.request(failures={self.task.name: a_failure}) expected = self.request_to_dict( - failures={self.task.name: a_failure.to_dict(include_args=False)}) + failures={self.task.name: a_failure.to_dict(include_args=False)} + ) self.assertEqual(expected, request.to_dict()) @mock.patch('oslo_utils.timeutils.now') diff --git a/taskflow/tests/unit/worker_based/test_proxy.py b/taskflow/tests/unit/worker_based/test_proxy.py index 1d77d22be..07b57b604 100644 --- a/taskflow/tests/unit/worker_based/test_proxy.py +++ b/taskflow/tests/unit/worker_based/test_proxy.py @@ -21,7 +21,6 @@ from taskflow.utils import threading_utils class TestProxy(test.MockTestCase): - def setUp(self): super().setUp() self.topic = 'test-topic' @@ -32,27 +31,36 @@ class TestProxy(test.MockTestCase): # patch classes self.conn_mock, self.conn_inst_mock = self.patchClass( - proxy.kombu, 'Connection') + proxy.kombu, 'Connection' + ) self.exchange_mock, self.exchange_inst_mock = self.patchClass( - proxy.kombu, 'Exchange') + proxy.kombu, 'Exchange' + ) self.queue_mock, self.queue_inst_mock = self.patchClass( - proxy.kombu, 'Queue') + proxy.kombu, 'Queue' + ) self.producer_mock, self.producer_inst_mock = self.patchClass( - proxy.kombu, 'Producer') + proxy.kombu, 'Producer' + ) # connection mocking def _ensure(obj, func, *args, **kwargs): return func + self.conn_inst_mock.drain_events.side_effect = [ - socket.timeout, socket.timeout, KeyboardInterrupt] + socket.timeout, + socket.timeout, + KeyboardInterrupt, + ] self.conn_inst_mock.ensure = mock.MagicMock(side_effect=_ensure) # connections mocking self.connections_mock = self.patch( "taskflow.engines.worker_based.proxy.kombu.connections", - attach_as='connections') - self.connections_mock.__getitem__().acquire().__enter__.return_value =\ - self.conn_inst_mock + attach_as='connections', + ) + acquire_mock = self.connections_mock.__getitem__().acquire() + acquire_mock.__enter__.return_value = self.conn_inst_mock # producers mocking self.conn_inst_mock.Producer.return_value.__enter__ = mock.MagicMock() @@ -73,53 +81,76 @@ class TestProxy(test.MockTestCase): return f"{self.exchange}_{topic}" def proxy_start_calls(self, calls, exc_type=mock.ANY): - return [ - mock.call.Queue(name=self._queue_name(self.topic), - exchange=self.exchange_inst_mock, - routing_key=self.topic, - durable=False, - auto_delete=True, - channel=self.conn_inst_mock), - mock.call.connection.Consumer(queues=self.queue_inst_mock, - callbacks=[mock.ANY]), - mock.call.connection.Consumer().__enter__(), - mock.call.connection.ensure(mock.ANY, mock.ANY, - interval_start=mock.ANY, - interval_max=mock.ANY, - max_retries=mock.ANY, - interval_step=mock.ANY, - errback=mock.ANY), - ] + calls + [ - mock.call.connection.Consumer().__exit__(exc_type, mock.ANY, - mock.ANY) - ] + return ( + [ + mock.call.Queue( + name=self._queue_name(self.topic), + exchange=self.exchange_inst_mock, + routing_key=self.topic, + durable=False, + auto_delete=True, + channel=self.conn_inst_mock, + ), + mock.call.connection.Consumer( + queues=self.queue_inst_mock, callbacks=[mock.ANY] + ), + mock.call.connection.Consumer().__enter__(), + mock.call.connection.ensure( + mock.ANY, + mock.ANY, + interval_start=mock.ANY, + interval_max=mock.ANY, + max_retries=mock.ANY, + interval_step=mock.ANY, + errback=mock.ANY, + ), + ] + + calls + + [ + mock.call.connection.Consumer().__exit__( + exc_type, mock.ANY, mock.ANY + ) + ] + ) def proxy_publish_calls(self, calls, routing_key, exc_type=mock.ANY): - return [ - mock.call.connection.Producer(), - mock.call.connection.Producer().__enter__(), - mock.call.connection.ensure(mock.ANY, mock.ANY, - interval_start=mock.ANY, - interval_max=mock.ANY, - max_retries=mock.ANY, - interval_step=mock.ANY, - errback=mock.ANY), - mock.call.Queue(name=self._queue_name(routing_key), - routing_key=routing_key, - exchange=self.exchange_inst_mock, - durable=False, - auto_delete=True, - channel=None), - ] + calls + [ - mock.call.connection.Producer().__exit__(exc_type, mock.ANY, - mock.ANY) - ] + return ( + [ + mock.call.connection.Producer(), + mock.call.connection.Producer().__enter__(), + mock.call.connection.ensure( + mock.ANY, + mock.ANY, + interval_start=mock.ANY, + interval_max=mock.ANY, + max_retries=mock.ANY, + interval_step=mock.ANY, + errback=mock.ANY, + ), + mock.call.Queue( + name=self._queue_name(routing_key), + routing_key=routing_key, + exchange=self.exchange_inst_mock, + durable=False, + auto_delete=True, + channel=None, + ), + ] + + calls + + [ + mock.call.connection.Producer().__exit__( + exc_type, mock.ANY, mock.ANY + ) + ] + ) def proxy(self, reset_master_mock=False, **kwargs): - proxy_kwargs = dict(topic=self.topic, - exchange=self.exchange, - url=self.broker_url, - type_handlers={}) + proxy_kwargs = dict( + topic=self.topic, + exchange=self.exchange, + url=self.broker_url, + type_handlers={}, + ) proxy_kwargs.update(kwargs) p = proxy.Proxy(**proxy_kwargs) if reset_master_mock: @@ -130,11 +161,12 @@ class TestProxy(test.MockTestCase): self.proxy() master_mock_calls = [ - mock.call.Connection(self.broker_url, transport=None, - transport_options=None), - mock.call.Exchange(name=self.exchange, - durable=False, - auto_delete=True) + mock.call.Connection( + self.broker_url, transport=None, transport_options=None + ), + mock.call.Exchange( + name=self.exchange, durable=False, auto_delete=True + ), ] self.assertEqual(master_mock_calls, self.master_mock.mock_calls) @@ -143,11 +175,14 @@ class TestProxy(test.MockTestCase): self.proxy(transport='memory', transport_options=transport_opts) master_mock_calls = [ - mock.call.Connection(self.broker_url, transport='memory', - transport_options=transport_opts), - mock.call.Exchange(name=self.exchange, - durable=False, - auto_delete=True) + mock.call.Connection( + self.broker_url, + transport='memory', + transport_options=transport_opts, + ), + mock.call.Exchange( + name=self.exchange, durable=False, auto_delete=True + ), ] self.assertEqual(master_mock_calls, self.master_mock.mock_calls) @@ -162,15 +197,20 @@ class TestProxy(test.MockTestCase): p.publish(msg_mock, routing_key, correlation_id=task_uuid) mock_producer = mock.call.connection.Producer() - master_mock_calls = self.proxy_publish_calls([ - mock_producer.__enter__().publish(body=msg_data, - routing_key=routing_key, - exchange=self.exchange_inst_mock, - correlation_id=task_uuid, - declare=[self.queue_inst_mock], - type=msg_mock.TYPE, - reply_to=None) - ], routing_key) + master_mock_calls = self.proxy_publish_calls( + [ + mock_producer.__enter__().publish( + body=msg_data, + routing_key=routing_key, + exchange=self.exchange_inst_mock, + correlation_id=task_uuid, + declare=[self.queue_inst_mock], + type=msg_mock.TYPE, + reply_to=None, + ) + ], + routing_key, + ) self.master_mock.assert_has_calls(master_mock_calls) def test_start(self): @@ -180,43 +220,54 @@ class TestProxy(test.MockTestCase): except KeyboardInterrupt: pass - master_calls = self.proxy_start_calls([ - mock.call.connection.drain_events(timeout=self.de_period), - mock.call.connection.drain_events(timeout=self.de_period), - mock.call.connection.drain_events(timeout=self.de_period), - ], exc_type=KeyboardInterrupt) + master_calls = self.proxy_start_calls( + [ + mock.call.connection.drain_events(timeout=self.de_period), + mock.call.connection.drain_events(timeout=self.de_period), + mock.call.connection.drain_events(timeout=self.de_period), + ], + exc_type=KeyboardInterrupt, + ) self.master_mock.assert_has_calls(master_calls) def test_start_with_on_wait(self): try: # KeyboardInterrupt will be raised after two iterations - self.proxy(reset_master_mock=True, - on_wait=self.on_wait_mock).start() + self.proxy( + reset_master_mock=True, on_wait=self.on_wait_mock + ).start() except KeyboardInterrupt: pass - master_calls = self.proxy_start_calls([ - mock.call.connection.drain_events(timeout=self.de_period), - mock.call.on_wait(), - mock.call.connection.drain_events(timeout=self.de_period), - mock.call.on_wait(), - mock.call.connection.drain_events(timeout=self.de_period), - ], exc_type=KeyboardInterrupt) + master_calls = self.proxy_start_calls( + [ + mock.call.connection.drain_events(timeout=self.de_period), + mock.call.on_wait(), + mock.call.connection.drain_events(timeout=self.de_period), + mock.call.on_wait(), + mock.call.connection.drain_events(timeout=self.de_period), + ], + exc_type=KeyboardInterrupt, + ) self.master_mock.assert_has_calls(master_calls) def test_start_with_on_wait_raises(self): self.on_wait_mock.side_effect = RuntimeError('Woot!') try: # KeyboardInterrupt will be raised after two iterations - self.proxy(reset_master_mock=True, - on_wait=self.on_wait_mock).start() + self.proxy( + reset_master_mock=True, on_wait=self.on_wait_mock + ).start() except KeyboardInterrupt: pass - master_calls = self.proxy_start_calls([ - mock.call.connection.drain_events(timeout=self.de_period), - mock.call.on_wait(), - ], exc_type=RuntimeError) + master_calls = self.proxy_start_calls( + [ + mock.call.connection.drain_events(timeout=self.de_period), + mock.call.on_wait(), + ], + exc_type=RuntimeError, + ) self.master_mock.assert_has_calls(master_calls) def test_stop(self): diff --git a/taskflow/tests/unit/worker_based/test_server.py b/taskflow/tests/unit/worker_based/test_server.py index 892dda101..6831207a7 100644 --- a/taskflow/tests/unit/worker_based/test_server.py +++ b/taskflow/tests/unit/worker_based/test_server.py @@ -23,7 +23,6 @@ from taskflow.types import failure class TestServer(test.MockTestCase): - def setUp(self): super().setUp() self.server_topic = 'server-topic' @@ -34,32 +33,40 @@ class TestServer(test.MockTestCase): self.task_args = {'x': 1} self.task_action = 'execute' self.reply_to = 'reply-to' - self.endpoints = [ep.Endpoint(task_cls=utils.TaskOneArgOneReturn), - ep.Endpoint(task_cls=utils.TaskWithFailure), - ep.Endpoint(task_cls=utils.ProgressingTask)] + self.endpoints = [ + ep.Endpoint(task_cls=utils.TaskOneArgOneReturn), + ep.Endpoint(task_cls=utils.TaskWithFailure), + ep.Endpoint(task_cls=utils.ProgressingTask), + ] # patch classes self.proxy_mock, self.proxy_inst_mock = self.patchClass( - server.proxy, 'Proxy') + server.proxy, 'Proxy' + ) self.response_mock, self.response_inst_mock = self.patchClass( - server.pr, 'Response') + server.pr, 'Response' + ) # other mocking self.proxy_inst_mock.is_running = True self.executor_mock = mock.MagicMock(name='executor') self.message_mock = mock.MagicMock(name='message') - self.message_mock.properties = {'correlation_id': self.task_uuid, - 'reply_to': self.reply_to, - 'type': pr.REQUEST} + self.message_mock.properties = { + 'correlation_id': self.task_uuid, + 'reply_to': self.reply_to, + 'type': pr.REQUEST, + } self.master_mock.attach_mock(self.executor_mock, 'executor') self.master_mock.attach_mock(self.message_mock, 'message') def server(self, reset_master_mock=False, **kwargs): - server_kwargs = dict(topic=self.server_topic, - exchange=self.server_exchange, - executor=self.executor_mock, - endpoints=self.endpoints, - url=self.broker_url) + server_kwargs = dict( + topic=self.server_topic, + exchange=self.server_exchange, + executor=self.executor_mock, + endpoints=self.endpoints, + url=self.broker_url, + ) server_kwargs.update(kwargs) s = server.Server(**server_kwargs) if reset_master_mock: @@ -67,11 +74,13 @@ class TestServer(test.MockTestCase): return s def make_request(self, **kwargs): - request_kwargs = dict(task=self.task, - uuid=self.task_uuid, - action=self.task_action, - arguments=self.task_args, - timeout=60) + request_kwargs = dict( + task=self.task, + uuid=self.task_uuid, + action=self.task_action, + arguments=self.task_args, + timeout=60, + ) request_kwargs.update(kwargs) request = pr.Request(**request_kwargs) return request.to_dict() @@ -81,10 +90,15 @@ class TestServer(test.MockTestCase): # check calls master_mock_calls = [ - mock.call.Proxy(self.server_topic, self.server_exchange, - type_handlers=mock.ANY, url=self.broker_url, - transport=mock.ANY, transport_options=mock.ANY, - retry_options=mock.ANY) + mock.call.Proxy( + self.server_topic, + self.server_exchange, + type_handlers=mock.ANY, + url=self.broker_url, + transport=mock.ANY, + transport_options=mock.ANY, + retry_options=mock.ANY, + ) ] self.master_mock.assert_has_calls(master_mock_calls) self.assertEqual(3, len(s._endpoints)) @@ -94,10 +108,15 @@ class TestServer(test.MockTestCase): # check calls master_mock_calls = [ - mock.call.Proxy(self.server_topic, self.server_exchange, - type_handlers=mock.ANY, url=self.broker_url, - transport=mock.ANY, transport_options=mock.ANY, - retry_options=mock.ANY) + mock.call.Proxy( + self.server_topic, + self.server_exchange, + type_handlers=mock.ANY, + url=self.broker_url, + transport=mock.ANY, + transport_options=mock.ANY, + retry_options=mock.ANY, + ) ] self.master_mock.assert_has_calls(master_mock_calls) self.assertEqual(len(self.endpoints), len(s._endpoints)) @@ -106,41 +125,70 @@ class TestServer(test.MockTestCase): request = self.make_request() bundle = pr.Request.from_dict(request) task_cls, task_name, action, task_args = bundle - self.assertEqual((self.task.name, self.task.name, self.task_action, - dict(arguments=self.task_args)), - (task_cls, task_name, action, task_args)) + self.assertEqual( + ( + self.task.name, + self.task.name, + self.task_action, + dict(arguments=self.task_args), + ), + (task_cls, task_name, action, task_args), + ) def test_parse_request_with_success_result(self): request = self.make_request(action='revert', result=1) bundle = pr.Request.from_dict(request) task_cls, task_name, action, task_args = bundle - self.assertEqual((self.task.name, self.task.name, 'revert', - dict(arguments=self.task_args, - result=1)), - (task_cls, task_name, action, task_args)) + self.assertEqual( + ( + self.task.name, + self.task.name, + 'revert', + dict(arguments=self.task_args, result=1), + ), + (task_cls, task_name, action, task_args), + ) def test_parse_request_with_failure_result(self): a_failure = failure.Failure.from_exception(Exception('test')) request = self.make_request(action='revert', result=a_failure) bundle = pr.Request.from_dict(request) task_cls, task_name, action, task_args = bundle - self.assertEqual((self.task.name, self.task.name, 'revert', - dict(arguments=self.task_args, - result=utils.FailureMatcher(a_failure))), - (task_cls, task_name, action, task_args)) + self.assertEqual( + ( + self.task.name, + self.task.name, + 'revert', + dict( + arguments=self.task_args, + result=utils.FailureMatcher(a_failure), + ), + ), + (task_cls, task_name, action, task_args), + ) def test_parse_request_with_failures(self): - failures = {'0': failure.Failure.from_exception(Exception('test1')), - '1': failure.Failure.from_exception(Exception('test2'))} + failures = { + '0': failure.Failure.from_exception(Exception('test1')), + '1': failure.Failure.from_exception(Exception('test2')), + } request = self.make_request(action='revert', failures=failures) bundle = pr.Request.from_dict(request) task_cls, task_name, action, task_args = bundle self.assertEqual( - (self.task.name, self.task.name, 'revert', - dict(arguments=self.task_args, - failures={i: utils.FailureMatcher(f) - for i, f in failures.items()})), - (task_cls, task_name, action, task_args)) + ( + self.task.name, + self.task.name, + 'revert', + dict( + arguments=self.task_args, + failures={ + i: utils.FailureMatcher(f) for i, f in failures.items() + }, + ), + ), + (task_cls, task_name, action, task_args), + ) @mock.patch("taskflow.engines.worker_based.server.LOG.critical") def test_reply_publish_failure(self, mocked_exception): @@ -150,11 +198,16 @@ class TestServer(test.MockTestCase): s = self.server(reset_master_mock=True) s._reply(True, self.reply_to, self.task_uuid) - self.master_mock.assert_has_calls([ - mock.call.Response(pr.FAILURE), - mock.call.proxy.publish(self.response_inst_mock, self.reply_to, - correlation_id=self.task_uuid) - ]) + self.master_mock.assert_has_calls( + [ + mock.call.Response(pr.FAILURE), + mock.call.proxy.publish( + self.response_inst_mock, + self.reply_to, + correlation_id=self.task_uuid, + ), + ] + ) self.assertTrue(mocked_exception.called) def test_on_run_reply_failure(self): @@ -177,19 +230,37 @@ class TestServer(test.MockTestCase): # check calls master_mock_calls = [ mock.call.Response(pr.RUNNING), - mock.call.proxy.publish(self.response_inst_mock, self.reply_to, - correlation_id=self.task_uuid), - mock.call.Response(pr.EVENT, details={'progress': 0.0}, - event_type=task_atom.EVENT_UPDATE_PROGRESS), - mock.call.proxy.publish(self.response_inst_mock, self.reply_to, - correlation_id=self.task_uuid), - mock.call.Response(pr.EVENT, details={'progress': 1.0}, - event_type=task_atom.EVENT_UPDATE_PROGRESS), - mock.call.proxy.publish(self.response_inst_mock, self.reply_to, - correlation_id=self.task_uuid), + mock.call.proxy.publish( + self.response_inst_mock, + self.reply_to, + correlation_id=self.task_uuid, + ), + mock.call.Response( + pr.EVENT, + details={'progress': 0.0}, + event_type=task_atom.EVENT_UPDATE_PROGRESS, + ), + mock.call.proxy.publish( + self.response_inst_mock, + self.reply_to, + correlation_id=self.task_uuid, + ), + mock.call.Response( + pr.EVENT, + details={'progress': 1.0}, + event_type=task_atom.EVENT_UPDATE_PROGRESS, + ), + mock.call.proxy.publish( + self.response_inst_mock, + self.reply_to, + correlation_id=self.task_uuid, + ), mock.call.Response(pr.SUCCESS, result=5), - mock.call.proxy.publish(self.response_inst_mock, self.reply_to, - correlation_id=self.task_uuid) + mock.call.proxy.publish( + self.response_inst_mock, + self.reply_to, + correlation_id=self.task_uuid, + ), ] self.master_mock.assert_has_calls(master_mock_calls) @@ -201,11 +272,17 @@ class TestServer(test.MockTestCase): # check calls master_mock_calls = [ mock.call.Response(pr.RUNNING), - mock.call.proxy.publish(self.response_inst_mock, self.reply_to, - correlation_id=self.task_uuid), + mock.call.proxy.publish( + self.response_inst_mock, + self.reply_to, + correlation_id=self.task_uuid, + ), mock.call.Response(pr.SUCCESS, result=1), - mock.call.proxy.publish(self.response_inst_mock, self.reply_to, - correlation_id=self.task_uuid) + mock.call.proxy.publish( + self.response_inst_mock, + self.reply_to, + correlation_id=self.task_uuid, + ), ] self.master_mock.assert_has_calls(master_mock_calls) @@ -235,9 +312,11 @@ class TestServer(test.MockTestCase): # check calls master_mock_calls = [ mock.call.Response(pr.FAILURE, result=failure_dict), - mock.call.proxy.publish(self.response_inst_mock, - self.reply_to, - correlation_id=self.task_uuid) + mock.call.proxy.publish( + self.response_inst_mock, + self.reply_to, + correlation_id=self.task_uuid, + ), ] self.master_mock.assert_has_calls(master_mock_calls) @@ -256,9 +335,11 @@ class TestServer(test.MockTestCase): # check calls master_mock_calls = [ mock.call.Response(pr.FAILURE, result=failure_dict), - mock.call.proxy.publish(self.response_inst_mock, - self.reply_to, - correlation_id=self.task_uuid) + mock.call.proxy.publish( + self.response_inst_mock, + self.reply_to, + correlation_id=self.task_uuid, + ), ] self.master_mock.assert_has_calls(master_mock_calls) @@ -278,9 +359,11 @@ class TestServer(test.MockTestCase): # check calls master_mock_calls = [ mock.call.Response(pr.FAILURE, result=failure_dict), - mock.call.proxy.publish(self.response_inst_mock, - self.reply_to, - correlation_id=self.task_uuid) + mock.call.proxy.publish( + self.response_inst_mock, + self.reply_to, + correlation_id=self.task_uuid, + ), ] self.master_mock.assert_has_calls(master_mock_calls) @@ -299,12 +382,17 @@ class TestServer(test.MockTestCase): # check calls master_mock_calls = [ mock.call.Response(pr.RUNNING), - mock.call.proxy.publish(self.response_inst_mock, self.reply_to, - correlation_id=self.task_uuid), + mock.call.proxy.publish( + self.response_inst_mock, + self.reply_to, + correlation_id=self.task_uuid, + ), mock.call.Response(pr.FAILURE, result=failure_dict), - mock.call.proxy.publish(self.response_inst_mock, - self.reply_to, - correlation_id=self.task_uuid) + mock.call.proxy.publish( + self.response_inst_mock, + self.reply_to, + correlation_id=self.task_uuid, + ), ] self.master_mock.assert_has_calls(master_mock_calls) @@ -312,9 +400,7 @@ class TestServer(test.MockTestCase): self.server(reset_master_mock=True).start() # check calls - master_mock_calls = [ - mock.call.proxy.start() - ] + master_mock_calls = [mock.call.proxy.start()] self.master_mock.assert_has_calls(master_mock_calls) def test_wait(self): @@ -323,17 +409,12 @@ class TestServer(test.MockTestCase): server.wait() # check calls - master_mock_calls = [ - mock.call.proxy.start(), - mock.call.proxy.wait() - ] + master_mock_calls = [mock.call.proxy.start(), mock.call.proxy.wait()] self.master_mock.assert_has_calls(master_mock_calls) def test_stop(self): self.server(reset_master_mock=True).stop() # check calls - master_mock_calls = [ - mock.call.proxy.stop() - ] + master_mock_calls = [mock.call.proxy.stop()] self.master_mock.assert_has_calls(master_mock_calls) diff --git a/taskflow/tests/unit/worker_based/test_types.py b/taskflow/tests/unit/worker_based/test_types.py index 03cdbd1e8..79368c0ea 100644 --- a/taskflow/tests/unit/worker_based/test_types.py +++ b/taskflow/tests/unit/worker_based/test_types.py @@ -22,8 +22,9 @@ from taskflow.tests import utils class TestTopicWorker(test.TestCase): def test_topic_worker(self): - worker = worker_types.TopicWorker("dummy-topic", - [utils.DummyTask], identity="dummy") + worker = worker_types.TopicWorker( + "dummy-topic", [utils.DummyTask], identity="dummy" + ) self.assertTrue(worker.performs(utils.DummyTask)) self.assertFalse(worker.performs(utils.NastyTask)) self.assertEqual('dummy', worker.identity) @@ -31,11 +32,11 @@ class TestTopicWorker(test.TestCase): class TestProxyFinder(test.TestCase): - @mock.patch("oslo_utils.timeutils.now") def test_expiry(self, mock_now): - finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), [], - worker_expiry=60) + finder = worker_types.ProxyWorkerFinder( + 'me', mock.MagicMock(), [], worker_expiry=60 + ) w, emit = finder._add('dummy-topic', [utils.DummyTask]) w.last_seen = 0 mock_now.side_effect = [120] @@ -61,7 +62,8 @@ class TestProxyFinder(test.TestCase): self.assertIsNotNone(w2) self.assertTrue(emit) w3 = finder.get_worker_for_task( - reflection.get_class_name(utils.DummyTask)) + reflection.get_class_name(utils.DummyTask) + ) self.assertIn(w3.identity, [w.identity, w2.identity]) def test_multi_different_topic_workers(self): diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py index 218c54d5f..8aa17dcd9 100644 --- a/taskflow/tests/unit/worker_based/test_worker.py +++ b/taskflow/tests/unit/worker_based/test_worker.py @@ -23,7 +23,6 @@ from taskflow.tests import utils class TestWorker(test.MockTestCase): - def setUp(self): super().setUp() self.task_cls = utils.DummyTask @@ -34,15 +33,19 @@ class TestWorker(test.MockTestCase): # patch classes self.executor_mock, self.executor_inst_mock = self.patchClass( - worker.futurist, 'ThreadPoolExecutor', attach_as='executor') + worker.futurist, 'ThreadPoolExecutor', attach_as='executor' + ) self.server_mock, self.server_inst_mock = self.patchClass( - worker.server, 'Server') + worker.server, 'Server' + ) def worker(self, reset_master_mock=False, **kwargs): - worker_kwargs = dict(exchange=self.exchange, - topic=self.topic, - tasks=[], - url=self.broker_url) + worker_kwargs = dict( + exchange=self.exchange, + topic=self.topic, + tasks=[], + url=self.broker_url, + ) worker_kwargs.update(kwargs) w = worker.Worker(**worker_kwargs) if reset_master_mock: @@ -54,12 +57,16 @@ class TestWorker(test.MockTestCase): master_mock_calls = [ mock.call.executor_class(max_workers=None), - mock.call.Server(self.topic, self.exchange, - self.executor_inst_mock, [], - url=self.broker_url, - transport_options=mock.ANY, - transport=mock.ANY, - retry_options=mock.ANY) + mock.call.Server( + self.topic, + self.exchange, + self.executor_inst_mock, + [], + url=self.broker_url, + transport_options=mock.ANY, + transport=mock.ANY, + retry_options=mock.ANY, + ), ] self.assertEqual(master_mock_calls, self.master_mock.mock_calls) @@ -76,12 +83,16 @@ class TestWorker(test.MockTestCase): master_mock_calls = [ mock.call.executor_class(max_workers=10), - mock.call.Server(self.topic, self.exchange, - self.executor_inst_mock, [], - url=self.broker_url, - transport_options=mock.ANY, - transport=mock.ANY, - retry_options=mock.ANY) + mock.call.Server( + self.topic, + self.exchange, + self.executor_inst_mock, + [], + url=self.broker_url, + transport_options=mock.ANY, + transport=mock.ANY, + retry_options=mock.ANY, + ), ] self.assertEqual(master_mock_calls, self.master_mock.mock_calls) @@ -90,39 +101,38 @@ class TestWorker(test.MockTestCase): self.worker(executor=executor_mock) master_mock_calls = [ - mock.call.Server(self.topic, self.exchange, executor_mock, [], - url=self.broker_url, - transport_options=mock.ANY, - transport=mock.ANY, - retry_options=mock.ANY) + mock.call.Server( + self.topic, + self.exchange, + executor_mock, + [], + url=self.broker_url, + transport_options=mock.ANY, + transport=mock.ANY, + retry_options=mock.ANY, + ) ] self.assertEqual(master_mock_calls, self.master_mock.mock_calls) def test_run_with_no_tasks(self): self.worker(reset_master_mock=True).run() - master_mock_calls = [ - mock.call.server.start() - ] + master_mock_calls = [mock.call.server.start()] self.assertEqual(master_mock_calls, self.master_mock.mock_calls) def test_run_with_tasks(self): - self.worker(reset_master_mock=True, - tasks=['taskflow.tests.utils:DummyTask']).run() + self.worker( + reset_master_mock=True, tasks=['taskflow.tests.utils:DummyTask'] + ).run() - master_mock_calls = [ - mock.call.server.start() - ] + master_mock_calls = [mock.call.server.start()] self.assertEqual(master_mock_calls, self.master_mock.mock_calls) def test_run_with_custom_executor(self): executor_mock = mock.MagicMock(name='executor') - self.worker(reset_master_mock=True, - executor=executor_mock).run() + self.worker(reset_master_mock=True, executor=executor_mock).run() - master_mock_calls = [ - mock.call.server.start() - ] + master_mock_calls = [mock.call.server.start()] self.assertEqual(master_mock_calls, self.master_mock.mock_calls) def test_wait(self): @@ -130,10 +140,7 @@ class TestWorker(test.MockTestCase): w.run() w.wait() - master_mock_calls = [ - mock.call.server.start(), - mock.call.server.wait() - ] + master_mock_calls = [mock.call.server.start(), mock.call.server.wait()] self.assertEqual(master_mock_calls, self.master_mock.mock_calls) def test_stop(self): @@ -141,13 +148,14 @@ class TestWorker(test.MockTestCase): master_mock_calls = [ mock.call.server.stop(), - mock.call.executor.shutdown() + mock.call.executor.shutdown(), ] self.assertEqual(master_mock_calls, self.master_mock.mock_calls) def test_derive_endpoints_from_string_tasks(self): endpoints = worker.Worker._derive_endpoints( - ['taskflow.tests.utils:DummyTask']) + ['taskflow.tests.utils:DummyTask'] + ) self.assertEqual(1, len(endpoints)) self.assertIsInstance(endpoints[0], endpoint.Endpoint) @@ -181,8 +189,9 @@ class TestWorker(test.MockTestCase): self.assertEqual(self.task_name, endpoints[0].name) def test_derive_endpoints_from_non_task_class(self): - self.assertRaises(TypeError, worker.Worker._derive_endpoints, - [utils.FakeTask]) + self.assertRaises( + TypeError, worker.Worker._derive_endpoints, [utils.FakeTask] + ) def test_derive_endpoints_from_modules(self): endpoints = worker.Worker._derive_endpoints([utils]) diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py index 909f85cc2..796de18d6 100644 --- a/taskflow/tests/utils.py +++ b/taskflow/tests/utils.py @@ -91,8 +91,9 @@ def redis_available(min_version): except Exception: return False else: - ok, redis_version = redis_utils.is_server_new_enough(client, - min_version) + ok, redis_version = redis_utils.is_server_new_enough( + client, min_version + ) return ok @@ -114,13 +115,11 @@ class NoopRetry(retry.AlwaysRevert): class NoopTask(task.Task): - def execute(self): pass class DummyTask(task.Task): - def execute(self, context, *args, **kwargs): pass @@ -129,9 +128,14 @@ class EmittingTask(task.Task): TASK_EVENTS = (task.EVENT_UPDATE_PROGRESS, 'hi') def execute(self, *args, **kwargs): - self.notifier.notify('hi', - details={'sent_on': timeutils.utcnow(), - 'args': args, 'kwargs': kwargs}) + self.notifier.notify( + 'hi', + details={ + 'sent_on': timeutils.utcnow(), + 'args': args, + 'kwargs': kwargs, + }, + ) class AddOneSameProvidesRequires(task.Task): @@ -149,7 +153,6 @@ class AddOne(task.Task): class GiveBackRevert(task.Task): - def execute(self, value): return value + 1 @@ -162,19 +165,21 @@ class GiveBackRevert(task.Task): class FakeTask: - def execute(self, **kwargs): pass class LongArgNameTask(task.Task): - def execute(self, long_arg_name): return long_arg_name -RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception', 'BaseException', - 'object'] +RUNTIME_ERROR_CLASSES = [ + 'RuntimeError', + 'Exception', + 'BaseException', + 'object', +] class ProvidesRequiresTask(task.Task): @@ -199,7 +204,6 @@ LOOKUP_NAME_POSTFIX = { class CaptureListener(capturing.CaptureListener): - @staticmethod def _format_capture(kind, state, details): name_postfix, name_key = LOOKUP_NAME_POSTFIX[kind] @@ -269,7 +273,6 @@ class OptionalTask(task.Task): class TaskWithFailure(task.Task): - def execute(self, **kwargs): raise RuntimeError('Woot!') @@ -280,7 +283,6 @@ class FailingTaskWithOneArg(ProgressingTask): class NastyTask(task.Task): - def execute(self, **kwargs): pass @@ -294,7 +296,6 @@ class NastyFailingTask(NastyTask): class TaskNoRequiresNoReturns(task.Task): - def execute(self, **kwargs): pass @@ -303,7 +304,6 @@ class TaskNoRequiresNoReturns(task.Task): class TaskOneArg(task.Task): - def execute(self, x, **kwargs): pass @@ -312,7 +312,6 @@ class TaskOneArg(task.Task): class TaskMultiArg(task.Task): - def execute(self, x, y, z, **kwargs): pass @@ -321,7 +320,6 @@ class TaskMultiArg(task.Task): class TaskOneReturn(task.Task): - def execute(self, **kwargs): return 1 @@ -330,7 +328,6 @@ class TaskOneReturn(task.Task): class TaskMultiReturn(task.Task): - def execute(self, **kwargs): return 1, 3, 5 @@ -339,7 +336,6 @@ class TaskMultiReturn(task.Task): class TaskOneArgOneReturn(task.Task): - def execute(self, x, **kwargs): return 1 @@ -348,7 +344,6 @@ class TaskOneArgOneReturn(task.Task): class TaskMultiArgOneReturn(task.Task): - def execute(self, x, y, z, **kwargs): return x + y + z @@ -357,7 +352,6 @@ class TaskMultiArgOneReturn(task.Task): class TaskMultiArgMultiReturn(task.Task): - def execute(self, x, y, z, **kwargs): return 1, 3, 5 @@ -366,7 +360,6 @@ class TaskMultiArgMultiReturn(task.Task): class TaskMultiDict(task.Task): - def execute(self): output = {} for i, k in enumerate(sorted(self.provides)): @@ -408,9 +401,9 @@ class EngineTestBase: super().tearDown() def _make_engine(self, flow, **kwargs): - raise exceptions.NotImplementedError("_make_engine() must be" - " overridden if an engine is" - " desired") + raise exceptions.NotImplementedError( + "_make_engine() must be overridden if an engine is desired" + ) class FailureMatcher: @@ -430,7 +423,6 @@ class FailureMatcher: class OneReturnRetry(retry.AlwaysRevert): - def execute(self, **kwargs): return 1 @@ -439,7 +431,6 @@ class OneReturnRetry(retry.AlwaysRevert): class ConditionalTask(ProgressingTask): - def execute(self, x, y): super().execute() if x != y: @@ -447,7 +438,6 @@ class ConditionalTask(ProgressingTask): class WaitForOneFromTask(ProgressingTask): - def __init__(self, name, wait_for, wait_states, **kwargs): super().__init__(name, **kwargs) if isinstance(wait_for, str): @@ -462,10 +452,11 @@ class WaitForOneFromTask(ProgressingTask): def execute(self): if not self.event.wait(WAIT_TIMEOUT): - raise RuntimeError('%s second timeout occurred while waiting ' - 'for %s to change state to %s' - % (WAIT_TIMEOUT, self.wait_for, - self.wait_states)) + raise RuntimeError( + '%s second timeout occurred while waiting ' + 'for %s to change state to %s' + % (WAIT_TIMEOUT, self.wait_for, self.wait_states) + ) return super().execute() def callback(self, state, details): @@ -480,8 +471,10 @@ def make_many(amount, task_cls=DummyTask, offset=0): tasks = [] while amount > 0: if offset >= len(name_pool): - raise AssertionError('Name pool size to small (%s < %s)' - % (len(name_pool), offset + 1)) + raise AssertionError( + 'Name pool size to small (%s < %s)' + % (len(name_pool), offset + 1) + ) tasks.append(task_cls(name=name_pool[offset])) offset += 1 amount -= 1 diff --git a/taskflow/types/entity.py b/taskflow/types/entity.py index 699ddd547..b48599d20 100644 --- a/taskflow/types/entity.py +++ b/taskflow/types/entity.py @@ -28,6 +28,7 @@ class Entity: entity) :type metadata: dict """ + def __init__(self, kind, name, metadata): self.kind = kind self.name = name @@ -37,5 +38,5 @@ class Entity: return { 'kind': self.kind, 'name': self.name, - 'metadata': self.metadata + 'metadata': self.metadata, } diff --git a/taskflow/types/failure.py b/taskflow/types/failure.py index 93182867d..66ca4f5e3 100644 --- a/taskflow/types/failure.py +++ b/taskflow/types/failure.py @@ -50,9 +50,13 @@ def _are_equal_exc_info_tuples(ei1, ei2): # NOTE(dhellmann): The flake8/pep8 error E721 does not apply here # because we want the types to be exactly the same, not just have # one be inherited from the other. - if not all((type(ei1[1]) == type(ei2[1]), # noqa: E721 - str(ei1[1]) == str(ei2[1]), - repr(ei1[1]) == repr(ei2[1]))): + if not all( + ( + type(ei1[1]) == type(ei2[1]), # noqa: E721 + str(ei1[1]) == str(ei2[1]), + repr(ei1[1]) == repr(ei2[1]), + ) + ): return False if ei1[2] == ei2[2]: return True @@ -61,7 +65,7 @@ def _are_equal_exc_info_tuples(ei1, ei2): return tb1 == tb2 -class Failure(): +class Failure: """An immutable object that represents failure. Failure objects encapsulate exception information so that they can be @@ -117,6 +121,7 @@ class Failure(): backport at https://pypi.org/project/traceback2/ to (hopefully) simplify the methods and contents of this object... """ + DICT_VERSION = 1 BASE_EXCEPTIONS = ('BaseException', 'Exception') @@ -159,7 +164,7 @@ class Failure(): "items": { "$ref": "#/definitions/cause", }, - } + }, }, "required": [ "exception_str", @@ -180,18 +185,23 @@ class Failure(): # either from a prior sys.exc_info() call or from some other # creation... if len(exc_info) != 3: - raise ValueError("Provided 'exc_info' must contain three" - " elements") + raise ValueError( + "Provided 'exc_info' must contain three elements" + ) self._exc_info = exc_info self._exc_args = tuple(getattr(exc_info[1], 'args', [])) self._exc_type_names = tuple( - reflection.get_all_class_names(exc_info[0], up_to=Exception)) + reflection.get_all_class_names(exc_info[0], up_to=Exception) + ) if not self._exc_type_names: - raise TypeError("Invalid exception type '%s' (%s)" - % (exc_info[0], type(exc_info[0]))) + raise TypeError( + "Invalid exception type '%s' (%s)" + % (exc_info[0], type(exc_info[0])) + ) self._exception_str = str(self._exc_info[1]) self._traceback_str = ''.join( - traceback.format_tb(self._exc_info[2])) + traceback.format_tb(self._exc_info[2]) + ) self._causes = kwargs.pop('causes', None) else: self._causes = kwargs.pop('causes', None) @@ -203,7 +213,8 @@ class Failure(): if kwargs: raise TypeError( 'Failure.__init__ got unexpected keyword argument(s): %s' - % ', '.join(kwargs.keys())) + % ', '.join(kwargs.keys()) + ) @classmethod def from_exception(cls, exception): @@ -211,7 +222,7 @@ class Failure(): exc_info = ( type(exception), exception, - getattr(exception, '__traceback__', None) + getattr(exception, '__traceback__', None), ) return cls(exc_info=exc_info) @@ -221,8 +232,9 @@ class Failure(): try: su.schema_validate(data, cls.SCHEMA) except su.ValidationError as e: - raise exc.InvalidFormat("Failure data not of the" - " expected format: %s" % (e.message), e) + raise exc.InvalidFormat( + "Failure data not of the expected format: %s" % (e.message), e + ) else: # Ensure that all 'exc_type_names' originate from one of # BASE_EXCEPTIONS, because those are the root exceptions that @@ -236,7 +248,8 @@ class Failure(): "Failure data 'exc_type_names' must" " have an initial exception type that is one" " of %s types: '%s' is not one of those" - " types" % (cls.BASE_EXCEPTIONS, root_exc_type)) + " types" % (cls.BASE_EXCEPTIONS, root_exc_type) + ) sub_causes = cause.get('causes') if sub_causes: causes.extend(sub_causes) @@ -244,11 +257,13 @@ class Failure(): def _matches(self, other): if self is other: return True - return (self._exc_type_names == other._exc_type_names - and self.exception_args == other.exception_args - and self.exception_str == other.exception_str - and self.traceback_str == other.traceback_str - and self.causes == other.causes) + return ( + self._exc_type_names == other._exc_type_names + and self.exception_args == other.exception_args + and self.exception_str == other.exception_str + and self.traceback_str == other.traceback_str + and self.causes == other.causes + ) def matches(self, other): """Checks if another object is equivalent to this object. @@ -266,8 +281,9 @@ class Failure(): def __eq__(self, other): if not isinstance(other, Failure): return NotImplemented - return (self._matches(other) and - _are_equal_exc_info_tuples(self.exc_info, other.exc_info)) + return self._matches(other) and _are_equal_exc_info_tuples( + self.exc_info, other.exc_info + ) def __ne__(self, other): return not (self == other) @@ -379,8 +395,7 @@ class Failure(): # # See: https://www.python.org/dev/peps/pep-0415/ for why/what # the '__suppress_context__' is/means/implies... - suppress_context = getattr(exc_val, - '__suppress_context__', False) + suppress_context = getattr(exc_val, '__suppress_context__', False) if suppress_context: attr_lookups = ['__cause__'] else: @@ -431,8 +446,11 @@ class Failure(): if not self._exc_type_names: buf.write('Failure: %s' % (self._exception_str)) else: - buf.write('Failure: {}: {}'.format(self._exc_type_names[0], - self._exception_str)) + buf.write( + 'Failure: {}: {}'.format( + self._exc_type_names[0], self._exception_str + ) + ) if traceback: if self._traceback_str is not None: traceback_str = self._traceback_str.rstrip() @@ -492,8 +510,9 @@ class Failure(): data = dict(data) version = data.pop('version', None) if version != cls.DICT_VERSION: - raise ValueError('Invalid dict version of failure object: %r' - % version) + raise ValueError( + 'Invalid dict version of failure object: %r' % version + ) causes = data.get('causes') if causes is not None: data['causes'] = tuple(cls.from_dict(d) for d in causes) @@ -516,9 +535,11 @@ class Failure(): def copy(self): """Copies this object.""" - return Failure(exc_info=_copy_exc_info(self.exc_info), - exception_str=self.exception_str, - traceback_str=self.traceback_str, - exc_args=self.exception_args, - exc_type_names=self._exc_type_names[:], - causes=self._causes) + return Failure( + exc_info=_copy_exc_info(self.exc_info), + exception_str=self.exception_str, + traceback_str=self.traceback_str, + exc_args=self.exception_args, + exc_type_names=self._exc_type_names[:], + causes=self._causes, + ) diff --git a/taskflow/types/graph.py b/taskflow/types/graph.py index bb38f0cb6..42729abdd 100644 --- a/taskflow/types/graph.py +++ b/taskflow/types/graph.py @@ -33,7 +33,7 @@ def _common_format(g, edge_notation): else: lines.append(" - %s" % n) lines.append("Edges: %s" % g.number_of_edges()) - for (u, v, e_data) in g.edges(data=True): + for u, v, e_data in g.edges(data=True): if e_data: lines.append(f" {u} {edge_notation} {v} ({e_data})") else: @@ -201,6 +201,7 @@ class OrderedDiGraph(DiGraph): ordering (so that the iteration order matches the insertion order). """ + node_dict_factory = collections.OrderedDict adjlist_outer_dict_factory = collections.OrderedDict adjlist_inner_dict_factory = collections.OrderedDict @@ -223,6 +224,7 @@ class OrderedGraph(Graph): ordering (so that the iteration order matches the insertion order). """ + node_dict_factory = collections.OrderedDict adjlist_outer_dict_factory = collections.OrderedDict adjlist_inner_dict_factory = collections.OrderedDict @@ -250,8 +252,9 @@ def merge_graphs(graph, *graphs, **kwargs): if overlap_detector is not None and not callable(overlap_detector): raise ValueError("Overlap detection callback expected to be callable") elif overlap_detector is None: - overlap_detector = (lambda to_graph, from_graph: - len(to_graph.subgraph(from_graph.nodes))) + overlap_detector = lambda to_graph, from_graph: len( + to_graph.subgraph(from_graph.nodes) + ) for g in graphs: # This should ensure that the nodes to be merged do not already exist # in the graph that is to be merged into. This could be problematic if @@ -261,10 +264,11 @@ def merge_graphs(graph, *graphs, **kwargs): # and see if any graph results. overlaps = overlap_detector(graph, g) if overlaps: - raise ValueError("Can not merge graph %s into %s since there " - "are %s overlapping nodes (and we do not " - "support merging nodes)" % (g, graph, - overlaps)) + raise ValueError( + "Can not merge graph %s into %s since there " + "are %s overlapping nodes (and we do not " + "support merging nodes)" % (g, graph, overlaps) + ) graph = nx.algorithms.compose(graph, g) # Keep the first graphs name. if graphs: diff --git a/taskflow/types/notifier.py b/taskflow/types/notifier.py index 14b22b3e1..a23c49e50 100644 --- a/taskflow/types/notifier.py +++ b/taskflow/types/notifier.py @@ -89,7 +89,9 @@ class Listener: def __repr__(self): repr_msg = "{} object at 0x{:x} calling into '{!r}'".format( reflection.get_class_name(self, fully_qualified=False), - id(self), self._callback) + id(self), + self._callback, + ) if self._details_filter is not None: repr_msg += " using details filter '%r'" % self._details_filter return "<%s>" % repr_msg @@ -108,15 +110,17 @@ class Listener: if self._details_filter is None: return False else: - return reflection.is_same_callback(self._details_filter, - details_filter) + return reflection.is_same_callback( + self._details_filter, details_filter + ) else: return self._details_filter is None def __eq__(self, other): if isinstance(other, Listener): - return self.is_equivalent(other._callback, - details_filter=other._details_filter) + return self.is_equivalent( + other._callback, details_filter=other._details_filter + ) else: return NotImplemented @@ -161,7 +165,7 @@ class Notifier: :rtype: number """ count = 0 - for (_event_type, listeners) in self._topics.items(): + for _event_type, listeners in self._topics.items(): count += len(listeners) return count @@ -196,8 +200,10 @@ class Notifier: :type details: dictionary """ if not self.can_trigger_notification(event_type): - LOG.debug("Event type '%s' is not allowed to trigger" - " notifications", event_type) + LOG.debug( + "Event type '%s' is not allowed to trigger notifications", + event_type, + ) return listeners = list(self._topics.get(self.ANY, [])) listeners.extend(self._topics.get(event_type, [])) @@ -209,12 +215,18 @@ class Notifier: try: listener(event_type, details.copy()) except Exception: - LOG.warning("Failure calling listener %s to notify about event" - " %s, details: %s", listener, event_type, - details, exc_info=True) + LOG.warning( + "Failure calling listener %s to notify about event" + " %s, details: %s", + listener, + event_type, + details, + exc_info=True, + ) - def register(self, event_type, callback, - args=None, kwargs=None, details_filter=None): + def register( + self, event_type, callback, args=None, kwargs=None, details_filter=None + ): """Register a callback to be called when event of a given type occurs. Callback will be called with provided ``args`` and ``kwargs`` and @@ -238,21 +250,31 @@ class Notifier: if not callable(details_filter): raise ValueError("Details filter must be callable") if not self.can_be_registered(event_type): - raise ValueError("Disallowed event type '%s' can not have a" - " callback registered" % event_type) - if self.is_registered(event_type, callback, - details_filter=details_filter): - raise ValueError("Event callback already registered with" - " equivalent details filter") + raise ValueError( + "Disallowed event type '%s' can not have a" + " callback registered" % event_type + ) + if self.is_registered( + event_type, callback, details_filter=details_filter + ): + raise ValueError( + "Event callback already registered with" + " equivalent details filter" + ) if kwargs: for k in self.RESERVED_KEYS: if k in kwargs: - raise KeyError("Reserved key '%s' not allowed in " - "kwargs" % k) + raise KeyError( + "Reserved key '%s' not allowed in kwargs" % k + ) self._topics[event_type].append( - Listener(callback, - args=args, kwargs=kwargs, - details_filter=details_filter)) + Listener( + callback, + args=args, + kwargs=kwargs, + details_filter=details_filter, + ) + ) def deregister(self, event_type, callback, details_filter=None): """Remove a single listener bound to event ``event_type``. @@ -277,7 +299,7 @@ class Notifier: def copy(self): c = copy.copy(self) c._topics = collections.defaultdict(list) - for (event_type, listeners) in self._topics.items(): + for event_type, listeners in self._topics.items(): c._topics[event_type] = listeners[:] return c @@ -339,13 +361,20 @@ class RestrictedNotifier(Notifier): :returns: whether the event can be registered/subscribed to :rtype: boolean """ - return (event_type in self._watchable_events or - (event_type == self.ANY and self._allow_any)) + return event_type in self._watchable_events or ( + event_type == self.ANY and self._allow_any + ) @contextlib.contextmanager -def register_deregister(notifier, event_type, callback=None, - args=None, kwargs=None, details_filter=None): +def register_deregister( + notifier, + event_type, + callback=None, + args=None, + kwargs=None, + details_filter=None, +): """Context manager that registers a callback, then deregisters on exit. NOTE(harlowja): if the callback is none, then this registers nothing, which @@ -355,11 +384,16 @@ def register_deregister(notifier, event_type, callback=None, if callback is None: yield else: - notifier.register(event_type, callback, - args=args, kwargs=kwargs, - details_filter=details_filter) + notifier.register( + event_type, + callback, + args=args, + kwargs=kwargs, + details_filter=details_filter, + ) try: yield finally: - notifier.deregister(event_type, callback, - details_filter=details_filter) + notifier.deregister( + event_type, callback, details_filter=details_filter + ) diff --git a/taskflow/types/sets.py b/taskflow/types/sets.py index f9c24d429..cf73b9e73 100644 --- a/taskflow/types/sets.py +++ b/taskflow/types/sets.py @@ -75,6 +75,7 @@ class OrderedSet(abc.Set, abc.Hashable): (i.e. elements that are common to all of the sets.) """ + def absorb_it(sets): for value in iter(self): matches = 0 @@ -85,6 +86,7 @@ class OrderedSet(abc.Set, abc.Hashable): break if matches == len(sets): yield value + return self._from_iterable(absorb_it(sets)) def issuperset(self, other): @@ -106,6 +108,7 @@ class OrderedSet(abc.Set, abc.Hashable): (i.e. all elements that are in this set but not the others.) """ + def absorb_it(sets): for value in iter(self): seen = False @@ -115,6 +118,7 @@ class OrderedSet(abc.Set, abc.Hashable): break if not seen: yield value + return self._from_iterable(absorb_it(sets)) def union(self, *sets): diff --git a/taskflow/types/timing.py b/taskflow/types/timing.py index cbfa006b7..dd6e7b7b1 100644 --- a/taskflow/types/timing.py +++ b/taskflow/types/timing.py @@ -21,10 +21,13 @@ class Timeout: This object has the ability to be interrupted before the actual timeout is reached. """ + def __init__(self, value, event_factory=threading.Event): if value < 0: - raise ValueError("Timeout value must be greater or" - " equal to zero and not '%s'" % (value)) + raise ValueError( + "Timeout value must be greater or" + " equal to zero and not '%s'" % (value) + ) self._value = value self._event = event_factory() @@ -50,8 +53,9 @@ class Timeout: self._event.clear() -def convert_to_timeout(value=None, default_value=None, - event_factory=threading.Event): +def convert_to_timeout( + value=None, default_value=None, event_factory=threading.Event +): """Converts a given value to a timeout instance (and returns it). Does nothing if the value provided is already a timeout instance. diff --git a/taskflow/types/tree.py b/taskflow/types/tree.py index a03a61d85..fe5f27320 100644 --- a/taskflow/types/tree.py +++ b/taskflow/types/tree.py @@ -192,9 +192,11 @@ class Node: :returns: the node that matched provided item (or ``None``) """ - return self.find_first_match(lambda n: n.item == item, - only_direct=only_direct, - include_self=include_self) + return self.find_first_match( + lambda n: n.item == item, + only_direct=only_direct, + include_self=include_self, + ) @misc.disallow_when_frozen(FrozenNode) def disassociate(self): @@ -230,8 +232,9 @@ class Node: search using depth first). :param include_self: include the current node during searching. """ - node = self.find(item, only_direct=only_direct, - include_self=include_self) + node = self.find( + item, only_direct=only_direct, include_self=include_self + ) if node is None: raise ValueError("Item '%s' not found to remove" % item) else: @@ -251,10 +254,15 @@ class Node: # NOTE(harlowja): 0 is the right most index, len - 1 is the left most return self._children[index] - def pformat(self, stringify_node=None, - linesep=LINE_SEP, vertical_conn=VERTICAL_CONN, - horizontal_conn=HORIZONTAL_CONN, empty_space=EMPTY_SPACE_SEP, - starting_prefix=STARTING_PREFIX): + def pformat( + self, + stringify_node=None, + linesep=LINE_SEP, + vertical_conn=VERTICAL_CONN, + horizontal_conn=HORIZONTAL_CONN, + empty_space=EMPTY_SPACE_SEP, + starting_prefix=STARTING_PREFIX, + ): """Formats this node + children into a nice string representation. **Example**:: @@ -290,7 +298,8 @@ class Node: # hit the root node (self) and use that as our nodes prefix # string... parent_node_it = iter_utils.while_is_not( - node.path_iter(include_self=True), stop_at_parent) + node.path_iter(include_self=True), stop_at_parent + ) for j, parent_node in enumerate(parent_node_it): if parent_node is stop_at_parent: if j > 0: @@ -303,9 +312,11 @@ class Node: # the right final starting prefix on (which may be # a empty space or another vertical connector)... last_node = self._children[-1] - m = last_node.find_first_match(lambda n: n is node, - include_self=False, - only_direct=False) + m = last_node.find_first_match( + lambda n: n is node, + include_self=False, + only_direct=False, + ) if m is not None: prefix.append(empty_space) else: @@ -365,7 +376,7 @@ class Node: def index(self, item): """Finds the child index of a given item, searches in added order.""" index_at = None - for (i, child) in enumerate(self._children): + for i, child in enumerate(self._children): if child.item == item: index_at = i break @@ -375,15 +386,15 @@ class Node: def dfs_iter(self, include_self=False, right_to_left=True): """Depth first iteration (non-recursive) over the child nodes.""" - return _DFSIter(self, - include_self=include_self, - right_to_left=right_to_left) + return _DFSIter( + self, include_self=include_self, right_to_left=right_to_left + ) def bfs_iter(self, include_self=False, right_to_left=False): """Breadth first iteration (non-recursive) over the child nodes.""" - return _BFSIter(self, - include_self=include_self, - right_to_left=right_to_left) + return _BFSIter( + self, include_self=include_self, right_to_left=right_to_left + ) def to_digraph(self): """Converts this node + its children into a ordered directed graph. diff --git a/taskflow/utils/banner.py b/taskflow/utils/banner.py index 9ed2175e6..b447bb66f 100644 --- a/taskflow/utils/banner.py +++ b/taskflow/utils/banner.py @@ -19,11 +19,13 @@ from taskflow.utils import misc from taskflow import version -BANNER_HEADER = string.Template(""" +BANNER_HEADER = string.Template( + """ ___ __ | |_ |ask |low v$version -""".strip()) +""".strip() +) BANNER_HEADER = BANNER_HEADER.substitute(version=version.version_string()) @@ -72,13 +74,17 @@ def make_banner(what, chapters): section_names = sorted(chapter_contents.keys()) for j, section_name in enumerate(section_names): if j + 1 < len(section_names): - buf.write_nl(" {} => {}".format( - section_name, - chapter_contents[section_name])) + buf.write_nl( + " {} => {}".format( + section_name, chapter_contents[section_name] + ) + ) else: - buf.write(" {} => {}".format( - section_name, - chapter_contents[section_name])) + buf.write( + " {} => {}".format( + section_name, chapter_contents[section_name] + ) + ) elif isinstance(chapter_contents, (list, tuple, set)): if isinstance(chapter_contents, set): sections = sorted(chapter_contents) @@ -90,9 +96,11 @@ def make_banner(what, chapters): else: buf.write(f" {j + 1}. {section}") else: - raise TypeError("Unsupported chapter contents" - " type: one of dict, list, tuple, set expected" - " and not %s" % type(chapter_contents).__name__) + raise TypeError( + "Unsupported chapter contents" + " type: one of dict, list, tuple, set expected" + " and not %s" % type(chapter_contents).__name__ + ) if i + 1 < len(chapter_names): buf.write_nl("") # NOTE(harlowja): this is needed since the template in this file diff --git a/taskflow/utils/iter_utils.py b/taskflow/utils/iter_utils.py index 3f4038c64..ce827c257 100644 --- a/taskflow/utils/iter_utils.py +++ b/taskflow/utils/iter_utils.py @@ -22,8 +22,9 @@ def _ensure_iterable(func): @functools.wraps(func) def wrapper(it, *args, **kwargs): if not isinstance(it, abc.Iterable): - raise ValueError("Iterable expected, but '%s' is not" - " iterable" % it) + raise ValueError( + "Iterable expected, but '%s' is not iterable" % it + ) return func(it, *args, **kwargs) return wrapper @@ -63,14 +64,17 @@ def generate_delays(delay, max_delay, multiplier=2): stop generating values). """ if max_delay < 0: - raise ValueError("Provided delay (max) must be greater" - " than or equal to zero") + raise ValueError( + "Provided delay (max) must be greater than or equal to zero" + ) if delay < 0: - raise ValueError("Provided delay must start off greater" - " than or equal to zero") + raise ValueError( + "Provided delay must start off greater than or equal to zero" + ) if multiplier < 1.0: - raise ValueError("Provided multiplier must be greater than" - " or equal to 1.0") + raise ValueError( + "Provided multiplier must be greater than or equal to 1.0" + ) def _gen_it(): # NOTE(harlowja): Generation is delayed so that validation @@ -106,8 +110,9 @@ def unique_seen(its, seen_selector=None): all_its = list(its) for it in all_its: if not isinstance(it, abc.Iterable): - raise ValueError("Iterable expected, but '%s' is" - " not iterable" % it) + raise ValueError( + "Iterable expected, but '%s' is not iterable" % it + ) return _gen_it(all_its) diff --git a/taskflow/utils/kazoo_utils.py b/taskflow/utils/kazoo_utils.py index becd4d74a..63514f7f8 100644 --- a/taskflow/utils/kazoo_utils.py +++ b/taskflow/utils/kazoo_utils.py @@ -30,7 +30,8 @@ CONF_TRANSFERS = ( ('keyfile_password', None, None), ('certfile', None, None), ('use_ssl', strutils.bool_from_string, False), - ('verify_certs', strutils.bool_from_string, True)) + ('verify_certs', strutils.bool_from_string, True), +) def _parse_hosts(hosts): @@ -38,7 +39,7 @@ def _parse_hosts(hosts): return hosts.strip() if isinstance(hosts, (dict)): host_ports = [] - for (k, v) in hosts.items(): + for k, v in hosts.items(): host_ports.append(f"{k}:{v}") hosts = host_ports if isinstance(hosts, (list, set, tuple)): @@ -49,7 +50,7 @@ def _parse_hosts(hosts): def prettify_failures(failures, limit=-1): """Prettifies a checked commits failures (ignores sensitive data...).""" prettier = [] - for (op, r) in failures: + for op, r in failures: pretty_op = reflection.get_class_name(op, fully_qualified=False) # Pick off a few attributes that are meaningful (but one that don't # show actual data, which might not be desired to show...). @@ -102,17 +103,22 @@ def checked_commit(txn): raise KazooTransactionException( "Transaction returned %s results, this is less than" " the number of expected transaction operations %s" - % (len(results), len(txn.operations)), failures) + % (len(results), len(txn.operations)), + failures, + ) if len(results) > len(txn.operations): raise KazooTransactionException( "Transaction returned %s results, this is greater than" " the number of expected transaction operations %s" - % (len(results), len(txn.operations)), failures) + % (len(results), len(txn.operations)), + failures, + ) if failures: raise KazooTransactionException( "Transaction with %s operations failed: %s" - % (len(txn.operations), - prettify_failures(failures, limit=1)), failures) + % (len(txn.operations), prettify_failures(failures, limit=1)), + failures, + ) return results @@ -137,10 +143,11 @@ def check_compatible(client, min_version=None, max_version=None): if server_version < min_version: pretty_server_version = ".".join([str(a) for a in server_version]) min_version = ".".join([str(a) for a in min_version]) - raise exc.IncompatibleVersion("Incompatible zookeeper version" - " %s detected, zookeeper >= %s" - " required" % (pretty_server_version, - min_version)) + raise exc.IncompatibleVersion( + "Incompatible zookeeper version" + " %s detected, zookeeper >= %s" + " required" % (pretty_server_version, min_version) + ) if max_version: if server_version is None: server_version = tuple(int(a) for a in client.server_version()) @@ -148,10 +155,11 @@ def check_compatible(client, min_version=None, max_version=None): if server_version > max_version: pretty_server_version = ".".join([str(a) for a in server_version]) max_version = ".".join([str(a) for a in max_version]) - raise exc.IncompatibleVersion("Incompatible zookeeper version" - " %s detected, zookeeper <= %s" - " required" % (pretty_server_version, - max_version)) + raise exc.IncompatibleVersion( + "Incompatible zookeeper version" + " %s detected, zookeeper <= %s" + " required" % (pretty_server_version, max_version) + ) def make_client(conf): @@ -207,8 +215,9 @@ def make_client(conf): for key, value_type_converter, default in CONF_TRANSFERS: if key in conf: if value_type_converter is not None: - client_kwargs[key] = value_type_converter(conf[key], - default=default) + client_kwargs[key] = value_type_converter( + conf[key], default=default + ) else: client_kwargs[key] = conf[key] else: @@ -221,9 +230,10 @@ def make_client(conf): client_kwargs['connection_retry'] = conf['connection_retry'] hosts = _parse_hosts(conf.get("hosts", "localhost:2181")) if not hosts or not isinstance(hosts, str): - raise TypeError("Invalid hosts format, expected " - "non-empty string/list, not '%s' (%s)" - % (hosts, type(hosts))) + raise TypeError( + "Invalid hosts format, expected " + "non-empty string/list, not '%s' (%s)" % (hosts, type(hosts)) + ) client_kwargs['hosts'] = hosts if 'timeout' in conf: client_kwargs['timeout'] = float(conf['timeout']) diff --git a/taskflow/utils/kombu_utils.py b/taskflow/utils/kombu_utils.py index 0c1f27ca3..97ea37233 100644 --- a/taskflow/utils/kombu_utils.py +++ b/taskflow/utils/kombu_utils.py @@ -13,11 +13,13 @@ # under the License. # Keys extracted from the message properties when formatting... -_MSG_PROPERTIES = tuple([ - 'correlation_id', - 'delivery_info/routing_key', - 'type', -]) +_MSG_PROPERTIES = tuple( + [ + 'correlation_id', + 'delivery_info/routing_key', + 'type', + ] +) class DelayedPretty: diff --git a/taskflow/utils/misc.py b/taskflow/utils/misc.py index ce175af6d..579d9e0bd 100644 --- a/taskflow/utils/misc.py +++ b/taskflow/utils/misc.py @@ -51,8 +51,10 @@ class StrEnum(str, enum.Enum): def __new__(cls, *args, **kwargs): for a in args: if not isinstance(a, str): - raise TypeError("Enumeration '%s' (%s) is not" - " a string" % (a, type(a).__name__)) + raise TypeError( + "Enumeration '%s' (%s) is not" + " a string" % (a, type(a).__name__) + ) return super().__new__(cls, *args, **kwargs) @@ -93,7 +95,7 @@ def match_type(obj, matchers): Returns the result (the second element of the provided tuple) if a type match occurs, otherwise none if no matches are found. """ - for (match_types, match_result) in matchers: + for match_types, match_result in matchers: if isinstance(obj, match_types): return match_result else: @@ -108,8 +110,9 @@ def countdown_iter(start_at, decr=1): that step parameter does **not** exist and therefore can't be used). """ if decr <= 0: - raise ValueError("Decrement value must be greater" - " than zero and not %s" % decr) + raise ValueError( + "Decrement value must be greater than zero and not %s" % decr + ) while start_at > 0: yield start_at start_at -= decr @@ -157,10 +160,10 @@ def merge_uri(uri, conf): if uri_port is not None: hostname += ":%s" % (uri_port) specials.append(('hostname', hostname, lambda v: bool(v))) - for (k, v, is_not_empty_value_func) in specials: + for k, v, is_not_empty_value_func in specials: if is_not_empty_value_func(v): conf.setdefault(k, v) - for (k, v) in uri.params().items(): + for k, v in uri.params().items(): conf.setdefault(k, v) return conf @@ -189,19 +192,22 @@ def find_subclasses(locations, base_cls, exclude_hidden=True): else: obj = importutils.import_class(f'{pkg}.{cls}') if not reflection.is_subclass(obj, base_cls): - raise TypeError("Object '%s' (%s) is not a '%s' subclass" - % (item, type(item), base_cls)) + raise TypeError( + "Object '%s' (%s) is not a '%s' subclass" + % (item, type(item), base_cls) + ) derived.add(obj) elif isinstance(item, types.ModuleType): module = item elif reflection.is_subclass(item, base_cls): derived.add(item) else: - raise TypeError("Object '%s' (%s) is an unexpected type" % - (item, type(item))) + raise TypeError( + "Object '%s' (%s) is an unexpected type" % (item, type(item)) + ) # If it's a module derive objects from it if we can. if module is not None: - for (name, obj) in inspect.getmembers(module): + for name, obj in inspect.getmembers(module): if name.startswith("_") and exclude_hidden: continue if reflection.is_subclass(obj, base_cls): @@ -221,12 +227,15 @@ def parse_uri(uri): """Parses a uri into its components.""" # Do some basic validation before continuing... if not isinstance(uri, str): - raise TypeError("Can only parse string types to uri data, " - "and not '%s' (%s)" % (uri, type(uri))) + raise TypeError( + "Can only parse string types to uri data, " + "and not '%s' (%s)" % (uri, type(uri)) + ) match = _SCHEME_REGEX.match(uri) if not match: - raise ValueError("Uri '%s' does not start with a RFC 3986 compliant" - " scheme" % (uri)) + raise ValueError( + "Uri '%s' does not start with a RFC 3986 compliant scheme" % (uri) + ) return netutils.urlsplit(uri) @@ -250,8 +259,10 @@ def disallow_when_frozen(excp_cls): def clamp(value, minimum, maximum, on_clamped=None): """Clamps a value to ensure its >= minimum and <= maximum.""" if minimum > maximum: - raise ValueError("Provided minimum '%s' must be less than or equal to" - " the provided maximum '%s'" % (minimum, maximum)) + raise ValueError( + "Provided minimum '%s' must be less than or equal to" + " the provided maximum '%s'" % (minimum, maximum) + ) if value > maximum: value = maximum if on_clamped is not None: @@ -276,8 +287,7 @@ def binary_encode(text, encoding='utf-8', errors='strict'): if isinstance(text, bytes): return text else: - return encodeutils.safe_encode(text, encoding=encoding, - errors=errors) + return encodeutils.safe_encode(text, encoding=encoding, errors=errors) def binary_decode(data, encoding='utf-8', errors='strict'): @@ -288,8 +298,7 @@ def binary_decode(data, encoding='utf-8', errors='strict'): if isinstance(data, str): return data else: - return encodeutils.safe_decode(data, incoming=encoding, - errors=errors) + return encodeutils.safe_decode(data, incoming=encoding, errors=errors) def _check_decoded_type(data, root_types=(dict,)): @@ -299,11 +308,15 @@ def _check_decoded_type(data, root_types=(dict,)): if not isinstance(data, root_types): if len(root_types) == 1: root_type = root_types[0] - raise ValueError("Expected '%s' root type not '%s'" - % (root_type, type(data))) + raise ValueError( + "Expected '%s' root type not '%s'" + % (root_type, type(data)) + ) else: - raise ValueError("Expected %s root types not '%s'" - % (list(root_types), type(data))) + raise ValueError( + "Expected %s root types not '%s'" + % (list(root_types), type(data)) + ) return data @@ -352,6 +365,7 @@ class cachedproperty: cached property would be stored under '_get_thing' in the self object after the first call to 'get_thing' occurs. """ + def __init__(self, fget=None, require_lock=True): if require_lock: self._lock = threading.RLock() @@ -422,8 +436,7 @@ def get_version_string(obj): obj_version = getattr(obj, 'version', None) if isinstance(obj_version, (list, tuple)): obj_version = '.'.join(str(item) for item in obj_version) - if obj_version is not None and not isinstance(obj_version, - str): + if obj_version is not None and not isinstance(obj_version, str): obj_version = str(obj_version) return obj_version @@ -458,8 +471,9 @@ def as_int(obj, quiet=False): pass # Eck, not sure what this is then. if not quiet: - raise TypeError("Can not translate '%s' (%s) to an integer" - % (obj, type(obj))) + raise TypeError( + "Can not translate '%s' (%s) to an integer" % (obj, type(obj)) + ) return obj @@ -521,8 +535,9 @@ def is_iterable(obj): :param obj: object to be tested for iterable :return: True if object is iterable and is not a string """ - return (not isinstance(obj, str) and - isinstance(obj, collections.abc.Iterable)) + return not isinstance(obj, str) and isinstance( + obj, collections.abc.Iterable + ) def safe_copy_dict(obj): diff --git a/taskflow/utils/redis_utils.py b/taskflow/utils/redis_utils.py index 036ca8888..93520bdec 100644 --- a/taskflow/utils/redis_utils.py +++ b/taskflow/utils/redis_utils.py @@ -24,8 +24,9 @@ def _raise_on_closed(meth): @functools.wraps(meth) def wrapper(self, *args, **kwargs): if self.closed: - raise redis_exceptions.ConnectionError("Connection has been" - " closed") + raise redis_exceptions.ConnectionError( + "Connection has been closed" + ) return meth(self, *args, **kwargs) return wrapper @@ -75,7 +76,8 @@ _UNKNOWN_EXPIRE_MAPPING = {e.value: e for e in list(UnknownExpire)} def get_expiry(client, key, prior_version=None): """Gets an expiry for a key (using **best** determined ttl method).""" is_new_enough, _prior_version = is_server_new_enough( - client, (2, 6), prior_version=prior_version) + client, (2, 6), prior_version=prior_version + ) if is_new_enough: result = client.pttl(key) try: @@ -93,7 +95,8 @@ def get_expiry(client, key, prior_version=None): def apply_expiry(client, key, expiry, prior_version=None): """Applies an expiry to a key (using **best** determined expiry method).""" is_new_enough, _prior_version = is_server_new_enough( - client, (2, 6), prior_version=prior_version) + client, (2, 6), prior_version=prior_version + ) if is_new_enough: # Use milliseconds (as that is what pexpire uses/expects...) ms_expiry = expiry * 1000.0 @@ -107,8 +110,9 @@ def apply_expiry(client, key, expiry, prior_version=None): return bool(result) -def is_server_new_enough(client, min_version, - default=False, prior_version=None): +def is_server_new_enough( + client, min_version, default=False, prior_version=None +): """Checks if a client is attached to a new enough redis server.""" if not prior_version: try: diff --git a/taskflow/utils/schema_utils.py b/taskflow/utils/schema_utils.py index 5c8e59f49..337f61d7b 100644 --- a/taskflow/utils/schema_utils.py +++ b/taskflow/utils/schema_utils.py @@ -27,7 +27,9 @@ def schema_validate(data, schema): # Special jsonschema validation types/adjustments. # See: https://github.com/Julian/jsonschema/issues/148 type_checker = Validator.TYPE_CHECKER.redefine( - "array", lambda checker, data: isinstance(data, (list, tuple))) + "array", lambda checker, data: isinstance(data, (list, tuple)) + ) TupleAllowingValidator = jsonschema.validators.extend( - Validator, type_checker=type_checker) + Validator, type_checker=type_checker + ) TupleAllowingValidator(schema).validate(data) diff --git a/taskflow/utils/threading_utils.py b/taskflow/utils/threading_utils.py index 10622cf38..a7d5a4bb1 100644 --- a/taskflow/utils/threading_utils.py +++ b/taskflow/utils/threading_utils.py @@ -54,17 +54,25 @@ def daemon_thread(target, *args, **kwargs): # Container for thread creator + associated callbacks. -_ThreadBuilder = collections.namedtuple('_ThreadBuilder', - ['thread_factory', - 'before_start', 'after_start', - 'before_join', 'after_join']) -_ThreadBuilder.fields = tuple([ - 'thread_factory', - 'before_start', - 'after_start', - 'before_join', - 'after_join', -]) +_ThreadBuilder = collections.namedtuple( + '_ThreadBuilder', + [ + 'thread_factory', + 'before_start', + 'after_start', + 'before_join', + 'after_join', + ], +) +_ThreadBuilder.fields = tuple( + [ + 'thread_factory', + 'before_start', + 'after_start', + 'before_join', + 'after_join', + ] +) def no_op(*args, **kwargs): @@ -78,9 +86,14 @@ class ThreadBundle: self._threads = [] self._lock = threading.Lock() - def bind(self, thread_factory, - before_start=None, after_start=None, - before_join=None, after_join=None): + def bind( + self, + thread_factory, + before_start=None, + after_start=None, + before_join=None, + after_join=None, + ): """Adds a thread (to-be) into this bundle (with given callbacks). NOTE(harlowja): callbacks provided should not attempt to call @@ -97,23 +110,27 @@ class ThreadBundle: before_join = no_op if after_join is None: after_join = no_op - builder = _ThreadBuilder(thread_factory, - before_start, after_start, - before_join, after_join) + builder = _ThreadBuilder( + thread_factory, before_start, after_start, before_join, after_join + ) for attr_name in builder.fields: cb = getattr(builder, attr_name) if not callable(cb): - raise ValueError("Provided callback for argument" - " '%s' must be callable" % attr_name) + raise ValueError( + "Provided callback for argument" + " '%s' must be callable" % attr_name + ) with self._lock: - self._threads.append([ - builder, - # The built thread. - None, - # Whether the built thread was started (and should have - # ran or still be running). - False, - ]) + self._threads.append( + [ + builder, + # The built thread. + None, + # Whether the built thread was started (and should have + # ran or still be running). + False, + ] + ) def start(self): """Creates & starts all associated threads (that are not running).""" diff --git a/tools/schema_generator.py b/tools/schema_generator.py index 8202ab241..d5eb5f159 100755 --- a/tools/schema_generator.py +++ b/tools/schema_generator.py @@ -69,12 +69,16 @@ def main(): # see the basic schema... row_type = re.sub(r"\(.*?\)", "", r['type']).strip() if not row_type: - raise ValueError("Row %s of table '%s' was empty after" - " cleaning" % (r['cid'], table_name)) + raise ValueError( + "Row %s of table '%s' was empty after" + " cleaning" % (r['cid'], table_name) + ) rows.append([r['name'], row_type, to_bool_string(r['pk'])]) contents = tabulate.tabulate( - rows, headers=['Name', 'Type', 'Primary Key'], - tablefmt="rst") + rows, + headers=['Name', 'Type', 'Primary Key'], + tablefmt="rst", + ) print("\n%s" % contents.strip()) if i + 1 != len(table_names): print("") diff --git a/tools/speed_test.py b/tools/speed_test.py index 2e834bc76..794ec656f 100644 --- a/tools/speed_test.py +++ b/tools/speed_test.py @@ -37,7 +37,10 @@ def print_header(name): class ProfileIt: - stats_ordering = ('cumulative', 'calls',) + stats_ordering = ( + 'cumulative', + 'calls', + ) def __init__(self, name, args): self.name = name @@ -88,21 +91,34 @@ class DummyTask(task.Task): def main(): parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument('--profile', "-p", - dest='profile', action='store_true', - default=False, - help='profile instead of gather timing' - ' (default: False)') - parser.add_argument('--dummies', "-d", - dest='dummies', action='store', type=int, - default=100, metavar="", - help='how many dummy/no-op tasks to inject' - ' (default: 100)') - parser.add_argument('--limit', '-l', - dest='limit', action='store', type=float, - default=100.0, metavar="", - help='percentage of profiling output to show' - ' (default: 100%%)') + parser.add_argument( + '--profile', + "-p", + dest='profile', + action='store_true', + default=False, + help='profile instead of gather timing (default: False)', + ) + parser.add_argument( + '--dummies', + "-d", + dest='dummies', + action='store', + type=int, + default=100, + metavar="", + help='how many dummy/no-op tasks to inject (default: 100)', + ) + parser.add_argument( + '--limit', + '-l', + dest='limit', + action='store', + type=float, + default=100.0, + metavar="", + help='percentage of profiling output to show (default: 100%%)', + ) args = parser.parse_args() if args.profile: ctx_manager = ProfileIt diff --git a/tools/state_graph.py b/tools/state_graph.py index 3535f4b33..b4a5c4711 100755 --- a/tools/state_graph.py +++ b/tools/state_graph.py @@ -20,8 +20,7 @@ import optparse import os import sys -top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), - os.pardir)) +top_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) sys.path.insert(0, top_dir) from automaton.converters import pydot @@ -47,7 +46,7 @@ def make_machine(start_state, transitions, event_name_cb): machine = machines.FiniteMachine() machine.add_state(start_state) machine.default_start_state = start_state - for (start_state, end_state) in transitions: + for start_state, end_state in transitions: if start_state not in machine: machine.add_state(start_state) if end_state not in machine: @@ -59,35 +58,67 @@ def make_machine(start_state, transitions, event_name_cb): def main(): parser = optparse.OptionParser() - parser.add_option("-f", "--file", dest="filename", - help="write svg to FILE", metavar="FILE") - parser.add_option("-t", "--tasks", dest="tasks", - action='store_true', - help="use task state transitions", - default=False) - parser.add_option("-r", "--retries", dest="retries", - action='store_true', - help="use retry state transitions", - default=False) - parser.add_option("-e", "--engines", dest="engines", - action='store_true', - help="use engine state transitions", - default=False) - parser.add_option("-w", "--wbe-requests", dest="wbe_requests", - action='store_true', - help="use wbe request transitions", - default=False) - parser.add_option("-j", "--jobs", dest="jobs", - action='store_true', - help="use job transitions", - default=False) - parser.add_option("--flow", dest="flow", - action='store_true', - help="use flow transitions", - default=False) - parser.add_option("-T", "--format", dest="format", - help="output in given format", - default='svg') + parser.add_option( + "-f", + "--file", + dest="filename", + help="write svg to FILE", + metavar="FILE", + ) + parser.add_option( + "-t", + "--tasks", + dest="tasks", + action='store_true', + help="use task state transitions", + default=False, + ) + parser.add_option( + "-r", + "--retries", + dest="retries", + action='store_true', + help="use retry state transitions", + default=False, + ) + parser.add_option( + "-e", + "--engines", + dest="engines", + action='store_true', + help="use engine state transitions", + default=False, + ) + parser.add_option( + "-w", + "--wbe-requests", + dest="wbe_requests", + action='store_true', + help="use wbe request transitions", + default=False, + ) + parser.add_option( + "-j", + "--jobs", + dest="jobs", + action='store_true', + help="use job transitions", + default=False, + ) + parser.add_option( + "--flow", + dest="flow", + action='store_true', + help="use flow transitions", + default=False, + ) + parser.add_option( + "-T", + "--format", + dest="format", + help="output in given format", + default='svg', + ) (options, args) = parser.parse_args() if options.filename is None: @@ -103,30 +134,40 @@ def main(): ] provided = sum([int(i) for i in types]) if provided > 1: - parser.error("Only one of task/retry/engines/wbe requests/jobs/flow" - " may be specified.") + parser.error( + "Only one of task/retry/engines/wbe requests/jobs/flow" + " may be specified." + ) if provided == 0: - parser.error("One of task/retry/engines/wbe requests/jobs/flow" - " must be specified.") + parser.error( + "One of task/retry/engines/wbe requests/jobs/flow" + " must be specified." + ) event_name_cb = lambda start_state, end_state: "on_%s" % end_state.lower() internal_states = list() ordering = 'in' if options.tasks: source_type = "Tasks" - source = make_machine(states.PENDING, - list(states._ALLOWED_TASK_TRANSITIONS), - event_name_cb) + source = make_machine( + states.PENDING, + list(states._ALLOWED_TASK_TRANSITIONS), + event_name_cb, + ) elif options.retries: source_type = "Retries" - source = make_machine(states.PENDING, - list(states._ALLOWED_RETRY_TRANSITIONS), - event_name_cb) + source = make_machine( + states.PENDING, + list(states._ALLOWED_RETRY_TRANSITIONS), + event_name_cb, + ) elif options.flow: source_type = "Flow" - source = make_machine(states.PENDING, - list(states._ALLOWED_FLOW_TRANSITIONS), - event_name_cb) + source = make_machine( + states.PENDING, + list(states._ALLOWED_FLOW_TRANSITIONS), + event_name_cb, + ) elif options.engines: source_type = "Engines" b = builder.MachineBuilder(DummyRuntime(), mock.MagicMock()) @@ -135,14 +176,18 @@ def main(): ordering = 'out' elif options.wbe_requests: source_type = "WBE requests" - source = make_machine(protocol.WAITING, - list(protocol._ALLOWED_TRANSITIONS), - event_name_cb) + source = make_machine( + protocol.WAITING, + list(protocol._ALLOWED_TRANSITIONS), + event_name_cb, + ) elif options.jobs: source_type = "Jobs" - source = make_machine(states.UNCLAIMED, - list(states._ALLOWED_JOB_TRANSITIONS), - event_name_cb) + source = make_machine( + states.UNCLAIMED, + list(states._ALLOWED_JOB_TRANSITIONS), + event_name_cb, + ) graph_attrs = { 'ordering': ordering, @@ -176,8 +221,13 @@ def main(): edge_attrs['fontcolor'] = 'green' return edge_attrs - g = pydot.convert(source, graph_name, graph_attrs=graph_attrs, - node_attrs_cb=node_attrs_cb, edge_attrs_cb=edge_attrs_cb) + g = pydot.convert( + source, + graph_name, + graph_attrs=graph_attrs, + node_attrs_cb=node_attrs_cb, + edge_attrs_cb=edge_attrs_cb, + ) print("*" * len(graph_name)) print(graph_name) print("*" * len(graph_name)) diff --git a/tools/subunit_trace.py b/tools/subunit_trace.py index b174e8a87..5d0ba424e 100755 --- a/tools/subunit_trace.py +++ b/tools/subunit_trace.py @@ -34,7 +34,6 @@ RESULTS = {} class Starts(testtools.StreamResult): - def __init__(self, output): super().__init__() self._output = output @@ -43,14 +42,31 @@ class Starts(testtools.StreamResult): self._neednewline = False self._emitted = set() - def status(self, test_id=None, test_status=None, test_tags=None, - runnable=True, file_name=None, file_bytes=None, eof=False, - mime_type=None, route_code=None, timestamp=None): + def status( + self, + test_id=None, + test_status=None, + test_tags=None, + runnable=True, + file_name=None, + file_bytes=None, + eof=False, + mime_type=None, + route_code=None, + timestamp=None, + ): super().status( - test_id, test_status, - test_tags=test_tags, runnable=runnable, file_name=file_name, - file_bytes=file_bytes, eof=eof, mime_type=mime_type, - route_code=route_code, timestamp=timestamp) + test_id, + test_status, + test_tags=test_tags, + runnable=runnable, + file_name=file_name, + file_bytes=file_bytes, + eof=eof, + mime_type=mime_type, + route_code=route_code, + timestamp=timestamp, + ) if not test_id: if not file_bytes: return @@ -58,9 +74,11 @@ class Starts(testtools.StreamResult): mime_type = 'text/plain; charset=utf-8' primary, sub, parameters = mimeparse.parse_mime_type(mime_type) content_type = testtools.content_type.ContentType( - primary, sub, parameters) + primary, sub, parameters + ) content = testtools.content.Content( - content_type, lambda: [file_bytes]) + content_type, lambda: [file_bytes] + ) text = content.as_text() if text and text[-1] not in '\r\n': self._neednewline = True @@ -77,8 +95,9 @@ class Starts(testtools.StreamResult): timestr = timestamp.isoformat() else: timestr = '' - self._output.write('%s: %s%s [start]\n' % - (timestr, worker, test_id)) + self._output.write( + '%s: %s%s [start]\n' % (timestr, worker, test_id) + ) self._emitted.add(test_id) @@ -97,7 +116,7 @@ def cleanup_test_name(name, strip_tags=True, strip_scenarios=False): tags_end = name.find(']') if tags_start > 0 and tags_end > tags_start: newname = name[:tags_start] - newname += name[tags_end + 1:] + newname += name[tags_end + 1 :] name = newname if strip_scenarios: @@ -105,7 +124,7 @@ def cleanup_test_name(name, strip_tags=True, strip_scenarios=False): tags_end = name.find(')') if tags_start > 0 and tags_end > tags_start: newname = name[:tags_start] - newname += name[tags_end + 1:] + newname += name[tags_end + 1 :] name = newname return name @@ -118,7 +137,9 @@ def get_duration(timestamps): else: delta = end - start duration = '%d.%06ds' % ( - delta.days * DAY_SECONDS + delta.seconds, delta.microseconds) + delta.days * DAY_SECONDS + delta.seconds, + delta.microseconds, + ) return duration @@ -174,21 +195,29 @@ def show_outcome(stream, test, print_failures=False, failonly=False): if status == 'fail': FAILS.append(test) - stream.write('{{{}}} {} [{}] ... FAILED\n'.format( - worker, name, duration)) + stream.write( + '{{{}}} {} [{}] ... FAILED\n'.format(worker, name, duration) + ) if not print_failures: print_attachments(stream, test, all_channels=True) elif not failonly: if status == 'success': - stream.write('{{{}}} {} [{}] ... ok\n'.format( - worker, name, duration)) + stream.write( + '{{{}}} {} [{}] ... ok\n'.format(worker, name, duration) + ) print_attachments(stream, test) elif status == 'skip': - stream.write('{{{}}} {} ... SKIPPED: {}\n'.format( - worker, name, test['details']['reason'].as_text())) + stream.write( + '{{{}}} {} ... SKIPPED: {}\n'.format( + worker, name, test['details']['reason'].as_text() + ) + ) else: - stream.write('{{{}}} {} [{}] ... {}\n'.format( - worker, name, duration, test['status'])) + stream.write( + '{{{}}} {} [{}] ... {}\n'.format( + worker, name, duration, test['status'] + ) + ) if not print_failures: print_attachments(stream, test, all_channels=True) @@ -240,8 +269,9 @@ def worker_stats(worker): def print_summary(stream): stream.write("\n======\nTotals\n======\n") - stream.write("Run: {} in {} sec.\n".format(count_tests('status', '.*'), - run_time())) + stream.write( + "Run: {} in {} sec.\n".format(count_tests('status', '.*'), run_time()) + ) stream.write(" - Passed: %s\n" % count_tests('status', 'success')) stream.write(" - Skipped: %s\n" % count_tests('status', 'skip')) stream.write(" - Failed: %s\n" % count_tests('status', 'fail')) @@ -254,38 +284,55 @@ def print_summary(stream): if w not in RESULTS: stream.write( " - WARNING: missing Worker %s! " - "Race in testr accounting.\n" % w) + "Race in testr accounting.\n" % w + ) else: num, time = worker_stats(w) - stream.write(" - Worker %s (%s tests) => %ss\n" % - (w, num, time)) + stream.write( + " - Worker %s (%s tests) => %ss\n" % (w, num, time) + ) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--no-failure-debug', '-n', action='store_true', - dest='print_failures', help='Disable printing failure ' - 'debug information in realtime') - parser.add_argument('--fails', '-f', action='store_true', - dest='post_fails', help='Print failure debug ' - 'information after the stream is proccesed') - parser.add_argument('--failonly', action='store_true', - dest='failonly', help="Don't print success items", - default=( - os.environ.get('TRACE_FAILONLY', False) - is not False)) + parser.add_argument( + '--no-failure-debug', + '-n', + action='store_true', + dest='print_failures', + help='Disable printing failure debug information in realtime', + ) + parser.add_argument( + '--fails', + '-f', + action='store_true', + dest='post_fails', + help='Print failure debug information after the stream is proccesed', + ) + parser.add_argument( + '--failonly', + action='store_true', + dest='failonly', + help="Don't print success items", + default=(os.environ.get('TRACE_FAILONLY', False) is not False), + ) return parser.parse_args() def main(): args = parse_args() stream = subunit.ByteStreamToStreamResult( - sys.stdin, non_subunit_name='stdout') + sys.stdin, non_subunit_name='stdout' + ) starts = Starts(sys.stdout) outcomes = testtools.StreamToDict( - functools.partial(show_outcome, sys.stdout, - print_failures=args.print_failures, - failonly=args.failonly)) + functools.partial( + show_outcome, + sys.stdout, + print_failures=args.print_failures, + failonly=args.failonly, + ) + ) summary = testtools.StreamSummary() result = testtools.CopyStreamResult([starts, outcomes, summary]) result.startTestRun() @@ -299,7 +346,7 @@ def main(): if args.post_fails: print_fails(sys.stdout) print_summary(sys.stdout) - return (0 if summary.wasSuccessful() else 1) + return 0 if summary.wasSuccessful() else 1 if __name__ == '__main__': diff --git a/tox.ini b/tox.ini index 8b2310194..e5578de07 100644 --- a/tox.ini +++ b/tox.ini @@ -73,14 +73,14 @@ deps = commands = pylint taskflow [flake8] +# We only enable the hacking (H) checks +select = H builtins = _ -exclude = .venv,.tox,dist,doc,*egg,.git,build,tools # H203: Use assertIs(Not)None to check for None # H204: Use assert(Not)Equal() # H205: Use assert{Greater,Less}[Equal] -# H904: Delay string interpolations at logging calls -enable-extensions = H203,H204,H205,H904 -ignore = E305,E402,E721,E731,E741,W503,W504 +enable-extensions = H203,H204,H205 +exclude = .venv,.tox,dist,doc,*egg,.git,build,tools [hacking] import_exceptions =