diff --git a/taskflow/tests/unit/test_failure.py b/taskflow/tests/unit/test_failure.py index 4ecfa5c5..bad99456 100644 --- a/taskflow/tests/unit/test_failure.py +++ b/taskflow/tests/unit/test_failure.py @@ -168,6 +168,19 @@ class FailureObjectTestCase(test.TestCase): f2 = failure.Failure.from_dict(d_f) self.assertTrue(f.matches(f2)) + def test_bad_root_exception(self): + f = _captured_failure('Woot!') + d_f = f.to_dict() + d_f['exc_type_names'] = ['Junk'] + self.assertRaises(exceptions.InvalidFormat, + failure.Failure.validate, d_f) + + def test_valid_from_dict_to_dict_2(self): + f = _captured_failure('Woot!') + d_f = f.to_dict() + d_f['exc_type_names'] = ['RuntimeError', 'Exception', 'BaseException'] + failure.Failure.validate(d_f) + def test_dont_catch_base_exception(self): try: raise SystemExit() diff --git a/taskflow/types/failure.py b/taskflow/types/failure.py index 2663fb81..72f3710f 100644 --- a/taskflow/types/failure.py +++ b/taskflow/types/failure.py @@ -14,6 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. +import collections import copy import os import sys @@ -28,6 +29,7 @@ from taskflow.utils import iter_utils from taskflow.utils import mixins from taskflow.utils import schema_utils as su + _exception_message = encodeutils.exception_to_unicode @@ -121,6 +123,13 @@ class Failure(mixins.StrMixin): """ DICT_VERSION = 1 + BASE_EXCEPTIONS = ('BaseException', 'Exception') + """ + Root exceptions of all other python exceptions. + + See: https://docs.python.org/2/library/exceptions.html + """ + #: Expected failure schema (in json schema format). SCHEMA = { "$ref": "#/definitions/cause", @@ -206,11 +215,29 @@ class Failure(mixins.StrMixin): @classmethod def validate(cls, data): + """Validate input data matches expected failure ``dict`` format.""" try: su.schema_validate(data, cls.SCHEMA) except su.ValidationError as e: raise exc.InvalidFormat("Failure data not of the" " expected format: %s" % (e.message), e) + else: + # Ensure that all 'exc_type_names' originate from one of + # BASE_EXCEPTIONS, because those are the root exceptions that + # python mandates/provides and anything else is invalid... + causes = collections.deque([data]) + while causes: + cause = causes.popleft() + root_exc_type = cause['exc_type_names'][-1] + if root_exc_type not in cls.BASE_EXCEPTIONS: + raise exc.InvalidFormat( + "Failure data 'exc_type_names' must" + " have an initial exception type that is one" + " of %s types: '%s' is not one of those" + " types" % (cls.BASE_EXCEPTIONS, root_exc_type)) + sub_causes = cause.get('causes') + if sub_causes: + causes.extend(sub_causes) def _matches(self, other): if self is other: