Getting to 100% coverage for xsrfutil module.
This commit is contained in:
@@ -16,10 +16,12 @@
|
||||
"""Helper methods for creating & verifying XSRF tokens."""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import hmac
|
||||
import six
|
||||
import time
|
||||
|
||||
import six
|
||||
from oauth2client._helpers import _to_bytes
|
||||
from oauth2client import util
|
||||
|
||||
__authors__ = [
|
||||
@@ -31,20 +33,11 @@ __authors__ = [
|
||||
DELIMITER = b':'
|
||||
|
||||
# 1 hour in seconds
|
||||
DEFAULT_TIMEOUT_SECS = 1 * 60 * 60
|
||||
|
||||
|
||||
def _force_bytes(s):
|
||||
if isinstance(s, bytes):
|
||||
return s
|
||||
s = str(s)
|
||||
if isinstance(s, six.text_type):
|
||||
return s.encode('utf-8')
|
||||
return s
|
||||
DEFAULT_TIMEOUT_SECS = 60 * 60
|
||||
|
||||
|
||||
@util.positional(2)
|
||||
def generate_token(key, user_id, action_id="", when=None):
|
||||
def generate_token(key, user_id, action_id='', when=None):
|
||||
"""Generates a URL-safe token for the given user, action, time tuple.
|
||||
|
||||
Args:
|
||||
@@ -58,12 +51,12 @@ def generate_token(key, user_id, action_id="", when=None):
|
||||
Returns:
|
||||
A string XSRF protection token.
|
||||
"""
|
||||
when = _force_bytes(when or int(time.time()))
|
||||
digester = hmac.new(_force_bytes(key))
|
||||
digester.update(_force_bytes(user_id))
|
||||
digester = hmac.new(_to_bytes(key, encoding='utf-8'))
|
||||
digester.update(_to_bytes(str(user_id), encoding='utf-8'))
|
||||
digester.update(DELIMITER)
|
||||
digester.update(_force_bytes(action_id))
|
||||
digester.update(_to_bytes(action_id, encoding='utf-8'))
|
||||
digester.update(DELIMITER)
|
||||
when = _to_bytes(str(when or int(time.time())), encoding='utf-8')
|
||||
digester.update(when)
|
||||
digest = digester.digest()
|
||||
|
||||
@@ -94,7 +87,7 @@ def validate_token(key, token, user_id, action_id="", current_time=None):
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(token)
|
||||
token_time = int(decoded.split(DELIMITER)[-1])
|
||||
except (TypeError, ValueError):
|
||||
except (TypeError, ValueError, binascii.Error):
|
||||
return False
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
|
||||
@@ -16,24 +16,207 @@
|
||||
Unit tests for oauth2client.xsrfutil.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import unittest
|
||||
|
||||
import mock
|
||||
|
||||
from oauth2client._helpers import _to_bytes
|
||||
from oauth2client import xsrfutil
|
||||
|
||||
# Jan 17 2008, 5:40PM
|
||||
TEST_KEY = 'test key'
|
||||
TEST_KEY = b'test key'
|
||||
# Jan. 17, 2008 22:40:32.081230 UTC
|
||||
TEST_TIME = 1200609642081230
|
||||
TEST_USER_ID_1 = 123832983
|
||||
TEST_USER_ID_2 = 938297432
|
||||
TEST_ACTION_ID_1 = 'some_action'
|
||||
TEST_ACTION_ID_2 = 'some_other_action'
|
||||
TEST_EXTRA_INFO_1 = 'extra_info_1'
|
||||
TEST_EXTRA_INFO_2 = 'more_extra_info'
|
||||
TEST_ACTION_ID_1 = b'some_action'
|
||||
TEST_ACTION_ID_2 = b'some_other_action'
|
||||
TEST_EXTRA_INFO_1 = b'extra_info_1'
|
||||
TEST_EXTRA_INFO_2 = b'more_extra_info'
|
||||
|
||||
|
||||
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
|
||||
|
||||
|
||||
class Test_generate_token(unittest.TestCase):
|
||||
|
||||
def test_bad_positional(self):
|
||||
# Need 2 positional arguments.
|
||||
self.assertRaises(TypeError, xsrfutil.generate_token, None)
|
||||
# At most 2 positional arguments.
|
||||
self.assertRaises(TypeError, xsrfutil.generate_token, None, None, None)
|
||||
|
||||
def test_it(self):
|
||||
digest = b'foobar'
|
||||
curr_time = 1440449755.74
|
||||
digester = mock.MagicMock()
|
||||
digester.digest = mock.MagicMock(name='digest', return_value=digest)
|
||||
with mock.patch('oauth2client.xsrfutil.hmac') as hmac:
|
||||
hmac.new = mock.MagicMock(name='new', return_value=digester)
|
||||
token = xsrfutil.generate_token(TEST_KEY,
|
||||
TEST_USER_ID_1,
|
||||
action_id=TEST_ACTION_ID_1,
|
||||
when=TEST_TIME)
|
||||
hmac.new.assert_called_once_with(TEST_KEY)
|
||||
digester.digest.assert_called_once_with()
|
||||
|
||||
expected_digest_calls = [
|
||||
mock.call.update(_to_bytes(str(TEST_USER_ID_1))),
|
||||
mock.call.update(xsrfutil.DELIMITER),
|
||||
mock.call.update(TEST_ACTION_ID_1),
|
||||
mock.call.update(xsrfutil.DELIMITER),
|
||||
mock.call.update(_to_bytes(str(TEST_TIME))),
|
||||
]
|
||||
self.assertEqual(digester.method_calls, expected_digest_calls)
|
||||
|
||||
expected_token_as_bytes = (digest + xsrfutil.DELIMITER +
|
||||
_to_bytes(str(TEST_TIME)))
|
||||
expected_token = base64.urlsafe_b64encode(
|
||||
expected_token_as_bytes)
|
||||
self.assertEqual(token, expected_token)
|
||||
|
||||
def test_with_system_time(self):
|
||||
digest = b'foobar'
|
||||
curr_time = 1440449755.74
|
||||
digester = mock.MagicMock()
|
||||
digester.digest = mock.MagicMock(name='digest', return_value=digest)
|
||||
with mock.patch('oauth2client.xsrfutil.hmac') as hmac:
|
||||
hmac.new = mock.MagicMock(name='new', return_value=digester)
|
||||
|
||||
with mock.patch('oauth2client.xsrfutil.time') as time:
|
||||
time.time = mock.MagicMock(name='time', return_value=curr_time)
|
||||
# when= is omitted
|
||||
token = xsrfutil.generate_token(TEST_KEY,
|
||||
TEST_USER_ID_1,
|
||||
action_id=TEST_ACTION_ID_1)
|
||||
|
||||
hmac.new.assert_called_once_with(TEST_KEY)
|
||||
time.time.assert_called_once_with()
|
||||
digester.digest.assert_called_once_with()
|
||||
|
||||
expected_digest_calls = [
|
||||
mock.call.update(_to_bytes(str(TEST_USER_ID_1))),
|
||||
mock.call.update(xsrfutil.DELIMITER),
|
||||
mock.call.update(TEST_ACTION_ID_1),
|
||||
mock.call.update(xsrfutil.DELIMITER),
|
||||
mock.call.update(_to_bytes(str(int(curr_time)))),
|
||||
]
|
||||
self.assertEqual(digester.method_calls, expected_digest_calls)
|
||||
|
||||
expected_token_as_bytes = (digest + xsrfutil.DELIMITER +
|
||||
_to_bytes(str(int(curr_time))))
|
||||
expected_token = base64.urlsafe_b64encode(
|
||||
expected_token_as_bytes)
|
||||
self.assertEqual(token, expected_token)
|
||||
|
||||
|
||||
class Test_validate_token(unittest.TestCase):
|
||||
|
||||
def test_bad_positional(self):
|
||||
# Need 3 positional arguments.
|
||||
self.assertRaises(TypeError, xsrfutil.validate_token, None, None)
|
||||
# At most 3 positional arguments.
|
||||
self.assertRaises(TypeError, xsrfutil.validate_token,
|
||||
None, None, None, None)
|
||||
|
||||
def test_no_token(self):
|
||||
key = token = user_id = None
|
||||
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
|
||||
|
||||
def test_token_not_valid_base64(self):
|
||||
key = user_id = None
|
||||
token = b'a' # Bad padding
|
||||
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
|
||||
|
||||
def test_token_non_integer(self):
|
||||
key = user_id = None
|
||||
token = base64.b64encode(b'abc' + xsrfutil.DELIMITER + b'xyz')
|
||||
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
|
||||
|
||||
def test_token_too_old_implicit_current_time(self):
|
||||
token_time = 123456789
|
||||
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1
|
||||
|
||||
key = user_id = None
|
||||
token = base64.b64encode(_to_bytes(str(token_time)))
|
||||
with mock.patch('oauth2client.xsrfutil.time') as time:
|
||||
time.time = mock.MagicMock(name='time', return_value=curr_time)
|
||||
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
|
||||
time.time.assert_called_once_with()
|
||||
|
||||
def test_token_too_old_explicit_current_time(self):
|
||||
token_time = 123456789
|
||||
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1
|
||||
|
||||
key = user_id = None
|
||||
token = base64.b64encode(_to_bytes(str(token_time)))
|
||||
self.assertFalse(xsrfutil.validate_token(key, token, user_id,
|
||||
current_time=curr_time))
|
||||
|
||||
def test_token_length_differs_from_generated(self):
|
||||
token_time = 123456789
|
||||
# Make sure it isn't too old.
|
||||
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1
|
||||
|
||||
key = object()
|
||||
user_id = object()
|
||||
action_id = object()
|
||||
token = base64.b64encode(_to_bytes(str(token_time)))
|
||||
generated_token = b'a'
|
||||
# Make sure the token length comparison will fail.
|
||||
self.assertNotEqual(len(token), len(generated_token))
|
||||
|
||||
with mock.patch('oauth2client.xsrfutil.generate_token',
|
||||
return_value=generated_token) as gen_tok:
|
||||
self.assertFalse(xsrfutil.validate_token(key, token, user_id,
|
||||
current_time=curr_time,
|
||||
action_id=action_id))
|
||||
gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
|
||||
when=token_time)
|
||||
|
||||
def test_token_differs_from_generated_but_same_length(self):
|
||||
token_time = 123456789
|
||||
# Make sure it isn't too old.
|
||||
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1
|
||||
|
||||
key = object()
|
||||
user_id = object()
|
||||
action_id = object()
|
||||
token = base64.b64encode(_to_bytes(str(token_time)))
|
||||
# It is encoded as b'MTIzNDU2Nzg5', which has length 12.
|
||||
generated_token = b'M' * 12
|
||||
# Make sure the token length comparison will succeed, but the token
|
||||
# comparison will fail.
|
||||
self.assertEqual(len(token), len(generated_token))
|
||||
self.assertNotEqual(token, generated_token)
|
||||
|
||||
with mock.patch('oauth2client.xsrfutil.generate_token',
|
||||
return_value=generated_token) as gen_tok:
|
||||
self.assertFalse(xsrfutil.validate_token(key, token, user_id,
|
||||
current_time=curr_time,
|
||||
action_id=action_id))
|
||||
gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
|
||||
when=token_time)
|
||||
|
||||
def test_success(self):
|
||||
token_time = 123456789
|
||||
# Make sure it isn't too old.
|
||||
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1
|
||||
|
||||
key = object()
|
||||
user_id = object()
|
||||
action_id = object()
|
||||
token = base64.b64encode(_to_bytes(str(token_time)))
|
||||
with mock.patch('oauth2client.xsrfutil.generate_token',
|
||||
return_value=token) as gen_tok:
|
||||
self.assertTrue(xsrfutil.validate_token(key, token, user_id,
|
||||
current_time=curr_time,
|
||||
action_id=action_id))
|
||||
gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
|
||||
when=token_time)
|
||||
|
||||
|
||||
class XsrfUtilTests(unittest.TestCase):
|
||||
"""Test xsrfutil functions."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user