diff --git a/heat/api/openstack/v1/stacks.py b/heat/api/openstack/v1/stacks.py index ad29a62f48..9aa51a29db 100644 --- a/heat/api/openstack/v1/stacks.py +++ b/heat/api/openstack/v1/stacks.py @@ -21,6 +21,7 @@ from heat.api.openstack.v1 import util from heat.api.openstack.v1.views import stacks_view from heat.common import environment_format from heat.common import identifier +from heat.common import param_utils from heat.common import serializers from heat.common import template_format from heat.common import urlfetch @@ -167,8 +168,8 @@ class StackController(object): filter_params = util.get_allowed_params(req.params, filter_whitelist) if engine_api.PARAM_SHOW_DELETED in params: - show_del = util.extract_bool(params[engine_api.PARAM_SHOW_DELETED]) - params[engine_api.PARAM_SHOW_DELETED] = show_del + params[engine_api.PARAM_SHOW_DELETED] = param_utils.extract_bool( + params[engine_api.PARAM_SHOW_DELETED]) if not filter_params: filter_params = None @@ -223,14 +224,13 @@ class StackController(object): """ data = InstantiationData(body) - args = util.extract_args(data.args()) result = self.rpc_client.preview_stack(req.context, data.stack_name(), data.template(), data.environment(), data.files(), - args) + data.args()) formatted_stack = stacks_view.format_stack(req, result) return {'stack': formatted_stack} @@ -241,14 +241,13 @@ class StackController(object): Create a new stack """ data = InstantiationData(body) - args = util.extract_args(data.args()) result = self.rpc_client.create_stack(req.context, data.stack_name(), data.template(), data.environment(), data.files(), - args) + data.args()) formatted_stack = stacks_view.format_stack( req, @@ -310,14 +309,13 @@ class StackController(object): Update an existing stack with a new template and/or parameters """ data = InstantiationData(body) - args = util.extract_args(data.args()) self.rpc_client.update_stack(req.context, identity, data.template(), data.environment(), data.files(), - args) + data.args()) raise exc.HTTPAccepted() diff --git a/heat/api/openstack/v1/util.py b/heat/api/openstack/v1/util.py index 490ed80657..a23bb32efe 100644 --- a/heat/api/openstack/v1/util.py +++ b/heat/api/openstack/v1/util.py @@ -16,13 +16,7 @@ from functools import wraps from webob import exc from heat.common import identifier -from heat.common import template_format from heat.openstack.common.gettextutils import _ -from heat.openstack.common import log as logging -from heat.openstack.common import strutils -from heat.rpc import api - -logger = logging.getLogger(__name__) def policy_enforce(handler): @@ -106,49 +100,3 @@ def get_allowed_params(params, whitelist): allowed_params[key] = value return allowed_params - - -def extract_args(params): - ''' - Extract any arguments passed as parameters through the API and return them - as a dictionary. This allows us to filter the passed args and do type - conversion where appropriate - ''' - kwargs = {} - timeout_mins = params.get(api.PARAM_TIMEOUT) - if timeout_mins not in ('0', 0, None): - try: - timeout = int(timeout_mins) - except (ValueError, TypeError): - logger.exception(_('Timeout conversion failed')) - else: - if timeout > 0: - kwargs[api.PARAM_TIMEOUT] = timeout - else: - raise ValueError(_('Invalid timeout value %s') % timeout) - - if api.PARAM_DISABLE_ROLLBACK in params: - disable_rollback = extract_bool(params[api.PARAM_DISABLE_ROLLBACK]) - kwargs[api.PARAM_DISABLE_ROLLBACK] = disable_rollback - - adopt_data = params.get(api.PARAM_ADOPT_STACK_DATA) - if adopt_data: - adopt_data = template_format.simple_parse(adopt_data) - if not isinstance(adopt_data, dict): - raise ValueError( - _('Unexpected adopt data "%s". Adopt data must be a dict.') - % adopt_data) - kwargs[api.PARAM_ADOPT_STACK_DATA] = adopt_data - - return kwargs - - -def extract_bool(subject): - ''' - Convert any true/false string to its corresponding boolean value, - regardless of case. - ''' - if str(subject).lower() not in ('true', 'false'): - raise ValueError(_('Unrecognized value "%(value)s, acceptable values ' - 'are: true, false.') % {'value': subject}) - return strutils.bool_from_string(subject, strict=True) diff --git a/heat/common/param_utils.py b/heat/common/param_utils.py new file mode 100644 index 0000000000..36e0bc0b08 --- /dev/null +++ b/heat/common/param_utils.py @@ -0,0 +1,26 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from heat.openstack.common.gettextutils import _ +from heat.openstack.common import strutils + + +def extract_bool(subject): + ''' + Convert any true/false string to its corresponding boolean value, + regardless of case. + ''' + if str(subject).lower() not in ('true', 'false'): + raise ValueError(_('Unrecognized value "%(value)s, acceptable values ' + 'are: true, false.') % {'value': subject}) + return strutils.bool_from_string(subject, strict=True) diff --git a/heat/engine/api.py b/heat/engine/api.py index 78b4a9e23f..2b983cb250 100644 --- a/heat/engine/api.py +++ b/heat/engine/api.py @@ -11,6 +11,8 @@ # License for the specific language governing permissions and limitations # under the License. +from heat.common import param_utils +from heat.common import template_format from heat.engine import constraints as constr from heat.openstack.common.gettextutils import _ from heat.openstack.common import log as logging @@ -20,6 +22,46 @@ from heat.rpc import api logger = logging.getLogger(__name__) +def extract_args(params): + ''' + Extract any arguments passed as parameters through the API and return them + as a dictionary. This allows us to filter the passed args and do type + conversion where appropriate + ''' + kwargs = {} + timeout_mins = params.get(api.PARAM_TIMEOUT) + if timeout_mins not in ('0', 0, None): + try: + timeout = int(timeout_mins) + except (ValueError, TypeError): + logger.exception(_('Timeout conversion failed')) + else: + if timeout > 0: + kwargs[api.PARAM_TIMEOUT] = timeout + else: + raise ValueError(_('Invalid timeout value %s') % timeout) + + if api.PARAM_DISABLE_ROLLBACK in params: + disable_rollback = param_utils.extract_bool( + params[api.PARAM_DISABLE_ROLLBACK]) + kwargs[api.PARAM_DISABLE_ROLLBACK] = disable_rollback + + if api.PARAM_SHOW_DELETED in params: + params[api.PARAM_SHOW_DELETED] = param_utils.extract_bool( + params[api.PARAM_SHOW_DELETED]) + + adopt_data = params.get(api.PARAM_ADOPT_STACK_DATA) + if adopt_data: + adopt_data = template_format.simple_parse(adopt_data) + if not isinstance(adopt_data, dict): + raise ValueError( + _('Unexpected adopt data "%s". Adopt data must be a dict.') + % adopt_data) + kwargs[api.PARAM_ADOPT_STACK_DATA] = adopt_data + + return kwargs + + def format_stack_outputs(stack, outputs): ''' Return a representation of the given output template for the given stack diff --git a/heat/engine/service.py b/heat/engine/service.py index 70a91155d7..2ad4b71f38 100644 --- a/heat/engine/service.py +++ b/heat/engine/service.py @@ -495,8 +495,9 @@ class EngineService(service.Service): tmpl = parser.Template(template, files=files) self._validate_new_stack(cnxt, stack_name, tmpl) + common_params = api.extract_args(args) env = environment.Environment(params) - stack = parser.Stack(cnxt, stack_name, tmpl, env, **args) + stack = parser.Stack(cnxt, stack_name, tmpl, env, **common_params) self._validate_deferred_auth_context(cnxt, stack) stack.validate() @@ -538,8 +539,9 @@ class EngineService(service.Service): tmpl = parser.Template(template, files=files) self._validate_new_stack(cnxt, stack_name, tmpl) + common_params = api.extract_args(args) env = environment.Environment(params) - stack = parser.Stack(cnxt, stack_name, tmpl, env, **args) + stack = parser.Stack(cnxt, stack_name, tmpl, env, **common_params) self._validate_deferred_auth_context(cnxt, stack) @@ -590,10 +592,12 @@ class EngineService(service.Service): raise exception.RequestLimitExceeded( message=exception.StackResourceLimitExceeded.msg_fmt) stack_name = current_stack.name - args.setdefault(rpc_api.PARAM_TIMEOUT, current_stack.timeout_mins) + common_params = api.extract_args(args) + common_params.setdefault(rpc_api.PARAM_TIMEOUT, + current_stack.timeout_mins) env = environment.Environment(params) updated_stack = parser.Stack(cnxt, stack_name, tmpl, - env, **args) + env, **common_params) updated_stack.parameters.set_stack_id(current_stack.identifier()) self._validate_deferred_auth_context(cnxt, updated_stack) diff --git a/heat/tests/test_api_openstack_v1_util.py b/heat/tests/test_api_openstack_v1_util.py index f2aea29297..fa143debea 100644 --- a/heat/tests/test_api_openstack_v1_util.py +++ b/heat/tests/test_api_openstack_v1_util.py @@ -11,11 +11,9 @@ # License for the specific language governing permissions and limitations # under the License. -import json import mock from webob import exc -import six from heat.api.openstack.v1 import util from heat.common import context @@ -120,90 +118,3 @@ class TestPolicyEnforce(HeatTestCase): self.assertRaises(exc.HTTPForbidden, self.controller.an_action, self.req, tenant_id='foo') - - -class TestExtractArgs(HeatTestCase): - def test_timeout_extract(self): - p = {'timeout_mins': '5'} - args = util.extract_args(p) - self.assertEqual(5, args['timeout_mins']) - - def test_timeout_extract_zero(self): - p = {'timeout_mins': '0'} - args = util.extract_args(p) - self.assertNotIn('timeout_mins', args) - - def test_timeout_extract_garbage(self): - p = {'timeout_mins': 'wibble'} - args = util.extract_args(p) - self.assertNotIn('timeout_mins', args) - - def test_timeout_extract_none(self): - p = {'timeout_mins': None} - args = util.extract_args(p) - self.assertNotIn('timeout_mins', args) - - def test_timeout_extract_negative(self): - p = {'timeout_mins': '-100'} - error = self.assertRaises(ValueError, util.extract_args, p) - self.assertIn('Invalid timeout value', six.text_type(error)) - - def test_timeout_extract_not_present(self): - args = util.extract_args({}) - self.assertNotIn('timeout_mins', args) - - def test_adopt_stack_data_extract_present(self): - p = {'adopt_stack_data': json.dumps({'Resources': {}})} - args = util.extract_args(p) - self.assertTrue(args.get('adopt_stack_data')) - - def test_invalid_adopt_stack_data(self): - p = {'adopt_stack_data': json.dumps("foo")} - error = self.assertRaises(ValueError, util.extract_args, p) - self.assertEqual( - 'Unexpected adopt data "foo". Adopt data must be a dict.', - six.text_type(error)) - - def test_adopt_stack_data_extract_not_present(self): - args = util.extract_args({}) - self.assertNotIn('adopt_stack_data', args) - - def test_disable_rollback_extract_true(self): - args = util.extract_args({'disable_rollback': True}) - self.assertIn('disable_rollback', args) - self.assertTrue(args.get('disable_rollback')) - - args = util.extract_args({'disable_rollback': 'True'}) - self.assertIn('disable_rollback', args) - self.assertTrue(args.get('disable_rollback')) - - args = util.extract_args({'disable_rollback': 'true'}) - self.assertIn('disable_rollback', args) - self.assertTrue(args.get('disable_rollback')) - - def test_disable_rollback_extract_false(self): - args = util.extract_args({'disable_rollback': False}) - self.assertIn('disable_rollback', args) - self.assertFalse(args.get('disable_rollback')) - - args = util.extract_args({'disable_rollback': 'False'}) - self.assertIn('disable_rollback', args) - self.assertFalse(args.get('disable_rollback')) - - args = util.extract_args({'disable_rollback': 'false'}) - self.assertIn('disable_rollback', args) - self.assertFalse(args.get('disable_rollback')) - - def test_disable_rollback_extract_bad(self): - self.assertRaises(ValueError, util.extract_args, - {'disable_rollback': 'bad'}) - - -class TestExtractBool(HeatTestCase): - def test_extract_bool(self): - for value in ('True', 'true', 'TRUE', True): - self.assertTrue(util.extract_bool(value)) - for value in ('False', 'false', 'FALSE', False): - self.assertFalse(util.extract_bool(value)) - for value in ('foo', 't', 'f', 'yes', 'no', 'y', 'n', '1', '0', None): - self.assertRaises(ValueError, util.extract_bool, value) diff --git a/heat/tests/test_common_param_utils.py b/heat/tests/test_common_param_utils.py new file mode 100644 index 0000000000..5c0d7f2a42 --- /dev/null +++ b/heat/tests/test_common_param_utils.py @@ -0,0 +1,25 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from heat.common import param_utils +from heat.tests.common import HeatTestCase + + +class TestExtractBool(HeatTestCase): + def test_extract_bool(self): + for value in ('True', 'true', 'TRUE', True): + self.assertTrue(param_utils.extract_bool(value)) + for value in ('False', 'false', 'FALSE', False): + self.assertFalse(param_utils.extract_bool(value)) + for value in ('foo', 't', 'f', 'yes', 'no', 'y', 'n', '1', '0', None): + self.assertRaises(ValueError, param_utils.extract_bool, value) diff --git a/heat/tests/test_engine_api_utils.py b/heat/tests/test_engine_api_utils.py index 2238f6375d..eecaedf656 100644 --- a/heat/tests/test_engine_api_utils.py +++ b/heat/tests/test_engine_api_utils.py @@ -14,7 +14,9 @@ from datetime import datetime import uuid +import json import mock +import six from heat.common.identifier import EventIdentifier from heat.common import template_format @@ -790,3 +792,80 @@ class FormatSoftwareConfigDeploymentTest(HeatTestCase): def test_format_software_deployment_none(self): self.assertIsNone(api.format_software_deployment(None)) + + +class TestExtractArgs(HeatTestCase): + def test_timeout_extract(self): + p = {'timeout_mins': '5'} + args = api.extract_args(p) + self.assertEqual(5, args['timeout_mins']) + + def test_timeout_extract_zero(self): + p = {'timeout_mins': '0'} + args = api.extract_args(p) + self.assertNotIn('timeout_mins', args) + + def test_timeout_extract_garbage(self): + p = {'timeout_mins': 'wibble'} + args = api.extract_args(p) + self.assertNotIn('timeout_mins', args) + + def test_timeout_extract_none(self): + p = {'timeout_mins': None} + args = api.extract_args(p) + self.assertNotIn('timeout_mins', args) + + def test_timeout_extract_negative(self): + p = {'timeout_mins': '-100'} + error = self.assertRaises(ValueError, api.extract_args, p) + self.assertIn('Invalid timeout value', six.text_type(error)) + + def test_timeout_extract_not_present(self): + args = api.extract_args({}) + self.assertNotIn('timeout_mins', args) + + def test_adopt_stack_data_extract_present(self): + p = {'adopt_stack_data': json.dumps({'Resources': {}})} + args = api.extract_args(p) + self.assertTrue(args.get('adopt_stack_data')) + + def test_invalid_adopt_stack_data(self): + p = {'adopt_stack_data': json.dumps("foo")} + error = self.assertRaises(ValueError, api.extract_args, p) + self.assertEqual( + 'Unexpected adopt data "foo". Adopt data must be a dict.', + six.text_type(error)) + + def test_adopt_stack_data_extract_not_present(self): + args = api.extract_args({}) + self.assertNotIn('adopt_stack_data', args) + + def test_disable_rollback_extract_true(self): + args = api.extract_args({'disable_rollback': True}) + self.assertIn('disable_rollback', args) + self.assertTrue(args.get('disable_rollback')) + + args = api.extract_args({'disable_rollback': 'True'}) + self.assertIn('disable_rollback', args) + self.assertTrue(args.get('disable_rollback')) + + args = api.extract_args({'disable_rollback': 'true'}) + self.assertIn('disable_rollback', args) + self.assertTrue(args.get('disable_rollback')) + + def test_disable_rollback_extract_false(self): + args = api.extract_args({'disable_rollback': False}) + self.assertIn('disable_rollback', args) + self.assertFalse(args.get('disable_rollback')) + + args = api.extract_args({'disable_rollback': 'False'}) + self.assertIn('disable_rollback', args) + self.assertFalse(args.get('disable_rollback')) + + args = api.extract_args({'disable_rollback': 'false'}) + self.assertIn('disable_rollback', args) + self.assertFalse(args.get('disable_rollback')) + + def test_disable_rollback_extract_bad(self): + self.assertRaises(ValueError, api.extract_args, + {'disable_rollback': 'bad'})