diff --git a/oauth2client/crypt.py b/oauth2client/crypt.py index de62e06..c450c5c 100644 --- a/oauth2client/crypt.py +++ b/oauth2client/crypt.py @@ -34,24 +34,26 @@ logger = logging.getLogger(__name__) class AppIdentityError(Exception): - pass + """Error to indicate crypto failure.""" + + +def _bad_pkcs12_key_as_pem(*args, **kwargs): + raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.') try: from oauth2client._openssl_crypt import OpenSSLVerifier from oauth2client._openssl_crypt import OpenSSLSigner from oauth2client._openssl_crypt import pkcs12_key_as_pem -except ImportError: +except ImportError: # pragma: NO COVER OpenSSLVerifier = None OpenSSLSigner = None - - def pkcs12_key_as_pem(*args, **kwargs): - raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.') + pkcs12_key_as_pem = _bad_pkcs12_key_as_pem try: from oauth2client._pycrypto_crypt import PyCryptoVerifier from oauth2client._pycrypto_crypt import PyCryptoSigner -except ImportError: +except ImportError: # pragma: NO COVER PyCryptoVerifier = None PyCryptoSigner = None @@ -59,10 +61,10 @@ except ImportError: if OpenSSLSigner: Signer = OpenSSLSigner Verifier = OpenSSLVerifier -elif PyCryptoSigner: +elif PyCryptoSigner: # pragma: NO COVER Signer = PyCryptoSigner Verifier = PyCryptoVerifier -else: +else: # pragma: NO COVER raise ImportError('No encryption library found. Please install either ' 'PyOpenSSL, or PyCrypto 2.6 or later') @@ -95,7 +97,107 @@ def make_signed_jwt(signer, payload): return b'.'.join(segments) -def verify_signed_jwt_with_certs(jwt, certs, audience): +def _verify_signature(message, signature, certs): + """Verifies signed content using a list of certificates. + + Args: + message: string or bytes, The message to verify. + signature: string or bytes, The signature on the message. + certs: iterable, certificates in PEM format. + + Raises: + AppIdentityError: If none of the certificates can verify the message + against the signature. + """ + for pem in certs: + verifier = Verifier.from_string(pem, is_x509_cert=True) + if verifier.verify(message, signature): + return + + # If we have not returned, no certificate confirms the signature. + raise AppIdentityError('Invalid token signature') + + +def _check_audience(payload_dict, audience): + """Checks audience field from a JWT payload. + + Does nothing if the passed in ``audience`` is null. + + Args: + payload_dict: dict, A dictionary containing a JWT payload. + audience: string or NoneType, an audience to check for in + the JWT payload. + + Raises: + AppIdentityError: If there is no ``'aud'`` field in the payload + dictionary but there is an ``audience`` to check. + AppIdentityError: If the ``'aud'`` field in the payload dictionary + does not match the ``audience``. + """ + if audience is None: + return + + audience_in_payload = payload_dict.get('aud') + if audience_in_payload is None: + raise AppIdentityError('No aud field in token: %s' % + (payload_dict,)) + if audience_in_payload != audience: + raise AppIdentityError('Wrong recipient, %s != %s: %s' % + (audience_in_payload, audience, payload_dict)) + + +def _verify_time_range(payload_dict): + """Verifies the issued at and expiration from a JWT payload. + + Makes sure the current time (in UTC) falls between the issued at and + expiration for the JWT (with some skew allowed for via + ``CLOCK_SKEW_SECS``). + + Args: + payload_dict: dict, A dictionary containing a JWT payload. + + Raises: + AppIdentityError: If there is no ``'iat'`` field in the payload + dictionary. + AppIdentityError: If there is no ``'exp'`` field in the payload + dictionary. + AppIdentityError: If the JWT expiration is too far in the future (i.e. + if the expiration would imply a token lifetime + longer than what is allowed.) + AppIdentityError: If the token appears to have been issued in the + future (up to clock skew). + AppIdentityError: If the token appears to have expired in the past + (up to clock skew). + """ + # Get the current time to use throughout. + now = int(time.time()) + + # Make sure issued at and expiration are in the payload. + issued_at = payload_dict.get('iat') + if issued_at is None: + raise AppIdentityError('No iat field in token: %s' % (payload_dict,)) + expiration = payload_dict.get('exp') + if expiration is None: + raise AppIdentityError('No exp field in token: %s' % (payload_dict,)) + + # Make sure the expiration gives an acceptable token lifetime. + if expiration >= now + MAX_TOKEN_LIFETIME_SECS: + raise AppIdentityError('exp field too far in future: %s' % + (payload_dict,)) + + # Make sure (up to clock skew) that the token wasn't issued in the future. + earliest = issued_at - CLOCK_SKEW_SECS + if now < earliest: + raise AppIdentityError('Token used too early, %d < %d: %s' % + (now, earliest, payload_dict)) + # Make sure (up to clock skew) that the token isn't already expired. + latest = expiration + CLOCK_SKEW_SECS + if now > latest: + raise AppIdentityError('Token used too late, %d > %d: %s' % + (now, latest, payload_dict)) + + +def verify_signed_jwt_with_certs(jwt, certs, audience=None): """Verify a JWT against public certs. See http://self-issued.info/docs/draft-jones-json-web-token.html. @@ -110,63 +212,32 @@ def verify_signed_jwt_with_certs(jwt, certs, audience): dict, The deserialized JSON payload in the JWT. Raises: - AppIdentityError if any checks are failed. + AppIdentityError: if any checks are failed. """ jwt = _to_bytes(jwt) - segments = jwt.split(b'.') - if len(segments) != 3: - raise AppIdentityError('Wrong number of segments in token: %s' % jwt) - signed = segments[0] + b'.' + segments[1] + if jwt.count(b'.') != 2: + raise AppIdentityError( + 'Wrong number of segments in token: %s' % (jwt,)) - signature = _urlsafe_b64decode(segments[2]) + header, payload, signature = jwt.split(b'.') + message_to_sign = header + b'.' + payload + signature = _urlsafe_b64decode(signature) # Parse token. - json_body = _urlsafe_b64decode(segments[1]) + payload_bytes = _urlsafe_b64decode(payload) try: - parsed = json.loads(_from_bytes(json_body)) + payload_dict = json.loads(_from_bytes(payload_bytes)) except: - raise AppIdentityError('Can\'t parse token: %s' % json_body) + raise AppIdentityError('Can\'t parse token: %s' % (payload_bytes,)) - # Check signature. - verified = False - for pem in certs.values(): - verifier = Verifier.from_string(pem, True) - if verifier.verify(signed, signature): - verified = True - break - if not verified: - raise AppIdentityError('Invalid token signature: %s' % jwt) + # Verify that the signature matches the message. + _verify_signature(message_to_sign, signature, certs.values()) - # Check creation timestamp. - iat = parsed.get('iat') - if iat is None: - raise AppIdentityError('No iat field in token: %s' % json_body) - earliest = iat - CLOCK_SKEW_SECS - - # Check expiration timestamp. - now = int(time.time()) - exp = parsed.get('exp') - if exp is None: - raise AppIdentityError('No exp field in token: %s' % json_body) - if exp >= now + MAX_TOKEN_LIFETIME_SECS: - raise AppIdentityError('exp field too far in future: %s' % json_body) - latest = exp + CLOCK_SKEW_SECS - - if now < earliest: - raise AppIdentityError('Token used too early, %d < %d: %s' % - (now, earliest, json_body)) - if now > latest: - raise AppIdentityError('Token used too late, %d > %d: %s' % - (now, latest, json_body)) + # Verify the issued at and created times in the payload. + _verify_time_range(payload_dict) # Check audience. - if audience is not None: - aud = parsed.get('aud') - if aud is None: - raise AppIdentityError('No aud field in token: %s' % json_body) - if aud != audience: - raise AppIdentityError('Wrong recipient, %s != %s: %s' % - (aud, audience, json_body)) + _check_audience(payload_dict, audience) - return parsed + return payload_dict diff --git a/tests/test_crypt.py b/tests/test_crypt.py index da8797a..06e76d0 100644 --- a/tests/test_crypt.py +++ b/tests/test_crypt.py @@ -12,16 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import mock +import base64 import os import sys import unittest -try: - reload -except NameError: - # For Python3 (though importlib should be used, silly 3.3). - from imp import reload +import mock from oauth2client import _helpers from oauth2client.client import HAS_OPENSSL @@ -36,6 +32,12 @@ def datafile(filename): return data +class Test__bad_pkcs12_key_as_pem(unittest.TestCase): + + def test_fails(self): + self.assertRaises(NotImplementedError, crypt._bad_pkcs12_key_as_pem) + + class Test_pkcs12_key_as_pem(unittest.TestCase): def _make_signed_jwt_creds(self, private_key_file='privatekey.p12', @@ -73,3 +75,243 @@ class Test_pkcs12_key_as_pem(unittest.TestCase): self.assertRaises(crypto.Error, crypt.pkcs12_key_as_pem, credentials.private_key, credentials.private_key_password) + + +class Test__verify_signature(unittest.TestCase): + + def test_success_single_cert(self): + cert_value = 'cert-value' + certs = [cert_value] + message = object() + signature = object() + + verifier = mock.MagicMock() + verifier.verify = mock.MagicMock(name='verify', return_value=True) + with mock.patch('oauth2client.crypt.Verifier') as Verifier: + Verifier.from_string = mock.MagicMock(name='from_string', + return_value=verifier) + result = crypt._verify_signature(message, signature, certs) + self.assertEqual(result, None) + + # Make sure our mocks were called as expected. + Verifier.from_string.assert_called_once_with(cert_value, + is_x509_cert=True) + verifier.verify.assert_called_once_with(message, signature) + + def test_success_multiple_certs(self): + cert_value1 = 'cert-value1' + cert_value2 = 'cert-value2' + cert_value3 = 'cert-value3' + certs = [cert_value1, cert_value2, cert_value3] + message = object() + signature = object() + + verifier = mock.MagicMock() + # Use side_effect to force all 3 cert values to be used by failing + # to verify on the first two. + verifier.verify = mock.MagicMock(name='verify', + side_effect=[False, False, True]) + with mock.patch('oauth2client.crypt.Verifier') as Verifier: + Verifier.from_string = mock.MagicMock(name='from_string', + return_value=verifier) + result = crypt._verify_signature(message, signature, certs) + self.assertEqual(result, None) + + # Make sure our mocks were called three times. + expected_from_string_calls = [ + mock.call(cert_value1, is_x509_cert=True), + mock.call(cert_value2, is_x509_cert=True), + mock.call(cert_value3, is_x509_cert=True), + ] + self.assertEqual(Verifier.from_string.mock_calls, + expected_from_string_calls) + expected_verify_calls = [mock.call(message, signature)] * 3 + self.assertEqual(verifier.verify.mock_calls, + expected_verify_calls) + + def test_failure(self): + cert_value = 'cert-value' + certs = [cert_value] + message = object() + signature = object() + + verifier = mock.MagicMock() + verifier.verify = mock.MagicMock(name='verify', return_value=False) + with mock.patch('oauth2client.crypt.Verifier') as Verifier: + Verifier.from_string = mock.MagicMock(name='from_string', + return_value=verifier) + self.assertRaises(crypt.AppIdentityError, crypt._verify_signature, + message, signature, certs) + + # Make sure our mocks were called as expected. + Verifier.from_string.assert_called_once_with(cert_value, + is_x509_cert=True) + verifier.verify.assert_called_once_with(message, signature) + + +class Test__check_audience(unittest.TestCase): + + def test_null_audience(self): + result = crypt._check_audience(None, None) + self.assertEqual(result, None) + + def test_success(self): + audience = 'audience' + payload_dict = {'aud': audience} + result = crypt._check_audience(payload_dict, audience) + # No exception and no result. + self.assertEqual(result, None) + + def test_missing_aud(self): + audience = 'audience' + payload_dict = {} + self.assertRaises(crypt.AppIdentityError, crypt._check_audience, + payload_dict, audience) + + def test_wrong_aud(self): + audience1 = 'audience1' + audience2 = 'audience2' + self.assertNotEqual(audience1, audience2) + payload_dict = {'aud': audience1} + self.assertRaises(crypt.AppIdentityError, crypt._check_audience, + payload_dict, audience2) + +class Test__verify_time_range(unittest.TestCase): + + def _exception_helper(self, payload_dict): + exception_caught = None + try: + crypt._verify_time_range(payload_dict) + except crypt.AppIdentityError as exc: + exception_caught = exc + + return exception_caught + + def test_without_issued_at(self): + payload_dict = {} + exception_caught = self._exception_helper(payload_dict) + self.assertNotEqual(exception_caught, None) + self.assertTrue(str(exception_caught).startswith( + 'No iat field in token')) + + def test_without_expiration(self): + payload_dict = {'iat': 'iat'} + exception_caught = self._exception_helper(payload_dict) + self.assertNotEqual(exception_caught, None) + self.assertTrue(str(exception_caught).startswith( + 'No exp field in token')) + + def test_with_bad_token_lifetime(self): + current_time = 123456 + payload_dict = { + 'iat': 'iat', + 'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS + 1, + } + with mock.patch('oauth2client.crypt.time') as time: + time.time = mock.MagicMock(name='time', + return_value=current_time) + + exception_caught = self._exception_helper(payload_dict) + self.assertNotEqual(exception_caught, None) + self.assertTrue(str(exception_caught).startswith( + 'exp field too far in future')) + + def test_with_issued_at_in_future(self): + current_time = 123456 + payload_dict = { + 'iat': current_time + crypt.CLOCK_SKEW_SECS + 1, + 'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS - 1, + } + with mock.patch('oauth2client.crypt.time') as time: + time.time = mock.MagicMock(name='time', + return_value=current_time) + + exception_caught = self._exception_helper(payload_dict) + self.assertNotEqual(exception_caught, None) + self.assertTrue(str(exception_caught).startswith( + 'Token used too early')) + + def test_with_expiration_in_the_past(self): + current_time = 123456 + payload_dict = { + 'iat': current_time, + 'exp': current_time - crypt.CLOCK_SKEW_SECS - 1, + } + with mock.patch('oauth2client.crypt.time') as time: + time.time = mock.MagicMock(name='time', + return_value=current_time) + + exception_caught = self._exception_helper(payload_dict) + self.assertNotEqual(exception_caught, None) + self.assertTrue(str(exception_caught).startswith( + 'Token used too late')) + + def test_success(self): + current_time = 123456 + payload_dict = { + 'iat': current_time, + 'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS - 1, + } + with mock.patch('oauth2client.crypt.time') as time: + time.time = mock.MagicMock(name='time', + return_value=current_time) + + exception_caught = self._exception_helper(payload_dict) + self.assertEqual(exception_caught, None) + + +class Test_verify_signed_jwt_with_certs(unittest.TestCase): + + def test_jwt_no_segments(self): + exception_caught = None + try: + crypt.verify_signed_jwt_with_certs(b'', None) + except crypt.AppIdentityError as exc: + exception_caught = exc + + self.assertNotEqual(exception_caught, None) + self.assertTrue(str(exception_caught).startswith( + 'Wrong number of segments in token')) + + def test_jwt_payload_bad_json(self): + header = signature = b'' + payload = base64.b64encode(b'{BADJSON') + jwt = b'.'.join([header, payload, signature]) + + exception_caught = None + try: + crypt.verify_signed_jwt_with_certs(jwt, None) + except crypt.AppIdentityError as exc: + exception_caught = exc + + self.assertNotEqual(exception_caught, None) + self.assertTrue(str(exception_caught).startswith( + 'Can\'t parse token')) + + @mock.patch('oauth2client.crypt._check_audience') + @mock.patch('oauth2client.crypt._verify_time_range') + @mock.patch('oauth2client.crypt._verify_signature') + def test_success(self, verify_sig, verify_time, check_aud): + certs = mock.MagicMock() + cert_values = object() + certs.values = mock.MagicMock(name='values', + return_value=cert_values) + audience = object() + + header = b'header' + signature_bytes = b'signature' + signature = base64.b64encode(signature_bytes) + payload_dict = {'a': 'b'} + payload = base64.b64encode(b'{"a": "b"}') + jwt = b'.'.join([header, payload, signature]) + + result = crypt.verify_signed_jwt_with_certs( + jwt, certs, audience=audience) + self.assertEqual(result, payload_dict) + + message_to_sign = header + b'.' + payload + verify_sig.assert_called_once_with( + message_to_sign, signature_bytes, cert_values) + verify_time.assert_called_once_with(payload_dict) + check_aud.assert_called_once_with(payload_dict, audience) + certs.values.assert_called_once_with()