Working on secure DB access (part 2)
* Added a hook for all secure models to automatically set proper project id * Removed all direct assignments of project id * Added a simple test * Other minor changes Change-Id: Ie0922a6f43f974738c23981dc399136138dd5c75
This commit is contained in:
parent
693227ffcf
commit
14a7b44bb0
@ -19,6 +19,7 @@ import uuid
|
|||||||
|
|
||||||
from oslo.db.sqlalchemy import models as oslo_models
|
from oslo.db.sqlalchemy import models as oslo_models
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import event
|
||||||
from sqlalchemy.ext import declarative
|
from sqlalchemy.ext import declarative
|
||||||
from sqlalchemy.orm import attributes
|
from sqlalchemy.orm import attributes
|
||||||
|
|
||||||
@ -72,16 +73,36 @@ class _MistralModelBase(oslo_models.ModelBase, oslo_models.TimestampMixin):
|
|||||||
return '%s %s' % (type(self).__name__, self.to_dict().__repr__())
|
return '%s %s' % (type(self).__name__, self.to_dict().__repr__())
|
||||||
|
|
||||||
|
|
||||||
class MistralSecureModelMixin(object):
|
|
||||||
"""Mixin adding model properties related to security."""
|
|
||||||
|
|
||||||
scope = sa.Column(sa.String(80), default="private")
|
|
||||||
project_id = sa.Column(sa.String(80), default=db_base.DEFAULT_PROJECT_ID)
|
|
||||||
|
|
||||||
|
|
||||||
def datetime_to_str(dct, attr_name):
|
def datetime_to_str(dct, attr_name):
|
||||||
if dct.get(attr_name) is not None:
|
if dct.get(attr_name) is not None:
|
||||||
dct[attr_name] = dct[attr_name].isoformat(' ')
|
dct[attr_name] = dct[attr_name].isoformat(' ')
|
||||||
|
|
||||||
|
|
||||||
MistralModelBase = declarative.declarative_base(cls=_MistralModelBase)
|
MistralModelBase = declarative.declarative_base(cls=_MistralModelBase)
|
||||||
|
|
||||||
|
|
||||||
|
# Secure model related stuff.
|
||||||
|
|
||||||
|
|
||||||
|
class MistralSecureModelBase(MistralModelBase):
|
||||||
|
"""Mixin adding model properties related to security."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
scope = sa.Column(sa.String(80), default='private')
|
||||||
|
project_id = sa.Column(sa.String(80), default=db_base.get_project_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_project_id(target, value, oldvalue, initiator):
|
||||||
|
return db_base.get_project_id()
|
||||||
|
|
||||||
|
|
||||||
|
def register_secure_model_hooks():
|
||||||
|
# Make sure 'project_id' is always properly set.
|
||||||
|
for sec_model_class in MistralSecureModelBase.__subclasses__():
|
||||||
|
event.listen(
|
||||||
|
sec_model_class.project_id,
|
||||||
|
'set',
|
||||||
|
_set_project_id,
|
||||||
|
retval=True
|
||||||
|
)
|
||||||
|
@ -17,6 +17,10 @@ from oslo.config import cfg
|
|||||||
from mistral import context as auth_context
|
from mistral import context as auth_context
|
||||||
|
|
||||||
|
|
||||||
|
# Make sure to import 'auth_enable' option before using it.
|
||||||
|
cfg.CONF.import_opt('auth_enable', 'mistral.config', group='pecan')
|
||||||
|
|
||||||
|
|
||||||
CONF = cfg.CONF
|
CONF = cfg.CONF
|
||||||
DEFAULT_PROJECT_ID = "<default-project>"
|
DEFAULT_PROJECT_ID = "<default-project>"
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2013 - Mirantis, Inc.
|
# Copyright 2015 - Mirantis, Inc.
|
||||||
# Copyright 2015 - StackStorm, Inc.
|
# Copyright 2015 - StackStorm, Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2013 - Mirantis, Inc.
|
# Copyright 2015 - Mirantis, Inc.
|
||||||
# Copyright 2015 - StackStorm, Inc.
|
# Copyright 2015 - StackStorm, Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -20,7 +20,6 @@ from oslo.config import cfg
|
|||||||
from oslo.db import exception as db_exc
|
from oslo.db import exception as db_exc
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
from mistral import context
|
|
||||||
from mistral.db.sqlalchemy import base as b
|
from mistral.db.sqlalchemy import base as b
|
||||||
from mistral.db import v2 as db_base
|
from mistral.db import v2 as db_base
|
||||||
from mistral.db.v2.sqlalchemy import models
|
from mistral.db.v2.sqlalchemy import models
|
||||||
@ -115,10 +114,10 @@ def _get_collection_sorted_by_time(model, **kwargs):
|
|||||||
def _get_db_object_by_name(model, name):
|
def _get_db_object_by_name(model, name):
|
||||||
query = b.model_query(model)
|
query = b.model_query(model)
|
||||||
|
|
||||||
proj = query.filter_by(name=name, project_id=db_base.get_project_id())
|
private = query.filter_by(name=name, project_id=db_base.get_project_id())
|
||||||
public = query.filter_by(name=name, scope='public')
|
public = query.filter_by(name=name, scope='public')
|
||||||
|
|
||||||
return proj.union(public).first()
|
return private.union(public).first()
|
||||||
|
|
||||||
|
|
||||||
def _get_db_object_by_id(model, id):
|
def _get_db_object_by_id(model, id):
|
||||||
@ -153,9 +152,6 @@ def create_workbook(values, session=None):
|
|||||||
|
|
||||||
wb.update(values.copy())
|
wb.update(values.copy())
|
||||||
|
|
||||||
# TODO(rakhmerov): It needs to be refactored.
|
|
||||||
wb['project_id'] = context.ctx().project_id
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
wb.save(session=session)
|
wb.save(session=session)
|
||||||
except db_exc.DBDuplicateEntry as e:
|
except db_exc.DBDuplicateEntry as e:
|
||||||
@ -174,8 +170,6 @@ def update_workbook(name, values, session=None):
|
|||||||
"Workbook not found [workbook_name=%s]" % name)
|
"Workbook not found [workbook_name=%s]" % name)
|
||||||
|
|
||||||
wb.update(values.copy())
|
wb.update(values.copy())
|
||||||
# TODO(rakhmerov): It needs to be refactored.
|
|
||||||
wb['project_id'] = context.ctx().project_id
|
|
||||||
|
|
||||||
return wb
|
return wb
|
||||||
|
|
||||||
@ -236,9 +230,6 @@ def create_workflow(values, session=None):
|
|||||||
|
|
||||||
wf.update(values.copy())
|
wf.update(values.copy())
|
||||||
|
|
||||||
# TODO(rakhmerov): It needs to be refactored.
|
|
||||||
wf['project_id'] = db_base.get_project_id()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
wf.save(session=session)
|
wf.save(session=session)
|
||||||
except db_exc.DBDuplicateEntry as e:
|
except db_exc.DBDuplicateEntry as e:
|
||||||
@ -257,8 +248,6 @@ def update_workflow(name, values, session=None):
|
|||||||
"Workflow not found [workflow_name=%s]" % name)
|
"Workflow not found [workflow_name=%s]" % name)
|
||||||
|
|
||||||
wf.update(values.copy())
|
wf.update(values.copy())
|
||||||
# TODO(rakhmerov): It needs to be refactored.
|
|
||||||
wf['project_id'] = db_base.get_project_id()
|
|
||||||
|
|
||||||
return wf
|
return wf
|
||||||
|
|
||||||
@ -702,13 +691,6 @@ def create_environment(values, session=None):
|
|||||||
|
|
||||||
env.update(values)
|
env.update(values)
|
||||||
|
|
||||||
# Default environment to private unless specified.
|
|
||||||
if not getattr(env, 'scope', None):
|
|
||||||
env['scope'] = 'private'
|
|
||||||
|
|
||||||
if context.ctx().project_id:
|
|
||||||
env['project_id'] = context.ctx().project_id
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
env.save(session=session)
|
env.save(session=session)
|
||||||
except db_exc.DBDuplicateEntry as e:
|
except db_exc.DBDuplicateEntry as e:
|
||||||
@ -732,9 +714,6 @@ def update_environment(name, values, session=None):
|
|||||||
if not getattr(env, 'scope', None):
|
if not getattr(env, 'scope', None):
|
||||||
env['scope'] = 'private'
|
env['scope'] = 'private'
|
||||||
|
|
||||||
if context.ctx().project_id:
|
|
||||||
env['project_id'] = context.ctx().project_id
|
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2013 - Mirantis, Inc.
|
# Copyright 2015 - Mirantis, Inc.
|
||||||
# Copyright 2015 - StackStorm, Inc.
|
# Copyright 2015 - StackStorm, Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -22,7 +22,7 @@ from mistral.db.sqlalchemy import model_base as mb
|
|||||||
from mistral.db.sqlalchemy import types as st
|
from mistral.db.sqlalchemy import types as st
|
||||||
|
|
||||||
|
|
||||||
class Workbook(mb.MistralModelBase, mb.MistralSecureModelMixin):
|
class Workbook(mb.MistralSecureModelBase):
|
||||||
"""Contains info about workbook (including definition in Mistral DSL)."""
|
"""Contains info about workbook (including definition in Mistral DSL)."""
|
||||||
|
|
||||||
__tablename__ = 'workbooks_v2'
|
__tablename__ = 'workbooks_v2'
|
||||||
@ -38,7 +38,7 @@ class Workbook(mb.MistralModelBase, mb.MistralSecureModelMixin):
|
|||||||
tags = sa.Column(st.JsonListType())
|
tags = sa.Column(st.JsonListType())
|
||||||
|
|
||||||
|
|
||||||
class Workflow(mb.MistralModelBase, mb.MistralSecureModelMixin):
|
class Workflow(mb.MistralSecureModelBase):
|
||||||
"""Contains info about workflow (including definition in Mistral DSL)."""
|
"""Contains info about workflow (including definition in Mistral DSL)."""
|
||||||
|
|
||||||
__tablename__ = 'workflows_v2'
|
__tablename__ = 'workflows_v2'
|
||||||
@ -56,7 +56,7 @@ class Workflow(mb.MistralModelBase, mb.MistralSecureModelMixin):
|
|||||||
trust_id = sa.Column(sa.String(80))
|
trust_id = sa.Column(sa.String(80))
|
||||||
|
|
||||||
|
|
||||||
class Execution(mb.MistralModelBase, mb.MistralSecureModelMixin):
|
class Execution(mb.MistralSecureModelBase):
|
||||||
"""Contains workflow execution information."""
|
"""Contains workflow execution information."""
|
||||||
|
|
||||||
__tablename__ = 'executions_v2'
|
__tablename__ = 'executions_v2'
|
||||||
@ -74,7 +74,7 @@ class Execution(mb.MistralModelBase, mb.MistralSecureModelMixin):
|
|||||||
parent_task_id = sa.Column(sa.String(36))
|
parent_task_id = sa.Column(sa.String(36))
|
||||||
|
|
||||||
|
|
||||||
class Task(mb.MistralModelBase, mb.MistralSecureModelMixin):
|
class Task(mb.MistralSecureModelBase):
|
||||||
"""Contains task runtime information."""
|
"""Contains task runtime information."""
|
||||||
|
|
||||||
__tablename__ = 'tasks_v2'
|
__tablename__ = 'tasks_v2'
|
||||||
@ -127,7 +127,7 @@ class DelayedCall(mb.MistralModelBase):
|
|||||||
execution_time = sa.Column(sa.DateTime, nullable=False)
|
execution_time = sa.Column(sa.DateTime, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
class Action(mb.MistralModelBase, mb.MistralSecureModelMixin):
|
class Action(mb.MistralSecureModelBase):
|
||||||
"""Contains info about registered Actions."""
|
"""Contains info about registered Actions."""
|
||||||
|
|
||||||
__tablename__ = 'actions_v2'
|
__tablename__ = 'actions_v2'
|
||||||
@ -153,7 +153,7 @@ class Action(mb.MistralModelBase, mb.MistralSecureModelMixin):
|
|||||||
is_system = sa.Column(sa.Boolean())
|
is_system = sa.Column(sa.Boolean())
|
||||||
|
|
||||||
|
|
||||||
class Environment(mb.MistralModelBase, mb.MistralSecureModelMixin):
|
class Environment(mb.MistralSecureModelBase):
|
||||||
"""Contains environment variables for workflow execution."""
|
"""Contains environment variables for workflow execution."""
|
||||||
|
|
||||||
__tablename__ = 'environments_v2'
|
__tablename__ = 'environments_v2'
|
||||||
@ -175,7 +175,7 @@ def _calc_workflow_input_hash(context):
|
|||||||
return hashlib.sha256(json.dumps(sorted(d.items()))).hexdigest()
|
return hashlib.sha256(json.dumps(sorted(d.items()))).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
class CronTrigger(mb.MistralModelBase, mb.MistralSecureModelMixin):
|
class CronTrigger(mb.MistralSecureModelBase):
|
||||||
"""Contains info about cron triggers."""
|
"""Contains info about cron triggers."""
|
||||||
|
|
||||||
__tablename__ = 'cron_triggers_v2'
|
__tablename__ = 'cron_triggers_v2'
|
||||||
@ -208,3 +208,7 @@ class CronTrigger(mb.MistralModelBase, mb.MistralSecureModelMixin):
|
|||||||
mb.datetime_to_str(d, 'next_execution_time')
|
mb.datetime_to_str(d, 'next_execution_time')
|
||||||
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
# Register all hooks related to secure models.
|
||||||
|
mb.register_secure_model_hooks()
|
||||||
|
@ -21,6 +21,7 @@ import datetime
|
|||||||
from oslo.config import cfg
|
from oslo.config import cfg
|
||||||
|
|
||||||
from mistral import context as auth_context
|
from mistral import context as auth_context
|
||||||
|
from mistral.db import v2 as db_base
|
||||||
from mistral.db.v2.sqlalchemy import api as db_api
|
from mistral.db.v2.sqlalchemy import api as db_api
|
||||||
from mistral import exceptions as exc
|
from mistral import exceptions as exc
|
||||||
from mistral.tests import base as test_base
|
from mistral.tests import base as test_base
|
||||||
@ -331,6 +332,12 @@ class WorkflowTest(SQLAlchemyTest):
|
|||||||
created0 = db_api.create_workflow(WORKFLOWS[0])
|
created0 = db_api.create_workflow(WORKFLOWS[0])
|
||||||
created1 = db_api.create_workflow(WORKFLOWS[1])
|
created1 = db_api.create_workflow(WORKFLOWS[1])
|
||||||
|
|
||||||
|
fetched0 = db_api.load_workflow(created0.name)
|
||||||
|
fetched1 = db_api.load_workflow(created1.name)
|
||||||
|
|
||||||
|
self.assertEqual(db_base.get_project_id(), fetched0.project_id)
|
||||||
|
self.assertEqual(db_base.get_project_id(), fetched1.project_id)
|
||||||
|
|
||||||
fetched = db_api.get_workflows()
|
fetched = db_api.get_workflows()
|
||||||
|
|
||||||
self.assertEqual(2, len(fetched))
|
self.assertEqual(2, len(fetched))
|
||||||
|
Loading…
Reference in New Issue
Block a user