Limit min<=max port check to TCP/UDP in secgroup rule

icmp_type and icmp_code are mapped to port_min_range and port_max_range
respectively. For ICMP there is no constraint between type and code.
Thus port range min<=max check should be enforced only for TCP and UDP.

Also makes sure that ICMP type/code are 0 to 255 (both inclusive).
Previously a value with 0 to 65535 were accepted for ICMP type/code.

Fixes bug 1197760
Fixes bug 1197769

Change-Id: I70aaf6e02fee461fa97dc254db906d9efa173669
This commit is contained in:
Akihiro MOTOKI 2013-07-05 01:27:18 +09:00
parent 93efc1dd78
commit 24e6ef332d
4 changed files with 152 additions and 16 deletions

View File

@ -34,7 +34,10 @@ INTERFACE_KEY = '_interfaces'
IPv4 = 'IPv4' IPv4 = 'IPv4'
IPv6 = 'IPv6' IPv6 = 'IPv6'
ICMP_PROTOCOL = 1
TCP_PROTOCOL = 6
UDP_PROTOCOL = 17 UDP_PROTOCOL = 17
DHCP_RESPONSE_PORT = 68 DHCP_RESPONSE_PORT = 68
MIN_VLAN_TAG = 1 MIN_VLAN_TAG = 1

View File

@ -22,6 +22,7 @@ from sqlalchemy.orm import exc
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import scoped_session
from neutron.api.v2 import attributes as attr from neutron.api.v2 import attributes as attr
from neutron.common import constants
from neutron.db import db_base_plugin_v2 from neutron.db import db_base_plugin_v2
from neutron.db import model_base from neutron.db import model_base
from neutron.db import models_v2 from neutron.db import models_v2
@ -29,6 +30,11 @@ from neutron.extensions import securitygroup as ext_sg
from neutron.openstack.common import uuidutils from neutron.openstack.common import uuidutils
IP_PROTOCOL_MAP = {'tcp': constants.TCP_PROTOCOL,
'udp': constants.UDP_PROTOCOL,
'icmp': constants.ICMP_PROTOCOL}
class SecurityGroup(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant): class SecurityGroup(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant):
"""Represents a v2 neutron security group.""" """Represents a v2 neutron security group."""
@ -284,6 +290,32 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
return self.create_security_group_rule_bulk_native(context, return self.create_security_group_rule_bulk_native(context,
bulk_rule)[0] bulk_rule)[0]
def _get_ip_proto_number(self, protocol):
if protocol is None:
return
return IP_PROTOCOL_MAP.get(protocol, protocol)
def _validate_port_range(self, rule):
"""Check that port_range is valid."""
if (rule['port_range_min'] is None and
rule['port_range_max'] is None):
return
if not rule['protocol']:
raise ext_sg.SecurityGroupProtocolRequiredWithPorts()
ip_proto = self._get_ip_proto_number(rule['protocol'])
if ip_proto in [constants.TCP_PROTOCOL, constants.UDP_PROTOCOL]:
if (rule['port_range_min'] is not None and
rule['port_range_min'] <= rule['port_range_max']):
pass
else:
raise ext_sg.SecurityGroupInvalidPortRange()
elif ip_proto == constants.ICMP_PROTOCOL:
for attr, field in [('port_range_min', 'type'),
('port_range_max', 'code')]:
if rule[attr] > 255:
raise ext_sg.SecurityGroupInvalidIcmpValue(
field=field, attr=attr, value=rule[attr])
def _validate_security_group_rules(self, context, security_group_rule): def _validate_security_group_rules(self, context, security_group_rule):
"""Check that rules being installed. """Check that rules being installed.
@ -297,16 +329,7 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
rule = rules.get('security_group_rule') rule = rules.get('security_group_rule')
new_rules.add(rule['security_group_id']) new_rules.add(rule['security_group_id'])
# Check that port_range's are valid self._validate_port_range(rule)
if (rule['port_range_min'] is None and
rule['port_range_max'] is None):
pass
elif (rule['port_range_min'] is not None and
rule['port_range_min'] <= rule['port_range_max']):
if not rule['protocol']:
raise ext_sg.SecurityGroupProtocolRequiredWithPorts()
else:
raise ext_sg.SecurityGroupInvalidPortRange()
if rule['remote_ip_prefix'] and rule['remote_group_id']: if rule['remote_ip_prefix'] and rule['remote_group_id']:
raise ext_sg.SecurityGroupRemoteGroupAndRemoteIpPrefix() raise ext_sg.SecurityGroupRemoteGroupAndRemoteIpPrefix()

