From d13fc7dbbd00ce411f734653fb711e48dff5ba6d Mon Sep 17 00:00:00 2001 From: Danny Hermes Date: Wed, 17 Feb 2016 14:22:10 -0800 Subject: [PATCH] Implement ServiceAccountCredentials.from_p12_keyfile_buffer(). Fixes #412. --- oauth2client/service_account.py | 71 +++++++++++++++++++++++++++++---- tests/test_service_account.py | 34 +++++++++++++--- 2 files changed, 92 insertions(+), 13 deletions(-) diff --git a/oauth2client/service_account.py b/oauth2client/service_account.py index f18f192..3c9bffe 100644 --- a/oauth2client/service_account.py +++ b/oauth2client/service_account.py @@ -215,6 +215,38 @@ class ServiceAccountCredentials(AssertionCredentials): """ return cls._from_parsed_json_keyfile(keyfile_dict, scopes) + @classmethod + def _from_p12_keyfile_contents(cls, service_account_email, + private_key_pkcs12, + private_key_password=None, scopes=''): + """Factory constructor from JSON keyfile. + + Args: + service_account_email: string, The email associated with the + service account. + private_key_pkcs12: string, The contents of a PKCS#12 keyfile. + private_key_password: string, (Optional) Password for PKCS#12 + private key. Defaults to ``notasecret``. + scopes: List or string, (Optional) Scopes to use when acquiring an + access token. + + Returns: + ServiceAccountCredentials, a credentials object created from + the keyfile. + + Raises: + NotImplementedError if pyOpenSSL is not installed / not the + active crypto library. + """ + if private_key_password is None: + private_key_password = _PASSWORD_DEFAULT + signer = crypt.Signer.from_string(private_key_pkcs12, + private_key_password) + credentials = cls(service_account_email, signer, scopes=scopes) + credentials._private_key_pkcs12 = private_key_pkcs12 + credentials._private_key_password = private_key_password + return credentials + @classmethod def from_p12_keyfile(cls, service_account_email, filename, private_key_password=None, scopes=''): @@ -239,14 +271,37 @@ class ServiceAccountCredentials(AssertionCredentials): """ with open(filename, 'rb') as file_obj: private_key_pkcs12 = file_obj.read() - if private_key_password is None: - private_key_password = _PASSWORD_DEFAULT - signer = crypt.Signer.from_string(private_key_pkcs12, - private_key_password) - credentials = cls(service_account_email, signer, scopes=scopes) - credentials._private_key_pkcs12 = private_key_pkcs12 - credentials._private_key_password = private_key_password - return credentials + return cls._from_p12_keyfile_contents( + service_account_email, private_key_pkcs12, + private_key_password=private_key_password, scopes=scopes) + + @classmethod + def from_p12_keyfile_buffer(cls, service_account_email, file_buffer, + private_key_password=None, scopes=''): + """Factory constructor from JSON keyfile. + + Args: + service_account_email: string, The email associated with the + service account. + file_buffer: stream, A buffer that implements ``read()`` + and contains the PKCS#12 key contents. + private_key_password: string, (Optional) Password for PKCS#12 + private key. Defaults to ``notasecret``. + scopes: List or string, (Optional) Scopes to use when acquiring an + access token. + + Returns: + ServiceAccountCredentials, a credentials object created from + the keyfile. + + Raises: + NotImplementedError if pyOpenSSL is not installed / not the + active crypto library. + """ + private_key_pkcs12 = file_buffer.read() + return cls._from_p12_keyfile_contents( + service_account_email, private_key_pkcs12, + private_key_password=private_key_password, scopes=scopes) def _generate_assertion(self): """Generate the assertion that will be used in the request.""" diff --git a/tests/test_service_account.py b/tests/test_service_account.py index 3c91d19..0b03319 100644 --- a/tests/test_service_account.py +++ b/tests/test_service_account.py @@ -156,10 +156,10 @@ class ServiceAccountCredentialsTests(unittest2.TestCase): private_key_password=private_key_password, scopes=scopes) self.assertIsInstance(creds, ServiceAccountCredentials) - self.assertEqual(creds.client_id, None) + self.assertIsNone(creds.client_id) self.assertEqual(creds._service_account_email, service_account_email) - self.assertEqual(creds._private_key_id, None) - self.assertEqual(creds._private_key_pkcs8_pem, None) + self.assertIsNone(creds._private_key_id) + self.assertIsNone(creds._private_key_pkcs8_pem) self.assertEqual(creds._private_key_pkcs12, key_contents) if private_key_password is not None: self.assertEqual(creds._private_key_password, private_key_password) @@ -173,6 +173,30 @@ class ServiceAccountCredentialsTests(unittest2.TestCase): self._from_p12_keyfile_helper(private_key_password=password, scopes=['foo', 'bar']) + def test_from_p12_keyfile_buffer(self): + service_account_email = 'name@email.com' + filename = data_filename('privatekey.p12') + private_key_password = 'notasecret' + scopes = ['foo', 'bar'] + with open(filename, 'rb') as file_obj: + key_contents = file_obj.read() + # Seek back to the beginning so the buffer can be + # passed to the constructor. + file_obj.seek(0) + creds = ServiceAccountCredentials.from_p12_keyfile_buffer( + service_account_email, file_obj, + private_key_password=private_key_password, + scopes=scopes) + # Check the created object. + self.assertIsInstance(creds, ServiceAccountCredentials) + self.assertIsNone(creds.client_id) + self.assertEqual(creds._service_account_email, service_account_email) + self.assertIsNone(creds._private_key_id) + self.assertIsNone(creds._private_key_pkcs8_pem) + self.assertEqual(creds._private_key_pkcs12, key_contents) + self.assertEqual(creds._private_key_password, private_key_password) + self.assertEqual(creds._scopes, ' '.join(scopes)) + def test_create_scoped_required_without_scopes(self): self.assertTrue(self.credentials.create_scoped_required()) @@ -236,9 +260,9 @@ class ServiceAccountCredentialsTests(unittest2.TestCase): ]) # Get Access Token, First attempt. - self.assertEqual(credentials.access_token, None) + self.assertIsNone(credentials.access_token) self.assertFalse(credentials.access_token_expired) - self.assertEqual(credentials.token_expiry, None) + self.assertIsNone(credentials.token_expiry) token = credentials.get_access_token(http=http) self.assertEqual(credentials.token_expiry, EXPIRY_TIME) self.assertEqual(token1, token.access_token)