Browse Source

Ignore project_safe restriction for admin users

DB operations were previously restricted to resources belonging to the
project used for authentication. This change removes this restriction so
that admin users can access/modify resources belonging to any project.

Change-Id: I7882ebeb194137e682bdb7ab90f03587c636a7f8
changes/66/706966/2
Duc Truong 1 week ago
parent
commit
be0cf7617b
10 changed files with 115 additions and 91 deletions
  1. +11
    -0
      releasenotes/notes/db-ignore-project_safe-for-admins-2986f15e74cd1d1c.yaml
  2. +24
    -72
      senlin/db/sqlalchemy/api.py
  3. +43
    -0
      senlin/db/sqlalchemy/utils.py
  4. +3
    -1
      senlin/tests/unit/db/test_action_api.py
  5. +6
    -3
      senlin/tests/unit/db/test_cluster_api.py
  6. +6
    -4
      senlin/tests/unit/db/test_event_api.py
  7. +10
    -5
      senlin/tests/unit/db/test_node_api.py
  8. +4
    -2
      senlin/tests/unit/db/test_policy_api.py
  9. +4
    -2
      senlin/tests/unit/db/test_profile_api.py
  10. +4
    -2
      senlin/tests/unit/db/test_receiver_api.py

+ 11
- 0
releasenotes/notes/db-ignore-project_safe-for-admins-2986f15e74cd1d1c.yaml View File

@@ -0,0 +1,11 @@
---
features:
- |
Admin role users can now access and modify all resources (clusters, nodes,
etc) regardless of which project that belong to.
security:
- |
Removed the restriction for admin role users that prevented access/changes
to resources (clusters, nodes, etc) belonging to projects not matching the
project used for authentication. Access for non-admin users is still
isolated to their project used for authentication.

+ 24
- 72
senlin/db/sqlalchemy/api.py View File

@@ -90,9 +90,7 @@ def query_by_short_id(context, model_query, model, short_id,
project_safe=True):
q = model_query()
q = q.filter(model.id.like('%s%%' % short_id))

if project_safe:
q = q.filter_by(project=context.project_id)
q = utils.filter_query_by_project(q, project_safe, context)

