From 1a0c4dbf92eddda508164a8578a62147590521ed Mon Sep 17 00:00:00 2001 From: Danny Hermes Date: Wed, 10 Aug 2016 16:53:46 -0700 Subject: [PATCH] Use transport.request in tests. In the process - "spring clean" the modules that were touched - use HttpMock when HttpMockSequence not needed - add some verifications on new HttpMock's --- .../contrib/test_multiprocess_file_storage.py | 62 +++++++---- tests/http_mock.py | 2 +- tests/test_client.py | 51 +++++---- tests/test_file.py | 79 ++++++++------ tests/test_jwt.py | 62 +++++++---- tests/test_service_account.py | 100 ++++++++++++++---- 6 files changed, 237 insertions(+), 119 deletions(-) diff --git a/tests/contrib/test_multiprocess_file_storage.py b/tests/contrib/test_multiprocess_file_storage.py index 9e29e05..0681e58 100644 --- a/tests/contrib/test_multiprocess_file_storage.py +++ b/tests/contrib/test_multiprocess_file_storage.py @@ -24,12 +24,13 @@ import unittest import fasteners import mock -from six import StringIO +import six +from six.moves import urllib_parse from oauth2client import client from oauth2client.contrib import multiprocess_file_storage -from ..http_mock import HttpMockSequence +from .. import http_mock @contextlib.contextmanager @@ -68,11 +69,7 @@ def _generate_token_response_http(new_token='new_token'): 'access_token': new_token, 'expires_in': '3600', }) - http = HttpMockSequence([ - ({'status': '200'}, token_response), - ]) - - return http + return http_mock.HttpMock(data=token_response) class MultiprocessStorageBehaviorTests(unittest.TestCase): @@ -115,6 +112,23 @@ class MultiprocessStorageBehaviorTests(unittest.TestCase): self.assertIsNone(credentials) + def _verify_refresh_payload(self, http, credentials): + self.assertEqual(http.requests, 1) + self.assertEqual(http.uri, credentials.token_uri) + self.assertEqual(http.method, 'POST') + expected_body = { + 'grant_type': ['refresh_token'], + 'client_id': [credentials.client_id], + 'client_secret': [credentials.client_secret], + 'refresh_token': [credentials.refresh_token], + } + self.assertEqual(urllib_parse.parse_qs(http.body), expected_body) + expected_headers = { + 'content-type': 'application/x-www-form-urlencoded', + 'user-agent': credentials.user_agent, + } + self.assertEqual(http.headers, expected_headers) + def test_single_process_refresh(self): store = multiprocess_file_storage.MultiprocessFileStorage( self.filename, 'single-process') @@ -128,6 +142,9 @@ class MultiprocessStorageBehaviorTests(unittest.TestCase): retrieved = store.get() self.assertEqual(retrieved.access_token, 'new_token') + # Verify mocks. + self._verify_refresh_payload(http, credentials) + def test_multi_process_refresh(self): # This will test that two processes attempting to refresh credentials # will only refresh once. @@ -136,6 +153,7 @@ class MultiprocessStorageBehaviorTests(unittest.TestCase): credentials = _create_test_credentials() credentials.set_store(store) store.put(credentials) + actual_token = 'b' def child_process_func( die_event, ready_event, check_event): # pragma: NO COVER @@ -156,10 +174,12 @@ class MultiprocessStorageBehaviorTests(unittest.TestCase): credentials.store.acquire_lock = replacement_acquire_lock - http = _generate_token_response_http('b') + http = _generate_token_response_http(actual_token) credentials.refresh(http) + self.assertEqual(credentials.access_token, actual_token) - self.assertEqual(credentials.access_token, 'b') + # Verify mock http. + self._verify_refresh_payload(http, credentials) check_event = multiprocessing.Event() with scoped_child_process(child_process_func, check_event=check_event): @@ -168,15 +188,17 @@ class MultiprocessStorageBehaviorTests(unittest.TestCase): store._backend._process_lock.acquire(blocking=False)) check_event.set() - # The child process will refresh first, so we should end up - # with 'b' as the token. - http = mock.Mock() + http = _generate_token_response_http('not ' + actual_token) credentials.refresh(http=http) - self.assertEqual(credentials.access_token, 'b') - self.assertFalse(http.request.called) + # The child process will refresh first, so we should end up + # with `actual_token`' as the token. + self.assertEqual(credentials.access_token, actual_token) + + # Make sure the refresh did not make a request. + self.assertEqual(http.requests, 0) retrieved = store.get() - self.assertEqual(retrieved.access_token, 'b') + self.assertEqual(retrieved.access_token, actual_token) def test_read_only_file_fail_lock(self): credentials = _create_test_credentials() @@ -233,7 +255,7 @@ class MultiprocessStorageUnitTests(unittest.TestCase): def test__read_write_credentials_file(self): credentials = _create_test_credentials() - contents = StringIO() + contents = six.StringIO() multiprocess_file_storage._write_credentials_file( contents, {'key': credentials}) @@ -253,23 +275,23 @@ class MultiprocessStorageUnitTests(unittest.TestCase): # the invalid one but still load the valid one. data['credentials']['invalid'] = '123' results = multiprocess_file_storage._load_credentials_file( - StringIO(json.dumps(data))) + six.StringIO(json.dumps(data))) self.assertNotIn('invalid', results) self.assertEqual( results['key'].access_token, credentials.access_token) def test__load_credentials_file_invalid_json(self): - contents = StringIO('{[') + contents = six.StringIO('{[') self.assertEqual( multiprocess_file_storage._load_credentials_file(contents), {}) def test__load_credentials_file_no_file_version(self): - contents = StringIO('{}') + contents = six.StringIO('{}') self.assertEqual( multiprocess_file_storage._load_credentials_file(contents), {}) def test__load_credentials_file_bad_file_version(self): - contents = StringIO(json.dumps({'file_version': 1})) + contents = six.StringIO(json.dumps({'file_version': 1})) self.assertEqual( multiprocess_file_storage._load_credentials_file(contents), {}) diff --git a/tests/http_mock.py b/tests/http_mock.py index f1be78e..a29024f 100644 --- a/tests/http_mock.py +++ b/tests/http_mock.py @@ -75,7 +75,7 @@ class HttpMockSequence(object): ({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'), ({'status': '200'}, 'echo_request_headers'), ]) - resp, content = http.request("http://examples.com") + resp, content = http.request('http://examples.com') There are special values you can pass in for content to trigger behavours that are helpful in testing. diff --git a/tests/test_client.py b/tests/test_client.py index 0ac4b3e..3d41da1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -35,6 +35,7 @@ from oauth2client import _helpers from oauth2client import client from oauth2client import clientsecrets from oauth2client import service_account +from oauth2client import transport from . import http_mock __author__ = 'jcgregorio@google.com (Joe Gregorio)' @@ -899,7 +900,7 @@ class BasicCredentialsTests(unittest.TestCase): ({'status': http_client.OK}, 'echo_request_headers'), ]) http = self.credentials.authorize(http) - resp, content = http.request('http://example.com') + resp, content = transport.request(http, 'http://example.com') self.assertEqual(b'Bearer 1/3w', content[b'Authorization']) self.assertFalse(self.credentials.access_token_expired) self.assertEqual(token_response, self.credentials.token_response) @@ -918,7 +919,7 @@ class BasicCredentialsTests(unittest.TestCase): http = http_mock.HttpMock(data=encoded_response) http = self.credentials.authorize(http) http = self.credentials.authorize(http) - http.request('http://example.com') + transport.request(http, 'http://example.com') def test_token_refresh_failure(self): for status_code in client.REFRESH_STATUS_CODES: @@ -930,7 +931,7 @@ class BasicCredentialsTests(unittest.TestCase): http = self.credentials.authorize(http) with self.assertRaises( client.HttpAccessTokenRefreshError) as exc_manager: - http.request('http://example.com') + transport.request(http, 'http://example.com') self.assertEqual(http_client.BAD_REQUEST, exc_manager.exception.status) self.assertTrue(self.credentials.access_token_expired) @@ -957,7 +958,7 @@ class BasicCredentialsTests(unittest.TestCase): def test_non_401_error_response(self): http = http_mock.HttpMock(headers={'status': http_client.BAD_REQUEST}) http = self.credentials.authorize(http) - resp, content = http.request('http://example.com') + resp, content = transport.request(http, 'http://example.com') self.assertEqual(http_client.BAD_REQUEST, resp.status) self.assertEqual(None, self.credentials.token_response) @@ -1001,7 +1002,8 @@ class BasicCredentialsTests(unittest.TestCase): http = credentials.authorize(http_mock.HttpMock()) headers = {u'foo': 3, b'bar': True, 'baz': b'abc'} cleaned_headers = {b'foo': b'3', b'bar': b'True', b'baz': b'abc'} - http.request(u'http://example.com', method=u'GET', headers=headers) + transport.request( + http, u'http://example.com', method=u'GET', headers=headers) for k, v in cleaned_headers.items(): self.assertTrue(k in http.headers) self.assertEqual(v, http.headers[k]) @@ -1009,8 +1011,9 @@ class BasicCredentialsTests(unittest.TestCase): # Next, test that we do fail on unicode. unicode_str = six.unichr(40960) + 'abcd' with self.assertRaises(client.NonAsciiHeaderError): - http.request(u'http://example.com', method=u'GET', - headers={u'foo': unicode_str}) + transport.request( + http, u'http://example.com', method=u'GET', + headers={u'foo': unicode_str}) def test_no_unicode_in_request_params(self): access_token = u'foo' @@ -1027,8 +1030,9 @@ class BasicCredentialsTests(unittest.TestCase): http = http_mock.HttpMock() http = credentials.authorize(http) - http.request(u'http://example.com', method=u'GET', - headers={u'foo': u'bar'}) + transport.request( + http, u'http://example.com', method=u'GET', + headers={u'foo': u'bar'}) for k, v in six.iteritems(http.headers): self.assertIsInstance(k, six.binary_type) self.assertIsInstance(v, six.binary_type) @@ -1036,8 +1040,8 @@ class BasicCredentialsTests(unittest.TestCase): # Test again with unicode strings that can't simply be converted # to ASCII. with self.assertRaises(client.NonAsciiHeaderError): - http.request( - u'http://example.com', method=u'GET', + transport.request( + http, u'http://example.com', method=u'GET', headers={u'foo': u'\N{COMET}'}) self.credentials.token_response = 'foobar' @@ -1473,7 +1477,7 @@ class BasicCredentialsTests(unittest.TestCase): ({'status': '200'}, 'echo_request_headers'), ]) http = self.credentials.authorize(http) - resp, content = http.request('http://example.com') + resp, content = transport.request(http, 'http://example.com') self.assertEqual(self.credentials.id_token, body) @@ -1492,7 +1496,7 @@ class AccessTokenCredentialsTests(unittest.TestCase): headers={'status': status_code}, data=b'') http = self.credentials.authorize(http) with self.assertRaises(client.AccessTokenCredentialsError): - resp, content = http.request('http://example.com') + resp, content = transport.request(http, 'http://example.com') def test_token_revoke_success(self): _token_revoke_test_helper( @@ -1507,7 +1511,7 @@ class AccessTokenCredentialsTests(unittest.TestCase): def test_non_401_error_response(self): http = http_mock.HttpMock(headers={'status': http_client.BAD_REQUEST}) http = self.credentials.authorize(http) - resp, content = http.request('http://example.com') + resp, content = transport.request(http, 'http://example.com') self.assertEqual(http_client.BAD_REQUEST, resp.status) def test_auth_header_sent(self): @@ -1515,7 +1519,7 @@ class AccessTokenCredentialsTests(unittest.TestCase): ({'status': '200'}, 'echo_request_headers'), ]) http = self.credentials.authorize(http) - resp, content = http.request('http://example.com') + resp, content = transport.request(http, 'http://example.com') self.assertEqual(b'Bearer foo', content[b'Authorization']) @@ -1551,7 +1555,7 @@ class TestAssertionCredentials(unittest.TestCase): ({'status': '200'}, 'echo_request_headers'), ]) http = self.credentials.authorize(http) - resp, content = http.request('http://example.com') + resp, content = transport.request(http, 'http://example.com') self.assertEqual(b'Bearer 1/3w', content[b'Authorization']) def test_token_revoke_success(self): @@ -1742,16 +1746,17 @@ class OAuth2WebServerFlowTest(unittest.TestCase): device_code, user_code, None, ver_url, None) self.assertEqual(result, expected) self.assertEqual(len(http.requests), 1) - self.assertEqual( - http.requests[0]['uri'], oauth2client.GOOGLE_DEVICE_URI) - body = http.requests[0]['body'] - self.assertEqual(urllib.parse.parse_qs(body), - {'client_id': [flow.client_id], - 'scope': [flow.scope]}) + info = http.requests[0] + self.assertEqual(info['uri'], oauth2client.GOOGLE_DEVICE_URI) + expected_body = { + 'client_id': [flow.client_id], + 'scope': [flow.scope], + } + self.assertEqual(urllib.parse.parse_qs(info['body']), expected_body) headers = {'content-type': 'application/x-www-form-urlencoded'} if extra_headers is not None: headers.update(extra_headers) - self.assertEqual(http.requests[0]['headers'], headers) + self.assertEqual(info['headers'], headers) def test_step1_get_device_and_user_codes(self): self._step1_get_device_and_user_codes_helper() diff --git a/tests/test_file.py b/tests/test_file.py index 8604344..437ad6d 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Oauth2client.file tests - -Unit tests for oauth2client.file -""" +"""Unit tests for oauth2client.file.""" import copy import datetime @@ -30,11 +27,13 @@ import warnings import mock import six from six.moves import http_client +from six.moves import urllib_parse from oauth2client import _helpers from oauth2client import client -from oauth2client import file -from .http_mock import HttpMockSequence +from oauth2client import file as file_module +from oauth2client import transport +from . import http_mock try: # Python2 @@ -80,14 +79,14 @@ class OAuth2ClientFileTests(unittest.TestCase): @mock.patch('warnings.warn') def test_non_existent_file_storage(self, warn_mock): - storage = file.Storage(FILENAME) + storage = file_module.Storage(FILENAME) credentials = storage.get() warn_mock.assert_called_with( _helpers._MISSING_FILE_MESSAGE.format(FILENAME)) self.assertIsNone(credentials) def test_directory_file_storage(self): - storage = file.Storage(FILENAME) + storage = file_module.Storage(FILENAME) os.mkdir(FILENAME) try: with self.assertRaises(IOError): @@ -99,7 +98,7 @@ class OAuth2ClientFileTests(unittest.TestCase): def test_no_sym_link_credentials(self): SYMFILENAME = FILENAME + '.sym' os.symlink(FILENAME, SYMFILENAME) - storage = file.Storage(SYMFILENAME) + storage = file_module.Storage(SYMFILENAME) try: with self.assertRaises(IOError): storage.get() @@ -116,7 +115,7 @@ class OAuth2ClientFileTests(unittest.TestCase): # Storage should be not be able to read that object, as the capability # to read and write credentials as pickled objects has been removed. - storage = file.Storage(FILENAME) + storage = file_module.Storage(FILENAME) read_credentials = storage.get() self.assertIsNone(read_credentials) @@ -134,7 +133,7 @@ class OAuth2ClientFileTests(unittest.TestCase): datetime.timedelta(minutes=15)) credentials = self._create_test_credentials(expiration=expiration) - storage = file.Storage(FILENAME) + storage = file_module.Storage(FILENAME) storage.put(credentials) credentials = storage.get() new_cred = copy.copy(credentials) @@ -143,13 +142,29 @@ class OAuth2ClientFileTests(unittest.TestCase): access_token = '1/3w' token_response = {'access_token': access_token, 'expires_in': 3600} - http = HttpMockSequence([ - ({'status': '200'}, json.dumps(token_response).encode('utf-8')), - ]) + response_content = json.dumps(token_response).encode('utf-8') + http = http_mock.HttpMock(data=response_content) - credentials._refresh(http.request) + credentials._refresh(http) self.assertEquals(credentials.access_token, access_token) + # Verify mocks. + self.assertEqual(http.requests, 1) + self.assertEqual(http.uri, credentials.token_uri) + self.assertEqual(http.method, 'POST') + expected_body = { + 'grant_type': ['refresh_token'], + 'client_id': [credentials.client_id], + 'client_secret': [credentials.client_secret], + 'refresh_token': [credentials.refresh_token], + } + self.assertEqual(urllib_parse.parse_qs(http.body), expected_body) + expected_headers = { + 'content-type': 'application/x-www-form-urlencoded', + 'user-agent': credentials.user_agent, + } + self.assertEqual(http.headers, expected_headers) + def test_token_refresh_store_expires_soon(self): # Tests the case where an access token that is valid when it is read # from the store expires before the original request succeeds. @@ -157,7 +172,7 @@ class OAuth2ClientFileTests(unittest.TestCase): datetime.timedelta(minutes=15)) credentials = self._create_test_credentials(expiration=expiration) - storage = file.Storage(FILENAME) + storage = file_module.Storage(FILENAME) storage.put(credentials) credentials = storage.get() new_cred = copy.copy(credentials) @@ -166,19 +181,19 @@ class OAuth2ClientFileTests(unittest.TestCase): access_token = '1/3w' token_response = {'access_token': access_token, 'expires_in': 3600} - http = HttpMockSequence([ - ({'status': str(int(http_client.UNAUTHORIZED))}, + http = http_mock.HttpMockSequence([ + ({'status': http_client.UNAUTHORIZED}, b'Initial token expired'), - ({'status': str(int(http_client.UNAUTHORIZED))}, + ({'status': http_client.UNAUTHORIZED}, b'Store token expired'), - ({'status': str(int(http_client.OK))}, + ({'status': http_client.OK}, json.dumps(token_response).encode('utf-8')), - ({'status': str(int(http_client.OK))}, + ({'status': http_client.OK}, b'Valid response to original request') ]) credentials.authorize(http) - http.request('https://example.com') + transport.request(http, 'https://example.com') self.assertEqual(credentials.access_token, access_token) def test_token_refresh_good_store(self): @@ -186,7 +201,7 @@ class OAuth2ClientFileTests(unittest.TestCase): datetime.timedelta(minutes=15)) credentials = self._create_test_credentials(expiration=expiration) - storage = file.Storage(FILENAME) + storage = file_module.Storage(FILENAME) storage.put(credentials) credentials = storage.get() new_cred = copy.copy(credentials) @@ -201,7 +216,7 @@ class OAuth2ClientFileTests(unittest.TestCase): datetime.timedelta(minutes=15)) credentials = self._create_test_credentials(expiration=expiration) - storage = file.Storage(FILENAME) + storage = file_module.Storage(FILENAME) storage.put(credentials) credentials = storage.get() new_cred = copy.copy(credentials) @@ -211,27 +226,27 @@ class OAuth2ClientFileTests(unittest.TestCase): valid_access_token = '1/3w' token_response = {'access_token': valid_access_token, 'expires_in': 3600} - http = HttpMockSequence([ - ({'status': str(int(http_client.UNAUTHORIZED))}, + http = http_mock.HttpMockSequence([ + ({'status': http_client.UNAUTHORIZED}, b'Initial token expired'), - ({'status': str(int(http_client.UNAUTHORIZED))}, + ({'status': http_client.UNAUTHORIZED}, b'Store token expired'), - ({'status': str(int(http_client.OK))}, + ({'status': http_client.OK}, json.dumps(token_response).encode('utf-8')), - ({'status': str(int(http_client.OK))}, 'echo_request_body') + ({'status': http_client.OK}, 'echo_request_body') ]) body = six.StringIO('streaming body') credentials.authorize(http) - _, content = http.request('https://example.com', body=body) + _, content = transport.request(http, 'https://example.com', body=body) self.assertEqual(content, 'streaming body') self.assertEqual(credentials.access_token, valid_access_token) def test_credentials_delete(self): credentials = self._create_test_credentials() - storage = file.Storage(FILENAME) + storage = file_module.Storage(FILENAME) storage.put(credentials) credentials = storage.get() self.assertIsNotNone(credentials) @@ -245,7 +260,7 @@ class OAuth2ClientFileTests(unittest.TestCase): credentials = client.AccessTokenCredentials(access_token, user_agent) - storage = file.Storage(FILENAME) + storage = file_module.Storage(FILENAME) credentials = storage.put(credentials) credentials = storage.get() diff --git a/tests/test_jwt.py b/tests/test_jwt.py index fdbb37f..4ac0495 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -20,13 +20,15 @@ import time import unittest import mock +from six.moves import http_client from oauth2client import _helpers from oauth2client import client from oauth2client import crypt -from oauth2client import file +from oauth2client import file as file_module from oauth2client import service_account -from .http_mock import HttpMockSequence +from oauth2client import transport +from . import http_mock __author__ = 'jcgregorio@google.com (Joe Gregorio)' @@ -114,25 +116,30 @@ class CryptTests(unittest.TestCase): self.assertEqual('billy bob', contents['user']) self.assertEqual('data', contents['metadata']['meta']) + def _verify_http_mock(self, http): + self.assertEqual(http.requests, 1) + self.assertEqual(http.uri, client.ID_TOKEN_VERIFICATION_CERTS) + self.assertEqual(http.method, 'GET') + self.assertIsNone(http.body) + self.assertIsNone(http.headers) + def test_verify_id_token_with_certs_uri(self): jwt = self._create_signed_jwt() - http = HttpMockSequence([ - ({'status': '200'}, datafile('certs.json')), - ]) - + http = http_mock.HttpMock(data=datafile('certs.json')) contents = client.verify_id_token( jwt, 'some_audience_address@testing.gserviceaccount.com', http=http) self.assertEqual('billy bob', contents['user']) self.assertEqual('data', contents['metadata']['meta']) + # Verify mocks. + self._verify_http_mock(http) + def test_verify_id_token_with_certs_uri_default_http(self): jwt = self._create_signed_jwt() - http = HttpMockSequence([ - ({'status': '200'}, datafile('certs.json')), - ]) + http = http_mock.HttpMock(data=datafile('certs.json')) with mock.patch('oauth2client.transport._CACHED_HTTP', new=http): contents = client.verify_id_token( @@ -141,17 +148,23 @@ class CryptTests(unittest.TestCase): self.assertEqual('billy bob', contents['user']) self.assertEqual('data', contents['metadata']['meta']) + # Verify mocks. + self._verify_http_mock(http) + def test_verify_id_token_with_certs_uri_fails(self): jwt = self._create_signed_jwt() test_email = 'some_audience_address@testing.gserviceaccount.com' - http = HttpMockSequence([ - ({'status': '404'}, datafile('certs.json')), - ]) + http = http_mock.HttpMock( + headers={'status': http_client.NOT_FOUND}, + data=datafile('certs.json')) with self.assertRaises(client.VerifyJwtTokenError): client.verify_id_token(jwt, test_email, http=http) + # Verify mocks. + self._verify_http_mock(http) + def test_verify_id_token_bad_tokens(self): private_key = datafile('privatekey.' + self.format_) @@ -261,12 +274,13 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase): def test_credentials_good(self): credentials = self._make_credentials() - http = HttpMockSequence([ - ({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'), - ({'status': '200'}, 'echo_request_headers'), + http = http_mock.HttpMockSequence([ + ({'status': http_client.OK}, + b'{"access_token":"1/3w","expires_in":3600}'), + ({'status': http_client.OK}, 'echo_request_headers'), ]) http = credentials.authorize(http) - resp, content = http.request('http://example.org') + resp, content = transport.request(http, 'http://example.org') self.assertEqual(b'Bearer 1/3w', content[b'Authorization']) def test_credentials_to_from_json(self): @@ -280,14 +294,16 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase): self.assertEqual(credentials._kwargs, restored._kwargs) def _credentials_refresh(self, credentials): - http = HttpMockSequence([ - ({'status': '200'}, b'{"access_token":"1/3w","expires_in":3600}'), - ({'status': '401'}, b''), - ({'status': '200'}, b'{"access_token":"3/3w","expires_in":3600}'), - ({'status': '200'}, 'echo_request_headers'), + http = http_mock.HttpMockSequence([ + ({'status': http_client.OK}, + b'{"access_token":"1/3w","expires_in":3600}'), + ({'status': http_client.UNAUTHORIZED}, b''), + ({'status': http_client.OK}, + b'{"access_token":"3/3w","expires_in":3600}'), + ({'status': http_client.OK}, 'echo_request_headers'), ]) http = credentials.authorize(http) - _, content = http.request('http://example.org') + _, content = transport.request(http, 'http://example.org') return content def test_credentials_refresh_without_storage(self): @@ -300,7 +316,7 @@ class SignedJwtAssertionCredentialsTests(unittest.TestCase): filehandle, filename = tempfile.mkstemp() os.close(filehandle) - store = file.Storage(filename) + store = file_module.Storage(filename) store.put(credentials) credentials.set_store(store) diff --git a/tests/test_service_account.py b/tests/test_service_account.py index 7dc8ad0..d6b2f07 100644 --- a/tests/test_service_account.py +++ b/tests/test_service_account.py @@ -407,13 +407,15 @@ class JWTAccessCredentialsTests(unittest.TestCase): time.return_value = T1 token_info = self.jwt.get_access_token() + certs = {'key': datafile('public_cert.pem')} payload = crypt.verify_signed_jwt_with_certs( - token_info.access_token, - {'key': datafile('public_cert.pem')}, audience=self.url) + token_info.access_token, certs, audience=self.url) + self.assertEqual(len(payload), 5) self.assertEqual(payload['iss'], self.service_account_email) self.assertEqual(payload['sub'], self.service_account_email) self.assertEqual(payload['iat'], T1) self.assertEqual(payload['exp'], T1_EXPIRY) + self.assertEqual(payload['aud'], self.url) self.assertEqual(token_info.expires_in, T1_EXPIRY - T1) # Verify that we vend the same token after 100 seconds @@ -444,19 +446,20 @@ class JWTAccessCredentialsTests(unittest.TestCase): utcnow.return_value = T1_DATE time.return_value = T1 - token_info = self.jwt.get_access_token( - additional_claims={'aud': 'https://test2.url.com', - 'sub': 'dummy2@google.com' - }) + audience = 'https://test2.url.com' + subject = 'dummy2@google.com' + claims = {'aud': audience, 'sub': subject} + token_info = self.jwt.get_access_token(additional_claims=claims) + certs = {'key': datafile('public_cert.pem')} payload = crypt.verify_signed_jwt_with_certs( - token_info.access_token, - {'key': datafile('public_cert.pem')}, - audience='https://test2.url.com') + token_info.access_token, certs, audience=audience) expires_in = token_info.expires_in + self.assertEqual(len(payload), 5) self.assertEqual(payload['iss'], self.service_account_email) - self.assertEqual(payload['sub'], 'dummy2@google.com') + self.assertEqual(payload['sub'], subject) self.assertEqual(payload['iat'], T1) self.assertEqual(payload['exp'], T1_EXPIRY) + self.assertEqual(payload['aud'], audience) self.assertEqual(expires_in, T1_EXPIRY - T1) def test_revoke(self): @@ -502,13 +505,15 @@ class JWTAccessCredentialsTests(unittest.TestCase): self.assertIsNone(info['body']) self.assertEqual(len(info['headers']), 1) bearer, token = info['headers'][b'Authorization'].split() + self.assertEqual(bearer, b'Bearer') payload = crypt.verify_signed_jwt_with_certs( token, certs, audience=self.url) + self.assertEqual(len(payload), 5) self.assertEqual(payload['iss'], self.service_account_email) self.assertEqual(payload['sub'], self.service_account_email) self.assertEqual(payload['iat'], T1) self.assertEqual(payload['exp'], T1_EXPIRY) - self.assertEqual(bearer, b'Bearer') + self.assertEqual(payload['aud'], self.url) @mock.patch('oauth2client.client._UTCNOW') @mock.patch('time.time') @@ -538,53 +543,108 @@ class JWTAccessCredentialsTests(unittest.TestCase): self.assertIsNone(info['body']) self.assertEqual(len(info['headers']), 1) bearer, token = info['headers'][b'Authorization'].split() + self.assertEqual(bearer, b'Bearer') certs = {'key': datafile('public_cert.pem')} payload = crypt.verify_signed_jwt_with_certs( token, certs, audience=self.url) + self.assertEqual(len(payload), 5) self.assertEqual(payload['iss'], self.service_account_email) self.assertEqual(payload['sub'], self.service_account_email) self.assertEqual(payload['iat'], T1) self.assertEqual(payload['exp'], T1_EXPIRY) - self.assertEqual(bearer, b'Bearer') + self.assertEqual(payload['aud'], self.url) @mock.patch('oauth2client.client._UTCNOW') def test_authorize_stale_token(self, utcnow): utcnow.return_value = T1_DATE # Create an initial token - h = http_mock.HttpMockSequence([ + http = http_mock.HttpMockSequence([ ({'status': http_client.OK}, b''), ({'status': http_client.OK}, b''), ]) - self.jwt.authorize(h) - h.request(self.url) + self.jwt.authorize(http) + transport.request(http, self.url) token_1 = self.jwt.access_token # Expire the token utcnow.return_value = T3_DATE - h.request(self.url) + transport.request(http, self.url) token_2 = self.jwt.access_token self.assertEquals(self.jwt.token_expiry, T3_EXPIRY_DATE) self.assertNotEqual(token_1, token_2) + # Verify mocks. + certs = {'key': datafile('public_cert.pem')} + self.assertEqual(len(http.requests), 2) + issued_at_vals = (T1, T3) + exp_vals = (T1_EXPIRY, T3_EXPIRY) + for info, issued_at, exp_val in zip(http.requests, issued_at_vals, + exp_vals): + self.assertEqual(info['uri'], self.url) + self.assertEqual(info['method'], 'GET') + self.assertIsNone(info['body']) + self.assertEqual(len(info['headers']), 1) + bearer, token = info['headers'][b'Authorization'].split() + self.assertEqual(bearer, b'Bearer') + # To parse the token, skip the time check, since this + # test intentionally has stale tokens. + with mock.patch('oauth2client.crypt._verify_time_range', + return_value=True): + payload = crypt.verify_signed_jwt_with_certs( + token, certs, audience=self.url) + self.assertEqual(len(payload), 5) + self.assertEqual(payload['iss'], self.service_account_email) + self.assertEqual(payload['sub'], self.service_account_email) + self.assertEqual(payload['iat'], issued_at) + self.assertEqual(payload['exp'], exp_val) + self.assertEqual(payload['aud'], self.url) + @mock.patch('oauth2client.client._UTCNOW') def test_authorize_401(self, utcnow): utcnow.return_value = T1_DATE - h = http_mock.HttpMockSequence([ + http = http_mock.HttpMockSequence([ ({'status': http_client.OK}, b''), ({'status': http_client.UNAUTHORIZED}, b''), ({'status': http_client.OK}, b''), ]) - self.jwt.authorize(h) - h.request(self.url) + self.jwt.authorize(http) + transport.request(http, self.url) token_1 = self.jwt.access_token utcnow.return_value = T2_DATE - self.assertEquals(h.request(self.url)[0].status, 200) + response, _ = transport.request(http, self.url) + self.assertEquals(response.status, http_client.OK) token_2 = self.jwt.access_token # Check the 401 forced a new token self.assertNotEqual(token_1, token_2) + # Verify mocks. + certs = {'key': datafile('public_cert.pem')} + self.assertEqual(len(http.requests), 3) + issued_at_vals = (T1, T1, T2) + exp_vals = (T1_EXPIRY, T1_EXPIRY, T2_EXPIRY) + for info, issued_at, exp_val in zip(http.requests, issued_at_vals, + exp_vals): + self.assertEqual(info['uri'], self.url) + self.assertEqual(info['method'], 'GET') + self.assertIsNone(info['body']) + self.assertEqual(len(info['headers']), 1) + bearer, token = info['headers'][b'Authorization'].split() + self.assertEqual(bearer, b'Bearer') + # To parse the token, skip the time check, since this + # test intentionally has stale tokens. + with mock.patch('oauth2client.crypt._verify_time_range', + return_value=True): + payload = crypt.verify_signed_jwt_with_certs( + token, certs, audience=self.url) + self.assertEqual(len(payload), 5) + self.assertEqual(payload['iss'], self.service_account_email) + self.assertEqual(payload['sub'], self.service_account_email) + self.assertEqual(payload['iat'], issued_at) + self.assertEqual(payload['exp'], exp_val) + self.assertEqual(payload['aud'], self.url) + @mock.patch('oauth2client.client._UTCNOW') def test_refresh(self, utcnow): utcnow.return_value = T1_DATE