Rewind original stream body when refreshing.
When refreshing credentials, the original request is re-sent after the credentials are refreshed. If the body of that request is a stream, the stream contents are read in the initial request, and the stream must be rewound before the request is re-sent. Otherwise, the original message body will be different (because stream data was skipped).
This commit is contained in:
@@ -554,6 +554,11 @@ class OAuth2Credentials(Credentials):
|
||||
else:
|
||||
headers['user-agent'] = self.user_agent
|
||||
|
||||
body_stream_position = None
|
||||
if all(getattr(body, stream_prop, None) for stream_prop in
|
||||
('read', 'seek', 'tell')):
|
||||
body_stream_position = body.tell()
|
||||
|
||||
resp, content = request_orig(uri, method, body, clean_headers(headers),
|
||||
redirections, connection_type)
|
||||
|
||||
@@ -567,6 +572,9 @@ class OAuth2Credentials(Credentials):
|
||||
refresh_attempt + 1, max_refresh_attempts)
|
||||
self._refresh(request_orig)
|
||||
self.apply(headers)
|
||||
if body_stream_position is not None:
|
||||
body.seek(body_stream_position)
|
||||
|
||||
resp, content = request_orig(uri, method, body, clean_headers(headers),
|
||||
redirections, connection_type)
|
||||
|
||||
|
||||
@@ -100,17 +100,16 @@ class HttpMockSequence(object):
|
||||
connection_type=None):
|
||||
resp, content = self._iterable.pop(0)
|
||||
self.requests.append({'uri': uri, 'body': body, 'headers': headers})
|
||||
# Read any underlying stream before sending the request.
|
||||
body_stream_content = body.read() if getattr(body, 'read', None) else None
|
||||
if content == 'echo_request_headers':
|
||||
content = headers
|
||||
elif content == 'echo_request_headers_as_json':
|
||||
content = json.dumps(headers)
|
||||
elif content == 'echo_request_body':
|
||||
if hasattr(body, 'read'):
|
||||
content = body.read()
|
||||
else:
|
||||
content = body
|
||||
content = body if body_stream_content is None else body_stream_content
|
||||
elif content == 'echo_request_uri':
|
||||
content = uri
|
||||
elif not isinstance(content, bytes):
|
||||
raise TypeError("http content should be bytes: %r" % (content,))
|
||||
raise TypeError('http content should be bytes: %r' % (content,))
|
||||
return httplib2.Response(resp), content
|
||||
|
||||
@@ -32,12 +32,15 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from .http_mock import HttpMockSequence
|
||||
import six
|
||||
|
||||
from oauth2client import file
|
||||
from oauth2client import locked_file
|
||||
from oauth2client import multistore_file
|
||||
from oauth2client import util
|
||||
from oauth2client.client import AccessTokenCredentials
|
||||
from oauth2client.client import OAuth2Credentials
|
||||
from six.moves import http_client
|
||||
try:
|
||||
# Python2
|
||||
from future_builtins import oct
|
||||
@@ -154,15 +157,17 @@ class OAuth2ClientFileTests(unittest.TestCase):
|
||||
access_token = '1/3w'
|
||||
token_response = {'access_token': access_token, 'expires_in': 3600}
|
||||
http = HttpMockSequence([
|
||||
({'status': '401'}, b'Initial token expired'),
|
||||
({'status': '401'}, b'Store token expired'),
|
||||
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
|
||||
({'status': '200'}, b'Valid response to original request')
|
||||
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
|
||||
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
|
||||
({'status': str(http_client.OK)},
|
||||
json.dumps(token_response).encode('utf-8')),
|
||||
({'status': str(http_client.OK)},
|
||||
b'Valid response to original request')
|
||||
])
|
||||
|
||||
credentials.authorize(http)
|
||||
http.request('https://example.com')
|
||||
self.assertEquals(credentials.access_token, access_token)
|
||||
self.assertEqual(credentials.access_token, access_token)
|
||||
|
||||
def test_token_refresh_good_store(self):
|
||||
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
||||
@@ -178,6 +183,34 @@ class OAuth2ClientFileTests(unittest.TestCase):
|
||||
credentials._refresh(lambda x: x)
|
||||
self.assertEquals(credentials.access_token, 'bar')
|
||||
|
||||
def test_token_refresh_stream_body(self):
|
||||
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
||||
credentials = self.create_test_credentials(expiration=expiration)
|
||||
|
||||
s = file.Storage(FILENAME)
|
||||
s.put(credentials)
|
||||
credentials = s.get()
|
||||
new_cred = copy.copy(credentials)
|
||||
new_cred.access_token = 'bar'
|
||||
s.put(new_cred)
|
||||
|
||||
valid_access_token = '1/3w'
|
||||
token_response = {'access_token': valid_access_token, 'expires_in': 3600}
|
||||
http = HttpMockSequence([
|
||||
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
|
||||
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
|
||||
({'status': str(http_client.OK)},
|
||||
json.dumps(token_response).encode('utf-8')),
|
||||
({'status': str(http_client.OK)}, 'echo_request_body')
|
||||
])
|
||||
|
||||
body = six.StringIO('streaming body')
|
||||
|
||||
credentials.authorize(http)
|
||||
_, content = http.request('https://example.com', body=body)
|
||||
self.assertEqual(content, 'streaming body')
|
||||
self.assertEqual(credentials.access_token, valid_access_token)
|
||||
|
||||
def test_credentials_delete(self):
|
||||
credentials = self.create_test_credentials()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user