From 913a64cc1175b3bd7efc7abe34895c32bf39a696 Mon Sep 17 00:00:00 2001 From: Sreekumar S Date: Fri, 22 Jan 2016 19:09:49 +0530 Subject: [PATCH] SG protocol validation to allow numbers or names SG rule protocol provided is validated against the DB rules' protocols for both number and name. The filter provided to DB is modified so that it is queried for records with both the protocol name and number, instead of exactly the type provided with the input. The returned DB rule record's protocol field is validated against the supplied SG protocol field for both name or number. This way, user is still allowed to enter protocol name or number to create a rule, and API compatibility is maintained. Change-Id: If4ad684e961433b8d9d3ec8fe2810585d3f6a093 Closes-Bug: #1215181 --- neutron/common/constants.py | 2 ++ neutron/db/securitygroups_db.py | 31 ++++++++++++++++- .../tests/unit/db/test_securitygroups_db.py | 12 +++++++ .../unit/extensions/test_securitygroup.py | 34 +++++++++++++++++++ 4 files changed, 78 insertions(+), 1 deletion(-) diff --git a/neutron/common/constants.py b/neutron/common/constants.py index 410012ca360..b2d9cd64e5a 100644 --- a/neutron/common/constants.py +++ b/neutron/common/constants.py @@ -131,6 +131,8 @@ IP_PROTOCOL_NAME_ALIASES = {PROTO_NAME_IPV6_ICMP_LEGACY: PROTO_NAME_IPV6_ICMP} VALID_DSCP_MARKS = [0, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 46, 48, 56] +IP_PROTOCOL_NUM_TO_NAME_MAP = {str(v): k for k, v in IP_PROTOCOL_MAP.items()} + # List of ICMPv6 types that should be allowed by default: # Multicast Listener Query (130), # Multicast Listener Report (131), diff --git a/neutron/db/securitygroups_db.py b/neutron/db/securitygroups_db.py index 2758a9d717f..81b2ef91edf 100644 --- a/neutron/db/securitygroups_db.py +++ b/neutron/db/securitygroups_db.py @@ -424,6 +424,17 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): protocol = constants.IP_PROTOCOL_NAME_ALIASES[protocol] return int(constants.IP_PROTOCOL_MAP.get(protocol, protocol)) + def _get_ip_proto_name_and_num(self, protocol): + if protocol is None: + return + protocol = str(protocol) + if protocol in constants.IP_PROTOCOL_MAP: + return [protocol, str(constants.IP_PROTOCOL_MAP.get(protocol))] + elif protocol in constants.IP_PROTOCOL_NUM_TO_NAME_MAP: + return [constants.IP_PROTOCOL_NUM_TO_NAME_MAP.get(protocol), + protocol] + return [protocol, protocol] + def _validate_port_range(self, rule): """Check that port_range is valid.""" if (rule['port_range_min'] is None and @@ -539,6 +550,10 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): value = sgr.get(key) if value: res[key] = [value] + # protocol field will get corresponding name and number + value = sgr.get('protocol') + if value: + res['protocol'] = self._get_ip_proto_name_and_num(value) return res def _check_for_duplicate_rules(self, context, security_group_rules): @@ -571,7 +586,21 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): # below to check for these corner cases. for db_rule in db_rules: rule_id = db_rule.pop('id', None) - if (security_group_rule['security_group_rule'] == db_rule): + # remove protocol and match separately for number and type + db_protocol = db_rule.pop('protocol', None) + sg_protocol = ( + security_group_rule['security_group_rule'].pop('protocol', + None)) + is_protocol_matching = ( + self._get_ip_proto_name_and_num(db_protocol) == + self._get_ip_proto_name_and_num(sg_protocol)) + are_rules_matching = ( + security_group_rule['security_group_rule'] == db_rule) + # reinstate protocol field for further processing + if sg_protocol: + security_group_rule['security_group_rule']['protocol'] = ( + sg_protocol) + if (is_protocol_matching and are_rules_matching): raise ext_sg.SecurityGroupRuleExists(rule_id=rule_id) def _validate_ip_prefix(self, rule): diff --git a/neutron/tests/unit/db/test_securitygroups_db.py b/neutron/tests/unit/db/test_securitygroups_db.py index 473118c5e72..248a1c784da 100644 --- a/neutron/tests/unit/db/test_securitygroups_db.py +++ b/neutron/tests/unit/db/test_securitygroups_db.py @@ -235,3 +235,15 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): mock_notify.assert_has_calls([mock.call('security_group_rule', 'precommit_delete', mock.ANY, context=mock.ANY, security_group_rule_id=mock.ANY)]) + + def test_get_ip_proto_name_and_num(self): + protocols = [constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_TCP), + 'blah', '111'] + protocol_names_nums = ( + [[constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_UDP)], + [constants.PROTO_NAME_TCP, str(constants.PROTO_NUM_TCP)], + ['blah', 'blah'], ['111', '111']]) + + for i, protocol in enumerate(protocols): + self.assertEqual(protocol_names_nums[i], + self.mixin._get_ip_proto_name_and_num(protocol)) diff --git a/neutron/tests/unit/extensions/test_securitygroup.py b/neutron/tests/unit/extensions/test_securitygroup.py index 06cbdfafb96..c0e99852ccf 100644 --- a/neutron/tests/unit/extensions/test_securitygroup.py +++ b/neutron/tests/unit/extensions/test_securitygroup.py @@ -978,6 +978,40 @@ class TestSecurityGroups(SecurityGroupDBTestCase): self.assertIn(sgr['security_group_rule']['id'], res.json['NeutronError']['message']) + def test_create_security_group_rule_duplicate_rules_proto_name_num(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + with self.security_group_rule(security_group_id): + rule = self._build_security_group_rule( + sg['security_group']['id'], 'ingress', + const.PROTO_NAME_TCP, '22', '22') + self._create_security_group_rule(self.fmt, rule) + rule = self._build_security_group_rule( + sg['security_group']['id'], 'ingress', + const.PROTO_NUM_TCP, '22', '22') + res = self._create_security_group_rule(self.fmt, rule) + self.deserialize(self.fmt, res) + self.assertEqual(webob.exc.HTTPConflict.code, res.status_int) + + def test_create_security_group_rule_duplicate_rules_proto_num_name(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + with self.security_group_rule(security_group_id): + rule = self._build_security_group_rule( + sg['security_group']['id'], 'ingress', + const.PROTO_NUM_UDP, '50', '100') + self._create_security_group_rule(self.fmt, rule) + rule = self._build_security_group_rule( + sg['security_group']['id'], 'ingress', + const.PROTO_NAME_UDP, '50', '100') + res = self._create_security_group_rule(self.fmt, rule) + self.deserialize(self.fmt, res) + self.assertEqual(webob.exc.HTTPConflict.code, res.status_int) + def test_create_security_group_rule_min_port_greater_max(self): name = 'webservers' description = 'my webservers'