diff --git a/oauth2client/service_account.py b/oauth2client/service_account.py index 8fca83b..ce7f78e 100644 --- a/oauth2client/service_account.py +++ b/oauth2client/service_account.py @@ -81,6 +81,12 @@ class ServiceAccountCredentials(AssertionCredentials): service account. user_agent: string, (Optional) User agent to use when sending request. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. kwargs: dict, Extra key-value pairs (both strings) to send in the payload body when making an assertion. """ @@ -106,10 +112,13 @@ class ServiceAccountCredentials(AssertionCredentials): private_key_id=None, client_id=None, user_agent=None, + token_uri=GOOGLE_TOKEN_URI, + revoke_uri=GOOGLE_REVOKE_URI, **kwargs): super(ServiceAccountCredentials, self).__init__( - None, user_agent=user_agent) + None, user_agent=user_agent, token_uri=token_uri, + revoke_uri=revoke_uri) self._service_account_email = service_account_email self._signer = signer @@ -145,7 +154,8 @@ class ServiceAccountCredentials(AssertionCredentials): strip, to_serialize=to_serialize) @classmethod - def _from_parsed_json_keyfile(cls, keyfile_dict, scopes): + def _from_parsed_json_keyfile(cls, keyfile_dict, scopes, + token_uri=None, revoke_uri=None): """Helper for factory constructors from JSON keyfile. Args: @@ -153,6 +163,12 @@ class ServiceAccountCredentials(AssertionCredentials): containing the contents of the JSON keyfile. scopes: List or string, Scopes to use when acquiring an access token. + token_uri: string, URI for OAuth 2.0 provider token endpoint. + If unset and not present in keyfile_dict, defaults + to Google's endpoints. + revoke_uri: string, URI for OAuth 2.0 provider revoke endpoint. + If unset and not present in keyfile_dict, defaults + to Google's endpoints. Returns: ServiceAccountCredentials, a credentials object created from @@ -172,22 +188,35 @@ class ServiceAccountCredentials(AssertionCredentials): private_key_pkcs8_pem = keyfile_dict['private_key'] private_key_id = keyfile_dict['private_key_id'] client_id = keyfile_dict['client_id'] + if not token_uri: + token_uri = keyfile_dict.get('token_uri', GOOGLE_TOKEN_URI) + if not revoke_uri: + revoke_uri = keyfile_dict.get('revoke_uri', GOOGLE_REVOKE_URI) signer = crypt.Signer.from_string(private_key_pkcs8_pem) credentials = cls(service_account_email, signer, scopes=scopes, private_key_id=private_key_id, - client_id=client_id) + client_id=client_id, token_uri=token_uri, + revoke_uri=revoke_uri) credentials._private_key_pkcs8_pem = private_key_pkcs8_pem return credentials @classmethod - def from_json_keyfile_name(cls, filename, scopes=''): + def from_json_keyfile_name(cls, filename, scopes='', + token_uri=None, revoke_uri=None): + """Factory constructor from JSON keyfile by name. Args: filename: string, The location of the keyfile. scopes: List or string, (Optional) Scopes to use when acquiring an access token. + token_uri: string, URI for OAuth 2.0 provider token endpoint. + If unset and not present in the key file, defaults + to Google's endpoints. + revoke_uri: string, URI for OAuth 2.0 provider revoke endpoint. + If unset and not present in the key file, defaults + to Google's endpoints. Returns: ServiceAccountCredentials, a credentials object created from @@ -200,10 +229,13 @@ class ServiceAccountCredentials(AssertionCredentials): """ with open(filename, 'r') as file_obj: client_credentials = json.load(file_obj) - return cls._from_parsed_json_keyfile(client_credentials, scopes) + return cls._from_parsed_json_keyfile(client_credentials, scopes, + token_uri=token_uri, + revoke_uri=revoke_uri) @classmethod - def from_json_keyfile_dict(cls, keyfile_dict, scopes=''): + def from_json_keyfile_dict(cls, keyfile_dict, scopes='', + token_uri=None, revoke_uri=None): """Factory constructor from parsed JSON keyfile. Args: @@ -211,6 +243,12 @@ class ServiceAccountCredentials(AssertionCredentials): containing the contents of the JSON keyfile. scopes: List or string, (Optional) Scopes to use when acquiring an access token. + token_uri: string, URI for OAuth 2.0 provider token endpoint. + If unset and not present in keyfile_dict, defaults + to Google's endpoints. + revoke_uri: string, URI for OAuth 2.0 provider revoke endpoint. + If unset and not present in keyfile_dict, defaults + to Google's endpoints. Returns: ServiceAccountCredentials, a credentials object created from @@ -221,12 +259,16 @@ class ServiceAccountCredentials(AssertionCredentials): KeyError, if one of the expected keys is not present in the keyfile. """ - return cls._from_parsed_json_keyfile(keyfile_dict, scopes) + return cls._from_parsed_json_keyfile(keyfile_dict, scopes, + token_uri=token_uri, + revoke_uri=revoke_uri) @classmethod def _from_p12_keyfile_contents(cls, service_account_email, private_key_pkcs12, - private_key_password=None, scopes=''): + private_key_password=None, scopes='', + token_uri=GOOGLE_TOKEN_URI, + revoke_uri=GOOGLE_REVOKE_URI): """Factory constructor from JSON keyfile. Args: @@ -237,6 +279,12 @@ class ServiceAccountCredentials(AssertionCredentials): private key. Defaults to ``notasecret``. scopes: List or string, (Optional) Scopes to use when acquiring an access token. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 + provider can be used. Returns: ServiceAccountCredentials, a credentials object created from @@ -252,14 +300,18 @@ class ServiceAccountCredentials(AssertionCredentials): raise NotImplementedError(_PKCS12_ERROR) signer = crypt.Signer.from_string(private_key_pkcs12, private_key_password) - credentials = cls(service_account_email, signer, scopes=scopes) + credentials = cls(service_account_email, signer, scopes=scopes, + token_uri=token_uri, revoke_uri=revoke_uri) credentials._private_key_pkcs12 = private_key_pkcs12 credentials._private_key_password = private_key_password return credentials @classmethod def from_p12_keyfile(cls, service_account_email, filename, - private_key_password=None, scopes=''): + private_key_password=None, scopes='', + token_uri=GOOGLE_TOKEN_URI, + revoke_uri=GOOGLE_REVOKE_URI): + """Factory constructor from JSON keyfile. Args: @@ -270,6 +322,12 @@ class ServiceAccountCredentials(AssertionCredentials): private key. Defaults to ``notasecret``. scopes: List or string, (Optional) Scopes to use when acquiring an access token. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 + provider can be used. Returns: ServiceAccountCredentials, a credentials object created from @@ -283,11 +341,14 @@ class ServiceAccountCredentials(AssertionCredentials): private_key_pkcs12 = file_obj.read() return cls._from_p12_keyfile_contents( service_account_email, private_key_pkcs12, - private_key_password=private_key_password, scopes=scopes) + private_key_password=private_key_password, scopes=scopes, + token_uri=token_uri, revoke_uri=revoke_uri) @classmethod def from_p12_keyfile_buffer(cls, service_account_email, file_buffer, - private_key_password=None, scopes=''): + private_key_password=None, scopes='', + token_uri=GOOGLE_TOKEN_URI, + revoke_uri=GOOGLE_REVOKE_URI): """Factory constructor from JSON keyfile. Args: @@ -299,6 +360,12 @@ class ServiceAccountCredentials(AssertionCredentials): private key. Defaults to ``notasecret``. scopes: List or string, (Optional) Scopes to use when acquiring an access token. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 + provider can be used. Returns: ServiceAccountCredentials, a credentials object created from @@ -311,7 +378,8 @@ class ServiceAccountCredentials(AssertionCredentials): private_key_pkcs12 = file_buffer.read() return cls._from_p12_keyfile_contents( service_account_email, private_key_pkcs12, - private_key_password=private_key_password, scopes=scopes) + private_key_password=private_key_password, scopes=scopes, + token_uri=token_uri, revoke_uri=revoke_uri) def _generate_assertion(self): """Generate the assertion that will be used in the request.""" @@ -508,6 +576,8 @@ class _JWTAccessCredentials(ServiceAccountCredentials): private_key_id=None, client_id=None, user_agent=None, + token_uri=GOOGLE_TOKEN_URI, + revoke_uri=GOOGLE_REVOKE_URI, additional_claims=None): if additional_claims is None: additional_claims = {} @@ -517,6 +587,8 @@ class _JWTAccessCredentials(ServiceAccountCredentials): private_key_id=private_key_id, client_id=client_id, user_agent=user_agent, + token_uri=token_uri, + revoke_uri=revoke_uri, **additional_claims) def authorize(self, http): @@ -595,7 +667,8 @@ class _JWTAccessCredentials(ServiceAccountCredentials): # JWTAccessCredentials are unscoped by definition return True - def create_scoped(self, scopes): + def create_scoped(self, scopes, token_uri=GOOGLE_TOKEN_URI, + revoke_uri=GOOGLE_REVOKE_URI): # Returns an OAuth2 credentials with the given scope result = ServiceAccountCredentials(self._service_account_email, self._signer, @@ -603,9 +676,9 @@ class _JWTAccessCredentials(ServiceAccountCredentials): private_key_id=self._private_key_id, client_id=self.client_id, user_agent=self._user_agent, + token_uri=token_uri, + revoke_uri=revoke_uri, **self._kwargs) - result.token_uri = self.token_uri - result.revoke_uri = self.revoke_uri if self._private_key_pkcs8_pem is not None: result._private_key_pkcs8_pem = self._private_key_pkcs8_pem if self._private_key_pkcs12 is not None: diff --git a/tests/test_service_account.py b/tests/test_service_account.py index 9697bc5..fe9b795 100644 --- a/tests/test_service_account.py +++ b/tests/test_service_account.py @@ -33,6 +33,8 @@ from oauth2client.service_account import _JWTAccessCredentials from oauth2client.service_account import ServiceAccountCredentials from oauth2client.service_account import SERVICE_ACCOUNT +from six import BytesIO + def data_filename(filename): return os.path.join(os.path.dirname(__file__), 'data', filename) @@ -96,14 +98,16 @@ class ServiceAccountCredentialsTests(unittest2.TestCase): self.credentials.service_account_email) @staticmethod - def _from_json_keyfile_name_helper(payload, scopes=None): + def _from_json_keyfile_name_helper(payload, scopes=None, + token_uri=None, revoke_uri=None): filehandle, filename = tempfile.mkstemp() os.close(filehandle) try: with open(filename, 'w') as file_obj: json.dump(payload, file_obj) return ServiceAccountCredentials.from_json_keyfile_name( - filename, scopes=scopes) + filename, scopes=scopes, token_uri=token_uri, + revoke_uri=revoke_uri) finally: os.remove(filename) @@ -122,17 +126,27 @@ class ServiceAccountCredentialsTests(unittest2.TestCase): 'private_key': private_key, } scopes = ['foo', 'bar'] - creds = self._from_json_keyfile_name_helper(payload, scopes=scopes) - self.assertIsInstance(creds, ServiceAccountCredentials) - self.assertEqual(creds.client_id, client_id) - self.assertEqual(creds._service_account_email, client_email) - self.assertEqual(creds._private_key_id, private_key_id) - self.assertEqual(creds._private_key_pkcs8_pem, private_key) - self.assertEqual(creds._scopes, ' '.join(scopes)) - # Check stub. - self.assertEqual(creds._signer, signer_factory.return_value) + token_uri = 'baz' + revoke_uri = 'qux' + base_creds = self._from_json_keyfile_name_helper( + payload, scopes=scopes, token_uri=token_uri, revoke_uri=revoke_uri) + self.assertEqual(base_creds._signer, signer_factory.return_value) signer_factory.assert_called_once_with(private_key) + payload['token_uri'] = token_uri + payload['revoke_uri'] = revoke_uri + creds_with_uris_from_file = self._from_json_keyfile_name_helper( + payload, scopes=scopes) + for creds in (base_creds, creds_with_uris_from_file): + self.assertIsInstance(creds, ServiceAccountCredentials) + self.assertEqual(creds.client_id, client_id) + self.assertEqual(creds._service_account_email, client_email) + self.assertEqual(creds._private_key_id, private_key_id) + self.assertEqual(creds._private_key_pkcs8_pem, private_key) + self.assertEqual(creds._scopes, ' '.join(scopes)) + self.assertEqual(creds.token_uri, token_uri) + self.assertEqual(creds.revoke_uri, revoke_uri) + def test_from_json_keyfile_name_factory_bad_type(self): type_ = 'bad-type' self.assertNotEqual(type_, SERVICE_ACCOUNT) @@ -148,24 +162,33 @@ class ServiceAccountCredentialsTests(unittest2.TestCase): with self.assertRaises(KeyError): self._from_json_keyfile_name_helper(payload) - def _from_p12_keyfile_helper(self, private_key_password=None, scopes=''): + def _from_p12_keyfile_helper(self, private_key_password=None, scopes='', + token_uri=None, revoke_uri=None): service_account_email = 'name@email.com' filename = data_filename('privatekey.p12') with open(filename, 'rb') as file_obj: key_contents = file_obj.read() - creds = ServiceAccountCredentials.from_p12_keyfile( + creds_from_filename = ServiceAccountCredentials.from_p12_keyfile( service_account_email, filename, private_key_password=private_key_password, - scopes=scopes) - self.assertIsInstance(creds, ServiceAccountCredentials) - self.assertIsNone(creds.client_id) - self.assertEqual(creds._service_account_email, service_account_email) - self.assertIsNone(creds._private_key_id) - self.assertIsNone(creds._private_key_pkcs8_pem) - self.assertEqual(creds._private_key_pkcs12, key_contents) - if private_key_password is not None: - self.assertEqual(creds._private_key_password, private_key_password) - self.assertEqual(creds._scopes, ' '.join(scopes)) + scopes=scopes, token_uri=token_uri, revoke_uri=revoke_uri) + creds_from_file_contents = ( + ServiceAccountCredentials.from_p12_keyfile_buffer( + service_account_email, BytesIO(key_contents), + private_key_password=private_key_password, + scopes=scopes, token_uri=token_uri, revoke_uri=revoke_uri)) + for creds in (creds_from_filename, creds_from_file_contents): + self.assertIsInstance(creds, ServiceAccountCredentials) + self.assertIsNone(creds.client_id) + self.assertEqual(creds._service_account_email, service_account_email) + self.assertIsNone(creds._private_key_id) + self.assertIsNone(creds._private_key_pkcs8_pem) + self.assertEqual(creds._private_key_pkcs12, key_contents) + if private_key_password is not None: + self.assertEqual(creds._private_key_password, private_key_password) + self.assertEqual(creds._scopes, ' '.join(scopes)) + self.assertEqual(creds.token_uri, token_uri) + self.assertEqual(creds.revoke_uri, revoke_uri) def _p12_not_implemented_helper(self): service_account_email = 'name@email.com' @@ -188,31 +211,8 @@ class ServiceAccountCredentialsTests(unittest2.TestCase): def test_from_p12_keyfile_explicit(self): password = 'notasecret' self._from_p12_keyfile_helper(private_key_password=password, - scopes=['foo', 'bar']) - - def test_from_p12_keyfile_buffer(self): - service_account_email = 'name@email.com' - filename = data_filename('privatekey.p12') - private_key_password = 'notasecret' - scopes = ['foo', 'bar'] - with open(filename, 'rb') as file_obj: - key_contents = file_obj.read() - # Seek back to the beginning so the buffer can be - # passed to the constructor. - file_obj.seek(0) - creds = ServiceAccountCredentials.from_p12_keyfile_buffer( - service_account_email, file_obj, - private_key_password=private_key_password, - scopes=scopes) - # Check the created object. - self.assertIsInstance(creds, ServiceAccountCredentials) - self.assertIsNone(creds.client_id) - self.assertEqual(creds._service_account_email, service_account_email) - self.assertIsNone(creds._private_key_id) - self.assertIsNone(creds._private_key_pkcs8_pem) - self.assertEqual(creds._private_key_pkcs12, key_contents) - self.assertEqual(creds._private_key_password, private_key_password) - self.assertEqual(creds._scopes, ' '.join(scopes)) + scopes=['foo', 'bar'], + token_uri='baz', revoke_uri='qux') def test_create_scoped_required_without_scopes(self): self.assertTrue(self.credentials.create_scoped_required())