Calculate template version only once

Make the template version an attribute rather than a method, so we don't
have to repeatedly recalculate it.

Change-Id: I65392d83aa931eb1fd9998664ad77d0c6fa5eb97
This commit is contained in:
Zane Bitter 2014-03-03 19:26:08 -05:00
parent f7a52b5005
commit 480087cfb4
5 changed files with 16 additions and 21 deletions

View File

@ -27,6 +27,10 @@ class CfnTemplate(template.Template):
SECTIONS_NO_DIRECT_ACCESS = set([PARAMETERS, VERSION]) SECTIONS_NO_DIRECT_ACCESS = set([PARAMETERS, VERSION])
def __init__(self, template, *args, **kwargs):
super(CfnTemplate, self).__init__(template, *args, **kwargs)
self.version = self._version()
def __getitem__(self, section): def __getitem__(self, section):
'''Get the relevant section in the template.''' '''Get the relevant section in the template.'''
if section not in self.SECTIONS: if section not in self.SECTIONS:
@ -42,7 +46,7 @@ class CfnTemplate(template.Template):
return self.t.get(section, default) return self.t.get(section, default)
def version(self): def _version(self):
for key in ('HeatTemplateFormatVersion', 'AWSTemplateFormatVersion'): for key in ('HeatTemplateFormatVersion', 'AWSTemplateFormatVersion'):
if key in self.t: if key in self.t:
return key, self.t[key] return key, self.t[key]
@ -64,4 +68,4 @@ class CfnTemplate(template.Template):
context=context) context=context)
def functions(self): def functions(self):
return functions.function_mapping(*self.version()) return functions.function_mapping(*self.version)

View File

@ -58,7 +58,10 @@ class HOTemplate(template.Template):
cfn_template.CfnTemplate.OUTPUTS: OUTPUTS} cfn_template.CfnTemplate.OUTPUTS: OUTPUTS}
def __init__(self, template, *args, **kwargs): def __init__(self, template, *args, **kwargs):
version = template[self.VERSION] # All user templates are forced to include a version string. This is
# just a convenient default for unit tests.
version = template.get(self.VERSION, '2013-05-23')
if version not in self.VERSIONS: if version not in self.VERSIONS:
msg = _('"%(version)s" is not a valid ' msg = _('"%(version)s" is not a valid '
'heat_template_version. Should be one of: ' 'heat_template_version. Should be one of: '
@ -67,6 +70,7 @@ class HOTemplate(template.Template):
'valid': str(self.VERSIONS)}) 'valid': str(self.VERSIONS)})
super(HOTemplate, self).__init__(template, *args, **kwargs) super(HOTemplate, self).__init__(template, *args, **kwargs)
self.version = self.VERSION, version
def __getitem__(self, section): def __getitem__(self, section):
""""Get the relevant section in the template.""" """"Get the relevant section in the template."""
@ -146,14 +150,6 @@ class HOTemplate(template.Template):
return cfn_outputs return cfn_outputs
def version(self):
if self.VERSION in self.t:
return self.VERSION, self.t[self.VERSION]
# All user templates are forced to include a version string. This is
# just a convenient default for unit tests.
return self.VERSION, '2013-05-23'
def param_schemata(self): def param_schemata(self):
params = self.t.get(self.PARAMETERS, {}).iteritems() params = self.t.get(self.PARAMETERS, {}).iteritems()
return dict((name, HOTParamSchema.from_dict(schema)) return dict((name, HOTParamSchema.from_dict(schema))
@ -166,7 +162,7 @@ class HOTemplate(template.Template):
def functions(self): def functions(self):
from heat.engine.hot import functions from heat.engine.hot import functions
return functions.function_mapping(*self.version()) return functions.function_mapping(*self.version)
class HOTParamSchema(parameters.Schema): class HOTParamSchema(parameters.Schema):

View File

@ -72,11 +72,6 @@ class Template(collections.Mapping):
'''Return the number of sections.''' '''Return the number of sections.'''
return len(self.SECTIONS) - len(self.SECTIONS_NO_DIRECT_ACCESS) return len(self.SECTIONS) - len(self.SECTIONS_NO_DIRECT_ACCESS)
@abc.abstractmethod
def version(self):
'''Return a (versionkey, version) tuple for the template.'''
pass
@abc.abstractmethod @abc.abstractmethod
def param_schemata(self): def param_schemata(self):
'''Return a dict of parameters.Schema objects for the parameters.''' '''Return a dict of parameters.Schema objects for the parameters.'''

View File

@ -311,8 +311,8 @@ class HOTemplateTest(HeatTestCase):
tmpl_str = "heat_template_version: 2013-05-23" tmpl_str = "heat_template_version: 2013-05-23"
hot_tmpl = template_format.parse(tmpl_str) hot_tmpl = template_format.parse(tmpl_str)
parsed_tmpl = template.Template(hot_tmpl) parsed_tmpl = template.Template(hot_tmpl)
expected = '2013-05-23' expected = ('heat_template_version', '2013-05-23')
observed = parsed_tmpl.version()[1] observed = parsed_tmpl.version
self.assertEqual(expected, observed) self.assertEqual(expected, observed)

View File

@ -169,12 +169,12 @@ class TemplateTest(HeatTestCase):
def test_aws_version(self): def test_aws_version(self):
tmpl = parser.Template(mapping_template) tmpl = parser.Template(mapping_template)
self.assertEqual(('AWSTemplateFormatVersion', '2010-09-09'), self.assertEqual(('AWSTemplateFormatVersion', '2010-09-09'),
tmpl.version()) tmpl.version)
def test_heat_version(self): def test_heat_version(self):
tmpl = parser.Template(resource_template) tmpl = parser.Template(resource_template)
self.assertEqual(('HeatTemplateFormatVersion', '2012-12-12'), self.assertEqual(('HeatTemplateFormatVersion', '2012-12-12'),
tmpl.version()) tmpl.version)
def test_invalid_template(self): def test_invalid_template(self):
scanner_error = ''' scanner_error = '''