Python 3: call functions from memcache_crypt.py with bytes as input
These functions expect bytes as input, but in Python 3 they were given text strings. Change-Id: I39fa15b8d5d56dc536e0bd71a50cf27da3d03046
This commit is contained in:
@@ -1022,6 +1022,8 @@ class AuthProtocol(object):
|
||||
# Note that 'invalid' and (data, expires) are the only
|
||||
# valid types of serialized cache entries, so there is not
|
||||
# a collision with jsonutils.loads(serialized) == None.
|
||||
if not isinstance(serialized, six.string_types):
|
||||
serialized = serialized.decode('utf-8')
|
||||
cached = jsonutils.loads(serialized)
|
||||
if cached == 'invalid':
|
||||
self.LOG.debug('Cached Token %s is marked unauthorized',
|
||||
@@ -1053,14 +1055,20 @@ class AuthProtocol(object):
|
||||
|
||||
"""
|
||||
serialized_data = jsonutils.dumps(data)
|
||||
if isinstance(serialized_data, six.text_type):
|
||||
serialized_data = serialized_data.encode('utf-8')
|
||||
if self._memcache_security_strategy is None:
|
||||
cache_key = CACHE_KEY_TEMPLATE % token_id
|
||||
data_to_store = serialized_data
|
||||
else:
|
||||
secret_key = self._memcache_secret_key
|
||||
if isinstance(secret_key, six.string_types):
|
||||
secret_key = secret_key.encode('utf-8')
|
||||
security_strategy = self._memcache_security_strategy
|
||||
if isinstance(security_strategy, six.string_types):
|
||||
security_strategy = security_strategy.encode('utf-8')
|
||||
keys = memcache_crypt.derive_keys(
|
||||
token_id,
|
||||
self._memcache_secret_key,
|
||||
self._memcache_security_strategy)
|
||||
token_id, secret_key, security_strategy)
|
||||
cache_key = CACHE_KEY_TEMPLATE % memcache_crypt.get_cache_key(keys)
|
||||
data_to_store = memcache_crypt.protect_data(keys, serialized_data)
|
||||
|
||||
|
@@ -35,6 +35,8 @@ import hashlib
|
||||
import hmac
|
||||
import math
|
||||
import os
|
||||
import six
|
||||
import sys
|
||||
|
||||
# make sure pycrypto is available
|
||||
try:
|
||||
@@ -82,6 +84,9 @@ def assert_crypto_availability(f):
|
||||
return wrapper
|
||||
|
||||
|
||||
if sys.version_info >= (3, 3):
|
||||
constant_time_compare = hmac.compare_digest
|
||||
else:
|
||||
def constant_time_compare(first, second):
|
||||
"""Returns True if both string inputs are equal, otherwise False.
|
||||
|
||||
@@ -92,6 +97,10 @@ def constant_time_compare(first, second):
|
||||
if len(first) != len(second):
|
||||
return False
|
||||
result = 0
|
||||
if six.PY3 and isinstance(first, bytes) and isinstance(second, bytes):
|
||||
for x, y in zip(first, second):
|
||||
result |= x ^ y
|
||||
else:
|
||||
for x, y in zip(first, second):
|
||||
result |= ord(x) ^ ord(y)
|
||||
return result == 0
|
||||
@@ -132,7 +141,7 @@ def encrypt_data(key, data):
|
||||
iv = os.urandom(16)
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
padding = 16 - len(data) % 16
|
||||
return iv + cipher.encrypt(data + chr(padding) * padding)
|
||||
return iv + cipher.encrypt(data + six.int2byte(padding) * padding)
|
||||
|
||||
|
||||
@assert_crypto_availability
|
||||
@@ -147,8 +156,7 @@ def decrypt_data(key, data):
|
||||
|
||||
# Strip the last n padding bytes where n is the last value in
|
||||
# the plaintext
|
||||
padding = ord(result[-1])
|
||||
return result[:-1 * padding]
|
||||
return result[:-1 * six.byte2int([result[-1]])]
|
||||
|
||||
|
||||
def protect_data(keys, data):
|
||||
@@ -156,7 +164,7 @@ def protect_data(keys, data):
|
||||
protected string suitable for storage in the cache.
|
||||
|
||||
"""
|
||||
if keys['strategy'] == 'ENCRYPT':
|
||||
if keys['strategy'] == b'ENCRYPT':
|
||||
data = encrypt_data(keys['ENCRYPTION'], data)
|
||||
|
||||
encoded_data = base64.b64encode(data)
|
||||
@@ -188,7 +196,7 @@ def unprotect_data(keys, signed_data):
|
||||
data = base64.b64decode(signed_data[DIGEST_LENGTH_B64:])
|
||||
|
||||
# then if necessary decrypt the data
|
||||
if keys['strategy'] == 'ENCRYPT':
|
||||
if keys['strategy'] == b'ENCRYPT':
|
||||
data = decrypt_data(keys['ENCRYPTION'], data)
|
||||
|
||||
return data
|
||||
|
@@ -789,7 +789,7 @@ class CommonAuthTokenMiddlewareTest(object):
|
||||
'memcache_secret_key': 'mysecret'
|
||||
}
|
||||
self.set_middleware(conf=conf)
|
||||
token = 'my_token'
|
||||
token = b'my_token'
|
||||
some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4)
|
||||
expires = timeutils.strtime(some_time_later)
|
||||
data = ('this_data', expires)
|
||||
@@ -805,7 +805,7 @@ class CommonAuthTokenMiddlewareTest(object):
|
||||
'memcache_secret_key': 'mysecret'
|
||||
}
|
||||
self.set_middleware(conf=conf)
|
||||
token = 'my_token'
|
||||
token = b'my_token'
|
||||
some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4)
|
||||
expires = timeutils.strtime(some_time_later)
|
||||
data = ('this_data', expires)
|
||||
|
@@ -12,6 +12,7 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import six
|
||||
import testtools
|
||||
|
||||
from keystoneclient.middleware import memcache_crypt
|
||||
@@ -19,7 +20,7 @@ from keystoneclient.middleware import memcache_crypt
|
||||
|
||||
class MemcacheCryptPositiveTests(testtools.TestCase):
|
||||
def _setup_keys(self, strategy):
|
||||
return memcache_crypt.derive_keys('token', 'secret', strategy)
|
||||
return memcache_crypt.derive_keys(b'token', b'secret', strategy)
|
||||
|
||||
def test_constant_time_compare(self):
|
||||
# make sure it works as a compare, the "constant time" aspect
|
||||
@@ -32,8 +33,18 @@ class MemcacheCryptPositiveTests(testtools.TestCase):
|
||||
self.assertFalse(ctc('abc', 'abc\x00'))
|
||||
self.assertFalse(ctc('', 'abc'))
|
||||
|
||||
# For Python 3, we want to test these functions with both str and bytes
|
||||
# as input.
|
||||
if six.PY3:
|
||||
self.assertTrue(ctc(b'abcd', b'abcd'))
|
||||
self.assertTrue(ctc(b'', b''))
|
||||
self.assertFalse(ctc(b'abcd', b'efgh'))
|
||||
self.assertFalse(ctc(b'abc', b'abcd'))
|
||||
self.assertFalse(ctc(b'abc', b'abc\x00'))
|
||||
self.assertFalse(ctc(b'', b'abc'))
|
||||
|
||||
def test_derive_keys(self):
|
||||
keys = memcache_crypt.derive_keys('token', 'secret', 'strategy')
|
||||
keys = self._setup_keys(b'strategy')
|
||||
self.assertEqual(len(keys['ENCRYPTION']),
|
||||
len(keys['CACHE_KEY']))
|
||||
self.assertEqual(len(keys['CACHE_KEY']),
|
||||
@@ -43,20 +54,20 @@ class MemcacheCryptPositiveTests(testtools.TestCase):
|
||||
self.assertIn('strategy', keys.keys())
|
||||
|
||||
def test_key_strategy_diff(self):
|
||||
k1 = self._setup_keys('MAC')
|
||||
k2 = self._setup_keys('ENCRYPT')
|
||||
k1 = self._setup_keys(b'MAC')
|
||||
k2 = self._setup_keys(b'ENCRYPT')
|
||||
self.assertNotEqual(k1, k2)
|
||||
|
||||
def test_sign_data(self):
|
||||
keys = self._setup_keys('MAC')
|
||||
sig = memcache_crypt.sign_data(keys['MAC'], 'data')
|
||||
keys = self._setup_keys(b'MAC')
|
||||
sig = memcache_crypt.sign_data(keys['MAC'], b'data')
|
||||
self.assertEqual(len(sig), memcache_crypt.DIGEST_LENGTH_B64)
|
||||
|
||||
def test_encryption(self):
|
||||
keys = self._setup_keys('ENCRYPT')
|
||||
keys = self._setup_keys(b'ENCRYPT')
|
||||
# what you put in is what you get out
|
||||
for data in ['data', '1234567890123456', '\x00\xFF' * 13
|
||||
] + [chr(x % 256) * x for x in range(768)]:
|
||||
for data in [b'data', b'1234567890123456', b'\x00\xFF' * 13
|
||||
] + [six.int2byte(x % 256) * x for x in range(768)]:
|
||||
crypt = memcache_crypt.encrypt_data(keys['ENCRYPTION'], data)
|
||||
decrypt = memcache_crypt.decrypt_data(keys['ENCRYPTION'], crypt)
|
||||
self.assertEqual(data, decrypt)
|
||||
@@ -65,12 +76,12 @@ class MemcacheCryptPositiveTests(testtools.TestCase):
|
||||
keys['ENCRYPTION'], crypt[:-1])
|
||||
|
||||
def test_protect_wrappers(self):
|
||||
data = 'My Pretty Little Data'
|
||||
for strategy in ['MAC', 'ENCRYPT']:
|
||||
data = b'My Pretty Little Data'
|
||||
for strategy in [b'MAC', b'ENCRYPT']:
|
||||
keys = self._setup_keys(strategy)
|
||||
protected = memcache_crypt.protect_data(keys, data)
|
||||
self.assertNotEqual(protected, data)
|
||||
if strategy == 'ENCRYPT':
|
||||
if strategy == b'ENCRYPT':
|
||||
self.assertNotIn(data, protected)
|
||||
unprotected = memcache_crypt.unprotect_data(keys, protected)
|
||||
self.assertEqual(data, unprotected)
|
||||
|
Reference in New Issue
Block a user