From 654f4a22d03e618d16e8a569192d15f68547d5f4 Mon Sep 17 00:00:00 2001 From: Joe Gregorio Date: Thu, 9 Feb 2012 14:15:44 -0500 Subject: [PATCH] Fix bugs with auth, batch and retries. Reviewed in http://codereview.appspot.com/5633052/. --- apiclient/http.py | 206 +++++++++++++++++++-------- oauth2client/client.py | 172 +++++++++++++++-------- tests/test_http.py | 247 ++++++++++++++++++++++++++------- tests/test_oauth2client.py | 46 +++--- tests/test_oauth2client_jwt.py | 10 +- 5 files changed, 486 insertions(+), 195 deletions(-) diff --git a/apiclient/http.py b/apiclient/http.py index 94eb266..ff61cb1 100644 --- a/apiclient/http.py +++ b/apiclient/http.py @@ -37,6 +37,7 @@ import urllib import urlparse import uuid +from email.generator import Generator from email.mime.multipart import MIMEMultipart from email.mime.nonmultipart import MIMENonMultipart from email.parser import FeedParser @@ -498,9 +499,12 @@ class BatchHttpRequest(object): # Global callback to be called for each individual response in the batch. self._callback = callback - # A map from id to (request, callback) pairs. + # A map from id to request. self._requests = {} + # A map from id to callback. + self._callbacks = {} + # List of request ids, in the order in which they were added. self._order = [] @@ -510,6 +514,39 @@ class BatchHttpRequest(object): # Unique ID on which to base the Content-ID headers. self._base_id = None + # A map from request id to (headers, content) response pairs + self._responses = {} + + # A map of id(Credentials) that have been refreshed. + self._refreshed_credentials = {} + + def _refresh_and_apply_credentials(self, request, http): + """Refresh the credentials and apply to the request. + + Args: + request: HttpRequest, the request. + http: httplib2.Http, the global http object for the batch. + """ + # For the credentials to refresh, but only once per refresh_token + # If there is no http per the request then refresh the http passed in + # via execute() + creds = None + if request.http is not None and hasattr(request.http.request, + 'credentials'): + creds = request.http.request.credentials + elif http is not None and hasattr(http.request, 'credentials'): + creds = http.request.credentials + if creds is not None: + if id(creds) not in self._refreshed_credentials: + creds.refresh(http) + self._refreshed_credentials[id(creds)] = 1 + + # Only apply the credentials if we are using the http object passed in, + # otherwise apply() will get called during _serialize_request(). + if request.http is None or not hasattr(request.http.request, + 'credentials'): + creds.apply(request.headers) + def _id_to_header(self, id_): """Convert an id to a Content-ID header value. @@ -568,6 +605,10 @@ class BatchHttpRequest(object): msg = MIMENonMultipart(major, minor) headers = request.headers.copy() + if request.http is not None and hasattr(request.http.request, + 'credentials'): + request.http.request.credentials.apply(headers) + # MIMENonMultipart adds its own Content-Type header. if 'content-type' in headers: del headers['content-type'] @@ -581,7 +622,13 @@ class BatchHttpRequest(object): msg.set_payload(request.body) msg['content-length'] = str(len(request.body)) - body = msg.as_string(False) + # Serialize the mime message. + fp = StringIO.StringIO() + # maxheaderlen=0 means don't line wrap headers. + g = Generator(fp, maxheaderlen=0) + g.flatten(msg, unixfrom=False) + body = fp.getvalue() + # Strip off the \n\n that the MIME lib tacks onto the end of the payload. if request.body is None: body = body[:-2] @@ -661,9 +708,71 @@ class BatchHttpRequest(object): raise BatchError("Resumable requests cannot be used in a batch request.") if request_id in self._requests: raise KeyError("A request with this ID already exists: %s" % request_id) - self._requests[request_id] = (request, callback) + self._requests[request_id] = request + self._callbacks[request_id] = callback self._order.append(request_id) + def _execute(self, http, order, requests): + """Serialize batch request, send to server, process response. + + Args: + http: httplib2.Http, an http object to be used to make the request with. + order: list, list of request ids in the order they were added to the + batch. + request: list, list of request objects to send. + + Raises: + httplib2.Error if a transport error has occured. + apiclient.errors.BatchError if the response is the wrong format. + """ + message = MIMEMultipart('mixed') + # Message should not write out it's own headers. + setattr(message, '_write_headers', lambda self: None) + + # Add all the individual requests. + for request_id in order: + request = requests[request_id] + + msg = MIMENonMultipart('application', 'http') + msg['Content-Transfer-Encoding'] = 'binary' + msg['Content-ID'] = self._id_to_header(request_id) + + body = self._serialize_request(request) + msg.set_payload(body) + message.attach(msg) + + body = message.as_string() + + headers = {} + headers['content-type'] = ('multipart/mixed; ' + 'boundary="%s"') % message.get_boundary() + + resp, content = http.request(self._batch_uri, 'POST', body=body, + headers=headers) + + if resp.status >= 300: + raise HttpError(resp, content, self._batch_uri) + + # Now break out the individual responses and store each one. + boundary, _ = content.split(None, 1) + + # Prepend with a content-type header so FeedParser can handle it. + header = 'content-type: %s\r\n\r\n' % resp['content-type'] + for_parser = header + content + + parser = FeedParser() + parser.feed(for_parser) + mime_response = parser.close() + + if not mime_response.is_multipart(): + raise BatchError("Response not in multipart/mixed format.", resp, + content) + + for part in mime_response.get_payload(): + request_id = self._header_to_id(part['Content-ID']) + headers, content = self._deserialize_response(part.get_payload()) + self._responses[request_id] = (headers, content) + def execute(self, http=None): """Execute all the requests as a single batched HTTP request. @@ -676,84 +785,61 @@ class BatchHttpRequest(object): None Raises: - apiclient.errors.HttpError if the response was not a 2xx. httplib2.Error if a transport error has occured. apiclient.errors.BatchError if the response is the wrong format. """ + + # If http is not supplied use the first valid one given in the requests. if http is None: for request_id in self._order: - request, callback = self._requests[request_id] + request = self._requests[request_id] if request is not None: http = request.http break + if http is None: raise ValueError("Missing a valid http object.") + self._execute(http, self._order, self._requests) - msgRoot = MIMEMultipart('mixed') - # msgRoot should not write out it's own headers - setattr(msgRoot, '_write_headers', lambda self: None) + # Loop over all the requests and check for 401s. For each 401 request the + # credentials should be refreshed and then sent again in a separate batch. + redo_requests = {} + redo_order = [] - # Add all the individual requests. for request_id in self._order: - request, callback = self._requests[request_id] + headers, content = self._responses[request_id] + if headers['status'] == '401': + redo_order.append(request_id) + request = self._requests[request_id] + self._refresh_and_apply_credentials(request, http) + redo_requests[request_id] = request - msg = MIMENonMultipart('application', 'http') - msg['Content-Transfer-Encoding'] = 'binary' - msg['Content-ID'] = self._id_to_header(request_id) + if redo_requests: + self._execute(http, redo_order, redo_requests) - body = self._serialize_request(request) - msg.set_payload(body) - msgRoot.attach(msg) + # Now process all callbacks that are erroring, and raise an exception for + # ones that return a non-2xx response? Or add extra parameter to callback + # that contains an HttpError? - body = msgRoot.as_string() + for request_id in self._order: + headers, content = self._responses[request_id] - headers = {} - headers['content-type'] = ('multipart/mixed; ' - 'boundary="%s"') % msgRoot.get_boundary() + request = self._requests[request_id] + callback = self._callbacks[request_id] - resp, content = http.request(self._batch_uri, 'POST', body=body, - headers=headers) + response = None + exception = None + try: + r = httplib2.Response(headers) + response = request.postproc(r, content) + except HttpError, e: + exception = e - if resp.status >= 300: - raise HttpError(resp, content, self._batch_uri) - - # Now break up the response and process each one with the correct postproc - # and trigger the right callbacks. - boundary, _ = content.split(None, 1) - - # Prepend with a content-type header so FeedParser can handle it. - header = 'content-type: %s\r\n\r\n' % resp['content-type'] - for_parser = header + content - - parser = FeedParser() - parser.feed(for_parser) - respRoot = parser.close() - - if not respRoot.is_multipart(): - raise BatchError("Response not in multipart/mixed format.", resp, - content) - - parts = respRoot.get_payload() - for part in parts: - request_id = self._header_to_id(part['Content-ID']) - - headers, content = self._deserialize_response(part.get_payload()) - - # TODO(jcgregorio) Remove this temporary hack once the server stops - # gzipping individual response bodies. - if content[0] != '{': - gzipped_content = content - content = gzip.GzipFile( - fileobj=StringIO.StringIO(gzipped_content)).read() - - request, cb = self._requests[request_id] - postproc = request.postproc - response = postproc(resp, content) - if cb is not None: - cb(request_id, response) + if callback is not None: + callback(request_id, response, exception) if self._callback is not None: - self._callback(request_id, response) + self._callback(request_id, response, exception) class HttpRequestMock(object): diff --git a/oauth2client/client.py b/oauth2client/client.py index c88b358..ce033ca 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -129,6 +129,23 @@ class Credentials(object): """ _abstract() + def refresh(self, http): + """Forces a refresh of the access_token. + + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. + """ + _abstract() + + def apply(self, headers): + """Add the authorization to the headers. + + Args: + headers: dict, the headers to add the Authorization header to. + """ + _abstract() + def _to_json(self, strip): """Utility function for creating a JSON representation of an instance of Credentials. @@ -324,6 +341,92 @@ class OAuth2Credentials(Credentials): # refreshed. self.invalid = False + def authorize(self, http): + """Authorize an httplib2.Http instance with these credentials. + + The modified http.request method will add authentication headers to each + request and will refresh access_tokens when a 401 is received on a + request. In addition the http.request method has a credentials property, + http.request.credentials, which is the Credentials object that authorized + it. + + Args: + http: An instance of httplib2.Http + or something that acts like it. + + Returns: + A modified instance of http that was passed in. + + Example: + + h = httplib2.Http() + h = credentials.authorize(h) + + You can't create a new OAuth subclass of httplib2.Authenication + because it never gets passed the absolute URI, which is needed for + signing. So instead we have to overload 'request' with a closure + that adds in the Authorization header and then calls the original + version of 'request()'. + """ + request_orig = http.request + + # The closure that will replace 'httplib2.Http.request'. + def new_request(uri, method='GET', body=None, headers=None, + redirections=httplib2.DEFAULT_MAX_REDIRECTS, + connection_type=None): + if not self.access_token: + logger.info('Attempting refresh to obtain initial access_token') + self._refresh(request_orig) + + # Modify the request headers to add the appropriate + # Authorization header. + if headers is None: + headers = {} + self.apply(headers) + + if self.user_agent is not None: + if 'user-agent' in headers: + headers['user-agent'] = self.user_agent + ' ' + headers['user-agent'] + else: + headers['user-agent'] = self.user_agent + + resp, content = request_orig(uri, method, body, headers, + redirections, connection_type) + + if resp.status == 401: + logger.info('Refreshing due to a 401') + self._refresh(request_orig) + self.apply(headers) + return request_orig(uri, method, body, headers, + redirections, connection_type) + else: + return (resp, content) + + # Replace the request method with our own closure. + http.request = new_request + + # Set credentials as a property of the request method. + setattr(http.request, 'credentials', self) + + return http + + def refresh(self, http): + """Forces a refresh of the access_token. + + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. + """ + self._refresh(http.request) + + def apply(self, headers): + """Add the authorization to the headers. + + Args: + headers: dict, the headers to add the Authorization header to. + """ + headers['Authorization'] = 'Bearer ' + self.access_token + def to_json(self): return self._to_json(Credentials.NON_SERIALIZED_MEMBERS) @@ -431,6 +534,13 @@ class OAuth2Credentials(Credentials): This method first checks by reading the Storage object if available. If a refresh is still needed, it holds the Storage lock until the refresh is completed. + + Args: + http_request: callable, a callable that matches the method signature of + httplib2.Http.request, used to make the refresh request. + + Raises: + AccessTokenRefreshError: When the refresh fails. """ if not self.store: self._do_refresh_request(http_request) @@ -451,8 +561,8 @@ class OAuth2Credentials(Credentials): """Refresh the access_token using the refresh_token. Args: - http: An instance of httplib2.Http.request - or something that acts like it. + http_request: callable, a callable that matches the method signature of + httplib2.Http.request, used to make the refresh request. Raises: AccessTokenRefreshError: When the refresh fails. @@ -491,64 +601,6 @@ class OAuth2Credentials(Credentials): pass raise AccessTokenRefreshError(error_msg) - def authorize(self, http): - """Authorize an httplib2.Http instance with these credentials. - - Args: - http: An instance of httplib2.Http - or something that acts like it. - - Returns: - A modified instance of http that was passed in. - - Example: - - h = httplib2.Http() - h = credentials.authorize(h) - - You can't create a new OAuth subclass of httplib2.Authenication - because it never gets passed the absolute URI, which is needed for - signing. So instead we have to overload 'request' with a closure - that adds in the Authorization header and then calls the original - version of 'request()'. - """ - request_orig = http.request - - # The closure that will replace 'httplib2.Http.request'. - def new_request(uri, method='GET', body=None, headers=None, - redirections=httplib2.DEFAULT_MAX_REDIRECTS, - connection_type=None): - if not self.access_token: - logger.info('Attempting refresh to obtain initial access_token') - self._refresh(request_orig) - - # Modify the request headers to add the appropriate - # Authorization header. - if headers is None: - headers = {} - headers['authorization'] = 'OAuth ' + self.access_token - - if self.user_agent is not None: - if 'user-agent' in headers: - headers['user-agent'] = self.user_agent + ' ' + headers['user-agent'] - else: - headers['user-agent'] = self.user_agent - - resp, content = request_orig(uri, method, body, headers, - redirections, connection_type) - - if resp.status == 401: - logger.info('Refreshing due to a 401') - self._refresh(request_orig) - headers['authorization'] = 'OAuth ' + self.access_token - return request_orig(uri, method, body, headers, - redirections, connection_type) - else: - return (resp, content) - - http.request = new_request - return http - class AccessTokenCredentials(OAuth2Credentials): """Credentials object for OAuth 2.0. diff --git a/tests/test_http.py b/tests/test_http.py index 67e8aa6..cb7a832 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -35,6 +35,52 @@ from apiclient.http import MediaUpload from apiclient.http import MediaInMemoryUpload from apiclient.http import set_user_agent from apiclient.model import JsonModel +from oauth2client.client import Credentials + + +class MockCredentials(Credentials): + """Mock class for all Credentials objects.""" + def __init__(self, bearer_token): + super(MockCredentials, self).__init__() + self._authorized = 0 + self._refreshed = 0 + self._applied = 0 + self._bearer_token = bearer_token + + def authorize(self, http): + self._authorized += 1 + + request_orig = http.request + + # The closure that will replace 'httplib2.Http.request'. + def new_request(uri, method='GET', body=None, headers=None, + redirections=httplib2.DEFAULT_MAX_REDIRECTS, + connection_type=None): + # Modify the request headers to add the appropriate + # Authorization header. + if headers is None: + headers = {} + self.apply(headers) + + resp, content = request_orig(uri, method, body, headers, + redirections, connection_type) + + return resp, content + + # Replace the request method with our own closure. + http.request = new_request + + # Set credentials as a property of the request method. + setattr(http.request, 'credentials', self) + + return http + + def refresh(self, http): + self._refreshed += 1 + + def apply(self, headers): + self._applied += 1 + headers['authorization'] = self._bearer_token + ' ' + str(self._refreshed) DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') @@ -52,7 +98,7 @@ class TestUserAgent(unittest.TestCase): http = set_user_agent(http, "my_app/5.5") resp, content = http.request("http://example.com") - self.assertEqual(content['user-agent'], 'my_app/5.5') + self.assertEqual('my_app/5.5', content['user-agent']) def test_set_user_agent_nested(self): http = HttpMockSequence([ @@ -62,25 +108,25 @@ class TestUserAgent(unittest.TestCase): http = set_user_agent(http, "my_app/5.5") http = set_user_agent(http, "my_library/0.1") resp, content = http.request("http://example.com") - self.assertEqual(content['user-agent'], 'my_app/5.5 my_library/0.1') + self.assertEqual('my_app/5.5 my_library/0.1', content['user-agent']) def test_media_file_upload_to_from_json(self): upload = MediaFileUpload( datafile('small.png'), chunksize=500, resumable=True) - self.assertEquals('image/png', upload.mimetype()) - self.assertEquals(190, upload.size()) - self.assertEquals(True, upload.resumable()) - self.assertEquals(500, upload.chunksize()) - self.assertEquals('PNG', upload.getbytes(1, 3)) + self.assertEqual('image/png', upload.mimetype()) + self.assertEqual(190, upload.size()) + self.assertEqual(True, upload.resumable()) + self.assertEqual(500, upload.chunksize()) + self.assertEqual('PNG', upload.getbytes(1, 3)) json = upload.to_json() new_upload = MediaUpload.new_from_json(json) - self.assertEquals('image/png', new_upload.mimetype()) - self.assertEquals(190, new_upload.size()) - self.assertEquals(True, new_upload.resumable()) - self.assertEquals(500, new_upload.chunksize()) - self.assertEquals('PNG', new_upload.getbytes(1, 3)) + self.assertEqual('image/png', new_upload.mimetype()) + self.assertEqual(190, new_upload.size()) + self.assertEqual(True, new_upload.resumable()) + self.assertEqual(500, new_upload.chunksize()) + self.assertEqual('PNG', new_upload.getbytes(1, 3)) def test_http_request_to_from_json(self): @@ -103,13 +149,13 @@ class TestUserAgent(unittest.TestCase): json = req.to_json() new_req = HttpRequest.from_json(json, http, _postproc) - self.assertEquals(new_req.headers, - {'content-type': - 'multipart/related; boundary="---flubber"'}) - self.assertEquals(new_req.uri, 'http://example.com') - self.assertEquals(new_req.body, '{}') - self.assertEquals(new_req.http, http) - self.assertEquals(new_req.resumable.to_json(), media_upload.to_json()) + self.assertEqual({'content-type': + 'multipart/related; boundary="---flubber"'}, + new_req.headers) + self.assertEqual('http://example.com', new_req.uri) + self.assertEqual('{}', new_req.body) + self.assertEqual(http, new_req.http) + self.assertEqual(media_upload.to_json(), new_req.resumable.to_json()) EXPECTED = """POST /someapi/v1/collection/?foo=bar HTTP/1.1 Content-Type: application/json @@ -153,6 +199,50 @@ ETag: "etag/sheep"\r\n\r\n{"baz": "qux"} --batch_foobarbaz--""" +BATCH_RESPONSE_WITH_401 = """--batch_foobarbaz +Content-Type: application/http +Content-Transfer-Encoding: binary +Content-ID: + +HTTP/1.1 401 Authoration Required +Content-Type application/json +Content-Length: 14 +ETag: "etag/pony"\r\n\r\n{"error": {"message": + "Authorizaton failed."}} + +--batch_foobarbaz +Content-Type: application/http +Content-Transfer-Encoding: binary +Content-ID: + +HTTP/1.1 200 OK +Content-Type application/json +Content-Length: 14 +ETag: "etag/sheep"\r\n\r\n{"baz": "qux"} +--batch_foobarbaz--""" + + +BATCH_SINGLE_RESPONSE = """--batch_foobarbaz +Content-Type: application/http +Content-Transfer-Encoding: binary +Content-ID: + +HTTP/1.1 200 OK +Content-Type application/json +Content-Length: 14 +ETag: "etag/pony"\r\n\r\n{"foo": 42} +--batch_foobarbaz--""" + +class Callbacks(object): + def __init__(self): + self.responses = {} + self.exceptions = {} + + def f(self, request_id, response, exception): + self.responses[request_id] = response + self.exceptions[request_id] = exception + + class TestBatch(unittest.TestCase): def setUp(self): @@ -196,7 +286,7 @@ class TestBatch(unittest.TestCase): methodId=None, resumable=None) s = batch._serialize_request(request).splitlines() - self.assertEquals(s, EXPECTED.splitlines()) + self.assertEqual(EXPECTED.splitlines(), s) def test_serialize_request_media_body(self): batch = BatchHttpRequest() @@ -213,9 +303,9 @@ class TestBatch(unittest.TestCase): headers={'content-type': 'application/json'}, methodId=None, resumable=None) + # Just testing it shouldn't raise an exception. s = batch._serialize_request(request).splitlines() - def test_serialize_request_no_body(self): batch = BatchHttpRequest() request = HttpRequest( @@ -228,30 +318,30 @@ class TestBatch(unittest.TestCase): methodId=None, resumable=None) s = batch._serialize_request(request).splitlines() - self.assertEquals(s, NO_BODY_EXPECTED.splitlines()) + self.assertEqual(NO_BODY_EXPECTED.splitlines(), s) def test_deserialize_response(self): batch = BatchHttpRequest() resp, content = batch._deserialize_response(RESPONSE) - self.assertEquals(resp.status, 200) - self.assertEquals(resp.reason, 'OK') - self.assertEquals(resp.version, 11) - self.assertEquals(content, '{"answer": 42}') + self.assertEqual(200, resp.status) + self.assertEqual('OK', resp.reason) + self.assertEqual(11, resp.version) + self.assertEqual('{"answer": 42}', content) def test_new_id(self): batch = BatchHttpRequest() id_ = batch._new_id() - self.assertEquals(id_, '1') + self.assertEqual('1', id_) id_ = batch._new_id() - self.assertEquals(id_, '2') + self.assertEqual('2', id_) batch.add(self.request1, request_id='3') id_ = batch._new_id() - self.assertEquals(id_, '4') + self.assertEqual('4', id_) def test_add(self): batch = BatchHttpRequest() @@ -267,13 +357,6 @@ class TestBatch(unittest.TestCase): self.assertRaises(BatchError, batch.add, self.request1, request_id='1') def test_execute(self): - class Callbacks(object): - def __init__(self): - self.responses = {} - - def f(self, request_id, response): - self.responses[request_id] = response - batch = BatchHttpRequest() callbacks = Callbacks() @@ -285,8 +368,10 @@ class TestBatch(unittest.TestCase): BATCH_RESPONSE), ]) batch.execute(http) - self.assertEqual(callbacks.responses['1'], {'foo': 42}) - self.assertEqual(callbacks.responses['2'], {'baz': 'qux'}) + self.assertEqual({'foo': 42}, callbacks.responses['1']) + self.assertEqual(None, callbacks.exceptions['1']) + self.assertEqual({'baz': 'qux'}, callbacks.responses['2']) + self.assertEqual(None, callbacks.exceptions['2']) def test_execute_request_body(self): batch = BatchHttpRequest() @@ -311,14 +396,82 @@ class TestBatch(unittest.TestCase): header = parts[1].splitlines()[1] self.assertEqual('Content-Type: application/http', header) + def test_execute_refresh_and_retry_on_401(self): + batch = BatchHttpRequest() + callbacks = Callbacks() + cred_1 = MockCredentials('Foo') + cred_2 = MockCredentials('Bar') + + http = HttpMockSequence([ + ({'status': '200', + 'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'}, + BATCH_RESPONSE_WITH_401), + ({'status': '200', + 'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'}, + BATCH_SINGLE_RESPONSE), + ]) + + creds_http_1 = HttpMockSequence([]) + cred_1.authorize(creds_http_1) + + creds_http_2 = HttpMockSequence([]) + cred_2.authorize(creds_http_2) + + self.request1.http = creds_http_1 + self.request2.http = creds_http_2 + + batch.add(self.request1, callback=callbacks.f) + batch.add(self.request2, callback=callbacks.f) + batch.execute(http) + + self.assertEqual({'foo': 42}, callbacks.responses['1']) + self.assertEqual(None, callbacks.exceptions['1']) + self.assertEqual({'baz': 'qux'}, callbacks.responses['2']) + self.assertEqual(None, callbacks.exceptions['2']) + + self.assertEqual(1, cred_1._refreshed) + self.assertEqual(0, cred_2._refreshed) + + self.assertEqual(1, cred_1._authorized) + self.assertEqual(1, cred_2._authorized) + + self.assertEqual(1, cred_2._applied) + self.assertEqual(2, cred_1._applied) + + def test_http_errors_passed_to_callback(self): + batch = BatchHttpRequest() + callbacks = Callbacks() + cred_1 = MockCredentials('Foo') + cred_2 = MockCredentials('Bar') + + http = HttpMockSequence([ + ({'status': '200', + 'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'}, + BATCH_RESPONSE_WITH_401), + ({'status': '200', + 'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'}, + BATCH_RESPONSE_WITH_401), + ]) + + creds_http_1 = HttpMockSequence([]) + cred_1.authorize(creds_http_1) + + creds_http_2 = HttpMockSequence([]) + cred_2.authorize(creds_http_2) + + self.request1.http = creds_http_1 + self.request2.http = creds_http_2 + + batch.add(self.request1, callback=callbacks.f) + batch.add(self.request2, callback=callbacks.f) + batch.execute(http) + + self.assertEqual(None, callbacks.responses['1']) + self.assertEqual(401, callbacks.exceptions['1'].resp.status) + self.assertEqual({u'baz': u'qux'}, callbacks.responses['2']) + self.assertEqual(None, callbacks.exceptions['2']) + def test_execute_global_callback(self): - class Callbacks(object): - def __init__(self): - self.responses = {} - - def f(self, request_id, response): - self.responses[request_id] = response - callbacks = Callbacks() batch = BatchHttpRequest(callback=callbacks.f) @@ -330,8 +483,8 @@ class TestBatch(unittest.TestCase): BATCH_RESPONSE), ]) batch.execute(http) - self.assertEqual(callbacks.responses['1'], {'foo': 42}) - self.assertEqual(callbacks.responses['2'], {'baz': 'qux'}) + self.assertEqual({'foo': 42}, callbacks.responses['1']) + self.assertEqual({'baz': 'qux'}, callbacks.responses['2']) def test_media_inmemory_upload(self): media = MediaInMemoryUpload('abcdef', 'text/plain', chunksize=10, diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py index d50d6a0..3e512a1 100644 --- a/tests/test_oauth2client.py +++ b/tests/test_oauth2client.py @@ -70,7 +70,7 @@ class OAuth2CredentialsTests(unittest.TestCase): ]) http = self.credentials.authorize(http) resp, content = http.request("http://example.com") - self.assertEqual(content['authorization'], 'OAuth 1/3w') + self.assertEqual('Bearer 1/3w', content['Authorization']) def test_token_refresh_failure(self): http = HttpMockSequence([ @@ -95,11 +95,11 @@ class OAuth2CredentialsTests(unittest.TestCase): def test_to_from_json(self): json = self.credentials.to_json() instance = OAuth2Credentials.from_json(json) - self.assertEquals(type(instance), OAuth2Credentials) + self.assertEqual(OAuth2Credentials, type(instance)) instance.token_expiry = None self.credentials.token_expiry = None - self.assertEquals(self.credentials.__dict__, instance.__dict__) + self.assertEqual(instance.__dict__, self.credentials.__dict__) class AccessTokenCredentialsTests(unittest.TestCase): @@ -136,7 +136,7 @@ class AccessTokenCredentialsTests(unittest.TestCase): ]) http = self.credentials.authorize(http) resp, content = http.request('http://example.com') - self.assertEqual(content['authorization'], 'OAuth foo') + self.assertEqual('Bearer foo', content['Authorization']) class TestAssertionCredentials(unittest.TestCase): @@ -155,8 +155,8 @@ class TestAssertionCredentials(unittest.TestCase): def test_assertion_body(self): body = urlparse.parse_qs(self.credentials._generate_refresh_request_body()) - self.assertEqual(body['assertion'][0], self.assertion_text) - self.assertEqual(body['assertion_type'][0], self.assertion_type) + self.assertEqual(self.assertion_text, body['assertion'][0]) + self.assertEqual(self.assertion_type, body['assertion_type'][0]) def test_assertion_refresh(self): http = HttpMockSequence([ @@ -165,7 +165,7 @@ class TestAssertionCredentials(unittest.TestCase): ]) http = self.credentials.authorize(http) resp, content = http.request("http://example.com") - self.assertEqual(content['authorization'], 'OAuth 1/3w') + self.assertEqual('Bearer 1/3w', content['Authorization']) class ExtractIdTokenText(unittest.TestCase): @@ -177,7 +177,7 @@ class ExtractIdTokenText(unittest.TestCase): jwt = 'stuff.' + payload + '.signature' extracted = _extract_id_token(jwt) - self.assertEqual(body, extracted) + self.assertEqual(extracted, body) def test_extract_failure(self): body = {'foo': 'bar'} @@ -201,11 +201,11 @@ class OAuth2WebServerFlowTest(unittest.TestCase): parsed = urlparse.urlparse(authorize_url) q = parse_qs(parsed[4]) - self.assertEqual(q['client_id'][0], 'client_id+1') - self.assertEqual(q['response_type'][0], 'code') - self.assertEqual(q['scope'][0], 'foo') - self.assertEqual(q['redirect_uri'][0], 'OOB_CALLBACK_URN') - self.assertEqual(q['access_type'][0], 'offline') + self.assertEqual('client_id+1', q['client_id'][0]) + self.assertEqual('code', q['response_type'][0]) + self.assertEqual('foo', q['scope'][0]) + self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0]) + self.assertEqual('offline', q['access_type'][0]) def test_override_flow_access_type(self): """Passing access_type overrides the default.""" @@ -220,11 +220,11 @@ class OAuth2WebServerFlowTest(unittest.TestCase): parsed = urlparse.urlparse(authorize_url) q = parse_qs(parsed[4]) - self.assertEqual(q['client_id'][0], 'client_id+1') - self.assertEqual(q['response_type'][0], 'code') - self.assertEqual(q['scope'][0], 'foo') - self.assertEqual(q['redirect_uri'][0], 'OOB_CALLBACK_URN') - self.assertEqual(q['access_type'][0], 'online') + self.assertEqual('client_id+1', q['client_id'][0]) + self.assertEqual('code', q['response_type'][0]) + self.assertEqual('foo', q['scope'][0]) + self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0]) + self.assertEqual('online', q['access_type'][0]) def test_exchange_failure(self): http = HttpMockSequence([ @@ -246,9 +246,9 @@ class OAuth2WebServerFlowTest(unittest.TestCase): ]) credentials = self.flow.step2_exchange('some random code', http) - self.assertEqual(credentials.access_token, 'SlAV32hkKG') - self.assertNotEqual(credentials.token_expiry, None) - self.assertEqual(credentials.refresh_token, '8xLOxBtZp8') + self.assertEqual('SlAV32hkKG', credentials.access_token) + self.assertNotEqual(None, credentials.token_expiry) + self.assertEqual('8xLOxBtZp8', credentials.refresh_token) def test_exchange_no_expires_in(self): http = HttpMockSequence([ @@ -257,7 +257,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase): ]) credentials = self.flow.step2_exchange('some random code', http) - self.assertEqual(credentials.token_expiry, None) + self.assertEqual(None, credentials.token_expiry) def test_exchange_id_token_fail(self): http = HttpMockSequence([ @@ -282,7 +282,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase): ]) credentials = self.flow.step2_exchange('some random code', http) - self.assertEquals(body, credentials.id_token) + self.assertEqual(credentials.id_token, body) if __name__ == '__main__': diff --git a/tests/test_oauth2client_jwt.py b/tests/test_oauth2client_jwt.py index 83551c6..dcbb33c 100644 --- a/tests/test_oauth2client_jwt.py +++ b/tests/test_oauth2client_jwt.py @@ -100,8 +100,8 @@ class CryptTests(unittest.TestCase): certs = {'foo': public_key } audience = 'some_audience_address@testing.gserviceaccount.com' contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience) - self.assertEquals('billy bob', contents['user']) - self.assertEquals('data', contents['metadata']['meta']) + self.assertEqual('billy bob', contents['user']) + self.assertEqual('data', contents['metadata']['meta']) def test_verify_id_token_with_certs_uri(self): jwt = self._create_signed_jwt() @@ -112,8 +112,8 @@ class CryptTests(unittest.TestCase): contents = verify_id_token(jwt, 'some_audience_address@testing.gserviceaccount.com', http) - self.assertEquals('billy bob', contents['user']) - self.assertEquals('data', contents['metadata']['meta']) + self.assertEqual('billy bob', contents['user']) + self.assertEqual('data', contents['metadata']['meta']) def test_verify_id_token_with_certs_uri_fails(self): jwt = self._create_signed_jwt() @@ -195,7 +195,7 @@ class CryptTests(unittest.TestCase): ]) http = credentials.authorize(http) resp, content = http.request('http://example.org') - self.assertEquals(content['authorization'], 'OAuth 1/3w') + self.assertEqual('Bearer 1/3w', content['Authorization']) if __name__ == '__main__': unittest.main()