diff --git a/heat/engine/cfn/template.py b/heat/engine/cfn/template.py index a1a641a33..66ef73488 100644 --- a/heat/engine/cfn/template.py +++ b/heat/engine/cfn/template.py @@ -27,6 +27,10 @@ class CfnTemplate(template.Template): SECTIONS_NO_DIRECT_ACCESS = set([PARAMETERS, VERSION]) + def __init__(self, template, *args, **kwargs): + super(CfnTemplate, self).__init__(template, *args, **kwargs) + self.version = self._version() + def __getitem__(self, section): '''Get the relevant section in the template.''' if section not in self.SECTIONS: @@ -42,7 +46,7 @@ class CfnTemplate(template.Template): return self.t.get(section, default) - def version(self): + def _version(self): for key in ('HeatTemplateFormatVersion', 'AWSTemplateFormatVersion'): if key in self.t: return key, self.t[key] @@ -64,4 +68,4 @@ class CfnTemplate(template.Template): context=context) def functions(self): - return functions.function_mapping(*self.version()) + return functions.function_mapping(*self.version) diff --git a/heat/engine/hot/__init__.py b/heat/engine/hot/__init__.py index 76d553283..a6ea5f8d2 100644 --- a/heat/engine/hot/__init__.py +++ b/heat/engine/hot/__init__.py @@ -58,7 +58,10 @@ class HOTemplate(template.Template): cfn_template.CfnTemplate.OUTPUTS: OUTPUTS} def __init__(self, template, *args, **kwargs): - version = template[self.VERSION] + # All user templates are forced to include a version string. This is + # just a convenient default for unit tests. + version = template.get(self.VERSION, '2013-05-23') + if version not in self.VERSIONS: msg = _('"%(version)s" is not a valid ' 'heat_template_version. Should be one of: ' @@ -67,6 +70,7 @@ class HOTemplate(template.Template): 'valid': str(self.VERSIONS)}) super(HOTemplate, self).__init__(template, *args, **kwargs) + self.version = self.VERSION, version def __getitem__(self, section): """"Get the relevant section in the template.""" @@ -146,14 +150,6 @@ class HOTemplate(template.Template): return cfn_outputs - def version(self): - if self.VERSION in self.t: - return self.VERSION, self.t[self.VERSION] - - # All user templates are forced to include a version string. This is - # just a convenient default for unit tests. - return self.VERSION, '2013-05-23' - def param_schemata(self): params = self.t.get(self.PARAMETERS, {}).iteritems() return dict((name, HOTParamSchema.from_dict(schema)) @@ -166,7 +162,7 @@ class HOTemplate(template.Template): def functions(self): from heat.engine.hot import functions - return functions.function_mapping(*self.version()) + return functions.function_mapping(*self.version) class HOTParamSchema(parameters.Schema): diff --git a/heat/engine/template.py b/heat/engine/template.py index e530a7246..677e33f01 100644 --- a/heat/engine/template.py +++ b/heat/engine/template.py @@ -72,11 +72,6 @@ class Template(collections.Mapping): '''Return the number of sections.''' return len(self.SECTIONS) - len(self.SECTIONS_NO_DIRECT_ACCESS) - @abc.abstractmethod - def version(self): - '''Return a (versionkey, version) tuple for the template.''' - pass - @abc.abstractmethod def param_schemata(self): '''Return a dict of parameters.Schema objects for the parameters.''' diff --git a/heat/tests/test_hot.py b/heat/tests/test_hot.py index 777b19f77..21b4bb202 100644 --- a/heat/tests/test_hot.py +++ b/heat/tests/test_hot.py @@ -311,8 +311,8 @@ class HOTemplateTest(HeatTestCase): tmpl_str = "heat_template_version: 2013-05-23" hot_tmpl = template_format.parse(tmpl_str) parsed_tmpl = template.Template(hot_tmpl) - expected = '2013-05-23' - observed = parsed_tmpl.version()[1] + expected = ('heat_template_version', '2013-05-23') + observed = parsed_tmpl.version self.assertEqual(expected, observed) diff --git a/heat/tests/test_parser.py b/heat/tests/test_parser.py index 77b5edffb..3e3f595a1 100644 --- a/heat/tests/test_parser.py +++ b/heat/tests/test_parser.py @@ -169,12 +169,12 @@ class TemplateTest(HeatTestCase): def test_aws_version(self): tmpl = parser.Template(mapping_template) self.assertEqual(('AWSTemplateFormatVersion', '2010-09-09'), - tmpl.version()) + tmpl.version) def test_heat_version(self): tmpl = parser.Template(resource_template) self.assertEqual(('HeatTemplateFormatVersion', '2012-12-12'), - tmpl.version()) + tmpl.version) def test_invalid_template(self): scanner_error = '''