From be0cf7617bbe4a627f08d4ece58cb3d6cd52d5ea Mon Sep 17 00:00:00 2001 From: Duc Truong Date: Mon, 10 Feb 2020 19:48:29 +0000 Subject: [PATCH] Ignore project_safe restriction for admin users DB operations were previously restricted to resources belonging to the project used for authentication. This change removes this restriction so that admin users can access/modify resources belonging to any project. Change-Id: I7882ebeb194137e682bdb7ab90f03587c636a7f8 --- ...ject_safe-for-admins-2986f15e74cd1d1c.yaml | 11 +++ senlin/db/sqlalchemy/api.py | 96 +++++-------------- senlin/db/sqlalchemy/utils.py | 43 +++++++++ senlin/tests/unit/db/test_action_api.py | 4 +- senlin/tests/unit/db/test_cluster_api.py | 9 +- senlin/tests/unit/db/test_event_api.py | 10 +- senlin/tests/unit/db/test_node_api.py | 15 ++- senlin/tests/unit/db/test_policy_api.py | 6 +- senlin/tests/unit/db/test_profile_api.py | 6 +- senlin/tests/unit/db/test_receiver_api.py | 6 +- 10 files changed, 115 insertions(+), 91 deletions(-) create mode 100644 releasenotes/notes/db-ignore-project_safe-for-admins-2986f15e74cd1d1c.yaml diff --git a/releasenotes/notes/db-ignore-project_safe-for-admins-2986f15e74cd1d1c.yaml b/releasenotes/notes/db-ignore-project_safe-for-admins-2986f15e74cd1d1c.yaml new file mode 100644 index 000000000..5442a192d --- /dev/null +++ b/releasenotes/notes/db-ignore-project_safe-for-admins-2986f15e74cd1d1c.yaml @@ -0,0 +1,11 @@ +--- +features: + - | + Admin role users can now access and modify all resources (clusters, nodes, + etc) regardless of which project that belong to. +security: + - | + Removed the restriction for admin role users that prevented access/changes + to resources (clusters, nodes, etc) belonging to projects not matching the + project used for authentication. Access for non-admin users is still + isolated to their project used for authentication. diff --git a/senlin/db/sqlalchemy/api.py b/senlin/db/sqlalchemy/api.py index 532b09806..52de1947b 100755 --- a/senlin/db/sqlalchemy/api.py +++ b/senlin/db/sqlalchemy/api.py @@ -90,9 +90,7 @@ def query_by_short_id(context, model_query, model, short_id, project_safe=True): q = model_query() q = q.filter(model.id.like('%s%%' % short_id)) - - if project_safe: - q = q.filter_by(project=context.project_id) + q = utils.filter_query_by_project(q, project_safe, context) if q.count() == 1: return q.first() @@ -105,9 +103,7 @@ def query_by_short_id(context, model_query, model, short_id, def query_by_name(context, model_query, name, project_safe=True): q = model_query() q = q.filter_by(name=name) - - if project_safe: - q = q.filter_by(project=context.project_id) + q = utils.filter_query_by_project(q, project_safe, context) if q.count() == 1: return q.first() @@ -143,10 +139,7 @@ def cluster_get(context, cluster_id, project_safe=True): if cluster is None: return None - if project_safe: - if context.project_id != cluster.project: - return None - return cluster + return utils.check_resource_project(context, cluster, project_safe) def cluster_get_by_name(context, name, project_safe=True): @@ -161,9 +154,8 @@ def cluster_get_by_short_id(context, short_id, project_safe=True): def _query_cluster_get_all(context, project_safe=True): query = cluster_model_query() + query = utils.filter_query_by_project(query, project_safe, context) - if project_safe: - query = query.filter_by(project=context.project_id) return query @@ -260,11 +252,7 @@ def node_get(context, node_id, project_safe=True): if not node: return None - if project_safe: - if context.project_id != node.project: - return None - - return node + return utils.check_resource_project(context, node, project_safe) def node_get_by_name(context, name, project_safe=True): @@ -283,8 +271,7 @@ def _query_node_get_all(context, project_safe=True, cluster_id=None): if cluster_id is not None: query = query.filter_by(cluster_id=cluster_id) - if project_safe: - query = query.filter_by(project=context.project_id) + query = utils.filter_query_by_project(query, project_safe, context) return query @@ -330,8 +317,7 @@ def node_count_by_cluster(context, cluster_id, **kwargs): query = node_model_query() query = query.filter_by(cluster_id=cluster_id) query = query.filter_by(**kwargs) - if project_safe: - query = query.filter_by(project=context.project_id) + query = utils.filter_query_by_project(query, project_safe, context) return query.count() @@ -612,11 +598,7 @@ def policy_get(context, policy_id, project_safe=True): if policy is None: return None - if project_safe: - if context.project_id != policy.project: - return None - - return policy + return utils.check_resource_project(context, policy, project_safe) def policy_get_by_name(context, name, project_safe=True): @@ -632,9 +614,7 @@ def policy_get_by_short_id(context, short_id, project_safe=True): def policy_get_all(context, limit=None, marker=None, sort=None, filters=None, project_safe=True): query = policy_model_query() - - if project_safe: - query = query.filter_by(project=context.project_id) + query = utils.filter_query_by_project(query, project_safe, context) if filters: query = utils.exact_filter(query, models.Policy, filters) @@ -861,11 +841,7 @@ def profile_get(context, profile_id, project_safe=True): if profile is None: return None - if project_safe: - if context.project_id != profile.project: - return None - - return profile + return utils.check_resource_project(context, profile, project_safe) def profile_get_by_name(context, name, project_safe=True): @@ -881,9 +857,7 @@ def profile_get_by_short_id(context, short_id, project_safe=True): def profile_get_all(context, limit=None, marker=None, sort=None, filters=None, project_safe=True): query = profile_model_query() - - if project_safe: - query = query.filter_by(project=context.project_id) + query = utils.filter_query_by_project(query, project_safe, context) if filters: query = utils.exact_filter(query, models.Profile, filters) @@ -996,11 +970,7 @@ def event_create(context, values): @retry_on_deadlock def event_get(context, event_id, project_safe=True): event = event_model_query().get(event_id) - if project_safe and event is not None: - if event.project != context.project_id: - return None - - return event + return utils.check_resource_project(context, event, project_safe) def event_get_by_short_id(context, short_id, project_safe=True): @@ -1023,8 +993,7 @@ def _event_filter_paginate_query(context, query, filters=None, def event_get_all(context, limit=None, marker=None, sort=None, filters=None, project_safe=True): query = event_model_query() - if project_safe: - query = query.filter_by(project=context.project_id) + query = utils.filter_query_by_project(query, project_safe, context) return _event_filter_paginate_query(context, query, filters=filters, limit=limit, marker=marker, sort=sort) @@ -1032,9 +1001,8 @@ def event_get_all(context, limit=None, marker=None, sort=None, filters=None, def event_count_by_cluster(context, cluster_id, project_safe=True): query = event_model_query() + query = utils.filter_query_by_project(query, project_safe, context) - if project_safe: - query = query.filter_by(project=context.project_id) count = query.filter_by(cluster_id=cluster_id).count() return count @@ -1044,9 +1012,7 @@ def event_get_all_by_cluster(context, cluster_id, limit=None, marker=None, sort=None, filters=None, project_safe=True): query = event_model_query() query = query.filter_by(cluster_id=cluster_id) - - if project_safe: - query = query.filter_by(project=context.project_id) + query = utils.filter_query_by_project(query, project_safe, context) return _event_filter_paginate_query(context, query, filters=filters, limit=limit, marker=marker, sort=sort) @@ -1057,8 +1023,7 @@ def event_prune(context, cluster_id, project_safe=True): with session_for_write() as session: query = session.query(models.Event).with_for_update() query = query.filter_by(cluster_id=cluster_id) - if project_safe: - query = query.filter_by(project=context.project_id) + query = utils.filter_query_by_project(query, project_safe, context) return query.delete(synchronize_session='fetch') @@ -1117,17 +1082,13 @@ def action_get(context, action_id, project_safe=True, refresh=False): if action is None: return None - if project_safe: - if action.project != context.project_id: - return None - - return action + return utils.check_resource_project(context, action, project_safe) def action_list_active_scaling(context, cluster_id=None, project_safe=True): query = action_model_query() - if project_safe: - query = query.filter_by(project=context.project_id) + query = utils.filter_query_by_project(query, project_safe, context) + if cluster_id: query = query.filter_by(target=cluster_id) query = query.filter( @@ -1159,8 +1120,7 @@ def action_get_all_by_owner(context, owner_id): def action_get_all_active_by_target(context, target_id, project_safe=True): query = action_model_query() - if project_safe: - query = query.filter_by(project=context.project_id) + query = utils.filter_query_by_project(query, project_safe, context) query = query.filter_by(target=target_id) query = query.filter( models.Action.status.in_( @@ -1175,8 +1135,7 @@ def action_get_all_active_by_target(context, target_id, project_safe=True): def action_get_all(context, filters=None, limit=None, marker=None, sort=None, project_safe=True): query = action_model_query() - if project_safe: - query = query.filter_by(project=context.project_id) + query = utils.filter_query_by_project(query, project_safe, context) if filters: query = utils.exact_filter(query, models.Action, filters) @@ -1480,9 +1439,7 @@ def action_delete_by_target(context, target, action=None, with session_for_write() as session: q = session.query(models.Action).filter_by(target=target) - - if project_safe: - q = q.filter_by(project=context.project_id) + q = utils.filter_query_by_project(q, project_safe, context) if action: q = q.filter(models.Action.action.in_(action)) @@ -1537,18 +1494,13 @@ def receiver_get(context, receiver_id, project_safe=True): if not receiver: return None - if project_safe: - if context.project_id != receiver.project: - return None - - return receiver + return utils.check_resource_project(context, receiver, project_safe) def receiver_get_all(context, limit=None, marker=None, filters=None, sort=None, project_safe=True): query = receiver_model_query() - if project_safe: - query = query.filter_by(project=context.project_id) + query = utils.filter_query_by_project(query, project_safe, context) if filters: query = utils.exact_filter(query, models.Receiver, filters) diff --git a/senlin/db/sqlalchemy/utils.py b/senlin/db/sqlalchemy/utils.py index 6c5906ca6..1daa2a996 100644 --- a/senlin/db/sqlalchemy/utils.py +++ b/senlin/db/sqlalchemy/utils.py @@ -46,6 +46,49 @@ def exact_filter(query, model, filters): return query +def filter_query_by_project(q, project_safe, context): + """Filters a query to the context's project + + Returns the updated query, Adds filter to limit project to the + context's project for non-admin users. For admin users, + the query is returned unmodified. + + :param query: query to apply filters to + :param project_safe: boolean indicating if project restriction filter + should be applied + :param context: context of the query + + """ + + if project_safe and not context.is_admin: + return q.filter_by(project=context.project_id) + + return q + + +def check_resource_project(context, resource, project_safe): + """Check if the resource's project matches the context's project + + For non-admin users, if project_safe is set and the resource's project + does not match the context's project, none is returned. + Otherwise return the resource unmodified. + + :param context: context of the call + :param resource: resource to check + :param project_safe: boolean indicating if project restriction should be + checked. + """ + + if resource is None: + return resource + + if project_safe and not context.is_admin: + if context.project_id != resource.project: + return None + + return resource + + def get_sort_params(value, default_key=None): """Parse a string into a list of sort_keys and a list of sort_dirs. diff --git a/senlin/tests/unit/db/test_action_api.py b/senlin/tests/unit/db/test_action_api.py index 339c12aea..577cec563 100755 --- a/senlin/tests/unit/db/test_action_api.py +++ b/senlin/tests/unit/db/test_action_api.py @@ -152,8 +152,10 @@ class DBAPIActionTest(base.SenlinTestCase): parser.simple_parse(shared.sample_action) action = _create_action(self.ctx) new_ctx = utils.dummy_context(project='another-project', is_admin=True) + retobj = db_api.action_get(new_ctx, action.id, project_safe=True) - self.assertIsNone(retobj) + self.assertIsNotNone(retobj) + retobj = db_api.action_get(new_ctx, action.id, project_safe=False) self.assertIsNotNone(retobj) diff --git a/senlin/tests/unit/db/test_cluster_api.py b/senlin/tests/unit/db/test_cluster_api.py index 0871d1061..358491957 100644 --- a/senlin/tests/unit/db/test_cluster_api.py +++ b/senlin/tests/unit/db/test_cluster_api.py @@ -81,7 +81,9 @@ class DBAPIClusterTest(base.SenlinTestCase): is_admin=True) ret_cluster = db_api.cluster_get(admin_ctx, cluster.id, project_safe=True) - self.assertIsNone(ret_cluster) + self.assertEqual(cluster.id, ret_cluster.id) + self.assertEqual('db_test_cluster_name', ret_cluster.name) + ret_cluster = db_api.cluster_get(admin_ctx, cluster.id, project_safe=False) self.assertEqual(cluster.id, ret_cluster.id) @@ -228,7 +230,8 @@ class DBAPIClusterTest(base.SenlinTestCase): admin_ctx = utils.dummy_context(project='another-project', is_admin=True) clusters = db_api.cluster_get_all(admin_ctx, project_safe=True) - self.assertEqual(0, len(clusters)) + self.assertEqual(5, len(clusters)) + clusters = db_api.cluster_get_all(admin_ctx, project_safe=False) self.assertEqual(5, len(clusters)) @@ -372,7 +375,7 @@ class DBAPIClusterTest(base.SenlinTestCase): admin_ctx = utils.dummy_context(project='another-project', is_admin=True) - self.assertEqual(0, db_api.cluster_count_all(admin_ctx, + self.assertEqual(5, db_api.cluster_count_all(admin_ctx, project_safe=True)) self.assertEqual(5, db_api.cluster_count_all(admin_ctx, project_safe=False)) diff --git a/senlin/tests/unit/db/test_event_api.py b/senlin/tests/unit/db/test_event_api.py index c027de717..a6445b1c8 100644 --- a/senlin/tests/unit/db/test_event_api.py +++ b/senlin/tests/unit/db/test_event_api.py @@ -102,7 +102,7 @@ class DBAPIEventTest(base.SenlinTestCase): admin_ctx = utils.dummy_context(project='a-different-project', is_admin=True) res = db_api.event_get(admin_ctx, event.id, project_safe=True) - self.assertIsNone(res) + self.assertIsNotNone(res) res = db_api.event_get(admin_ctx, event.id, project_safe=False) self.assertIsNotNone(res) @@ -274,7 +274,8 @@ class DBAPIEventTest(base.SenlinTestCase): admin_ctx = utils.dummy_context(project='another-project', is_admin=True) events = db_api.event_get_all(admin_ctx, project_safe=True) - self.assertEqual(0, len(events)) + self.assertEqual(3, len(events)) + events = db_api.event_get_all(admin_ctx, project_safe=False) self.assertEqual(3, len(events)) @@ -354,7 +355,7 @@ class DBAPIEventTest(base.SenlinTestCase): is_admin=True) events = db_api.event_get_all_by_cluster(admin_ctx, cluster1.id, project_safe=True) - self.assertEqual(0, len(events)) + self.assertEqual(2, len(events)) events = db_api.event_get_all_by_cluster(admin_ctx, cluster1.id, project_safe=False) self.assertEqual(2, len(events)) @@ -423,7 +424,8 @@ class DBAPIEventTest(base.SenlinTestCase): res = db_api.event_count_by_cluster(admin_ctx, cluster1.id, project_safe=True) - self.assertEqual(0, res) + self.assertEqual(1, res) + res = db_api.event_count_by_cluster(admin_ctx, cluster1.id, project_safe=False) self.assertEqual(1, res) diff --git a/senlin/tests/unit/db/test_node_api.py b/senlin/tests/unit/db/test_node_api.py index c654dfff0..89ab0045e 100644 --- a/senlin/tests/unit/db/test_node_api.py +++ b/senlin/tests/unit/db/test_node_api.py @@ -78,7 +78,8 @@ class DBAPINodeTest(base.SenlinTestCase): admin_ctx = utils.dummy_context(project='a_different_project', is_admin=True) node = db_api.node_get(admin_ctx, res.id, project_safe=True) - self.assertIsNone(node) + self.assertIsNotNone(node) + node = db_api.node_get(admin_ctx, res.id, project_safe=False) self.assertIsNotNone(node) @@ -153,7 +154,8 @@ class DBAPINodeTest(base.SenlinTestCase): is_admin=True) res = db_api.node_get_by_short_id(admin_ctx, node_id[:11], project_safe=True) - self.assertIsNone(res) + self.assertIsNotNone(res) + res = db_api.node_get_by_short_id(admin_ctx, node_id[:11], project_safe=False) self.assertIsNotNone(res) @@ -368,7 +370,8 @@ class DBAPINodeTest(base.SenlinTestCase): admin_ctx = utils.dummy_context(project='a_different_project', is_admin=True) results = db_api.node_get_all(admin_ctx, project_safe=True) - self.assertEqual(0, len(results)) + self.assertEqual(2, len(results)) + results = db_api.node_get_all(admin_ctx, project_safe=False) self.assertEqual(2, len(results)) @@ -444,7 +447,8 @@ class DBAPINodeTest(base.SenlinTestCase): admin_ctx = utils.dummy_context(project='a_different_project', is_admin=True) nodes = db_api.node_get_all_by_cluster(admin_ctx, self.cluster.id) - self.assertEqual(0, len(nodes)) + self.assertEqual(2, len(nodes)) + nodes = db_api.node_get_all_by_cluster(admin_ctx, self.cluster.id, project_safe=False) self.assertEqual(2, len(nodes)) @@ -498,7 +502,8 @@ class DBAPINodeTest(base.SenlinTestCase): is_admin=True) res = db_api.node_count_by_cluster(admin_ctx, self.cluster.id, project_safe=True) - self.assertEqual(0, res) + self.assertEqual(2, res) + res = db_api.node_count_by_cluster(admin_ctx, self.cluster.id, project_safe=False) self.assertEqual(2, res) diff --git a/senlin/tests/unit/db/test_policy_api.py b/senlin/tests/unit/db/test_policy_api.py index 670c22af9..89b8565bc 100644 --- a/senlin/tests/unit/db/test_policy_api.py +++ b/senlin/tests/unit/db/test_policy_api.py @@ -90,7 +90,8 @@ class DBAPIPolicyTest(base.SenlinTestCase): admin_ctx = utils.dummy_context(project='a-different-project', is_admin=True) res = db_api.policy_get(admin_ctx, policy.id, project_safe=True) - self.assertIsNone(res) + self.assertIsNotNone(res) + res = db_api.policy_get(admin_ctx, policy.id, project_safe=False) self.assertIsNotNone(res) @@ -235,7 +236,8 @@ class DBAPIPolicyTest(base.SenlinTestCase): admin_ctx = utils.dummy_context(project='a-different-project', is_admin=True) policies = db_api.policy_get_all(admin_ctx, project_safe=True) - self.assertEqual(0, len(policies)) + self.assertEqual(2, len(policies)) + policies = db_api.policy_get_all(admin_ctx, project_safe=False) self.assertEqual(2, len(policies)) diff --git a/senlin/tests/unit/db/test_profile_api.py b/senlin/tests/unit/db/test_profile_api.py index 17543bc71..9c62cfd3c 100644 --- a/senlin/tests/unit/db/test_profile_api.py +++ b/senlin/tests/unit/db/test_profile_api.py @@ -57,7 +57,8 @@ class DBAPIProfileTest(base.SenlinTestCase): admin_ctx = utils.dummy_context(project='a-different-project', is_admin=True) res = db_api.profile_get(admin_ctx, profile.id, project_safe=True) - self.assertIsNone(res) + self.assertIsNotNone(res) + res = db_api.profile_get(admin_ctx, profile.id, project_safe=False) self.assertIsNotNone(res) @@ -186,7 +187,8 @@ class DBAPIProfileTest(base.SenlinTestCase): admin_ctx = utils.dummy_context(project='a-different-project', is_admin=True) profiles = db_api.profile_get_all(admin_ctx, project_safe=True) - self.assertEqual(0, len(profiles)) + self.assertEqual(2, len(profiles)) + profiles = db_api.profile_get_all(admin_ctx, project_safe=False) self.assertEqual(2, len(profiles)) diff --git a/senlin/tests/unit/db/test_receiver_api.py b/senlin/tests/unit/db/test_receiver_api.py index e14079860..0032fc051 100644 --- a/senlin/tests/unit/db/test_receiver_api.py +++ b/senlin/tests/unit/db/test_receiver_api.py @@ -85,7 +85,8 @@ class DBAPIReceiverTest(base.SenlinTestCase): r = self._create_receiver(self.ctx) res = db_api.receiver_get(admin_ctx, r.id, project_safe=True) - self.assertIsNone(res) + self.assertIsNotNone(res) + res = db_api.receiver_get(admin_ctx, r.id, project_safe=False) self.assertIsNotNone(res) @@ -295,7 +296,8 @@ class DBAPIReceiverTest(base.SenlinTestCase): admin_ctx = utils.dummy_context(project='a-different-project', is_admin=True) results = db_api.receiver_get_all(admin_ctx, project_safe=True) - self.assertEqual(0, len(results)) + self.assertEqual(2, len(results)) + results = db_api.receiver_get_all(admin_ctx, project_safe=False) self.assertEqual(2, len(results))