Merge "Improve the SG RPC callback ``security_group_info_for_ports``" into stable/2023.2

This commit is contained in:
Zuul 2024-01-03 10:29:02 +00:00 committed by Gerrit Code Review
commit 3f3125fcfe
4 changed files with 45 additions and 13 deletions

View File

@ -438,6 +438,10 @@ class SecurityGroupServerAPIShim(sg_rpc_base.SecurityGroupInfoAPIMixin):
for sg_id in p['security_group_ids']))
return [(sg_id, ) for sg_id in sg_ids]
def _is_security_group_stateful(self, context, sg_id):
sg = self.rcache.get_resource_by_id(resources.SECURITYGROUP, sg_id)
return sg.stateful
def _get_sgs_stateful_flag(self, context, sg_ids):
sgs_stateful = {}
for sg_id in sg_ids:
sg = self.rcache.get_resource_by_id(resources.SECURITYGROUP, sg_id)
sgs_stateful[sg_id] = sg.stateful
return sgs_stateful

View File

@ -215,12 +215,10 @@ class SecurityGroupInfoAPIMixin(object):
# this set will be serialized into a list by rpc code
remote_address_group_info[remote_ag_id][ethertype] = set()
direction = rule_in_db['direction']
stateful = self._is_security_group_stateful(context,
security_group_id)
rule_dict = {
'direction': direction,
'ethertype': ethertype,
'stateful': stateful}
}
for key in ('protocol', 'port_range_min', 'port_range_max',
'remote_ip_prefix', 'remote_group_id',
@ -238,6 +236,13 @@ class SecurityGroupInfoAPIMixin(object):
if rule_dict not in sg_info['security_groups'][security_group_id]:
sg_info['security_groups'][security_group_id].append(
rule_dict)
# Populate the security group "stateful" flag in the SGs list of rules.
for sg_id, stateful in self._get_sgs_stateful_flag(
context, sg_info['security_groups'].keys()).items():
for rule in sg_info['security_groups'][sg_id]:
rule['stateful'] = stateful
# Update the security groups info if they don't have any rules
sg_ids = self._select_sg_ids_for_ports(context, ports)
for (sg_id, ) in sg_ids:
@ -431,13 +436,13 @@ class SecurityGroupInfoAPIMixin(object):
"""
raise NotImplementedError()
def _is_security_group_stateful(self, context, sg_id):
"""Return whether the security group is stateful or not.
def _get_sgs_stateful_flag(self, context, sg_id):
"""Return the security groups stateful flag.
Return True if the security group associated with the given ID
is stateful, else False.
Returns a dictionary with the SG ID as key and the stateful flag:
{sg_1: True, sg_2: False, ...}
"""
return True
raise NotImplementedError()
class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
@ -534,5 +539,5 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
return ips_by_group
@db_api.retry_if_session_inactive()
def _is_security_group_stateful(self, context, sg_id):
return sg_obj.SecurityGroup.get_sg_by_id(context, sg_id).stateful
def _get_sgs_stateful_flag(self, context, sg_ids):
return sg_obj.SecurityGroup.get_sgs_stateful_flag(context, sg_ids)

View File

@ -133,6 +133,13 @@ class SecurityGroup(rbac_db.NeutronRbacObject):
security_group_ids=[obj_id])
return {port.project_id for port in port_objs}
@classmethod
@db_api.CONTEXT_READER
def get_sgs_stateful_flag(cls, context, sg_ids):
query = context.session.query(cls.db_model.id, cls.db_model.stateful)
query = query.filter(cls.db_model.id.in_(sg_ids))
return dict(query.all())
@base.NeutronObjectRegistry.register
class DefaultSecurityGroup(base.NeutronDbObject):

View File

@ -210,6 +210,22 @@ class SecurityGroupDbObjTestCase(test_base.BaseDbObjectTestCase,
self.assertEqual(len(sg_obj.rules), 0)
self.assertIsNone(listed_objs[0].rules)
def test_get_sgs_stateful_flag(self):
for obj in self.objs:
obj.create()
sg_ids = tuple(sg.id for sg in self.objs)
sgs_stateful = securitygroup.SecurityGroup.get_sgs_stateful_flag(
self.context, sg_ids)
for sg_id, stateful in sgs_stateful.items():
for obj in (obj for obj in self.objs if obj.id == sg_id):
self.assertEqual(obj.stateful, stateful)
sg_ids = sg_ids + ('random_id_not_present', )
sgs_stateful = securitygroup.SecurityGroup.get_sgs_stateful_flag(
self.context, sg_ids)
self.assertEqual(len(self.objs), len(sgs_stateful))
class DefaultSecurityGroupIfaceObjTestCase(test_base.BaseObjectIfaceTestCase):