diff --git a/keystoneclient/middleware/auth_token.py b/keystoneclient/middleware/auth_token.py index 8a6add57d..88dca1865 100644 --- a/keystoneclient/middleware/auth_token.py +++ b/keystoneclient/middleware/auth_token.py @@ -1022,6 +1022,8 @@ class AuthProtocol(object): # Note that 'invalid' and (data, expires) are the only # valid types of serialized cache entries, so there is not # a collision with jsonutils.loads(serialized) == None. + if not isinstance(serialized, six.string_types): + serialized = serialized.decode('utf-8') cached = jsonutils.loads(serialized) if cached == 'invalid': self.LOG.debug('Cached Token %s is marked unauthorized', @@ -1053,14 +1055,20 @@ class AuthProtocol(object): """ serialized_data = jsonutils.dumps(data) + if isinstance(serialized_data, six.text_type): + serialized_data = serialized_data.encode('utf-8') if self._memcache_security_strategy is None: cache_key = CACHE_KEY_TEMPLATE % token_id data_to_store = serialized_data else: + secret_key = self._memcache_secret_key + if isinstance(secret_key, six.string_types): + secret_key = secret_key.encode('utf-8') + security_strategy = self._memcache_security_strategy + if isinstance(security_strategy, six.string_types): + security_strategy = security_strategy.encode('utf-8') keys = memcache_crypt.derive_keys( - token_id, - self._memcache_secret_key, - self._memcache_security_strategy) + token_id, secret_key, security_strategy) cache_key = CACHE_KEY_TEMPLATE % memcache_crypt.get_cache_key(keys) data_to_store = memcache_crypt.protect_data(keys, serialized_data) diff --git a/keystoneclient/middleware/memcache_crypt.py b/keystoneclient/middleware/memcache_crypt.py index 878f2e944..8bae5068f 100644 --- a/keystoneclient/middleware/memcache_crypt.py +++ b/keystoneclient/middleware/memcache_crypt.py @@ -35,6 +35,8 @@ import hashlib import hmac import math import os +import six +import sys # make sure pycrypto is available try: @@ -82,19 +84,26 @@ def assert_crypto_availability(f): return wrapper -def constant_time_compare(first, second): - """Returns True if both string inputs are equal, otherwise False. +if sys.version_info >= (3, 3): + constant_time_compare = hmac.compare_digest +else: + def constant_time_compare(first, second): + """Returns True if both string inputs are equal, otherwise False. - This function should take a constant amount of time regardless of - how many characters in the strings match. + This function should take a constant amount of time regardless of + how many characters in the strings match. - """ - if len(first) != len(second): - return False - result = 0 - for x, y in zip(first, second): - result |= ord(x) ^ ord(y) - return result == 0 + """ + if len(first) != len(second): + return False + result = 0 + if six.PY3 and isinstance(first, bytes) and isinstance(second, bytes): + for x, y in zip(first, second): + result |= x ^ y + else: + for x, y in zip(first, second): + result |= ord(x) ^ ord(y) + return result == 0 def derive_keys(token, secret, strategy): @@ -132,7 +141,7 @@ def encrypt_data(key, data): iv = os.urandom(16) cipher = AES.new(key, AES.MODE_CBC, iv) padding = 16 - len(data) % 16 - return iv + cipher.encrypt(data + chr(padding) * padding) + return iv + cipher.encrypt(data + six.int2byte(padding) * padding) @assert_crypto_availability @@ -147,8 +156,7 @@ def decrypt_data(key, data): # Strip the last n padding bytes where n is the last value in # the plaintext - padding = ord(result[-1]) - return result[:-1 * padding] + return result[:-1 * six.byte2int([result[-1]])] def protect_data(keys, data): @@ -156,7 +164,7 @@ def protect_data(keys, data): protected string suitable for storage in the cache. """ - if keys['strategy'] == 'ENCRYPT': + if keys['strategy'] == b'ENCRYPT': data = encrypt_data(keys['ENCRYPTION'], data) encoded_data = base64.b64encode(data) @@ -188,7 +196,7 @@ def unprotect_data(keys, signed_data): data = base64.b64decode(signed_data[DIGEST_LENGTH_B64:]) # then if necessary decrypt the data - if keys['strategy'] == 'ENCRYPT': + if keys['strategy'] == b'ENCRYPT': data = decrypt_data(keys['ENCRYPTION'], data) return data diff --git a/keystoneclient/tests/test_auth_token_middleware.py b/keystoneclient/tests/test_auth_token_middleware.py index db7bae271..bbd64d2c2 100644 --- a/keystoneclient/tests/test_auth_token_middleware.py +++ b/keystoneclient/tests/test_auth_token_middleware.py @@ -789,7 +789,7 @@ class CommonAuthTokenMiddlewareTest(object): 'memcache_secret_key': 'mysecret' } self.set_middleware(conf=conf) - token = 'my_token' + token = b'my_token' some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4) expires = timeutils.strtime(some_time_later) data = ('this_data', expires) @@ -805,7 +805,7 @@ class CommonAuthTokenMiddlewareTest(object): 'memcache_secret_key': 'mysecret' } self.set_middleware(conf=conf) - token = 'my_token' + token = b'my_token' some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4) expires = timeutils.strtime(some_time_later) data = ('this_data', expires) diff --git a/keystoneclient/tests/test_memcache_crypt.py b/keystoneclient/tests/test_memcache_crypt.py index 500a50986..159898b95 100644 --- a/keystoneclient/tests/test_memcache_crypt.py +++ b/keystoneclient/tests/test_memcache_crypt.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +import six import testtools from keystoneclient.middleware import memcache_crypt @@ -19,7 +20,7 @@ from keystoneclient.middleware import memcache_crypt class MemcacheCryptPositiveTests(testtools.TestCase): def _setup_keys(self, strategy): - return memcache_crypt.derive_keys('token', 'secret', strategy) + return memcache_crypt.derive_keys(b'token', b'secret', strategy) def test_constant_time_compare(self): # make sure it works as a compare, the "constant time" aspect @@ -32,8 +33,18 @@ class MemcacheCryptPositiveTests(testtools.TestCase): self.assertFalse(ctc('abc', 'abc\x00')) self.assertFalse(ctc('', 'abc')) + # For Python 3, we want to test these functions with both str and bytes + # as input. + if six.PY3: + self.assertTrue(ctc(b'abcd', b'abcd')) + self.assertTrue(ctc(b'', b'')) + self.assertFalse(ctc(b'abcd', b'efgh')) + self.assertFalse(ctc(b'abc', b'abcd')) + self.assertFalse(ctc(b'abc', b'abc\x00')) + self.assertFalse(ctc(b'', b'abc')) + def test_derive_keys(self): - keys = memcache_crypt.derive_keys('token', 'secret', 'strategy') + keys = self._setup_keys(b'strategy') self.assertEqual(len(keys['ENCRYPTION']), len(keys['CACHE_KEY'])) self.assertEqual(len(keys['CACHE_KEY']), @@ -43,20 +54,20 @@ class MemcacheCryptPositiveTests(testtools.TestCase): self.assertIn('strategy', keys.keys()) def test_key_strategy_diff(self): - k1 = self._setup_keys('MAC') - k2 = self._setup_keys('ENCRYPT') + k1 = self._setup_keys(b'MAC') + k2 = self._setup_keys(b'ENCRYPT') self.assertNotEqual(k1, k2) def test_sign_data(self): - keys = self._setup_keys('MAC') - sig = memcache_crypt.sign_data(keys['MAC'], 'data') + keys = self._setup_keys(b'MAC') + sig = memcache_crypt.sign_data(keys['MAC'], b'data') self.assertEqual(len(sig), memcache_crypt.DIGEST_LENGTH_B64) def test_encryption(self): - keys = self._setup_keys('ENCRYPT') + keys = self._setup_keys(b'ENCRYPT') # what you put in is what you get out - for data in ['data', '1234567890123456', '\x00\xFF' * 13 - ] + [chr(x % 256) * x for x in range(768)]: + for data in [b'data', b'1234567890123456', b'\x00\xFF' * 13 + ] + [six.int2byte(x % 256) * x for x in range(768)]: crypt = memcache_crypt.encrypt_data(keys['ENCRYPTION'], data) decrypt = memcache_crypt.decrypt_data(keys['ENCRYPTION'], crypt) self.assertEqual(data, decrypt) @@ -65,12 +76,12 @@ class MemcacheCryptPositiveTests(testtools.TestCase): keys['ENCRYPTION'], crypt[:-1]) def test_protect_wrappers(self): - data = 'My Pretty Little Data' - for strategy in ['MAC', 'ENCRYPT']: + data = b'My Pretty Little Data' + for strategy in [b'MAC', b'ENCRYPT']: keys = self._setup_keys(strategy) protected = memcache_crypt.protect_data(keys, data) self.assertNotEqual(protected, data) - if strategy == 'ENCRYPT': + if strategy == b'ENCRYPT': self.assertNotIn(data, protected) unprotected = memcache_crypt.unprotect_data(keys, protected) self.assertEqual(data, unprotected)