Refactor xsrfutil
This commit is contained in:
@@ -25,17 +25,27 @@ import base64
|
||||
import hmac
|
||||
import time
|
||||
|
||||
import six
|
||||
from oauth2client import util
|
||||
|
||||
|
||||
# Delimiter character
|
||||
DELIMITER = ':'
|
||||
DELIMITER = b':'
|
||||
|
||||
ENCODING = 'utf-8'
|
||||
|
||||
# 1 hour in seconds
|
||||
DEFAULT_TIMEOUT_SECS = 1*60*60
|
||||
|
||||
|
||||
def _force_bytes(s):
|
||||
if isinstance(s, bytes):
|
||||
return s
|
||||
s = str(s)
|
||||
if isinstance(s, six.text_type):
|
||||
return s.encode('utf-8')
|
||||
return s
|
||||
|
||||
|
||||
@util.positional(2)
|
||||
def generate_token(key, user_id, action_id="", when=None):
|
||||
"""Generates a URL-safe token for the given user, action, time tuple.
|
||||
@@ -51,22 +61,16 @@ def generate_token(key, user_id, action_id="", when=None):
|
||||
Returns:
|
||||
A string XSRF protection token.
|
||||
"""
|
||||
when = when or int(time.time())
|
||||
decoded_key = '{key}{user_id}{delim}{action_id}{delim}{time}'.format(key=key,
|
||||
user_id=user_id,
|
||||
action_id=action_id,
|
||||
delim=DELIMITER,
|
||||
time=when).encode(ENCODING)
|
||||
|
||||
digester = hmac.new(decoded_key)
|
||||
when = _force_bytes(when or int(time.time()))
|
||||
digester = hmac.new(_force_bytes(key))
|
||||
digester.update(_force_bytes(user_id))
|
||||
digester.update(DELIMITER)
|
||||
digester.update(_force_bytes(action_id))
|
||||
digester.update(DELIMITER)
|
||||
digester.update(when)
|
||||
digest = digester.digest()
|
||||
|
||||
decoded_token = '{digest}{delim}{time}'.format(digest=digest, delim=DELIMITER, time=when)
|
||||
|
||||
try:
|
||||
token = base64.urlsafe_b64encode(decoded_token.encode(ENCODING))
|
||||
except UnicodeDecodeError:
|
||||
token = base64.urlsafe_b64encode(decoded_token)
|
||||
token = base64.urlsafe_b64encode(digest + DELIMITER + when)
|
||||
return token
|
||||
|
||||
|
||||
@@ -92,17 +96,9 @@ def validate_token(key, token, user_id, action_id="", current_time=None):
|
||||
return False
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(token)
|
||||
# Decode is needed for Python3
|
||||
# It will fail for Python2
|
||||
token_time = int(decoded.decode(ENCODING).split(DELIMITER)[-1])
|
||||
token_time = int(decoded.split(DELIMITER)[-1])
|
||||
except (TypeError, ValueError):
|
||||
try:
|
||||
# Try again, in case it fails here
|
||||
decoded = base64.urlsafe_b64decode(token)
|
||||
# Decode is not needed for Python2
|
||||
token_time = int(decoded.split(DELIMITER)[-1])
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return False
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
# If the token is too old it's not valid.
|
||||
@@ -117,15 +113,6 @@ def validate_token(key, token, user_id, action_id="", current_time=None):
|
||||
|
||||
# Perform constant time comparison to avoid timing attacks
|
||||
different = 0
|
||||
try:
|
||||
# Python3
|
||||
for x, y in zip(token, expected_token):
|
||||
different |= x ^ y
|
||||
except (TypeError, ValueError):
|
||||
# Python2
|
||||
for x, y in zip(token.encode(ENCODING), expected_token.encode(ENCODING)):
|
||||
different |= ord(x) ^ ord(y)
|
||||
if different:
|
||||
return False
|
||||
|
||||
return True
|
||||
for x, y in zip(bytearray(token), bytearray(expected_token)):
|
||||
different |= x ^ y
|
||||
return different == 0
|
||||
|
||||
@@ -96,7 +96,7 @@ class XsrfUtilTests(unittest.TestCase):
|
||||
|
||||
# Invalid with extra garbage
|
||||
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
||||
token.decode('utf-8') + 'x',
|
||||
token + b'x',
|
||||
TEST_USER_ID_1,
|
||||
action_id=TEST_ACTION_ID_1,
|
||||
current_time=later15mins))
|
||||
|
||||
Reference in New Issue
Block a user