diff --git a/heat/engine/cfn/template.py b/heat/engine/cfn/template.py index b2288449b4..e7f3b65eb4 100644 --- a/heat/engine/cfn/template.py +++ b/heat/engine/cfn/template.py @@ -140,7 +140,7 @@ class CfnTemplateBase(template_common.CommonTemplate): return rsrc_defn.ResourceDefinition(name, **kwargs) - conditions = template_common.Conditions(self.conditions(stack)) + conditions = self.conditions(stack) def defns(): for name, snippet in resources.items(): @@ -155,8 +155,14 @@ class CfnTemplateBase(template_common.CommonTemplate): cond_name = defn.condition_name() if cond_name is not None: - path = [self.RESOURCES, name, self.RES_CONDITION] - if not conditions.is_enabled(cond_name, path): + try: + enabled = conditions.is_enabled(cond_name) + except ValueError as exc: + path = [self.RESOURCES, name, self.RES_CONDITION] + message = six.text_type(exc) + raise exception.StackValidationFailed(path=path, + message=message) + if not enabled: continue yield name, defn diff --git a/heat/engine/conditions.py b/heat/engine/conditions.py new file mode 100644 index 0000000000..e54269cccb --- /dev/null +++ b/heat/engine/conditions.py @@ -0,0 +1,62 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections + +import six + +from heat.common.i18n import _ + +from heat.common import exception +from heat.engine import function + + +class Conditions(object): + def __init__(self, conditions_dict): + assert isinstance(conditions_dict, collections.Mapping) + self._conditions = conditions_dict + self._resolved = {} + + def validate(self): + for name, cond in six.iteritems(self._conditions): + self._check_condition_type(name, cond) + function.validate(cond) + + def _resolve(self, condition_name): + resolved = function.resolve(self._conditions[condition_name]) + self._check_condition_type(condition_name, resolved) + return resolved + + def _check_condition_type(self, condition_name, condition_defn): + if not isinstance(condition_defn, (bool, function.Function)): + msg_data = {'cd': condition_name, 'definition': condition_defn} + message = _('The definition of condition "%(cd)s" is invalid: ' + '%(definition)s') % msg_data + raise exception.StackValidationFailed( + error='Condition validation error', + message=message) + + def is_enabled(self, condition_name): + if condition_name is None: + return True + + if not (isinstance(condition_name, six.string_types) and + condition_name in self._conditions): + raise ValueError(_('Invalid condition "%s"') % condition_name) + + if condition_name not in self._resolved: + self._resolved[condition_name] = self._resolve(condition_name) + return self._resolved[condition_name] + + def __repr__(self): + return 'Conditions(%r)' % self._conditions diff --git a/heat/engine/hot/functions.py b/heat/engine/hot/functions.py index 3216a8836d..a24f4a81dd 100644 --- a/heat/engine/hot/functions.py +++ b/heat/engine/hot/functions.py @@ -950,15 +950,11 @@ class If(function.Macro): '[condition_name, value_if_true, value_if_false]') raise ValueError(msg % self.fn_name) - cd = self.get_condition(cd_name) + cd = self._get_condition(cd_name) return parse_func(value_if_true if cd else value_if_false) - def get_condition(self, cd_name): - conditions = self.template.conditions(self.stack) - if cd_name not in conditions: - raise KeyError(_('Invalid condition name "%s"') % cd_name) - - return conditions[cd_name] + def _get_condition(self, cd_name): + return self.template.conditions(self.stack).is_enabled(cd_name) class Not(function.Function): diff --git a/heat/engine/hot/template.py b/heat/engine/hot/template.py index 5bbe2f4172..7cd5221c20 100644 --- a/heat/engine/hot/template.py +++ b/heat/engine/hot/template.py @@ -228,8 +228,7 @@ class HOTemplate20130523(template_common.CommonTemplate): def resource_definitions(self, stack): resources = self.t.get(self.RESOURCES) or {} - - conditions = template_common.Conditions(self.conditions(stack)) + conditions = self.conditions(stack) def defns(): for name, snippet in six.iteritems(resources): @@ -244,8 +243,14 @@ class HOTemplate20130523(template_common.CommonTemplate): cond_name = defn.condition_name() if cond_name is not None: - path = [self.RESOURCES, name, self.RES_CONDITION] - if not conditions.is_enabled(cond_name, path): + try: + enabled = conditions.is_enabled(cond_name) + except ValueError as exc: + path = [self.RESOURCES, name, self.RES_CONDITION] + message = six.text_type(exc) + raise exception.StackValidationFailed(path=path, + message=message) + if not enabled: continue yield name, defn diff --git a/heat/engine/stack.py b/heat/engine/stack.py index 2fbafd022d..bd408559cc 100644 --- a/heat/engine/stack.py +++ b/heat/engine/stack.py @@ -793,6 +793,8 @@ class Stack(collections.Mapping): DeprecationWarning) self.t.validate_resource_definitions(self) + self.t.conditions(self).validate() + # Load the resources definitions (success of which implies the # definitions are valid) resources = self.resources @@ -808,7 +810,7 @@ class Stack(collections.Mapping): if validate_by_deps: iter_rsc = self.dependencies else: - iter_rsc = six.itervalues(self.resources) + iter_rsc = six.itervalues(resources) for res in iter_rsc: try: diff --git a/heat/engine/template.py b/heat/engine/template.py index d71e6eb2f5..06f563be5a 100644 --- a/heat/engine/template.py +++ b/heat/engine/template.py @@ -23,6 +23,7 @@ from stevedore import extension from heat.common import exception from heat.common.i18n import _ +from heat.engine import conditions from heat.engine import environment from heat.engine import function from heat.engine import output @@ -254,7 +255,7 @@ class Template(collections.Mapping): def conditions(self, stack): """Return a dictionary of resolved conditions.""" - return {} + return conditions.Conditions({}) def outputs(self, stack): warnings.warn("The default implementation of the outputs() method " diff --git a/heat/engine/template_common.py b/heat/engine/template_common.py index 3086353268..f665423542 100644 --- a/heat/engine/template_common.py +++ b/heat/engine/template_common.py @@ -17,6 +17,7 @@ import six from heat.common import exception from heat.common.i18n import _ +from heat.engine import conditions from heat.engine import function from heat.engine import output from heat.engine import template @@ -85,44 +86,29 @@ class CommonTemplate(template.Template): six.string_types, 'string', name, data) - def _resolve_conditions(self, stack): - cd_snippet = self._get_condition_definitions() - result = {} - for cd_key, cd_value in six.iteritems(cd_snippet): - # hasn't been resolved yet - if not isinstance(cd_value, bool): - condition_func = self.parse_condition( - stack, cd_value, '.'.join([self.CONDITIONS, cd_key])) - resolved_cd_value = function.resolve(condition_func) - result[cd_key] = resolved_cd_value - else: - result[cd_key] = cd_value - - return result - def _get_condition_definitions(self): """Return the condition definitions of template.""" return {} def conditions(self, stack): if self._conditions is None: - resolved_cds = self._resolve_conditions(stack) - if resolved_cds: - for cd_key, cd_value in six.iteritems(resolved_cds): - if not isinstance(cd_value, bool): - msg_data = {'cd': cd_key, 'definition': cd_value} - message = _('The definition of condition "%(cd)s" is ' - 'invalid: %(definition)s') % msg_data - raise exception.StackValidationFailed( - error='Condition validation error', - message=message) + raw_defs = self._get_condition_definitions() + if not isinstance(raw_defs, collections.Mapping): + message = _('Condition definitions must be a map. Found a ' + '%s instead') % type(raw_defs).__name__ + raise exception.StackValidationFailed( + error='Conditions validation error', + message=message) - self._conditions = resolved_cds + parsed = {n: self.parse_condition(stack, c, + '.'.join([self.CONDITIONS, n])) + for n, c in raw_defs.items()} + self._conditions = conditions.Conditions(parsed) return self._conditions def outputs(self, stack): - conditions = Conditions(self.conditions(stack)) + conds = self.conditions(stack) outputs = self.t.get(self.OUTPUTS) or {} @@ -148,8 +134,15 @@ class CommonTemplate(template.Template): if hasattr(self, 'OUTPUT_CONDITION'): cond_name = val.get(self.OUTPUT_CONDITION) - path = [self.OUTPUTS, key, self.OUTPUT_CONDITION] - if not conditions.is_enabled(cond_name, path): + try: + enabled = conds.is_enabled(cond_name) + except ValueError as exc: + path = [self.OUTPUTS, key, self.OUTPUT_CONDITION] + message = six.text_type(exc) + raise exception.StackValidationFailed(path=path, + message=message) + + if not enabled: yield key, output.OutputDefinition(key, None, description) continue @@ -161,19 +154,3 @@ class CommonTemplate(template.Template): yield key, output.OutputDefinition(key, value_def, description) return dict(get_outputs()) - - -class Conditions(object): - def __init__(self, conditions_dict): - self._conditions = conditions_dict - - def is_enabled(self, condition_name, path=None): - if condition_name is None: - return True - - if condition_name not in self._conditions: - message = _('Invalid condition "%s"') % condition_name - raise exception.StackValidationFailed(path=path, - message=message) - - return self._conditions[condition_name] diff --git a/heat/tests/test_hot.py b/heat/tests/test_hot.py index b11334d6ca..6be84fa609 100644 --- a/heat/tests/test_hot.py +++ b/heat/tests/test_hot.py @@ -20,6 +20,7 @@ from heat.common import identifier from heat.common import template_format from heat.engine.cfn import functions as cfn_functions from heat.engine import check_resource as cr +from heat.engine import conditions from heat.engine import environment from heat.engine import function from heat.engine.hot import functions as hot_functions @@ -1168,14 +1169,16 @@ class HOTemplateTest(common.HeatTestCase): tmpl = template.Template(hot_newton_tpl_empty) stack = parser.Stack(utils.dummy_context(), 'test_if_function', tmpl) - tmpl._conditions = {'create_prod': True} - resolved = self.resolve(snippet, tmpl, stack) - self.assertEqual('value_if_true', resolved) + with mock.patch.object(tmpl, 'conditions') as conds: + conds.return_value = conditions.Conditions({'create_prod': True}) + resolved = self.resolve(snippet, tmpl, stack) + self.assertEqual('value_if_true', resolved) # when condition evaluates to false, if function # resolve to value_if_false - tmpl._conditions = {'create_prod': False} - resolved = self.resolve(snippet, tmpl, stack) - self.assertEqual('value_if_false', resolved) + with mock.patch.object(tmpl, 'conditions') as conds: + conds.return_value = conditions.Conditions({'create_prod': False}) + resolved = self.resolve(snippet, tmpl, stack) + self.assertEqual('value_if_false', resolved) def test_if_invalid_args(self): snippet = {'if': ['create_prod', 'one_value']} @@ -1191,11 +1194,13 @@ class HOTemplateTest(common.HeatTestCase): tmpl = template.Template(hot_newton_tpl_empty) stack = parser.Stack(utils.dummy_context(), 'test_if_function', tmpl) - tmpl._conditions = {'create_prod': True} - exc = self.assertRaises(exception.StackValidationFailed, - self.resolve, snippet, tmpl, stack) - self.assertIn('Invalid condition name "cd_not_existing"', + with mock.patch.object(tmpl, 'conditions') as conds: + conds.return_value = conditions.Conditions({'create_prod': True}) + exc = self.assertRaises(exception.StackValidationFailed, + self.resolve, snippet, tmpl, stack) + self.assertIn('Invalid condition "cd_not_existing"', six.text_type(exc)) + self.assertIn('if:', six.text_type(exc)) def test_repeat(self): """Test repeat function.""" diff --git a/heat/tests/test_template.py b/heat/tests/test_template.py index 26170b87b4..38ff5bc311 100644 --- a/heat/tests/test_template.py +++ b/heat/tests/test_template.py @@ -32,7 +32,6 @@ from heat.engine import parameters from heat.engine import rsrc_defn from heat.engine import stack from heat.engine import template -from heat.engine import template_common from heat.tests import common from heat.tests.openstack.nova import fakes as fakes_nova from heat.tests import utils @@ -322,7 +321,7 @@ class TestTemplateConditionParser(common.HeatTestCase): tmpl = template.Template(t) stk = stack.Stack(self.ctx, 'test_condition_with_get_attr_func', tmpl) ex = self.assertRaises(exception.StackValidationFailed, - tmpl._resolve_conditions, stk) + tmpl.conditions, stk) self.assertIn('"get_attr" is invalid', six.text_type(ex)) self.assertIn('conditions.prod_env.equals[1].get_attr', six.text_type(ex)) @@ -331,14 +330,14 @@ class TestTemplateConditionParser(common.HeatTestCase): tmpl.t['conditions']['prod_env'] = {'get_resource': 'R1'} stk = stack.Stack(self.ctx, 'test_condition_with_get_attr_func', tmpl) ex = self.assertRaises(exception.StackValidationFailed, - tmpl._resolve_conditions, stk) + tmpl.conditions, stk) self.assertIn('"get_resource" is invalid', six.text_type(ex)) # test with get_attr in top level of a condition tmpl.t['conditions']['prod_env'] = {'get_attr': [None, 'att']} stk = stack.Stack(self.ctx, 'test_condition_with_get_attr_func', tmpl) ex = self.assertRaises(exception.StackValidationFailed, - tmpl._resolve_conditions, stk) + tmpl.conditions, stk) self.assertIn('"get_attr" is invalid', six.text_type(ex)) def test_condition_resolved_not_boolean(self): @@ -357,8 +356,9 @@ class TestTemplateConditionParser(common.HeatTestCase): tmpl = template.Template(t) stk = stack.Stack(self.ctx, 'test_condition_not_boolean', tmpl) + conditions = tmpl.conditions(stk) ex = self.assertRaises(exception.StackValidationFailed, - tmpl.conditions, stk) + conditions.is_enabled, 'prod_env') self.assertIn('The definition of condition "prod_env" is invalid', six.text_type(ex)) @@ -367,18 +367,12 @@ class TestTemplateConditionParser(common.HeatTestCase): # test condition name is invalid stk = stack.Stack(self.ctx, 'test_res_invalid_condition', tmpl) - conds = template_common.Conditions(tmpl.conditions(stk)) - ex = self.assertRaises(exception.StackValidationFailed, - conds.is_enabled, 'invalid_cd', - 'resources.r1.condition') + conds = tmpl.conditions(stk) + ex = self.assertRaises(ValueError, conds.is_enabled, 'invalid_cd') self.assertIn('Invalid condition "invalid_cd"', six.text_type(ex)) - self.assertIn('resources.r1.condition', six.text_type(ex)) # test condition name is not string - ex = self.assertRaises(exception.StackValidationFailed, - conds.is_enabled, 111, - 'resources.r1.condition') + ex = self.assertRaises(ValueError, conds.is_enabled, 111) self.assertIn('Invalid condition "111"', six.text_type(ex)) - self.assertIn('resources.r1.condition', six.text_type(ex)) def test_parse_output_condition_invalid(self): stk = stack.Stack(self.ctx,