diff --git a/neutron/db/securitygroups_db.py b/neutron/db/securitygroups_db.py index 9ebd30b0ef1..6fdbd0c3a99 100644 --- a/neutron/db/securitygroups_db.py +++ b/neutron/db/securitygroups_db.py @@ -530,27 +530,27 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): query = self._model_query(context, DefaultSecurityGroup) # the next loop should do 2 iterations at max while True: - with db_api.autonested_transaction(context.session): + try: + default_group = query.filter_by(tenant_id=tenant_id).one() + except exc.NoResultFound: + security_group = { + 'security_group': + {'name': 'default', + 'tenant_id': tenant_id, + 'description': _('Default security group')} + } try: - default_group = query.filter_by(tenant_id=tenant_id).one() - except exc.NoResultFound: - security_group = { - 'security_group': - {'name': 'default', - 'tenant_id': tenant_id, - 'description': _('Default security group')} - } - try: + with db_api.autonested_transaction(context.session): ret = self.create_security_group( context, security_group, default_sg=True) - except exception.DBDuplicateEntry as ex: - LOG.debug("Duplicate default security group %s was " - "not created", ex.value) - continue - else: - return ret['id'] + except exception.DBDuplicateEntry as ex: + LOG.debug("Duplicate default security group %s was " + "not created", ex.value) + continue else: - return default_group['security_group_id'] + return ret['id'] + else: + return default_group['security_group_id'] def _get_security_groups_on_port(self, context, port): """Check that all security groups on port belong to tenant. diff --git a/neutron/tests/unit/test_extension_security_group.py b/neutron/tests/unit/test_extension_security_group.py index 6ffca895b6a..f13c19a978c 100644 --- a/neutron/tests/unit/test_extension_security_group.py +++ b/neutron/tests/unit/test_extension_security_group.py @@ -266,11 +266,27 @@ class TestSecurityGroups(SecurityGroupDBTestCase): self._assert_sg_rule_has_kvs(v6_rule, expected) def test_skip_duplicate_default_sg_error(self): - # can't always raise, or create_security_group will hang + num_called = [0] + original_func = self.plugin.create_security_group + + def side_effect(context, security_group, default_sg): + # can't always raise, or create_security_group will hang + self.assertTrue(default_sg) + self.assertTrue(num_called[0] < 2) + num_called[0] += 1 + ret = original_func(context, security_group, default_sg) + if num_called[0] == 1: + return ret + # make another call to cause an exception. + # NOTE(yamamoto): raising the exception by ourselves + # doesn't update the session state appropriately. + self.assertRaises(exc.DBDuplicateEntry, + original_func, context, security_group, + default_sg) + with mock.patch.object(SecurityGroupTestPlugin, 'create_security_group', - side_effect=[exc.DBDuplicateEntry(), - {'id': 'foo'}]): + side_effect=side_effect): self.plugin.create_network( context.get_admin_context(), {'network': {'name': 'foo',