Merge "Improve the SG RPC callback ``security_group_info_for_ports``" into stable/2023.2
This commit is contained in:
commit
3f3125fcfe
|
@ -438,6 +438,10 @@ class SecurityGroupServerAPIShim(sg_rpc_base.SecurityGroupInfoAPIMixin):
|
|||
for sg_id in p['security_group_ids']))
|
||||
return [(sg_id, ) for sg_id in sg_ids]
|
||||
|
||||
def _is_security_group_stateful(self, context, sg_id):
|
||||
sg = self.rcache.get_resource_by_id(resources.SECURITYGROUP, sg_id)
|
||||
return sg.stateful
|
||||
def _get_sgs_stateful_flag(self, context, sg_ids):
|
||||
sgs_stateful = {}
|
||||
for sg_id in sg_ids:
|
||||
sg = self.rcache.get_resource_by_id(resources.SECURITYGROUP, sg_id)
|
||||
sgs_stateful[sg_id] = sg.stateful
|
||||
|
||||
return sgs_stateful
|
||||
|
|
|
@ -215,12 +215,10 @@ class SecurityGroupInfoAPIMixin(object):
|
|||
# this set will be serialized into a list by rpc code
|
||||
remote_address_group_info[remote_ag_id][ethertype] = set()
|
||||
direction = rule_in_db['direction']
|
||||
stateful = self._is_security_group_stateful(context,
|
||||
security_group_id)
|
||||
rule_dict = {
|
||||
'direction': direction,
|
||||
'ethertype': ethertype,
|
||||
'stateful': stateful}
|
||||
}
|
||||
|
||||
for key in ('protocol', 'port_range_min', 'port_range_max',
|
||||
'remote_ip_prefix', 'remote_group_id',
|
||||
|
@ -238,6 +236,13 @@ class SecurityGroupInfoAPIMixin(object):
|
|||
if rule_dict not in sg_info['security_groups'][security_group_id]:
|
||||
sg_info['security_groups'][security_group_id].append(
|
||||
rule_dict)
|
||||
|
||||
# Populate the security group "stateful" flag in the SGs list of rules.
|
||||
for sg_id, stateful in self._get_sgs_stateful_flag(
|
||||
context, sg_info['security_groups'].keys()).items():
|
||||
for rule in sg_info['security_groups'][sg_id]:
|
||||
rule['stateful'] = stateful
|
||||
|
||||
# Update the security groups info if they don't have any rules
|
||||
sg_ids = self._select_sg_ids_for_ports(context, ports)
|
||||
for (sg_id, ) in sg_ids:
|
||||
|
@ -431,13 +436,13 @@ class SecurityGroupInfoAPIMixin(object):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _is_security_group_stateful(self, context, sg_id):
|
||||
"""Return whether the security group is stateful or not.
|
||||
def _get_sgs_stateful_flag(self, context, sg_id):
|
||||
"""Return the security groups stateful flag.
|
||||
|
||||
Return True if the security group associated with the given ID
|
||||
is stateful, else False.
|
||||
Returns a dictionary with the SG ID as key and the stateful flag:
|
||||
{sg_1: True, sg_2: False, ...}
|
||||
"""
|
||||
return True
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
|
||||
|
@ -534,5 +539,5 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
|
|||
return ips_by_group
|
||||
|
||||
@db_api.retry_if_session_inactive()
|
||||
def _is_security_group_stateful(self, context, sg_id):
|
||||
return sg_obj.SecurityGroup.get_sg_by_id(context, sg_id).stateful
|
||||
def _get_sgs_stateful_flag(self, context, sg_ids):
|
||||
return sg_obj.SecurityGroup.get_sgs_stateful_flag(context, sg_ids)
|
||||
|
|
|
@ -133,6 +133,13 @@ class SecurityGroup(rbac_db.NeutronRbacObject):
|
|||
security_group_ids=[obj_id])
|
||||
return {port.project_id for port in port_objs}
|
||||
|
||||
@classmethod
|
||||
@db_api.CONTEXT_READER
|
||||
def get_sgs_stateful_flag(cls, context, sg_ids):
|
||||
query = context.session.query(cls.db_model.id, cls.db_model.stateful)
|
||||
query = query.filter(cls.db_model.id.in_(sg_ids))
|
||||
return dict(query.all())
|
||||
|
||||
|
||||
@base.NeutronObjectRegistry.register
|
||||
class DefaultSecurityGroup(base.NeutronDbObject):
|
||||
|
|
|
@ -210,6 +210,22 @@ class SecurityGroupDbObjTestCase(test_base.BaseDbObjectTestCase,
|
|||
self.assertEqual(len(sg_obj.rules), 0)
|
||||
self.assertIsNone(listed_objs[0].rules)
|
||||
|
||||
def test_get_sgs_stateful_flag(self):
|
||||
for obj in self.objs:
|
||||
obj.create()
|
||||
|
||||
sg_ids = tuple(sg.id for sg in self.objs)
|
||||
sgs_stateful = securitygroup.SecurityGroup.get_sgs_stateful_flag(
|
||||
self.context, sg_ids)
|
||||
for sg_id, stateful in sgs_stateful.items():
|
||||
for obj in (obj for obj in self.objs if obj.id == sg_id):
|
||||
self.assertEqual(obj.stateful, stateful)
|
||||
|
||||
sg_ids = sg_ids + ('random_id_not_present', )
|
||||
sgs_stateful = securitygroup.SecurityGroup.get_sgs_stateful_flag(
|
||||
self.context, sg_ids)
|
||||
self.assertEqual(len(self.objs), len(sgs_stateful))
|
||||
|
||||
|
||||
class DefaultSecurityGroupIfaceObjTestCase(test_base.BaseObjectIfaceTestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue