Fix bugs with auth, batch and retries.

Reviewed in http://codereview.appspot.com/5633052/.
This commit is contained in:
Joe Gregorio
2012-02-09 14:15:44 -05:00
parent f2326c0524
commit 654f4a22d0
5 changed files with 486 additions and 195 deletions

View File

@@ -37,6 +37,7 @@ import urllib
import urlparse import urlparse
import uuid import uuid
from email.generator import Generator
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.nonmultipart import MIMENonMultipart from email.mime.nonmultipart import MIMENonMultipart
from email.parser import FeedParser from email.parser import FeedParser
@@ -498,9 +499,12 @@ class BatchHttpRequest(object):
# Global callback to be called for each individual response in the batch. # Global callback to be called for each individual response in the batch.
self._callback = callback self._callback = callback
# A map from id to (request, callback) pairs. # A map from id to request.
self._requests = {} self._requests = {}
# A map from id to callback.
self._callbacks = {}
# List of request ids, in the order in which they were added. # List of request ids, in the order in which they were added.
self._order = [] self._order = []
@@ -510,6 +514,39 @@ class BatchHttpRequest(object):
# Unique ID on which to base the Content-ID headers. # Unique ID on which to base the Content-ID headers.
self._base_id = None self._base_id = None
# A map from request id to (headers, content) response pairs
self._responses = {}
# A map of id(Credentials) that have been refreshed.
self._refreshed_credentials = {}
def _refresh_and_apply_credentials(self, request, http):
"""Refresh the credentials and apply to the request.
Args:
request: HttpRequest, the request.
http: httplib2.Http, the global http object for the batch.
"""
# For the credentials to refresh, but only once per refresh_token
# If there is no http per the request then refresh the http passed in
# via execute()
creds = None
if request.http is not None and hasattr(request.http.request,
'credentials'):
creds = request.http.request.credentials
elif http is not None and hasattr(http.request, 'credentials'):
creds = http.request.credentials
if creds is not None:
if id(creds) not in self._refreshed_credentials:
creds.refresh(http)
self._refreshed_credentials[id(creds)] = 1
# Only apply the credentials if we are using the http object passed in,
# otherwise apply() will get called during _serialize_request().
if request.http is None or not hasattr(request.http.request,
'credentials'):
creds.apply(request.headers)
def _id_to_header(self, id_): def _id_to_header(self, id_):
"""Convert an id to a Content-ID header value. """Convert an id to a Content-ID header value.
@@ -568,6 +605,10 @@ class BatchHttpRequest(object):
msg = MIMENonMultipart(major, minor) msg = MIMENonMultipart(major, minor)
headers = request.headers.copy() headers = request.headers.copy()
if request.http is not None and hasattr(request.http.request,
'credentials'):
request.http.request.credentials.apply(headers)
# MIMENonMultipart adds its own Content-Type header. # MIMENonMultipart adds its own Content-Type header.
if 'content-type' in headers: if 'content-type' in headers:
del headers['content-type'] del headers['content-type']
@@ -581,7 +622,13 @@ class BatchHttpRequest(object):
msg.set_payload(request.body) msg.set_payload(request.body)
msg['content-length'] = str(len(request.body)) msg['content-length'] = str(len(request.body))
body = msg.as_string(False) # Serialize the mime message.
fp = StringIO.StringIO()
# maxheaderlen=0 means don't line wrap headers.
g = Generator(fp, maxheaderlen=0)
g.flatten(msg, unixfrom=False)
body = fp.getvalue()
# Strip off the \n\n that the MIME lib tacks onto the end of the payload. # Strip off the \n\n that the MIME lib tacks onto the end of the payload.
if request.body is None: if request.body is None:
body = body[:-2] body = body[:-2]
@@ -661,9 +708,71 @@ class BatchHttpRequest(object):
raise BatchError("Resumable requests cannot be used in a batch request.") raise BatchError("Resumable requests cannot be used in a batch request.")
if request_id in self._requests: if request_id in self._requests:
raise KeyError("A request with this ID already exists: %s" % request_id) raise KeyError("A request with this ID already exists: %s" % request_id)
self._requests[request_id] = (request, callback) self._requests[request_id] = request
self._callbacks[request_id] = callback
self._order.append(request_id) self._order.append(request_id)
def _execute(self, http, order, requests):
"""Serialize batch request, send to server, process response.
Args:
http: httplib2.Http, an http object to be used to make the request with.
order: list, list of request ids in the order they were added to the
batch.
request: list, list of request objects to send.
Raises:
httplib2.Error if a transport error has occured.
apiclient.errors.BatchError if the response is the wrong format.
"""
message = MIMEMultipart('mixed')
# Message should not write out it's own headers.
setattr(message, '_write_headers', lambda self: None)
# Add all the individual requests.
for request_id in order:
request = requests[request_id]
msg = MIMENonMultipart('application', 'http')
msg['Content-Transfer-Encoding'] = 'binary'
msg['Content-ID'] = self._id_to_header(request_id)
body = self._serialize_request(request)
msg.set_payload(body)
message.attach(msg)
body = message.as_string()
headers = {}
headers['content-type'] = ('multipart/mixed; '
'boundary="%s"') % message.get_boundary()
resp, content = http.request(self._batch_uri, 'POST', body=body,
headers=headers)
if resp.status >= 300:
raise HttpError(resp, content, self._batch_uri)
# Now break out the individual responses and store each one.
boundary, _ = content.split(None, 1)
# Prepend with a content-type header so FeedParser can handle it.
header = 'content-type: %s\r\n\r\n' % resp['content-type']
for_parser = header + content
parser = FeedParser()
parser.feed(for_parser)
mime_response = parser.close()
if not mime_response.is_multipart():
raise BatchError("Response not in multipart/mixed format.", resp,
content)
for part in mime_response.get_payload():
request_id = self._header_to_id(part['Content-ID'])
headers, content = self._deserialize_response(part.get_payload())
self._responses[request_id] = (headers, content)
def execute(self, http=None): def execute(self, http=None):
"""Execute all the requests as a single batched HTTP request. """Execute all the requests as a single batched HTTP request.
@@ -676,84 +785,61 @@ class BatchHttpRequest(object):
None None
Raises: Raises:
apiclient.errors.HttpError if the response was not a 2xx.
httplib2.Error if a transport error has occured. httplib2.Error if a transport error has occured.
apiclient.errors.BatchError if the response is the wrong format. apiclient.errors.BatchError if the response is the wrong format.
""" """
# If http is not supplied use the first valid one given in the requests.
if http is None: if http is None:
for request_id in self._order: for request_id in self._order:
request, callback = self._requests[request_id] request = self._requests[request_id]
if request is not None: if request is not None:
http = request.http http = request.http
break break
if http is None: if http is None:
raise ValueError("Missing a valid http object.") raise ValueError("Missing a valid http object.")
self._execute(http, self._order, self._requests)
msgRoot = MIMEMultipart('mixed') # Loop over all the requests and check for 401s. For each 401 request the
# msgRoot should not write out it's own headers # credentials should be refreshed and then sent again in a separate batch.
setattr(msgRoot, '_write_headers', lambda self: None) redo_requests = {}
redo_order = []
# Add all the individual requests.
for request_id in self._order: for request_id in self._order:
request, callback = self._requests[request_id] headers, content = self._responses[request_id]
if headers['status'] == '401':
redo_order.append(request_id)
request = self._requests[request_id]
self._refresh_and_apply_credentials(request, http)
redo_requests[request_id] = request
msg = MIMENonMultipart('application', 'http') if redo_requests:
msg['Content-Transfer-Encoding'] = 'binary' self._execute(http, redo_order, redo_requests)
msg['Content-ID'] = self._id_to_header(request_id)
body = self._serialize_request(request) # Now process all callbacks that are erroring, and raise an exception for
msg.set_payload(body) # ones that return a non-2xx response? Or add extra parameter to callback
msgRoot.attach(msg) # that contains an HttpError?
body = msgRoot.as_string() for request_id in self._order:
headers, content = self._responses[request_id]
headers = {} request = self._requests[request_id]
headers['content-type'] = ('multipart/mixed; ' callback = self._callbacks[request_id]
'boundary="%s"') % msgRoot.get_boundary()
resp, content = http.request(self._batch_uri, 'POST', body=body, response = None
headers=headers) exception = None
try:
r = httplib2.Response(headers)
response = request.postproc(r, content)
except HttpError, e:
exception = e
if resp.status >= 300: if callback is not None:
raise HttpError(resp, content, self._batch_uri) callback(request_id, response, exception)
# Now break up the response and process each one with the correct postproc
# and trigger the right callbacks.
boundary, _ = content.split(None, 1)
# Prepend with a content-type header so FeedParser can handle it.
header = 'content-type: %s\r\n\r\n' % resp['content-type']
for_parser = header + content
parser = FeedParser()
parser.feed(for_parser)
respRoot = parser.close()
if not respRoot.is_multipart():
raise BatchError("Response not in multipart/mixed format.", resp,
content)
parts = respRoot.get_payload()
for part in parts:
request_id = self._header_to_id(part['Content-ID'])
headers, content = self._deserialize_response(part.get_payload())
# TODO(jcgregorio) Remove this temporary hack once the server stops
# gzipping individual response bodies.
if content[0] != '{':
gzipped_content = content
content = gzip.GzipFile(
fileobj=StringIO.StringIO(gzipped_content)).read()
request, cb = self._requests[request_id]
postproc = request.postproc
response = postproc(resp, content)
if cb is not None:
cb(request_id, response)
if self._callback is not None: if self._callback is not None:
self._callback(request_id, response) self._callback(request_id, response, exception)
class HttpRequestMock(object): class HttpRequestMock(object):

View File

@@ -129,6 +129,23 @@ class Credentials(object):
""" """
_abstract() _abstract()
def refresh(self, http):
"""Forces a refresh of the access_token.
Args:
http: httplib2.Http, an http object to be used to make the refresh
request.
"""
_abstract()
def apply(self, headers):
"""Add the authorization to the headers.
Args:
headers: dict, the headers to add the Authorization header to.
"""
_abstract()
def _to_json(self, strip): def _to_json(self, strip):
"""Utility function for creating a JSON representation of an instance of Credentials. """Utility function for creating a JSON representation of an instance of Credentials.
@@ -324,6 +341,92 @@ class OAuth2Credentials(Credentials):
# refreshed. # refreshed.
self.invalid = False self.invalid = False
def authorize(self, http):
"""Authorize an httplib2.Http instance with these credentials.
The modified http.request method will add authentication headers to each
request and will refresh access_tokens when a 401 is received on a
request. In addition the http.request method has a credentials property,
http.request.credentials, which is the Credentials object that authorized
it.
Args:
http: An instance of httplib2.Http
or something that acts like it.
Returns:
A modified instance of http that was passed in.
Example:
h = httplib2.Http()
h = credentials.authorize(h)
You can't create a new OAuth subclass of httplib2.Authenication
because it never gets passed the absolute URI, which is needed for
signing. So instead we have to overload 'request' with a closure
that adds in the Authorization header and then calls the original
version of 'request()'.
"""
request_orig = http.request
# The closure that will replace 'httplib2.Http.request'.
def new_request(uri, method='GET', body=None, headers=None,
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
connection_type=None):
if not self.access_token:
logger.info('Attempting refresh to obtain initial access_token')
self._refresh(request_orig)
# Modify the request headers to add the appropriate
# Authorization header.
if headers is None:
headers = {}
self.apply(headers)
if self.user_agent is not None:
if 'user-agent' in headers:
headers['user-agent'] = self.user_agent + ' ' + headers['user-agent']
else:
headers['user-agent'] = self.user_agent
resp, content = request_orig(uri, method, body, headers,
redirections, connection_type)
if resp.status == 401:
logger.info('Refreshing due to a 401')
self._refresh(request_orig)
self.apply(headers)
return request_orig(uri, method, body, headers,
redirections, connection_type)
else:
return (resp, content)
# Replace the request method with our own closure.
http.request = new_request
# Set credentials as a property of the request method.
setattr(http.request, 'credentials', self)
return http
def refresh(self, http):
"""Forces a refresh of the access_token.
Args:
http: httplib2.Http, an http object to be used to make the refresh
request.
"""
self._refresh(http.request)
def apply(self, headers):
"""Add the authorization to the headers.
Args:
headers: dict, the headers to add the Authorization header to.
"""
headers['Authorization'] = 'Bearer ' + self.access_token
def to_json(self): def to_json(self):
return self._to_json(Credentials.NON_SERIALIZED_MEMBERS) return self._to_json(Credentials.NON_SERIALIZED_MEMBERS)
@@ -431,6 +534,13 @@ class OAuth2Credentials(Credentials):
This method first checks by reading the Storage object if available. This method first checks by reading the Storage object if available.
If a refresh is still needed, it holds the Storage lock until the If a refresh is still needed, it holds the Storage lock until the
refresh is completed. refresh is completed.
Args:
http_request: callable, a callable that matches the method signature of
httplib2.Http.request, used to make the refresh request.
Raises:
AccessTokenRefreshError: When the refresh fails.
""" """
if not self.store: if not self.store:
self._do_refresh_request(http_request) self._do_refresh_request(http_request)
@@ -451,8 +561,8 @@ class OAuth2Credentials(Credentials):
"""Refresh the access_token using the refresh_token. """Refresh the access_token using the refresh_token.
Args: Args:
http: An instance of httplib2.Http.request http_request: callable, a callable that matches the method signature of
or something that acts like it. httplib2.Http.request, used to make the refresh request.
Raises: Raises:
AccessTokenRefreshError: When the refresh fails. AccessTokenRefreshError: When the refresh fails.
@@ -491,64 +601,6 @@ class OAuth2Credentials(Credentials):
pass pass
raise AccessTokenRefreshError(error_msg) raise AccessTokenRefreshError(error_msg)
def authorize(self, http):
"""Authorize an httplib2.Http instance with these credentials.
Args:
http: An instance of httplib2.Http
or something that acts like it.
Returns:
A modified instance of http that was passed in.
Example:
h = httplib2.Http()
h = credentials.authorize(h)
You can't create a new OAuth subclass of httplib2.Authenication
because it never gets passed the absolute URI, which is needed for
signing. So instead we have to overload 'request' with a closure
that adds in the Authorization header and then calls the original
version of 'request()'.
"""
request_orig = http.request
# The closure that will replace 'httplib2.Http.request'.
def new_request(uri, method='GET', body=None, headers=None,
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
connection_type=None):
if not self.access_token:
logger.info('Attempting refresh to obtain initial access_token')
self._refresh(request_orig)
# Modify the request headers to add the appropriate
# Authorization header.
if headers is None:
headers = {}
headers['authorization'] = 'OAuth ' + self.access_token
if self.user_agent is not None:
if 'user-agent' in headers:
headers['user-agent'] = self.user_agent + ' ' + headers['user-agent']
else:
headers['user-agent'] = self.user_agent
resp, content = request_orig(uri, method, body, headers,
redirections, connection_type)
if resp.status == 401:
logger.info('Refreshing due to a 401')
self._refresh(request_orig)
headers['authorization'] = 'OAuth ' + self.access_token
return request_orig(uri, method, body, headers,
redirections, connection_type)
else:
return (resp, content)
http.request = new_request
return http
class AccessTokenCredentials(OAuth2Credentials): class AccessTokenCredentials(OAuth2Credentials):
"""Credentials object for OAuth 2.0. """Credentials object for OAuth 2.0.

