Python 3: call functions from memcache_crypt.py with bytes as input

These functions expect bytes as input, but in Python 3 they were given text
strings.

Change-Id: I39fa15b8d5d56dc536e0bd71a50cf27da3d03046
This commit is contained in:
Cyril Roelandt
2014-01-23 23:02:38 +01:00
parent eab811c307
commit d71b5b3460
4 changed files with 60 additions and 33 deletions

View File

@@ -1022,6 +1022,8 @@ class AuthProtocol(object):
# Note that 'invalid' and (data, expires) are the only # Note that 'invalid' and (data, expires) are the only
# valid types of serialized cache entries, so there is not # valid types of serialized cache entries, so there is not
# a collision with jsonutils.loads(serialized) == None. # a collision with jsonutils.loads(serialized) == None.
if not isinstance(serialized, six.string_types):
serialized = serialized.decode('utf-8')
cached = jsonutils.loads(serialized) cached = jsonutils.loads(serialized)
if cached == 'invalid': if cached == 'invalid':
self.LOG.debug('Cached Token %s is marked unauthorized', self.LOG.debug('Cached Token %s is marked unauthorized',
@@ -1053,14 +1055,20 @@ class AuthProtocol(object):
""" """
serialized_data = jsonutils.dumps(data) serialized_data = jsonutils.dumps(data)
if isinstance(serialized_data, six.text_type):
serialized_data = serialized_data.encode('utf-8')
if self._memcache_security_strategy is None: if self._memcache_security_strategy is None:
cache_key = CACHE_KEY_TEMPLATE % token_id cache_key = CACHE_KEY_TEMPLATE % token_id
data_to_store = serialized_data data_to_store = serialized_data
else: else:
secret_key = self._memcache_secret_key
if isinstance(secret_key, six.string_types):
secret_key = secret_key.encode('utf-8')
security_strategy = self._memcache_security_strategy
if isinstance(security_strategy, six.string_types):
security_strategy = security_strategy.encode('utf-8')
keys = memcache_crypt.derive_keys( keys = memcache_crypt.derive_keys(
token_id, token_id, secret_key, security_strategy)
self._memcache_secret_key,
self._memcache_security_strategy)
cache_key = CACHE_KEY_TEMPLATE % memcache_crypt.get_cache_key(keys) cache_key = CACHE_KEY_TEMPLATE % memcache_crypt.get_cache_key(keys)
data_to_store = memcache_crypt.protect_data(keys, serialized_data) data_to_store = memcache_crypt.protect_data(keys, serialized_data)

View File

@@ -35,6 +35,8 @@ import hashlib
import hmac import hmac
import math import math
import os import os
import six
import sys
# make sure pycrypto is available # make sure pycrypto is available
try: try:
@@ -82,6 +84,9 @@ def assert_crypto_availability(f):
return wrapper return wrapper
if sys.version_info >= (3, 3):
constant_time_compare = hmac.compare_digest
else:
def constant_time_compare(first, second): def constant_time_compare(first, second):
"""Returns True if both string inputs are equal, otherwise False. """Returns True if both string inputs are equal, otherwise False.
@@ -92,6 +97,10 @@ def constant_time_compare(first, second):
if len(first) != len(second): if len(first) != len(second):
return False return False
result = 0 result = 0
if six.PY3 and isinstance(first, bytes) and isinstance(second, bytes):
for x, y in zip(first, second):
result |= x ^ y
else:
for x, y in zip(first, second): for x, y in zip(first, second):
result |= ord(x) ^ ord(y) result |= ord(x) ^ ord(y)
return result == 0 return result == 0
@@ -132,7 +141,7 @@ def encrypt_data(key, data):
iv = os.urandom(16) iv = os.urandom(16)
cipher = AES.new(key, AES.MODE_CBC, iv) cipher = AES.new(key, AES.MODE_CBC, iv)
padding = 16 - len(data) % 16 padding = 16 - len(data) % 16
return iv + cipher.encrypt(data + chr(padding) * padding) return iv + cipher.encrypt(data + six.int2byte(padding) * padding)
@assert_crypto_availability @assert_crypto_availability
@@ -147,8 +156,7 @@ def decrypt_data(key, data):
# Strip the last n padding bytes where n is the last value in # Strip the last n padding bytes where n is the last value in
# the plaintext # the plaintext
padding = ord(result[-1]) return result[:-1 * six.byte2int([result[-1]])]
return result[:-1 * padding]
def protect_data(keys, data): def protect_data(keys, data):
@@ -156,7 +164,7 @@ def protect_data(keys, data):
protected string suitable for storage in the cache. protected string suitable for storage in the cache.
""" """
if keys['strategy'] == 'ENCRYPT': if keys['strategy'] == b'ENCRYPT':
data = encrypt_data(keys['ENCRYPTION'], data) data = encrypt_data(keys['ENCRYPTION'], data)
encoded_data = base64.b64encode(data) encoded_data = base64.b64encode(data)
@@ -188,7 +196,7 @@ def unprotect_data(keys, signed_data):
data = base64.b64decode(signed_data[DIGEST_LENGTH_B64:]) data = base64.b64decode(signed_data[DIGEST_LENGTH_B64:])
# then if necessary decrypt the data # then if necessary decrypt the data
if keys['strategy'] == 'ENCRYPT': if keys['strategy'] == b'ENCRYPT':
data = decrypt_data(keys['ENCRYPTION'], data) data = decrypt_data(keys['ENCRYPTION'], data)
return data return data

View File

@@ -789,7 +789,7 @@ class CommonAuthTokenMiddlewareTest(object):
'memcache_secret_key': 'mysecret' 'memcache_secret_key': 'mysecret'
} }
self.set_middleware(conf=conf) self.set_middleware(conf=conf)
token = 'my_token' token = b'my_token'
some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4) some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4)
expires = timeutils.strtime(some_time_later) expires = timeutils.strtime(some_time_later)
data = ('this_data', expires) data = ('this_data', expires)
@@ -805,7 +805,7 @@ class CommonAuthTokenMiddlewareTest(object):
'memcache_secret_key': 'mysecret' 'memcache_secret_key': 'mysecret'
} }
self.set_middleware(conf=conf) self.set_middleware(conf=conf)
token = 'my_token' token = b'my_token'
some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4) some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4)
expires = timeutils.strtime(some_time_later) expires = timeutils.strtime(some_time_later)
data = ('this_data', expires) data = ('this_data', expires)

