Rewind original stream body when refreshing.
When refreshing credentials, the original request is re-sent after the credentials are refreshed. If the body of that request is a stream, the stream contents are read in the initial request, and the stream must be rewound before the request is re-sent. Otherwise, the original message body will be different (because stream data was skipped).
This commit is contained in:
@@ -554,6 +554,11 @@ class OAuth2Credentials(Credentials):
|
|||||||
else:
|
else:
|
||||||
headers['user-agent'] = self.user_agent
|
headers['user-agent'] = self.user_agent
|
||||||
|
|
||||||
|
body_stream_position = None
|
||||||
|
if all(getattr(body, stream_prop, None) for stream_prop in
|
||||||
|
('read', 'seek', 'tell')):
|
||||||
|
body_stream_position = body.tell()
|
||||||
|
|
||||||
resp, content = request_orig(uri, method, body, clean_headers(headers),
|
resp, content = request_orig(uri, method, body, clean_headers(headers),
|
||||||
redirections, connection_type)
|
redirections, connection_type)
|
||||||
|
|
||||||
@@ -567,6 +572,9 @@ class OAuth2Credentials(Credentials):
|
|||||||
refresh_attempt + 1, max_refresh_attempts)
|
refresh_attempt + 1, max_refresh_attempts)
|
||||||
self._refresh(request_orig)
|
self._refresh(request_orig)
|
||||||
self.apply(headers)
|
self.apply(headers)
|
||||||
|
if body_stream_position is not None:
|
||||||
|
body.seek(body_stream_position)
|
||||||
|
|
||||||
resp, content = request_orig(uri, method, body, clean_headers(headers),
|
resp, content = request_orig(uri, method, body, clean_headers(headers),
|
||||||
redirections, connection_type)
|
redirections, connection_type)
|
||||||
|
|
||||||
|
|||||||
@@ -100,17 +100,16 @@ class HttpMockSequence(object):
|
|||||||
connection_type=None):
|
connection_type=None):
|
||||||
resp, content = self._iterable.pop(0)
|
resp, content = self._iterable.pop(0)
|
||||||
self.requests.append({'uri': uri, 'body': body, 'headers': headers})
|
self.requests.append({'uri': uri, 'body': body, 'headers': headers})
|
||||||
|
# Read any underlying stream before sending the request.
|
||||||
|
body_stream_content = body.read() if getattr(body, 'read', None) else None
|
||||||
if content == 'echo_request_headers':
|
if content == 'echo_request_headers':
|
||||||
content = headers
|
content = headers
|
||||||
elif content == 'echo_request_headers_as_json':
|
elif content == 'echo_request_headers_as_json':
|
||||||
content = json.dumps(headers)
|
content = json.dumps(headers)
|
||||||
elif content == 'echo_request_body':
|
elif content == 'echo_request_body':
|
||||||
if hasattr(body, 'read'):
|
content = body if body_stream_content is None else body_stream_content
|
||||||
content = body.read()
|
|
||||||
else:
|
|
||||||
content = body
|
|
||||||
elif content == 'echo_request_uri':
|
elif content == 'echo_request_uri':
|
||||||
content = uri
|
content = uri
|
||||||
elif not isinstance(content, bytes):
|
elif not isinstance(content, bytes):
|
||||||
raise TypeError("http content should be bytes: %r" % (content,))
|
raise TypeError('http content should be bytes: %r' % (content,))
|
||||||
return httplib2.Response(resp), content
|
return httplib2.Response(resp), content
|
||||||
|
|||||||
@@ -32,12 +32,15 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from .http_mock import HttpMockSequence
|
from .http_mock import HttpMockSequence
|
||||||
|
import six
|
||||||
|
|
||||||
from oauth2client import file
|
from oauth2client import file
|
||||||
from oauth2client import locked_file
|
from oauth2client import locked_file
|
||||||
from oauth2client import multistore_file
|
from oauth2client import multistore_file
|
||||||
from oauth2client import util
|
from oauth2client import util
|
||||||
from oauth2client.client import AccessTokenCredentials
|
from oauth2client.client import AccessTokenCredentials
|
||||||
from oauth2client.client import OAuth2Credentials
|
from oauth2client.client import OAuth2Credentials
|
||||||
|
from six.moves import http_client
|
||||||
try:
|
try:
|
||||||
# Python2
|
# Python2
|
||||||
from future_builtins import oct
|
from future_builtins import oct
|
||||||
@@ -154,15 +157,17 @@ class OAuth2ClientFileTests(unittest.TestCase):
|
|||||||
access_token = '1/3w'
|
access_token = '1/3w'
|
||||||
token_response = {'access_token': access_token, 'expires_in': 3600}
|
token_response = {'access_token': access_token, 'expires_in': 3600}
|
||||||
http = HttpMockSequence([
|
http = HttpMockSequence([
|
||||||
({'status': '401'}, b'Initial token expired'),
|
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
|
||||||
({'status': '401'}, b'Store token expired'),
|
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
|
||||||
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
|
({'status': str(http_client.OK)},
|
||||||
({'status': '200'}, b'Valid response to original request')
|
json.dumps(token_response).encode('utf-8')),
|
||||||
|
({'status': str(http_client.OK)},
|
||||||
|
b'Valid response to original request')
|
||||||
])
|
])
|
||||||
|
|
||||||
credentials.authorize(http)
|
credentials.authorize(http)
|
||||||
http.request('https://example.com')
|
http.request('https://example.com')
|
||||||
self.assertEquals(credentials.access_token, access_token)
|
self.assertEqual(credentials.access_token, access_token)
|
||||||
|
|
||||||
def test_token_refresh_good_store(self):
|
def test_token_refresh_good_store(self):
|
||||||
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
||||||
@@ -178,6 +183,34 @@ class OAuth2ClientFileTests(unittest.TestCase):
|
|||||||
credentials._refresh(lambda x: x)
|
credentials._refresh(lambda x: x)
|
||||||
self.assertEquals(credentials.access_token, 'bar')
|
self.assertEquals(credentials.access_token, 'bar')
|
||||||
|
|
||||||
|
def test_token_refresh_stream_body(self):
|
||||||
|
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
|
||||||
|
credentials = self.create_test_credentials(expiration=expiration)
|
||||||
|
|
||||||
|
s = file.Storage(FILENAME)
|
||||||
|
s.put(credentials)
|
||||||
|
credentials = s.get()
|
||||||
|
new_cred = copy.copy(credentials)
|
||||||
|
new_cred.access_token = 'bar'
|
||||||
|
s.put(new_cred)
|
||||||
|
|
||||||
|
valid_access_token = '1/3w'
|
||||||
|
token_response = {'access_token': valid_access_token, 'expires_in': 3600}
|
||||||
|
http = HttpMockSequence([
|
||||||
|
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
|
||||||
|
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
|
||||||
|
({'status': str(http_client.OK)},
|
||||||
|
json.dumps(token_response).encode('utf-8')),
|
||||||
|
({'status': str(http_client.OK)}, 'echo_request_body')
|
||||||
|
])
|
||||||
|
|
||||||
|
body = six.StringIO('streaming body')
|
||||||
|
|
||||||
|
credentials.authorize(http)
|
||||||
|
_, content = http.request('https://example.com', body=body)
|
||||||
|
self.assertEqual(content, 'streaming body')
|
||||||
|
self.assertEqual(credentials.access_token, valid_access_token)
|
||||||
|
|
||||||
def test_credentials_delete(self):
|
def test_credentials_delete(self):
|
||||||
credentials = self.create_test_credentials()
|
credentials = self.create_test_credentials()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user