diff --git a/taskflow/engines/worker_based/protocol.py b/taskflow/engines/worker_based/protocol.py index 63556c25..d2d8e34d 100644 --- a/taskflow/engines/worker_based/protocol.py +++ b/taskflow/engines/worker_based/protocol.py @@ -20,6 +20,7 @@ import threading import fasteners import futurist +from oslo_serialization import jsonutils from oslo_utils import reflection from oslo_utils import timeutils import six @@ -100,6 +101,17 @@ RESPONSE = 'RESPONSE' LOG = logging.getLogger(__name__) +def failure_to_dict(failure): + failure_dict = failure.to_dict() + try: + # it's possible the exc_args can't be serialized as JSON + # if that's the case, just get the failure without them + jsonutils.dumps(failure_dict) + return failure_dict + except (TypeError, ValueError): + return failure.to_dict(include_args=False) + + @six.add_metaclass(abc.ABCMeta) class Message(object): """Base class for all message types.""" @@ -301,6 +313,7 @@ class Request(Message): convert all `failure.Failure` objects into dictionaries (which will then be reconstituted by the receiver). """ + request = { 'task_cls': reflection.get_class_name(self._task), 'task_name': self._task.name, @@ -311,14 +324,14 @@ class Request(Message): if 'result' in self._kwargs: result = self._kwargs['result'] if isinstance(result, ft.Failure): - request['result'] = ('failure', result.to_dict()) + request['result'] = ('failure', failure_to_dict(result)) else: request['result'] = ('success', result) if 'failures' in self._kwargs: failures = self._kwargs['failures'] request['failures'] = {} for task, failure in six.iteritems(failures): - request['failures'][task] = failure.to_dict() + request['failures'][task] = failure_to_dict(failure) return request def set_result(self, result): diff --git a/taskflow/engines/worker_based/server.py b/taskflow/engines/worker_based/server.py index 4c3b32bc..978ab3ac 100644 --- a/taskflow/engines/worker_based/server.py +++ b/taskflow/engines/worker_based/server.py @@ -179,7 +179,7 @@ class Server(object): with misc.capture_failure() as failure: LOG.warn("Failed to parse request contents from message '%s'", ku.DelayedPretty(message), exc_info=True) - reply_callback(result=failure.to_dict()) + reply_callback(result=pr.failure_to_dict(failure)) return # Now fetch the task endpoint (and action handler on it). @@ -191,7 +191,7 @@ class Server(object): " to continue processing request message '%s'", work.task_cls, ku.DelayedPretty(message), exc_info=True) - reply_callback(result=failure.to_dict()) + reply_callback(result=pr.failure_to_dict(failure)) return else: try: @@ -202,7 +202,7 @@ class Server(object): " '%s', unable to continue processing request" " message '%s'", work.action, endpoint, ku.DelayedPretty(message), exc_info=True) - reply_callback(result=failure.to_dict()) + reply_callback(result=pr.failure_to_dict(failure)) return else: try: @@ -212,7 +212,7 @@ class Server(object): LOG.warn("The '%s' task '%s' generation for request" " message '%s' failed", endpoint, work.action, ku.DelayedPretty(message), exc_info=True) - reply_callback(result=failure.to_dict()) + reply_callback(result=pr.failure_to_dict(failure)) return else: if not reply_callback(state=pr.RUNNING): @@ -240,7 +240,7 @@ class Server(object): LOG.warn("The '%s' endpoint '%s' execution for request" " message '%s' failed", endpoint, work.action, ku.DelayedPretty(message), exc_info=True) - reply_callback(result=failure.to_dict()) + reply_callback(result=pr.failure_to_dict(failure)) else: # And be done with it! if isinstance(result, ft.Failure): diff --git a/taskflow/tests/unit/test_failure.py b/taskflow/tests/unit/test_failure.py index 6d12ac09..bc95dd0f 100644 --- a/taskflow/tests/unit/test_failure.py +++ b/taskflow/tests/unit/test_failure.py @@ -283,6 +283,16 @@ class FailureObjectTestCase(test.TestCase): text = captured.pformat(traceback=True) self.assertIn("Traceback (most recent call last):", text) + def test_no_capture_exc_args(self): + captured = _captured_failure(Exception("I am not valid JSON")) + fail_obj = failure.Failure(exception_str=captured.exception_str, + traceback_str=captured.traceback_str, + exc_type_names=list(captured), + exc_args=list(captured.exception_args)) + fail_json = fail_obj.to_dict(include_args=False) + self.assertNotEqual(fail_obj.exception_args, fail_json['exc_args']) + self.assertEqual(fail_json['exc_args'], tuple()) + class WrappedFailureTestCase(test.TestCase): diff --git a/taskflow/tests/unit/worker_based/test_protocol.py b/taskflow/tests/unit/worker_based/test_protocol.py index 71116864..a0127eb9 100644 --- a/taskflow/tests/unit/worker_based/test_protocol.py +++ b/taskflow/tests/unit/worker_based/test_protocol.py @@ -162,6 +162,14 @@ class TestProtocol(test.TestCase): failures={self.task.name: a_failure.to_dict()}) self.assertEqual(expected, request.to_dict()) + def test_to_dict_with_invalid_json_failures(self): + exc = RuntimeError(Exception("I am not valid JSON")) + a_failure = failure.Failure.from_exception(exc) + request = self.request(failures={self.task.name: a_failure}) + expected = self.request_to_dict( + failures={self.task.name: a_failure.to_dict(include_args=False)}) + self.assertEqual(expected, request.to_dict()) + @mock.patch('oslo_utils.timeutils.now') def test_pending_not_expired(self, now): now.return_value = 0 diff --git a/taskflow/types/failure.py b/taskflow/types/failure.py index 9b7b2182..ec33dd93 100644 --- a/taskflow/types/failure.py +++ b/taskflow/types/failure.py @@ -499,14 +499,18 @@ class Failure(mixins.StrMixin): data['causes'] = tuple(cls.from_dict(d) for d in causes) return cls(**data) - def to_dict(self): - """Converts this object to a dictionary.""" + def to_dict(self, include_args=True): + """Converts this object to a dictionary. + + :param include_args: boolean indicating whether to include the + exception args in the output. + """ return { 'exception_str': self.exception_str, 'traceback_str': self.traceback_str, 'exc_type_names': list(self), 'version': self.DICT_VERSION, - 'exc_args': self.exception_args, + 'exc_args': self.exception_args if include_args else tuple(), 'causes': [f.to_dict() for f in self.causes], }