Regenerate trust when update with different user

Regenerate trust when update stack with different user
We will regenerate (and delete old trust) when user credential
doesn't match with current context (means different user is
operating).

Story: #1752347
Task: #17352

Change-Id: I39795bdbd8ab255150153bf8b1e165b49e1a7027
This commit is contained in:
ricolin 2018-03-28 21:33:18 +08:00 committed by Rico Lin
parent 4a707e45f5
commit d8efcd1780
12 changed files with 258 additions and 43 deletions

View File

@ -91,6 +91,13 @@ class FakeKeystoneClient(object):
trust_id='atrust', trust_id='atrust',
trustor_user_id=self.user_id) trustor_user_id=self.user_id)
def regenerate_trust_context(self):
return context.RequestContext(username=self.username,
password=self.password,
is_admin=False,
trust_id='atrust',
trustor_user_id=self.user_id)
def delete_trust(self, trust_id): def delete_trust(self, trust_id):
pass pass

View File

@ -188,19 +188,7 @@ class KsClientWrapper(object):
return client return client
def create_trust_context(self): def _create_trust_context(self, trustor_user_id, trustor_proj_id):
"""Create a trust using the trustor identity in the current context.
The trust is created with the trustee as the heat service user.
If the current context already contains a trust_id, we do nothing
and return the current context.
Returns a context containing the new trust_id.
"""
if self.context.trust_id:
return self.context
# We need the service admin user ID (not name), as the trustor user # We need the service admin user ID (not name), as the trustor user
# can't lookup the ID in keystoneclient unless they're admin # can't lookup the ID in keystoneclient unless they're admin
# workaround this by getting the user_id from admin_client # workaround this by getting the user_id from admin_client
@ -211,9 +199,6 @@ class KsClientWrapper(object):
LOG.error("Domain admin client authentication failed") LOG.error("Domain admin client authentication failed")
raise exception.AuthorizationFailure() raise exception.AuthorizationFailure()
trustor_user_id = self.context.auth_plugin.get_user_id(self.session)
trustor_proj_id = self.context.auth_plugin.get_project_id(self.session)
role_kw = {} role_kw = {}
# inherit the roles of the trustor, unless set trusts_delegated_roles # inherit the roles of the trustor, unless set trusts_delegated_roles
if cfg.CONF.trusts_delegated_roles: if cfg.CONF.trusts_delegated_roles:
@ -245,6 +230,23 @@ class KsClientWrapper(object):
trust_context.trustor_user_id = trustor_user_id trust_context.trustor_user_id = trustor_user_id
return trust_context return trust_context
def create_trust_context(self):
"""Create a trust using the trustor identity in the current context.
The trust is created with the trustee as the heat service user.
If the current context already contains a trust_id, we do nothing
and return the current context.
Returns a context containing the new trust_id.
"""
if self.context.trust_id:
return self.context
trustor_user_id = self.context.auth_plugin.get_user_id(self.session)
trustor_proj_id = self.context.auth_plugin.get_project_id(self.session)
return self._create_trust_context(trustor_user_id, trustor_proj_id)
def delete_trust(self, trust_id): def delete_trust(self, trust_id):
"""Delete the specified trust.""" """Delete the specified trust."""
try: try:
@ -252,6 +254,23 @@ class KsClientWrapper(object):
except (ks_exception.NotFound, ks_exception.Unauthorized): except (ks_exception.NotFound, ks_exception.Unauthorized):
pass pass
def regenerate_trust_context(self):
"""Regenerate a trust using the trustor identity of current user_id.
The trust is created with the trustee as the heat service user.
Returns a context containing the new trust_id.
"""
old_trust_id = self.context.trust_id
trustor_user_id = self.context.auth_plugin.get_user_id(self.session)
trustor_proj_id = self.context.auth_plugin.get_project_id(self.session)
trust_context = self._create_trust_context(trustor_user_id,
trustor_proj_id)
if old_trust_id:
self.delete_trust(old_trust_id)
return trust_context
def _get_username(self, username): def _get_username(self, username):
if(len(username) > 255): if(len(username) > 255):
LOG.warning("Truncating the username %s to the last 255 " LOG.warning("Truncating the username %s to the last 255 "

View File

@ -1022,9 +1022,11 @@ class EngineService(service.ServiceBase):
LOG.info('Updating stack %s', db_stack.name) LOG.info('Updating stack %s', db_stack.name)
if cfg.CONF.reauthentication_auth_method == 'trusts': if cfg.CONF.reauthentication_auth_method == 'trusts':
current_stack = parser.Stack.load( current_stack = parser.Stack.load(
cnxt, stack=db_stack, use_stored_context=True) cnxt, stack=db_stack, use_stored_context=True,
check_refresh_cred=True)
else: else:
current_stack = parser.Stack.load(cnxt, stack=db_stack) current_stack = parser.Stack.load(cnxt, stack=db_stack,
check_refresh_cred=True)
self.resource_enforcer.enforce_stack(current_stack, self.resource_enforcer.enforce_stack(current_stack,
is_registered_policy=True) is_registered_policy=True)

View File

@ -126,7 +126,7 @@ class Stack(collections.Mapping):
nested_depth=0, strict_validate=True, convergence=False, nested_depth=0, strict_validate=True, convergence=False,
current_traversal=None, tags=None, prev_raw_template_id=None, current_traversal=None, tags=None, prev_raw_template_id=None,
current_deps=None, cache_data=None, current_deps=None, cache_data=None,
deleted_time=None, converge=False): deleted_time=None, converge=False, refresh_cred=False):
"""Initialise the Stack. """Initialise the Stack.
@ -188,6 +188,9 @@ class Stack(collections.Mapping):
self.thread_group_mgr = None self.thread_group_mgr = None
self.converge = converge self.converge = converge
# This flag is use to check whether credential needs to refresh or not
self.refresh_cred = refresh_cred
# strict_validate can be used to disable value validation # strict_validate can be used to disable value validation
# in the resource properties schema, this is useful when # in the resource properties schema, this is useful when
# performing validation when properties reference attributes # performing validation when properties reference attributes
@ -541,10 +544,35 @@ class Stack(collections.Mapping):
{'res': str(res), {'res': str(res),
'err': str(exc)}) 'err': str(exc)})
@classmethod
def _check_refresh_cred(cls, context, stack):
if stack.user_creds_id:
creds_obj = ucreds_object.UserCreds.get_by_id(
context, stack.user_creds_id)
creds = creds_obj.obj_to_primitive()["versioned_object.data"]
stored_context = common_context.StoredContext.from_dict(creds)
if cfg.CONF.deferred_auth_method == 'trusts':
old_trustor_proj_id = stored_context.tenant_id
old_trustor_user_id = stored_context.trustor_user_id
trustor_user_id = context.auth_plugin.get_user_id(
context.clients.client('keystone').session)
trustor_proj_id = context.auth_plugin.get_project_id(
context.clients.client('keystone').session)
return False if (
old_trustor_user_id == trustor_user_id) and (
old_trustor_proj_id == trustor_proj_id
) else True
# Should not raise error or allow refresh credential when we can't find
# user_creds_id in stack
return False
@classmethod @classmethod
def load(cls, context, stack_id=None, stack=None, show_deleted=True, def load(cls, context, stack_id=None, stack=None, show_deleted=True,
use_stored_context=False, force_reload=False, cache_data=None, use_stored_context=False, force_reload=False, cache_data=None,
load_template=True): load_template=True, check_refresh_cred=False):
"""Retrieve a Stack from the database.""" """Retrieve a Stack from the database."""
if stack is None: if stack is None:
stack = stack_object.Stack.get_by_id( stack = stack_object.Stack.get_by_id(
@ -555,13 +583,22 @@ class Stack(collections.Mapping):
message = _('No stack exists with id "%s"') % str(stack_id) message = _('No stack exists with id "%s"') % str(stack_id)
raise exception.NotFound(message) raise exception.NotFound(message)
refresh_cred = False
if check_refresh_cred and (
cfg.CONF.deferred_auth_method == 'trusts'
):
if cls._check_refresh_cred(context, stack):
use_stored_context = False
refresh_cred = True
if force_reload: if force_reload:
stack.refresh() stack.refresh()
return cls._from_db(context, stack, return cls._from_db(context, stack,
use_stored_context=use_stored_context, use_stored_context=use_stored_context,
cache_data=cache_data, cache_data=cache_data,
load_template=load_template) load_template=load_template,
refresh_cred=refresh_cred)
@classmethod @classmethod
def load_all(cls, context, limit=None, marker=None, sort_keys=None, def load_all(cls, context, limit=None, marker=None, sort_keys=None,
@ -595,7 +632,7 @@ class Stack(collections.Mapping):
@classmethod @classmethod
def _from_db(cls, context, stack, def _from_db(cls, context, stack,
use_stored_context=False, cache_data=None, use_stored_context=False, cache_data=None,
load_template=True): load_template=True, refresh_cred=False):
if load_template: if load_template:
template = tmpl.Template.load( template = tmpl.Template.load(
context, stack.raw_template_id, stack.raw_template) context, stack.raw_template_id, stack.raw_template)
@ -619,7 +656,8 @@ class Stack(collections.Mapping):
prev_raw_template_id=stack.prev_raw_template_id, prev_raw_template_id=stack.prev_raw_template_id,
current_deps=stack.current_deps, cache_data=cache_data, current_deps=stack.current_deps, cache_data=cache_data,
nested_depth=stack.nested_depth, nested_depth=stack.nested_depth,
deleted_time=stack.deleted_at) deleted_time=stack.deleted_at,
refresh_cred=refresh_cred)
def get_kwargs_for_cloning(self, keep_status=False, only_db=False, def get_kwargs_for_cloning(self, keep_status=False, only_db=False,
keep_tags=False): keep_tags=False):
@ -687,6 +725,17 @@ class Stack(collections.Mapping):
s['raw_template_id'] = self.t.id s['raw_template_id'] = self.t.id
if self.id is not None: if self.id is not None:
if self.refresh_cred:
keystone = self.clients.client('keystone')
trust_ctx = keystone.regenerate_trust_context()
new_creds = ucreds_object.UserCreds.create(trust_ctx)
s['user_creds_id'] = new_creds.id
self._delete_user_cred(raise_keystone_exception=True)
self.user_creds_id = new_creds.id
self.refresh_cred = False
if exp_trvsl is None and not ignore_traversal_check: if exp_trvsl is None and not ignore_traversal_check:
exp_trvsl = self.current_traversal exp_trvsl = self.current_traversal
@ -1840,11 +1889,10 @@ class Stack(collections.Mapping):
LOG.exception("Failed to retrieve user_creds") LOG.exception("Failed to retrieve user_creds")
return None return None
def _delete_credentials(self, stack_status, reason, abandon): def _delete_user_cred(self, stack_status=None, reason=None,
raise_keystone_exception=False):
# Cleanup stored user_creds so they aren't accessible via # Cleanup stored user_creds so they aren't accessible via
# the soft-deleted stack which remains in the DB # the soft-deleted stack which remains in the DB
# The stack_status and reason passed in are current values, which
# may get rewritten and returned from this method
if self.user_creds_id: if self.user_creds_id:
user_creds = self._try_get_user_creds() user_creds = self._try_get_user_creds()
# If we created a trust, delete it # If we created a trust, delete it
@ -1874,6 +1922,8 @@ class Stack(collections.Mapping):
# Without this, they would need to issue # Without this, they would need to issue
# an additional stack-delete # an additional stack-delete
LOG.exception("Error deleting trust") LOG.exception("Error deleting trust")
if raise_keystone_exception:
raise
# Delete the stored credentials # Delete the stored credentials
try: try:
@ -1883,13 +1933,18 @@ class Stack(collections.Mapping):
LOG.info("Tried to delete user_creds that do not exist " LOG.info("Tried to delete user_creds that do not exist "
"(stack=%(stack)s user_creds_id=%(uc)s)", "(stack=%(stack)s user_creds_id=%(uc)s)",
{'stack': self.id, 'uc': self.user_creds_id}) {'stack': self.id, 'uc': self.user_creds_id})
self.user_creds_id = None
return stack_status, reason
try: def _delete_credentials(self, stack_status, reason, abandon):
self.user_creds_id = None # The stack_status and reason passed in are current values, which
self.store() # may get rewritten and returned from this method
except exception.NotFound: stack_status, reason = self._delete_user_cred(stack_status, reason)
LOG.info("Tried to store a stack that does not exist %s", try:
self.id) self.store()
except exception.NotFound:
LOG.info("Tried to store a stack that does not exist %s",
self.id)
# If the stack has a domain project, delete it # If the stack has a domain project, delete it
if self.stack_user_project_id and not abandon: if self.stack_user_project_id and not abandon:

View File

@ -521,6 +521,65 @@ class KeystoneClientTest(common.HeatTestCase):
self.assertRaises(exception.AuthorizationFailure, self.assertRaises(exception.AuthorizationFailure,
heat_keystoneclient.KeystoneClient, ctx) heat_keystoneclient.KeystoneClient, ctx)
def test_regenerate_trust_context_with_no_exist_trust_id(self):
"""Test regenerate_trust_context."""
class MockTrust(object):
id = 'dtrust123'
mock_ks_auth, mock_auth_ref = self._stubs_auth(user_id='5678',
project_id='42',
stub_trust_context=True,
stub_admin_auth=True)
cfg.CONF.set_override('deferred_auth_method', 'trusts')
trustor_roles = ['heat_stack_owner', 'admin', '__member__']
trustee_roles = trustor_roles
mock_auth_ref.user_id = '5678'
mock_auth_ref.project_id = '42'
self.mock_ks_v3_client.trusts.create.return_value = MockTrust()
ctx = utils.dummy_context(roles=trustor_roles)
ctx.trust_id = None
heat_ks_client = heat_keystoneclient.KeystoneClient(ctx)
trust_context = heat_ks_client.regenerate_trust_context()
self.assertEqual('dtrust123', trust_context.trust_id)
self.assertEqual('5678', trust_context.trustor_user_id)
ks_loading.load_auth_from_conf_options.assert_called_once_with(
cfg.CONF, 'trustee', trust_id=None)
self.mock_ks_v3_client.trusts.create.assert_called_once_with(
trustor_user='5678',
trustee_user='1234',
project='42',
impersonation=True,
allow_redelegation=False,
role_names=trustee_roles)
self.assertEqual(0, self.mock_ks_v3_client.trusts.delete.call_count)
def test_regenerate_trust_context_with_exist_trust_id(self):
"""Test regenerate_trust_context."""
self._stubs_auth(method='trust')
cfg.CONF.set_override('deferred_auth_method', 'trusts')
ctx = utils.dummy_context()
ctx.trust_id = 'atrust123'
ctx.trustor_user_id = 'trustor_user_id'
class MockTrust(object):
id = 'dtrust123'
self.mock_ks_v3_client.trusts.create.return_value = MockTrust()
heat_ks_client = heat_keystoneclient.KeystoneClient(ctx)
trust_context = heat_ks_client.regenerate_trust_context()
self.assertEqual('dtrust123', trust_context.trust_id)
self.mock_ks_v3_client.trusts.delete.assert_called_once_with(
ctx.trust_id)
def test_create_trust_context_trust_id(self): def test_create_trust_context_trust_id(self):
"""Test create_trust_context with existing trust_id.""" """Test create_trust_context with existing trust_id."""

View File

@ -11,13 +11,15 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from oslo_config import cfg
from heat.common import context
from heat.engine import resource from heat.engine import resource
from heat.tests import common from heat.tests import common
from heat.tests.convergence.framework import fake_resource from heat.tests.convergence.framework import fake_resource
from heat.tests.convergence.framework import processes from heat.tests.convergence.framework import processes
from heat.tests.convergence.framework import scenario from heat.tests.convergence.framework import scenario
from heat.tests.convergence.framework import testutils from heat.tests.convergence.framework import testutils
from oslo_config import cfg
class ScenarioTest(common.HeatTestCase): class ScenarioTest(common.HeatTestCase):
@ -27,6 +29,7 @@ class ScenarioTest(common.HeatTestCase):
def setUp(self): def setUp(self):
super(ScenarioTest, self).setUp() super(ScenarioTest, self).setUp()
self.patchobject(context, 'StoredContext')
resource._register_class('OS::Heat::TestResource', resource._register_class('OS::Heat::TestResource',
fake_resource.TestResource) fake_resource.TestResource)
self.procs = processes.Processes() self.procs = processes.Processes()

View File

@ -159,6 +159,7 @@ class StackServiceUpdateActionsNotSupportedTest(common.HeatTestCase):
self.ctx, old_stack.identifier(), template, self.ctx, old_stack.identifier(), template,
params, None, {}) params, None, {})
self.assertEqual(exception.NotSupported, ex.exc_info[0]) self.assertEqual(exception.NotSupported, ex.exc_info[0])
mock_load.assert_called_once_with(self.ctx, stack=s) mock_load.assert_called_once_with(self.ctx, stack=s,
check_refresh_cred=True)
old_stack.delete() old_stack.delete()

