Merge "SG protocol validation to allow numbers or names" into stable/liberty

This commit is contained in:
Jenkins 2016-04-01 15:15:43 +00:00 committed by Gerrit Code Review
commit 7b088a2694
4 changed files with 87 additions and 10 deletions

View File

@ -130,6 +130,13 @@ PROTO_NUM_ICMP = 1
PROTO_NUM_ICMP_V6 = 58
PROTO_NUM_UDP = 17
IP_PROTOCOL_MAP = {PROTO_NAME_TCP: PROTO_NUM_TCP,
PROTO_NAME_UDP: PROTO_NUM_UDP,
PROTO_NAME_ICMP: PROTO_NUM_ICMP,
PROTO_NAME_ICMP_V6: PROTO_NUM_ICMP_V6}
IP_PROTOCOL_NUM_TO_NAME_MAP = {str(v): k for k, v in IP_PROTOCOL_MAP.items()}
# List of ICMPv6 types that should be allowed by default:
# Multicast Listener Query (130),
# Multicast Listener Report (131),

View File

@ -37,11 +37,6 @@ from neutron.extensions import securitygroup as ext_sg
LOG = logging.getLogger(__name__)
IP_PROTOCOL_MAP = {constants.PROTO_NAME_TCP: constants.PROTO_NUM_TCP,
constants.PROTO_NAME_UDP: constants.PROTO_NUM_UDP,
constants.PROTO_NAME_ICMP: constants.PROTO_NUM_ICMP,
constants.PROTO_NAME_ICMP_V6: constants.PROTO_NUM_ICMP_V6}
class SecurityGroup(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant):
"""Represents a v2 neutron security group."""
@ -418,7 +413,18 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
# problems with comparing int and string in PostgreSQL. Here this
# string is converted to int to give an opportunity to use it as
# before.
return int(IP_PROTOCOL_MAP.get(protocol, protocol))
return int(constants.IP_PROTOCOL_MAP.get(protocol, protocol))
def _get_ip_proto_name_and_num(self, protocol):
if protocol is None:
return
protocol = str(protocol)
if protocol in constants.IP_PROTOCOL_MAP:
return [protocol, str(constants.IP_PROTOCOL_MAP.get(protocol))]
elif protocol in constants.IP_PROTOCOL_NUM_TO_NAME_MAP:
return [constants.IP_PROTOCOL_NUM_TO_NAME_MAP.get(protocol),
protocol]
return [protocol, protocol]
def _validate_port_range(self, rule):
"""Check that port_range is valid."""
@ -525,6 +531,10 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
value = sgr.get(key)
if value:
res[key] = [value]
# protocol field will get corresponding name and number
value = sgr.get('protocol')
if value:
res['protocol'] = self._get_ip_proto_name_and_num(value)
return res
def _check_for_duplicate_rules(self, context, security_group_rules):
@ -553,10 +563,23 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
# relying on this behavior. Therefore, we do the filtering
# below to check for these corner cases.
for db_rule in db_rules:
# need to remove id from db_rule for matching
id = db_rule.pop('id')
if (security_group_rule['security_group_rule'] == db_rule):
raise ext_sg.SecurityGroupRuleExists(id=id)
rule_id = db_rule.pop('id', None)
# remove protocol and match separately for number and type
db_protocol = db_rule.pop('protocol', None)
sg_protocol = (
security_group_rule['security_group_rule'].pop('protocol',
None))
is_protocol_matching = (
self._get_ip_proto_name_and_num(db_protocol) ==
self._get_ip_proto_name_and_num(sg_protocol))
are_rules_matching = (
security_group_rule['security_group_rule'] == db_rule)
# reinstate protocol field for further processing
if sg_protocol:
security_group_rule['security_group_rule']['protocol'] = (
sg_protocol)
if (is_protocol_matching and are_rules_matching):
raise ext_sg.SecurityGroupRuleExists(id=rule_id)
def _validate_ip_prefix(self, rule):
"""Check that a valid cidr was specified as remote_ip_prefix

View File

@ -16,6 +16,7 @@ import testtools
from neutron.callbacks import exceptions
from neutron.callbacks import registry
from neutron.common import constants
from neutron import context
from neutron.db import common_db_mixin
from neutron.db import securitygroups_db
@ -83,3 +84,15 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
with testtools.ExpectedException(
securitygroup.SecurityGroupRuleNotFound):
self.mixin.delete_security_group_rule(self.ctx, 'foo_rule')
def test_get_ip_proto_name_and_num(self):
protocols = [constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_TCP),
'blah', '111']
protocol_names_nums = (
[[constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_UDP)],
[constants.PROTO_NAME_TCP, str(constants.PROTO_NUM_TCP)],
['blah', 'blah'], ['111', '111']])
for i, protocol in enumerate(protocols):
self.assertEqual(protocol_names_nums[i],
self.mixin._get_ip_proto_name_and_num(protocol))

View File

@ -954,6 +954,40 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, webob.exc.HTTPConflict.code)
def test_create_security_group_rule_duplicate_rules_proto_name_num(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NAME_TCP, '22', '22')
self._create_security_group_rule(self.fmt, rule)
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NUM_TCP, '22', '22')
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
def test_create_security_group_rule_duplicate_rules_proto_num_name(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NUM_UDP, '50', '100')
self._create_security_group_rule(self.fmt, rule)
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NAME_UDP, '50', '100')
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
def test_create_security_group_rule_min_port_greater_max(self):
name = 'webservers'
description = 'my webservers'