diff --git a/senlin/db/sqlalchemy/api.py b/senlin/db/sqlalchemy/api.py index a004623b6..9ffd40eb5 100644 --- a/senlin/db/sqlalchemy/api.py +++ b/senlin/db/sqlalchemy/api.py @@ -141,7 +141,7 @@ def cluster_get_all_by_parent(context, parent): def cluster_get_by_name_and_parent(context, cluster_name, parent): query = soft_delete_aware_query(context, models.Cluster).\ - filter_by(tenant=context.tenant_id).\ + filter_by(project=context.tenant_id).\ filter_by(name=cluster_name).\ filter_by(parent=parent) return query.first() @@ -155,13 +155,11 @@ def cluster_get_by_name(context, cluster_name): def _query_cluster_get_all(context, tenant_safe=True, show_deleted=False, show_nested=False): - q0 = soft_delete_aware_query(context, models.Cluster, - show_deleted=show_deleted) + query = soft_delete_aware_query(context, models.Cluster, + show_deleted=show_deleted) - if show_nested: - query = q0.filter_by(backup=False) - else: - query = q0.filter_by(parent=None) + if not show_nested: + query = query.filter_by(parent=None) if tenant_safe: query = query.filter_by(project=context.tenant_id) @@ -285,10 +283,6 @@ def node_get_all(context): def node_get_all_by_cluster(context, cluster_id): query = model_query(context, models.Node).filter_by(cluster_id=cluster_id) nodes = query.all() - if not nodes: - msg = _("No nodes for cluster %s were found") % cluster_id - raise exception.NotFound(msg) - return dict((node.name, node) for node in nodes) diff --git a/senlin/tests/db/test_cluster_api.py b/senlin/tests/db/test_cluster_api.py index abe69b448..6d5e6dbd7 100644 --- a/senlin/tests/db/test_cluster_api.py +++ b/senlin/tests/db/test_cluster_api.py @@ -10,6 +10,9 @@ # License for the specific language governing permissions and limitations # under the License. +import datetime +import mock + from senlin.common import exception from senlin.db.sqlalchemy import api as db_api from senlin.tests.common import base @@ -111,7 +114,7 @@ class DBAPIClusterTest(base.SenlinTestCase): self.assertEqual(cluster.id, ret_cluster.id) self.assertEqual('db_test_cluster_name', ret_cluster.name) - def test_cluster_get_by_name(self): + def test_cluster_get_by_name_default(self): cluster = shared.create_cluster(self.ctx, self.profile) ret_cluster = db_api.cluster_get_by_name(self.ctx, cluster.name) self.assertIsNotNone(ret_cluster) @@ -123,6 +126,23 @@ class DBAPIClusterTest(base.SenlinTestCase): self.ctx.tenant_id = 'abc' self.assertIsNone(db_api.cluster_get_by_name(self.ctx, 'abc')) + def test_cluster_get_by_name(self): + cluster = shared.create_cluster(self.ctx, self.profile, + name='cluster', project=UUID2) + + res = db_api.cluster_get_by_name(self.ctx, 'cluster') + self.assertIsNone(res) + + self.ctx.tenant_id = UUID3 + self.assertIsNone(db_api.cluster_get_by_name(self.ctx, 'cluster')) + + self.ctx.tenant_id = UUID2 + res = db_api.cluster_get_by_name(self.ctx, 'cluster') + self.assertEqual(cluster.id, res.id) + + db_api.cluster_delete(self.ctx, cluster.id) + self.assertIsNone(db_api.cluster_get_by_name(self.ctx, 'cluster')) + def test_cluster_get_all(self): values = [ {'name': 'cluster1'}, @@ -189,6 +209,314 @@ class DBAPIClusterTest(base.SenlinTestCase): clusters = db_api.cluster_get_all(self.ctx, tenant_safe=False) self.assertEqual(5, len(clusters)) + def test_get_sort_keys_returns_empty_list_if_no_keys(self): + sort_keys = None + mapping = {} + + filtered_keys = db_api._get_sort_keys(sort_keys, mapping) + self.assertEqual([], filtered_keys) + + def test_get_sort_keys_whitelists_single_key(self): + sort_key = 'foo' + mapping = {'foo': 'Foo'} + + filtered_keys = db_api._get_sort_keys(sort_key, mapping) + self.assertEqual(['Foo'], filtered_keys) + + def test_get_sort_keys_whitelists_multiple_keys(self): + sort_keys = ['foo', 'bar', 'nope'] + mapping = {'foo': 'Foo', 'bar': 'Bar'} + + filtered_keys = db_api._get_sort_keys(sort_keys, mapping) + self.assertIn('Foo', filtered_keys) + self.assertIn('Bar', filtered_keys) + self.assertEqual(2, len(filtered_keys)) + + @mock.patch.object(db_api.utils, 'paginate_query') + def test_paginate_query_raises_invalid_sort_key(self, mock_paginate_query): + query = mock.Mock() + model = mock.Mock() + + mock_paginate_query.side_effect = db_api.utils.InvalidSortKey() + self.assertRaises(exception.Invalid, db_api._paginate_query, + self.ctx, query, model, sort_keys=['foo']) + + @mock.patch.object(db_api.utils, 'paginate_query') + @mock.patch.object(db_api, 'model_query') + def test_paginate_query_gets_model_marker(self, mock_query, + mock_paginate_query): + query = mock.Mock() + model = mock.Mock() + marker = mock.Mock() + + mock_query_object = mock.Mock() + mock_query_object.get.return_value = 'real_marker' + mock_query.return_value = mock_query_object + + db_api._paginate_query(self.ctx, query, model, marker=marker) + mock_query_object.get.assert_called_once_with(marker) + args, _ = mock_paginate_query.call_args + self.assertIn('real_marker', args) + + @mock.patch.object(db_api.utils, 'paginate_query') + def test_paginate_query_default_sorts_by_created_at_and_id( + self, mock_paginate_query): + query = mock.Mock() + model = mock.Mock() + db_api._paginate_query(self.ctx, query, model, sort_keys=None) + args, _ = mock_paginate_query.call_args + self.assertIn(['created_time', 'id'], args) + + @mock.patch.object(db_api.utils, 'paginate_query') + def test_paginate_query_default_sorts_dir_by_desc(self, + mock_paginate_query): + query = mock.Mock() + model = mock.Mock() + db_api._paginate_query(self.ctx, query, model, sort_dir=None) + args, _ = mock_paginate_query.call_args + self.assertIn('desc', args) + + @mock.patch.object(db_api.utils, 'paginate_query') + def test_paginate_query_uses_given_sort_plus_id(self, + mock_paginate_query): + query = mock.Mock() + model = mock.Mock() + db_api._paginate_query(self.ctx, query, model, sort_keys=['name']) + args, _ = mock_paginate_query.call_args + self.assertIn(['name', 'id'], args) + + @mock.patch.object(db_api, '_paginate_query') + def test_filter_and_page_query_paginates_query(self, mock_paginate_query): + query = mock.Mock() + db_api._filter_and_page_query(self.ctx, query) + + assert mock_paginate_query.called + + @mock.patch.object(db_api, '_events_paginate_query') + def test_events_filter_and_page_query(self, mock_events_paginate_query): + query = mock.Mock() + db_api._events_filter_and_page_query(self.ctx, query) + + assert mock_events_paginate_query.called + + @mock.patch.object(db_api.db_filters, 'exact_filter') + def test_filter_and_page_query_handles_no_filters(self, mock_db_filter): + query = mock.Mock() + db_api._filter_and_page_query(self.ctx, query) + + mock_db_filter.assert_called_once_with(mock.ANY, mock.ANY, {}) + + @mock.patch.object(db_api.db_filters, 'exact_filter') + def test_events_filter_and_page_query_handles_no_filters(self, + mock_db_filter): + query = mock.Mock() + db_api._events_filter_and_page_query(self.ctx, query) + + mock_db_filter.assert_called_once_with(mock.ANY, mock.ANY, {}) + + @mock.patch.object(db_api.db_filters, 'exact_filter') + def test_filter_and_page_query_applies_filters(self, mock_db_filter): + query = mock.Mock() + filters = {'foo': 'bar'} + db_api._filter_and_page_query(self.ctx, query, filters=filters) + + assert mock_db_filter.called + + @mock.patch.object(db_api.db_filters, 'exact_filter') + def test_events_filter_and_page_query_applies_filters(self, + mock_db_filter): + query = mock.Mock() + filters = {'foo': 'bar'} + db_api._events_filter_and_page_query(self.ctx, query, filters=filters) + + assert mock_db_filter.called + + @mock.patch.object(db_api, '_paginate_query') + def test_filter_and_page_query_whitelists_sort_keys(self, + mock_paginate_query): + query = mock.Mock() + sort_keys = ['name', 'foo'] + db_api._filter_and_page_query(self.ctx, query, sort_keys=sort_keys) + + args, _ = mock_paginate_query.call_args + self.assertIn(['name'], args) + + def test_nested_cluster_get_by_name(self): + cluster1 = shared.create_cluster(self.ctx, self.profile, + name='cluster1') + cluster2 = shared.create_cluster(self.ctx, self.profile, + name='cluster2', + parent=cluster1.id) + + result = db_api.cluster_get_by_name(self.ctx, 'cluster2') + self.assertEqual(cluster2.id, result.id) + + db_api.cluster_delete(self.ctx, cluster2.id) + result = db_api.cluster_get_by_name(self.ctx, 'cluster2') + self.assertIsNone(result) + + def test_cluster_get_by_name_and_parent(self): + cluster1 = shared.create_cluster(self.ctx, self.profile, + name='cluster1') + cluster2 = shared.create_cluster(self.ctx, self.profile, + name='cluster2', + parent=cluster1.id) + + result = db_api.cluster_get_by_name_and_parent(self.ctx, 'cluster2', + None) + self.assertIsNone(result) + + result = db_api.cluster_get_by_name_and_parent(self.ctx, 'cluster2', + cluster1.id) + self.assertEqual(cluster2.id, result.id) + + def test_cluster_get_show_deleted_context(self): + cluster = shared.create_cluster(self.ctx, self.profile) + + self.assertFalse(self.ctx.show_deleted) + result = db_api.cluster_get(self.ctx, cluster.id) + self.assertEqual(cluster.id, result.id) + + db_api.cluster_delete(self.ctx, cluster.id) + result = db_api.cluster_get(self.ctx, cluster.id) + self.assertIsNone(result) + + self.ctx.show_deleted = True + result = db_api.cluster_get(self.ctx, cluster.id) + self.assertEqual(cluster.id, result.id) + + def test_cluster_get_all_show_deleted(self): + clusters = [shared.create_cluster(self.ctx, self.profile) + for x in range(3)] + + results = db_api.cluster_get_all(self.ctx) + self.assertEqual(3, len(results)) + + db_api.cluster_delete(self.ctx, clusters[0].id) + results = db_api.cluster_get_all(self.ctx) + self.assertEqual(2, len(results)) + + results = db_api.cluster_get_all(self.ctx, show_deleted=True) + self.assertEqual(3, len(results)) + + def test_cluster_get_all_show_nested(self): + cluster1 = shared.create_cluster(self.ctx, self.profile, + name='cluster1') + cluster2 = shared.create_cluster(self.ctx, self.profile, + name='cluster2', + parent=cluster1.id) + + cl_db = db_api.cluster_get_all(self.ctx) + self.assertEqual(1, len(cl_db)) + self.assertEqual(cluster1.id, cl_db[0].id) + + cl_db = db_api.cluster_get_all(self.ctx, show_nested=True) + self.assertEqual(2, len(cl_db)) + cl_ids = [s.id for s in cl_db] + self.assertIn(cluster1.id, cl_ids) + self.assertIn(cluster2.id, cl_ids) + + def test_cluster_get_all_with_filters(self): + shared.create_cluster(self.ctx, self.profile, name='foo') + shared.create_cluster(self.ctx, self.profile, name='bar') + + filters = {'name': ['bar', 'quux']} + results = db_api.cluster_get_all(self.ctx, filters=filters) + self.assertEqual(1, len(results)) + self.assertEqual('bar', results[0]['name']) + + filters = {'name': 'foo'} + results = db_api.cluster_get_all(self.ctx, filters=filters) + self.assertEqual(1, len(results)) + self.assertEqual('foo', results[0]['name']) + + def test_cluster_get_all_returns_all_if_no_filters(self): + shared.create_cluster(self.ctx, self.profile) + shared.create_cluster(self.ctx, self.profile) + + filters = None + results = db_api.cluster_get_all(self.ctx, filters=filters) + + self.assertEqual(2, len(results)) + + def test_cluster_get_all_default_sort_dir(self): + dt = datetime.datetime + clusters = [shared.create_cluster(self.ctx, self.profile, + created_time=dt.utcnow()) + for x in range(3)] + + st_db = db_api.cluster_get_all(self.ctx, sort_dir='asc') + self.assertEqual(3, len(st_db)) + self.assertEqual(clusters[0].id, st_db[0].id) + self.assertEqual(clusters[1].id, st_db[1].id) + self.assertEqual(clusters[2].id, st_db[2].id) + + def test_cluster_get_all_str_sort_keys(self): + dt = datetime.datetime + clusters = [shared.create_cluster(self.ctx, self.profile, + created_time=dt.utcnow()) + for x in range(3)] + + st_db = db_api.cluster_get_all(self.ctx, sort_keys='created_time') + self.assertEqual(3, len(st_db)) + self.assertEqual(clusters[0].id, st_db[0].id) + self.assertEqual(clusters[1].id, st_db[1].id) + self.assertEqual(clusters[2].id, st_db[2].id) + + @mock.patch.object(db_api.utils, 'paginate_query') + def test_cluster_get_all_filters_sort_keys(self, mock_paginate): + sort_keys = ['name', 'status', 'created_time', + 'updated_time', 'parent'] + db_api.cluster_get_all(self.ctx, sort_keys=sort_keys) + + args = mock_paginate.call_args[0] + used_sort_keys = set(args[3]) + expected_keys = set(['name', 'status', 'created_time', + 'updated_time', 'id']) + self.assertEqual(expected_keys, used_sort_keys) + + def test_cluster_get_all_marker(self): + dt = datetime.datetime + clusters = [shared.create_cluster(self.ctx, self.profile, + created_time=dt.utcnow()) + for x in range(3)] + cl_db = db_api.cluster_get_all(self.ctx, marker=clusters[1].id) + self.assertEqual(1, len(cl_db)) + self.assertEqual(clusters[0].id, cl_db[0].id) + + def test_cluster_get_all_non_existing_marker(self): + [shared.create_cluster(self.ctx, self.profile) for x in range(3)] + uuid = 'this cluster doesnt exist' + st_db = db_api.cluster_get_all(self.ctx, marker=uuid) + self.assertEqual(3, len(st_db)) + + def test_cluster_get_all_doesnt_mutate_sort_keys(self): + [shared.create_cluster(self.ctx, self.profile) for x in range(3)] + sort_keys = ['id'] + db_api.cluster_get_all(self.ctx, sort_keys=sort_keys) + self.assertEqual(['id'], sort_keys) + + def test_cluster_count_all(self): + clusters = [shared.create_cluster(self.ctx, self.profile) + for i in range(3)] + + cl_db = db_api.cluster_count_all(self.ctx) + self.assertEqual(3, cl_db) + + db_api.cluster_delete(self.ctx, clusters[0].id) + cl_db = db_api.cluster_count_all(self.ctx) + self.assertEqual(2, cl_db) + # show deleted + cl_db = db_api.cluster_count_all(self.ctx, show_deleted=True) + self.assertEqual(3, cl_db) + + db_api.cluster_delete(self.ctx, clusters[1].id) + cl_db = db_api.cluster_count_all(self.ctx) + self.assertEqual(1, cl_db) + # show deleted + cl_db = db_api.cluster_count_all(self.ctx, show_deleted=True) + self.assertEqual(3, cl_db) + def test_cluster_count_all_with_regular_tenant(self): values = [ {'tenant_id': UUID1}, @@ -218,6 +546,25 @@ class DBAPIClusterTest(base.SenlinTestCase): self.assertEqual(5, db_api.cluster_count_all(self.ctx, tenant_safe=False)) + def test_cluster_count_all_show_nested(self): + cluster1 = shared.create_cluster(self.ctx, self.profile, name='c1') + shared.create_cluster(self.ctx, self.profile, name='c2', + parent=cluster1.id) + + results = db_api.cluster_count_all(self.ctx) + self.assertEqual(1, results) + results = db_api.cluster_count_all(self.ctx, show_nested=True) + self.assertEqual(2, results) + + def test_cluster_count_all_with_filters(self): + shared.create_cluster(self.ctx, self.profile, name='foo') + shared.create_cluster(self.ctx, self.profile, name='bar') + shared.create_cluster(self.ctx, self.profile, name='bar') + filters = {'name': 'bar'} + + cl_db = db_api.cluster_count_all(self.ctx, filters=filters) + self.assertEqual(2, cl_db) + def _deleted_cluster_existance(self, ctx, clusters, existing, deleted): for s in existing: self.assertIsNotNone(db_api.cluster_get(ctx, clusters[s].id,