View File

@ -16,6 +16,7 @@ from unittest import mock
from oslo_config import cfg from oslo_config import cfg
from oslo_messaging import conffixture from oslo_messaging import conffixture
from heat.common import context
from heat.engine import resource as res from heat.engine import resource as res
from heat.engine.resources.aws.ec2 import instance as instances from heat.engine.resources.aws.ec2 import instance as instances
from heat.engine import service from heat.engine import service
@ -32,6 +33,7 @@ class StackEventTest(common.HeatTestCase):
def setUp(self): def setUp(self):
super(StackEventTest, self).setUp() super(StackEventTest, self).setUp()
self.patchobject(context, 'StoredContext')
self.ctx = utils.dummy_context(tenant_id='stack_event_test_tenant') self.ctx = utils.dummy_context(tenant_id='stack_event_test_tenant')
self.eng = service.EngineService('a-host', 'a-topic') self.eng = service.EngineService('a-host', 'a-topic')

View File

@ -18,6 +18,7 @@ from oslo_config import cfg
from oslo_messaging import conffixture from oslo_messaging import conffixture
from oslo_messaging.rpc import dispatcher from oslo_messaging.rpc import dispatcher
from heat.common import context
from heat.common import environment_util as env_util from heat.common import environment_util as env_util
from heat.common import exception from heat.common import exception
from heat.common import messaging from heat.common import messaging
@ -45,6 +46,7 @@ class ServiceStackUpdateTest(common.HeatTestCase):
def setUp(self): def setUp(self):
super(ServiceStackUpdateTest, self).setUp() super(ServiceStackUpdateTest, self).setUp()
self.useFixture(conffixture.ConfFixture(cfg.CONF)) self.useFixture(conffixture.ConfFixture(cfg.CONF))
self.patchobject(context, 'StoredContext')
self.ctx = utils.dummy_context() self.ctx = utils.dummy_context()
self.man = service.EngineService('a-host', 'a-topic') self.man = service.EngineService('a-host', 'a-topic')
self.man.thread_group_mgr = tools.DummyThreadGroupManager() self.man.thread_group_mgr = tools.DummyThreadGroupManager()
@ -103,7 +105,8 @@ class ServiceStackUpdateTest(common.HeatTestCase):
username='test_username', username='test_username',
converge=True converge=True
) )
mock_load.assert_called_once_with(self.ctx, stack=s) mock_load.assert_called_once_with(self.ctx, stack=s,
check_refresh_cred=True)
mock_validate.assert_called_once_with() mock_validate.assert_called_once_with()
def _test_stack_update_with_environment_files(self, stack_name, def _test_stack_update_with_environment_files(self, stack_name,
@ -222,7 +225,8 @@ class ServiceStackUpdateTest(common.HeatTestCase):
username='test_username', username='test_username',
converge=False converge=False
) )
mock_load.assert_called_once_with(self.ctx, stack=s) mock_load.assert_called_once_with(self.ctx, stack=s,
check_refresh_cred=True)
mock_validate.assert_called_once_with() mock_validate.assert_called_once_with()
def test_stack_update_existing_parameters(self): def test_stack_update_existing_parameters(self):
@ -555,7 +559,8 @@ resources:
mock_validate.assert_called_once_with() mock_validate.assert_called_once_with()
mock_tmpl.assert_called_once_with(template, files=None) mock_tmpl.assert_called_once_with(template, files=None)
mock_env.assert_called_once_with(params) mock_env.assert_called_once_with(params)
mock_load.assert_called_once_with(self.ctx, stack=s) mock_load.assert_called_once_with(self.ctx, stack=s,
check_refresh_cred=True)
mock_stack.assert_called_once_with( mock_stack.assert_called_once_with(
self.ctx, stk.name, stk.t, self.ctx, stk.name, stk.t,
convergence=False, convergence=False,
@ -703,7 +708,8 @@ resources:
username='test_username', username='test_username',
converge=False converge=False
) )
mock_load.assert_called_once_with(self.ctx, stack=s) mock_load.assert_called_once_with(self.ctx, stack=s,
check_refresh_cred=True)
mock_validate.assert_called_once_with() mock_validate.assert_called_once_with()
def test_stack_update_stack_id_equal(self): def test_stack_update_stack_id_equal(self):
@ -750,9 +756,11 @@ resources:
old_stack['A'].properties['Foo']) old_stack['A'].properties['Foo'])
self.assertEqual(create_stack['A'].id, old_stack['A'].id) self.assertEqual(create_stack['A'].id, old_stack['A'].id)
mock_load.assert_called_once_with(self.ctx, stack=s) mock_load.assert_called_once_with(self.ctx, stack=s,
check_refresh_cred=True)
def test_stack_update_exceeds_resource_limit(self): def test_stack_update_exceeds_resource_limit(self):
self.patchobject(context, 'StoredContext')
stack_name = 'test_stack_update_exceeds_resource_limit' stack_name = 'test_stack_update_exceeds_resource_limit'
params = {} params = {}
tpl = {'HeatTemplateFormatVersion': '2012-12-12', tpl = {'HeatTemplateFormatVersion': '2012-12-12',
@ -822,7 +830,8 @@ resources:
username='test_username', username='test_username',
converge=False converge=False
) )
mock_load.assert_called_once_with(self.ctx, stack=s) mock_load.assert_called_once_with(self.ctx, stack=s,
check_refresh_cred=True)
mock_validate.assert_called_once_with() mock_validate.assert_called_once_with()
def test_stack_update_nonexist(self): def test_stack_update_nonexist(self):
@ -886,7 +895,8 @@ resources:
user_creds_id=u'1', username='test_username', user_creds_id=u'1', username='test_username',
converge=False converge=False
) )
mock_load.assert_called_once_with(self.ctx, stack=s) mock_load.assert_called_once_with(self.ctx, stack=s,
check_refresh_cred=True)
def test_stack_update_existing_template(self): def test_stack_update_existing_template(self):
'''Update a stack using the same template.''' '''Update a stack using the same template.'''

