Merge "SG protocol validation to allow numbers or names" into stable/mitaka

This commit is contained in:
Jenkins 2016-04-12 11:03:50 +00:00 committed by Gerrit Code Review
commit d791170f55
4 changed files with 84 additions and 1 deletions

View File

@ -128,6 +128,8 @@ IP_PROTOCOL_MAP = {PROTO_NAME_AH: PROTO_NUM_AH,
IP_PROTOCOL_NAME_ALIASES = {PROTO_NAME_IPV6_ICMP_LEGACY: PROTO_NAME_IPV6_ICMP}
IP_PROTOCOL_NUM_TO_NAME_MAP = {str(v): k for k, v in IP_PROTOCOL_MAP.items()}
# List of ICMPv6 types that should be allowed by default:
# Multicast Listener Query (130),
# Multicast Listener Report (131),

View File

@ -424,6 +424,17 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
protocol = constants.IP_PROTOCOL_NAME_ALIASES[protocol]
return int(constants.IP_PROTOCOL_MAP.get(protocol, protocol))
def _get_ip_proto_name_and_num(self, protocol):
if protocol is None:
return
protocol = str(protocol)
if protocol in constants.IP_PROTOCOL_MAP:
return [protocol, str(constants.IP_PROTOCOL_MAP.get(protocol))]
elif protocol in constants.IP_PROTOCOL_NUM_TO_NAME_MAP:
return [constants.IP_PROTOCOL_NUM_TO_NAME_MAP.get(protocol),
protocol]
return [protocol, protocol]
def _validate_port_range(self, rule):
"""Check that port_range is valid."""
if (rule['port_range_min'] is None and
@ -539,6 +550,10 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
value = sgr.get(key)
if value:
res[key] = [value]
# protocol field will get corresponding name and number
value = sgr.get('protocol')
if value:
res['protocol'] = self._get_ip_proto_name_and_num(value)
return res
def _check_for_duplicate_rules(self, context, security_group_rules):
@ -569,9 +584,16 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
# is changed which cannot be because other methods are already
# relying on this behavior. Therefore, we do the filtering
# below to check for these corner cases.
rule_dict = security_group_rule['security_group_rule'].copy()
sg_protocol = rule_dict.pop('protocol', None)
for db_rule in db_rules:
rule_id = db_rule.pop('id', None)
if (security_group_rule['security_group_rule'] == db_rule):
# remove protocol and match separately for number and type
db_protocol = db_rule.pop('protocol', None)
is_protocol_matching = (
self._get_ip_proto_name_and_num(db_protocol) ==
self._get_ip_proto_name_and_num(sg_protocol))
if (is_protocol_matching and rule_dict == db_rule):
raise ext_sg.SecurityGroupRuleExists(rule_id=rule_id)
def _validate_ip_prefix(self, rule):

View File

@ -91,6 +91,19 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
self.mixin.create_security_group_rule(
self.ctx, mock.MagicMock())
def test__check_for_duplicate_rules_in_db_does_not_drop_protocol(self):
with mock.patch.object(self.mixin, 'get_security_group_rules',
return_value=[mock.Mock()]):
context = mock.Mock()
rule_dict = {
'security_group_rule': {'protocol': None,
'tenant_id': 'fake',
'security_group_id': 'fake',
'direction': 'fake'}
}
self.mixin._check_for_duplicate_rules_in_db(context, rule_dict)
self.assertIn('protocol', rule_dict['security_group_rule'])
def test_delete_security_group_rule_in_use(self):
with mock.patch.object(registry, "notify") as mock_notify:
mock_notify.side_effect = exceptions.CallbackFailure(Exception())
@ -235,3 +248,15 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
mock_notify.assert_has_calls([mock.call('security_group_rule',
'precommit_delete', mock.ANY, context=mock.ANY,
security_group_rule_id=mock.ANY)])
def test_get_ip_proto_name_and_num(self):
protocols = [constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_TCP),
'blah', '111']
protocol_names_nums = (
[[constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_UDP)],
[constants.PROTO_NAME_TCP, str(constants.PROTO_NUM_TCP)],
['blah', 'blah'], ['111', '111']])
for i, protocol in enumerate(protocols):
self.assertEqual(protocol_names_nums[i],
self.mixin._get_ip_proto_name_and_num(protocol))

View File

@ -978,6 +978,40 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
self.assertIn(sgr['security_group_rule']['id'],
res.json['NeutronError']['message'])
def test_create_security_group_rule_duplicate_rules_proto_name_num(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NAME_TCP, '22', '22')
self._create_security_group_rule(self.fmt, rule)
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NUM_TCP, '22', '22')
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
def test_create_security_group_rule_duplicate_rules_proto_num_name(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NUM_UDP, '50', '100')
self._create_security_group_rule(self.fmt, rule)
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NAME_UDP, '50', '100')
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
def test_create_security_group_rule_min_port_greater_max(self):
name = 'webservers'
description = 'my webservers'