Merge "Refactor validation of arguments to 'repeat' intrinsic function"
This commit is contained in:
commit
73fc655d1c
@ -522,34 +522,39 @@ class Repeat(function.Function):
|
||||
def __init__(self, stack, fn_name, args):
|
||||
super(Repeat, self).__init__(stack, fn_name, args)
|
||||
|
||||
self._for_each, self._template = self._parse_args()
|
||||
|
||||
def _parse_args(self):
|
||||
if not isinstance(self.args, collections.Mapping):
|
||||
raise TypeError(_('Arguments to "%s" must be a map') %
|
||||
self.fn_name)
|
||||
|
||||
# We don't check for invalid keys appearing here, which is wrong but
|
||||
# it's probably too late to change
|
||||
try:
|
||||
for_each = self.args['for_each']
|
||||
template = self.args['template']
|
||||
except (KeyError, TypeError):
|
||||
self._for_each = self.args['for_each']
|
||||
self._template = self.args['template']
|
||||
except KeyError:
|
||||
example = ('''repeat:
|
||||
template: This is %var%
|
||||
for_each:
|
||||
%var%: ['a', 'b', 'c']''')
|
||||
raise KeyError(_('"repeat" syntax should be %s') %
|
||||
example)
|
||||
raise KeyError(_('"repeat" syntax should be %s') % example)
|
||||
|
||||
if not isinstance(for_each, function.Function):
|
||||
if not isinstance(for_each, collections.Mapping):
|
||||
def validate(self):
|
||||
super(Repeat, self).validate()
|
||||
|
||||
if not isinstance(self._for_each, function.Function):
|
||||
if not isinstance(self._for_each, collections.Mapping):
|
||||
raise TypeError(_('The "for_each" argument to "%s" must '
|
||||
'contain a map') % self.fn_name)
|
||||
for v in six.itervalues(for_each):
|
||||
if not isinstance(v, (list, function.Function)):
|
||||
raise TypeError(_('The values of the "for_each" argument '
|
||||
'to "%s" must be lists') % self.fn_name)
|
||||
|
||||
return for_each, template
|
||||
if not all(self._valid_list(v) for v in self._for_each.values()):
|
||||
raise TypeError(_('The values of the "for_each" argument '
|
||||
'to "%s" must be lists') % self.fn_name)
|
||||
|
||||
@staticmethod
|
||||
def _valid_list(arg):
|
||||
return (isinstance(arg, (collections.Sequence,
|
||||
function.Function)) and
|
||||
not isinstance(arg, six.string_types))
|
||||
|
||||
def _do_replacement(self, keys, values, template):
|
||||
if isinstance(template, six.string_types):
|
||||
@ -566,16 +571,15 @@ class Repeat(function.Function):
|
||||
|
||||
def result(self):
|
||||
for_each = function.resolve(self._for_each)
|
||||
keys = list(six.iterkeys(for_each))
|
||||
lists = [for_each[key] for key in keys]
|
||||
if not all(isinstance(l, list) for l in lists):
|
||||
if not all(self._valid_list(l) for l in for_each.values()):
|
||||
raise TypeError(_('The values of the "for_each" argument to '
|
||||
'"%s" must be lists') % self.fn_name)
|
||||
|
||||
template = function.resolve(self._template)
|
||||
|
||||
return [self._do_replacement(keys, items, template)
|
||||
for items in itertools.product(*lists)]
|
||||
keys, lists = six.moves.zip(*for_each.items())
|
||||
return [self._do_replacement(keys, replacements, template)
|
||||
for replacements in itertools.product(*lists)]
|
||||
|
||||
|
||||
class Digest(function.Function):
|
||||
|
@ -917,11 +917,6 @@ class HOTemplateTest(common.HeatTestCase):
|
||||
'foreach': {'%var%': ['a', 'b', 'c']}}}
|
||||
self.assertRaises(KeyError, self.resolve, snippet, tmpl)
|
||||
|
||||
# for_each is not a map
|
||||
snippet = {'repeat': {'template': 'this is %var%',
|
||||
'for_each': '%var%'}}
|
||||
self.assertRaises(TypeError, self.resolve, snippet, tmpl)
|
||||
|
||||
# value given to for_each entry is not a list
|
||||
snippet = {'repeat': {'template': 'this is %var%',
|
||||
'for_each': {'%var%': 'a'}}}
|
||||
@ -932,6 +927,15 @@ class HOTemplateTest(common.HeatTestCase):
|
||||
'for_each': {'%var%': ['a', 'b', 'c']}}}
|
||||
self.assertRaises(KeyError, self.resolve, snippet, tmpl)
|
||||
|
||||
def test_repeat_bad_arg_type(self):
|
||||
tmpl = template.Template(hot_kilo_tpl_empty)
|
||||
|
||||
# for_each is not a map
|
||||
snippet = {'repeat': {'template': 'this is %var%',
|
||||
'for_each': '%var%'}}
|
||||
repeat = tmpl.parse(None, snippet)
|
||||
self.assertRaises(TypeError, function.validate, repeat)
|
||||
|
||||
def test_digest(self):
|
||||
snippet = {'digest': ['md5', 'foobar']}
|
||||
snippet_resolved = '3858f62230ac3c915f300c664312c63f'
|
||||
|
Loading…
x
Reference in New Issue
Block a user