diff --git a/tests/test_flask.py b/tests/test_flask.py index e5add1f..33b7cca 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -1,5 +1,5 @@ import unittest -from flask import Flask, json +from flask import Flask, json, abort from wsmeext.flask import signature from wsme.api import Response from wsme.types import Base, text @@ -46,6 +46,20 @@ def get_model(name): return Model(name=name) +@test_app.route('/models//secret') +@signature(Model, text) +def model_secret(name): + abort(403, description="You're not allowed in there!") + + +@test_app.route('/models//custom-error') +@signature(Model, text) +def model_custom_error(name): + class CustomError(Exception): + code = 412 + raise CustomError("FOO!") + + @test_app.route('/models', methods=['POST']) @signature(Model, body=Model) def post_model(body): @@ -111,6 +125,40 @@ class FlaskrTestCase(unittest.TestCase): resp = self.app.get('/status_response') assert resp.status_code == 201 + def test_custom_clientside_error(self): + r = self.app.get( + '/models/test/secret', + headers={'Accept': 'application/json'} + ) + assert r.status_code == 403, r.status_code + assert json.loads(r.data)['faultstring'] == '403: Forbidden' + + r = self.app.get( + '/models/test/secret', + headers={'Accept': 'application/xml'} + ) + assert r.status_code == 403, r.status_code + assert r.data == ('Server' + '403: Forbidden' + '') + + def test_custom_non_http_clientside_error(self): + r = self.app.get( + '/models/test/custom-error', + headers={'Accept': 'application/json'} + ) + assert r.status_code == 412, r.status_code + assert json.loads(r.data)['faultstring'] == 'FOO!' + + r = self.app.get( + '/models/test/custom-error', + headers={'Accept': 'application/xml'} + ) + assert r.status_code == 412, r.status_code + assert r.data == ('Server' + 'FOO!' + '') + def test_serversideerror(self): r = self.app.get('/divide_by_zero') assert r.status_code == 500 diff --git a/wsmeext/flask.py b/wsmeext/flask.py index f9f5b79..4abd3dd 100644 --- a/wsmeext/flask.py +++ b/wsmeext/flask.py @@ -9,6 +9,7 @@ import wsme.api import wsme.rest.json import wsme.rest.xml import wsme.rest.args +from wsmeext.utils import is_valid_code import flask @@ -78,10 +79,19 @@ def signature(*args, **kw): res.mimetype = dataformat.content_type res.status_code = status_code except: - data = wsme.api.format_exception(sys.exc_info()) + try: + exception_info = sys.exc_info() + orig_exception = exception_info[1] + orig_code = getattr(orig_exception, 'code', None) + data = wsme.api.format_exception(exception_info) + finally: + del exception_info + res = flask.make_response(dataformat.encode_error(None, data)) if data['faultcode'] == 'client': res.status_code = 400 + elif orig_code and is_valid_code(orig_code): + res.status_code = orig_code else: res.status_code = 500 return res