Make all cluster DB queries project_safe

This patch makes all cluster realted DB queries project_safe by
default.

Implements: blueprint make-db-query-project-safe
Change-Id: I8bfe501b211df5214ecf5de254fa62025548d651
This commit is contained in:
yanyanhu 2015-10-14 22:50:40 -04:00
parent bc4a96ccf3
commit 1c195e9664
3 changed files with 36 additions and 9 deletions

View File

@ -45,12 +45,14 @@ def cluster_get(context, cluster_id, show_deleted=False, project_safe=True):
project_safe=project_safe)
def cluster_get_by_name(context, cluster_name):
return IMPL.cluster_get_by_name(context, cluster_name)
def cluster_get_by_name(context, cluster_name, project_safe=True):
return IMPL.cluster_get_by_name(context, cluster_name,
project_safe=project_safe)
def cluster_get_by_short_id(context, short_id):
return IMPL.cluster_get_by_short_id(context, short_id)
def cluster_get_by_short_id(context, short_id, project_safe=True):
return IMPL.cluster_get_by_short_id(context, short_id,
project_safe=project_safe)
def cluster_get_all(context, limit=None, marker=None, sort_keys=None,

View File

@ -128,9 +128,15 @@ def soft_delete_aware_query(context, *args, **kwargs):
return query
def query_by_short_id(context, model, short_id, show_deleted=False):
# TODO(Yanyan Hu): Set default value of project_safe to True
def query_by_short_id(context, model, short_id, project_safe=False,
show_deleted=False):
q = soft_delete_aware_query(context, model, show_deleted=show_deleted)
q = q.filter(model.id.like('%s%%' % short_id))
if project_safe:
q = q.filter_by(project=context.project)
if q.count() == 1:
return q.first()
elif q.count() == 0:
@ -139,6 +145,7 @@ def query_by_short_id(context, model, short_id, show_deleted=False):
raise exception.MultipleChoices(arg=short_id)
# TODO(Yanyan Hu): Set default value of project_safe to True
def query_by_name(context, model, name, project_safe=False,
show_deleted=False):
q = soft_delete_aware_query(context, model, show_deleted=show_deleted)
@ -181,12 +188,14 @@ def cluster_get(context, cluster_id, show_deleted=False, project_safe=True):
return cluster
def cluster_get_by_name(context, name):
return query_by_name(context, models.Cluster, name, True)
def cluster_get_by_name(context, name, project_safe=True):
return query_by_name(context, models.Cluster, name,
project_safe=project_safe)
def cluster_get_by_short_id(context, short_id):
return query_by_short_id(context, models.Cluster, short_id)
def cluster_get_by_short_id(context, short_id, project_safe=True):
return query_by_short_id(context, models.Cluster, short_id,
project_safe=project_safe)
def _query_cluster_get_all(context, project_safe=True, show_deleted=False,

View File

@ -165,6 +165,22 @@ class DBAPIClusterTest(base.SenlinTestCase):
res = db_api.cluster_get_by_short_id(self.ctx, 'non-existent')
self.assertIsNone(res)
ctx_new = utils.dummy_context(project='different_project_id')
res = db_api.cluster_get_by_short_id(ctx_new, cid1[:11])
self.assertIsNone(res)
def test_cluster_get_by_short_id_diff_project(self):
cluster1 = shared.create_cluster(self.ctx, self.profile,
id=UUID1,
name='cluster-1')
res = db_api.cluster_get_by_short_id(self.ctx, UUID1[:11])
self.assertEqual(cluster1.id, res.id)
ctx_new = utils.dummy_context(project='different_project_id')
res = db_api.cluster_get_by_short_id(ctx_new, UUID1[:11])
self.assertIsNone(res)
def test_cluster_get_all(self):
values = [
{'name': 'cluster1'},