Enforce multi-tenancy for actions

This patch fixes the db layer and engine layer about action
multi-tenancy support. There may need some more stricter unit tests for
this support.

Change-Id: If1a0fbcc3a36be57d9accedad647235f6a29940a
This commit is contained in:
tengqm 2016-02-17 10:18:54 -05:00
parent 3536553042
commit d0c7ecabf8
8 changed files with 70 additions and 34 deletions

View File

@ -299,16 +299,18 @@ def action_update(context, action_id, values):
return IMPL.action_update(context, action_id, values)
def action_get(context, action_id, refresh=False):
return IMPL.action_get(context, action_id, refresh=refresh)
def action_get(context, action_id, project_safe=True, refresh=False):
return IMPL.action_get(context, action_id, project_safe=project_safe,
refresh=refresh)
def action_get_by_name(context, name):
return IMPL.action_get_by_name(context, name)
def action_get_by_name(context, name, project_safe=True):
return IMPL.action_get_by_name(context, name, project_safe=project_safe)
def action_get_by_short_id(context, short_id):
return IMPL.action_get_by_short_id(context, short_id)
def action_get_by_short_id(context, short_id, project_safe=True):
return IMPL.action_get_by_short_id(context, short_id,
project_safe=project_safe)
def action_get_all_by_owner(context, owner):

View File

@ -847,21 +847,27 @@ def action_update(context, action_id, values):
action.save(session)
def action_get(context, action_id, refresh=False):
def action_get(context, action_id, project_safe=True, refresh=False):
session = _session(context)
action = session.query(models.Action).get(action_id)
if action is None:
return None
if project_safe is True and action.project != context.project:
return None
session.refresh(action)
return action
def action_get_by_name(context, name):
return query_by_name(context, models.Action, name)
def action_get_by_name(context, name, project_safe=True):
return query_by_name(context, models.Action, name,
project_safe=project_safe)
def action_get_by_short_id(context, short_id):
return query_by_short_id(context, models.Action, short_id)
def action_get_by_short_id(context, short_id, project_safe=True):
return query_by_short_id(context, models.Action, short_id,
project_safe=project_safe)
def action_get_all_by_owner(context, owner_id):

View File

@ -96,14 +96,20 @@ class Action(object):
# TODO(Yanyan Hu): Replace context with DB session
if not context:
self.user = kwargs.get('user')
self.project = kwargs.get('project')
self.domain = kwargs.get('domain')
params = {
'user': kwargs.get('user'),
'project': kwargs.get('project'),
'domain': kwargs.get('domain'),
'user': self.user,
'project': self.project,
'domain': self.domain,
'is_admin': False
}
self.context = req_context.RequestContext.from_dict(params)
else:
self.user = context.user
self.project = context.project
self.domain = context.domain
self.context = context
# TODO(Qiming): make description a db column
@ -177,6 +183,9 @@ class Action(object):
'created_at': self.created_at,
'updated_at': self.updated_at,
'data': self.data,
'user': self.user,
'project': self.project,
'domain': self.domain,
}
if self.id:
@ -216,6 +225,9 @@ class Action(object):
'created_at': record.created_at,
'updated_at': record.updated_at,
'data': record.data,
'user': record.user,
'project': record.project,
'domain': record.domain,
}
return cls(record.target, record.action, context=context, **kwargs)

View File

@ -1778,7 +1778,7 @@ class EngineService(service.Service):
return {'action': action.id}
def action_find(self, context, identity):
def action_find(self, context, identity, project_safe=True):
"""Find an action with the given identity.
:param context: An instance of the request context.
@ -1787,13 +1787,17 @@ class EngineService(service.Service):
matching action is found.
"""
if uuidutils.is_uuid_like(identity):
action = db_api.action_get(context, identity)
action = db_api.action_get(context, identity,
project_safe=project_safe)
if not action:
action = db_api.action_get_by_name(context, identity)
action = db_api.action_get_by_name(context, identity,
project_safe=project_safe)
else:
action = db_api.action_get_by_name(context, identity)
action = db_api.action_get_by_name(context, identity,
project_safe=project_safe)
if not action:
action = db_api.action_get_by_short_id(context, identity)
action = db_api.action_get_by_short_id(
context, identity, project_safe=project_safe)
if not action:
raise exception.ActionNotFound(action=identity)

View File

