From f605bcc812e2f0eaeca7c62da759f16a7478a6b0 Mon Sep 17 00:00:00 2001 From: Zane Bitter Date: Wed, 31 Oct 2012 20:18:43 +0100 Subject: [PATCH] Separate Parameters implementation from Properties There are only 3 types allowed for Parameters (String, Number and CommaDelimitedList), but we are currently allowing more due to a shared implementation with Properties (which is an internal implementation detail). This creates a separate implementation for Parameters with only the allowed types. Change-Id: If51ec538893a582da2caa0356c25e515e9d8004e Signed-off-by: Zane Bitter --- heat/engine/parameters.py | 262 +++++++++++++++++++++++++++++ heat/engine/parser.py | 59 +------ heat/tests/test_parameters.py | 307 ++++++++++++++++++++++++++++++++++ heat/tests/test_parser.py | 44 ----- 4 files changed, 577 insertions(+), 95 deletions(-) create mode 100644 heat/engine/parameters.py create mode 100644 heat/tests/test_parameters.py diff --git a/heat/engine/parameters.py b/heat/engine/parameters.py new file mode 100644 index 0000000000..dbe3869f01 --- /dev/null +++ b/heat/engine/parameters.py @@ -0,0 +1,262 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections +import re + +from heat.engine import template + +PARAMETER_KEYS = ( + TYPE, DEFAULT, NO_ECHO, VALUES, PATTERN, + MAX_LENGTH, MIN_LENGTH, MAX_VALUE, MIN_VALUE, + DESCRIPTION, CONSTRAINT_DESCRIPTION +) = ( + 'Type', 'Default', 'NoEcho', 'AllowedValues', 'AllowedPattern', + 'MaxLength', 'MinLength', 'MaxValue', 'MinValue', + 'Description', 'ConstraintDescription' +) +PARAMETER_TYPES = ( + STRING, NUMBER, COMMA_DELIMITED_LIST +) = ( + 'String', 'Number', 'CommaDelimitedList' +) +(PARAM_STACK_NAME, PARAM_REGION) = ('AWS::StackName', 'AWS::Region') + + +class Parameter(object): + '''A template parameter.''' + + def __new__(cls, name, schema, value=None): + '''Create a new Parameter of the appropriate type.''' + if cls is not Parameter: + return super(Parameter, cls).__new__(cls) + + param_type = schema[TYPE] + if param_type == STRING: + ParamClass = StringParam + elif param_type == NUMBER: + ParamClass = NumberParam + elif param_type == COMMA_DELIMITED_LIST: + ParamClass = CommaDelimitedListParam + else: + raise ValueError('Invalid Parameter type "%s"' % param_type) + + return ParamClass(name, schema, value) + + def __init__(self, name, schema, value=None): + ''' + Initialise the Parameter with a name, schema and optional user-supplied + value. + ''' + self.name = name + self.schema = schema + self.user_value = value + self._constraint_error = self.schema.get(CONSTRAINT_DESCRIPTION) + + if self.has_default(): + self._validate(self.default()) + + if self.user_value is not None: + self._validate(self.user_value) + + def _error_msg(self, message): + return '%s %s' % (self.name, self._constraint_error or message) + + def _validate(self, value): + if VALUES in self.schema: + allowed = self.schema[VALUES] + if value not in allowed: + message = '%s not in %s %s' % (value, VALUES, allowed) + raise ValueError(self._error_msg(message)) + + def value(self): + '''Get the parameter value, optionally sanitising it for output.''' + if self.user_value is not None: + return self.user_value + + if self.has_default(): + return self.default() + + raise KeyError('Missing parameter %s' % self.name) + + def description(self): + '''Return the description of the parameter.''' + return self.schema.get(DESCRIPTION, '') + + def has_default(self): + '''Return whether the parameter has a default value.''' + return DEFAULT in self.schema + + def default(self): + '''Return the default value of the parameter.''' + return self.schema.get(DEFAULT) + + def __str__(self): + '''Return a string representation of the parameter''' + return self.value() + + +class NumberParam(Parameter): + '''A template parameter of type "Number".''' + + @staticmethod + def str_to_num(s): + '''Convert a string to an integer (if possible) or float.''' + try: + return int(s) + except ValueError: + return float(s) + + def _validate(self, value): + '''Check that the supplied value is compatible with the constraints.''' + num = self.str_to_num(value) + minn = self.str_to_num(self.schema.get(MIN_VALUE, value)) + maxn = self.str_to_num(self.schema.get(MAX_VALUE, value)) + + if num > maxn or num < minn: + raise ValueError(self._error_msg('%s is out of range' % value)) + + Parameter._validate(self, value) + + def __int__(self): + '''Return an integer representation of the parameter''' + return int(self.value()) + + def __float__(self): + '''Return a float representation of the parameter''' + return float(self.value()) + + +class StringParam(Parameter): + '''A template parameter of type "String".''' + + def _validate(self, value): + '''Check that the supplied value is compatible with the constraints''' + if not isinstance(value, basestring): + raise ValueError(self._error_msg('value must be a string')) + + length = len(value) + if MAX_LENGTH in self.schema: + max_length = int(self.schema[MAX_LENGTH]) + if length > max_length: + message = 'length (%d) overflows %s %s' % (length, + MAX_LENGTH, + max_length) + raise ValueError(self._error_msg(message)) + + if MIN_LENGTH in self.schema: + min_length = int(self.schema[MIN_LENGTH]) + if length < min_length: + message = 'length (%d) underflows %s %d' % (length, + MIN_LENGTH, + min_length) + raise ValueError(self._error_msg(message)) + + if PATTERN in self.schema: + pattern = self.schema[PATTERN] + match = re.match(pattern, value) + if match is None or match.end() != length: + message = '"%s" does not match %s "%s"' % (value, + PATTERN, + pattern) + raise ValueError(self._error_msg(message)) + + Parameter._validate(self, value) + + +class CommaDelimitedListParam(Parameter, collections.Sequence): + '''A template parameter of type "CommaDelimitedList".''' + + def _validate(self, value): + '''Check that the supplied value is compatible with the constraints''' + try: + sp = value.split(',') + except AttributeError: + raise ValueError('Value must be a comma-delimited list string') + + for li in self: + Parameter._validate(self, li) + + def __len__(self): + '''Return the length of the list''' + return len(self.value().split(',')) + + def __getitem__(self, index): + '''Return an item from the list''' + return self.value().split(',')[index] + + +class Parameters(collections.Mapping): + ''' + The parameters of a stack, with type checking, defaults &c. specified by + the stack's template. + ''' + + def __init__(self, stack_name, tmpl, user_params={}): + ''' + Create the parameter container for a stack from the stack name and + template, optionally setting the user-supplied parameter values. + ''' + def parameters(): + if stack_name is not None: + yield Parameter(PARAM_STACK_NAME, + {TYPE: STRING, + DESCRIPTION: 'Stack Name', + DEFAULT: stack_name}) + yield Parameter(PARAM_REGION, + {TYPE: STRING, + DEFAULT: 'ap-southeast-1', + VALUES: ['us-east-1', + 'us-west-1', 'us-west-2', + 'sa-east-1', + 'eu-west-1', + 'ap-southeast-1', + 'ap-northeast-1']}) + + for name, schema in tmpl[template.PARAMETERS].iteritems(): + yield Parameter(name, schema, user_params.get(name)) + + self.params = dict((p.name, p) for p in parameters()) + + def __contains__(self, key): + '''Return whether the specified parameter exists''' + return key in self.params + + def __iter__(self): + '''Return an iterator over the parameter names.''' + return iter(self.params) + + def __len__(self): + '''Return the number of parameters defined''' + return len(self.params) + + def __getitem__(self, key): + '''Get a parameter value.''' + return self.params[key].value() + + def map(self, func, filter_func=lambda p: True): + ''' + Map the supplied filter function onto each Parameter (with an + optional filter function) and return the resulting dictionary. + ''' + return dict((n, func(p)) for n, p in self.params.iteritems() + if filter_func(p)) + + def user_parameters(self): + ''' + Return a dictionary of all the parameters passed in by the user + ''' + return self.map(lambda p: p.user_value, + lambda p: p.user_value is not None) diff --git a/heat/engine/parser.py b/heat/engine/parser.py index 7c01818aed..dbefdc8397 100644 --- a/heat/engine/parser.py +++ b/heat/engine/parser.py @@ -19,12 +19,12 @@ import functools import copy from heat.common import exception -from heat.engine import checkeddict from heat.engine import dependencies from heat.engine import identifier from heat.engine import resources from heat.engine import template from heat.engine import timestamp +from heat.engine.parameters import Parameters from heat.engine.template import Template from heat.db import api as db_api @@ -35,51 +35,6 @@ logger = logging.getLogger('heat.engine.parser') (PARAM_STACK_NAME, PARAM_REGION) = ('AWS::StackName', 'AWS::Region') -class Parameters(checkeddict.CheckedDict): - ''' - The parameters of a stack, with type checking, defaults &c. specified by - the stack's template. - ''' - - def __init__(self, stack_name, tmpl, user_params={}): - ''' - Create the parameter container for a stack from the stack name and - template, optionally setting the initial set of parameters. - ''' - checkeddict.CheckedDict.__init__(self, template.PARAMETERS) - self._init_schemata(tmpl[template.PARAMETERS]) - - self[PARAM_STACK_NAME] = stack_name - self.update(user_params) - - def _init_schemata(self, schemata): - ''' - Initialise the parameter schemata with the pseudo-parameters and the - list of schemata obtained from the template. - ''' - self.addschema(PARAM_STACK_NAME, {"Description": "AWS StackName", - "Type": "String"}) - self.addschema(PARAM_REGION, { - "Description": "AWS Regions", - "Default": "ap-southeast-1", - "Type": "String", - "AllowedValues": ["us-east-1", "us-west-1", "us-west-2", - "sa-east-1", "eu-west-1", "ap-southeast-1", - "ap-northeast-1"], - "ConstraintDescription": "must be a valid EC2 instance type.", - }) - - for param, schema in schemata.items(): - self.addschema(param, copy.deepcopy(schema)) - - def user_parameters(self): - ''' - Return a dictionary of all the parameters passed in by the user - ''' - return dict((k, v['Value']) for k, v in self.data.iteritems() - if 'Value' in v) - - class Stack(object): CREATE_IN_PROGRESS = 'CREATE_IN_PROGRESS' CREATE_FAILED = 'CREATE_FAILED' @@ -241,14 +196,16 @@ class Stack(object): 'Parameters': []} return response - def format_param(p): + def describe_param(p): return {'NoEcho': 'false', - 'ParameterKey': p, - 'Description': self.parameters.get_attr(p, 'Description'), - 'DefaultValue': self.parameters.get_attr(p, 'Default')} + 'ParameterKey': p.name, + 'Description': p.description(), + 'DefaultValue': p.default()} + + params = self.parameters.map(describe_param) response = {'Description': 'Successfully validated', - 'Parameters': [format_param(p) for p in self.parameters]} + 'Parameters': params.values()} return response diff --git a/heat/tests/test_parameters.py b/heat/tests/test_parameters.py new file mode 100644 index 0000000000..9f92ba228e --- /dev/null +++ b/heat/tests/test_parameters.py @@ -0,0 +1,307 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import nose +import unittest +from nose.plugins.attrib import attr +import mox +import json + +from heat.common import context +from heat.common import exception +from heat.engine import parameters + + +@attr(tag=['unit', 'parameters']) +@attr(speed='fast') +class ParameterTest(unittest.TestCase): + def test_new_string(self): + p = parameters.Parameter('p', {'Type': 'String'}) + self.assertTrue(isinstance(p, parameters.StringParam)) + + def test_new_number(self): + p = parameters.Parameter('p', {'Type': 'Number'}) + self.assertTrue(isinstance(p, parameters.NumberParam)) + + def test_new_list(self): + p = parameters.Parameter('p', {'Type': 'CommaDelimitedList'}) + self.assertTrue(isinstance(p, parameters.CommaDelimitedListParam)) + + def test_new_bad_type(self): + self.assertRaises(ValueError, parameters.Parameter, + 'p', {'Type': 'List'}) + + def test_new_no_type(self): + self.assertRaises(KeyError, parameters.Parameter, + 'p', {'Default': 'blarg'}) + + def test_default_no_override(self): + p = parameters.Parameter('defaulted', {'Type': 'String', + 'Default': 'blarg'}) + self.assertTrue(p.has_default()) + self.assertEqual(p.default(), 'blarg') + self.assertEqual(p.value(), 'blarg') + + def test_default_override(self): + p = parameters.Parameter('defaulted', + {'Type': 'String', + 'Default': 'blarg'}, + 'wibble') + self.assertTrue(p.has_default()) + self.assertEqual(p.default(), 'blarg') + self.assertEqual(p.value(), 'wibble') + + def test_default_invalid(self): + schema = {'Type': 'String', + 'AllowedValues': ['foo'], + 'ConstraintDescription': 'wibble', + 'Default': 'bar'} + try: + parameters.Parameter('p', schema, 'foo') + except ValueError as ve: + msg = str(ve) + self.assertNotEqual(msg.find('wibble'), -1) + else: + self.fail('ValueError not raised') + + def test_description(self): + description = 'Description of the parameter' + p = parameters.Parameter('p', {'Type': 'String', + 'Description': description}) + self.assertEqual(p.description(), description) + + def test_no_description(self): + p = parameters.Parameter('p', {'Type': 'String'}) + self.assertEqual(p.description(), '') + + def test_string_len_good(self): + schema = {'Type': 'String', + 'MinLength': '3', + 'MaxLength': '3'} + p = parameters.Parameter('p', schema, 'foo') + self.assertEqual(p.value(), 'foo') + + def test_string_underflow(self): + schema = {'Type': 'String', + 'ConstraintDescription': 'wibble', + 'MinLength': '4'} + try: + parameters.Parameter('p', schema, 'foo') + except ValueError as ve: + msg = str(ve) + self.assertNotEqual(msg.find('wibble'), -1) + else: + self.fail('ValueError not raised') + + def test_string_overflow(self): + schema = {'Type': 'String', + 'ConstraintDescription': 'wibble', + 'MaxLength': '2'} + try: + parameters.Parameter('p', schema, 'foo') + except ValueError as ve: + msg = str(ve) + self.assertNotEqual(msg.find('wibble'), -1) + else: + self.fail('ValueError not raised') + + def test_string_pattern_good(self): + schema = {'Type': 'String', + 'AllowedPattern': '[a-z]*'} + p = parameters.Parameter('p', schema, 'foo') + self.assertEqual(p.value(), 'foo') + + def test_string_pattern_bad_prefix(self): + schema = {'Type': 'String', + 'ConstraintDescription': 'wibble', + 'AllowedPattern': '[a-z]*'} + try: + parameters.Parameter('p', schema, '1foo') + except ValueError as ve: + msg = str(ve) + self.assertNotEqual(msg.find('wibble'), -1) + else: + self.fail('ValueError not raised') + + def test_string_pattern_bad_suffix(self): + schema = {'Type': 'String', + 'ConstraintDescription': 'wibble', + 'AllowedPattern': '[a-z]*'} + try: + parameters.Parameter('p', schema, 'foo1') + except ValueError as ve: + msg = str(ve) + self.assertNotEqual(msg.find('wibble'), -1) + else: + self.fail('ValueError not raised') + + def test_string_value_list_good(self): + schema = {'Type': 'String', + 'AllowedValues': ['foo', 'bar', 'baz']} + p = parameters.Parameter('p', schema, 'bar') + self.assertEqual(p.value(), 'bar') + + def test_string_value_list_bad(self): + schema = {'Type': 'String', + 'ConstraintDescription': 'wibble', + 'AllowedValues': ['foo', 'bar', 'baz']} + try: + parameters.Parameter('p', schema, 'blarg') + except ValueError as ve: + msg = str(ve) + self.assertNotEqual(msg.find('wibble'), -1) + else: + self.fail('ValueError not raised') + + def test_number_int_good(self): + schema = {'Type': 'Number', + 'MinValue': '3', + 'MaxValue': '3'} + p = parameters.Parameter('p', schema, '3') + self.assertEqual(p.value(), '3') + + def test_number_float_good(self): + schema = {'Type': 'Number', + 'MinValue': '3.0', + 'MaxValue': '3.0'} + p = parameters.Parameter('p', schema, '3.0') + self.assertEqual(p.value(), '3.0') + + def test_number_low(self): + schema = {'Type': 'Number', + 'ConstraintDescription': 'wibble', + 'MinValue': '4'} + try: + parameters.Parameter('p', schema, '3') + except ValueError as ve: + msg = str(ve) + self.assertNotEqual(msg.find('wibble'), -1) + else: + self.fail('ValueError not raised') + + def test_number_high(self): + schema = {'Type': 'Number', + 'ConstraintDescription': 'wibble', + 'MaxValue': '2'} + try: + parameters.Parameter('p', schema, '3') + except ValueError as ve: + msg = str(ve) + self.assertNotEqual(msg.find('wibble'), -1) + else: + self.fail('ValueError not raised') + + def test_number_value_list_good(self): + schema = {'Type': 'Number', + 'AllowedValues': ['1', '3', '5']} + p = parameters.Parameter('p', schema, '5') + self.assertEqual(p.value(), '5') + + def test_number_value_list_bad(self): + schema = {'Type': 'Number', + 'ConstraintDescription': 'wibble', + 'AllowedValues': ['1', '3', '5']} + try: + parameters.Parameter('p', schema, '2') + except ValueError as ve: + msg = str(ve) + self.assertNotEqual(msg.find('wibble'), -1) + else: + self.fail('ValueError not raised') + + def test_list_value_list_good(self): + schema = {'Type': 'CommaDelimitedList', + 'AllowedValues': ['foo', 'bar', 'baz']} + p = parameters.Parameter('p', schema, 'baz,foo,bar') + self.assertEqual(p.value(), 'baz,foo,bar') + + def test_list_value_list_bad(self): + schema = {'Type': 'CommaDelimitedList', + 'ConstraintDescription': 'wibble', + 'AllowedValues': ['foo', 'bar', 'baz']} + try: + parameters.Parameter('p', schema, 'foo,baz,blarg') + except ValueError as ve: + msg = str(ve) + self.assertNotEqual(msg.find('wibble'), -1) + else: + self.fail('ValueError not raised') + + +params_schema = json.loads('''{ + "Parameters" : { + "User" : { "Type": "String" }, + "Defaulted" : { + "Type": "String", + "Default": "foobar" + } + } +}''') + + +@attr(tag=['unit', 'parameters']) +@attr(speed='fast') +class ParametersTest(unittest.TestCase): + def test_pseudo_params(self): + params = parameters.Parameters('test_stack', {"Parameters": {}}) + + self.assertEqual(params['AWS::StackName'], 'test_stack') + self.assertTrue('AWS::Region' in params) + + def test_user_param(self): + user_params = {'User': 'wibble'} + params = parameters.Parameters('test', params_schema, user_params) + self.assertEqual(params.user_parameters(), user_params) + + def test_user_param_nonexist(self): + params = parameters.Parameters('test', params_schema) + self.assertEqual(params.user_parameters(), {}) + + def test_schema_invariance(self): + params1 = parameters.Parameters('test', params_schema, + {'Defaulted': 'wibble'}) + self.assertEqual(params1['Defaulted'], 'wibble') + + params2 = parameters.Parameters('test', params_schema) + self.assertEqual(params2['Defaulted'], 'foobar') + + def test_to_dict(self): + template = {'Parameters': {'Foo': {'Type': 'String'}, + 'Bar': {'Type': 'Number', 'Default': '42'}}} + params = parameters.Parameters('test_params', template, {'Foo': 'foo'}) + + as_dict = dict(params) + self.assertEqual(as_dict['Foo'], 'foo') + self.assertEqual(as_dict['Bar'], '42') + self.assertEqual(as_dict['AWS::StackName'], 'test_params') + self.assertTrue('AWS::Region' in as_dict) + + def test_map(self): + template = {'Parameters': {'Foo': {'Type': 'String'}, + 'Bar': {'Type': 'Number', 'Default': '42'}}} + params = parameters.Parameters('test_params', template, {'Foo': 'foo'}) + + expected = {'Foo': False, + 'Bar': True, + 'AWS::Region': True, + 'AWS::StackName': True} + + self.assertEqual(params.map(lambda p: p.has_default()), expected) + + +# allows testing of the test directly, shown below +if __name__ == '__main__': + sys.argv.append(__file__) + nose.main() diff --git a/heat/tests/test_parser.py b/heat/tests/test_parser.py index 4f56c78155..dfc10fb713 100644 --- a/heat/tests/test_parser.py +++ b/heat/tests/test_parser.py @@ -273,50 +273,6 @@ class TemplateTest(unittest.TestCase): dict_snippet) -params_schema = json.loads('''{ - "Parameters" : { - "User" : { "Type": "String" }, - "Defaulted" : { - "Type": "String", - "Default": "foobar" - } - } -}''') - - -@attr(tag=['unit', 'parser', 'parameters']) -@attr(speed='fast') -class ParametersTest(unittest.TestCase): - def test_pseudo_params(self): - params = parser.Parameters('test_stack', {"Parameters": {}}) - - self.assertEqual(params['AWS::StackName'], 'test_stack') - self.assertTrue('AWS::Region' in params) - - def test_user_param(self): - params = parser.Parameters('test', params_schema, {'User': 'wibble'}) - user_params = params.user_parameters() - self.assertEqual(user_params['User'], 'wibble') - - def test_user_param_default(self): - params = parser.Parameters('test', params_schema) - user_params = params.user_parameters() - self.assertTrue('Defaulted' not in user_params) - - def test_user_param_nonexist(self): - params = parser.Parameters('test', params_schema) - user_params = params.user_parameters() - self.assertTrue('User' not in user_params) - - def test_schema_invariance(self): - params1 = parser.Parameters('test', params_schema) - params1['Defaulted'] = "wibble" - self.assertEqual(params1['Defaulted'], 'wibble') - - params2 = parser.Parameters('test', params_schema) - self.assertEqual(params2['Defaulted'], 'foobar') - - @attr(tag=['unit', 'parser', 'stack']) @attr(speed='fast') class StackTest(unittest.TestCase):