Make sure snapshot belongs to stack for actions

Check the snapshot belongs to stack when deleting and showing
stack's snapshot, and restoring from snapshot.

Change-Id: I8ce170b40b05ae17669524d75f80e06e39986673
Closes-Bug: #1437602
changes/66/168766/7
huangtianhua 8 years ago
parent dd92a7b2bb
commit af6b0db444
  1. 1
      heat/api/middleware/fault.py
  2. 5
      heat/common/exception.py
  3. 4
      heat/db/api.py
  4. 9
      heat/db/sqlalchemy/api.py
  5. 11
      heat/engine/service.py
  6. 4
      heat/engine/stack.py
  7. 5
      heat/objects/snapshot.py
  8. 14
      heat/tests/db/test_sqlalchemy_api.py
  9. 65
      heat/tests/test_engine_service.py

@ -59,6 +59,7 @@ class FaultWrapper(wsgi.Middleware):
'ResourceActionNotSupported': webob.exc.HTTPBadRequest,
'ResourceNotFound': webob.exc.HTTPNotFound,
'ResourceTypeNotFound': webob.exc.HTTPNotFound,
'SnapshotNotFound': webob.exc.HTTPNotFound,
'ResourceNotAvailable': webob.exc.HTTPNotFound,
'PhysicalResourceNotFound': webob.exc.HTTPNotFound,
'InvalidTenant': webob.exc.HTTPForbidden,

@ -315,6 +315,11 @@ class ResourceNotFound(HeatException):
"in Stack %(stack_name)s.")
class SnapshotNotFound(HeatException):
msg_fmt = _("The Snapshot (%(snapshot)s) for Stack (%(stack)s) "
"could not be found.")
class ResourceTypeNotFound(HeatException):
msg_fmt = _("The Resource Type (%(type_name)s) could not be found.")

@ -317,6 +317,10 @@ def snapshot_get(context, snapshot_id):
return IMPL.snapshot_get(context, snapshot_id)
def snapshot_get_by_stack(context, snapshot_id, stack):
return IMPL.snapshot_get_by_stack(context, snapshot_id, stack)
def snapshot_update(context, snapshot_id, values):
return IMPL.snapshot_update(context, snapshot_id, values)

@ -870,6 +870,15 @@ def snapshot_get(context, snapshot_id):
return result
def snapshot_get_by_stack(context, snapshot_id, stack):
snapshot = snapshot_get(context, snapshot_id)
if snapshot.stack_id != stack.id:
raise exception.SnapshotNotFound(snapshot=snapshot_id,
stack=stack.name)
return snapshot
def snapshot_update(context, snapshot_id, values):
snapshot = snapshot_get(context, snapshot_id)
snapshot.update(values)

@ -1254,7 +1254,9 @@ class EngineService(service.Service):
@context.request_context
def show_snapshot(self, cnxt, stack_identity, snapshot_id):
snapshot = snapshot_object.Snapshot.get_by_id(cnxt, snapshot_id)
s = self._get_stack(cnxt, stack_identity)
snapshot = snapshot_object.Snapshot.get_snapshot_by_stack(
cnxt, snapshot_id, s)
return api.format_snapshot(snapshot)
@context.request_context
@ -1265,7 +1267,8 @@ class EngineService(service.Service):
s = self._get_stack(cnxt, stack_identity)
stack = parser.Stack.load(cnxt, stack=s)
snapshot = snapshot_object.Snapshot.get_by_id(cnxt, snapshot_id)
snapshot = snapshot_object.Snapshot.get_snapshot_by_stack(
cnxt, snapshot_id, s)
self.thread_group_mgr.start(
stack.id, _delete_snapshot, stack, snapshot)
@ -1288,9 +1291,9 @@ class EngineService(service.Service):
stack.restore(snapshot)
s = self._get_stack(cnxt, stack_identity)
snapshot = snapshot_object.Snapshot.get_by_id(cnxt, snapshot_id)
stack = parser.Stack.load(cnxt, stack=s)
snapshot = snapshot_object.Snapshot.get_snapshot_by_stack(
cnxt, snapshot_id, s)
self.thread_group_mgr.start_with_lock(cnxt, stack, self.engine_id,
_stack_restore, stack, snapshot)

@ -1262,10 +1262,6 @@ class Stack(collections.Mapping):
'''
Restore the given snapshot, invoking handle_restore on all resources.
'''
if snapshot.stack_id != self.id:
self.state_set(self.RESTORE, self.FAILED,
"Can't restore snapshot from other stack")
return
self.updated_time = datetime.datetime.utcnow()
template = tmpl.Template(snapshot.data['template'], env=self.env)

@ -54,9 +54,10 @@ class Snapshot(base.VersionedObject,
context, cls(), db_api.snapshot_create(context, values))
@classmethod
def get_by_id(cls, context, snapshot_id):
def get_snapshot_by_stack(cls, context, snapshot_id, stack):
return cls._from_db_object(
context, cls(), db_api.snapshot_get(context, snapshot_id))
context, cls(), db_api.snapshot_get_by_stack(
context, snapshot_id, stack))
@classmethod
def update(cls, context, snapshot_id, values):

