Rewind original stream body when refreshing.

When refreshing credentials, the original request is re-sent
after the credentials are refreshed.

If the body of that request is a stream, the stream contents
are read in the initial request, and the stream must be rewound
before the request is re-sent. Otherwise, the original message
body will be different (because stream data was skipped).
This commit is contained in:
Travis Hobrla
2015-05-11 15:42:45 -07:00
parent d93ed1ed1e
commit 7d517d9295
3 changed files with 50 additions and 10 deletions

View File

@@ -554,6 +554,11 @@ class OAuth2Credentials(Credentials):
else: else:
headers['user-agent'] = self.user_agent headers['user-agent'] = self.user_agent
body_stream_position = None
if all(getattr(body, stream_prop, None) for stream_prop in
('read', 'seek', 'tell')):
body_stream_position = body.tell()
resp, content = request_orig(uri, method, body, clean_headers(headers), resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type) redirections, connection_type)
@@ -567,6 +572,9 @@ class OAuth2Credentials(Credentials):
refresh_attempt + 1, max_refresh_attempts) refresh_attempt + 1, max_refresh_attempts)
self._refresh(request_orig) self._refresh(request_orig)
self.apply(headers) self.apply(headers)
if body_stream_position is not None:
body.seek(body_stream_position)
resp, content = request_orig(uri, method, body, clean_headers(headers), resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type) redirections, connection_type)

View File

@@ -100,17 +100,16 @@ class HttpMockSequence(object):
connection_type=None): connection_type=None):
resp, content = self._iterable.pop(0) resp, content = self._iterable.pop(0)
self.requests.append({'uri': uri, 'body': body, 'headers': headers}) self.requests.append({'uri': uri, 'body': body, 'headers': headers})
# Read any underlying stream before sending the request.
body_stream_content = body.read() if getattr(body, 'read', None) else None
if content == 'echo_request_headers': if content == 'echo_request_headers':
content = headers content = headers
elif content == 'echo_request_headers_as_json': elif content == 'echo_request_headers_as_json':
content = json.dumps(headers) content = json.dumps(headers)
elif content == 'echo_request_body': elif content == 'echo_request_body':
if hasattr(body, 'read'): content = body if body_stream_content is None else body_stream_content
content = body.read()
else:
content = body
elif content == 'echo_request_uri': elif content == 'echo_request_uri':
content = uri content = uri
elif not isinstance(content, bytes): elif not isinstance(content, bytes):
raise TypeError("http content should be bytes: %r" % (content,)) raise TypeError('http content should be bytes: %r' % (content,))
return httplib2.Response(resp), content return httplib2.Response(resp), content

View File

@@ -32,12 +32,15 @@ import tempfile
import unittest import unittest
from .http_mock import HttpMockSequence from .http_mock import HttpMockSequence
import six
from oauth2client import file from oauth2client import file
from oauth2client import locked_file from oauth2client import locked_file
from oauth2client import multistore_file from oauth2client import multistore_file
from oauth2client import util from oauth2client import util
from oauth2client.client import AccessTokenCredentials from oauth2client.client import AccessTokenCredentials
from oauth2client.client import OAuth2Credentials from oauth2client.client import OAuth2Credentials
from six.moves import http_client
try: try:
# Python2 # Python2
from future_builtins import oct from future_builtins import oct
@@ -154,15 +157,17 @@ class OAuth2ClientFileTests(unittest.TestCase):
access_token = '1/3w' access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600} token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([ http = HttpMockSequence([
({'status': '401'}, b'Initial token expired'), ({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
({'status': '401'}, b'Store token expired'), ({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
({'status': '200'}, json.dumps(token_response).encode('utf-8')), ({'status': str(http_client.OK)},
({'status': '200'}, b'Valid response to original request') json.dumps(token_response).encode('utf-8')),
({'status': str(http_client.OK)},
b'Valid response to original request')
]) ])
credentials.authorize(http) credentials.authorize(http)
http.request('https://example.com') http.request('https://example.com')
self.assertEquals(credentials.access_token, access_token) self.assertEqual(credentials.access_token, access_token)
def test_token_refresh_good_store(self): def test_token_refresh_good_store(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15) expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
@@ -178,6 +183,34 @@ class OAuth2ClientFileTests(unittest.TestCase):
credentials._refresh(lambda x: x) credentials._refresh(lambda x: x)
self.assertEquals(credentials.access_token, 'bar') self.assertEquals(credentials.access_token, 'bar')
def test_token_refresh_stream_body(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)
s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
valid_access_token = '1/3w'
token_response = {'access_token': valid_access_token, 'expires_in': 3600}
http = HttpMockSequence([
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
({'status': str(http_client.OK)},
json.dumps(token_response).encode('utf-8')),
({'status': str(http_client.OK)}, 'echo_request_body')
])
body = six.StringIO('streaming body')
credentials.authorize(http)
_, content = http.request('https://example.com', body=body)
self.assertEqual(content, 'streaming body')
self.assertEqual(credentials.access_token, valid_access_token)
def test_credentials_delete(self): def test_credentials_delete(self):
credentials = self.create_test_credentials() credentials = self.create_test_credentials()