321 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			321 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2014 Google Inc. All rights reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #      http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import base64
 | |
| import os
 | |
| import unittest
 | |
| 
 | |
| import mock
 | |
| 
 | |
| from oauth2client import _helpers
 | |
| from oauth2client.client import HAS_OPENSSL
 | |
| from oauth2client.client import SignedJwtAssertionCredentials
 | |
| from oauth2client import crypt
 | |
| 
 | |
| 
 | |
| def datafile(filename):
 | |
|     f = open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb')
 | |
|     data = f.read()
 | |
|     f.close()
 | |
|     return data
 | |
| 
 | |
| 
 | |
| class Test__bad_pkcs12_key_as_pem(unittest.TestCase):
 | |
| 
 | |
|     def test_fails(self):
 | |
|         self.assertRaises(NotImplementedError, crypt._bad_pkcs12_key_as_pem)
 | |
| 
 | |
| 
 | |
| class Test_pkcs12_key_as_pem(unittest.TestCase):
 | |
| 
 | |
|     def _make_signed_jwt_creds(self, private_key_file='privatekey.p12',
 | |
|                                private_key=None):
 | |
|         private_key = private_key or datafile(private_key_file)
 | |
|         return SignedJwtAssertionCredentials(
 | |
|             'some_account@example.com',
 | |
|             private_key,
 | |
|             scope='read+write',
 | |
|             sub='joe@example.org')
 | |
| 
 | |
|     def _succeeds_helper(self, password=None):
 | |
|         self.assertEqual(True, HAS_OPENSSL)
 | |
| 
 | |
|         credentials = self._make_signed_jwt_creds()
 | |
|         if password is None:
 | |
|             password = credentials.private_key_password
 | |
|         pem_contents = crypt.pkcs12_key_as_pem(credentials.private_key,
 | |
|                                                password)
 | |
|         pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem')
 | |
|         pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem)
 | |
|         alternate_pem = datafile('pem_from_pkcs12_alternate.pem')
 | |
|         self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem])
 | |
| 
 | |
|     def test_succeeds(self):
 | |
|         self._succeeds_helper()
 | |
| 
 | |
|     def test_succeeds_with_unicode_password(self):
 | |
|         password = u'notasecret'
 | |
|         self._succeeds_helper(password)
 | |
| 
 | |
|     def test_with_nonsense_key(self):
 | |
|         from OpenSSL import crypto
 | |
|         credentials = self._make_signed_jwt_creds(private_key=b'NOT_A_KEY')
 | |
|         self.assertRaises(crypto.Error, crypt.pkcs12_key_as_pem,
 | |
|                           credentials.private_key,
 | |
|                           credentials.private_key_password)
 | |
| 
 | |
| 
 | |
| class Test__verify_signature(unittest.TestCase):
 | |
| 
 | |
|     def test_success_single_cert(self):
 | |
|         cert_value = 'cert-value'
 | |
|         certs = [cert_value]
 | |
|         message = object()
 | |
|         signature = object()
 | |
| 
 | |
|         verifier = mock.MagicMock()
 | |
|         verifier.verify = mock.MagicMock(name='verify', return_value=True)
 | |
|         with mock.patch('oauth2client.crypt.Verifier') as Verifier:
 | |
|             Verifier.from_string = mock.MagicMock(name='from_string',
 | |
|                                                   return_value=verifier)
 | |
|             result = crypt._verify_signature(message, signature, certs)
 | |
|             self.assertEqual(result, None)
 | |
| 
 | |
|             # Make sure our mocks were called as expected.
 | |
|             Verifier.from_string.assert_called_once_with(cert_value,
 | |
|                                                          is_x509_cert=True)
 | |
|             verifier.verify.assert_called_once_with(message, signature)
 | |
| 
 | |
|     def test_success_multiple_certs(self):
 | |
|         cert_value1 = 'cert-value1'
 | |
|         cert_value2 = 'cert-value2'
 | |
|         cert_value3 = 'cert-value3'
 | |
|         certs = [cert_value1, cert_value2, cert_value3]
 | |
|         message = object()
 | |
|         signature = object()
 | |
| 
 | |
|         verifier = mock.MagicMock()
 | |
|         # Use side_effect to force all 3 cert values to be used by failing
 | |
|         # to verify on the first two.
 | |
|         verifier.verify = mock.MagicMock(name='verify',
 | |
|                                          side_effect=[False, False, True])
 | |
|         with mock.patch('oauth2client.crypt.Verifier') as Verifier:
 | |
|             Verifier.from_string = mock.MagicMock(name='from_string',
 | |
|                                                   return_value=verifier)
 | |
|             result = crypt._verify_signature(message, signature, certs)
 | |
|             self.assertEqual(result, None)
 | |
| 
 | |
