Get sec group ids after address group update

This change adds code to retrieve for the agent the security group ids
affected by an update or deletion of an address group.

Also adds event notificatoins to add and remove addresses from address
groups.

Co-authored-by: Hang Yang <hangyang@verizonmedia.com>
Change-Id: I34766b96cb775356664f5e0d48a08a22ac6898e2
This commit is contained in:
Miguel Lavalle 2020-11-19 18:45:52 -06:00 committed by Hang Yang
parent 6db15a004d
commit 92359b6fb9
5 changed files with 61 additions and 29 deletions

View File

@ -248,13 +248,17 @@ class SecurityGroupAgentRpc(object):
def address_group_updated(self, address_group_id): def address_group_updated(self, address_group_id):
LOG.info("Address group updated %r", address_group_id) LOG.info("Address group updated %r", address_group_id)
# TODO(mlavalle) A follow up patch in the address groups implementation security_group_ids = (
# series will add more code here self.plugin_rpc.get_secgroup_ids_for_address_group(
address_group_id))
self.security_groups_rule_updated(security_group_ids)
def address_group_deleted(self, address_group_id): def address_group_deleted(self, address_group_id):
LOG.info("Address group deleted %r", address_group_id) LOG.info("Address group deleted %r", address_group_id)
# TODO(mlavalle) A follow up patch in the address groups implementation security_group_ids = (
# series will add more code here self.plugin_rpc.get_secgroup_ids_for_address_group(
address_group_id))
self.security_groups_rule_updated(security_group_ids)
def remove_devices_filter(self, device_ids): def remove_devices_filter(self, device_ids):
if not device_ids: if not device_ids:

View File

@ -246,13 +246,10 @@ class SecurityGroupServerAPIShim(sg_rpc_base.SecurityGroupInfoAPIMixin):
# error. # error.
raise NotImplementedError() raise NotImplementedError()
def get_address_group_details(self, address_group_id): def get_secgroup_ids_for_address_group(self, address_group_id):
ag_obj = self.rcache.get_resource_by_id(resources.ADDRESSGROUP, filters = {'remote_address_group_id': (address_group_id, )}
address_group_id) return set([rule.security_group_id for rule in
if not ag_obj: self.rcache.get_resources('SecurityGroupRule', filters)])
LOG.debug("Address group %s does not exist in cache.",
address_group_id)
return ag_obj
def _add_child_sg_rules(self, rtype, event, trigger, context, updated, def _add_child_sg_rules(self, rtype, event, trigger, context, updated,
**kwargs): **kwargs):

View File

@ -108,6 +108,12 @@ class AddressGroupDbMixin(ag_ext.AddressGroupPluginBase):
addr_assoc = ag_obj.AddressAssociation(context, **args) addr_assoc = ag_obj.AddressAssociation(context, **args)
addr_assoc.create() addr_assoc.create()
ag.update() # reload synthetic fields ag.update() # reload synthetic fields
# TODO(hangyang) this notification should be updated to publish when
# the callback handler handle_event, class _ObjectChangeHandler in
# neutron.plugins.ml2.ovo_rpc is updated to receive notifications with
# new style payload objects as argument.
registry.notify(ADDRESS_GROUP, events.AFTER_UPDATE, self,
context=context, address_group_id=ag.id)
return {'address_group': self._make_address_group_dict(ag)} return {'address_group': self._make_address_group_dict(ag)}
def remove_addresses(self, context, address_group_id, addresses): def remove_addresses(self, context, address_group_id, addresses):
@ -121,6 +127,12 @@ class AddressGroupDbMixin(ag_ext.AddressGroupPluginBase):
ag_obj.AddressAssociation.delete_objects( ag_obj.AddressAssociation.delete_objects(
context, address_group_id=address_group_id, address=addr) context, address_group_id=address_group_id, address=addr)
ag.update() # reload synthetic fields ag.update() # reload synthetic fields
# TODO(hangyang) this notification should be updated to publish when
# the callback handler handle_event, class _ObjectChangeHandler in
# neutron.plugins.ml2.ovo_rpc is updated to receive notifications with
# new style payload objects as argument.
registry.notify(ADDRESS_GROUP, events.AFTER_UPDATE, self,
context=context, address_group_id=ag.id)
return {'address_group': self._make_address_group_dict(ag)} return {'address_group': self._make_address_group_dict(ag)}
def create_address_group(self, context, address_group): def create_address_group(self, context, address_group):
@ -132,15 +144,15 @@ class AddressGroupDbMixin(ag_ext.AddressGroupPluginBase):
'description': fields['description']} 'description': fields['description']}
ag = ag_obj.AddressGroup(context, **args) ag = ag_obj.AddressGroup(context, **args)
ag.create() ag.create()
if fields.get('addresses') is not constants.ATTR_NOT_SPECIFIED:
self.add_addresses(context, ag.id, fields)
ag.update() # reload synthetic fields
# TODO(mlavalle) this notification should be updated to publish when # TODO(mlavalle) this notification should be updated to publish when
# the callback handler handle_event, class _ObjectChangeHandler in # the callback handler handle_event, class _ObjectChangeHandler in
# neutron.plugins.ml2.ovo_rpc is updated to receive notifications with # neutron.plugins.ml2.ovo_rpc is updated to receive notifications with
# new style payload objects as argument. # new style payload objects as argument.
registry.notify(ADDRESS_GROUP, events.AFTER_CREATE, self, registry.notify(ADDRESS_GROUP, events.AFTER_CREATE, self,
context=context, address_group_id=ag.id) context=context, address_group_id=ag.id)
if fields.get('addresses') is not constants.ATTR_NOT_SPECIFIED:
self.add_addresses(context, ag.id, fields)
ag.update() # reload synthetic fields
return self._make_address_group_dict(ag) return self._make_address_group_dict(ag)
def update_address_group(self, context, id, address_group): def update_address_group(self, context, id, address_group):

