diff --git a/neutron/db/securitygroups_db.py b/neutron/db/securitygroups_db.py index 35240999442..33ca2a5cd9a 100644 --- a/neutron/db/securitygroups_db.py +++ b/neutron/db/securitygroups_db.py @@ -465,6 +465,14 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): rule['port_range_max'] is not None): raise ext_sg.SecurityGroupMissingIcmpType( value=rule['port_range_max']) + else: + # Only the protocols above support port ranges, raise otherwise. + # When min/max are the same it is just a single port. + if (rule['port_range_min'] is not None and + rule['port_range_max'] is not None and + rule['port_range_min'] != rule['port_range_max']): + raise ext_sg.SecurityGroupInvalidProtocolForPortRange( + protocol=ip_proto) def _validate_ethertype_and_protocol(self, rule): """Check if given ethertype and protocol are valid or not""" diff --git a/neutron/extensions/securitygroup.py b/neutron/extensions/securitygroup.py index a4b5172d40e..4cd376a7973 100644 --- a/neutron/extensions/securitygroup.py +++ b/neutron/extensions/securitygroup.py @@ -40,6 +40,11 @@ class SecurityGroupInvalidPortRange(nexception.InvalidInput): "<= port_range_max") +class SecurityGroupInvalidProtocolForPortRange(nexception.InvalidInput): + message = _("Invalid protocol %(protocol)s for port range, only " + "supported for TCP and UDP.") + + class SecurityGroupInvalidPortValue(nexception.InvalidInput): message = _("Invalid value for port %(port)s") diff --git a/neutron/tests/unit/db/test_securitygroups_db.py b/neutron/tests/unit/db/test_securitygroups_db.py index 85c52094acb..1586dbaef0b 100644 --- a/neutron/tests/unit/db/test_securitygroups_db.py +++ b/neutron/tests/unit/db/test_securitygroups_db.py @@ -440,3 +440,16 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase): {'port_range_min': pmin, 'port_range_max': pmax, 'protocol': protocol}) + + def test__validate_port_range_exception(self): + self.assertRaises(securitygroup.SecurityGroupInvalidPortValue, + self.mixin._validate_port_range, + {'port_range_min': 0, + 'port_range_max': None, + 'protocol': constants.PROTO_NAME_TCP}) + self.assertRaises( + securitygroup.SecurityGroupInvalidProtocolForPortRange, + self.mixin._validate_port_range, + {'port_range_min': 100, + 'port_range_max': 200, + 'protocol': '111'}) diff --git a/neutron/tests/unit/extensions/test_securitygroup.py b/neutron/tests/unit/extensions/test_securitygroup.py index 9bbd4839868..4e8b23b41ad 100644 --- a/neutron/tests/unit/extensions/test_securitygroup.py +++ b/neutron/tests/unit/extensions/test_securitygroup.py @@ -600,6 +600,45 @@ class TestSecurityGroups(SecurityGroupDBTestCase): self.deserialize(self.fmt, res) self.assertEqual(webob.exc.HTTPCreated.code, res.status_int) + def test_create_security_group_rule_protocol_as_number_with_port(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + protocol = 111 + rule = self._build_security_group_rule( + security_group_id, 'ingress', protocol, '70') + res = self._create_security_group_rule(self.fmt, rule) + self.deserialize(self.fmt, res) + self.assertEqual(webob.exc.HTTPCreated.code, res.status_int) + + def test_create_security_group_rule_protocol_as_number_range(self): + # This is a SG rule with a port range, but treated as a single + # port since min/max are the same. + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + protocol = 111 + rule = self._build_security_group_rule( + security_group_id, 'ingress', protocol, '70', '70') + res = self._create_security_group_rule(self.fmt, rule) + self.deserialize(self.fmt, res) + self.assertEqual(webob.exc.HTTPCreated.code, res.status_int) + + def test_create_security_group_rule_protocol_as_number_range_bad(self): + # Only certain protocols support a SG rule with a port range + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + protocol = 111 + rule = self._build_security_group_rule( + security_group_id, 'ingress', protocol, '70', '71') + res = self._create_security_group_rule(self.fmt, rule) + self.deserialize(self.fmt, res) + self.assertEqual(webob.exc.HTTPBadRequest.code, res.status_int) + def test_create_security_group_rule_case_insensitive(self): name = 'webservers' description = 'my webservers'