@ -25,6 +25,9 @@ from senlin.tests.unit.db import shared
def _create_action(context, action=shared.sample_action, **kwargs):
data = parser.simple_parse(action)
data['user'] = context.user
data['project'] = context.project
data['domain'] = context.domain
data.update(kwargs)
return db_api.action_create(context, data)
@ -36,7 +39,7 @@ class DBAPIActionTest(base.SenlinTestCase):
def test_action_create(self):
data = parser.simple_parse(shared.sample_action)
action = db_api.action_create(self.ctx, data)
action = _create_action(self.ctx)
self.assertIsNotNone(action)
self.assertEqual(data['name'], action.name)
@ -47,11 +50,13 @@ class DBAPIActionTest(base.SenlinTestCase):
self.assertEqual(data['status'], action.status)
self.assertEqual(data['status_reason'], action.status_reason)
self.assertEqual(10, action.inputs['max_size'])
self.assertEqual(self.ctx.user, action.user)
self.assertEqual(self.ctx.project, action.project)
self.assertEqual(self.ctx.domain, action.domain)
self.assertIsNone(action.outputs)
def test_action_update(self):
data = parser.simple_parse(shared.sample_action)
action = db_api.action_create(self.ctx, data)
action = _create_action(self.ctx)
values = {
'status': 'ERROR',
'status_reason': 'Cluster creation failed',

View File

@ -195,6 +195,9 @@ class ActionBaseTest(base.SenlinTestCase):
self.assertEqual(obj.created_at, action_obj.created_at)
self.assertEqual(obj.updated_at, action_obj.updated_at)
self.assertEqual(obj.data, action_obj.data)
self.assertEqual(obj.user, action_obj.user)
self.assertEqual(obj.project, action_obj.project)
self.assertEqual(obj.domain, action_obj.domain)
def test_from_db_record_with_empty_fields(self):
values = copy.deepcopy(self.action_values)
@ -441,7 +444,7 @@ class ActionBaseTest(base.SenlinTestCase):
self.assertEqual('FAKE_STATUS', res)
self.assertEqual('FAKE_STATUS', action.status)
mock_get.assert_called_once_with(action.context, 'FAKE_ID',
refresh=True)
project_safe=True, refresh=True)
@mock.patch.object(action_base, 'wallclock')
def test_is_timeout(self, mock_time):

View File

@ -23,11 +23,11 @@ class CustomActionTest(base.SenlinTestCase):
self.ctx = utils.dummy_context()
def test_init(self):
obj = ca.CustomAction(self.ctx, 'OBJID', 'OBJECT_ACTION')
obj = ca.CustomAction('OBJID', 'OBJECT_ACTION', self.ctx)
self.assertIsNotNone(obj)
def test_execute(self):
obj = ca.CustomAction(self.ctx, 'OBJID', 'OBJECT_ACTION')
obj = ca.CustomAction('OBJID', 'OBJECT_ACTION', self.ctx)
params = {'key': 'value'}
res = obj.execute(**params)

View File

@ -40,7 +40,7 @@ class ActionTest(base.SenlinTestCase):
result = self.eng.action_find(self.ctx, aid)
self.assertEqual(x_action, result)
mock_get.assert_called_once_with(self.ctx, aid)
mock_get.assert_called_once_with(self.ctx, aid, project_safe=True)
@mock.patch.object(db_api, 'action_get_by_name')
@mock.patch.object(db_api, 'action_get')
@ -50,11 +50,12 @@ class ActionTest(base.SenlinTestCase):
mock_get.return_value = None
aid = uuidutils.generate_uuid()
result = self.eng.action_find(self.ctx, aid)
result = self.eng.action_find(self.ctx, aid, False)
self.assertEqual(x_action, result)
mock_get.assert_called_once_with(self.ctx, aid)
mock_get_name.assert_called_once_with(self.ctx, aid)
mock_get.assert_called_once_with(self.ctx, aid, project_safe=False)
mock_get_name.assert_called_once_with(self.ctx, aid,
project_safe=False)
@mock.patch.object(db_api, 'action_get_by_name')
def test_action_find_by_name(self, mock_get_name):
@ -65,7 +66,7 @@ class ActionTest(base.SenlinTestCase):
result = self.eng.action_find(self.ctx, aid)
self.assertEqual(x_action, result)
mock_get_name.assert_called_once_with(self.ctx, aid)
mock_get_name.assert_called_once_with(self.ctx, aid, project_safe=True)
@mock.patch.object(db_api, 'action_get_by_short_id')
@mock.patch.object(db_api, 'action_get_by_name')
@ -75,11 +76,13 @@ class ActionTest(base.SenlinTestCase):
mock_get_name.return_value = None
aid = 'abcd-1234-abcd'
result = self.eng.action_find(self.ctx, aid)
result = self.eng.action_find(self.ctx, aid, False)
self.assertEqual(x_action, result)
mock_get_name.assert_called_once_with(self.ctx, aid)
mock_get_shortid.assert_called_once_with(self.ctx, aid)
mock_get_name.assert_called_once_with(self.ctx, aid,
project_safe=False)
mock_get_shortid.assert_called_once_with(self.ctx, aid,
project_safe=False)
@mock.patch.object(db_api, 'action_get_by_name')
def test_action_find_not_found(self, mock_get_name):
@ -90,7 +93,8 @@ class ActionTest(base.SenlinTestCase):
self.ctx, 'bogus')
self.assertEqual('The action (bogus) could not be found.',
six.text_type(ex))
mock_get_name.assert_called_once_with(self.ctx, 'bogus')
mock_get_name.assert_called_once_with(self.ctx, 'bogus',
project_safe=True)
@mock.patch.object(action_base.Action, 'load_all')
def test_action_list(self, mock_load):