diff --git a/neutron/db/securitygroups_db.py b/neutron/db/securitygroups_db.py index 33ca2a5cd9a..ee821f0b09d 100644 --- a/neutron/db/securitygroups_db.py +++ b/neutron/db/securitygroups_db.py @@ -194,6 +194,18 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): raise ext_sg.SecurityGroupNotFound(id=id) return sg + def _check_security_group(self, context, id, tenant_id=None): + if tenant_id: + tmp_context_tenant_id = context.tenant_id + context.tenant_id = tenant_id + + try: + if not sg_obj.SecurityGroup.objects_exist(context, id=id): + raise ext_sg.SecurityGroupNotFound(id=id) + finally: + if tenant_id: + context.tenant_id = tmp_context_tenant_id + @db_api.retry_if_session_inactive() def delete_security_group(self, context, id): filters = {'security_group_id': [id]} @@ -324,10 +336,7 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): security_group_id = self._validate_security_group_rules( context, security_group_rules) with db_api.context_manager.writer.using(context): - if not self.get_security_group(context, security_group_id): - raise ext_sg.SecurityGroupNotFound(id=security_group_id) - - self._check_for_duplicate_rules(context, rules) + self._check_for_duplicate_rules(context, security_group_id, rules) ret = [] for rule_dict in rules: res_rule_dict = self._create_security_group_rule( @@ -350,7 +359,8 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): def _create_security_group_rule(self, context, security_group_rule, validate=True): if validate: - self._validate_security_group_rule(context, security_group_rule) + sg_id = self._validate_security_group_rule(context, + security_group_rule) rule_dict = security_group_rule['security_group_rule'] remote_ip_prefix = rule_dict.get('remote_ip_prefix') if remote_ip_prefix: @@ -390,8 +400,8 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): exc_cls=ext_sg.SecurityGroupConflict, **kwargs) with db_api.context_manager.writer.using(context): if validate: - self._check_for_duplicate_rules_in_db(context, - security_group_rule) + self._check_for_duplicate_rules(context, sg_id, + [security_group_rule]) sg_rule = sg_obj.SecurityGroupRule(context, **args) sg_rule.create() @@ -521,15 +531,15 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): remote_group_id = rule['remote_group_id'] # Check that remote_group_id exists for tenant if remote_group_id: - self.get_security_group(context, remote_group_id, - tenant_id=rule['tenant_id']) + self._check_security_group(context, remote_group_id, + tenant_id=rule['tenant_id']) security_group_id = rule['security_group_id'] # Confirm that the tenant has permission # to add rules to this security group. - self.get_security_group(context, security_group_id, - tenant_id=rule['tenant_id']) + self._check_security_group(context, security_group_id, + tenant_id=rule['tenant_id']) return security_group_id def _validate_security_group_rules(self, context, security_group_rules): @@ -572,73 +582,54 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): res['protocol'] = self._get_ip_proto_name_and_num(value) return res - def _rules_equal(self, rule1, rule2): - """Determines if two rules are equal ignoring id field.""" - rule1_copy = rule1.copy() - rule2_copy = rule2.copy() - rule1_copy.pop('id', None) - rule2_copy.pop('id', None) - return rule1_copy == rule2_copy + def _rule_to_key(self, rule): + def _normalize_rule_value(key, value): + # This string is used as a placeholder for str(None), but shorter. + none_char = '+' - def _check_for_duplicate_rules(self, context, security_group_rules): - for i in security_group_rules: - found_self = False - for j in security_group_rules: - if self._rules_equal(i['security_group_rule'], - j['security_group_rule']): - if found_self: - raise ext_sg.DuplicateSecurityGroupRuleInPost(rule=i) - found_self = True + if key == 'remote_ip_prefix': + all_address = ['0.0.0.0/0', '::/0', None] + if value in all_address: + return none_char + elif value is None: + return none_char + elif key == 'protocol': + return str(self._get_ip_proto_name_and_num(value)) + return str(value) - self._check_for_duplicate_rules_in_db(context, i) + comparison_keys = [ + 'direction', + 'ethertype', + 'port_range_max', + 'port_range_min', + 'protocol', + 'remote_group_id', + 'remote_ip_prefix', + 'security_group_id' + ] + return '_'.join([_normalize_rule_value(x, rule.get(x)) + for x in comparison_keys]) - def _check_for_duplicate_rules_in_db(self, context, security_group_rule): - # Check in database if rule exists - filters = self._make_security_group_rule_filter_dict( - security_group_rule) - rule_dict = security_group_rule['security_group_rule'].copy() - rule_dict.pop('description', None) - keys = rule_dict.keys() - fields = list(keys) + ['id'] - if 'remote_ip_prefix' not in fields: - fields += ['remote_ip_prefix'] - db_rules = self.get_security_group_rules(context, filters, - fields=fields) - # Note(arosen): the call to get_security_group_rules wildcards - # values in the filter that have a value of [None]. For - # example, filters = {'remote_group_id': [None]} will return - # all security group rules regardless of their value of - # remote_group_id. Therefore it is not possible to do this - # query unless the behavior of _get_collection() - # is changed which cannot be because other methods are already - # relying on this behavior. Therefore, we do the filtering - # below to check for these corner cases. - rule_dict.pop('id', None) - sg_protocol = rule_dict.pop('protocol', None) - remote_ip_prefix = rule_dict.pop('remote_ip_prefix', None) - for db_rule in db_rules: - rule_id = db_rule.pop('id', None) - # remove protocol and match separately for number and type - db_protocol = db_rule.pop('protocol', None) - is_protocol_matching = ( - self._get_ip_proto_name_and_num(db_protocol) == - self._get_ip_proto_name_and_num(sg_protocol)) - db_remote_ip_prefix = db_rule.pop('remote_ip_prefix', None) - duplicate_ip_prefix = self._validate_duplicate_ip_prefix( - remote_ip_prefix, db_remote_ip_prefix) - if (is_protocol_matching and duplicate_ip_prefix and - rule_dict == db_rule): - raise ext_sg.SecurityGroupRuleExists(rule_id=rule_id) + def _check_for_duplicate_rules(self, context, security_group_id, + new_security_group_rules): + # First up, check for any duplicates in the new rules. + new_rules_set = set() + for i in new_security_group_rules: + rule_key = self._rule_to_key(i['security_group_rule']) + if rule_key in new_rules_set: + raise ext_sg.DuplicateSecurityGroupRuleInPost(rule=i) + new_rules_set.add(rule_key) - def _validate_duplicate_ip_prefix(self, ip_prefix, other_ip_prefix): - if other_ip_prefix is not None: - other_ip_prefix = str(other_ip_prefix) - all_address = ['0.0.0.0/0', '::/0', None] - if ip_prefix == other_ip_prefix: - return True - elif ip_prefix in all_address and other_ip_prefix in all_address: - return True - return False + # Now, let's make sure none of the new rules conflict with + # existing rules; note that we do *not* store the db rules + # in the set, as we assume they were already checked, + # when added. + sg = self.get_security_group(context, security_group_id) + if sg: + for i in sg['security_group_rules']: + rule_key = self._rule_to_key(i) + if rule_key in new_rules_set: + raise ext_sg.SecurityGroupRuleExists(rule_id=i.get('id')) def _validate_ip_prefix(self, rule): """Check that a valid cidr was specified as remote_ip_prefix diff --git a/neutron/objects/base.py b/neutron/objects/base.py index 584dc6f79c1..a4a5716ddef 100644 --- a/neutron/objects/base.py +++ b/neutron/objects/base.py @@ -709,6 +709,6 @@ class NeutronDbObject(NeutronObject): if validate_filters: cls.validate_filters(**kwargs) # Succeed if at least a single object matches; no need to fetch more - return bool(obj_db_api.get_object( + return bool(obj_db_api.count( context, cls.db_model, **cls.modify_fields_to_db(kwargs)) ) diff --git a/neutron/tests/unit/db/test_securitygroups_db.py b/neutron/tests/unit/db/test_securitygroups_db.py index 1586dbaef0b..43577728be9 100644 --- a/neutron/tests/unit/db/test_securitygroups_db.py +++ b/neutron/tests/unit/db/test_securitygroups_db.py @@ -103,7 +103,7 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): def test_create_security_group_rule_conflict(self): with mock.patch.object(self.mixin, '_validate_security_group_rule'),\ mock.patch.object(self.mixin, - '_check_for_duplicate_rules_in_db'),\ + '_check_for_duplicate_rules'),\ mock.patch.object(registry, "notify") as mock_notify: mock_notify.side_effect = exceptions.CallbackFailure(Exception()) with testtools.ExpectedException( @@ -111,9 +111,9 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): self.mixin.create_security_group_rule( self.ctx, mock.MagicMock()) - def test__check_for_duplicate_rules_in_db_does_not_drop_protocol(self): - with mock.patch.object(self.mixin, 'get_security_group_rules', - return_value=[mock.Mock()]): + def test__check_for_duplicate_rules_does_not_drop_protocol(self): + with mock.patch.object(self.mixin, 'get_security_group', + return_value=None): context = mock.Mock() rule_dict = { 'security_group_rule': {'protocol': None, @@ -121,7 +121,7 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): 'security_group_id': 'fake', 'direction': 'fake'} } - self.mixin._check_for_duplicate_rules_in_db(context, rule_dict) + self.mixin._check_for_duplicate_rules(context, 'fake', [rule_dict]) self.assertIn('protocol', rule_dict['security_group_rule']) def test__check_for_duplicate_rules_ignores_rule_id(self): @@ -132,33 +132,20 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): # in this case as this test, tests that the id fields are dropped # while being compared. This is in the case if a plugin specifies # the rule ids themselves. - self.assertRaises(securitygroup.DuplicateSecurityGroupRuleInPost, - self.mixin._check_for_duplicate_rules, - context, rules) - - def test__check_for_duplicate_rules_in_db_ignores_rule_id(self): - db_rules = {'protocol': 'tcp', 'id': 'fake', 'tenant_id': 'fake', - 'direction': 'ingress', 'security_group_id': 'fake'} - with mock.patch.object(self.mixin, 'get_security_group_rules', - return_value=[db_rules]): - context = mock.Mock() - rule_dict = { - 'security_group_rule': {'protocol': 'tcp', - 'id': 'fake2', - 'tenant_id': 'fake', - 'security_group_id': 'fake', - 'direction': 'ingress'} - } - self.assertRaises(securitygroup.SecurityGroupRuleExists, - self.mixin._check_for_duplicate_rules_in_db, - context, rule_dict) + with mock.patch.object(self.mixin, 'get_security_group', + return_value=None): + self.assertRaises(securitygroup.DuplicateSecurityGroupRuleInPost, + self.mixin._check_for_duplicate_rules, + context, 'fake', rules) def test_check_for_duplicate_diff_rules_remote_ip_prefix_ipv4(self): - db_rules = {'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv4', - 'direction': 'ingress', 'security_group_id': 'fake', - 'remote_ip_prefix': None} - with mock.patch.object(self.mixin, 'get_security_group_rules', - return_value=[db_rules]): + fake_secgroup = copy.deepcopy(FAKE_SECGROUP) + fake_secgroup['security_group_rules'] = \ + [{'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv4', + 'direction': 'ingress', 'security_group_id': 'fake', + 'remote_ip_prefix': None}] + with mock.patch.object(self.mixin, 'get_security_group', + return_value=fake_secgroup): context = mock.Mock() rule_dict = { 'security_group_rule': {'id': 'fake2', @@ -169,15 +156,17 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): 'remote_ip_prefix': '0.0.0.0/0'} } self.assertRaises(securitygroup.SecurityGroupRuleExists, - self.mixin._check_for_duplicate_rules_in_db, - context, rule_dict) + self.mixin._check_for_duplicate_rules, + context, 'fake', [rule_dict]) def test_check_for_duplicate_diff_rules_remote_ip_prefix_ipv6(self): - db_rules = {'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv6', - 'direction': 'ingress', 'security_group_id': 'fake', - 'remote_ip_prefix': None} - with mock.patch.object(self.mixin, 'get_security_group_rules', - return_value=[db_rules]): + fake_secgroup = copy.deepcopy(FAKE_SECGROUP) + fake_secgroup['security_group_rules'] = \ + [{'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv6', + 'direction': 'ingress', 'security_group_id': 'fake', + 'remote_ip_prefix': None}] + with mock.patch.object(self.mixin, 'get_security_group', + return_value=fake_secgroup): context = mock.Mock() rule_dict = { 'security_group_rule': {'id': 'fake2', @@ -188,8 +177,8 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): 'remote_ip_prefix': '::/0'} } self.assertRaises(securitygroup.SecurityGroupRuleExists, - self.mixin._check_for_duplicate_rules_in_db, - context, rule_dict) + self.mixin._check_for_duplicate_rules, + context, 'fake', [rule_dict]) def test_delete_security_group_rule_in_use(self): with mock.patch.object(registry, "notify") as mock_notify: diff --git a/neutron/tests/unit/extensions/test_securitygroup.py b/neutron/tests/unit/extensions/test_securitygroup.py index 4e8b23b41ad..524b98a425e 100644 --- a/neutron/tests/unit/extensions/test_securitygroup.py +++ b/neutron/tests/unit/extensions/test_securitygroup.py @@ -129,6 +129,10 @@ class SecurityGroupsTestCase(test_db_base_plugin_v2.NeutronDbPluginV2TestCase): # create a specific auth context for this request security_group_rule_req.environ['neutron.context'] = ( context.Context('', kwargs['tenant_id'])) + elif kwargs.get('admin_context'): + security_group_rule_req.environ['neutron.context'] = ( + context.Context(user_id='admin', tenant_id='admin-tenant', + is_admin=True)) return security_group_rule_req.get_response(self.ext_api) def _make_security_group(self, fmt, name, description, **kwargs): @@ -695,6 +699,50 @@ class TestSecurityGroups(SecurityGroupDBTestCase): for k, v, in keys: self.assertEqual(sg_rule[0][k], v) + # This test case checks that admins from a different tenant can add rules + # as themselves. This is an odd behavior, with some weird GET semantics, + # but this test is checking that we don't break that old behavior, at least + # until we make a conscious choice to do so. + def test_create_security_group_rules_admin_tenant(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + # Add a couple normal rules + rule = self._build_security_group_rule( + sg['security_group']['id'], "ingress", const.PROTO_NAME_TCP, + port_range_min=22, port_range_max=22, + remote_ip_prefix="10.0.0.0/24", + ethertype=const.IPv4) + self._make_security_group_rule(self.fmt, rule) + + rule = self._build_security_group_rule( + sg['security_group']['id'], "ingress", const.PROTO_NAME_TCP, + port_range_min=22, port_range_max=22, + remote_ip_prefix="10.0.1.0/24", + ethertype=const.IPv4) + self._make_security_group_rule(self.fmt, rule) + + # Let's add a rule as admin, with a different tenant_id. The + # results of this call are arguably a bug, but it is past behavior. + rule = self._build_security_group_rule( + sg['security_group']['id'], "ingress", const.PROTO_NAME_TCP, + port_range_min=22, port_range_max=22, + remote_ip_prefix="10.0.2.0/24", + ethertype=const.IPv4, + tenant_id='admin-tenant') + self._make_security_group_rule(self.fmt, rule, admin_context=True) + + # Now, let's make sure all the rules are there, with their odd + # tenant_id behavior. + res = self.new_list_request('security-groups') + sgs = self.deserialize(self.fmt, res.get_response(self.ext_api)) + for sg in sgs['security_groups']: + if sg['name'] == "webservers": + rules = sg['security_group_rules'] + self.assertEqual(len(rules), 5) + self.assertNotEqual(rules[3]['tenant_id'], 'admin-tenant') + self.assertEqual(rules[4]['tenant_id'], 'admin-tenant') + def test_get_security_group_on_port_from_wrong_tenant(self): plugin = directory.get_plugin() if not hasattr(plugin, '_get_security_groups_on_port'):