@ -1052,6 +1052,20 @@ class SqlAlchemyTest(common.HeatTestCase):
self.assertEqual(values['status'], snapshot.status)
self.assertIsNotNone(snapshot.created_at)
def test_snapshot_get_by_another_stack(self):
template = create_raw_template(self.ctx)
user_creds = create_user_creds(self.ctx)
stack = create_stack(self.ctx, template, user_creds)
stack1 = create_stack(self.ctx, template, user_creds)
values = {'tenant': self.ctx.tenant_id, 'status': 'IN_PROGRESS',
'stack_id': stack.id}
snapshot = db_api.snapshot_create(self.ctx, values)
self.assertIsNotNone(snapshot)
snapshot_id = snapshot.id
self.assertRaises(exception.SnapshotNotFound,
db_api.snapshot_get_by_stack,
self.ctx, snapshot_id, stack1)
def test_snapshot_get_not_found_invalid_tenant(self):
template = create_raw_template(self.ctx)
user_creds = create_user_creds(self.ctx)

@ -4380,18 +4380,40 @@ class SnapshotServiceTest(common.HeatTestCase):
sid = stack.store()
s = stack_object.Stack.get_by_id(self.ctx, sid)
stack.state_set(stack.CREATE, stack.COMPLETE, 'mock completion')
if stub:
self.m.StubOutWithMock(parser.Stack, 'load')
stack.state_set(stack.CREATE, stack.COMPLETE, 'mock completion')
parser.Stack.load(self.ctx, stack=s).MultipleTimes().AndReturn(stack)
parser.Stack.load(self.ctx,
stack=s).MultipleTimes().AndReturn(stack)
return stack
def test_show_snapshot_not_found(self):
stack1 = self._create_stack(stub=False)
snapshot_id = str(uuid.uuid4())
ex = self.assertRaises(dispatcher.ExpectedException,
self.engine.show_snapshot,
self.ctx, None, snapshot_id)
self.ctx, stack1.identifier(),
snapshot_id)
expected = 'Snapshot with id %s not found' % snapshot_id
self.assertEqual(exception.NotFound, ex.exc_info[0])
self.assertIn(expected, six.text_type(ex.exc_info[1]))
def test_show_snapshot_not_belong_to_stack(self):
stack1 = self._create_stack(stub=False)
snapshot1 = self.engine.stack_snapshot(
self.ctx, stack1.identifier(), 'snap1')
self.engine.thread_group_mgr.groups[stack1.id].wait()
snapshot_id = snapshot1['id']
stack2 = self._create_stack(stub=False)
ex = self.assertRaises(dispatcher.ExpectedException,
self.engine.show_snapshot,
self.ctx, stack2.identifier(),
snapshot_id)
expected = ('The Snapshot (%(snapshot)s) for Stack (%(stack)s) '
'could not be found') % {'snapshot': snapshot_id,
'stack': stack2.name}
self.assertEqual(exception.SnapshotNotFound, ex.exc_info[0])
self.assertIn(expected, six.text_type(ex.exc_info[1]))
def test_create_snapshot(self):
stack = self._create_stack()
@ -4433,6 +4455,28 @@ class SnapshotServiceTest(common.HeatTestCase):
self.ctx, stack.identifier(), snapshot_id)
self.assertEqual(exception.NotFound, ex.exc_info[0])
def test_delete_snapshot_not_belong_to_stack(self):
stack1 = self._create_stack()
self.m.ReplayAll()
snapshot1 = self.engine.stack_snapshot(
self.ctx, stack1.identifier(), 'snap1')
self.engine.thread_group_mgr.groups[stack1.id].wait()
snapshot_id = snapshot1['id']
self.m.UnsetStubs()
stack2 = self._create_stack()
self.m.ReplayAll()
ex = self.assertRaises(dispatcher.ExpectedException,
self.engine.delete_snapshot,
self.ctx,
stack2.identifier(),
snapshot_id)
expected = ('The Snapshot (%(snapshot)s) for Stack (%(stack)s) '
'could not be found') % {'snapshot': snapshot_id,
'stack': stack2.name}
self.assertEqual(exception.SnapshotNotFound, ex.exc_info[0])
self.assertIn(expected, six.text_type(ex.exc_info[1]))
def test_delete_snapshot(self):
stack = self._create_stack()
self.m.ReplayAll()
@ -4488,8 +4532,13 @@ class SnapshotServiceTest(common.HeatTestCase):
self.m.UnsetStubs()
stack2 = self._create_stack()
self.m.ReplayAll()
self.engine.stack_restore(self.ctx, stack2.identifier(), snapshot_id)
self.engine.thread_group_mgr.groups[stack2.id].wait()
self.assertEqual((stack2.RESTORE, stack2.FAILED), stack2.state)
self.assertEqual("Can't restore snapshot from other stack",
stack2.status_reason)
ex = self.assertRaises(dispatcher.ExpectedException,
self.engine.stack_restore,
self.ctx,
stack2.identifier(),
snapshot_id)
expected = ('The Snapshot (%(snapshot)s) for Stack (%(stack)s) '
'could not be found') % {'snapshot': snapshot_id,
'stack': stack2.name}
self.assertEqual(exception.SnapshotNotFound, ex.exc_info[0])
self.assertIn(expected, six.text_type(ex.exc_info[1]))

Loading…
Cancel
Save