if q.count() == 1:
return q.first()
@@ -105,9 +103,7 @@ def query_by_short_id(context, model_query, model, short_id,
def query_by_name(context, model_query, name, project_safe=True):
q = model_query()
q = q.filter_by(name=name)

if project_safe:
q = q.filter_by(project=context.project_id)
q = utils.filter_query_by_project(q, project_safe, context)

if q.count() == 1:
return q.first()
@@ -143,10 +139,7 @@ def cluster_get(context, cluster_id, project_safe=True):
if cluster is None:
return None

if project_safe:
if context.project_id != cluster.project:
return None
return cluster
return utils.check_resource_project(context, cluster, project_safe)


def cluster_get_by_name(context, name, project_safe=True):
@@ -161,9 +154,8 @@ def cluster_get_by_short_id(context, short_id, project_safe=True):

def _query_cluster_get_all(context, project_safe=True):
query = cluster_model_query()
query = utils.filter_query_by_project(query, project_safe, context)

if project_safe:
query = query.filter_by(project=context.project_id)
return query


@@ -260,11 +252,7 @@ def node_get(context, node_id, project_safe=True):
if not node:
return None

if project_safe:
if context.project_id != node.project:
return None

return node
return utils.check_resource_project(context, node, project_safe)


def node_get_by_name(context, name, project_safe=True):
@@ -283,8 +271,7 @@ def _query_node_get_all(context, project_safe=True, cluster_id=None):
if cluster_id is not None:
query = query.filter_by(cluster_id=cluster_id)

if project_safe:
query = query.filter_by(project=context.project_id)
query = utils.filter_query_by_project(query, project_safe, context)

return query

@@ -330,8 +317,7 @@ def node_count_by_cluster(context, cluster_id, **kwargs):
query = node_model_query()
query = query.filter_by(cluster_id=cluster_id)
query = query.filter_by(**kwargs)
if project_safe:
query = query.filter_by(project=context.project_id)
query = utils.filter_query_by_project(query, project_safe, context)

return query.count()

@@ -612,11 +598,7 @@ def policy_get(context, policy_id, project_safe=True):
if policy is None:
return None

if project_safe:
if context.project_id != policy.project:
return None

return policy
return utils.check_resource_project(context, policy, project_safe)


def policy_get_by_name(context, name, project_safe=True):
@@ -632,9 +614,7 @@ def policy_get_by_short_id(context, short_id, project_safe=True):
def policy_get_all(context, limit=None, marker=None, sort=None, filters=None,
project_safe=True):
query = policy_model_query()

if project_safe:
query = query.filter_by(project=context.project_id)
query = utils.filter_query_by_project(query, project_safe, context)

if filters:
query = utils.exact_filter(query, models.Policy, filters)
@@ -861,11 +841,7 @@ def profile_get(context, profile_id, project_safe=True):
if profile is None:
return None

if project_safe:
if context.project_id != profile.project:
return None

return profile
return utils.check_resource_project(context, profile, project_safe)


def profile_get_by_name(context, name, project_safe=True):
@@ -881,9 +857,7 @@ def profile_get_by_short_id(context, short_id, project_safe=True):
def profile_get_all(context, limit=None, marker=None, sort=None, filters=None,
project_safe=True):
query = profile_model_query()

if project_safe:
query = query.filter_by(project=context.project_id)
query = utils.filter_query_by_project(query, project_safe, context)

if filters:
query = utils.exact_filter(query, models.Profile, filters)
@@ -996,11 +970,7 @@ def event_create(context, values):
@retry_on_deadlock
def event_get(context, event_id, project_safe=True):
event = event_model_query().get(event_id)
if project_safe and event is not None:
if event.project != context.project_id:
return None

return event
return utils.check_resource_project(context, event, project_safe)


def event_get_by_short_id(context, short_id, project_safe=True):
@@ -1023,8 +993,7 @@ def _event_filter_paginate_query(context, query, filters=None,
def event_get_all(context, limit=None, marker=None, sort=None, filters=None,
project_safe=True):
query = event_model_query()
if project_safe:
query = query.filter_by(project=context.project_id)
query = utils.filter_query_by_project(query, project_safe, context)

return _event_filter_paginate_query(context, query, filters=filters,
limit=limit, marker=marker, sort=sort)
@@ -1032,9 +1001,8 @@ def event_get_all(context, limit=None, marker=None, sort=None, filters=None,

def event_count_by_cluster(context, cluster_id, project_safe=True):
query = event_model_query()
query = utils.filter_query_by_project(query, project_safe, context)

if project_safe:
query = query.filter_by(project=context.project_id)
count = query.filter_by(cluster_id=cluster_id).count()

return count
@@ -1044,9 +1012,7 @@ def event_get_all_by_cluster(context, cluster_id, limit=None, marker=None,
sort=None, filters=None, project_safe=True):
query = event_model_query()
query = query.filter_by(cluster_id=cluster_id)

if project_safe:
query = query.filter_by(project=context.project_id)
query = utils.filter_query_by_project(query, project_safe, context)

return _event_filter_paginate_query(context, query, filters=filters,
limit=limit, marker=marker, sort=sort)
@@ -1057,8 +1023,7 @@ def event_prune(context, cluster_id, project_safe=True):
with session_for_write() as session:
query = session.query(models.Event).with_for_update()
query = query.filter_by(cluster_id=cluster_id)
if project_safe:
query = query.filter_by(project=context.project_id)
query = utils.filter_query_by_project(query, project_safe, context)

return query.delete(synchronize_session='fetch')

@@ -1117,17 +1082,13 @@ def action_get(context, action_id, project_safe=True, refresh=False):
if action is None:
return None

if project_safe:
if action.project != context.project_id:
return None

return action
return utils.check_resource_project(context, action, project_safe)


def action_list_active_scaling(context, cluster_id=None, project_safe=True):
query = action_model_query()
if project_safe:
query = query.filter_by(project=context.project_id)
query = utils.filter_query_by_project(query, project_safe, context)
if cluster_id:
query = query.filter_by(target=cluster_id)
query = query.filter(
@@ -1159,8 +1120,7 @@ def action_get_all_by_owner(context, owner_id):

def action_get_all_active_by_target(context, target_id, project_safe=True):
query = action_model_query()
if project_safe:
query = query.filter_by(project=context.project_id)
query = utils.filter_query_by_project(query, project_safe, context)
query = query.filter_by(target=target_id)
query = query.filter(
models.Action.status.in_(
@@ -1175,8 +1135,7 @@ def action_get_all_active_by_target(context, target_id, project_safe=True):
def action_get_all(context, filters=None, limit=None, marker=None, sort=None,
project_safe=True):
query = action_model_query()
if project_safe:
query = query.filter_by(project=context.project_id)
query = utils.filter_query_by_project(query, project_safe, context)

if filters:
query = utils.exact_filter(query, models.Action, filters)
@@ -1480,9 +1439,7 @@ def action_delete_by_target(context, target, action=None,

with session_for_write() as session:
q = session.query(models.Action).filter_by(target=target)

if project_safe:
q = q.filter_by(project=context.project_id)
q = utils.filter_query_by_project(q, project_safe, context)

if action:
q = q.filter(models.Action.action.in_(action))
@@ -1537,18 +1494,13 @@ def receiver_get(context, receiver_id, project_safe=True):
if not receiver:
return None

if project_safe:
if context.project_id != receiver.project:
return None

return receiver
return utils.check_resource_project(context, receiver, project_safe)


def receiver_get_all(context, limit=None, marker=None, filters=None, sort=None,
project_safe=True):
query = receiver_model_query()
if project_safe:
query = query.filter_by(project=context.project_id)
query = utils.filter_query_by_project(query, project_safe, context)

if filters:
query = utils.exact_filter(query, models.Receiver, filters)

+ 43
- 0
senlin/db/sqlalchemy/utils.py View File

@@ -46,6 +46,49 @@ def exact_filter(query, model, filters):
return query


def filter_query_by_project(q, project_safe, context):
"""Filters a query to the context's project

Returns the updated query, Adds filter to limit project to the
context's project for non-admin users. For admin users,
the query is returned unmodified.

:param query: query to apply filters to
:param project_safe: boolean indicating if project restriction filter
should be applied
:param context: context of the query

"""

if project_safe and not context.is_admin:
return q.filter_by(project=context.project_id)

return q


def check_resource_project(context, resource, project_safe):
"""Check if the resource's project matches the context's project

For non-admin users, if project_safe is set and the resource's project
does not match the context's project, none is returned.
Otherwise return the resource unmodified.

:param context: context of the call
:param resource: resource to check
:param project_safe: boolean indicating if project restriction should be
checked.
"""

if resource is None:
return resource

if project_safe and not context.is_admin:
if context.project_id != resource.project:
return None

return resource


def get_sort_params(value, default_key=None):
"""Parse a string into a list of sort_keys and a list of sort_dirs.


+ 3
- 1
senlin/tests/unit/db/test_action_api.py View File

@@ -152,8 +152,10 @@ class DBAPIActionTest(base.SenlinTestCase):
parser.simple_parse(shared.sample_action)
action = _create_action(self.ctx)
new_ctx = utils.dummy_context(project='another-project', is_admin=True)

retobj = db_api.action_get(new_ctx, action.id, project_safe=True)
self.assertIsNone(retobj)
self.assertIsNotNone(retobj)

retobj = db_api.action_get(new_ctx, action.id, project_safe=False)
self.assertIsNotNone(retobj)


+ 6
- 3
senlin/tests/unit/db/test_cluster_api.py View File

@@ -81,7 +81,9 @@ class DBAPIClusterTest(base.SenlinTestCase):
is_admin=True)
ret_cluster = db_api.cluster_get(admin_ctx, cluster.id,
project_safe=True)
self.assertIsNone(ret_cluster)
self.assertEqual(cluster.id, ret_cluster.id)
self.assertEqual('db_test_cluster_name', ret_cluster.name)

ret_cluster = db_api.cluster_get(admin_ctx, cluster.id,
project_safe=False)
self.assertEqual(cluster.id, ret_cluster.id)
@@ -228,7 +230,8 @@ class DBAPIClusterTest(base.SenlinTestCase):
admin_ctx = utils.dummy_context(project='another-project',
is_admin=True)
clusters = db_api.cluster_get_all(admin_ctx, project_safe=True)
self.assertEqual(0, len(clusters))
self.assertEqual(5, len(clusters))

clusters = db_api.cluster_get_all(admin_ctx, project_safe=False)
self.assertEqual(5, len(clusters))

@@ -372,7 +375,7 @@ class DBAPIClusterTest(base.SenlinTestCase):

admin_ctx = utils.dummy_context(project='another-project',
is_admin=True)
self.assertEqual(0, db_api.cluster_count_all(admin_ctx,
self.assertEqual(5, db_api.cluster_count_all(admin_ctx,
project_safe=True))
self.assertEqual(5, db_api.cluster_count_all(admin_ctx,
project_safe=False))

+ 6
- 4
senlin/tests/unit/db/test_event_api.py View File

@@ -102,7 +102,7 @@ class DBAPIEventTest(base.SenlinTestCase):
admin_ctx = utils.dummy_context(project='a-different-project',
is_admin=True)
res = db_api.event_get(admin_ctx, event.id, project_safe=True)
self.assertIsNone(res)
self.assertIsNotNone(res)
res = db_api.event_get(admin_ctx, event.id, project_safe=False)
self.assertIsNotNone(res)

@@ -274,7 +274,8 @@ class DBAPIEventTest(base.SenlinTestCase):
admin_ctx = utils.dummy_context(project='another-project',
is_admin=True)
events = db_api.event_get_all(admin_ctx, project_safe=True)
self.assertEqual(0, len(events))
self.assertEqual(3, len(events))

events = db_api.event_get_all(admin_ctx, project_safe=False)
self.assertEqual(3, len(events))

@@ -354,7 +355,7 @@ class DBAPIEventTest(base.SenlinTestCase):
is_admin=True)
events = db_api.event_get_all_by_cluster(admin_ctx, cluster1.id,
project_safe=True)
self.assertEqual(0, len(events))
self.assertEqual(2, len(events))
events = db_api.event_get_all_by_cluster(admin_ctx, cluster1.id,
project_safe=False)
self.assertEqual(2, len(events))
@@ -423,7 +424,8 @@ class DBAPIEventTest(base.SenlinTestCase):

res = db_api.event_count_by_cluster(admin_ctx, cluster1.id,
project_safe=True)
self.assertEqual(0, res)
self.assertEqual(1, res)

res = db_api.event_count_by_cluster(admin_ctx, cluster1.id,
project_safe=False)
self.assertEqual(1, res)

+ 10
- 5
senlin/tests/unit/db/test_node_api.py View File

@@ -78,7 +78,8 @@ class DBAPINodeTest(base.SenlinTestCase):
admin_ctx = utils.dummy_context(project='a_different_project',
is_admin=True)
node = db_api.node_get(admin_ctx, res.id, project_safe=True)
self.assertIsNone(node)
self.assertIsNotNone(node)

node = db_api.node_get(admin_ctx, res.id, project_safe=False)
self.assertIsNotNone(node)

@@ -153,7 +154,8 @@ class DBAPINodeTest(base.SenlinTestCase):
is_admin=True)
res = db_api.node_get_by_short_id(admin_ctx, node_id[:11],
project_safe=True)
self.assertIsNone(res)
self.assertIsNotNone(res)

res = db_api.node_get_by_short_id(admin_ctx, node_id[:11],
project_safe=False)
self.assertIsNotNone(res)
@@ -368,7 +370,8 @@ class DBAPINodeTest(base.SenlinTestCase):
admin_ctx = utils.dummy_context(project='a_different_project',
is_admin=True)
results = db_api.node_get_all(admin_ctx, project_safe=True)
self.assertEqual(0, len(results))
self.assertEqual(2, len(results))

results = db_api.node_get_all(admin_ctx, project_safe=False)
self.assertEqual(2, len(results))

@@ -444,7 +447,8 @@ class DBAPINodeTest(base.SenlinTestCase):
admin_ctx = utils.dummy_context(project='a_different_project',
is_admin=True)
nodes = db_api.node_get_all_by_cluster(admin_ctx, self.cluster.id)
self.assertEqual(0, len(nodes))
self.assertEqual(2, len(nodes))

nodes = db_api.node_get_all_by_cluster(admin_ctx, self.cluster.id,
project_safe=False)
self.assertEqual(2, len(nodes))
@@ -498,7 +502,8 @@ class DBAPINodeTest(base.SenlinTestCase):
is_admin=True)
res = db_api.node_count_by_cluster(admin_ctx, self.cluster.id,
project_safe=True)
self.assertEqual(0, res)
self.assertEqual(2, res)

res = db_api.node_count_by_cluster(admin_ctx, self.cluster.id,
project_safe=False)
self.assertEqual(2, res)

+ 4
- 2
senlin/tests/unit/db/test_policy_api.py View File

@@ -90,7 +90,8 @@ class DBAPIPolicyTest(base.SenlinTestCase):
admin_ctx = utils.dummy_context(project='a-different-project',
is_admin=True)
res = db_api.policy_get(admin_ctx, policy.id, project_safe=True)
self.assertIsNone(res)
self.assertIsNotNone(res)

res = db_api.policy_get(admin_ctx, policy.id, project_safe=False)
self.assertIsNotNone(res)

@@ -235,7 +236,8 @@ class DBAPIPolicyTest(base.SenlinTestCase):
admin_ctx = utils.dummy_context(project='a-different-project',
is_admin=True)
policies = db_api.policy_get_all(admin_ctx, project_safe=True)
self.assertEqual(0, len(policies))
self.assertEqual(2, len(policies))

policies = db_api.policy_get_all(admin_ctx, project_safe=False)
self.assertEqual(2, len(policies))


+ 4
- 2
senlin/tests/unit/db/test_profile_api.py View File

@@ -57,7 +57,8 @@ class DBAPIProfileTest(base.SenlinTestCase):
admin_ctx = utils.dummy_context(project='a-different-project',
is_admin=True)
res = db_api.profile_get(admin_ctx, profile.id, project_safe=True)
self.assertIsNone(res)
self.assertIsNotNone(res)

res = db_api.profile_get(admin_ctx, profile.id, project_safe=False)
self.assertIsNotNone(res)

@@ -186,7 +187,8 @@ class DBAPIProfileTest(base.SenlinTestCase):
admin_ctx = utils.dummy_context(project='a-different-project',
is_admin=True)
profiles = db_api.profile_get_all(admin_ctx, project_safe=True)
self.assertEqual(0, len(profiles))
self.assertEqual(2, len(profiles))

profiles = db_api.profile_get_all(admin_ctx, project_safe=False)
self.assertEqual(2, len(profiles))


+ 4
- 2
senlin/tests/unit/db/test_receiver_api.py View File

@@ -85,7 +85,8 @@ class DBAPIReceiverTest(base.SenlinTestCase):
r = self._create_receiver(self.ctx)

res = db_api.receiver_get(admin_ctx, r.id, project_safe=True)
self.assertIsNone(res)
self.assertIsNotNone(res)

res = db_api.receiver_get(admin_ctx, r.id, project_safe=False)
self.assertIsNotNone(res)

@@ -295,7 +296,8 @@ class DBAPIReceiverTest(base.SenlinTestCase):
admin_ctx = utils.dummy_context(project='a-different-project',
is_admin=True)
results = db_api.receiver_get_all(admin_ctx, project_safe=True)
self.assertEqual(0, len(results))
self.assertEqual(2, len(results))

results = db_api.receiver_get_all(admin_ctx, project_safe=False)
self.assertEqual(2, len(results))


Loading…
Cancel
Save