Working on secure DB access (part 2)

* Added a hook for all secure models to automatically set proper project id
* Removed all direct assignments of project id
* Added a simple test
* Other minor changes

Change-Id: Ie0922a6f43f974738c23981dc399136138dd5c75
This commit is contained in:
Renat Akhmerov 2015-01-19 16:13:22 +06:00
parent 693227ffcf
commit 14a7b44bb0
6 changed files with 56 additions and 41 deletions

View File

@ -19,6 +19,7 @@ import uuid
from oslo.db.sqlalchemy import models as oslo_models
import sqlalchemy as sa
from sqlalchemy import event
from sqlalchemy.ext import declarative
from sqlalchemy.orm import attributes
@ -72,16 +73,36 @@ class _MistralModelBase(oslo_models.ModelBase, oslo_models.TimestampMixin):
return '%s %s' % (type(self).__name__, self.to_dict().__repr__())
class MistralSecureModelMixin(object):
"""Mixin adding model properties related to security."""
scope = sa.Column(sa.String(80), default="private")
project_id = sa.Column(sa.String(80), default=db_base.DEFAULT_PROJECT_ID)
def datetime_to_str(dct, attr_name):
if dct.get(attr_name) is not None:
dct[attr_name] = dct[attr_name].isoformat(' ')
MistralModelBase = declarative.declarative_base(cls=_MistralModelBase)
# Secure model related stuff.
class MistralSecureModelBase(MistralModelBase):
"""Mixin adding model properties related to security."""
__abstract__ = True
scope = sa.Column(sa.String(80), default='private')
project_id = sa.Column(sa.String(80), default=db_base.get_project_id)
def _set_project_id(target, value, oldvalue, initiator):
return db_base.get_project_id()
def register_secure_model_hooks():
# Make sure 'project_id' is always properly set.
for sec_model_class in MistralSecureModelBase.__subclasses__():
event.listen(
sec_model_class.project_id,
'set',
_set_project_id,
retval=True
)

View File

@ -17,6 +17,10 @@ from oslo.config import cfg
from mistral import context as auth_context
# Make sure to import 'auth_enable' option before using it.
cfg.CONF.import_opt('auth_enable', 'mistral.config', group='pecan')
CONF = cfg.CONF
DEFAULT_PROJECT_ID = "<default-project>"
@ -25,4 +29,4 @@ def get_project_id():
if CONF.pecan.auth_enable and auth_context.has_ctx():
return auth_context.ctx().project_id
else:
return DEFAULT_PROJECT_ID
return DEFAULT_PROJECT_ID

View File

@ -1,4 +1,4 @@
# Copyright 2013 - Mirantis, Inc.
# Copyright 2015 - Mirantis, Inc.
# Copyright 2015 - StackStorm, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,4 +1,4 @@
# Copyright 2013 - Mirantis, Inc.
# Copyright 2015 - Mirantis, Inc.
# Copyright 2015 - StackStorm, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,7 +20,6 @@ from oslo.config import cfg
from oslo.db import exception as db_exc
import sqlalchemy as sa
from mistral import context
from mistral.db.sqlalchemy import base as b
from mistral.db import v2 as db_base
from mistral.db.v2.sqlalchemy import models
@ -115,10 +114,10 @@ def _get_collection_sorted_by_time(model, **kwargs):
def _get_db_object_by_name(model, name):
query = b.model_query(model)
proj = query.filter_by(name=name, project_id=db_base.get_project_id())
private = query.filter_by(name=name, project_id=db_base.get_project_id())
public = query.filter_by(name=name, scope='public')
return proj.union(public).first()
return private.union(public).first()
def _get_db_object_by_id(model, id):
@ -153,9 +152,6 @@ def create_workbook(values, session=None):
wb.update(values.copy())
# TODO(rakhmerov): It needs to be refactored.
wb['project_id'] = context.ctx().project_id
try:
wb.save(session=session)
except db_exc.DBDuplicateEntry as e:
@ -174,8 +170,6 @@ def update_workbook(name, values, session=None):
"Workbook not found [workbook_name=%s]" % name)
wb.update(values.copy())
# TODO(rakhmerov): It needs to be refactored.
wb['project_id'] = context.ctx().project_id
return wb
@ -236,9 +230,6 @@ def create_workflow(values, session=None):
wf.update(values.copy())
# TODO(rakhmerov): It needs to be refactored.
wf['project_id'] = db_base.get_project_id()
try:
wf.save(session=session)
except db_exc.DBDuplicateEntry as e:
@ -257,8 +248,6 @@ def update_workflow(name, values, session=None):
"Workflow not found [workflow_name=%s]" % name)
wf.update(values.copy())
# TODO(rakhmerov): It needs to be refactored.
wf['project_id'] = db_base.get_project_id()
return wf
@ -702,13 +691,6 @@ def create_environment(values, session=None):
env.update(values)
# Default environment to private unless specified.
if not getattr(env, 'scope', None):
env['scope'] = 'private'
if context.ctx().project_id:
env['project_id'] = context.ctx().project_id
try:
env.save(session=session)
except db_exc.DBDuplicateEntry as e:
@ -732,9 +714,6 @@ def update_environment(name, values, session=None):
if not getattr(env, 'scope', None):
env['scope'] = 'private'
if context.ctx().project_id:
env['project_id'] = context.ctx().project_id
return env

View File

@ -1,4 +1,4 @@
# Copyright 2013 - Mirantis, Inc.
# Copyright 2015 - Mirantis, Inc.
# Copyright 2015 - StackStorm, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -22,7 +22,7 @@ from mistral.db.sqlalchemy import model_base as mb
from mistral.db.sqlalchemy import types as st
class Workbook(mb.MistralModelBase, mb.MistralSecureModelMixin):
class Workbook(mb.MistralSecureModelBase):
"""Contains info about workbook (including definition in Mistral DSL)."""
__tablename__ = 'workbooks_v2'
@ -38,7 +38,7 @@ class Workbook(mb.MistralModelBase, mb.MistralSecureModelMixin):
tags = sa.Column(st.JsonListType())
class Workflow(mb.MistralModelBase, mb.MistralSecureModelMixin):
class Workflow(mb.MistralSecureModelBase):
"""Contains info about workflow (including definition in Mistral DSL)."""
__tablename__ = 'workflows_v2'
@ -56,7 +56,7 @@ class Workflow(mb.MistralModelBase, mb.MistralSecureModelMixin):
trust_id = sa.Column(sa.String(80))
class Execution(mb.MistralModelBase, mb.MistralSecureModelMixin):
class Execution(mb.MistralSecureModelBase):
"""Contains workflow execution information."""
__tablename__ = 'executions_v2'
@ -74,7 +74,7 @@ class Execution(mb.MistralModelBase, mb.MistralSecureModelMixin):
parent_task_id = sa.Column(sa.String(36))
class Task(mb.MistralModelBase, mb.MistralSecureModelMixin):
class Task(mb.MistralSecureModelBase):
"""Contains task runtime information."""
__tablename__ = 'tasks_v2'
@ -127,7 +127,7 @@ class DelayedCall(mb.MistralModelBase):
execution_time = sa.Column(sa.DateTime, nullable=False)
class Action(mb.MistralModelBase, mb.MistralSecureModelMixin):
class Action(mb.MistralSecureModelBase):
"""Contains info about registered Actions."""
__tablename__ = 'actions_v2'
@ -153,7 +153,7 @@ class Action(mb.MistralModelBase, mb.MistralSecureModelMixin):
is_system = sa.Column(sa.Boolean())
class Environment(mb.MistralModelBase, mb.MistralSecureModelMixin):
class Environment(mb.MistralSecureModelBase):
"""Contains environment variables for workflow execution."""
__tablename__ = 'environments_v2'
@ -175,7 +175,7 @@ def _calc_workflow_input_hash(context):
return hashlib.sha256(json.dumps(sorted(d.items()))).hexdigest()
class CronTrigger(mb.MistralModelBase, mb.MistralSecureModelMixin):
class CronTrigger(mb.MistralSecureModelBase):
"""Contains info about cron triggers."""
__tablename__ = 'cron_triggers_v2'
@ -208,3 +208,7 @@ class CronTrigger(mb.MistralModelBase, mb.MistralSecureModelMixin):
mb.datetime_to_str(d, 'next_execution_time')
return d
# Register all hooks related to secure models.
mb.register_secure_model_hooks()

View File

@ -21,6 +21,7 @@ import datetime
from oslo.config import cfg
from mistral import context as auth_context
from mistral.db import v2 as db_base
from mistral.db.v2.sqlalchemy import api as db_api
from mistral import exceptions as exc
from mistral.tests import base as test_base
@ -331,6 +332,12 @@ class WorkflowTest(SQLAlchemyTest):
created0 = db_api.create_workflow(WORKFLOWS[0])
created1 = db_api.create_workflow(WORKFLOWS[1])
fetched0 = db_api.load_workflow(created0.name)
fetched1 = db_api.load_workflow(created1.name)
self.assertEqual(db_base.get_project_id(), fetched0.project_id)
self.assertEqual(db_base.get_project_id(), fetched1.project_id)
fetched = db_api.get_workflows()
self.assertEqual(2, len(fetched))