View File

@ -289,7 +289,7 @@ class StackConvergenceServiceCreateUpdateTest(common.HeatTestCase):
self.assertIsInstance(result, dict) self.assertIsInstance(result, dict)
self.assertTrue(result['stack_id']) self.assertTrue(result['stack_id'])
parser.Stack.load.assert_called_once_with( parser.Stack.load.assert_called_once_with(
self.ctx, stack=mock.ANY) self.ctx, stack=mock.ANY, check_refresh_cred=True)
templatem.Template.assert_called_once_with(template, files=None) templatem.Template.assert_called_once_with(template, files=None)
environment.Environment.assert_called_once_with(params) environment.Environment.assert_called_once_with(params)

View File

@ -478,7 +478,7 @@ class StackTest(common.HeatTestCase):
prev_raw_template_id=None, prev_raw_template_id=None,
current_deps=None, cache_data=None, current_deps=None, cache_data=None,
nested_depth=0, nested_depth=0,
deleted_time=None) deleted_time=None, refresh_cred=False)
template.Template.load.assert_called_once_with( template.Template.load.assert_called_once_with(
self.ctx, stk.raw_template_id, stk.raw_template) self.ctx, stk.raw_template_id, stk.raw_template)
@ -1630,6 +1630,31 @@ class StackTest(common.HeatTestCase):
saved_stack = stack.Stack.load(self.ctx, stack_id=stack_ownee.id) saved_stack = stack.Stack.load(self.ctx, stack_id=stack_ownee.id)
self.assertEqual(self.stack.id, saved_stack.owner_id) self.assertEqual(self.stack.id, saved_stack.owner_id)
def _test_load_with_refresh_cred(self, refresh=True):
cfg.CONF.set_override('deferred_auth_method', 'trusts')
self.patchobject(self.ctx.auth_plugin, 'get_user_id',
return_value='old_trustor_user_id')
self.patchobject(self.ctx.auth_plugin, 'get_project_id',
return_value='test_tenant_id')
old_context = utils.dummy_context()
old_context.trust_id = 'atrust123'
old_context.trustor_user_id = (
'trustor_user_id' if refresh else 'old_trustor_user_id')
m_sc = self.patchobject(context, 'StoredContext')
m_sc.from_dict.return_value = old_context
self.stack = stack.Stack(self.ctx, 'test_regenerate_trust', self.tmpl)
self.stack.store()
load_stack = stack.Stack.load(self.ctx, stack_id=self.stack.id,
check_refresh_cred=True)
self.assertEqual(refresh, load_stack.refresh_cred)
def test_load_with_refresh_cred(self):
self._test_load_with_refresh_cred()
def test_load_with_no_refresh_cred(self):
self._test_load_with_refresh_cred(refresh=False)
def test_requires_deferred_auth(self): def test_requires_deferred_auth(self):
tmpl = {'HeatTemplateFormatVersion': '2012-12-12', tmpl = {'HeatTemplateFormatVersion': '2012-12-12',
'Resources': {'AResource': {'Type': 'GenericResourceType'}, 'Resources': {'AResource': {'Type': 'GenericResourceType'},

View File

@ -18,6 +18,7 @@ from unittest import mock
from heat.common import exception from heat.common import exception
from heat.common import template_format from heat.common import template_format
from heat.db.sqlalchemy import api as db_api from heat.db.sqlalchemy import api as db_api
from heat.engine.clients.os.keystone import fake_keystoneclient
from heat.engine import environment from heat.engine import environment
from heat.engine import resource from heat.engine import resource
from heat.engine import rsrc_defn from heat.engine import rsrc_defn
@ -72,6 +73,37 @@ class StackUpdateTest(common.HeatTestCase):
self.assertRaises(exception.NotFound, self.assertRaises(exception.NotFound,
db_api.raw_template_get, self.ctx, raw_template_id) db_api.raw_template_get, self.ctx, raw_template_id)
def test_update_with_refresh_creds(self):
tmpl = {'HeatTemplateFormatVersion': '2012-12-12',
'Resources': {'AResource': {'Type': 'GenericResourceType'}}}
self.stack = stack.Stack(self.ctx, 'update_test_stack',
template.Template(tmpl))
self.stack.store()
self.stack.create()
self.assertEqual((stack.Stack.CREATE, stack.Stack.COMPLETE),
self.stack.state)
tmpl2 = {'HeatTemplateFormatVersion': '2012-12-12',
'Resources': {
'AResource': {'Type': 'GenericResourceType'},
'BResource': {'Type': 'GenericResourceType'}}}
updated_stack = stack.Stack(self.ctx, 'updated_stack',
template.Template(tmpl2))
old_user_creds_id = self.stack.user_creds_id
self.stack.refresh_cred = True
self.stack.context.user_id = '5678'
mock_del_trust = self.patchobject(
fake_keystoneclient.FakeKeystoneClient, 'delete_trust')
self.stack.update(updated_stack)
self.assertEqual((stack.Stack.UPDATE, stack.Stack.COMPLETE),
self.stack.state)
self.assertEqual(1, mock_del_trust.call_count)
self.assertNotEqual(self.stack.user_creds_id, old_user_creds_id)
def test_update_remove(self): def test_update_remove(self):
tmpl = {'HeatTemplateFormatVersion': '2012-12-12', tmpl = {'HeatTemplateFormatVersion': '2012-12-12',
'Resources': { 'Resources': {