diff --git a/heatclient/osc/v1/stack.py b/heatclient/osc/v1/stack.py index 1d3ec7c7..92010a8c 100644 --- a/heatclient/osc/v1/stack.py +++ b/heatclient/osc/v1/stack.py @@ -251,11 +251,13 @@ class UpdateStack(command.ShowOne): ) parser.add_argument( '--rollback', metavar='', + default='keep', + choices=['enabled', 'disabled', 'keep'], help=_('Set rollback on update failure. ' 'Value "enabled" sets rollback to enabled. ' 'Value "disabled" sets rollback to disabled. ' 'Value "keep" uses the value of existing stack to be ' - 'updated (default)') + 'updated.') ) parser.add_argument( '--dry-run', action="store_true", @@ -370,13 +372,8 @@ class UpdateStack(command.ShowOne): if parsed_args.clear_parameter: fields['clear_parameters'] = list(parsed_args.clear_parameter) - if parsed_args.rollback: - rollback = parsed_args.rollback.strip().lower() - if rollback not in ('enabled', 'disabled', 'keep'): - msg = _('--rollback invalid value: %s') % parsed_args.rollback - raise exc.CommandError(msg) - if rollback != 'keep': - fields['disable_rollback'] = rollback == 'disabled' + if parsed_args.rollback != 'keep': + fields['disable_rollback'] = (parsed_args.rollback == 'disabled') if parsed_args.dry_run: if parsed_args.show_nested: diff --git a/heatclient/tests/unit/osc/v1/test_stack.py b/heatclient/tests/unit/osc/v1/test_stack.py index fd641964..7bf76a02 100644 --- a/heatclient/tests/unit/osc/v1/test_stack.py +++ b/heatclient/tests/unit/osc/v1/test_stack.py @@ -267,16 +267,6 @@ class TestStackUpdate(TestStack): self.assertNotIn('disable_rollback', self.defaults) self.stack_client.update.assert_called_with(**self.defaults) - def test_stack_update_rollback_invalid(self): - arglist = ['my_stack', '-t', self.template_path, '--rollback', 'foo'] - kwargs = copy.deepcopy(self.defaults) - kwargs['disable_rollback'] = False - parsed_args = self.check_parser(self.cmd, arglist, []) - - ex = self.assertRaises(exc.CommandError, self.cmd.take_action, - parsed_args) - self.assertEqual("--rollback invalid value: foo", str(ex)) - def test_stack_update_parameters(self): template_path = ('/'.join(self.template_path.split('/')[:-1]) + '/parameters.yaml')