diff --git a/oauth2client/_openssl_crypt.py b/oauth2client/_openssl_crypt.py index d024cf3..5d9baca 100644 --- a/oauth2client/_openssl_crypt.py +++ b/oauth2client/_openssl_crypt.py @@ -112,7 +112,7 @@ class OpenSSLSigner(object): Raises: OpenSSL.crypto.Error if the key can't be parsed. """ - parsed_pem_key = _parse_pem_key(key) + parsed_pem_key = _parse_pem_key(_to_bytes(key)) if parsed_pem_key: pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key) else: diff --git a/oauth2client/crypt.py b/oauth2client/crypt.py index c871cbb..340ee4a 100644 --- a/oauth2client/crypt.py +++ b/oauth2client/crypt.py @@ -71,7 +71,7 @@ else: # pragma: NO COVER Verifier = RsaVerifier -def make_signed_jwt(signer, payload): +def make_signed_jwt(signer, payload, key_id=None): """Make a signed JWT. See http://self-issued.info/docs/draft-jones-json-web-token.html. @@ -79,11 +79,14 @@ def make_signed_jwt(signer, payload): Args: signer: crypt.Signer, Cryptographic signer. payload: dict, Dictionary of data to convert to JSON and then sign. + key_id: string, (Optional) Key ID header. Returns: string, The JWT for the payload. """ header = {'typ': 'JWT', 'alg': 'RS256'} + if key_id is not None: + header['kid'] = key_id segments = [ _urlsafe_b64encode(_json_encode(header)), @@ -91,7 +94,7 @@ def make_signed_jwt(signer, payload): ] signing_input = b'.'.join(segments) - signature = signer.sign(signing_input) + signature = signer.sign(signing_input).rstrip(b'=') segments.append(_urlsafe_b64encode(signature)) logger.debug(str(segments)) diff --git a/oauth2client/service_account.py b/oauth2client/service_account.py index 46ed04a..d087fc8 100644 --- a/oauth2client/service_account.py +++ b/oauth2client/service_account.py @@ -12,20 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""A service account credentials class. - -This credentials class is implemented on top of rsa library. -""" +"""oauth2client Service account credentials class.""" import base64 import datetime import json import time -from pyasn1.codec.ber import decoder -from pyasn1_modules.rfc5208 import PrivateKeyInfo -import rsa - from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_TOKEN_URI from oauth2client._helpers import _json_encode @@ -35,6 +28,7 @@ from oauth2client._helpers import _urlsafe_b64encode from oauth2client import util from oauth2client.client import AssertionCredentials from oauth2client.client import EXPIRY_FORMAT +from oauth2client import crypt class _ServiceAccountCredentials(AssertionCredentials): @@ -43,10 +37,9 @@ class _ServiceAccountCredentials(AssertionCredentials): MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds NON_SERIALIZED_MEMBERS = ( - frozenset(['_private_key']) | + frozenset(['_signer']) | AssertionCredentials.NON_SERIALIZED_MEMBERS) - def __init__(self, service_account_id, service_account_email, private_key_id, private_key_pkcs8_text, scopes, user_agent=None, token_uri=GOOGLE_TOKEN_URI, @@ -59,8 +52,8 @@ class _ServiceAccountCredentials(AssertionCredentials): self._service_account_id = service_account_id self._service_account_email = service_account_email self._private_key_id = private_key_id - self._private_key = _get_private_key(private_key_pkcs8_text) self._private_key_pkcs8_text = private_key_pkcs8_text + self._signer = crypt.Signer.from_string(self._private_key_pkcs8_text) self._scopes = util.scopes_to_string(scopes) self._user_agent = user_agent self._token_uri = token_uri @@ -69,39 +62,20 @@ class _ServiceAccountCredentials(AssertionCredentials): def _generate_assertion(self): """Generate the assertion that will be used in the request.""" - - header = { - 'alg': 'RS256', - 'typ': 'JWT', - 'kid': self._private_key_id - } - now = int(time.time()) payload = { 'aud': self._token_uri, 'scope': self._scopes, 'iat': now, - 'exp': now + _ServiceAccountCredentials.MAX_TOKEN_LIFETIME_SECS, - 'iss': self._service_account_email + 'exp': now + self.MAX_TOKEN_LIFETIME_SECS, + 'iss': self._service_account_email, } payload.update(self._kwargs) - - first_segment = _urlsafe_b64encode(_json_encode(header)) - second_segment = _urlsafe_b64encode(_json_encode(payload)) - assertion_input = first_segment + b'.' + second_segment - - # Sign the assertion. - rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key, - 'SHA-256') - signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=') - - return assertion_input + b'.' + signature + return crypt.make_signed_jwt(self._signer, payload, + key_id=self._private_key_id) def sign_blob(self, blob): - # Ensure that it is bytes - blob = _to_bytes(blob, encoding='utf-8') - return (self._private_key_id, - rsa.pkcs1.sign(blob, self._private_key, 'SHA-256')) + return self._private_key_id, self._signer.sign(blob) @property def service_account_email(self): @@ -149,13 +123,3 @@ class _ServiceAccountCredentials(AssertionCredentials): token_uri=self._token_uri, revoke_uri=self._revoke_uri, **self._kwargs) - - -def _get_private_key(private_key_pkcs8_text): - """Get an RSA private key object from a pkcs8 representation.""" - private_key_pkcs8_text = _to_bytes(private_key_pkcs8_text) - der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY') - asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo()) - return rsa.PrivateKey.load_pkcs1( - asn1_private_key.getComponentByName('privateKey').asOctets(), - format='DER') diff --git a/tests/test_client.py b/tests/test_client.py index 131e72e..708db4d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -628,15 +628,17 @@ class GoogleCredentialsTests(unittest2.TestCase): self.assertEqual(creds.__dict__, creds2.__dict__) def test_to_from_json_service_account(self): - self.maxDiff=None credentials_file = datafile( os.path.join('gcloud', _WELL_KNOWN_CREDENTIALS_FILE)) - creds = GoogleCredentials.from_stream(credentials_file) + creds1 = GoogleCredentials.from_stream(credentials_file) + # Convert to and then back from json. + creds2 = GoogleCredentials.from_json(creds1.to_json()) - json = creds.to_json() - creds2 = GoogleCredentials.from_json(json) - - self.assertEqual(creds.__dict__, creds2.__dict__) + creds1_vals = creds1.__dict__ + creds1_vals.pop('_signer') + creds2_vals = creds2.__dict__ + creds2_vals.pop('_signer') + self.assertEqual(creds1_vals, creds2_vals) def test_parse_expiry(self): dt = datetime.datetime(2016, 1, 1) diff --git a/tests/test_service_account.py b/tests/test_service_account.py index 09d6234..df48f31 100644 --- a/tests/test_service_account.py +++ b/tests/test_service_account.py @@ -91,13 +91,29 @@ class ServiceAccountCredentialsTests(unittest.TestCase): self.assertEqual('dummy_scope', new_credentials._scopes) @mock.patch('oauth2client.client._UTCNOW') - @mock.patch('rsa.pkcs1.sign', return_value=b'signed-value') - def test_access_token(self, sign_func, utcnow): + def test_access_token(self, utcnow): # Configure the patch. seconds = 11 NOW = datetime.datetime(1992, 12, 31, second=seconds) utcnow.return_value = NOW + # Create a custom credentials with a mock signer. + signer = mock.MagicMock() + signed_value = b'signed-content' + signer.sign = mock.MagicMock(name='sign', + return_value=signed_value) + signer_patch = mock.patch('oauth2client.crypt.Signer.from_string', + return_value=signer) + with signer_patch as signer_factory: + credentials = _ServiceAccountCredentials( + self.service_account_id, + self.service_account_email, + self.private_key_id, + self.private_key, + '', + ) + + # Begin testing. lifetime = 2 # number of seconds in which the token expires EXPIRY_TIME = datetime.datetime(1992, 12, 31, second=seconds + lifetime) @@ -120,51 +136,51 @@ class ServiceAccountCredentialsTests(unittest.TestCase): ]) # Get Access Token, First attempt. - self.assertEqual(self.credentials.access_token, None) - self.assertFalse(self.credentials.access_token_expired) - self.assertEqual(self.credentials.token_expiry, None) - token = self.credentials.get_access_token(http=http) - self.assertEqual(self.credentials.token_expiry, EXPIRY_TIME) + self.assertEqual(credentials.access_token, None) + self.assertFalse(credentials.access_token_expired) + self.assertEqual(credentials.token_expiry, None) + token = credentials.get_access_token(http=http) + self.assertEqual(credentials.token_expiry, EXPIRY_TIME) self.assertEqual(token1, token.access_token) self.assertEqual(lifetime, token.expires_in) self.assertEqual(token_response_first, - self.credentials.token_response) + credentials.token_response) # Two utcnow calls are expected: # - get_access_token() -> _do_refresh_request (setting expires in) # - get_access_token() -> _expires_in() expected_utcnow_calls = [mock.call()] * 2 self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) - # One rsa.pkcs1.sign expected: Actual refresh was needed. - self.assertEqual(len(sign_func.mock_calls), 1) + # One call to sign() expected: Actual refresh was needed. + self.assertEqual(len(signer.sign.mock_calls), 1) # Get Access Token, Second Attempt (not expired) - self.assertEqual(self.credentials.access_token, token1) - self.assertFalse(self.credentials.access_token_expired) - token = self.credentials.get_access_token(http=http) + self.assertEqual(credentials.access_token, token1) + self.assertFalse(credentials.access_token_expired) + token = credentials.get_access_token(http=http) # Make sure no refresh occurred since the token was not expired. self.assertEqual(token1, token.access_token) self.assertEqual(lifetime, token.expires_in) - self.assertEqual(token_response_first, self.credentials.token_response) + self.assertEqual(token_response_first, credentials.token_response) # Three more utcnow calls are expected: # - access_token_expired # - get_access_token() -> access_token_expired # - get_access_token -> _expires_in expected_utcnow_calls = [mock.call()] * (2 + 3) self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) - # No rsa.pkcs1.sign expected: the token was not expired. - self.assertEqual(len(sign_func.mock_calls), 1 + 0) + # No call to sign() expected: the token was not expired. + self.assertEqual(len(signer.sign.mock_calls), 1 + 0) # Get Access Token, Third Attempt (force expiration) - self.assertEqual(self.credentials.access_token, token1) - self.credentials.token_expiry = NOW # Manually force expiry. - self.assertTrue(self.credentials.access_token_expired) - token = self.credentials.get_access_token(http=http) + self.assertEqual(credentials.access_token, token1) + credentials.token_expiry = NOW # Manually force expiry. + self.assertTrue(credentials.access_token_expired) + token = credentials.get_access_token(http=http) # Make sure refresh occurred since the token was not expired. self.assertEqual(token2, token.access_token) self.assertEqual(lifetime, token.expires_in) - self.assertFalse(self.credentials.access_token_expired) + self.assertFalse(credentials.access_token_expired) self.assertEqual(token_response_second, - self.credentials.token_response) + credentials.token_response) # Five more utcnow calls are expected: # - access_token_expired # - get_access_token -> access_token_expired @@ -173,10 +189,10 @@ class ServiceAccountCredentialsTests(unittest.TestCase): # - access_token_expired expected_utcnow_calls = [mock.call()] * (2 + 3 + 5) self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) - # One more rsa.pkcs1.sign expected: Actual refresh was needed. - self.assertEqual(len(sign_func.mock_calls), 1 + 0 + 1) + # One more call to sign() expected: Actual refresh was needed. + self.assertEqual(len(signer.sign.mock_calls), 1 + 0 + 1) - self.assertEqual(self.credentials.access_token, token2) + self.assertEqual(credentials.access_token, token2) if __name__ == '__main__': # pragma: NO COVER