View File

@@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import six
import testtools import testtools
from keystoneclient.middleware import memcache_crypt from keystoneclient.middleware import memcache_crypt
@@ -19,7 +20,7 @@ from keystoneclient.middleware import memcache_crypt
class MemcacheCryptPositiveTests(testtools.TestCase): class MemcacheCryptPositiveTests(testtools.TestCase):
def _setup_keys(self, strategy): def _setup_keys(self, strategy):
return memcache_crypt.derive_keys('token', 'secret', strategy) return memcache_crypt.derive_keys(b'token', b'secret', strategy)
def test_constant_time_compare(self): def test_constant_time_compare(self):
# make sure it works as a compare, the "constant time" aspect # make sure it works as a compare, the "constant time" aspect
@@ -32,8 +33,18 @@ class MemcacheCryptPositiveTests(testtools.TestCase):
self.assertFalse(ctc('abc', 'abc\x00')) self.assertFalse(ctc('abc', 'abc\x00'))
self.assertFalse(ctc('', 'abc')) self.assertFalse(ctc('', 'abc'))
# For Python 3, we want to test these functions with both str and bytes
# as input.
if six.PY3:
self.assertTrue(ctc(b'abcd', b'abcd'))
self.assertTrue(ctc(b'', b''))
self.assertFalse(ctc(b'abcd', b'efgh'))
self.assertFalse(ctc(b'abc', b'abcd'))
self.assertFalse(ctc(b'abc', b'abc\x00'))
self.assertFalse(ctc(b'', b'abc'))
def test_derive_keys(self): def test_derive_keys(self):
keys = memcache_crypt.derive_keys('token', 'secret', 'strategy') keys = self._setup_keys(b'strategy')
self.assertEqual(len(keys['ENCRYPTION']), self.assertEqual(len(keys['ENCRYPTION']),
len(keys['CACHE_KEY'])) len(keys['CACHE_KEY']))
self.assertEqual(len(keys['CACHE_KEY']), self.assertEqual(len(keys['CACHE_KEY']),
@@ -43,20 +54,20 @@ class MemcacheCryptPositiveTests(testtools.TestCase):
self.assertIn('strategy', keys.keys()) self.assertIn('strategy', keys.keys())
def test_key_strategy_diff(self): def test_key_strategy_diff(self):
k1 = self._setup_keys('MAC') k1 = self._setup_keys(b'MAC')
k2 = self._setup_keys('ENCRYPT') k2 = self._setup_keys(b'ENCRYPT')
self.assertNotEqual(k1, k2) self.assertNotEqual(k1, k2)
def test_sign_data(self): def test_sign_data(self):
keys = self._setup_keys('MAC') keys = self._setup_keys(b'MAC')
sig = memcache_crypt.sign_data(keys['MAC'], 'data') sig = memcache_crypt.sign_data(keys['MAC'], b'data')
self.assertEqual(len(sig), memcache_crypt.DIGEST_LENGTH_B64) self.assertEqual(len(sig), memcache_crypt.DIGEST_LENGTH_B64)
def test_encryption(self): def test_encryption(self):
keys = self._setup_keys('ENCRYPT') keys = self._setup_keys(b'ENCRYPT')
# what you put in is what you get out # what you put in is what you get out
for data in ['data', '1234567890123456', '\x00\xFF' * 13 for data in [b'data', b'1234567890123456', b'\x00\xFF' * 13
] + [chr(x % 256) * x for x in range(768)]: ] + [six.int2byte(x % 256) * x for x in range(768)]:
crypt = memcache_crypt.encrypt_data(keys['ENCRYPTION'], data) crypt = memcache_crypt.encrypt_data(keys['ENCRYPTION'], data)
decrypt = memcache_crypt.decrypt_data(keys['ENCRYPTION'], crypt) decrypt = memcache_crypt.decrypt_data(keys['ENCRYPTION'], crypt)
self.assertEqual(data, decrypt) self.assertEqual(data, decrypt)
@@ -65,12 +76,12 @@ class MemcacheCryptPositiveTests(testtools.TestCase):
keys['ENCRYPTION'], crypt[:-1]) keys['ENCRYPTION'], crypt[:-1])
def test_protect_wrappers(self): def test_protect_wrappers(self):
data = 'My Pretty Little Data' data = b'My Pretty Little Data'
for strategy in ['MAC', 'ENCRYPT']: for strategy in [b'MAC', b'ENCRYPT']:
keys = self._setup_keys(strategy) keys = self._setup_keys(strategy)
protected = memcache_crypt.protect_data(keys, data) protected = memcache_crypt.protect_data(keys, data)
self.assertNotEqual(protected, data) self.assertNotEqual(protected, data)
if strategy == 'ENCRYPT': if strategy == b'ENCRYPT':
self.assertNotIn(data, protected) self.assertNotIn(data, protected)
unprotected = memcache_crypt.unprotect_data(keys, protected) unprotected = memcache_crypt.unprotect_data(keys, protected)
self.assertEqual(data, unprotected) self.assertEqual(data, unprotected)