More test cases for cluster DB APIs

This commit is contained in:
tengqm 2015-01-03 15:55:26 +08:00
parent f08f3d32e3
commit 8b24aad0d4
2 changed files with 353 additions and 12 deletions

View File

@ -141,7 +141,7 @@ def cluster_get_all_by_parent(context, parent):
def cluster_get_by_name_and_parent(context, cluster_name, parent): def cluster_get_by_name_and_parent(context, cluster_name, parent):
query = soft_delete_aware_query(context, models.Cluster).\ query = soft_delete_aware_query(context, models.Cluster).\
filter_by(tenant=context.tenant_id).\ filter_by(project=context.tenant_id).\
filter_by(name=cluster_name).\ filter_by(name=cluster_name).\
filter_by(parent=parent) filter_by(parent=parent)
return query.first() return query.first()
@ -155,13 +155,11 @@ def cluster_get_by_name(context, cluster_name):
def _query_cluster_get_all(context, tenant_safe=True, show_deleted=False, def _query_cluster_get_all(context, tenant_safe=True, show_deleted=False,
show_nested=False): show_nested=False):
q0 = soft_delete_aware_query(context, models.Cluster, query = soft_delete_aware_query(context, models.Cluster,
show_deleted=show_deleted) show_deleted=show_deleted)
if show_nested: if not show_nested:
query = q0.filter_by(backup=False) query = query.filter_by(parent=None)
else:
query = q0.filter_by(parent=None)
if tenant_safe: if tenant_safe:
query = query.filter_by(project=context.tenant_id) query = query.filter_by(project=context.tenant_id)
@ -285,10 +283,6 @@ def node_get_all(context):
def node_get_all_by_cluster(context, cluster_id): def node_get_all_by_cluster(context, cluster_id):
query = model_query(context, models.Node).filter_by(cluster_id=cluster_id) query = model_query(context, models.Node).filter_by(cluster_id=cluster_id)
nodes = query.all() nodes = query.all()
if not nodes:
msg = _("No nodes for cluster %s were found") % cluster_id
raise exception.NotFound(msg)
return dict((node.name, node) for node in nodes) return dict((node.name, node) for node in nodes)

View File

