diff --git a/neutron/db/securitygroups_db.py b/neutron/db/securitygroups_db.py index 16d04aceaa7..95f354d74d8 100644 --- a/neutron/db/securitygroups_db.py +++ b/neutron/db/securitygroups_db.py @@ -22,6 +22,7 @@ from neutron_lib.callbacks import resources from neutron_lib import constants from neutron_lib import context as context_lib from neutron_lib import exceptions as n_exc +from neutron_lib.objects import exceptions as obj_exc from neutron_lib.utils import helpers from neutron_lib.utils import net from oslo_utils import uuidutils @@ -41,6 +42,9 @@ from neutron.objects import base as base_obj from neutron.objects import securitygroup as sg_obj +DEFAULT_SG_DESCRIPTION = _('Default security group') + + @resource_extend.has_resource_extenders @registry.has_registry_receivers class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): @@ -805,10 +809,13 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): 'security_group': {'name': 'default', 'tenant_id': tenant_id, - 'description': _('Default security group')} + 'description': DEFAULT_SG_DESCRIPTION} } - return self.create_security_group(context, security_group, - default_sg=True)['id'] + try: + return self.create_security_group(context, security_group, + default_sg=True)['id'] + except obj_exc.NeutronDbObjectDuplicateEntry: + return self._get_default_sg_id(context, tenant_id) def _get_security_groups_on_port(self, context, port): """Check that all security groups on port belong to tenant. diff --git a/neutron/tests/unit/db/test_securitygroups_db.py b/neutron/tests/unit/db/test_securitygroups_db.py index 92b2f0259e7..f0bfdf38b63 100644 --- a/neutron/tests/unit/db/test_securitygroups_db.py +++ b/neutron/tests/unit/db/test_securitygroups_db.py @@ -20,6 +20,7 @@ from neutron_lib.callbacks import registry from neutron_lib.callbacks import resources from neutron_lib import constants from neutron_lib import context +from neutron_lib.objects import exceptions as obj_exc import sqlalchemy import testtools @@ -517,3 +518,49 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): for rule in (rule for rule in rules_after if rule not in rules_before): self.assertEqual('tenant_1', rule['tenant_id']) self.assertEqual(self.sg_user['id'], rule['security_group_id']) + + def test__ensure_default_security_group(self): + with mock.patch.object( + self.mixin, '_get_default_sg_id') as get_default_sg_id,\ + mock.patch.object( + self.mixin, 'create_security_group') as create_sg: + get_default_sg_id.return_value = None + self.mixin._ensure_default_security_group(self.ctx, 'tenant_1') + create_sg.assert_called_once_with( + self.ctx, + {'security_group': { + 'name': 'default', + 'tenant_id': 'tenant_1', + 'description': securitygroups_db.DEFAULT_SG_DESCRIPTION}}, + default_sg=True) + get_default_sg_id.assert_called_once_with(self.ctx, 'tenant_1') + + def test__ensure_default_security_group_already_exists(self): + with mock.patch.object( + self.mixin, '_get_default_sg_id') as get_default_sg_id,\ + mock.patch.object( + self.mixin, 'create_security_group') as create_sg: + get_default_sg_id.return_value = 'default_sg_id' + self.mixin._ensure_default_security_group(self.ctx, 'tenant_1') + create_sg.assert_not_called() + get_default_sg_id.assert_called_once_with(self.ctx, 'tenant_1') + + def test__ensure_default_security_group_created_in_parallel(self): + with mock.patch.object( + self.mixin, '_get_default_sg_id') as get_default_sg_id,\ + mock.patch.object( + self.mixin, 'create_security_group') as create_sg: + get_default_sg_id.side_effect = [None, 'default_sg_id'] + create_sg.side_effect = obj_exc.NeutronDbObjectDuplicateEntry( + mock.Mock(), mock.Mock()) + self.mixin._ensure_default_security_group(self.ctx, 'tenant_1') + create_sg.assert_called_once_with( + self.ctx, + {'security_group': { + 'name': 'default', + 'tenant_id': 'tenant_1', + 'description': securitygroups_db.DEFAULT_SG_DESCRIPTION}}, + default_sg=True) + get_default_sg_id.assert_has_calls([ + mock.call(self.ctx, 'tenant_1'), + mock.call(self.ctx, 'tenant_1')])