diff --git a/neutron/db/securitygroups_db.py b/neutron/db/securitygroups_db.py index 9116df6ec2b..494927e714a 100644 --- a/neutron/db/securitygroups_db.py +++ b/neutron/db/securitygroups_db.py @@ -192,6 +192,10 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): raise ext_sg.SecurityGroupNotFound(id=id) return sg + def _check_security_group(self, context, id, **kwargs): + if not sg_obj.SecurityGroup.objects_exist(context, id=id, **kwargs): + raise ext_sg.SecurityGroupNotFound(id=id) + @db_api.retry_if_session_inactive() def delete_security_group(self, context, id): filters = {'security_group_id': [id]} @@ -325,10 +329,7 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): security_group_id = self._validate_security_group_rules( context, security_group_rules) with db_api.CONTEXT_WRITER.using(context): - if not self.get_security_group(context, security_group_id): - raise ext_sg.SecurityGroupNotFound(id=security_group_id) - - self._check_for_duplicate_rules(context, rules) + self._check_for_duplicate_rules(context, security_group_id, rules) ret = [] for rule_dict in rules: res_rule_dict = self._create_security_group_rule( @@ -351,7 +352,8 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): def _create_security_group_rule(self, context, security_group_rule, validate=True): if validate: - self._validate_security_group_rule(context, security_group_rule) + sg_id = self._validate_security_group_rule(context, + security_group_rule) rule_dict = security_group_rule['security_group_rule'] remote_ip_prefix = rule_dict.get('remote_ip_prefix') if remote_ip_prefix: @@ -391,8 +393,8 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): exc_cls=ext_sg.SecurityGroupConflict, **kwargs) with db_api.CONTEXT_WRITER.using(context): if validate: - self._check_for_duplicate_rules_in_db(context, - security_group_rule) + self._check_for_duplicate_rules(context, sg_id, + [security_group_rule]) sg_rule = sg_obj.SecurityGroupRule(context, **args) sg_rule.create() @@ -525,15 +527,15 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): remote_group_id = rule['remote_group_id'] # Check that remote_group_id exists for tenant if remote_group_id: - self.get_security_group(context, remote_group_id, - tenant_id=rule['tenant_id']) + self._check_security_group(context, remote_group_id, + project_id=rule['tenant_id']) security_group_id = rule['security_group_id'] # Confirm that the tenant has permission # to add rules to this security group. - self.get_security_group(context, security_group_id, - tenant_id=rule['tenant_id']) + self._check_security_group(context, security_group_id, + project_id=rule['tenant_id']) return security_group_id def _validate_security_group_rules(self, context, security_group_rules): @@ -576,73 +578,54 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): res['protocol'] = self._get_ip_proto_name_and_num(value) return res - def _rules_equal(self, rule1, rule2): - """Determines if two rules are equal ignoring id field.""" - rule1_copy = rule1.copy() - rule2_copy = rule2.copy() - rule1_copy.pop('id', None) - rule2_copy.pop('id', None) - return rule1_copy == rule2_copy + def _rule_to_key(self, rule): + def _normalize_rule_value(key, value): + # This string is used as a placeholder for str(None), but shorter. + none_char = '+' - def _check_for_duplicate_rules(self, context, security_group_rules): - for i in security_group_rules: - found_self = False - for j in security_group_rules: - if self._rules_equal(i['security_group_rule'], - j['security_group_rule']): - if found_self: - raise ext_sg.DuplicateSecurityGroupRuleInPost(rule=i) - found_self = True + if key == 'remote_ip_prefix': + all_address = ['0.0.0.0/0', '::/0', None] + if value in all_address: + return none_char + elif value is None: + return none_char + elif key == 'protocol': + return str(self._get_ip_proto_name_and_num(value)) + return str(value) - self._check_for_duplicate_rules_in_db(context, i) + comparison_keys = [ + 'direction', + 'ethertype', + 'port_range_max', + 'port_range_min', + 'protocol', + 'remote_group_id', + 'remote_ip_prefix', + 'security_group_id' + ] + return '_'.join([_normalize_rule_value(x, rule.get(x)) + for x in comparison_keys]) - def _check_for_duplicate_rules_in_db(self, context, security_group_rule): - # Check in database if rule exists - filters = self._make_security_group_rule_filter_dict( - security_group_rule) - rule_dict = security_group_rule['security_group_rule'].copy() - rule_dict.pop('description', None) - keys = rule_dict.keys() - fields = list(keys) + ['id'] - if 'remote_ip_prefix' not in fields: - fields += ['remote_ip_prefix'] - db_rules = self.get_security_group_rules(context, filters, - fields=fields) - # Note(arosen): the call to get_security_group_rules wildcards - # values in the filter that have a value of [None]. For - # example, filters = {'remote_group_id': [None]} will return - # all security group rules regardless of their value of - # remote_group_id. Therefore it is not possible to do this - # query unless the behavior of _get_collection() - # is changed which cannot be because other methods are already - # relying on this behavior. Therefore, we do the filtering - # below to check for these corner cases. - rule_dict.pop('id', None) - sg_protocol = rule_dict.pop('protocol', None) - remote_ip_prefix = rule_dict.pop('remote_ip_prefix', None) - for db_rule in db_rules: - rule_id = db_rule.pop('id', None) - # remove protocol and match separately for number and type - db_protocol = db_rule.pop('protocol', None) - is_protocol_matching = ( - self._get_ip_proto_name_and_num(db_protocol) == - self._get_ip_proto_name_and_num(sg_protocol)) - db_remote_ip_prefix = db_rule.pop('remote_ip_prefix', None) - duplicate_ip_prefix = self._validate_duplicate_ip_prefix( - remote_ip_prefix, db_remote_ip_prefix) - if (is_protocol_matching and duplicate_ip_prefix and - rule_dict == db_rule): - raise ext_sg.SecurityGroupRuleExists(rule_id=rule_id) + def _check_for_duplicate_rules(self, context, security_group_id, + new_security_group_rules): + # First up, check for any duplicates in the new rules. + new_rules_set = set() + for i in new_security_group_rules: + rule_key = self._rule_to_key(i['security_group_rule']) + if rule_key in new_rules_set: + raise ext_sg.DuplicateSecurityGroupRuleInPost(rule=i) + new_rules_set.add(rule_key) - def _validate_duplicate_ip_prefix(self, ip_prefix, other_ip_prefix): - if other_ip_prefix is not None: - other_ip_prefix = str(other_ip_prefix) - all_address = ['0.0.0.0/0', '::/0', None] - if ip_prefix == other_ip_prefix: - return True - elif ip_prefix in all_address and other_ip_prefix in all_address: - return True - return False + # Now, let's make sure none of the new rules conflict with + # existing rules; note that we do *not* store the db rules + # in the set, as we assume they were already checked, + # when added. + sg = self.get_security_group(context, security_group_id) + if sg: + for i in sg['security_group_rules']: + rule_key = self._rule_to_key(i) + if rule_key in new_rules_set: + raise ext_sg.SecurityGroupRuleExists(rule_id=i.get('id')) def _validate_ip_prefix(self, rule): """Check that a valid cidr was specified as remote_ip_prefix diff --git a/neutron/objects/base.py b/neutron/objects/base.py index d58ef16eb79..aaf1aa976ed 100644 --- a/neutron/objects/base.py +++ b/neutron/objects/base.py @@ -878,6 +878,6 @@ class NeutronDbObject(NeutronObject): if validate_filters: cls.validate_filters(**kwargs) # Succeed if at least a single object matches; no need to fetch more - return bool(obj_db_api.get_object( + return bool(obj_db_api.count( cls, context, **cls.modify_fields_to_db(kwargs)) ) diff --git a/neutron/tests/unit/db/test_securitygroups_db.py b/neutron/tests/unit/db/test_securitygroups_db.py index 283ab379548..a5254acfe54 100644 --- a/neutron/tests/unit/db/test_securitygroups_db.py +++ b/neutron/tests/unit/db/test_securitygroups_db.py @@ -103,7 +103,7 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): def test_create_security_group_rule_conflict(self): with mock.patch.object(self.mixin, '_validate_security_group_rule'),\ mock.patch.object(self.mixin, - '_check_for_duplicate_rules_in_db'),\ + '_check_for_duplicate_rules'),\ mock.patch.object(registry, "notify") as mock_notify: mock_notify.side_effect = exceptions.CallbackFailure(Exception()) with testtools.ExpectedException( @@ -111,9 +111,9 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): self.mixin.create_security_group_rule( self.ctx, mock.MagicMock()) - def test__check_for_duplicate_rules_in_db_does_not_drop_protocol(self): - with mock.patch.object(self.mixin, 'get_security_group_rules', - return_value=[mock.Mock()]): + def test__check_for_duplicate_rules_does_not_drop_protocol(self): + with mock.patch.object(self.mixin, 'get_security_group', + return_value=None): context = mock.Mock() rule_dict = { 'security_group_rule': {'protocol': None, @@ -121,7 +121,7 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): 'security_group_id': 'fake', 'direction': 'fake'} } - self.mixin._check_for_duplicate_rules_in_db(context, rule_dict) + self.mixin._check_for_duplicate_rules(context, 'fake', [rule_dict]) self.assertIn('protocol', rule_dict['security_group_rule']) def test__check_for_duplicate_rules_ignores_rule_id(self): @@ -132,33 +132,20 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): # in this case as this test, tests that the id fields are dropped # while being compared. This is in the case if a plugin specifies # the rule ids themselves. - self.assertRaises(securitygroup.DuplicateSecurityGroupRuleInPost, - self.mixin._check_for_duplicate_rules, - context, rules) - - def test__check_for_duplicate_rules_in_db_ignores_rule_id(self): - db_rules = {'protocol': 'tcp', 'id': 'fake', 'tenant_id': 'fake', - 'direction': 'ingress', 'security_group_id': 'fake'} - with mock.patch.object(self.mixin, 'get_security_group_rules', - return_value=[db_rules]): - context = mock.Mock() - rule_dict = { - 'security_group_rule': {'protocol': 'tcp', - 'id': 'fake2', - 'tenant_id': 'fake', - 'security_group_id': 'fake', - 'direction': 'ingress'} - } - self.assertRaises(securitygroup.SecurityGroupRuleExists, - self.mixin._check_for_duplicate_rules_in_db, - context, rule_dict) + with mock.patch.object(self.mixin, 'get_security_group', + return_value=None): + self.assertRaises(securitygroup.DuplicateSecurityGroupRuleInPost, + self.mixin._check_for_duplicate_rules, + context, 'fake', rules) def test_check_for_duplicate_diff_rules_remote_ip_prefix_ipv4(self): - db_rules = {'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv4', - 'direction': 'ingress', 'security_group_id': 'fake', - 'remote_ip_prefix': None} - with mock.patch.object(self.mixin, 'get_security_group_rules', - return_value=[db_rules]): + fake_secgroup = copy.deepcopy(FAKE_SECGROUP) + fake_secgroup['security_group_rules'] = \ + [{'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv4', + 'direction': 'ingress', 'security_group_id': 'fake', + 'remote_ip_prefix': None}] + with mock.patch.object(self.mixin, 'get_security_group', + return_value=fake_secgroup): context = mock.Mock() rule_dict = { 'security_group_rule': {'id': 'fake2', @@ -169,15 +156,17 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): 'remote_ip_prefix': '0.0.0.0/0'} } self.assertRaises(securitygroup.SecurityGroupRuleExists, - self.mixin._check_for_duplicate_rules_in_db, - context, rule_dict) + self.mixin._check_for_duplicate_rules, + context, 'fake', [rule_dict]) def test_check_for_duplicate_diff_rules_remote_ip_prefix_ipv6(self): - db_rules = {'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv6', - 'direction': 'ingress', 'security_group_id': 'fake', - 'remote_ip_prefix': None} - with mock.patch.object(self.mixin, 'get_security_group_rules', - return_value=[db_rules]): + fake_secgroup = copy.deepcopy(FAKE_SECGROUP) + fake_secgroup['security_group_rules'] = \ + [{'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv6', + 'direction': 'ingress', 'security_group_id': 'fake', + 'remote_ip_prefix': None}] + with mock.patch.object(self.mixin, 'get_security_group', + return_value=fake_secgroup): context = mock.Mock() rule_dict = { 'security_group_rule': {'id': 'fake2', @@ -188,8 +177,8 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): 'remote_ip_prefix': '::/0'} } self.assertRaises(securitygroup.SecurityGroupRuleExists, - self.mixin._check_for_duplicate_rules_in_db, - context, rule_dict) + self.mixin._check_for_duplicate_rules, + context, 'fake', [rule_dict]) def test_delete_security_group_rule_in_use(self): with mock.patch.object(registry, "notify") as mock_notify: