Removing novel crypto in service_account.py.
This commit is contained in:
@@ -112,7 +112,7 @@ class OpenSSLSigner(object):
|
||||
Raises:
|
||||
OpenSSL.crypto.Error if the key can't be parsed.
|
||||
"""
|
||||
parsed_pem_key = _parse_pem_key(key)
|
||||
parsed_pem_key = _parse_pem_key(_to_bytes(key))
|
||||
if parsed_pem_key:
|
||||
pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key)
|
||||
else:
|
||||
|
||||
@@ -71,7 +71,7 @@ else: # pragma: NO COVER
|
||||
Verifier = RsaVerifier
|
||||
|
||||
|
||||
def make_signed_jwt(signer, payload):
|
||||
def make_signed_jwt(signer, payload, key_id=None):
|
||||
"""Make a signed JWT.
|
||||
|
||||
See http://self-issued.info/docs/draft-jones-json-web-token.html.
|
||||
@@ -79,11 +79,14 @@ def make_signed_jwt(signer, payload):
|
||||
Args:
|
||||
signer: crypt.Signer, Cryptographic signer.
|
||||
payload: dict, Dictionary of data to convert to JSON and then sign.
|
||||
key_id: string, (Optional) Key ID header.
|
||||
|
||||
Returns:
|
||||
string, The JWT for the payload.
|
||||
"""
|
||||
header = {'typ': 'JWT', 'alg': 'RS256'}
|
||||
if key_id is not None:
|
||||
header['kid'] = key_id
|
||||
|
||||
segments = [
|
||||
_urlsafe_b64encode(_json_encode(header)),
|
||||
@@ -91,7 +94,7 @@ def make_signed_jwt(signer, payload):
|
||||
]
|
||||
signing_input = b'.'.join(segments)
|
||||
|
||||
signature = signer.sign(signing_input)
|
||||
signature = signer.sign(signing_input).rstrip(b'=')
|
||||
segments.append(_urlsafe_b64encode(signature))
|
||||
|
||||
logger.debug(str(segments))
|
||||
|
||||
@@ -12,20 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A service account credentials class.
|
||||
|
||||
This credentials class is implemented on top of rsa library.
|
||||
"""
|
||||
"""oauth2client Service account credentials class."""
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
|
||||
from pyasn1.codec.ber import decoder
|
||||
from pyasn1_modules.rfc5208 import PrivateKeyInfo
|
||||
import rsa
|
||||
|
||||
from oauth2client import GOOGLE_REVOKE_URI
|
||||
from oauth2client import GOOGLE_TOKEN_URI
|
||||
from oauth2client._helpers import _json_encode
|
||||
@@ -35,6 +28,7 @@ from oauth2client._helpers import _urlsafe_b64encode
|
||||
from oauth2client import util
|
||||
from oauth2client.client import AssertionCredentials
|
||||
from oauth2client.client import EXPIRY_FORMAT
|
||||
from oauth2client import crypt
|
||||
|
||||
|
||||
class _ServiceAccountCredentials(AssertionCredentials):
|
||||
@@ -43,10 +37,9 @@ class _ServiceAccountCredentials(AssertionCredentials):
|
||||
MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
|
||||
|
||||
NON_SERIALIZED_MEMBERS = (
|
||||
frozenset(['_private_key']) |
|
||||
frozenset(['_signer']) |
|
||||
AssertionCredentials.NON_SERIALIZED_MEMBERS)
|
||||
|
||||
|
||||
def __init__(self, service_account_id, service_account_email,
|
||||
private_key_id, private_key_pkcs8_text, scopes,
|
||||
user_agent=None, token_uri=GOOGLE_TOKEN_URI,
|
||||
@@ -59,8 +52,8 @@ class _ServiceAccountCredentials(AssertionCredentials):
|
||||
self._service_account_id = service_account_id
|
||||
self._service_account_email = service_account_email
|
||||
self._private_key_id = private_key_id
|
||||
self._private_key = _get_private_key(private_key_pkcs8_text)
|
||||
self._private_key_pkcs8_text = private_key_pkcs8_text
|
||||
self._signer = crypt.Signer.from_string(self._private_key_pkcs8_text)
|
||||
self._scopes = util.scopes_to_string(scopes)
|
||||
self._user_agent = user_agent
|
||||
self._token_uri = token_uri
|
||||
@@ -69,39 +62,20 @@ class _ServiceAccountCredentials(AssertionCredentials):
|
||||
|
||||
def _generate_assertion(self):
|
||||
"""Generate the assertion that will be used in the request."""
|
||||
|
||||
header = {
|
||||
'alg': 'RS256',
|
||||
'typ': 'JWT',
|
||||
'kid': self._private_key_id
|
||||
}
|
||||
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
'aud': self._token_uri,
|
||||
'scope': self._scopes,
|
||||
'iat': now,
|
||||
'exp': now + _ServiceAccountCredentials.MAX_TOKEN_LIFETIME_SECS,
|
||||
'iss': self._service_account_email
|
||||
'exp': now + self.MAX_TOKEN_LIFETIME_SECS,
|
||||
'iss': self._service_account_email,
|
||||
}
|
||||
payload.update(self._kwargs)
|
||||
|
||||
first_segment = _urlsafe_b64encode(_json_encode(header))
|
||||
second_segment = _urlsafe_b64encode(_json_encode(payload))
|
||||
assertion_input = first_segment + b'.' + second_segment
|
||||
|
||||
# Sign the assertion.
|
||||
rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key,
|
||||
'SHA-256')
|
||||
signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=')
|
||||
|
||||
return assertion_input + b'.' + signature
|
||||
return crypt.make_signed_jwt(self._signer, payload,
|
||||
key_id=self._private_key_id)
|
||||
|
||||
def sign_blob(self, blob):
|
||||
# Ensure that it is bytes
|
||||
blob = _to_bytes(blob, encoding='utf-8')
|
||||
return (self._private_key_id,
|
||||
rsa.pkcs1.sign(blob, self._private_key, 'SHA-256'))
|
||||
return self._private_key_id, self._signer.sign(blob)
|
||||
|
||||
@property
|
||||
def service_account_email(self):
|
||||
@@ -149,13 +123,3 @@ class _ServiceAccountCredentials(AssertionCredentials):
|
||||
token_uri=self._token_uri,
|
||||
revoke_uri=self._revoke_uri,
|
||||
**self._kwargs)
|
||||
|
||||
|
||||
def _get_private_key(private_key_pkcs8_text):
|
||||
"""Get an RSA private key object from a pkcs8 representation."""
|
||||
private_key_pkcs8_text = _to_bytes(private_key_pkcs8_text)
|
||||
der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY')
|
||||
asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo())
|
||||
return rsa.PrivateKey.load_pkcs1(
|
||||
asn1_private_key.getComponentByName('privateKey').asOctets(),
|
||||
format='DER')
|
||||
|
||||
@@ -628,15 +628,17 @@ class GoogleCredentialsTests(unittest2.TestCase):
|
||||
self.assertEqual(creds.__dict__, creds2.__dict__)
|
||||
|
||||
def test_to_from_json_service_account(self):
|
||||
self.maxDiff=None
|
||||
credentials_file = datafile(
|
||||
os.path.join('gcloud', _WELL_KNOWN_CREDENTIALS_FILE))
|
||||
creds = GoogleCredentials.from_stream(credentials_file)
|
||||
creds1 = GoogleCredentials.from_stream(credentials_file)
|
||||
# Convert to and then back from json.
|
||||
creds2 = GoogleCredentials.from_json(creds1.to_json())
|
||||
|
||||
json = creds.to_json()
|
||||
creds2 = GoogleCredentials.from_json(json)
|
||||
|
||||
self.assertEqual(creds.__dict__, creds2.__dict__)
|
||||
creds1_vals = creds1.__dict__
|
||||
creds1_vals.pop('_signer')
|
||||
creds2_vals = creds2.__dict__
|
||||
creds2_vals.pop('_signer')
|
||||
self.assertEqual(creds1_vals, creds2_vals)
|
||||
|
||||
def test_parse_expiry(self):
|
||||
dt = datetime.datetime(2016, 1, 1)
|
||||
|
||||
@@ -91,13 +91,29 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
|
||||
self.assertEqual('dummy_scope', new_credentials._scopes)
|
||||
|
||||
@mock.patch('oauth2client.client._UTCNOW')
|
||||
@mock.patch('rsa.pkcs1.sign', return_value=b'signed-value')
|
||||
def test_access_token(self, sign_func, utcnow):
|
||||
def test_access_token(self, utcnow):
|
||||
# Configure the patch.
|
||||
seconds = 11
|
||||
NOW = datetime.datetime(1992, 12, 31, second=seconds)
|
||||
utcnow.return_value = NOW
|
||||
|
||||
# Create a custom credentials with a mock signer.
|
||||
signer = mock.MagicMock()
|
||||
signed_value = b'signed-content'
|
||||
signer.sign = mock.MagicMock(name='sign',
|
||||
return_value=signed_value)
|
||||
signer_patch = mock.patch('oauth2client.crypt.Signer.from_string',
|
||||
return_value=signer)
|
||||
with signer_patch as signer_factory:
|
||||
credentials = _ServiceAccountCredentials(
|
||||
self.service_account_id,
|
||||
self.service_account_email,
|
||||
self.private_key_id,
|
||||
self.private_key,
|
||||
'',
|
||||
)
|
||||
|
||||
# Begin testing.
|
||||
lifetime = 2 # number of seconds in which the token expires
|
||||
EXPIRY_TIME = datetime.datetime(1992, 12, 31,
|
||||
second=seconds + lifetime)
|
||||
@@ -120,51 +136,51 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
|
||||
])
|
||||
|
||||
# Get Access Token, First attempt.
|
||||
self.assertEqual(self.credentials.access_token, None)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertEqual(self.credentials.token_expiry, None)
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual(self.credentials.token_expiry, EXPIRY_TIME)
|
||||
self.assertEqual(credentials.access_token, None)
|
||||
self.assertFalse(credentials.access_token_expired)
|
||||
self.assertEqual(credentials.token_expiry, None)
|
||||
token = credentials.get_access_token(http=http)
|
||||
self.assertEqual(credentials.token_expiry, EXPIRY_TIME)
|
||||
self.assertEqual(token1, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertEqual(token_response_first,
|
||||
self.credentials.token_response)
|
||||
credentials.token_response)
|
||||
# Two utcnow calls are expected:
|
||||
# - get_access_token() -> _do_refresh_request (setting expires in)
|
||||
# - get_access_token() -> _expires_in()
|
||||
expected_utcnow_calls = [mock.call()] * 2
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
# One rsa.pkcs1.sign expected: Actual refresh was needed.
|
||||
self.assertEqual(len(sign_func.mock_calls), 1)
|
||||
# One call to sign() expected: Actual refresh was needed.
|
||||
self.assertEqual(len(signer.sign.mock_calls), 1)
|
||||
|
||||
# Get Access Token, Second Attempt (not expired)
|
||||
self.assertEqual(self.credentials.access_token, token1)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual(credentials.access_token, token1)
|
||||
self.assertFalse(credentials.access_token_expired)
|
||||
token = credentials.get_access_token(http=http)
|
||||
# Make sure no refresh occurred since the token was not expired.
|
||||
self.assertEqual(token1, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertEqual(token_response_first, self.credentials.token_response)
|
||||
self.assertEqual(token_response_first, credentials.token_response)
|
||||
# Three more utcnow calls are expected:
|
||||
# - access_token_expired
|
||||
# - get_access_token() -> access_token_expired
|
||||
# - get_access_token -> _expires_in
|
||||
expected_utcnow_calls = [mock.call()] * (2 + 3)
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
# No rsa.pkcs1.sign expected: the token was not expired.
|
||||
self.assertEqual(len(sign_func.mock_calls), 1 + 0)
|
||||
# No call to sign() expected: the token was not expired.
|
||||
self.assertEqual(len(signer.sign.mock_calls), 1 + 0)
|
||||
|
||||
# Get Access Token, Third Attempt (force expiration)
|
||||
self.assertEqual(self.credentials.access_token, token1)
|
||||
self.credentials.token_expiry = NOW # Manually force expiry.
|
||||
self.assertTrue(self.credentials.access_token_expired)
|
||||
token = self.credentials.get_access_token(http=http)
|
||||
self.assertEqual(credentials.access_token, token1)
|
||||
credentials.token_expiry = NOW # Manually force expiry.
|
||||
self.assertTrue(credentials.access_token_expired)
|
||||
token = credentials.get_access_token(http=http)
|
||||
# Make sure refresh occurred since the token was not expired.
|
||||
self.assertEqual(token2, token.access_token)
|
||||
self.assertEqual(lifetime, token.expires_in)
|
||||
self.assertFalse(self.credentials.access_token_expired)
|
||||
self.assertFalse(credentials.access_token_expired)
|
||||
self.assertEqual(token_response_second,
|
||||
self.credentials.token_response)
|
||||
credentials.token_response)
|
||||
# Five more utcnow calls are expected:
|
||||
# - access_token_expired
|
||||
# - get_access_token -> access_token_expired
|
||||
@@ -173,10 +189,10 @@ class ServiceAccountCredentialsTests(unittest.TestCase):
|
||||
# - access_token_expired
|
||||
expected_utcnow_calls = [mock.call()] * (2 + 3 + 5)
|
||||
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
|
||||
# One more rsa.pkcs1.sign expected: Actual refresh was needed.
|
||||
self.assertEqual(len(sign_func.mock_calls), 1 + 0 + 1)
|
||||
# One more call to sign() expected: Actual refresh was needed.
|
||||
self.assertEqual(len(signer.sign.mock_calls), 1 + 0 + 1)
|
||||
|
||||
self.assertEqual(self.credentials.access_token, token2)
|
||||
self.assertEqual(credentials.access_token, token2)
|
||||
|
||||
|
||||
if __name__ == '__main__': # pragma: NO COVER
|
||||
|
||||
Reference in New Issue
Block a user