Improve node_ids_by_cluster() by adding filters

This adds a 'filters' parameter to the node_ids_by_cluster() DB API and
also exposes it through the versioned Node object. It will be used by
upper layers to save DB access overhead.

Change-Id: I412f8e5ff0e6e9ef5e7380423caeafa9d67ff8eb
This commit is contained in:
tengqm 2017-03-01 01:55:18 -05:00
parent 5740532186
commit d661f890c9
4 changed files with 31 additions and 7 deletions

View File

@ -104,8 +104,8 @@ def node_get_all_by_cluster(context, cluster_id, filters=None,
project_safe=project_safe)
def node_ids_by_cluster(context, cluster_id):
return IMPL.node_ids_by_cluster(context, cluster_id)
def node_ids_by_cluster(context, cluster_id, filters=None):
return IMPL.node_ids_by_cluster(context, cluster_id, filters=None)
def node_count_by_cluster(context, cluster_id, **kwargs):

View File

@ -283,12 +283,14 @@ def node_get_all_by_cluster(context, cluster_id, filters=None,
return query.all()
def node_ids_by_cluster(context, cluster_id):
def node_ids_by_cluster(context, cluster_id, filters=None):
"""an internal API for getting node IDs."""
with session_for_read() as session:
nodes = session.query(models.Node.id).filter_by(
cluster_id=cluster_id).all()
return [n[0] for n in nodes]
query = session.query(models.Node.id).filter_by(cluster_id=cluster_id)
if filters:
query = utils.exact_filter(query, models.Node, filters)
return [n[0] for n in query.all()]
def node_count_by_cluster(context, cluster_id, **kwargs):

View File

@ -136,6 +136,11 @@ class Node(base.SenlinObject, base.VersionedObjectDictCompat):
context, cluster_id, filters=filters, project_safe=project_safe)
return [cls._from_db_object(context, cls(), obj) for obj in objs]
@classmethod
def ids_by_cluster(cls, context, cluster_id, filters=None):
"""An internal API for retrieving node ids only."""
return db_api.node_ids_by_cluster(context, cluster_id, filters=filters)
@classmethod
def count_by_cluster(cls, context, cluster_id, **kwargs):
return db_api.node_count_by_cluster(context, cluster_id, **kwargs)

View File

@ -498,7 +498,7 @@ class DBAPINodeTest(base.SenlinTestCase):
project_safe=False)
self.assertEqual(2, res)
def test_nodeids_by_cluster(self):
def test_ids_by_cluster(self):
node0 = shared.create_node(self.ctx, None, self.profile)
node1 = shared.create_node(self.ctx, self.cluster, self.profile)
node2 = shared.create_node(self.ctx, self.cluster, self.profile)
@ -512,6 +512,23 @@ class DBAPINodeTest(base.SenlinTestCase):
self.assertEqual(1, len(results))
self.assertEqual(node0.id, results[0])
def test_ids_by_cluster_with_filters(self):
node0 = shared.create_node(self.ctx, None, self.profile,
role='slave')
node1 = shared.create_node(self.ctx, self.cluster, self.profile,
role='master')
shared.create_node(self.ctx, self.cluster, self.profile)
results = db_api.node_ids_by_cluster(self.ctx, self.cluster.id,
filters={'role': 'master'})
self.assertEqual(1, len(results))
self.assertEqual(node1.id, results[0])
# retrieve orphan nodes
results = db_api.node_ids_by_cluster(self.ctx, '')
self.assertEqual(1, len(results))
self.assertEqual(node0.id, results[0])
def test_node_update(self):
node = shared.create_node(self.ctx, self.cluster, self.profile)
new_attributes = {