From 427412999c21529d0020d905af6d40e1f563b30d Mon Sep 17 00:00:00 2001 From: Danny Hermes Date: Mon, 24 Aug 2015 15:21:36 -0700 Subject: [PATCH] Getting to 100% coverage for xsrfutil module. --- oauth2client/xsrfutil.py | 27 ++---- tests/test_xsrfutil.py | 193 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 198 insertions(+), 22 deletions(-) diff --git a/oauth2client/xsrfutil.py b/oauth2client/xsrfutil.py index 5c40832..685eb46 100644 --- a/oauth2client/xsrfutil.py +++ b/oauth2client/xsrfutil.py @@ -16,10 +16,12 @@ """Helper methods for creating & verifying XSRF tokens.""" import base64 +import binascii import hmac +import six import time -import six +from oauth2client._helpers import _to_bytes from oauth2client import util __authors__ = [ @@ -31,20 +33,11 @@ __authors__ = [ DELIMITER = b':' # 1 hour in seconds -DEFAULT_TIMEOUT_SECS = 1 * 60 * 60 - - -def _force_bytes(s): - if isinstance(s, bytes): - return s - s = str(s) - if isinstance(s, six.text_type): - return s.encode('utf-8') - return s +DEFAULT_TIMEOUT_SECS = 60 * 60 @util.positional(2) -def generate_token(key, user_id, action_id="", when=None): +def generate_token(key, user_id, action_id='', when=None): """Generates a URL-safe token for the given user, action, time tuple. Args: @@ -58,12 +51,12 @@ def generate_token(key, user_id, action_id="", when=None): Returns: A string XSRF protection token. """ - when = _force_bytes(when or int(time.time())) - digester = hmac.new(_force_bytes(key)) - digester.update(_force_bytes(user_id)) + digester = hmac.new(_to_bytes(key, encoding='utf-8')) + digester.update(_to_bytes(str(user_id), encoding='utf-8')) digester.update(DELIMITER) - digester.update(_force_bytes(action_id)) + digester.update(_to_bytes(action_id, encoding='utf-8')) digester.update(DELIMITER) + when = _to_bytes(str(when or int(time.time())), encoding='utf-8') digester.update(when) digest = digester.digest() @@ -94,7 +87,7 @@ def validate_token(key, token, user_id, action_id="", current_time=None): try: decoded = base64.urlsafe_b64decode(token) token_time = int(decoded.split(DELIMITER)[-1]) - except (TypeError, ValueError): + except (TypeError, ValueError, binascii.Error): return False if current_time is None: current_time = time.time() diff --git a/tests/test_xsrfutil.py b/tests/test_xsrfutil.py index 50b79b4..6f2214c 100644 --- a/tests/test_xsrfutil.py +++ b/tests/test_xsrfutil.py @@ -16,24 +16,207 @@ Unit tests for oauth2client.xsrfutil. """ +import base64 import unittest +import mock + +from oauth2client._helpers import _to_bytes from oauth2client import xsrfutil # Jan 17 2008, 5:40PM -TEST_KEY = 'test key' +TEST_KEY = b'test key' +# Jan. 17, 2008 22:40:32.081230 UTC TEST_TIME = 1200609642081230 TEST_USER_ID_1 = 123832983 TEST_USER_ID_2 = 938297432 -TEST_ACTION_ID_1 = 'some_action' -TEST_ACTION_ID_2 = 'some_other_action' -TEST_EXTRA_INFO_1 = 'extra_info_1' -TEST_EXTRA_INFO_2 = 'more_extra_info' +TEST_ACTION_ID_1 = b'some_action' +TEST_ACTION_ID_2 = b'some_other_action' +TEST_EXTRA_INFO_1 = b'extra_info_1' +TEST_EXTRA_INFO_2 = b'more_extra_info' __author__ = 'jcgregorio@google.com (Joe Gregorio)' +class Test_generate_token(unittest.TestCase): + + def test_bad_positional(self): + # Need 2 positional arguments. + self.assertRaises(TypeError, xsrfutil.generate_token, None) + # At most 2 positional arguments. + self.assertRaises(TypeError, xsrfutil.generate_token, None, None, None) + + def test_it(self): + digest = b'foobar' + curr_time = 1440449755.74 + digester = mock.MagicMock() + digester.digest = mock.MagicMock(name='digest', return_value=digest) + with mock.patch('oauth2client.xsrfutil.hmac') as hmac: + hmac.new = mock.MagicMock(name='new', return_value=digester) + token = xsrfutil.generate_token(TEST_KEY, + TEST_USER_ID_1, + action_id=TEST_ACTION_ID_1, + when=TEST_TIME) + hmac.new.assert_called_once_with(TEST_KEY) + digester.digest.assert_called_once_with() + + expected_digest_calls = [ + mock.call.update(_to_bytes(str(TEST_USER_ID_1))), + mock.call.update(xsrfutil.DELIMITER), + mock.call.update(TEST_ACTION_ID_1), + mock.call.update(xsrfutil.DELIMITER), + mock.call.update(_to_bytes(str(TEST_TIME))), + ] + self.assertEqual(digester.method_calls, expected_digest_calls) + + expected_token_as_bytes = (digest + xsrfutil.DELIMITER + + _to_bytes(str(TEST_TIME))) + expected_token = base64.urlsafe_b64encode( + expected_token_as_bytes) + self.assertEqual(token, expected_token) + + def test_with_system_time(self): + digest = b'foobar' + curr_time = 1440449755.74 + digester = mock.MagicMock() + digester.digest = mock.MagicMock(name='digest', return_value=digest) + with mock.patch('oauth2client.xsrfutil.hmac') as hmac: + hmac.new = mock.MagicMock(name='new', return_value=digester) + + with mock.patch('oauth2client.xsrfutil.time') as time: + time.time = mock.MagicMock(name='time', return_value=curr_time) + # when= is omitted + token = xsrfutil.generate_token(TEST_KEY, + TEST_USER_ID_1, + action_id=TEST_ACTION_ID_1) + + hmac.new.assert_called_once_with(TEST_KEY) + time.time.assert_called_once_with() + digester.digest.assert_called_once_with() + + expected_digest_calls = [ + mock.call.update(_to_bytes(str(TEST_USER_ID_1))), + mock.call.update(xsrfutil.DELIMITER), + mock.call.update(TEST_ACTION_ID_1), + mock.call.update(xsrfutil.DELIMITER), + mock.call.update(_to_bytes(str(int(curr_time)))), + ] + self.assertEqual(digester.method_calls, expected_digest_calls) + + expected_token_as_bytes = (digest + xsrfutil.DELIMITER + + _to_bytes(str(int(curr_time)))) + expected_token = base64.urlsafe_b64encode( + expected_token_as_bytes) + self.assertEqual(token, expected_token) + + +class Test_validate_token(unittest.TestCase): + + def test_bad_positional(self): + # Need 3 positional arguments. + self.assertRaises(TypeError, xsrfutil.validate_token, None, None) + # At most 3 positional arguments. + self.assertRaises(TypeError, xsrfutil.validate_token, + None, None, None, None) + + def test_no_token(self): + key = token = user_id = None + self.assertFalse(xsrfutil.validate_token(key, token, user_id)) + + def test_token_not_valid_base64(self): + key = user_id = None + token = b'a' # Bad padding + self.assertFalse(xsrfutil.validate_token(key, token, user_id)) + + def test_token_non_integer(self): + key = user_id = None + token = base64.b64encode(b'abc' + xsrfutil.DELIMITER + b'xyz') + self.assertFalse(xsrfutil.validate_token(key, token, user_id)) + + def test_token_too_old_implicit_current_time(self): + token_time = 123456789 + curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1 + + key = user_id = None + token = base64.b64encode(_to_bytes(str(token_time))) + with mock.patch('oauth2client.xsrfutil.time') as time: + time.time = mock.MagicMock(name='time', return_value=curr_time) + self.assertFalse(xsrfutil.validate_token(key, token, user_id)) + time.time.assert_called_once_with() + + def test_token_too_old_explicit_current_time(self): + token_time = 123456789 + curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1 + + key = user_id = None + token = base64.b64encode(_to_bytes(str(token_time))) + self.assertFalse(xsrfutil.validate_token(key, token, user_id, + current_time=curr_time)) + + def test_token_length_differs_from_generated(self): + token_time = 123456789 + # Make sure it isn't too old. + curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1 + + key = object() + user_id = object() + action_id = object() + token = base64.b64encode(_to_bytes(str(token_time))) + generated_token = b'a' + # Make sure the token length comparison will fail. + self.assertNotEqual(len(token), len(generated_token)) + + with mock.patch('oauth2client.xsrfutil.generate_token', + return_value=generated_token) as gen_tok: + self.assertFalse(xsrfutil.validate_token(key, token, user_id, + current_time=curr_time, + action_id=action_id)) + gen_tok.assert_called_once_with(key, user_id, action_id=action_id, + when=token_time) + + def test_token_differs_from_generated_but_same_length(self): + token_time = 123456789 + # Make sure it isn't too old. + curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1 + + key = object() + user_id = object() + action_id = object() + token = base64.b64encode(_to_bytes(str(token_time))) + # It is encoded as b'MTIzNDU2Nzg5', which has length 12. + generated_token = b'M' * 12 + # Make sure the token length comparison will succeed, but the token + # comparison will fail. + self.assertEqual(len(token), len(generated_token)) + self.assertNotEqual(token, generated_token) + + with mock.patch('oauth2client.xsrfutil.generate_token', + return_value=generated_token) as gen_tok: + self.assertFalse(xsrfutil.validate_token(key, token, user_id, + current_time=curr_time, + action_id=action_id)) + gen_tok.assert_called_once_with(key, user_id, action_id=action_id, + when=token_time) + + def test_success(self): + token_time = 123456789 + # Make sure it isn't too old. + curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1 + + key = object() + user_id = object() + action_id = object() + token = base64.b64encode(_to_bytes(str(token_time))) + with mock.patch('oauth2client.xsrfutil.generate_token', + return_value=token) as gen_tok: + self.assertTrue(xsrfutil.validate_token(key, token, user_id, + current_time=curr_time, + action_id=action_id)) + gen_tok.assert_called_once_with(key, user_id, action_id=action_id, + when=token_time) + + class XsrfUtilTests(unittest.TestCase): """Test xsrfutil functions."""