|             # Make sure our mocks were called three times.
 | |
|             expected_from_string_calls = [
 | |
|                 mock.call(cert_value1, is_x509_cert=True),
 | |
|                 mock.call(cert_value2, is_x509_cert=True),
 | |
|                 mock.call(cert_value3, is_x509_cert=True),
 | |
|             ]
 | |
|             self.assertEqual(Verifier.from_string.mock_calls,
 | |
|                              expected_from_string_calls)
 | |
|             expected_verify_calls = [mock.call(message, signature)] * 3
 | |
|             self.assertEqual(verifier.verify.mock_calls,
 | |
|                              expected_verify_calls)
 | |
| 
 | |
|     def test_failure(self):
 | |
|         cert_value = 'cert-value'
 | |
|         certs = [cert_value]
 | |
|         message = object()
 | |
|         signature = object()
 | |
| 
 | |
|         verifier = mock.MagicMock()
 | |
|         verifier.verify = mock.MagicMock(name='verify', return_value=False)
 | |
|         with mock.patch('oauth2client.crypt.Verifier') as Verifier:
 | |
|             Verifier.from_string = mock.MagicMock(name='from_string',
 | |
|                                                   return_value=verifier)
 | |
|             self.assertRaises(crypt.AppIdentityError, crypt._verify_signature,
 | |
|                               message, signature, certs)
 | |
| 
 | |
|             # Make sure our mocks were called as expected.
 | |
|             Verifier.from_string.assert_called_once_with(cert_value,
 | |
|                                                          is_x509_cert=True)
 | |
|             verifier.verify.assert_called_once_with(message, signature)
 | |
| 
 | |
| 
 | |
| class Test__check_audience(unittest.TestCase):
 | |
| 
 | |
|     def test_null_audience(self):
 | |
|         result = crypt._check_audience(None, None)
 | |
|         self.assertEqual(result, None)
 | |
| 
 | |
|     def test_success(self):
 | |
|         audience = 'audience'
 | |
|         payload_dict = {'aud': audience}
 | |
|         result = crypt._check_audience(payload_dict, audience)
 | |
|         # No exception and no result.
 | |
|         self.assertEqual(result, None)
 | |
| 
 | |
|     def test_missing_aud(self):
 | |
|         audience = 'audience'
 | |
|         payload_dict = {}
 | |
|         self.assertRaises(crypt.AppIdentityError, crypt._check_audience,
 | |
|                           payload_dict, audience)
 | |
| 
 | |
|     def test_wrong_aud(self):
 | |
|         audience1 = 'audience1'
 | |
|         audience2 = 'audience2'
 | |
|         self.assertNotEqual(audience1, audience2)
 | |
|         payload_dict = {'aud': audience1}
 | |
|         self.assertRaises(crypt.AppIdentityError, crypt._check_audience,
 | |
|                           payload_dict, audience2)
 | |
| 
 | |
| class Test__verify_time_range(unittest.TestCase):
 | |
| 
 | |
|     def _exception_helper(self, payload_dict):
 | |
|         exception_caught = None
 | |
|         try:
 | |
|             crypt._verify_time_range(payload_dict)
 | |
|         except crypt.AppIdentityError as exc:
 | |
|             exception_caught = exc
 | |
| 
 | |
|         return exception_caught
 | |
| 
 | |
|     def test_without_issued_at(self):
 | |
|         payload_dict = {}
 | |
|         exception_caught = self._exception_helper(payload_dict)
 | |
|         self.assertNotEqual(exception_caught, None)
 | |
|         self.assertTrue(str(exception_caught).startswith(
 | |
|             'No iat field in token'))
 | |
| 
 | |
|     def test_without_expiration(self):
 | |
|         payload_dict = {'iat': 'iat'}
 | |
|         exception_caught = self._exception_helper(payload_dict)
 | |
|         self.assertNotEqual(exception_caught, None)
 | |
|         self.assertTrue(str(exception_caught).startswith(
 | |
|             'No exp field in token'))
 | |
| 
 | |
|     def test_with_bad_token_lifetime(self):
 | |
|         current_time = 123456
 | |
|         payload_dict = {
 | |
|             'iat': 'iat',
 | |
|             'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS + 1,
 | |
|         }
 | |
|         with mock.patch('oauth2client.crypt.time') as time:
 | |
|             time.time = mock.MagicMock(name='time',
 | |
|                                        return_value=current_time)
 | |
| 
 | |
|             exception_caught = self._exception_helper(payload_dict)
 | |
|             self.assertNotEqual(exception_caught, None)
 | |
|             self.assertTrue(str(exception_caught).startswith(
 | |
|                 'exp field too far in future'))
 | |
| 
 | |
|     def test_with_issued_at_in_future(self):
 | |
|         current_time = 123456
 | |
