Database API and engine changes for stack tags

Allow tagging of stacks with simple string tags.

blueprint stack-tags

Change-Id: I65e1e8e87515595edae332c2ff7e0e82ded409ce
changes/50/159250/11
Jason Dunsmore 8 years ago
parent ecb1b1eb04
commit 9817ed77ef
  1. 1
      heat/api/openstack/v1/views/stacks_view.py
  2. 12
      heat/db/api.py
  3. 30
      heat/db/sqlalchemy/api.py
  4. 21
      heat/engine/api.py
  5. 19
      heat/engine/stack.py
  6. 11
      heat/objects/stack.py
  7. 39
      heat/objects/stack_tag.py
  8. 8
      heat/rpc/api.py
  9. 32
      heat/tests/db/test_sqlalchemy_api.py
  10. 34
      heat/tests/test_engine_api_utils.py
  11. 33
      heat/tests/test_stack.py
  12. 43
      heat/tests/test_stack_update.py

@ -31,6 +31,7 @@ basic_keys = (
rpc_api.STACK_OWNER,
rpc_api.STACK_PARENT,
rpc_api.STACK_USER_PROJECT_ID,
rpc_api.STACK_TAGS,
)

@ -76,6 +76,18 @@ def resource_data_delete(resource, key):
return IMPL.resource_data_delete(resource, key)
def stack_tags_set(context, stack_id, tags):
return IMPL.stack_tags_set(context, stack_id, tags)
def stack_tags_delete(context, stack_id):
return IMPL.stack_tags_delete(context, stack_id)
def stack_tags_get(context, stack_id):
return IMPL.stack_tags_get(context, stack_id)
def resource_get(context, resource_id):
return IMPL.resource_get(context, resource_id)

@ -200,6 +200,36 @@ def resource_data_get(resource, key):
return result.value
def stack_tags_set(context, stack_id, tags):
session = get_session()
with session.begin():
stack_tags_delete(context, stack_id)
result = []
for tag in tags:
stack_tag = models.StackTag()
stack_tag.tag = tag
stack_tag.stack_id = stack_id
stack_tag.save(session=session)
result.append(stack_tag)
return result or None
def stack_tags_delete(context, stack_id):
session = get_session()
with session.begin():
result = stack_tags_get(context, stack_id)
if result:
for tag in result:
tag.delete()
def stack_tags_get(context, stack_id):
result = (model_query(context, models.StackTag)
.filter_by(stack_id=stack_id)
.all())
return result or None
def _encrypt(value):
if value is not None:
return crypt.encrypt(value.encode('utf-8'))

@ -15,6 +15,7 @@ import collections
from oslo_log import log as logging
from oslo_utils import timeutils
import six
from heat.common.i18n import _
from heat.common.i18n import _LE
@ -62,6 +63,25 @@ def extract_args(params):
raise ValueError(_('Invalid adopt data: %s') % exc)
kwargs[rpc_api.PARAM_ADOPT_STACK_DATA] = adopt_data
tags = params.get(rpc_api.PARAM_TAGS)
if tags:
if not isinstance(tags, list):
raise ValueError(_('Invalid tags, not a list: %s') % tags)
for tag in tags:
if not isinstance(tag, six.string_types):
raise ValueError(_('Invalid tag, "%s" is not a string') % tag)
if len(tag) > 80:
raise ValueError(_('Invalid tag, "%s" is longer than 80 '
'characters') % tag)
# Comma is not allowed as per the API WG tagging guidelines
if ',' in tag:
raise ValueError(_('Invalid tag, "%s" contains a comma') % tag)
kwargs[rpc_api.PARAM_TAGS] = tags
return kwargs
@ -105,6 +125,7 @@ def format_stack(stack, preview=False):
rpc_api.STACK_OWNER: stack.username,
rpc_api.STACK_PARENT: stack.owner_id,
rpc_api.STACK_USER_PROJECT_ID: stack.stack_user_project_id,
rpc_api.STACK_TAGS: stack.tags,
}
if not preview:

