diff --git a/watcher/api/controllers/v1/audit_template.py b/watcher/api/controllers/v1/audit_template.py index ffc2c0e8b..ca81b16a2 100644 --- a/watcher/api/controllers/v1/audit_template.py +++ b/watcher/api/controllers/v1/audit_template.py @@ -148,18 +148,23 @@ class AuditTemplatePostType(wtypes.Base): "included and excluded together")) if audit_template.strategy: - available_strategies = objects.Strategy.list( - AuditTemplatePostType._ctx) - available_strategies_map = { - s.uuid: s for s in available_strategies} - if audit_template.strategy not in available_strategies_map: + try: + if (common_utils.is_uuid_like(audit_template.strategy) or + common_utils.is_int_like(audit_template.strategy)): + strategy = objects.Strategy.get( + AuditTemplatePostType._ctx, audit_template.strategy) + else: + strategy = objects.Strategy.get_by_name( + AuditTemplatePostType._ctx, audit_template.strategy) + except Exception: raise exception.InvalidStrategy( strategy=audit_template.strategy) - strategy = available_strategies_map[audit_template.strategy] # Check that the strategy we indicate is actually related to the # specified goal if strategy.goal_id != goal.id: + available_strategies = objects.Strategy.list( + AuditTemplatePostType._ctx) choices = ["'%s' (%s)" % (s.uuid, s.name) for s in available_strategies] raise exception.InvalidStrategy( diff --git a/watcher/tests/api/v1/test_audit_templates.py b/watcher/tests/api/v1/test_audit_templates.py index dc990e4d2..6735fbc24 100644 --- a/watcher/tests/api/v1/test_audit_templates.py +++ b/watcher/tests/api/v1/test_audit_templates.py @@ -555,6 +555,35 @@ class TestPost(FunctionalTestWithSetup): response.json['created_at']).replace(tzinfo=None) self.assertEqual(test_time, return_created_at) + @mock.patch.object(timeutils, 'utcnow') + def test_create_audit_template_with_strategy_name(self, mock_utcnow): + audit_template_dict = post_get_test_audit_template( + goal=self.fake_goal1.uuid, + strategy=self.fake_strategy1.name) + test_time = datetime.datetime(2000, 1, 1, 0, 0) + mock_utcnow.return_value = test_time + + response = self.post_json('/audit_templates', audit_template_dict) + self.assertEqual('application/json', response.content_type) + self.assertEqual(201, response.status_int) + # Check location header + self.assertIsNotNone(response.location) + expected_location = \ + '/v1/audit_templates/%s' % response.json['uuid'] + self.assertEqual(urlparse.urlparse(response.location).path, + expected_location) + self.assertTrue(utils.is_uuid_like(response.json['uuid'])) + self.assertNotIn('updated_at', response.json.keys) + self.assertNotIn('deleted_at', response.json.keys) + self.assertEqual(self.fake_goal1.uuid, response.json['goal_uuid']) + self.assertEqual(self.fake_strategy1.uuid, + response.json['strategy_uuid']) + self.assertEqual(self.fake_strategy1.name, + response.json['strategy_name']) + return_created_at = timeutils.parse_isotime( + response.json['created_at']).replace(tzinfo=None) + self.assertEqual(test_time, return_created_at) + def test_create_audit_template_validation_with_aggregates(self): scope = [{'compute': [{'host_aggregates': [{'id': '*'}]}, {'availability_zones': [{'name': 'AZ1'},