diff --git a/keystoneclient/middleware/auth_token.py b/keystoneclient/middleware/auth_token.py index e2628b624..e606b9c5a 100644 --- a/keystoneclient/middleware/auth_token.py +++ b/keystoneclient/middleware/auth_token.py @@ -769,6 +769,8 @@ class AuthProtocol(object): :no longer raises ServiceError since it no longer makes RPC """ + token_id = None + try: token_id = cms.cms_hash_token(user_token) cached = self._cache_get(token_id) @@ -784,12 +786,13 @@ class AuthProtocol(object): return data except NetworkError as e: self.LOG.debug('Token validation failure.', exc_info=True) - self.LOG.warn("Authorization failed for token %s", user_token) + self.LOG.warn("Authorization failed for token %s", token_id) raise InvalidUserToken('Token authorization failed') except Exception as e: self.LOG.debug('Token validation failure.', exc_info=True) - self._cache_store_invalid(user_token) - self.LOG.warn("Authorization failed for token %s", user_token) + if token_id: + self._cache_store_invalid(token_id) + self.LOG.warn("Authorization failed for token %s", token_id) raise InvalidUserToken('Token authorization failed') def _token_is_v2(self, token_info): @@ -930,20 +933,20 @@ class AuthProtocol(object): env_key = self._header_to_env_var(key) return env.get(env_key, default) - def _cache_get(self, token, ignore_expires=False): + def _cache_get(self, token_id, ignore_expires=False): """Return token information from cache. If token is invalid raise InvalidUserToken return token only if fresh (not expired). """ - if self._cache and token: + if self._cache and token_id: if self._memcache_security_strategy is None: - key = CACHE_KEY_TEMPLATE % token + key = CACHE_KEY_TEMPLATE % token_id serialized = self._cache.get(key) else: keys = memcache_crypt.derive_keys( - token, + token_id, self._memcache_secret_key, self._memcache_security_strategy) cache_key = CACHE_KEY_TEMPLATE % ( @@ -968,17 +971,18 @@ class AuthProtocol(object): # a collision with json.loads(serialized) == None. cached = json.loads(serialized) if cached == 'invalid': - self.LOG.debug('Cached Token %s is marked unauthorized', token) + self.LOG.debug('Cached Token %s is marked unauthorized', + token_id) raise InvalidUserToken('Token authorization failed') data, expires = cached if ignore_expires or time.time() < float(expires): - self.LOG.debug('Returning cached token %s', token) + self.LOG.debug('Returning cached token %s', token_id) return data else: - self.LOG.debug('Cached Token %s seems expired', token) + self.LOG.debug('Cached Token %s seems expired', token_id) - def _cache_store(self, token, data): + def _cache_store(self, token_id, data): """Store value into memcache. data may be the string 'invalid' or a tuple like (data, expires) @@ -986,11 +990,11 @@ class AuthProtocol(object): """ serialized_data = json.dumps(data) if self._memcache_security_strategy is None: - cache_key = CACHE_KEY_TEMPLATE % token + cache_key = CACHE_KEY_TEMPLATE % token_id data_to_store = serialized_data else: keys = memcache_crypt.derive_keys( - token, + token_id, self._memcache_secret_key, self._memcache_security_strategy) cache_key = CACHE_KEY_TEMPLATE % memcache_crypt.get_cache_key(keys) @@ -1025,7 +1029,7 @@ class AuthProtocol(object): raise InvalidUserToken('Token authorization failed') return expires - def _cache_put(self, token, data, expires): + def _cache_put(self, token_id, data, expires): """Put token data into the cache. Stores the parsed expire date in cache allowing @@ -1033,15 +1037,15 @@ class AuthProtocol(object): """ if self._cache: - self.LOG.debug('Storing %s token in memcache', token) - self._cache_store(token, (data, expires)) + self.LOG.debug('Storing %s token in memcache', token_id) + self._cache_store(token_id, (data, expires)) - def _cache_store_invalid(self, token): + def _cache_store_invalid(self, token_id): """Store invalid token in cache.""" if self._cache: self.LOG.debug( - 'Marking token %s as unauthorized in memcache', token) - self._cache_store(token, 'invalid') + 'Marking token %s as unauthorized in memcache', token_id) + self._cache_store(token_id, 'invalid') def cert_file_missing(self, proc_output, file_name): return (file_name in proc_output and not os.path.exists(file_name)) diff --git a/tests/test_auth_token_middleware.py b/tests/test_auth_token_middleware.py index 439522599..8ec13fe98 100644 --- a/tests/test_auth_token_middleware.py +++ b/tests/test_auth_token_middleware.py @@ -1076,9 +1076,17 @@ class AuthTokenMiddlewareTest(BaseAuthTokenMiddlewareTest): self.middleware(req.environ, self.start_fake_response) self.assertEqual(self.response_status, 401) - def test_memcache_set_invalid(self): + def test_memcache_set_invalid_uuid(self): req = webob.Request.blank('/') - token = 'invalid-token' + token = uuid.uuid4().hex + req.headers['X-Auth-Token'] = token + self.middleware(req.environ, self.start_fake_response) + self.assertRaises(auth_token.InvalidUserToken, + self._get_cached_token, token) + + def test_memcache_set_invalid_signed(self): + req = webob.Request.blank('/') + token = self.token_dict['signed_token_scoped_expired'] req.headers['X-Auth-Token'] = token self.middleware(req.environ, self.start_fake_response) self.assertRaises(auth_token.InvalidUserToken,