@ -44,6 +44,7 @@ from heat.engine import update
from heat.objects import resource as resource_objects
from heat.objects import snapshot as snapshot_object
from heat.objects import stack as stack_object
from heat.objects import stack_tag as stack_tag_object
from heat.objects import user_creds as ucreds_object
from heat.rpc import api as rpc_api
@ -83,7 +84,7 @@ class Stack(collections.Mapping):
user_creds_id=None, tenant_id=None,
use_stored_context=False, username=None,
nested_depth=0, strict_validate=True, convergence=False,
current_traversal=None):
current_traversal=None, tags=None):
'''
Initialise from a context, name, Template object and (optionally)
Environment object. The database ID may also be initialised, if the
@ -126,6 +127,7 @@ class Stack(collections.Mapping):
self.strict_validate = strict_validate
self.convergence = convergence
self.current_traversal = current_traversal
self.tags = tags
if use_stored_context:
self.context = self.stored_context()
@ -371,6 +373,9 @@ class Stack(collections.Mapping):
use_stored_context=False):
template = tmpl.Template.load(
context, stack.raw_template_id, stack.raw_template)
tags = None
if stack.tags:
tags = [t.tag for t in stack.tags]
return cls(context, stack.name, template,
stack_id=stack.id,
action=stack.action, status=stack.status,
@ -386,7 +391,7 @@ class Stack(collections.Mapping):
user_creds_id=stack.user_creds_id, tenant_id=stack.tenant,
use_stored_context=use_stored_context,
username=stack.username, convergence=stack.convergence,
current_traversal=stack.current_traversal)
current_traversal=stack.current_traversal, tags=tags)
def get_kwargs_for_cloning(self, keep_status=False, only_db=False):
"""Get common kwargs for calling Stack() for cloning.
@ -465,6 +470,9 @@ class Stack(collections.Mapping):
self.id = new_s.id
self.created_time = new_s.created_at
if self.tags:
stack_tag_object.StackTagList.set(self.context, self.id, self.tags)
self._set_param_stackid()
return self.id
@ -920,6 +928,13 @@ class Stack(collections.Mapping):
self.timeout_mins = newstack.timeout_mins
self._set_param_stackid()
self.tags = newstack.tags
if newstack.tags:
stack_tag_object.StackTagList.set(self.context, self.id,
newstack.tags)
else:
stack_tag_object.StackTagList.delete(self.context, self.id)
try:
updater.start(timeout=self.timeout_secs())
yield

@ -57,7 +57,7 @@ class Stack(
'current_deps': heat_fields.JsonField(),
'prev_raw_template_id': fields.IntegerField(),
'prev_raw_template': fields.ObjectField('RawTemplate'),
'tag': fields.ObjectField('StackTag'),
'tags': fields.ObjectField('StackTagList'),
}
@staticmethod
@ -67,13 +67,12 @@ class Stack(
stack['raw_template'] = (
raw_template.RawTemplate.get_by_id(
context, db_stack['raw_template_id']))
elif field == 'tag':
elif field == 'tags':
if db_stack.get(field) is not None:
stack['tag'] = stack_tag.StackTag.get_obj(
db_stack.get(field)
)
stack['tags'] = stack_tag.StackTagList.get(
context, db_stack['id'])
else:
stack['tag'] = None
stack['tags'] = None
else:
stack[field] = db_stack.__dict__.get(field)
stack._context = context

@ -19,6 +19,8 @@ StackTag object
from oslo_versionedobjects import base
from oslo_versionedobjects import fields
from heat.db import api as db_api
class StackTag(base.VersionedObject,
base.VersionedObjectDictCompat,
@ -32,7 +34,11 @@ class StackTag(base.VersionedObject,
}
@staticmethod
def _from_db_object(tag, db_tag):
def _from_db_object(context, tag, db_tag):
"""Method to help with migration to objects.
Converts a database entity to a formal object.
"""
if db_tag is None:
return None
for field in tag.fields:
@ -42,5 +48,32 @@ class StackTag(base.VersionedObject,
@classmethod
def get_obj(cls, context, tag):
tag_obj = cls._from_db_object(cls(context), tag)
return tag_obj
return cls._from_db_object(cls(context), tag)
class StackTagList(base.VersionedObject,
base.ObjectListBase):
fields = {
'objects': fields.ListOfObjectsField('StackTag'),
}
def __init__(self, *args, **kwargs):
self._changed_fields = set()
super(StackTagList, self).__init__()
@classmethod
def get(cls, context, stack_id):
db_tags = db_api.stack_tags_get(context, stack_id)
if db_tags:
return base.obj_make_list(context, cls(), StackTag, db_tags)
@classmethod
def set(cls, context, stack_id, tags):
db_tags = db_api.stack_tags_set(context, stack_id, tags)
if db_tags:
return base.obj_make_list(context, cls(), StackTag, db_tags)
@classmethod
def delete(cls, context, stack_id):
db_api.stack_tags_delete(context, stack_id)

@ -18,12 +18,12 @@ PARAM_KEYS = (
PARAM_TIMEOUT, PARAM_DISABLE_ROLLBACK, PARAM_ADOPT_STACK_DATA,
PARAM_SHOW_DELETED, PARAM_SHOW_NESTED, PARAM_EXISTING,
PARAM_CLEAR_PARAMETERS, PARAM_GLOBAL_TENANT, PARAM_LIMIT,
PARAM_NESTED_DEPTH,
PARAM_NESTED_DEPTH, PARAM_TAGS
) = (
'timeout_mins', 'disable_rollback', 'adopt_stack_data',
'show_deleted', 'show_nested', 'existing',
'clear_parameters', 'global_tenant', 'limit',
'nested_depth',
'nested_depth', 'tags'
)
STACK_KEYS = (
@ -34,7 +34,7 @@ STACK_KEYS = (
STACK_PARAMETERS, STACK_OUTPUTS, STACK_ACTION,
STACK_STATUS, STACK_STATUS_DATA, STACK_CAPABILITIES,
STACK_DISABLE_ROLLBACK, STACK_TIMEOUT, STACK_OWNER,
STACK_PARENT, STACK_USER_PROJECT_ID
STACK_PARENT, STACK_USER_PROJECT_ID, STACK_TAGS
) = (
'stack_name', 'stack_identity',
'creation_time', 'updated_time', 'deletion_time',
@ -43,7 +43,7 @@ STACK_KEYS = (
'parameters', 'outputs', 'stack_action',
'stack_status', 'stack_status_reason', 'capabilities',
'disable_rollback', 'timeout_mins', 'stack_owner',
'parent', 'stack_user_project_id'
'parent', 'stack_user_project_id', 'tags'
)
STACK_OUTPUT_KEYS = (

@ -1324,6 +1324,38 @@ class DBAPIUserCredsTest(common.HeatTestCase):
self.assertIn(exp_msg, six.text_type(err))
class DBAPIStackTagTest(common.HeatTestCase):
def setUp(self):
super(DBAPIStackTagTest, self).setUp()
self.ctx = utils.dummy_context()
self.template = create_raw_template(self.ctx)
self.user_creds = create_user_creds(self.ctx)
self.stack = create_stack(self.ctx, self.template, self.user_creds)
def test_stack_tags_set(self):
tags = db_api.stack_tags_set(self.ctx, self.stack.id, ['tag1', 'tag2'])
self.assertEqual(self.stack.id, tags[0].stack_id)
self.assertEqual('tag1', tags[0].tag)
tags = db_api.stack_tags_set(self.ctx, self.stack.id, [])
self.assertIsNone(tags)
def test_stack_tags_get(self):
db_api.stack_tags_set(self.ctx, self.stack.id, ['tag1', 'tag2'])
tags = db_api.stack_tags_get(self.ctx, self.stack.id)
self.assertEqual(self.stack.id, tags[0].stack_id)
self.assertEqual('tag1', tags[0].tag)
tags = db_api.stack_tags_get(self.ctx, UUID1)
self.assertIsNone(tags)
def test_stack_tags_delete(self):
db_api.stack_tags_set(self.ctx, self.stack.id, ['tag1', 'tag2'])
db_api.stack_tags_delete(self.ctx, self.stack.id)
tags = db_api.stack_tags_get(self.ctx, self.stack.id)
self.assertIsNone(tags)
class DBAPIStackTest(common.HeatTestCase):
def setUp(self):
super(DBAPIStackTest, self).setUp()

@ -304,6 +304,7 @@ class FormatTest(common.HeatTestCase):
'stack_user_project_id': None,
'template_description': 'No description',
'timeout_mins': None,
'tags': None,
'parameters': {
'AWS::Region': 'ap-southeast-1',
'AWS::StackId': aws_id,
@ -1056,3 +1057,36 @@ class TestExtractArgs(common.HeatTestCase):
def test_disable_rollback_extract_bad(self):
self.assertRaises(ValueError, api.extract_args,
{'disable_rollback': 'bad'})
def test_tags_extract(self):
p = {'tags': ["tag1", "tag2"]}
args = api.extract_args(p)
self.assertEqual(['tag1', 'tag2'], args['tags'])
def test_tags_extract_not_present(self):
args = api.extract_args({})
self.assertNotIn('tags', args)
def test_tags_extract_not_map(self):
p = {'tags': {"foo": "bar"}}
exc = self.assertRaises(ValueError, api.extract_args, p)
self.assertIn('Invalid tags, not a list: ', six.text_type(exc))
def test_tags_extract_not_string(self):
p = {'tags': ["tag1", 2]}
exc = self.assertRaises(ValueError, api.extract_args, p)
self.assertIn('Invalid tag, "2" is not a string', six.text_type(exc))
def test_tags_extract_over_limit(self):
p = {'tags': ["tag1", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"]}
exc = self.assertRaises(ValueError, api.extract_args, p)
self.assertIn('Invalid tag, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" is longer '
'than 80 characters', six.text_type(exc))
def test_tags_extract_comma(self):
p = {'tags': ["tag1", 'tag2,']}
exc = self.assertRaises(ValueError, api.extract_args, p)
self.assertIn('Invalid tag, "tag2," contains a comma',
six.text_type(exc))

@ -32,6 +32,7 @@ from heat.engine import scheduler
from heat.engine import stack
from heat.engine import template
from heat.objects import stack as stack_object
from heat.objects import stack_tag as stack_tag_object
from heat.objects import user_creds as ucreds_object
from heat.tests import common
from heat.tests import fakes
@ -290,7 +291,8 @@ class StackTest(common.HeatTestCase):
use_stored_context=False,
username=mox.IgnoreArg(),
convergence=False,
current_traversal=None)
current_traversal=None,
tags=mox.IgnoreArg())
self.m.ReplayAll()
stack.Stack.load(self.ctx, stack_id=self.stack.id,
@ -1124,6 +1126,35 @@ class StackTest(common.HeatTestCase):
ctx_expected['auth_token'] = None
self.assertEqual(ctx_expected, self.stack.stored_context().to_dict())
def test_load_reads_tags(self):
self.stack = stack.Stack(self.ctx, 'stack_tags', self.tmpl)
self.stack.store()
stack_id = self.stack.id
test_stack = stack.Stack.load(self.ctx, stack_id=stack_id)
self.assertIsNone(test_stack.tags)
self.stack = stack.Stack(self.ctx, 'stack_name', self.tmpl,
tags=['tag1', 'tag2'])
self.stack.store()
stack_id = self.stack.id
test_stack = stack.Stack.load(self.ctx, stack_id=stack_id)
self.assertEqual(['tag1', 'tag2'], test_stack.tags)
def test_store_saves_tags(self):
self.stack = stack.Stack(self.ctx, 'tags_stack', self.tmpl)
self.stack.store()
db_tags = stack_tag_object.StackTagList.get(self.stack.context,
self.stack.id)
self.assertIsNone(db_tags)
self.stack = stack.Stack(self.ctx, 'tags_stack', self.tmpl,
tags=['tag1', 'tag2'])
self.stack.store()
db_tags = stack_tag_object.StackTagList.get(self.stack.context,
self.stack.id)
self.assertEqual('tag1', db_tags[0].tag)
self.assertEqual('tag2', db_tags[1].tag)
def test_store_saves_creds(self):
"""
A user_creds entry is created on first stack store

@ -187,6 +187,49 @@ class StackUpdateTest(common.HeatTestCase):
self.stack.state)
self.assertEqual(True, self.stack.disable_rollback)
def test_update_tags(self):
tmpl = {'HeatTemplateFormatVersion': '2012-12-12',
'Description': 'ATemplate',
'Resources': {'AResource': {'Type': 'GenericResourceType'}}}
self.stack = stack.Stack(self.ctx, 'update_test_stack',
template.Template(tmpl),
tags=['tag1', 'tag2'])
self.stack.store()
self.stack.create()
self.assertEqual((stack.Stack.CREATE, stack.Stack.COMPLETE),
self.stack.state)
self.assertEqual(['tag1', 'tag2'], self.stack.tags)
updated_stack = stack.Stack(self.ctx, 'updated_stack',
template.Template(tmpl),
tags=['tag3', 'tag4'])
self.stack.update(updated_stack)
self.assertEqual((stack.Stack.UPDATE, stack.Stack.COMPLETE),
self.stack.state)
self.assertEqual(['tag3', 'tag4'], self.stack.tags)
def test_update_tags_remove_all(self):
tmpl = {'HeatTemplateFormatVersion': '2012-12-12',
'Description': 'ATemplate',
'Resources': {'AResource': {'Type': 'GenericResourceType'}}}
self.stack = stack.Stack(self.ctx, 'update_test_stack',
template.Template(tmpl),
tags=['tag1', 'tag2'])
self.stack.store()
self.stack.create()
self.assertEqual((stack.Stack.CREATE, stack.Stack.COMPLETE),
self.stack.state)
self.assertEqual(['tag1', 'tag2'], self.stack.tags)
updated_stack = stack.Stack(self.ctx, 'updated_stack',
template.Template(tmpl))
self.stack.update(updated_stack)
self.assertEqual((stack.Stack.UPDATE, stack.Stack.COMPLETE),
self.stack.state)
self.assertEqual(None, self.stack.tags)
def test_update_modify_ok_replace(self):
tmpl = {'HeatTemplateFormatVersion': '2012-12-12',
'Resources': {'AResource': {'Type': 'ResourceWithPropsType',

Loading…
Cancel
Save