|         payload_dict = {
 | |
|             'iat': current_time + crypt.CLOCK_SKEW_SECS + 1,
 | |
|             'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS - 1,
 | |
|         }
 | |
|         with mock.patch('oauth2client.crypt.time') as time:
 | |
|             time.time = mock.MagicMock(name='time',
 | |
|                                        return_value=current_time)
 | |
| 
 | |
|             exception_caught = self._exception_helper(payload_dict)
 | |
|             self.assertNotEqual(exception_caught, None)
 | |
|             self.assertTrue(str(exception_caught).startswith(
 | |
|                 'Token used too early'))
 | |
| 
 | |
|     def test_with_expiration_in_the_past(self):
 | |
|         current_time = 123456
 | |
|         payload_dict = {
 | |
|             'iat': current_time,
 | |
|             'exp': current_time - crypt.CLOCK_SKEW_SECS - 1,
 | |
|         }
 | |
|         with mock.patch('oauth2client.crypt.time') as time:
 | |
|             time.time = mock.MagicMock(name='time',
 | |
|                                        return_value=current_time)
 | |
| 
 | |
|             exception_caught = self._exception_helper(payload_dict)
 | |
|             self.assertNotEqual(exception_caught, None)
 | |
|             self.assertTrue(str(exception_caught).startswith(
 | |
|                 'Token used too late'))
 | |
| 
 | |
|     def test_success(self):
 | |
|         current_time = 123456
 | |
|         payload_dict = {
 | |
|             'iat': current_time,
 | |
|             'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS - 1,
 | |
|         }
 | |
|         with mock.patch('oauth2client.crypt.time') as time:
 | |
|             time.time = mock.MagicMock(name='time',
 | |
|                                        return_value=current_time)
 | |
| 
 | |
|             exception_caught = self._exception_helper(payload_dict)
 | |
|             self.assertEqual(exception_caught, None)
 | |
| 
 | |
| 
 | |
| class Test_verify_signed_jwt_with_certs(unittest.TestCase):
 | |
| 
 | |
|     def test_jwt_no_segments(self):
 | |
|         exception_caught = None
 | |
|         try:
 | |
|             crypt.verify_signed_jwt_with_certs(b'', None)
 | |
|         except crypt.AppIdentityError as exc:
 | |
|             exception_caught = exc
 | |
| 
 | |
|         self.assertNotEqual(exception_caught, None)
 | |
|         self.assertTrue(str(exception_caught).startswith(
 | |
|             'Wrong number of segments in token'))
 | |
| 
 | |
|     def test_jwt_payload_bad_json(self):
 | |
|         header = signature = b''
 | |
|         payload = base64.b64encode(b'{BADJSON')
 | |
|         jwt = b'.'.join([header, payload, signature])
 | |
| 
 | |
|         exception_caught = None
 | |
|         try:
 | |
|             crypt.verify_signed_jwt_with_certs(jwt, None)
 | |
|         except crypt.AppIdentityError as exc:
 | |
|             exception_caught = exc
 | |
| 
 | |
|         self.assertNotEqual(exception_caught, None)
 | |
|         self.assertTrue(str(exception_caught).startswith(
 | |
|             'Can\'t parse token'))
 | |
| 
 | |
|     @mock.patch('oauth2client.crypt._check_audience')
 | |
|     @mock.patch('oauth2client.crypt._verify_time_range')
 | |
|     @mock.patch('oauth2client.crypt._verify_signature')
 | |
|     def test_success(self, verify_sig, verify_time, check_aud):
 | |
|         certs = mock.MagicMock()
 | |
|         cert_values = object()
 | |
|         certs.values = mock.MagicMock(name='values',
 | |
|                                       return_value=cert_values)
 | |
|         audience = object()
 | |
| 
 | |
|         header = b'header'
 | |
|         signature_bytes = b'signature'
 | |
|         signature = base64.b64encode(signature_bytes)
 | |
|         payload_dict = {'a': 'b'}
 | |
|         payload = base64.b64encode(b'{"a": "b"}')
 | |
|         jwt = b'.'.join([header, payload, signature])
 | |
| 
 | |
|         result = crypt.verify_signed_jwt_with_certs(
 | |
|             jwt, certs, audience=audience)
 | |
|         self.assertEqual(result, payload_dict)
 | |
| 
 | |
|         message_to_sign = header + b'.' + payload
 | |
|         verify_sig.assert_called_once_with(
 | |
|             message_to_sign, signature_bytes, cert_values)
 | |
|         verify_time.assert_called_once_with(payload_dict)
 | |
|         check_aud.assert_called_once_with(payload_dict, audience)
 | |
|         certs.values.assert_called_once_with()
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':  # pragma: NO COVER
 | |
|     unittest.main()
 | 