View File

@ -126,6 +126,8 @@ class SecurityGroupServerAPIShimTestCase(base.BaseTestCase):
port_range_min=400, port_range_min=400,
remote_group_id=attrs['id'], remote_group_id=attrs['id'],
revision_number=1, revision_number=1,
remote_address_group_id=kwargs.get('remote_address_group_id',
None),
) )
attrs['rules'] = [sg_rule] attrs['rules'] = [sg_rule]
attrs.update(**kwargs) attrs.update(**kwargs)
@ -194,16 +196,15 @@ class SecurityGroupServerAPIShimTestCase(base.BaseTestCase):
self.sg_agent.security_groups_member_updated.assert_called_with( self.sg_agent.security_groups_member_updated.assert_called_with(
{s1.id}) {s1.id})
def test_get_address_group_details(self): def test_get_secgroup_ids_for_address_group(self):
ag = self._make_address_group_ovo() ag = self._make_address_group_ovo()
retrieved_ag = self.shim.get_address_group_details(ag.id) sg1 = self._make_security_group_ovo(remote_address_group_id=ag.id)
self.assertEqual(ag.id, retrieved_ag.id) sg2 = self._make_security_group_ovo(remote_address_group_id=ag.id)
self.assertEqual(ag.name, retrieved_ag.name) sg3 = self._make_security_group_ovo()
self.assertEqual(ag.description, retrieved_ag.description) sec_group_ids = self.shim.get_secgroup_ids_for_address_group(ag.id)
self.assertEqual(ag.addresses[0].address, self.assertEqual(set([sg1.id, sg2.id]), set(sec_group_ids))
retrieved_ag.addresses[0].address) self.assertEqual(2, len(sec_group_ids))
self.assertEqual(ag.addresses[1].address, self.assertNotIn(sg3.id, sec_group_ids)
retrieved_ag.addresses[1].address)
def test_address_group_update_events(self): def test_address_group_update_events(self):
ag = self._make_address_group_ovo() ag = self._make_address_group_ovo()

View File

@ -39,14 +39,24 @@ class OVOServerRpcInterfaceTestCase(test_plugin.Ml2PluginV2TestCase):
self.ovo_push_interface_p.stop() self.ovo_push_interface_p.stop()
self.plugin.ovo_notifier = ovo_rpc.OVOServerRpcInterface() self.plugin.ovo_notifier = ovo_rpc.OVOServerRpcInterface()
def _assert_object_received(self, ovotype, oid=None, event=None): def _assert_object_received(self, ovotype, oid=None, event=None,
count=1):
self.plugin.ovo_notifier.wait() self.plugin.ovo_notifier.wait()
match = 0
for obj, evt in self.received: for obj, evt in self.received:
if isinstance(obj, ovotype): if isinstance(obj, ovotype):
if (obj.id == oid or not oid) and (not event or event == evt): if (obj.id == oid or not oid) and (not event or event == evt):
return obj match += 1
self.fail("Could not find OVO %s with ID %s in %s" % if count == 1:
(ovotype, oid, self.received)) return obj
if count > 1:
self.assertEqual(
match, count,
"Could not find match %s for OVO %s with ID %s in %s" %
(match, ovotype, oid, self.received))
return
self.fail("Could not find OVO %s with ID %s or event %s in %s" %
(ovotype, oid, event, self.received))
def test_network_lifecycle(self): def test_network_lifecycle(self):
with self.network() as n: with self.network() as n:
@ -112,11 +122,19 @@ class OVOServerRpcInterfaceTestCase(test_plugin.Ml2PluginV2TestCase):
'addresses': ['10.0.0.1/32', 'addresses': ['10.0.0.1/32',
'2001:db8::/32']}}) '2001:db8::/32']}})
self._assert_object_received( self._assert_object_received(
address_group.AddressGroup, ag['id'], 'updated') address_group.AddressGroup, ag['id'], 'updated', 2)
self.plugin.update_address_group(self.ctx, ag['id'], self.plugin.update_address_group(self.ctx, ag['id'],
{'address_group': {'name': 'an-address-group-other-name'}}) {'address_group': {'name': 'an-address-group-other-name'}})
self._assert_object_received( self._assert_object_received(
address_group.AddressGroup, ag['id'], 'updated') address_group.AddressGroup, ag['id'], 'updated', 3)
self.plugin.add_addresses(self.ctx, ag['id'],
{'addresses': ['10.0.0.2/32']})
self._assert_object_received(
address_group.AddressGroup, ag['id'], 'updated', 4)
self.plugin.remove_addresses(self.ctx, ag['id'],
{'addresses': ['10.0.0.1/32']})
self._assert_object_received(
address_group.AddressGroup, ag['id'], 'updated', 5)
self.plugin.delete_address_group(self.ctx, ag['id']) self.plugin.delete_address_group(self.ctx, ag['id'])
self._assert_object_received( self._assert_object_received(
address_group.AddressGroup, ag['id'], 'deleted') address_group.AddressGroup, ag['id'], 'deleted')