diff --git a/vmware_nsx/plugins/nsx_p/plugin.py b/vmware_nsx/plugins/nsx_p/plugin.py index 1db3fe625f..e4f4429f3f 100644 --- a/vmware_nsx/plugins/nsx_p/plugin.py +++ b/vmware_nsx/plugins/nsx_p/plugin.py @@ -2544,38 +2544,29 @@ class NsxPolicyPlugin(nsx_plugin_common.NsxPluginV3Base): cond_val=scope_and_tag, cond_key=policy_constants.CONDITION_KEY_TAG, cond_member_type=policy_constants.CONDITION_MEMBER_PORT) - # Create the group - try: - self.nsxpolicy.group.create_or_overwrite_with_conditions( - nsx_name, NSX_P_GLOBAL_DOMAIN_ID, group_id=sg_id, - description=secgroup.get('description'), - conditions=[condition], tags=tags) - except Exception as e: - msg = (_("Failed to create NSX group for SG %(sg)s: " - "%(e)s") % {'sg': sg_id, 'e': e}) - raise nsx_exc.NsxPluginException(err_msg=msg) category = NSX_P_REGULAR_SECTION_CATEGORY if secgroup.get(provider_sg.PROVIDER) is True: category = NSX_P_PROVIDER_SECTION_CATEGORY - # create the communication map (=section) and entries (=rules) try: - if entries: - self.nsxpolicy.comm_map.create_with_entries( - nsx_name, NSX_P_GLOBAL_DOMAIN_ID, map_id=sg_id, + with policy_trans.NsxPolicyTransaction(): + # Create the group + self.nsxpolicy.group.create_or_overwrite_with_conditions( + nsx_name, NSX_P_GLOBAL_DOMAIN_ID, group_id=sg_id, description=secgroup.get('description'), - entries=entries, - tags=tags, category=category) - else: + conditions=[condition], tags=tags) + + # create the communication map (=section) and entries (=rules) self.nsxpolicy.comm_map.create_or_overwrite_map_only( nsx_name, NSX_P_GLOBAL_DOMAIN_ID, map_id=sg_id, description=secgroup.get('description'), tags=tags, category=category) + for entry in entries: + self.nsxpolicy.comm_map.create_entry_from_def(entry) except Exception as e: - msg = (_("Failed to create NSX communication map for SG %(sg)s: " + msg = (_("Failed to create NSX resources for SG %(sg)s: " "%(e)s") % {'sg': sg_id, 'e': e}) - self.nsxpolicy.group.delete(NSX_P_GLOBAL_DOMAIN_ID, sg_id) raise nsx_exc.NsxPluginException(err_msg=msg) def _get_rule_ip_protocol(self, sg_rule): @@ -2643,8 +2634,7 @@ class NsxPolicyPlugin(nsx_plugin_common.NsxPluginV3Base): def _create_security_group_backend_rule(self, context, map_id, sg_rule, secgroup_logging, - is_provider_sg=False, - create_rule=True): + is_provider_sg=False): """Create backend resources for a DFW rule All rule resources (service, groups) will be created @@ -2706,34 +2696,20 @@ class NsxPolicyPlugin(nsx_plugin_common.NsxPluginV3Base): this_group_id)] action = (policy_constants.ACTION_DENY if is_provider_sg else policy_constants.ACTION_ALLOW) - if create_rule: - self.nsxpolicy.comm_map.create_entry( - nsx_name, NSX_P_GLOBAL_DOMAIN_ID, - map_id, entry_id=sg_rule['id'], - description=sg_rule.get('description'), - service_ids=[service] if service else None, - ip_protocol=ip_protocol, - action=action, - source_groups=[source] if source else None, - dest_groups=[destination] if destination else None, - scope=scope, - direction=direction, logged=logging, - tag=sg_rule.get('project_id')) - else: - # Just return the rule entry without creating it - rule_entry = self.nsxpolicy.comm_map.build_entry( - nsx_name, NSX_P_GLOBAL_DOMAIN_ID, - map_id, entry_id=sg_rule['id'], - description=sg_rule.get('description'), - service_ids=[service] if service else None, - ip_protocol=ip_protocol, - action=action, - source_groups=[source] if source else None, - dest_groups=[destination] if destination else None, - scope=scope, - tag=sg_rule.get('project_id'), - direction=direction, logged=logging) - return rule_entry + # Just return the rule entry without creating it + rule_entry = self.nsxpolicy.comm_map.build_entry( + nsx_name, NSX_P_GLOBAL_DOMAIN_ID, + map_id, entry_id=sg_rule['id'], + description=sg_rule.get('description'), + service_ids=[service] if service else None, + ip_protocol=ip_protocol, + action=action, + source_groups=[source] if source else None, + dest_groups=[destination] if destination else None, + scope=scope, + tag=sg_rule.get('project_id'), + direction=direction, logged=logging) + return rule_entry def create_security_group(self, context, security_group, default_sg=False): secgroup = security_group['security_group'] @@ -2768,11 +2744,13 @@ class NsxPolicyPlugin(nsx_plugin_common.NsxPluginV3Base): sg_rules = secgroup_db['security_group_rules'] secgroup_logging = secgroup.get(sg_logging.LOGGING, False) backend_rules = [] - for sg_rule in sg_rules: - rule_entry = self._create_security_group_backend_rule( - context, secgroup_db['id'], sg_rule, - secgroup_logging, create_rule=False) - backend_rules.append(rule_entry) + with policy_trans.NsxPolicyTransaction(): + # Create all the rules resources in a single transaction + for sg_rule in sg_rules: + rule_entry = self._create_security_group_backend_rule( + context, secgroup_db['id'], sg_rule, + secgroup_logging) + backend_rules.append(rule_entry) # Create Group & communication map on the NSX self._create_security_group_backend_resources( context, secgroup, backend_rules) @@ -2885,13 +2863,29 @@ class NsxPolicyPlugin(nsx_plugin_common.NsxPluginV3Base): is_provider_sg = sg.get(provider_sg.PROVIDER) secgroup_logging = self._is_security_group_logged(context, sg_id) - for rule_data in rules_db: - #TODO(asarfaty): Consider using update_entries with all the rules - # if multiple rules are added - # create the NSX backend rule - self._create_security_group_backend_rule( - context, sg_id, rule_data, secgroup_logging, - is_provider_sg=is_provider_sg) + category = (NSX_P_PROVIDER_SECTION_CATEGORY if is_provider_sg + else NSX_P_REGULAR_SECTION_CATEGORY) + # Create the NSX backend rules in a single transaction + with policy_trans.NsxPolicyTransaction(): + # Build new rules and relevant objects + backend_rules = [] + for rule_data in rules_db: + rule_entry = self._create_security_group_backend_rule( + context, sg_id, rule_data, secgroup_logging, + is_provider_sg=is_provider_sg) + + backend_rules.append(rule_entry) + + # Add the old rules + for rule in sg['security_group_rules']: + rule_entry = self.nsxpolicy.comm_map.build_entry( + NSX_P_GLOBAL_DOMAIN_ID, sg_id, rule['id']) + backend_rules.append(rule_entry) + + # Update the policy with all the rules. + self.nsxpolicy.comm_map.update_with_entries( + NSX_P_GLOBAL_DOMAIN_ID, sg_id, entries=backend_rules, + category=category) return rules_db diff --git a/vmware_nsx/tests/unit/extensions/test_provider_security_groups.py b/vmware_nsx/tests/unit/extensions/test_provider_security_groups.py index 09c263f2bb..9a9f10d6ff 100644 --- a/vmware_nsx/tests/unit/extensions/test_provider_security_groups.py +++ b/vmware_nsx/tests/unit/extensions/test_provider_security_groups.py @@ -28,8 +28,6 @@ from vmware_nsx.extensions import providersecuritygroup as provider_sg from vmware_nsx.tests.unit.nsx_p import test_plugin as test_nsxp_plugin from vmware_nsx.tests.unit.nsx_v import test_plugin as test_nsxv_plugin from vmware_nsx.tests.unit.nsx_v3 import test_plugin as test_nsxv3_plugin -from vmware_nsxlib.v3 import nsx_constants -from vmware_nsxlib.v3.policy import constants as policy_constants PLUGIN_NAME = ('vmware_nsx.tests.unit.extensions.' @@ -404,23 +402,7 @@ class TestNSXpProviderSecurityGrp(test_nsxp_plugin.NsxPPluginTestCaseMixin, sg_id = provider_secgroup['security_group']['id'] with mock.patch("vmware_nsxlib.v3.policy.core_resources." - "NsxPolicyCommunicationMapApi.create_entry" - ) as entry_create: - with self.security_group_rule(security_group_id=sg_id) as rule: - rule_data = rule['security_group_rule'] - rule_id = rule_data['id'] - scope = [self.plugin.nsxpolicy.group.get_path( - policy_constants.DEFAULT_DOMAIN, sg_id)] - entry_create.assert_called_once_with( - rule_id, policy_constants.DEFAULT_DOMAIN, - sg_id, entry_id=rule_id, - description='', - direction=nsx_constants.IN, - ip_protocol=nsx_constants.IPV4, - action=policy_constants.ACTION_DENY, - service_ids=mock.ANY, - source_groups=mock.ANY, - dest_groups=mock.ANY, - scope=scope, - logged=False, - tag=rule_data['project_id']) + "NsxPolicyCommunicationMapApi.update_with_entries" + ) as entry_create,\ + self.security_group_rule(security_group_id=sg_id): + entry_create.assert_called_once() diff --git a/vmware_nsx/tests/unit/nsx_p/test_plugin.py b/vmware_nsx/tests/unit/nsx_p/test_plugin.py index 8bce299fa3..4b7b40b217 100644 --- a/vmware_nsx/tests/unit/nsx_p/test_plugin.py +++ b/vmware_nsx/tests/unit/nsx_p/test_plugin.py @@ -1213,7 +1213,7 @@ class NsxPTestSecurityGroup(common_v3.FixExternalNetBaseTest, ) as group_create,\ mock.patch("vmware_nsxlib.v3.policy.core_resources." "NsxPolicyCommunicationMapApi." - "create_with_entries") as comm_map_create,\ + "create_or_overwrite_map_only") as comm_map_create,\ self.security_group(name, description) as sg: sg_id = sg['security_group']['id'] nsx_name = utils.get_name_and_uuid(name, sg_id) @@ -1225,7 +1225,6 @@ class NsxPTestSecurityGroup(common_v3.FixExternalNetBaseTest, nsx_name, policy_constants.DEFAULT_DOMAIN, map_id=sg_id, description=description, tags=mock.ANY, - entries=mock.ANY, category=policy_constants.CATEGORY_ENVIRONMENT) def _create_provider_security_group(self): @@ -1299,28 +1298,13 @@ class NsxPTestSecurityGroup(common_v3.FixExternalNetBaseTest, with self.security_group(name, description) as sg: sg_id = sg['security_group']['id'] with mock.patch("vmware_nsxlib.v3.policy.core_resources." - "NsxPolicyCommunicationMapApi.create_entry" - ) as entry_create,\ + "NsxPolicyCommunicationMapApi.update_with_entries" + ) as update_policy,\ self.security_group_rule(sg_id, direction, protocol, port_range_min, port_range_max, - remote_ip_prefix) as rule: - rule_id = rule['security_group_rule']['id'] - scope = [self.plugin.nsxpolicy.group.get_path( - policy_constants.DEFAULT_DOMAIN, sg_id)] - entry_create.assert_called_once_with( - rule_id, policy_constants.DEFAULT_DOMAIN, - sg_id, entry_id=rule_id, - description='', - direction=nsx_constants.IN, - ip_protocol=nsx_constants.IPV4, - action=policy_constants.ACTION_ALLOW, - service_ids=mock.ANY, - source_groups=mock.ANY, - dest_groups=mock.ANY, - scope=scope, - logged=False, - tag=mock.ANY) + remote_ip_prefix): + update_policy.assert_called_once() def test_create_security_group_rule_with_remote_group(self): with self.security_group() as sg1, self.security_group() as sg2: