diff --git a/swift/common/middleware/formpost.py b/swift/common/middleware/formpost.py index b3dde1832a..383eed9fbe 100644 --- a/swift/common/middleware/formpost.py +++ b/swift/common/middleware/formpost.py @@ -136,10 +136,11 @@ from swift.common.digest import get_allowed_digests, \ extract_digest_and_algorithm, DEFAULT_ALLOWED_DIGESTS from swift.common.utils import streq_const_time, parse_content_disposition, \ parse_mime_headers, iter_multipart_mime_documents, reiterate, \ - close_if_possible, get_logger + closing_if_possible, get_logger from swift.common.registry import register_swift_info from swift.common.wsgi import make_pre_authed_env from swift.common.swob import HTTPUnauthorized, wsgi_to_str, str_to_wsgi +from swift.common.http import is_success from swift.proxy.controllers.base import get_account_info, get_container_info @@ -278,6 +279,7 @@ class FormPost(object): attributes = {} file_attributes = {} subheaders = [] + resp_body = None file_count = 0 for fp in iter_multipart_mime_documents( env['wsgi.input'], boundary, read_chunk_size=READ_CHUNK_SIZE): @@ -302,9 +304,10 @@ class FormPost(object): 'content-encoding' in hdrs: file_attributes['content-encoding'] = \ hdrs['Content-Encoding'] - status, subheaders = \ + status, subheaders, resp_body = \ self._perform_subrequest(env, file_attributes, fp, keys) - if not status.startswith('2'): + status_code = int(status.split(' ', 1)[0]) + if not is_success(status_code): break else: data = b'' @@ -325,6 +328,7 @@ class FormPost(object): status = '400 Bad Request' message = 'no files to process' + status_code = int(status.split(' ', 1)[0]) headers = [(k, v) for k, v in subheaders if k.lower().startswith('access-control')] @@ -333,17 +337,19 @@ class FormPost(object): body = status if message: body = status + '\r\nFormPost: ' + message.title() - headers.extend([('Content-Type', 'text/plain'), - ('Content-Length', len(body))]) if six.PY3: body = body.encode('utf-8') + if not is_success(status_code) and resp_body: + body = resp_body + headers.extend([('Content-Type', 'text/plain'), + ('Content-Length', len(body))]) return status, headers, body - status = status.split(' ', 1)[0] if '?' in redirect: redirect += '&' else: redirect += '?' - redirect += 'status=%s&message=%s' % (quote(status), quote(message)) + redirect += 'status=%s&message=%s' % (quote(str(status_code)), + quote(message)) body = '

' \ 'Click to continue...

' % redirect if six.PY3: @@ -447,8 +453,10 @@ class FormPost(object): # reiterate to ensure the response started, # but drop any data on the floor - close_if_possible(reiterate(self.app(subenv, _start_response))) - return substatus[0], subheaders[0] + resp = self.app(subenv, _start_response) + with closing_if_possible(reiterate(resp)): + body = b''.join(resp) + return substatus[0], subheaders[0], body def _get_keys(self, env): """ diff --git a/test/unit/common/middleware/test_formpost.py b/test/unit/common/middleware/test_formpost.py index d751062e88..4ebada2fe0 100644 --- a/test/unit/common/middleware/test_formpost.py +++ b/test/unit/common/middleware/test_formpost.py @@ -1046,7 +1046,7 @@ class TestFormPost(unittest.TestCase): self.assertTrue(b'201 Created' in body) self.assertEqual(len(self.app.requests), 2) - def test_subrequest_fails(self): + def test_subrequest_fails_redirect_404(self): key = b'abc' sig, env, body = self._make_sig_env_body( '/v1/AUTH_test/container', 'http://brim.net', 1024, 10, @@ -1083,6 +1083,36 @@ class TestFormPost(unittest.TestCase): self.assertTrue(b'http://brim.net?status=404&message=' in body) self.assertEqual(len(self.app.requests), 1) + def test_subrequest_fails_no_redirect_503(self): + key = b'abc' + sig, env, body = self._make_sig_env_body( + '/v1/AUTH_test/container', '', 1024, 10, + int(time() + 86400), key) + env['wsgi.input'] = BytesIO(b'\r\n'.join(body)) + env['swift.infocache'][get_cache_key('AUTH_test')] = ( + self._fake_cache_env('AUTH_test', [key])) + env['swift.infocache'][get_cache_key( + 'AUTH_test', 'container')] = {'meta': {}} + self.app = FakeApp(iter([('503 Server Error', {}, b'some bad news')])) + self.auth = tempauth.filter_factory({})(self.app) + self.formpost = formpost.filter_factory({})(self.auth) + status = [None] + headers = [None] + exc_info = [None] + + def start_response(s, h, e=None): + status[0] = s + headers[0] = h + exc_info[0] = e + + body = b''.join(self.formpost(env, start_response)) + status = status[0] + headers = headers[0] + exc_info = exc_info[0] + self.assertEqual(status, '503 Server Error') + self.assertTrue(b'bad news' in body) + self.assertEqual(len(self.app.requests), 1) + def test_truncated_attr_value(self): key = b'abc' redirect = 'a' * formpost.MAX_VALUE_LENGTH