diff --git a/sahara/service/validations/cluster_templates.py b/sahara/service/validations/cluster_templates.py index 947838c2..03f07326 100644 --- a/sahara/service/validations/cluster_templates.py +++ b/sahara/service/validations/cluster_templates.py @@ -60,24 +60,33 @@ def check_cluster_template_usage(cluster_template_id, **kwargs): 'clusters': ', '.join(users)}) -def check_cluster_template_update(data, **kwargs): +def check_cluster_template_update(cluster_template_id, data, **kwargs): + if data.get('plugin_name') and not data.get('hadoop_version'): + raise ex.InvalidReferenceException( + _("You must specify a hadoop_version value " + "for your plugin_name")) + if data.get('plugin_name'): - b.check_plugin_name_exists(data['plugin_name']) - - if data.get('plugin_name') and data.get('hadoop_version'): - b.check_plugin_supports_version(data['plugin_name'], - data['hadoop_version']) + plugin = data['plugin_name'] + version = data['hadoop_version'] + b.check_plugin_name_exists(plugin) + b.check_plugin_supports_version(plugin, version) b.check_all_configurations(data) + else: + cluster_template = api.get_cluster_template(cluster_template_id) + plugin = cluster_template.plugin_name + if data.get('hadoop_version'): + version = data.get('hadoop_version') + b.check_plugin_supports_version(plugin, version) + else: + version = cluster_template.hadoop_version - if data.get('default_image_id'): - b.check_image_registered(data['default_image_id']) - b.check_required_image_tags(data['plugin_name'], - data['hadoop_version'], - data['default_image_id']) + if data.get('default_image_id'): + b.check_image_registered(data['default_image_id']) + b.check_required_image_tags(plugin, version, data['default_image_id']) - if data.get('anti_affinity'): - b.check_node_processes(data['plugin_name'], data['hadoop_version'], - data['anti_affinity']) + if data.get('anti_affinity'): + b.check_node_processes(plugin, version, data['anti_affinity']) if data.get('neutron_management_network'): b.check_network_exists(data['neutron_management_network']) diff --git a/sahara/tests/unit/service/validation/test_cluster_template_update_validation.py b/sahara/tests/unit/service/validation/test_cluster_template_update_validation.py index 40cb058a..6fe7744f 100644 --- a/sahara/tests/unit/service/validation/test_cluster_template_update_validation.py +++ b/sahara/tests/unit/service/validation/test_cluster_template_update_validation.py @@ -15,9 +15,10 @@ import copy +import mock + from sahara.service import api from sahara.service.validations import cluster_template_schema as ct_schema -from sahara.service.validations import cluster_templates as ct from sahara.tests.unit.service.validation import utils as u @@ -33,7 +34,7 @@ SAMPLE_DATA = { class TestClusterTemplateUpdateValidation(u.ValidationTestCase): def setUp(self): super(TestClusterTemplateUpdateValidation, self).setUp() - self._create_object_fun = ct.check_cluster_template_update + self._create_object_fun = mock.Mock() self.scheme = ct_schema.CLUSTER_TEMPLATE_UPDATE_SCHEMA api.plugin_base.setup_plugins()