Refactor xsrfutil

This commit is contained in:
INADA Naoki
2014-08-29 19:28:39 +09:00
parent a1f3885e20
commit ec1fb365ca
2 changed files with 26 additions and 39 deletions

View File

@@ -25,17 +25,27 @@ import base64
import hmac
import time
import six
from oauth2client import util
# Delimiter character
DELIMITER = ':'
DELIMITER = b':'
ENCODING = 'utf-8'
# 1 hour in seconds
DEFAULT_TIMEOUT_SECS = 1*60*60
def _force_bytes(s):
if isinstance(s, bytes):
return s
s = str(s)
if isinstance(s, six.text_type):
return s.encode('utf-8')
return s
@util.positional(2)
def generate_token(key, user_id, action_id="", when=None):
"""Generates a URL-safe token for the given user, action, time tuple.
@@ -51,22 +61,16 @@ def generate_token(key, user_id, action_id="", when=None):
Returns:
A string XSRF protection token.
"""
when = when or int(time.time())
decoded_key = '{key}{user_id}{delim}{action_id}{delim}{time}'.format(key=key,
user_id=user_id,
action_id=action_id,
delim=DELIMITER,
time=when).encode(ENCODING)
digester = hmac.new(decoded_key)
when = _force_bytes(when or int(time.time()))
digester = hmac.new(_force_bytes(key))
digester.update(_force_bytes(user_id))
digester.update(DELIMITER)
digester.update(_force_bytes(action_id))
digester.update(DELIMITER)
digester.update(when)
digest = digester.digest()
decoded_token = '{digest}{delim}{time}'.format(digest=digest, delim=DELIMITER, time=when)
try:
token = base64.urlsafe_b64encode(decoded_token.encode(ENCODING))
except UnicodeDecodeError:
token = base64.urlsafe_b64encode(decoded_token)
token = base64.urlsafe_b64encode(digest + DELIMITER + when)
return token
@@ -92,17 +96,9 @@ def validate_token(key, token, user_id, action_id="", current_time=None):
return False
try:
decoded = base64.urlsafe_b64decode(token)
# Decode is needed for Python3
# It will fail for Python2
token_time = int(decoded.decode(ENCODING).split(DELIMITER)[-1])
token_time = int(decoded.split(DELIMITER)[-1])
except (TypeError, ValueError):
try:
# Try again, in case it fails here
decoded = base64.urlsafe_b64decode(token)
# Decode is not needed for Python2
token_time = int(decoded.split(DELIMITER)[-1])
except (TypeError, ValueError):
return False
return False
if current_time is None:
current_time = time.time()
# If the token is too old it's not valid.
@@ -117,15 +113,6 @@ def validate_token(key, token, user_id, action_id="", current_time=None):
# Perform constant time comparison to avoid timing attacks
different = 0
try:
# Python3
for x, y in zip(token, expected_token):
different |= x ^ y
except (TypeError, ValueError):
# Python2
for x, y in zip(token.encode(ENCODING), expected_token.encode(ENCODING)):
different |= ord(x) ^ ord(y)
if different:
return False
return True
for x, y in zip(bytearray(token), bytearray(expected_token)):
different |= x ^ y
return different == 0

View File

@@ -96,7 +96,7 @@ class XsrfUtilTests(unittest.TestCase):
# Invalid with extra garbage
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
token.decode('utf-8') + 'x',
token + b'x',
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
current_time=later15mins))