View File

@@ -35,6 +35,52 @@ from apiclient.http import MediaUpload
from apiclient.http import MediaInMemoryUpload from apiclient.http import MediaInMemoryUpload
from apiclient.http import set_user_agent from apiclient.http import set_user_agent
from apiclient.model import JsonModel from apiclient.model import JsonModel
from oauth2client.client import Credentials
class MockCredentials(Credentials):
"""Mock class for all Credentials objects."""
def __init__(self, bearer_token):
super(MockCredentials, self).__init__()
self._authorized = 0
self._refreshed = 0
self._applied = 0
self._bearer_token = bearer_token
def authorize(self, http):
self._authorized += 1
request_orig = http.request
# The closure that will replace 'httplib2.Http.request'.
def new_request(uri, method='GET', body=None, headers=None,
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
connection_type=None):
# Modify the request headers to add the appropriate
# Authorization header.
if headers is None:
headers = {}
self.apply(headers)
resp, content = request_orig(uri, method, body, headers,
redirections, connection_type)
return resp, content
# Replace the request method with our own closure.
http.request = new_request
# Set credentials as a property of the request method.
setattr(http.request, 'credentials', self)
return http
def refresh(self, http):
self._refreshed += 1
def apply(self, headers):
self._applied += 1
headers['authorization'] = self._bearer_token + ' ' + str(self._refreshed)
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
@@ -52,7 +98,7 @@ class TestUserAgent(unittest.TestCase):
http = set_user_agent(http, "my_app/5.5") http = set_user_agent(http, "my_app/5.5")
resp, content = http.request("http://example.com") resp, content = http.request("http://example.com")
self.assertEqual(content['user-agent'], 'my_app/5.5') self.assertEqual('my_app/5.5', content['user-agent'])
def test_set_user_agent_nested(self): def test_set_user_agent_nested(self):
http = HttpMockSequence([ http = HttpMockSequence([
@@ -62,25 +108,25 @@ class TestUserAgent(unittest.TestCase):
http = set_user_agent(http, "my_app/5.5") http = set_user_agent(http, "my_app/5.5")
http = set_user_agent(http, "my_library/0.1") http = set_user_agent(http, "my_library/0.1")
resp, content = http.request("http://example.com") resp, content = http.request("http://example.com")
self.assertEqual(content['user-agent'], 'my_app/5.5 my_library/0.1') self.assertEqual('my_app/5.5 my_library/0.1', content['user-agent'])
def test_media_file_upload_to_from_json(self): def test_media_file_upload_to_from_json(self):
upload = MediaFileUpload( upload = MediaFileUpload(
datafile('small.png'), chunksize=500, resumable=True) datafile('small.png'), chunksize=500, resumable=True)
self.assertEquals('image/png', upload.mimetype()) self.assertEqual('image/png', upload.mimetype())
self.assertEquals(190, upload.size()) self.assertEqual(190, upload.size())
self.assertEquals(True, upload.resumable()) self.assertEqual(True, upload.resumable())
self.assertEquals(500, upload.chunksize()) self.assertEqual(500, upload.chunksize())
self.assertEquals('PNG', upload.getbytes(1, 3)) self.assertEqual('PNG', upload.getbytes(1, 3))
json = upload.to_json() json = upload.to_json()
new_upload = MediaUpload.new_from_json(json) new_upload = MediaUpload.new_from_json(json)
self.assertEquals('image/png', new_upload.mimetype()) self.assertEqual('image/png', new_upload.mimetype())
self.assertEquals(190, new_upload.size()) self.assertEqual(190, new_upload.size())
self.assertEquals(True, new_upload.resumable()) self.assertEqual(True, new_upload.resumable())
self.assertEquals(500, new_upload.chunksize()) self.assertEqual(500, new_upload.chunksize())
self.assertEquals('PNG', new_upload.getbytes(1, 3)) self.assertEqual('PNG', new_upload.getbytes(1, 3))
def test_http_request_to_from_json(self): def test_http_request_to_from_json(self):
@@ -103,13 +149,13 @@ class TestUserAgent(unittest.TestCase):
json = req.to_json() json = req.to_json()
new_req = HttpRequest.from_json(json, http, _postproc) new_req = HttpRequest.from_json(json, http, _postproc)
self.assertEquals(new_req.headers, self.assertEqual({'content-type':
{'content-type': 'multipart/related; boundary="---flubber"'},
'multipart/related; boundary="---flubber"'}) new_req.headers)
self.assertEquals(new_req.uri, 'http://example.com') self.assertEqual('http://example.com', new_req.uri)
self.assertEquals(new_req.body, '{}') self.assertEqual('{}', new_req.body)
self.assertEquals(new_req.http, http) self.assertEqual(http, new_req.http)
self.assertEquals(new_req.resumable.to_json(), media_upload.to_json()) self.assertEqual(media_upload.to_json(), new_req.resumable.to_json())
EXPECTED = """POST /someapi/v1/collection/?foo=bar HTTP/1.1 EXPECTED = """POST /someapi/v1/collection/?foo=bar HTTP/1.1
Content-Type: application/json Content-Type: application/json
@@ -153,6 +199,50 @@ ETag: "etag/sheep"\r\n\r\n{"baz": "qux"}
--batch_foobarbaz--""" --batch_foobarbaz--"""
BATCH_RESPONSE_WITH_401 = """--batch_foobarbaz
Content-Type: application/http
Content-Transfer-Encoding: binary
Content-ID: <randomness+1>
HTTP/1.1 401 Authoration Required
Content-Type application/json
Content-Length: 14
ETag: "etag/pony"\r\n\r\n{"error": {"message":
"Authorizaton failed."}}
--batch_foobarbaz
Content-Type: application/http
Content-Transfer-Encoding: binary
Content-ID: <randomness+2>
HTTP/1.1 200 OK
Content-Type application/json
Content-Length: 14
ETag: "etag/sheep"\r\n\r\n{"baz": "qux"}
--batch_foobarbaz--"""
BATCH_SINGLE_RESPONSE = """--batch_foobarbaz
Content-Type: application/http
Content-Transfer-Encoding: binary
Content-ID: <randomness+1>
HTTP/1.1 200 OK
Content-Type application/json
Content-Length: 14
ETag: "etag/pony"\r\n\r\n{"foo": 42}
--batch_foobarbaz--"""
class Callbacks(object):
def __init__(self):
self.responses = {}
self.exceptions = {}
def f(self, request_id, response, exception):
self.responses[request_id] = response
self.exceptions[request_id] = exception
class TestBatch(unittest.TestCase): class TestBatch(unittest.TestCase):
def setUp(self): def setUp(self):
@@ -196,7 +286,7 @@ class TestBatch(unittest.TestCase):
methodId=None, methodId=None,
resumable=None) resumable=None)
s = batch._serialize_request(request).splitlines() s = batch._serialize_request(request).splitlines()
self.assertEquals(s, EXPECTED.splitlines()) self.assertEqual(EXPECTED.splitlines(), s)
def test_serialize_request_media_body(self): def test_serialize_request_media_body(self):
batch = BatchHttpRequest() batch = BatchHttpRequest()
@@ -213,9 +303,9 @@ class TestBatch(unittest.TestCase):
headers={'content-type': 'application/json'}, headers={'content-type': 'application/json'},
methodId=None, methodId=None,
resumable=None) resumable=None)
# Just testing it shouldn't raise an exception.
s = batch._serialize_request(request).splitlines() s = batch._serialize_request(request).splitlines()
def test_serialize_request_no_body(self): def test_serialize_request_no_body(self):
batch = BatchHttpRequest() batch = BatchHttpRequest()
request = HttpRequest( request = HttpRequest(
@@ -228,30 +318,30 @@ class TestBatch(unittest.TestCase):
methodId=None, methodId=None,
resumable=None) resumable=None)
s = batch._serialize_request(request).splitlines() s = batch._serialize_request(request).splitlines()
self.assertEquals(s, NO_BODY_EXPECTED.splitlines()) self.assertEqual(NO_BODY_EXPECTED.splitlines(), s)
def test_deserialize_response(self): def test_deserialize_response(self):
batch = BatchHttpRequest() batch = BatchHttpRequest()
resp, content = batch._deserialize_response(RESPONSE) resp, content = batch._deserialize_response(RESPONSE)
self.assertEquals(resp.status, 200) self.assertEqual(200, resp.status)
self.assertEquals(resp.reason, 'OK') self.assertEqual('OK', resp.reason)
self.assertEquals(resp.version, 11) self.assertEqual(11, resp.version)
self.assertEquals(content, '{"answer": 42}') self.assertEqual('{"answer": 42}', content)
def test_new_id(self): def test_new_id(self):
batch = BatchHttpRequest() batch = BatchHttpRequest()
id_ = batch._new_id() id_ = batch._new_id()
self.assertEquals(id_, '1') self.assertEqual('1', id_)
id_ = batch._new_id() id_ = batch._new_id()
self.assertEquals(id_, '2') self.assertEqual('2', id_)
batch.add(self.request1, request_id='3') batch.add(self.request1, request_id='3')
id_ = batch._new_id() id_ = batch._new_id()
self.assertEquals(id_, '4') self.assertEqual('4', id_)
def test_add(self): def test_add(self):
batch = BatchHttpRequest() batch = BatchHttpRequest()
@@ -267,13 +357,6 @@ class TestBatch(unittest.TestCase):
self.assertRaises(BatchError, batch.add, self.request1, request_id='1') self.assertRaises(BatchError, batch.add, self.request1, request_id='1')
def test_execute(self): def test_execute(self):
class Callbacks(object):
def __init__(self):
self.responses = {}
def f(self, request_id, response):
self.responses[request_id] = response
batch = BatchHttpRequest() batch = BatchHttpRequest()
callbacks = Callbacks() callbacks = Callbacks()
@@ -285,8 +368,10 @@ class TestBatch(unittest.TestCase):
BATCH_RESPONSE), BATCH_RESPONSE),
]) ])
batch.execute(http) batch.execute(http)
self.assertEqual(callbacks.responses['1'], {'foo': 42}) self.assertEqual({'foo': 42}, callbacks.responses['1'])
self.assertEqual(callbacks.responses['2'], {'baz': 'qux'}) self.assertEqual(None, callbacks.exceptions['1'])
self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
self.assertEqual(None, callbacks.exceptions['2'])
def test_execute_request_body(self): def test_execute_request_body(self):
batch = BatchHttpRequest() batch = BatchHttpRequest()
@@ -311,14 +396,82 @@ class TestBatch(unittest.TestCase):
header = parts[1].splitlines()[1] header = parts[1].splitlines()[1]
self.assertEqual('Content-Type: application/http', header) self.assertEqual('Content-Type: application/http', header)
def test_execute_refresh_and_retry_on_401(self):
batch = BatchHttpRequest()
callbacks = Callbacks()
cred_1 = MockCredentials('Foo')
cred_2 = MockCredentials('Bar')
http = HttpMockSequence([
({'status': '200',
'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
BATCH_RESPONSE_WITH_401),
({'status': '200',
'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
BATCH_SINGLE_RESPONSE),
])
creds_http_1 = HttpMockSequence([])
cred_1.authorize(creds_http_1)
creds_http_2 = HttpMockSequence([])
cred_2.authorize(creds_http_2)
self.request1.http = creds_http_1
self.request2.http = creds_http_2
batch.add(self.request1, callback=callbacks.f)
batch.add(self.request2, callback=callbacks.f)
batch.execute(http)
self.assertEqual({'foo': 42}, callbacks.responses['1'])
self.assertEqual(None, callbacks.exceptions['1'])
self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
self.assertEqual(None, callbacks.exceptions['2'])
self.assertEqual(1, cred_1._refreshed)
self.assertEqual(0, cred_2._refreshed)
self.assertEqual(1, cred_1._authorized)
self.assertEqual(1, cred_2._authorized)
self.assertEqual(1, cred_2._applied)
self.assertEqual(2, cred_1._applied)
def test_http_errors_passed_to_callback(self):
batch = BatchHttpRequest()
callbacks = Callbacks()
cred_1 = MockCredentials('Foo')
cred_2 = MockCredentials('Bar')
http = HttpMockSequence([
({'status': '200',
'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
BATCH_RESPONSE_WITH_401),
({'status': '200',
'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
BATCH_RESPONSE_WITH_401),
])
creds_http_1 = HttpMockSequence([])
cred_1.authorize(creds_http_1)
creds_http_2 = HttpMockSequence([])
cred_2.authorize(creds_http_2)
self.request1.http = creds_http_1
self.request2.http = creds_http_2
batch.add(self.request1, callback=callbacks.f)
batch.add(self.request2, callback=callbacks.f)
batch.execute(http)
self.assertEqual(None, callbacks.responses['1'])
self.assertEqual(401, callbacks.exceptions['1'].resp.status)
self.assertEqual({u'baz': u'qux'}, callbacks.responses['2'])
self.assertEqual(None, callbacks.exceptions['2'])
def test_execute_global_callback(self): def test_execute_global_callback(self):
class Callbacks(object):
def __init__(self):
self.responses = {}
def f(self, request_id, response):
self.responses[request_id] = response
callbacks = Callbacks() callbacks = Callbacks()
batch = BatchHttpRequest(callback=callbacks.f) batch = BatchHttpRequest(callback=callbacks.f)
@@ -330,8 +483,8 @@ class TestBatch(unittest.TestCase):
BATCH_RESPONSE), BATCH_RESPONSE),
]) ])
batch.execute(http) batch.execute(http)
self.assertEqual(callbacks.responses['1'], {'foo': 42}) self.assertEqual({'foo': 42}, callbacks.responses['1'])
self.assertEqual(callbacks.responses['2'], {'baz': 'qux'}) self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
def test_media_inmemory_upload(self): def test_media_inmemory_upload(self):
media = MediaInMemoryUpload('abcdef', 'text/plain', chunksize=10, media = MediaInMemoryUpload('abcdef', 'text/plain', chunksize=10,

View File

@@ -70,7 +70,7 @@ class OAuth2CredentialsTests(unittest.TestCase):
]) ])
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request("http://example.com") resp, content = http.request("http://example.com")
self.assertEqual(content['authorization'], 'OAuth 1/3w') self.assertEqual('Bearer 1/3w', content['Authorization'])
def test_token_refresh_failure(self): def test_token_refresh_failure(self):
http = HttpMockSequence([ http = HttpMockSequence([
@@ -95,11 +95,11 @@ class OAuth2CredentialsTests(unittest.TestCase):
def test_to_from_json(self): def test_to_from_json(self):
json = self.credentials.to_json() json = self.credentials.to_json()
instance = OAuth2Credentials.from_json(json) instance = OAuth2Credentials.from_json(json)
self.assertEquals(type(instance), OAuth2Credentials) self.assertEqual(OAuth2Credentials, type(instance))
instance.token_expiry = None instance.token_expiry = None
self.credentials.token_expiry = None self.credentials.token_expiry = None
self.assertEquals(self.credentials.__dict__, instance.__dict__) self.assertEqual(instance.__dict__, self.credentials.__dict__)
class AccessTokenCredentialsTests(unittest.TestCase): class AccessTokenCredentialsTests(unittest.TestCase):
@@ -136,7 +136,7 @@ class AccessTokenCredentialsTests(unittest.TestCase):
]) ])
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request('http://example.com') resp, content = http.request('http://example.com')
self.assertEqual(content['authorization'], 'OAuth foo') self.assertEqual('Bearer foo', content['Authorization'])
class TestAssertionCredentials(unittest.TestCase): class TestAssertionCredentials(unittest.TestCase):
@@ -155,8 +155,8 @@ class TestAssertionCredentials(unittest.TestCase):
def test_assertion_body(self): def test_assertion_body(self):
body = urlparse.parse_qs(self.credentials._generate_refresh_request_body()) body = urlparse.parse_qs(self.credentials._generate_refresh_request_body())
self.assertEqual(body['assertion'][0], self.assertion_text) self.assertEqual(self.assertion_text, body['assertion'][0])
self.assertEqual(body['assertion_type'][0], self.assertion_type) self.assertEqual(self.assertion_type, body['assertion_type'][0])
def test_assertion_refresh(self): def test_assertion_refresh(self):
http = HttpMockSequence([ http = HttpMockSequence([
@@ -165,7 +165,7 @@ class TestAssertionCredentials(unittest.TestCase):
]) ])
http = self.credentials.authorize(http) http = self.credentials.authorize(http)
resp, content = http.request("http://example.com") resp, content = http.request("http://example.com")
self.assertEqual(content['authorization'], 'OAuth 1/3w') self.assertEqual('Bearer 1/3w', content['Authorization'])
class ExtractIdTokenText(unittest.TestCase): class ExtractIdTokenText(unittest.TestCase):
@@ -177,7 +177,7 @@ class ExtractIdTokenText(unittest.TestCase):
jwt = 'stuff.' + payload + '.signature' jwt = 'stuff.' + payload + '.signature'
extracted = _extract_id_token(jwt) extracted = _extract_id_token(jwt)
self.assertEqual(body, extracted) self.assertEqual(extracted, body)
def test_extract_failure(self): def test_extract_failure(self):
body = {'foo': 'bar'} body = {'foo': 'bar'}
@@ -201,11 +201,11 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
parsed = urlparse.urlparse(authorize_url) parsed = urlparse.urlparse(authorize_url)
q = parse_qs(parsed[4]) q = parse_qs(parsed[4])
self.assertEqual(q['client_id'][0], 'client_id+1') self.assertEqual('client_id+1', q['client_id'][0])
self.assertEqual(q['response_type'][0], 'code') self.assertEqual('code', q['response_type'][0])
self.assertEqual(q['scope'][0], 'foo') self.assertEqual('foo', q['scope'][0])
self.assertEqual(q['redirect_uri'][0], 'OOB_CALLBACK_URN') self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0])
self.assertEqual(q['access_type'][0], 'offline') self.assertEqual('offline', q['access_type'][0])
def test_override_flow_access_type(self): def test_override_flow_access_type(self):
"""Passing access_type overrides the default.""" """Passing access_type overrides the default."""
@@ -220,11 +220,11 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
parsed = urlparse.urlparse(authorize_url) parsed = urlparse.urlparse(authorize_url)
q = parse_qs(parsed[4]) q = parse_qs(parsed[4])
self.assertEqual(q['client_id'][0], 'client_id+1') self.assertEqual('client_id+1', q['client_id'][0])
self.assertEqual(q['response_type'][0], 'code') self.assertEqual('code', q['response_type'][0])
self.assertEqual(q['scope'][0], 'foo') self.assertEqual('foo', q['scope'][0])
self.assertEqual(q['redirect_uri'][0], 'OOB_CALLBACK_URN') self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0])
self.assertEqual(q['access_type'][0], 'online') self.assertEqual('online', q['access_type'][0])
def test_exchange_failure(self): def test_exchange_failure(self):
http = HttpMockSequence([ http = HttpMockSequence([
@@ -246,9 +246,9 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
]) ])
credentials = self.flow.step2_exchange('some random code', http) credentials = self.flow.step2_exchange('some random code', http)
self.assertEqual(credentials.access_token, 'SlAV32hkKG') self.assertEqual('SlAV32hkKG', credentials.access_token)
self.assertNotEqual(credentials.token_expiry, None) self.assertNotEqual(None, credentials.token_expiry)
self.assertEqual(credentials.refresh_token, '8xLOxBtZp8') self.assertEqual('8xLOxBtZp8', credentials.refresh_token)
def test_exchange_no_expires_in(self): def test_exchange_no_expires_in(self):
http = HttpMockSequence([ http = HttpMockSequence([
@@ -257,7 +257,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
]) ])
credentials = self.flow.step2_exchange('some random code', http) credentials = self.flow.step2_exchange('some random code', http)
self.assertEqual(credentials.token_expiry, None) self.assertEqual(None, credentials.token_expiry)
def test_exchange_id_token_fail(self): def test_exchange_id_token_fail(self):
http = HttpMockSequence([ http = HttpMockSequence([
@@ -282,7 +282,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
]) ])
credentials = self.flow.step2_exchange('some random code', http) credentials = self.flow.step2_exchange('some random code', http)
self.assertEquals(body, credentials.id_token) self.assertEqual(credentials.id_token, body)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -100,8 +100,8 @@ class CryptTests(unittest.TestCase):
certs = {'foo': public_key } certs = {'foo': public_key }
audience = 'some_audience_address@testing.gserviceaccount.com' audience = 'some_audience_address@testing.gserviceaccount.com'
contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience) contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
self.assertEquals('billy bob', contents['user']) self.assertEqual('billy bob', contents['user'])
self.assertEquals('data', contents['metadata']['meta']) self.assertEqual('data', contents['metadata']['meta'])
def test_verify_id_token_with_certs_uri(self): def test_verify_id_token_with_certs_uri(self):
jwt = self._create_signed_jwt() jwt = self._create_signed_jwt()
@@ -112,8 +112,8 @@ class CryptTests(unittest.TestCase):
contents = verify_id_token(jwt, contents = verify_id_token(jwt,
'some_audience_address@testing.gserviceaccount.com', http) 'some_audience_address@testing.gserviceaccount.com', http)
self.assertEquals('billy bob', contents['user']) self.assertEqual('billy bob', contents['user'])
self.assertEquals('data', contents['metadata']['meta']) self.assertEqual('data', contents['metadata']['meta'])
def test_verify_id_token_with_certs_uri_fails(self): def test_verify_id_token_with_certs_uri_fails(self):
jwt = self._create_signed_jwt() jwt = self._create_signed_jwt()
@@ -195,7 +195,7 @@ class CryptTests(unittest.TestCase):
]) ])
http = credentials.authorize(http) http = credentials.authorize(http)
resp, content = http.request('http://example.org') resp, content = http.request('http://example.org')
self.assertEquals(content['authorization'], 'OAuth 1/3w') self.assertEqual('Bearer 1/3w', content['Authorization'])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()