From ad133adea6ee293f0cfc6145a483fa0cfc27faf6 Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Tue, 10 Mar 2015 18:01:44 -0700 Subject: [PATCH] Add + use failure json schema validation Change-Id: Ie3aa386c831459a028ba494570bafd53b998126e --- doc/source/utils.rst | 5 ++ taskflow/engines/worker_based/protocol.py | 71 +++++++++++++++---- taskflow/engines/worker_based/server.py | 30 +------- taskflow/tests/unit/test_failure.py | 31 ++++++++ .../tests/unit/worker_based/test_server.py | 8 +-- taskflow/types/failure.py | 50 +++++++++++++ taskflow/utils/schema_utils.py | 34 +++++++++ 7 files changed, 182 insertions(+), 47 deletions(-) create mode 100644 taskflow/utils/schema_utils.py diff --git a/doc/source/utils.rst b/doc/source/utils.rst index 1f774663..6949ccf0 100644 --- a/doc/source/utils.rst +++ b/doc/source/utils.rst @@ -48,6 +48,11 @@ Persistence .. automodule:: taskflow.utils.persistence_utils +Schema +~~~~~~ + +.. automodule:: taskflow.utils.schema_utils + Threading ~~~~~~~~~ diff --git a/taskflow/engines/worker_based/protocol.py b/taskflow/engines/worker_based/protocol.py index 8a137471..b22d61fe 100644 --- a/taskflow/engines/worker_based/protocol.py +++ b/taskflow/engines/worker_based/protocol.py @@ -18,8 +18,6 @@ import abc import threading from concurrent import futures -import jsonschema -from jsonschema import exceptions as schema_exc from oslo_utils import reflection from oslo_utils import timeutils import six @@ -30,6 +28,7 @@ from taskflow import logging from taskflow.types import failure as ft from taskflow.types import timing as tt from taskflow.utils import lock_utils +from taskflow.utils import schema_utils as su # NOTE(skudriashev): This is protocol states and events, which are not # related to task states. @@ -98,12 +97,6 @@ NOTIFY = 'NOTIFY' REQUEST = 'REQUEST' RESPONSE = 'RESPONSE' -# Special jsonschema validation types/adjustments. -_SCHEMA_TYPES = { - # See: https://github.com/Julian/jsonschema/issues/148 - 'array': (list, tuple), -} - LOG = logging.getLogger(__name__) @@ -166,8 +159,8 @@ class Notify(Message): else: schema = cls.SENDER_SCHEMA try: - jsonschema.validate(data, schema, types=_SCHEMA_TYPES) - except schema_exc.ValidationError as e: + su.schema_validate(data, schema) + except su.ValidationError as e: if response: raise excp.InvalidFormat("%s message response data not of the" " expected format: %s" @@ -358,11 +351,57 @@ class Request(Message): @classmethod def validate(cls, data): try: - jsonschema.validate(data, cls.SCHEMA, types=_SCHEMA_TYPES) - except schema_exc.ValidationError as e: + su.schema_validate(data, cls.SCHEMA) + except su.ValidationError as e: raise excp.InvalidFormat("%s message response data not of the" " expected format: %s" % (cls.TYPE, e.message), e) + else: + # Validate all failure dictionaries that *may* be present... + failures = [] + if 'failures' in data: + failures.extend(six.itervalues(data['failures'])) + result = data.get('result') + if result is not None: + result_data_type, result_data = result + if result_data_type == 'failure': + failures.append(result_data) + for fail_data in failures: + ft.Failure.validate(fail_data) + + @staticmethod + def from_dict(data, task_uuid=None): + """Parses **validated** data before it can be further processed. + + All :py:class:`~taskflow.types.failure.Failure` objects that have been + converted to dict(s) on the remote side will now converted back + to py:class:`~taskflow.types.failure.Failure` objects. + """ + task_cls = data['task_cls'] + task_name = data['task_name'] + action = data['action'] + arguments = data.get('arguments', {}) + result = data.get('result') + failures = data.get('failures') + # These arguments will eventually be given to the task executor + # so they need to be in a format it will accept (and using keyword + # argument names that it accepts)... + arguments = { + 'arguments': arguments, + } + if task_uuid is not None: + arguments['task_uuid'] = task_uuid + if result is not None: + result_data_type, result_data = result + if result_data_type == 'failure': + arguments['result'] = ft.Failure.from_dict(result_data) + else: + arguments['result'] = result_data + if failures is not None: + arguments['failures'] = {} + for task, fail_data in six.iteritems(failures): + arguments['failures'][task] = ft.Failure.from_dict(fail_data) + return (task_cls, task_name, action, arguments) class Response(Message): @@ -455,8 +494,12 @@ class Response(Message): @classmethod def validate(cls, data): try: - jsonschema.validate(data, cls.SCHEMA, types=_SCHEMA_TYPES) - except schema_exc.ValidationError as e: + su.schema_validate(data, cls.SCHEMA) + except su.ValidationError as e: raise excp.InvalidFormat("%s message response data not of the" " expected format: %s" % (cls.TYPE, e.message), e) + else: + state = data['state'] + if state == FAILURE and 'result' in data: + ft.Failure.validate(data['result']) diff --git a/taskflow/engines/worker_based/server.py b/taskflow/engines/worker_based/server.py index 1043e879..99a39895 100644 --- a/taskflow/engines/worker_based/server.py +++ b/taskflow/engines/worker_based/server.py @@ -17,7 +17,6 @@ import functools from oslo_utils import reflection -import six from taskflow.engines.worker_based import dispatcher from taskflow.engines.worker_based import protocol as pr @@ -94,32 +93,6 @@ class Server(object): def connection_details(self): return self._proxy.connection_details - @staticmethod - def _parse_request(task_cls, task_name, action, arguments, result=None, - failures=None, **kwargs): - """Parse request before it can be further processed. - - All `failure.Failure` objects that have been converted to dict on the - remote side will now converted back to `failure.Failure` objects. - """ - # These arguments will eventually be given to the task executor - # so they need to be in a format it will accept (and using keyword - # argument names that it accepts)... - arguments = { - 'arguments': arguments, - } - if result is not None: - data_type, data = result - if data_type == 'failure': - arguments['result'] = ft.Failure.from_dict(data) - else: - arguments['result'] = data - if failures is not None: - arguments['failures'] = {} - for key, data in six.iteritems(failures): - arguments['failures'][key] = ft.Failure.from_dict(data) - return (task_cls, task_name, action, arguments) - @staticmethod def _parse_message(message): """Extracts required attributes out of the messages properties. @@ -201,9 +174,8 @@ class Server(object): # parse request to get task name, action and action arguments try: - bundle = self._parse_request(**request) + bundle = pr.Request.from_dict(request, task_uuid=task_uuid) task_cls, task_name, action, arguments = bundle - arguments['task_uuid'] = task_uuid except ValueError: with misc.capture_failure() as failure: LOG.warn("Failed to parse request contents from message '%s'", diff --git a/taskflow/tests/unit/test_failure.py b/taskflow/tests/unit/test_failure.py index c8f83b9b..fab9cb9a 100644 --- a/taskflow/tests/unit/test_failure.py +++ b/taskflow/tests/unit/test_failure.py @@ -137,6 +137,36 @@ class FromExceptionTestCase(test.TestCase, GeneralFailureObjTestsMixin): class FailureObjectTestCase(test.TestCase): + def test_invalids(self): + f = { + 'exception_str': 'blah', + 'traceback_str': 'blah', + 'exc_type_names': [], + } + self.assertRaises(exceptions.InvalidFormat, + failure.Failure.validate, f) + f = { + 'exception_str': 'blah', + 'exc_type_names': ['Exception'], + } + self.assertRaises(exceptions.InvalidFormat, + failure.Failure.validate, f) + f = { + 'exception_str': 'blah', + 'traceback_str': 'blah', + 'exc_type_names': ['Exception'], + 'version': -1, + } + self.assertRaises(exceptions.InvalidFormat, + failure.Failure.validate, f) + + def test_valid_from_dict_to_dict(self): + f = _captured_failure('Woot!') + d_f = f.to_dict() + failure.Failure.validate(d_f) + f2 = failure.Failure.from_dict(d_f) + self.assertTrue(f.matches(f2)) + def test_dont_catch_base_exception(self): try: raise SystemExit() @@ -358,6 +388,7 @@ class FailureCausesTest(test.TestCase): self.assertIsNotNone(f) d_f = f.to_dict() + failure.Failure.validate(d_f) f = failure.Failure.from_dict(d_f) self.assertEqual(2, len(f.causes)) self.assertEqual("Still not working", f.causes[0].exception_str) diff --git a/taskflow/tests/unit/worker_based/test_server.py b/taskflow/tests/unit/worker_based/test_server.py index fea5d1cc..9b7815c4 100644 --- a/taskflow/tests/unit/worker_based/test_server.py +++ b/taskflow/tests/unit/worker_based/test_server.py @@ -108,7 +108,7 @@ class TestServer(test.MockTestCase): def test_parse_request(self): request = self.make_request() - bundle = server.Server._parse_request(**request) + bundle = pr.Request.from_dict(request) task_cls, task_name, action, task_args = bundle self.assertEqual((task_cls, task_name, action, task_args), (self.task.name, self.task.name, self.task_action, @@ -116,7 +116,7 @@ class TestServer(test.MockTestCase): def test_parse_request_with_success_result(self): request = self.make_request(action='revert', result=1) - bundle = server.Server._parse_request(**request) + bundle = pr.Request.from_dict(request) task_cls, task_name, action, task_args = bundle self.assertEqual((task_cls, task_name, action, task_args), (self.task.name, self.task.name, 'revert', @@ -126,7 +126,7 @@ class TestServer(test.MockTestCase): def test_parse_request_with_failure_result(self): a_failure = failure.Failure.from_exception(Exception('test')) request = self.make_request(action='revert', result=a_failure) - bundle = server.Server._parse_request(**request) + bundle = pr.Request.from_dict(request) task_cls, task_name, action, task_args = bundle self.assertEqual((task_cls, task_name, action, task_args), (self.task.name, self.task.name, 'revert', @@ -137,7 +137,7 @@ class TestServer(test.MockTestCase): failures = {'0': failure.Failure.from_exception(Exception('test1')), '1': failure.Failure.from_exception(Exception('test2'))} request = self.make_request(action='revert', failures=failures) - bundle = server.Server._parse_request(**request) + bundle = pr.Request.from_dict(request) task_cls, task_name, action, task_args = bundle self.assertEqual( (task_cls, task_name, action, task_args), diff --git a/taskflow/types/failure.py b/taskflow/types/failure.py index f251b00e..1c98aa2d 100644 --- a/taskflow/types/failure.py +++ b/taskflow/types/failure.py @@ -23,6 +23,7 @@ from oslo_utils import reflection import six from taskflow import exceptions as exc +from taskflow.utils import schema_utils as su def _copy_exc_info(exc_info): @@ -132,6 +133,47 @@ class Failure(object): """ DICT_VERSION = 1 + #: Expected failure schema (in json schema format). + SCHEMA = { + "$ref": "#/definitions/cause", + "definitions": { + "cause": { + "type": "object", + 'properties': { + 'version': { + "type": "integer", + "minimum": 0, + }, + 'exception_str': { + "type": "string", + }, + 'traceback_str': { + "type": "string", + }, + 'exc_type_names': { + "type": "array", + "items": { + "type": "string", + }, + "minItems": 1, + }, + 'causes': { + "type": "array", + "items": { + "$ref": "#/definitions/cause", + }, + } + }, + "required": [ + "exception_str", + 'traceback_str', + 'exc_type_names', + ], + "additionalProperties": True, + }, + }, + } + def __init__(self, exc_info=None, **kwargs): if not kwargs: if exc_info is None: @@ -169,6 +211,14 @@ class Failure(object): """Creates a failure object from a exception instance.""" return cls((type(exception), exception, None)) + @classmethod + def validate(cls, data): + try: + su.schema_validate(data, cls.SCHEMA) + except su.ValidationError as e: + raise exc.InvalidFormat("Failure data not of the" + " expected format: %s" % (e.message), e) + def _matches(self, other): if self is other: return True diff --git a/taskflow/utils/schema_utils.py b/taskflow/utils/schema_utils.py new file mode 100644 index 00000000..8d7c216e --- /dev/null +++ b/taskflow/utils/schema_utils.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2015 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import jsonschema +from jsonschema import exceptions as schema_exc + +# Special jsonschema validation types/adjustments. +_SCHEMA_TYPES = { + # See: https://github.com/Julian/jsonschema/issues/148 + 'array': (list, tuple), +} + + +# Expose these types so that people don't have to import the same exceptions. +ValidationError = schema_exc.ValidationError +SchemaError = schema_exc.SchemaError + + +def schema_validate(data, schema): + """Validates given data using provided json schema.""" + jsonschema.validate(data, schema, types=_SCHEMA_TYPES)