Make sure snapshot belongs to stack for actions

Check the snapshot belongs to stack when deleting and showing
stack's snapshot, and restoring from snapshot.

Change-Id: I8ce170b40b05ae17669524d75f80e06e39986673
Closes-Bug: #1437602
This commit is contained in:
huangtianhua 2015-03-30 11:33:43 +08:00
parent dd92a7b2bb
commit af6b0db444
9 changed files with 100 additions and 18 deletions

View File

@ -59,6 +59,7 @@ class FaultWrapper(wsgi.Middleware):
'ResourceActionNotSupported': webob.exc.HTTPBadRequest, 'ResourceActionNotSupported': webob.exc.HTTPBadRequest,
'ResourceNotFound': webob.exc.HTTPNotFound, 'ResourceNotFound': webob.exc.HTTPNotFound,
'ResourceTypeNotFound': webob.exc.HTTPNotFound, 'ResourceTypeNotFound': webob.exc.HTTPNotFound,
'SnapshotNotFound': webob.exc.HTTPNotFound,
'ResourceNotAvailable': webob.exc.HTTPNotFound, 'ResourceNotAvailable': webob.exc.HTTPNotFound,
'PhysicalResourceNotFound': webob.exc.HTTPNotFound, 'PhysicalResourceNotFound': webob.exc.HTTPNotFound,
'InvalidTenant': webob.exc.HTTPForbidden, 'InvalidTenant': webob.exc.HTTPForbidden,

View File

@ -315,6 +315,11 @@ class ResourceNotFound(HeatException):
"in Stack %(stack_name)s.") "in Stack %(stack_name)s.")
class SnapshotNotFound(HeatException):
msg_fmt = _("The Snapshot (%(snapshot)s) for Stack (%(stack)s) "
"could not be found.")
class ResourceTypeNotFound(HeatException): class ResourceTypeNotFound(HeatException):
msg_fmt = _("The Resource Type (%(type_name)s) could not be found.") msg_fmt = _("The Resource Type (%(type_name)s) could not be found.")

View File

@ -317,6 +317,10 @@ def snapshot_get(context, snapshot_id):
return IMPL.snapshot_get(context, snapshot_id) return IMPL.snapshot_get(context, snapshot_id)
def snapshot_get_by_stack(context, snapshot_id, stack):
return IMPL.snapshot_get_by_stack(context, snapshot_id, stack)
def snapshot_update(context, snapshot_id, values): def snapshot_update(context, snapshot_id, values):
return IMPL.snapshot_update(context, snapshot_id, values) return IMPL.snapshot_update(context, snapshot_id, values)

View File

@ -870,6 +870,15 @@ def snapshot_get(context, snapshot_id):
return result return result
def snapshot_get_by_stack(context, snapshot_id, stack):
snapshot = snapshot_get(context, snapshot_id)
if snapshot.stack_id != stack.id:
raise exception.SnapshotNotFound(snapshot=snapshot_id,
stack=stack.name)
return snapshot
def snapshot_update(context, snapshot_id, values): def snapshot_update(context, snapshot_id, values):
snapshot = snapshot_get(context, snapshot_id) snapshot = snapshot_get(context, snapshot_id)
snapshot.update(values) snapshot.update(values)

View File

@ -1254,7 +1254,9 @@ class EngineService(service.Service):
@context.request_context @context.request_context
def show_snapshot(self, cnxt, stack_identity, snapshot_id): def show_snapshot(self, cnxt, stack_identity, snapshot_id):
snapshot = snapshot_object.Snapshot.get_by_id(cnxt, snapshot_id) s = self._get_stack(cnxt, stack_identity)
snapshot = snapshot_object.Snapshot.get_snapshot_by_stack(
cnxt, snapshot_id, s)
return api.format_snapshot(snapshot) return api.format_snapshot(snapshot)
@context.request_context @context.request_context
@ -1265,7 +1267,8 @@ class EngineService(service.Service):
s = self._get_stack(cnxt, stack_identity) s = self._get_stack(cnxt, stack_identity)
stack = parser.Stack.load(cnxt, stack=s) stack = parser.Stack.load(cnxt, stack=s)
snapshot = snapshot_object.Snapshot.get_by_id(cnxt, snapshot_id) snapshot = snapshot_object.Snapshot.get_snapshot_by_stack(
cnxt, snapshot_id, s)
self.thread_group_mgr.start( self.thread_group_mgr.start(
stack.id, _delete_snapshot, stack, snapshot) stack.id, _delete_snapshot, stack, snapshot)
@ -1288,9 +1291,9 @@ class EngineService(service.Service):
stack.restore(snapshot) stack.restore(snapshot)
s = self._get_stack(cnxt, stack_identity) s = self._get_stack(cnxt, stack_identity)
snapshot = snapshot_object.Snapshot.get_by_id(cnxt, snapshot_id)
stack = parser.Stack.load(cnxt, stack=s) stack = parser.Stack.load(cnxt, stack=s)
snapshot = snapshot_object.Snapshot.get_snapshot_by_stack(
cnxt, snapshot_id, s)
self.thread_group_mgr.start_with_lock(cnxt, stack, self.engine_id, self.thread_group_mgr.start_with_lock(cnxt, stack, self.engine_id,
_stack_restore, stack, snapshot) _stack_restore, stack, snapshot)

