Compute: use instance object for refresh_instance_security_rules

Make sure that the instance object is passed in the RPC message and not
an instance dictionary.

blueprint liberty-objects

Change-Id: I4ac05960b4900d11917739e2f25fa9926463ed95
This commit is contained in:
Gary Kotton 2015-06-02 04:24:41 -07:00 committed by John Garbutt
parent 422809c289
commit 12fbe6f082
5 changed files with 49 additions and 56 deletions

View File

@ -4172,14 +4172,11 @@ class SecurityGroupAPI(base.Base, security_group_base.SecurityGroupBase):
def trigger_rules_refresh(self, context, id):
"""Called when a rule is added to or removed from a security_group."""
security_group = self.db.security_group_get(
context, id, columns_to_join=['instances'])
for instance in security_group['instances']:
if instance['host'] is not None:
instances = objects.InstanceList.get_by_security_group_id(context, id)
for instance in instances:
if instance.host is not None:
self.compute_rpcapi.refresh_instance_security_rules(
context, instance['host'], instance)
context, instance.host, instance)
def trigger_members_refresh(self, context, group_ids):
"""Called when a security group gains a new or loses a member.
@ -4197,25 +4194,19 @@ class SecurityGroupAPI(base.Base, security_group_base.SecurityGroupBase):
group_id))
# ..then we distill the rules into the groups to which they belong..
security_groups = set()
for rule in security_group_rules:
security_group = self.db.security_group_get(
context, rule['parent_group_id'],
columns_to_join=['instances'])
security_groups.add(security_group)
# ..then we find the instances that are members of these groups..
instances = {}
for security_group in security_groups:
for instance in security_group['instances']:
if instance['uuid'] not in instances:
instances[instance['uuid']] = instance
for rule in security_group_rules:
sg_instances = objects.InstanceList.get_by_security_group_id(
context, rule['parent_group_id'])
for instance in sg_instances:
if instance.uuid not in instances:
instances[instance.uuid] = instance
# ..then we send a request to refresh the rules for each instance.
for instance in instances.values():
if instance['host']:
if instance.host:
self.compute_rpcapi.refresh_instance_security_rules(
context, instance['host'], instance)
context, instance.host, instance)
def get_instance_security_groups(self, context, instance_uuid,
detailed=False):

View File

@ -649,7 +649,7 @@ class ComputeVirtAPI(virtapi.VirtAPI):
class ComputeManager(manager.Manager):
"""Manages the running instances from creation to destruction."""
target = messaging.Target(version='4.3')
target = messaging.Target(version='4.4')
# How long to wait in seconds before re-issuing a shutdown
# signal to a instance during power off. The overall

View File

@ -300,6 +300,7 @@ class ComputeAPI(object):
* 4.1 - Make prep_resize() and resize_instance() send Flavor object
* 4.2 - Add migration argument to live_migration()
* 4.3 - Added get_mks_console method
* 4.4 - Make refresh_instance_security_rules send an instance object
'''
VERSION_ALIASES = {
@ -925,10 +926,11 @@ class ComputeAPI(object):
security_group_id=security_group_id)
def refresh_instance_security_rules(self, ctxt, host, instance):
version = '4.0'
# TODO(danms): This needs to be fixed for objects!
instance_p = jsonutils.to_primitive(instance)
version = '4.4'
if not self.client.can_send_version(version):
version = '4.0'
instance = objects_base.obj_to_primitive(instance)
cctxt = self.client.prepare(server=_compute_host(None, instance),
version=version)
cctxt.cast(ctxt, 'refresh_instance_security_rules',
instance=instance_p)
instance=instance)

View File

@ -9814,15 +9814,17 @@ class ComputeAPITestCase(BaseTestCase):
mock_rule = db_fakes.FakeModel({'parent_group_id': 1})
return [mock_rule]
def group_get(*args, **kwargs):
mock_group = db_fakes.FakeModel({'instances': [instance]})
return mock_group
@staticmethod
def get_by_security_group_id(context, security_group_id):
return [instance]
self.stubs.Set(
self.compute_api.db,
'security_group_rule_get_by_security_group_grantee',
rule_get)
self.stubs.Set(self.compute_api.db, 'security_group_get', group_get)
self.stubs.Set(objects.InstanceList, 'get_by_security_group_id',
get_by_security_group_id)
rpcapi = compute_rpcapi.ComputeAPI
self.mox.StubOutWithMock(rpcapi, 'refresh_instance_security_rules')
@ -9836,24 +9838,25 @@ class ComputeAPITestCase(BaseTestCase):
def test_secgroup_refresh_once(self):
instance = self._create_fake_instance_obj()
@staticmethod
def get_by_security_group_id(context, security_group_id):
return [instance]
def rule_get(*args, **kwargs):
mock_rule = db_fakes.FakeModel({'parent_group_id': 1})
return [mock_rule]
def group_get(*args, **kwargs):
mock_group = db_fakes.FakeModel({'instances': [instance]})
return mock_group
self.stubs.Set(
self.compute_api.db,
'security_group_rule_get_by_security_group_grantee',
rule_get)
self.stubs.Set(self.compute_api.db, 'security_group_get', group_get)
self.stubs.Set(objects.InstanceList, 'get_by_security_group_id',
get_by_security_group_id)
rpcapi = compute_rpcapi.ComputeAPI
self.mox.StubOutWithMock(rpcapi, 'refresh_instance_security_rules')
rpcapi.refresh_instance_security_rules(self.context,
instance['host'],
instance.host,
instance)
self.mox.ReplayAll()
@ -9884,12 +9887,11 @@ class ComputeAPITestCase(BaseTestCase):
def test_secrule_refresh(self):
instance = self._create_fake_instance_obj()
def group_get(*args, **kwargs):
mock_group = db_fakes.FakeModel({'instances': [instance]})
return mock_group
self.stubs.Set(self.compute_api.db, 'security_group_get', group_get)
@staticmethod
def get_by_security_group_id(context, security_group_id):
return [instance]
self.stubs.Set(objects.InstanceList, 'get_by_security_group_id',
get_by_security_group_id)
rpcapi = compute_rpcapi.ComputeAPI
self.mox.StubOutWithMock(rpcapi, 'refresh_instance_security_rules')
rpcapi.refresh_instance_security_rules(self.context,
@ -9902,12 +9904,11 @@ class ComputeAPITestCase(BaseTestCase):
def test_secrule_refresh_once(self):
instance = self._create_fake_instance_obj()
def group_get(*args, **kwargs):
mock_group = db_fakes.FakeModel({'instances': [instance]})
return mock_group
self.stubs.Set(self.compute_api.db, 'security_group_get', group_get)
@staticmethod
def get_by_security_group_id(context, security_group_id):
return [instance]
self.stubs.Set(objects.InstanceList, 'get_by_security_group_id',
get_by_security_group_id)
rpcapi = compute_rpcapi.ComputeAPI
self.mox.StubOutWithMock(rpcapi, 'refresh_instance_security_rules')
rpcapi.refresh_instance_security_rules(self.context,
@ -9918,12 +9919,11 @@ class ComputeAPITestCase(BaseTestCase):
self.security_group_api.trigger_rules_refresh(self.context, [1, 2])
def test_secrule_refresh_none(self):
def group_get(*args, **kwargs):
mock_group = db_fakes.FakeModel({'instances': []})
return mock_group
self.stubs.Set(self.compute_api.db, 'security_group_get', group_get)
@staticmethod
def get_by_security_group_id(context, security_group_id):
return []
self.stubs.Set(objects.InstanceList, 'get_by_security_group_id',
get_by_security_group_id)
rpcapi = compute_rpcapi.ComputeAPI
self.mox.StubOutWithMock(rpcapi, 'refresh_instance_security_rules')
self.mox.ReplayAll()

View File

@ -327,10 +327,10 @@ class ComputeRpcAPITestCase(test.NoDBTestCase):
security_group_id='id', host='host', version='4.0')
def test_refresh_instance_security_rules(self):
expected_args = {'instance': self.fake_instance}
expected_args = {'instance': self.fake_instance_obj}
self._test_compute_api('refresh_instance_security_rules', 'cast',
expected_args, host='fake_host',
instance=self.fake_instance_obj, version='4.0')
instance=self.fake_instance_obj, version='4.4')
def test_remove_aggregate_host(self):
self._test_compute_api('remove_aggregate_host', 'cast',