Python 3: call functions from memcache_crypt.py with bytes as input
These functions expect bytes as input, but in Python 3 they were given text strings. Change-Id: I39fa15b8d5d56dc536e0bd71a50cf27da3d03046
This commit is contained in:
@@ -1022,6 +1022,8 @@ class AuthProtocol(object):
|
|||||||
# Note that 'invalid' and (data, expires) are the only
|
# Note that 'invalid' and (data, expires) are the only
|
||||||
# valid types of serialized cache entries, so there is not
|
# valid types of serialized cache entries, so there is not
|
||||||
# a collision with jsonutils.loads(serialized) == None.
|
# a collision with jsonutils.loads(serialized) == None.
|
||||||
|
if not isinstance(serialized, six.string_types):
|
||||||
|
serialized = serialized.decode('utf-8')
|
||||||
cached = jsonutils.loads(serialized)
|
cached = jsonutils.loads(serialized)
|
||||||
if cached == 'invalid':
|
if cached == 'invalid':
|
||||||
self.LOG.debug('Cached Token %s is marked unauthorized',
|
self.LOG.debug('Cached Token %s is marked unauthorized',
|
||||||
@@ -1053,14 +1055,20 @@ class AuthProtocol(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
serialized_data = jsonutils.dumps(data)
|
serialized_data = jsonutils.dumps(data)
|
||||||
|
if isinstance(serialized_data, six.text_type):
|
||||||
|
serialized_data = serialized_data.encode('utf-8')
|
||||||
if self._memcache_security_strategy is None:
|
if self._memcache_security_strategy is None:
|
||||||
cache_key = CACHE_KEY_TEMPLATE % token_id
|
cache_key = CACHE_KEY_TEMPLATE % token_id
|
||||||
data_to_store = serialized_data
|
data_to_store = serialized_data
|
||||||
else:
|
else:
|
||||||
|
secret_key = self._memcache_secret_key
|
||||||
|
if isinstance(secret_key, six.string_types):
|
||||||
|
secret_key = secret_key.encode('utf-8')
|
||||||
|
security_strategy = self._memcache_security_strategy
|
||||||
|
if isinstance(security_strategy, six.string_types):
|
||||||
|
security_strategy = security_strategy.encode('utf-8')
|
||||||
keys = memcache_crypt.derive_keys(
|
keys = memcache_crypt.derive_keys(
|
||||||
token_id,
|
token_id, secret_key, security_strategy)
|
||||||
self._memcache_secret_key,
|
|
||||||
self._memcache_security_strategy)
|
|
||||||
cache_key = CACHE_KEY_TEMPLATE % memcache_crypt.get_cache_key(keys)
|
cache_key = CACHE_KEY_TEMPLATE % memcache_crypt.get_cache_key(keys)
|
||||||
data_to_store = memcache_crypt.protect_data(keys, serialized_data)
|
data_to_store = memcache_crypt.protect_data(keys, serialized_data)
|
||||||
|
|
||||||
|
@@ -35,6 +35,8 @@ import hashlib
|
|||||||
import hmac
|
import hmac
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import six
|
||||||
|
import sys
|
||||||
|
|
||||||
# make sure pycrypto is available
|
# make sure pycrypto is available
|
||||||
try:
|
try:
|
||||||
@@ -82,19 +84,26 @@ def assert_crypto_availability(f):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def constant_time_compare(first, second):
|
if sys.version_info >= (3, 3):
|
||||||
"""Returns True if both string inputs are equal, otherwise False.
|
constant_time_compare = hmac.compare_digest
|
||||||
|
else:
|
||||||
|
def constant_time_compare(first, second):
|
||||||
|
"""Returns True if both string inputs are equal, otherwise False.
|
||||||
|
|
||||||
This function should take a constant amount of time regardless of
|
This function should take a constant amount of time regardless of
|
||||||
how many characters in the strings match.
|
how many characters in the strings match.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if len(first) != len(second):
|
if len(first) != len(second):
|
||||||
return False
|
return False
|
||||||
result = 0
|
result = 0
|
||||||
for x, y in zip(first, second):
|
if six.PY3 and isinstance(first, bytes) and isinstance(second, bytes):
|
||||||
result |= ord(x) ^ ord(y)
|
for x, y in zip(first, second):
|
||||||
return result == 0
|
result |= x ^ y
|
||||||
|
else:
|
||||||
|
for x, y in zip(first, second):
|
||||||
|
result |= ord(x) ^ ord(y)
|
||||||
|
return result == 0
|
||||||
|
|
||||||
|
|
||||||
def derive_keys(token, secret, strategy):
|
def derive_keys(token, secret, strategy):
|
||||||
@@ -132,7 +141,7 @@ def encrypt_data(key, data):
|
|||||||
iv = os.urandom(16)
|
iv = os.urandom(16)
|
||||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||||
padding = 16 - len(data) % 16
|
padding = 16 - len(data) % 16
|
||||||
return iv + cipher.encrypt(data + chr(padding) * padding)
|
return iv + cipher.encrypt(data + six.int2byte(padding) * padding)
|
||||||
|
|
||||||
|
|
||||||
@assert_crypto_availability
|
@assert_crypto_availability
|
||||||
@@ -147,8 +156,7 @@ def decrypt_data(key, data):
|
|||||||
|
|
||||||
# Strip the last n padding bytes where n is the last value in
|
# Strip the last n padding bytes where n is the last value in
|
||||||
# the plaintext
|
# the plaintext
|
||||||
padding = ord(result[-1])
|
return result[:-1 * six.byte2int([result[-1]])]
|
||||||
return result[:-1 * padding]
|
|
||||||
|
|
||||||
|
|
||||||
def protect_data(keys, data):
|
def protect_data(keys, data):
|
||||||
@@ -156,7 +164,7 @@ def protect_data(keys, data):
|
|||||||
protected string suitable for storage in the cache.
|
protected string suitable for storage in the cache.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if keys['strategy'] == 'ENCRYPT':
|
if keys['strategy'] == b'ENCRYPT':
|
||||||
data = encrypt_data(keys['ENCRYPTION'], data)
|
data = encrypt_data(keys['ENCRYPTION'], data)
|
||||||
|
|
||||||
encoded_data = base64.b64encode(data)
|
encoded_data = base64.b64encode(data)
|
||||||
@@ -188,7 +196,7 @@ def unprotect_data(keys, signed_data):
|
|||||||
data = base64.b64decode(signed_data[DIGEST_LENGTH_B64:])
|
data = base64.b64decode(signed_data[DIGEST_LENGTH_B64:])
|
||||||
|
|
||||||
# then if necessary decrypt the data
|
# then if necessary decrypt the data
|
||||||
if keys['strategy'] == 'ENCRYPT':
|
if keys['strategy'] == b'ENCRYPT':
|
||||||
data = decrypt_data(keys['ENCRYPTION'], data)
|
data = decrypt_data(keys['ENCRYPTION'], data)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
@@ -789,7 +789,7 @@ class CommonAuthTokenMiddlewareTest(object):
|
|||||||
'memcache_secret_key': 'mysecret'
|
'memcache_secret_key': 'mysecret'
|
||||||
}
|
}
|
||||||
self.set_middleware(conf=conf)
|
self.set_middleware(conf=conf)
|
||||||
token = 'my_token'
|
token = b'my_token'
|
||||||
some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4)
|
some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4)
|
||||||
expires = timeutils.strtime(some_time_later)
|
expires = timeutils.strtime(some_time_later)
|
||||||
data = ('this_data', expires)
|
data = ('this_data', expires)
|
||||||
@@ -805,7 +805,7 @@ class CommonAuthTokenMiddlewareTest(object):
|
|||||||
'memcache_secret_key': 'mysecret'
|
'memcache_secret_key': 'mysecret'
|
||||||
}
|
}
|
||||||
self.set_middleware(conf=conf)
|
self.set_middleware(conf=conf)
|
||||||
token = 'my_token'
|
token = b'my_token'
|
||||||
some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4)
|
some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4)
|
||||||
expires = timeutils.strtime(some_time_later)
|
expires = timeutils.strtime(some_time_later)
|
||||||
data = ('this_data', expires)
|
data = ('this_data', expires)
|
||||||
|
@@ -12,6 +12,7 @@
|
|||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
import six
|
||||||
import testtools
|
import testtools
|
||||||
|
|
||||||
from keystoneclient.middleware import memcache_crypt
|
from keystoneclient.middleware import memcache_crypt
|
||||||
@@ -19,7 +20,7 @@ from keystoneclient.middleware import memcache_crypt
|
|||||||
|
|
||||||
class MemcacheCryptPositiveTests(testtools.TestCase):
|
class MemcacheCryptPositiveTests(testtools.TestCase):
|
||||||
def _setup_keys(self, strategy):
|
def _setup_keys(self, strategy):
|
||||||
return memcache_crypt.derive_keys('token', 'secret', strategy)
|
return memcache_crypt.derive_keys(b'token', b'secret', strategy)
|
||||||
|
|
||||||
def test_constant_time_compare(self):
|
def test_constant_time_compare(self):
|
||||||
# make sure it works as a compare, the "constant time" aspect
|
# make sure it works as a compare, the "constant time" aspect
|
||||||
@@ -32,8 +33,18 @@ class MemcacheCryptPositiveTests(testtools.TestCase):
|
|||||||
self.assertFalse(ctc('abc', 'abc\x00'))
|
self.assertFalse(ctc('abc', 'abc\x00'))
|
||||||
self.assertFalse(ctc('', 'abc'))
|
self.assertFalse(ctc('', 'abc'))
|
||||||
|
|
||||||
|
# For Python 3, we want to test these functions with both str and bytes
|
||||||
|
# as input.
|
||||||
|
if six.PY3:
|
||||||
|
self.assertTrue(ctc(b'abcd', b'abcd'))
|
||||||
|
self.assertTrue(ctc(b'', b''))
|
||||||
|
self.assertFalse(ctc(b'abcd', b'efgh'))
|
||||||
|
self.assertFalse(ctc(b'abc', b'abcd'))
|
||||||
|
self.assertFalse(ctc(b'abc', b'abc\x00'))
|
||||||
|
self.assertFalse(ctc(b'', b'abc'))
|
||||||
|
|
||||||
def test_derive_keys(self):
|
def test_derive_keys(self):
|
||||||
keys = memcache_crypt.derive_keys('token', 'secret', 'strategy')
|
keys = self._setup_keys(b'strategy')
|
||||||
self.assertEqual(len(keys['ENCRYPTION']),
|
self.assertEqual(len(keys['ENCRYPTION']),
|
||||||
len(keys['CACHE_KEY']))
|
len(keys['CACHE_KEY']))
|
||||||
self.assertEqual(len(keys['CACHE_KEY']),
|
self.assertEqual(len(keys['CACHE_KEY']),
|
||||||
@@ -43,20 +54,20 @@ class MemcacheCryptPositiveTests(testtools.TestCase):
|
|||||||
self.assertIn('strategy', keys.keys())
|
self.assertIn('strategy', keys.keys())
|
||||||
|
|
||||||
def test_key_strategy_diff(self):
|
def test_key_strategy_diff(self):
|
||||||
k1 = self._setup_keys('MAC')
|
k1 = self._setup_keys(b'MAC')
|
||||||
k2 = self._setup_keys('ENCRYPT')
|
k2 = self._setup_keys(b'ENCRYPT')
|
||||||
self.assertNotEqual(k1, k2)
|
self.assertNotEqual(k1, k2)
|
||||||
|
|
||||||
def test_sign_data(self):
|
def test_sign_data(self):
|
||||||
keys = self._setup_keys('MAC')
|
keys = self._setup_keys(b'MAC')
|
||||||
sig = memcache_crypt.sign_data(keys['MAC'], 'data')
|
sig = memcache_crypt.sign_data(keys['MAC'], b'data')
|
||||||
self.assertEqual(len(sig), memcache_crypt.DIGEST_LENGTH_B64)
|
self.assertEqual(len(sig), memcache_crypt.DIGEST_LENGTH_B64)
|
||||||
|
|
||||||
def test_encryption(self):
|
def test_encryption(self):
|
||||||
keys = self._setup_keys('ENCRYPT')
|
keys = self._setup_keys(b'ENCRYPT')
|
||||||
# what you put in is what you get out
|
# what you put in is what you get out
|
||||||
for data in ['data', '1234567890123456', '\x00\xFF' * 13
|
for data in [b'data', b'1234567890123456', b'\x00\xFF' * 13
|
||||||
] + [chr(x % 256) * x for x in range(768)]:
|
] + [six.int2byte(x % 256) * x for x in range(768)]:
|
||||||
crypt = memcache_crypt.encrypt_data(keys['ENCRYPTION'], data)
|
crypt = memcache_crypt.encrypt_data(keys['ENCRYPTION'], data)
|
||||||
decrypt = memcache_crypt.decrypt_data(keys['ENCRYPTION'], crypt)
|
decrypt = memcache_crypt.decrypt_data(keys['ENCRYPTION'], crypt)
|
||||||
self.assertEqual(data, decrypt)
|
self.assertEqual(data, decrypt)
|
||||||
@@ -65,12 +76,12 @@ class MemcacheCryptPositiveTests(testtools.TestCase):
|
|||||||
keys['ENCRYPTION'], crypt[:-1])
|
keys['ENCRYPTION'], crypt[:-1])
|
||||||
|
|
||||||
def test_protect_wrappers(self):
|
def test_protect_wrappers(self):
|
||||||
data = 'My Pretty Little Data'
|
data = b'My Pretty Little Data'
|
||||||
for strategy in ['MAC', 'ENCRYPT']:
|
for strategy in [b'MAC', b'ENCRYPT']:
|
||||||
keys = self._setup_keys(strategy)
|
keys = self._setup_keys(strategy)
|
||||||
protected = memcache_crypt.protect_data(keys, data)
|
protected = memcache_crypt.protect_data(keys, data)
|
||||||
self.assertNotEqual(protected, data)
|
self.assertNotEqual(protected, data)
|
||||||
if strategy == 'ENCRYPT':
|
if strategy == b'ENCRYPT':
|
||||||
self.assertNotIn(data, protected)
|
self.assertNotIn(data, protected)
|
||||||
unprotected = memcache_crypt.unprotect_data(keys, protected)
|
unprotected = memcache_crypt.unprotect_data(keys, protected)
|
||||||
self.assertEqual(data, unprotected)
|
self.assertEqual(data, unprotected)
|
||||||
|
Reference in New Issue
Block a user