View File

@ -1262,10 +1262,6 @@ class Stack(collections.Mapping):
''' '''
Restore the given snapshot, invoking handle_restore on all resources. Restore the given snapshot, invoking handle_restore on all resources.
''' '''
if snapshot.stack_id != self.id:
self.state_set(self.RESTORE, self.FAILED,
"Can't restore snapshot from other stack")
return
self.updated_time = datetime.datetime.utcnow() self.updated_time = datetime.datetime.utcnow()
template = tmpl.Template(snapshot.data['template'], env=self.env) template = tmpl.Template(snapshot.data['template'], env=self.env)

View File

@ -54,9 +54,10 @@ class Snapshot(base.VersionedObject,
context, cls(), db_api.snapshot_create(context, values)) context, cls(), db_api.snapshot_create(context, values))
@classmethod @classmethod
def get_by_id(cls, context, snapshot_id): def get_snapshot_by_stack(cls, context, snapshot_id, stack):
return cls._from_db_object( return cls._from_db_object(
context, cls(), db_api.snapshot_get(context, snapshot_id)) context, cls(), db_api.snapshot_get_by_stack(
context, snapshot_id, stack))
@classmethod @classmethod
def update(cls, context, snapshot_id, values): def update(cls, context, snapshot_id, values):

View File

@ -1052,6 +1052,20 @@ class SqlAlchemyTest(common.HeatTestCase):
self.assertEqual(values['status'], snapshot.status) self.assertEqual(values['status'], snapshot.status)
self.assertIsNotNone(snapshot.created_at) self.assertIsNotNone(snapshot.created_at)
def test_snapshot_get_by_another_stack(self):
template = create_raw_template(self.ctx)
user_creds = create_user_creds(self.ctx)
stack = create_stack(self.ctx, template, user_creds)
stack1 = create_stack(self.ctx, template, user_creds)
values = {'tenant': self.ctx.tenant_id, 'status': 'IN_PROGRESS',
'stack_id': stack.id}
snapshot = db_api.snapshot_create(self.ctx, values)
self.assertIsNotNone(snapshot)
snapshot_id = snapshot.id
self.assertRaises(exception.SnapshotNotFound,
db_api.snapshot_get_by_stack,
self.ctx, snapshot_id, stack1)
def test_snapshot_get_not_found_invalid_tenant(self): def test_snapshot_get_not_found_invalid_tenant(self):
template = create_raw_template(self.ctx) template = create_raw_template(self.ctx)
user_creds = create_user_creds(self.ctx) user_creds = create_user_creds(self.ctx)

View File

