diff --git a/neutron/api/rpc/handlers/securitygroups_rpc.py b/neutron/api/rpc/handlers/securitygroups_rpc.py index f46d5e4b967..23beac95814 100644 --- a/neutron/api/rpc/handlers/securitygroups_rpc.py +++ b/neutron/api/rpc/handlers/securitygroups_rpc.py @@ -431,6 +431,10 @@ class SecurityGroupServerAPIShim(sg_rpc_base.SecurityGroupInfoAPIMixin): for sg_id in p['security_group_ids'])) return [(sg_id, ) for sg_id in sg_ids] - def _is_security_group_stateful(self, context, sg_id): - sg = self.rcache.get_resource_by_id(resources.SECURITYGROUP, sg_id) - return sg.stateful + def _get_sgs_stateful_flag(self, context, sg_ids): + sgs_stateful = {} + for sg_id in sg_ids: + sg = self.rcache.get_resource_by_id(resources.SECURITYGROUP, sg_id) + sgs_stateful[sg_id] = sg.stateful + + return sgs_stateful diff --git a/neutron/db/securitygroups_rpc_base.py b/neutron/db/securitygroups_rpc_base.py index e970f6d4a54..3f8df31e2a6 100644 --- a/neutron/db/securitygroups_rpc_base.py +++ b/neutron/db/securitygroups_rpc_base.py @@ -211,12 +211,10 @@ class SecurityGroupInfoAPIMixin(object): # this set will be serialized into a list by rpc code remote_address_group_info[remote_ag_id][ethertype] = set() direction = rule_in_db['direction'] - stateful = self._is_security_group_stateful(context, - security_group_id) rule_dict = { 'direction': direction, 'ethertype': ethertype, - 'stateful': stateful} + } for key in ('protocol', 'port_range_min', 'port_range_max', 'remote_ip_prefix', 'remote_group_id', @@ -234,6 +232,13 @@ class SecurityGroupInfoAPIMixin(object): if rule_dict not in sg_info['security_groups'][security_group_id]: sg_info['security_groups'][security_group_id].append( rule_dict) + + # Populate the security group "stateful" flag in the SGs list of rules. + for sg_id, stateful in self._get_sgs_stateful_flag( + context, sg_info['security_groups'].keys()).items(): + for rule in sg_info['security_groups'][sg_id]: + rule['stateful'] = stateful + # Update the security groups info if they don't have any rules sg_ids = self._select_sg_ids_for_ports(context, ports) for (sg_id, ) in sg_ids: @@ -427,13 +432,13 @@ class SecurityGroupInfoAPIMixin(object): """ raise NotImplementedError() - def _is_security_group_stateful(self, context, sg_id): - """Return whether the security group is stateful or not. + def _get_sgs_stateful_flag(self, context, sg_id): + """Return the security groups stateful flag. - Return True if the security group associated with the given ID - is stateful, else False. + Returns a dictionary with the SG ID as key and the stateful flag: + {sg_1: True, sg_2: False, ...} """ - return True + raise NotImplementedError() class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin, @@ -526,5 +531,5 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin, return ips_by_group @db_api.retry_if_session_inactive() - def _is_security_group_stateful(self, context, sg_id): - return sg_obj.SecurityGroup.get_sg_by_id(context, sg_id).stateful + def _get_sgs_stateful_flag(self, context, sg_ids): + return sg_obj.SecurityGroup.get_sgs_stateful_flag(context, sg_ids) diff --git a/neutron/objects/securitygroup.py b/neutron/objects/securitygroup.py index 59c3c60b4e4..f2558d8f6ee 100644 --- a/neutron/objects/securitygroup.py +++ b/neutron/objects/securitygroup.py @@ -11,6 +11,7 @@ # under the License. from neutron_lib import context as context_lib +from neutron_lib.db import api as db_api from neutron_lib.objects import common_types from neutron_lib.utils import net as net_utils from oslo_utils import versionutils @@ -130,6 +131,13 @@ class SecurityGroup(rbac_db.NeutronRbacObject): security_group_ids=[obj_id]) return {port.tenant_id for port in port_objs} + @classmethod + @db_api.CONTEXT_READER + def get_sgs_stateful_flag(cls, context, sg_ids): + query = context.session.query(cls.db_model.id, cls.db_model.stateful) + query = query.filter(cls.db_model.id.in_(sg_ids)) + return dict(query.all()) + @base.NeutronObjectRegistry.register class DefaultSecurityGroup(base.NeutronDbObject): diff --git a/neutron/tests/unit/objects/test_securitygroup.py b/neutron/tests/unit/objects/test_securitygroup.py index 53a7901d3ac..011cabeb810 100644 --- a/neutron/tests/unit/objects/test_securitygroup.py +++ b/neutron/tests/unit/objects/test_securitygroup.py @@ -210,6 +210,22 @@ class SecurityGroupDbObjTestCase(test_base.BaseDbObjectTestCase, self.assertEqual(len(sg_obj.rules), 0) self.assertIsNone(listed_objs[0].rules) + def test_get_sgs_stateful_flag(self): + for obj in self.objs: + obj.create() + + sg_ids = tuple(sg.id for sg in self.objs) + sgs_stateful = securitygroup.SecurityGroup.get_sgs_stateful_flag( + self.context, sg_ids) + for sg_id, stateful in sgs_stateful.items(): + for obj in (obj for obj in self.objs if obj.id == sg_id): + self.assertEqual(obj.stateful, stateful) + + sg_ids = sg_ids + ('random_id_not_present', ) + sgs_stateful = securitygroup.SecurityGroup.get_sgs_stateful_flag( + self.context, sg_ids) + self.assertEqual(len(self.objs), len(sgs_stateful)) + class DefaultSecurityGroupIfaceObjTestCase(test_base.BaseObjectIfaceTestCase):