View File

@ -39,6 +39,11 @@ class SecurityGroupInvalidPortValue(qexception.InvalidInput):
message = _("Invalid value for port %(port)s") message = _("Invalid value for port %(port)s")
class SecurityGroupInvalidIcmpValue(qexception.InvalidInput):
message = _("Invalid value for ICMP %(field)s (%(attr)s) "
"%(value)s. It must be 0 to 255.")
class SecurityGroupInUse(qexception.InUse): class SecurityGroupInUse(qexception.InUse):
message = _("Security Group %(id)s in use.") message = _("Security Group %(id)s in use.")

View File

@ -633,6 +633,57 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
for k, v, in keys: for k, v, in keys:
self.assertEqual(rule['security_group_rule'][k], v) self.assertEqual(rule['security_group_rule'][k], v)
def test_create_security_group_rule_icmp_with_type_and_code(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
direction = "ingress"
remote_ip_prefix = "10.0.0.0/24"
protocol = 'icmp'
# port_range_min (ICMP type) is greater than port_range_max
# (ICMP code) in order to confirm min <= max port check is
# not called for ICMP.
port_range_min = 8
port_range_max = 5
keys = [('remote_ip_prefix', remote_ip_prefix),
('security_group_id', security_group_id),
('direction', direction),
('protocol', protocol),
('port_range_min', port_range_min),
('port_range_max', port_range_max)]
with self.security_group_rule(security_group_id, direction,
protocol, port_range_min,
port_range_max,
remote_ip_prefix) as rule:
for k, v, in keys:
self.assertEqual(rule['security_group_rule'][k], v)
def test_create_security_group_rule_icmp_with_type_only(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
direction = "ingress"
remote_ip_prefix = "10.0.0.0/24"
protocol = 'icmp'
# ICMP type
port_range_min = 8
# ICMP code
port_range_max = None
keys = [('remote_ip_prefix', remote_ip_prefix),
('security_group_id', security_group_id),
('direction', direction),
('protocol', protocol),
('port_range_min', port_range_min),
('port_range_max', port_range_max)]
with self.security_group_rule(security_group_id, direction,
protocol, port_range_min,
port_range_max,
remote_ip_prefix) as rule:
for k, v, in keys:
self.assertEqual(rule['security_group_rule'][k], v)
def test_create_security_group_source_group_ip_and_ip_prefix(self): def test_create_security_group_source_group_ip_and_ip_prefix(self):
security_group_id = "4cd70774-cc67-4a87-9b39-7d1db38eb087" security_group_id = "4cd70774-cc67-4a87-9b39-7d1db38eb087"
direction = "ingress" direction = "ingress"
@ -757,12 +808,14 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
with self.security_group(name, description) as sg: with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id'] security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id): with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule( for protocol in ['tcp', 'udp', 6, 17]:
sg['security_group']['id'], 'ingress', 'tcp', '50', '22') rule = self._build_security_group_rule(
self._create_security_group_rule(self.fmt, rule) sg['security_group']['id'],
res = self._create_security_group_rule(self.fmt, rule) 'ingress', protocol, '50', '22')
self.deserialize(self.fmt, res) self._create_security_group_rule(self.fmt, rule)
self.assertEqual(res.status_int, 400) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400)
def test_create_security_group_rule_ports_but_no_protocol(self): def test_create_security_group_rule_ports_but_no_protocol(self):
name = 'webservers' name = 'webservers'
@ -777,6 +830,58 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
self.deserialize(self.fmt, res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_security_group_rule_port_range_min_only(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress', 'tcp', '22', None)
self._create_security_group_rule(self.fmt, rule)
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400)
def test_create_security_group_rule_port_range_max_only(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress', 'tcp', None, '22')
self._create_security_group_rule(self.fmt, rule)
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400)
def test_create_security_group_rule_icmp_type_too_big(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress', 'icmp', '256', None)
self._create_security_group_rule(self.fmt, rule)
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400)
def test_create_security_group_rule_icmp_code_too_big(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress', 'icmp', '8', '256')
self._create_security_group_rule(self.fmt, rule)
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400)
def test_list_ports_security_group(self): def test_list_ports_security_group(self):
with self.network() as n: with self.network() as n:
with self.subnet(n): with self.subnet(n):