Merge "Refactor validation of arguments to 'repeat' intrinsic function"

This commit is contained in:
Jenkins 2016-03-10 09:50:20 +00:00 committed by Gerrit Code Review
commit 73fc655d1c
2 changed files with 33 additions and 25 deletions

View File

@ -522,34 +522,39 @@ class Repeat(function.Function):
def __init__(self, stack, fn_name, args):
super(Repeat, self).__init__(stack, fn_name, args)
self._for_each, self._template = self._parse_args()
def _parse_args(self):
if not isinstance(self.args, collections.Mapping):
raise TypeError(_('Arguments to "%s" must be a map') %
self.fn_name)
# We don't check for invalid keys appearing here, which is wrong but
# it's probably too late to change
try:
for_each = self.args['for_each']
template = self.args['template']
except (KeyError, TypeError):
self._for_each = self.args['for_each']
self._template = self.args['template']
except KeyError:
example = ('''repeat:
template: This is %var%
for_each:
%var%: ['a', 'b', 'c']''')
raise KeyError(_('"repeat" syntax should be %s') %
example)
raise KeyError(_('"repeat" syntax should be %s') % example)
if not isinstance(for_each, function.Function):
if not isinstance(for_each, collections.Mapping):
def validate(self):
super(Repeat, self).validate()
if not isinstance(self._for_each, function.Function):
if not isinstance(self._for_each, collections.Mapping):
raise TypeError(_('The "for_each" argument to "%s" must '
'contain a map') % self.fn_name)
for v in six.itervalues(for_each):
if not isinstance(v, (list, function.Function)):
raise TypeError(_('The values of the "for_each" argument '
'to "%s" must be lists') % self.fn_name)
return for_each, template
if not all(self._valid_list(v) for v in self._for_each.values()):
raise TypeError(_('The values of the "for_each" argument '
'to "%s" must be lists') % self.fn_name)
@staticmethod
def _valid_list(arg):
return (isinstance(arg, (collections.Sequence,
function.Function)) and
not isinstance(arg, six.string_types))
def _do_replacement(self, keys, values, template):
if isinstance(template, six.string_types):
@ -566,16 +571,15 @@ class Repeat(function.Function):
def result(self):
for_each = function.resolve(self._for_each)
keys = list(six.iterkeys(for_each))
lists = [for_each[key] for key in keys]
if not all(isinstance(l, list) for l in lists):
if not all(self._valid_list(l) for l in for_each.values()):
raise TypeError(_('The values of the "for_each" argument to '
'"%s" must be lists') % self.fn_name)
template = function.resolve(self._template)
return [self._do_replacement(keys, items, template)
for items in itertools.product(*lists)]
keys, lists = six.moves.zip(*for_each.items())
return [self._do_replacement(keys, replacements, template)
for replacements in itertools.product(*lists)]
class Digest(function.Function):

View File

@ -917,11 +917,6 @@ class HOTemplateTest(common.HeatTestCase):
'foreach': {'%var%': ['a', 'b', 'c']}}}
self.assertRaises(KeyError, self.resolve, snippet, tmpl)
# for_each is not a map
snippet = {'repeat': {'template': 'this is %var%',
'for_each': '%var%'}}
self.assertRaises(TypeError, self.resolve, snippet, tmpl)
# value given to for_each entry is not a list
snippet = {'repeat': {'template': 'this is %var%',
'for_each': {'%var%': 'a'}}}
@ -932,6 +927,15 @@ class HOTemplateTest(common.HeatTestCase):
'for_each': {'%var%': ['a', 'b', 'c']}}}
self.assertRaises(KeyError, self.resolve, snippet, tmpl)
def test_repeat_bad_arg_type(self):
tmpl = template.Template(hot_kilo_tpl_empty)
# for_each is not a map
snippet = {'repeat': {'template': 'this is %var%',
'for_each': '%var%'}}
repeat = tmpl.parse(None, snippet)
self.assertRaises(TypeError, function.validate, repeat)
def test_digest(self):
snippet = {'digest': ['md5', 'foobar']}
snippet_resolved = '3858f62230ac3c915f300c664312c63f'