@ -4380,18 +4380,40 @@ class SnapshotServiceTest(common.HeatTestCase):
sid = stack.store() sid = stack.store()
s = stack_object.Stack.get_by_id(self.ctx, sid) s = stack_object.Stack.get_by_id(self.ctx, sid)
stack.state_set(stack.CREATE, stack.COMPLETE, 'mock completion')
if stub: if stub:
self.m.StubOutWithMock(parser.Stack, 'load') self.m.StubOutWithMock(parser.Stack, 'load')
stack.state_set(stack.CREATE, stack.COMPLETE, 'mock completion') parser.Stack.load(self.ctx,
parser.Stack.load(self.ctx, stack=s).MultipleTimes().AndReturn(stack) stack=s).MultipleTimes().AndReturn(stack)
return stack return stack
def test_show_snapshot_not_found(self): def test_show_snapshot_not_found(self):
stack1 = self._create_stack(stub=False)
snapshot_id = str(uuid.uuid4()) snapshot_id = str(uuid.uuid4())
ex = self.assertRaises(dispatcher.ExpectedException, ex = self.assertRaises(dispatcher.ExpectedException,
self.engine.show_snapshot, self.engine.show_snapshot,
self.ctx, None, snapshot_id) self.ctx, stack1.identifier(),
snapshot_id)
expected = 'Snapshot with id %s not found' % snapshot_id
self.assertEqual(exception.NotFound, ex.exc_info[0]) self.assertEqual(exception.NotFound, ex.exc_info[0])
self.assertIn(expected, six.text_type(ex.exc_info[1]))
def test_show_snapshot_not_belong_to_stack(self):
stack1 = self._create_stack(stub=False)
snapshot1 = self.engine.stack_snapshot(
self.ctx, stack1.identifier(), 'snap1')
self.engine.thread_group_mgr.groups[stack1.id].wait()
snapshot_id = snapshot1['id']
stack2 = self._create_stack(stub=False)
ex = self.assertRaises(dispatcher.ExpectedException,
self.engine.show_snapshot,
self.ctx, stack2.identifier(),
snapshot_id)
expected = ('The Snapshot (%(snapshot)s) for Stack (%(stack)s) '
'could not be found') % {'snapshot': snapshot_id,
'stack': stack2.name}
self.assertEqual(exception.SnapshotNotFound, ex.exc_info[0])
self.assertIn(expected, six.text_type(ex.exc_info[1]))
def test_create_snapshot(self): def test_create_snapshot(self):
stack = self._create_stack() stack = self._create_stack()
@ -4433,6 +4455,28 @@ class SnapshotServiceTest(common.HeatTestCase):
self.ctx, stack.identifier(), snapshot_id) self.ctx, stack.identifier(), snapshot_id)
self.assertEqual(exception.NotFound, ex.exc_info[0]) self.assertEqual(exception.NotFound, ex.exc_info[0])
def test_delete_snapshot_not_belong_to_stack(self):
stack1 = self._create_stack()
self.m.ReplayAll()
snapshot1 = self.engine.stack_snapshot(
self.ctx, stack1.identifier(), 'snap1')
self.engine.thread_group_mgr.groups[stack1.id].wait()
snapshot_id = snapshot1['id']
self.m.UnsetStubs()
stack2 = self._create_stack()
self.m.ReplayAll()
ex = self.assertRaises(dispatcher.ExpectedException,
self.engine.delete_snapshot,
self.ctx,
stack2.identifier(),
snapshot_id)
expected = ('The Snapshot (%(snapshot)s) for Stack (%(stack)s) '
'could not be found') % {'snapshot': snapshot_id,
'stack': stack2.name}
self.assertEqual(exception.SnapshotNotFound, ex.exc_info[0])
self.assertIn(expected, six.text_type(ex.exc_info[1]))
def test_delete_snapshot(self): def test_delete_snapshot(self):
stack = self._create_stack() stack = self._create_stack()
self.m.ReplayAll() self.m.ReplayAll()
@ -4488,8 +4532,13 @@ class SnapshotServiceTest(common.HeatTestCase):
self.m.UnsetStubs() self.m.UnsetStubs()
stack2 = self._create_stack() stack2 = self._create_stack()
self.m.ReplayAll() self.m.ReplayAll()
self.engine.stack_restore(self.ctx, stack2.identifier(), snapshot_id) ex = self.assertRaises(dispatcher.ExpectedException,
self.engine.thread_group_mgr.groups[stack2.id].wait() self.engine.stack_restore,
self.assertEqual((stack2.RESTORE, stack2.FAILED), stack2.state) self.ctx,
self.assertEqual("Can't restore snapshot from other stack", stack2.identifier(),
stack2.status_reason) snapshot_id)
expected = ('The Snapshot (%(snapshot)s) for Stack (%(stack)s) '
'could not be found') % {'snapshot': snapshot_id,
'stack': stack2.name}
self.assertEqual(exception.SnapshotNotFound, ex.exc_info[0])
self.assertIn(expected, six.text_type(ex.exc_info[1]))