@ -10,6 +10,9 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import datetime
import mock
from senlin.common import exception from senlin.common import exception
from senlin.db.sqlalchemy import api as db_api from senlin.db.sqlalchemy import api as db_api
from senlin.tests.common import base from senlin.tests.common import base
@ -111,7 +114,7 @@ class DBAPIClusterTest(base.SenlinTestCase):
self.assertEqual(cluster.id, ret_cluster.id) self.assertEqual(cluster.id, ret_cluster.id)
self.assertEqual('db_test_cluster_name', ret_cluster.name) self.assertEqual('db_test_cluster_name', ret_cluster.name)
def test_cluster_get_by_name(self): def test_cluster_get_by_name_default(self):
cluster = shared.create_cluster(self.ctx, self.profile) cluster = shared.create_cluster(self.ctx, self.profile)
ret_cluster = db_api.cluster_get_by_name(self.ctx, cluster.name) ret_cluster = db_api.cluster_get_by_name(self.ctx, cluster.name)
self.assertIsNotNone(ret_cluster) self.assertIsNotNone(ret_cluster)
@ -123,6 +126,23 @@ class DBAPIClusterTest(base.SenlinTestCase):
self.ctx.tenant_id = 'abc' self.ctx.tenant_id = 'abc'
self.assertIsNone(db_api.cluster_get_by_name(self.ctx, 'abc')) self.assertIsNone(db_api.cluster_get_by_name(self.ctx, 'abc'))
def test_cluster_get_by_name(self):
cluster = shared.create_cluster(self.ctx, self.profile,
name='cluster', project=UUID2)
res = db_api.cluster_get_by_name(self.ctx, 'cluster')
self.assertIsNone(res)
self.ctx.tenant_id = UUID3
self.assertIsNone(db_api.cluster_get_by_name(self.ctx, 'cluster'))
self.ctx.tenant_id = UUID2
res = db_api.cluster_get_by_name(self.ctx, 'cluster')
self.assertEqual(cluster.id, res.id)
db_api.cluster_delete(self.ctx, cluster.id)
self.assertIsNone(db_api.cluster_get_by_name(self.ctx, 'cluster'))
def test_cluster_get_all(self): def test_cluster_get_all(self):
values = [ values = [
{'name': 'cluster1'}, {'name': 'cluster1'},
@ -189,6 +209,314 @@ class DBAPIClusterTest(base.SenlinTestCase):
clusters = db_api.cluster_get_all(self.ctx, tenant_safe=False) clusters = db_api.cluster_get_all(self.ctx, tenant_safe=False)
self.assertEqual(5, len(clusters)) self.assertEqual(5, len(clusters))
def test_get_sort_keys_returns_empty_list_if_no_keys(self):
sort_keys = None
mapping = {}
filtered_keys = db_api._get_sort_keys(sort_keys, mapping)
self.assertEqual([], filtered_keys)
def test_get_sort_keys_whitelists_single_key(self):
sort_key = 'foo'
mapping = {'foo': 'Foo'}
filtered_keys = db_api._get_sort_keys(sort_key, mapping)
self.assertEqual(['Foo'], filtered_keys)
def test_get_sort_keys_whitelists_multiple_keys(self):
sort_keys = ['foo', 'bar', 'nope']
mapping = {'foo': 'Foo', 'bar': 'Bar'}
filtered_keys = db_api._get_sort_keys(sort_keys, mapping)
self.assertIn('Foo', filtered_keys)
self.assertIn('Bar', filtered_keys)
self.assertEqual(2, len(filtered_keys))
@mock.patch.object(db_api.utils, 'paginate_query')
def test_paginate_query_raises_invalid_sort_key(self, mock_paginate_query):
query = mock.Mock()
model = mock.Mock()
mock_paginate_query.side_effect = db_api.utils.InvalidSortKey()
self.assertRaises(exception.Invalid, db_api._paginate_query,
self.ctx, query, model, sort_keys=['foo'])
@mock.patch.object(db_api.utils, 'paginate_query')
@mock.patch.object(db_api, 'model_query')
def test_paginate_query_gets_model_marker(self, mock_query,
mock_paginate_query):
query = mock.Mock()
model = mock.Mock()
marker = mock.Mock()
mock_query_object = mock.Mock()
mock_query_object.get.return_value = 'real_marker'
mock_query.return_value = mock_query_object
db_api._paginate_query(self.ctx, query, model, marker=marker)
mock_query_object.get.assert_called_once_with(marker)
args, _ = mock_paginate_query.call_args
self.assertIn('real_marker', args)
@mock.patch.object(db_api.utils, 'paginate_query')
def test_paginate_query_default_sorts_by_created_at_and_id(
self, mock_paginate_query):
query = mock.Mock()
model = mock.Mock()
db_api._paginate_query(self.ctx, query, model, sort_keys=None)
args, _ = mock_paginate_query.call_args
self.assertIn(['created_time', 'id'], args)
@mock.patch.object(db_api.utils, 'paginate_query')
def test_paginate_query_default_sorts_dir_by_desc(self,
mock_paginate_query):
query = mock.Mock()
model = mock.Mock()
db_api._paginate_query(self.ctx, query, model, sort_dir=None)
args, _ = mock_paginate_query.call_args
self.assertIn('desc', args)
@mock.patch.object(db_api.utils, 'paginate_query')
def test_paginate_query_uses_given_sort_plus_id(self,
mock_paginate_query):
query = mock.Mock()
model = mock.Mock()
db_api._paginate_query(self.ctx, query, model, sort_keys=['name'])
args, _ = mock_paginate_query.call_args
self.assertIn(['name', 'id'], args)
@mock.patch.object(db_api, '_paginate_query')
def test_filter_and_page_query_paginates_query(self, mock_paginate_query):
query = mock.Mock()
db_api._filter_and_page_query(self.ctx, query)
assert mock_paginate_query.called
@mock.patch.object(db_api, '_events_paginate_query')
def test_events_filter_and_page_query(self, mock_events_paginate_query):
query = mock.Mock()
db_api._events_filter_and_page_query(self.ctx, query)
assert mock_events_paginate_query.called
@mock.patch.object(db_api.db_filters, 'exact_filter')
def test_filter_and_page_query_handles_no_filters(self, mock_db_filter):
query = mock.Mock()
db_api._filter_and_page_query(self.ctx, query)
mock_db_filter.assert_called_once_with(mock.ANY, mock.ANY, {})
@mock.patch.object(db_api.db_filters, 'exact_filter')
def test_events_filter_and_page_query_handles_no_filters(self,
mock_db_filter):
query = mock.Mock()
db_api._events_filter_and_page_query(self.ctx, query)
mock_db_filter.assert_called_once_with(mock.ANY, mock.ANY, {})
@mock.patch.object(db_api.db_filters, 'exact_filter')
def test_filter_and_page_query_applies_filters(self, mock_db_filter):
query = mock.Mock()
filters = {'foo': 'bar'}
db_api._filter_and_page_query(self.ctx, query, filters=filters)
assert mock_db_filter.called
@mock.patch.object(db_api.db_filters, 'exact_filter')
def test_events_filter_and_page_query_applies_filters(self,
mock_db_filter):
query = mock.Mock()
filters = {'foo': 'bar'}
db_api._events_filter_and_page_query(self.ctx, query, filters=filters)
assert mock_db_filter.called
@mock.patch.object(db_api, '_paginate_query')
def test_filter_and_page_query_whitelists_sort_keys(self,
mock_paginate_query):
query = mock.Mock()
sort_keys = ['name', 'foo']
db_api._filter_and_page_query(self.ctx, query, sort_keys=sort_keys)
args, _ = mock_paginate_query.call_args
self.assertIn(['name'], args)
def test_nested_cluster_get_by_name(self):
cluster1 = shared.create_cluster(self.ctx, self.profile,
name='cluster1')
cluster2 = shared.create_cluster(self.ctx, self.profile,
name='cluster2',
parent=cluster1.id)
result = db_api.cluster_get_by_name(self.ctx, 'cluster2')
self.assertEqual(cluster2.id, result.id)
db_api.cluster_delete(self.ctx, cluster2.id)
result = db_api.cluster_get_by_name(self.ctx, 'cluster2')
self.assertIsNone(result)
def test_cluster_get_by_name_and_parent(self):
cluster1 = shared.create_cluster(self.ctx, self.profile,
name='cluster1')
cluster2 = shared.create_cluster(self.ctx, self.profile,
name='cluster2',
parent=cluster1.id)
result = db_api.cluster_get_by_name_and_parent(self.ctx, 'cluster2',
None)
self.assertIsNone(result)
result = db_api.cluster_get_by_name_and_parent(self.ctx, 'cluster2',
cluster1.id)
self.assertEqual(cluster2.id, result.id)
def test_cluster_get_show_deleted_context(self):
cluster = shared.create_cluster(self.ctx, self.profile)
self.assertFalse(self.ctx.show_deleted)
result = db_api.cluster_get(self.ctx, cluster.id)
self.assertEqual(cluster.id, result.id)
db_api.cluster_delete(self.ctx, cluster.id)
result = db_api.cluster_get(self.ctx, cluster.id)
self.assertIsNone(result)
self.ctx.show_deleted = True
result = db_api.cluster_get(self.ctx, cluster.id)
self.assertEqual(cluster.id, result.id)
def test_cluster_get_all_show_deleted(self):
clusters = [shared.create_cluster(self.ctx, self.profile)
for x in range(3)]
results = db_api.cluster_get_all(self.ctx)
self.assertEqual(3, len(results))
db_api.cluster_delete(self.ctx, clusters[0].id)
results = db_api.cluster_get_all(self.ctx)
self.assertEqual(2, len(results))
results = db_api.cluster_get_all(self.ctx, show_deleted=True)
self.assertEqual(3, len(results))
def test_cluster_get_all_show_nested(self):
cluster1 = shared.create_cluster(self.ctx, self.profile,
name='cluster1')
cluster2 = shared.create_cluster(self.ctx, self.profile,
name='cluster2',
parent=cluster1.id)
cl_db = db_api.cluster_get_all(self.ctx)
self.assertEqual(1, len(cl_db))
self.assertEqual(cluster1.id, cl_db[0].id)
cl_db = db_api.cluster_get_all(self.ctx, show_nested=True)
self.assertEqual(2, len(cl_db))
cl_ids = [s.id for s in cl_db]
self.assertIn(cluster1.id, cl_ids)
self.assertIn(cluster2.id, cl_ids)
def test_cluster_get_all_with_filters(self):
shared.create_cluster(self.ctx, self.profile, name='foo')
shared.create_cluster(self.ctx, self.profile, name='bar')
filters = {'name': ['bar', 'quux']}
results = db_api.cluster_get_all(self.ctx, filters=filters)
self.assertEqual(1, len(results))
self.assertEqual('bar', results[0]['name'])
filters = {'name': 'foo'}
results = db_api.cluster_get_all(self.ctx, filters=filters)
self.assertEqual(1, len(results))
self.assertEqual('foo', results[0]['name'])
def test_cluster_get_all_returns_all_if_no_filters(self):
shared.create_cluster(self.ctx, self.profile)
shared.create_cluster(self.ctx, self.profile)
filters = None
results = db_api.cluster_get_all(self.ctx, filters=filters)
self.assertEqual(2, len(results))
def test_cluster_get_all_default_sort_dir(self):
dt = datetime.datetime
clusters = [shared.create_cluster(self.ctx, self.profile,
created_time=dt.utcnow())
for x in range(3)]
st_db = db_api.cluster_get_all(self.ctx, sort_dir='asc')
self.assertEqual(3, len(st_db))
self.assertEqual(clusters[0].id, st_db[0].id)
self.assertEqual(clusters[1].id, st_db[1].id)
self.assertEqual(clusters[2].id, st_db[2].id)
def test_cluster_get_all_str_sort_keys(self):
dt = datetime.datetime
clusters = [shared.create_cluster(self.ctx, self.profile,
created_time=dt.utcnow())
for x in range(3)]
st_db = db_api.cluster_get_all(self.ctx, sort_keys='created_time')
self.assertEqual(3, len(st_db))
self.assertEqual(clusters[0].id, st_db[0].id)
self.assertEqual(clusters[1].id, st_db[1].id)
self.assertEqual(clusters[2].id, st_db[2].id)
@mock.patch.object(db_api.utils, 'paginate_query')
def test_cluster_get_all_filters_sort_keys(self, mock_paginate):
sort_keys = ['name', 'status', 'created_time',
'updated_time', 'parent']
db_api.cluster_get_all(self.ctx, sort_keys=sort_keys)
args = mock_paginate.call_args[0]
used_sort_keys = set(args[3])
expected_keys = set(['name', 'status', 'created_time',
'updated_time', 'id'])
self.assertEqual(expected_keys, used_sort_keys)
def test_cluster_get_all_marker(self):
dt = datetime.datetime
clusters = [shared.create_cluster(self.ctx, self.profile,
created_time=dt.utcnow())
for x in range(3)]
cl_db = db_api.cluster_get_all(self.ctx, marker=clusters[1].id)
self.assertEqual(1, len(cl_db))
self.assertEqual(clusters[0].id, cl_db[0].id)
def test_cluster_get_all_non_existing_marker(self):
[shared.create_cluster(self.ctx, self.profile) for x in range(3)]
uuid = 'this cluster doesnt exist'
st_db = db_api.cluster_get_all(self.ctx, marker=uuid)
self.assertEqual(3, len(st_db))
def test_cluster_get_all_doesnt_mutate_sort_keys(self):
[shared.create_cluster(self.ctx, self.profile) for x in range(3)]
sort_keys = ['id']
db_api.cluster_get_all(self.ctx, sort_keys=sort_keys)
self.assertEqual(['id'], sort_keys)
def test_cluster_count_all(self):
clusters = [shared.create_cluster(self.ctx, self.profile)
for i in range(3)]
cl_db = db_api.cluster_count_all(self.ctx)
self.assertEqual(3, cl_db)
db_api.cluster_delete(self.ctx, clusters[0].id)
cl_db = db_api.cluster_count_all(self.ctx)
self.assertEqual(2, cl_db)
# show deleted
cl_db = db_api.cluster_count_all(self.ctx, show_deleted=True)
self.assertEqual(3, cl_db)
db_api.cluster_delete(self.ctx, clusters[1].id)
cl_db = db_api.cluster_count_all(self.ctx)
self.assertEqual(1, cl_db)
# show deleted
cl_db = db_api.cluster_count_all(self.ctx, show_deleted=True)
self.assertEqual(3, cl_db)
def test_cluster_count_all_with_regular_tenant(self): def test_cluster_count_all_with_regular_tenant(self):
values = [ values = [
{'tenant_id': UUID1}, {'tenant_id': UUID1},
@ -218,6 +546,25 @@ class DBAPIClusterTest(base.SenlinTestCase):
self.assertEqual(5, db_api.cluster_count_all(self.ctx, self.assertEqual(5, db_api.cluster_count_all(self.ctx,
tenant_safe=False)) tenant_safe=False))
def test_cluster_count_all_show_nested(self):
cluster1 = shared.create_cluster(self.ctx, self.profile, name='c1')
shared.create_cluster(self.ctx, self.profile, name='c2',
parent=cluster1.id)
results = db_api.cluster_count_all(self.ctx)
self.assertEqual(1, results)
results = db_api.cluster_count_all(self.ctx, show_nested=True)
self.assertEqual(2, results)
def test_cluster_count_all_with_filters(self):
shared.create_cluster(self.ctx, self.profile, name='foo')
shared.create_cluster(self.ctx, self.profile, name='bar')
shared.create_cluster(self.ctx, self.profile, name='bar')
filters = {'name': 'bar'}
cl_db = db_api.cluster_count_all(self.ctx, filters=filters)
self.assertEqual(2, cl_db)
def _deleted_cluster_existance(self, ctx, clusters, existing, deleted): def _deleted_cluster_existance(self, ctx, clusters, existing, deleted):
for s in existing: for s in existing:
self.assertIsNotNone(db_api.cluster_get(ctx, clusters[s].id, self.assertIsNotNone(db_api.cluster_get(ctx, clusters[s].id,