diff --git a/heat/engine/resources/loadbalancer.py b/heat/engine/resources/loadbalancer.py index c74c8ed387..6f39c07b00 100644 --- a/heat/engine/resources/loadbalancer.py +++ b/heat/engine/resources/loadbalancer.py @@ -434,25 +434,38 @@ class LoadBalancer(stack_resource.StackResource): contents = lb_template_default return template_format.parse(contents) - def handle_create(self): + def child_params(self): + params = {} + + # If the owning stack defines KeyName, we use that key for the nested + # template, otherwise use no key + if 'KeyName' in self.stack.parameters: + params['KeyName'] = self.stack.parameters['KeyName'] + + return params + + def child_template(self): templ = self.get_parsed_template() + # If the owning stack defines KeyName, we use that key for the nested + # template, otherwise use no key + if 'KeyName' not in self.stack.parameters: + del templ['Resources']['LB_instance']['Properties']['KeyName'] + del templ['Parameters']['KeyName'] + + return templ + + def handle_create(self): + templ = self.child_template() + params = self.child_params() + if self.properties[self.INSTANCES]: md = templ['Resources']['LB_instance']['Metadata'] files = md['AWS::CloudFormation::Init']['config']['files'] cfg = self._haproxy_config(templ, self.properties[self.INSTANCES]) files['/etc/haproxy/haproxy.cfg']['content'] = cfg - # If the owning stack defines KeyName, we use that key for the nested - # template, otherwise use no key - try: - param = {'KeyName': self.stack.parameters['KeyName']} - except KeyError: - del templ['Resources']['LB_instance']['Properties']['KeyName'] - del templ['Parameters']['KeyName'] - param = {} - - return self.create_with_template(templ, param) + return self.create_with_template(templ, params) def handle_update(self, json_snippet, tmpl_diff, prop_diff): ''' diff --git a/heat/tests/test_loadbalancer.py b/heat/tests/test_loadbalancer.py index 2248b3d30a..0e37e7fd04 100644 --- a/heat/tests/test_loadbalancer.py +++ b/heat/tests/test_loadbalancer.py @@ -13,6 +13,7 @@ # under the License. +import mock import mox import re @@ -241,3 +242,41 @@ class LoadBalancerTest(HeatTestCase): t['Resources']['LoadBalancer'], s) self.assertRaises(exception.StackValidationFailed, rsrc.validate) + + def setup_loadbalancer(self, include_keyname=True): + template = template_format.parse(lb_template) + if not include_keyname: + del template['Parameters']['KeyName'] + stack = utils.parse_stack(template) + + resource_name = 'LoadBalancer' + lb_json = template['Resources'][resource_name] + return lb.LoadBalancer(resource_name, lb_json, stack) + + def test_child_params_without_key_name(self): + rsrc = self.setup_loadbalancer(False) + self.assertEqual({}, rsrc.child_params()) + + def test_child_params_with_key_name(self): + rsrc = self.setup_loadbalancer() + params = rsrc.child_params() + self.assertEqual('test', params['KeyName']) + + def test_child_template_without_key_name(self): + rsrc = self.setup_loadbalancer(False) + parsed_template = { + 'Resources': {'LB_instance': {'Properties': {'KeyName': 'foo'}}}, + 'Parameters': {'KeyName': 'foo'} + } + rsrc.get_parsed_template = mock.Mock(return_value=parsed_template) + + tmpl = rsrc.child_template() + self.assertNotIn('KeyName', tmpl['Parameters']) + self.assertNotIn('KeyName', + tmpl['Resources']['LB_instance']['Properties']) + + def test_child_template_with_key_name(self): + rsrc = self.setup_loadbalancer() + rsrc.get_parsed_template = mock.Mock(return_value='foo') + + self.assertEqual('foo', rsrc.child_template())