Merge "SG protocol validation to allow numbers or names" into stable/mitaka
This commit is contained in:
commit
d791170f55
|
@ -128,6 +128,8 @@ IP_PROTOCOL_MAP = {PROTO_NAME_AH: PROTO_NUM_AH,
|
|||
|
||||
IP_PROTOCOL_NAME_ALIASES = {PROTO_NAME_IPV6_ICMP_LEGACY: PROTO_NAME_IPV6_ICMP}
|
||||
|
||||
IP_PROTOCOL_NUM_TO_NAME_MAP = {str(v): k for k, v in IP_PROTOCOL_MAP.items()}
|
||||
|
||||
# List of ICMPv6 types that should be allowed by default:
|
||||
# Multicast Listener Query (130),
|
||||
# Multicast Listener Report (131),
|
||||
|
|
|
@ -424,6 +424,17 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
|||
protocol = constants.IP_PROTOCOL_NAME_ALIASES[protocol]
|
||||
return int(constants.IP_PROTOCOL_MAP.get(protocol, protocol))
|
||||
|
||||
def _get_ip_proto_name_and_num(self, protocol):
|
||||
if protocol is None:
|
||||
return
|
||||
protocol = str(protocol)
|
||||
if protocol in constants.IP_PROTOCOL_MAP:
|
||||
return [protocol, str(constants.IP_PROTOCOL_MAP.get(protocol))]
|
||||
elif protocol in constants.IP_PROTOCOL_NUM_TO_NAME_MAP:
|
||||
return [constants.IP_PROTOCOL_NUM_TO_NAME_MAP.get(protocol),
|
||||
protocol]
|
||||
return [protocol, protocol]
|
||||
|
||||
def _validate_port_range(self, rule):
|
||||
"""Check that port_range is valid."""
|
||||
if (rule['port_range_min'] is None and
|
||||
|
@ -539,6 +550,10 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
|||
value = sgr.get(key)
|
||||
if value:
|
||||
res[key] = [value]
|
||||
# protocol field will get corresponding name and number
|
||||
value = sgr.get('protocol')
|
||||
if value:
|
||||
res['protocol'] = self._get_ip_proto_name_and_num(value)
|
||||
return res
|
||||
|
||||
def _check_for_duplicate_rules(self, context, security_group_rules):
|
||||
|
@ -569,9 +584,16 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
|||
# is changed which cannot be because other methods are already
|
||||
# relying on this behavior. Therefore, we do the filtering
|
||||
# below to check for these corner cases.
|
||||
rule_dict = security_group_rule['security_group_rule'].copy()
|
||||
sg_protocol = rule_dict.pop('protocol', None)
|
||||
for db_rule in db_rules:
|
||||
rule_id = db_rule.pop('id', None)
|
||||
if (security_group_rule['security_group_rule'] == db_rule):
|
||||
# remove protocol and match separately for number and type
|
||||
db_protocol = db_rule.pop('protocol', None)
|
||||
is_protocol_matching = (
|
||||
self._get_ip_proto_name_and_num(db_protocol) ==
|
||||
self._get_ip_proto_name_and_num(sg_protocol))
|
||||
if (is_protocol_matching and rule_dict == db_rule):
|
||||
raise ext_sg.SecurityGroupRuleExists(rule_id=rule_id)
|
||||
|
||||
def _validate_ip_prefix(self, rule):
|
||||
|
|
|
@ -91,6 +91,19 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
|
|||
self.mixin.create_security_group_rule(
|
||||
self.ctx, mock.MagicMock())
|
||||
|
||||
def test__check_for_duplicate_rules_in_db_does_not_drop_protocol(self):
|
||||
with mock.patch.object(self.mixin, 'get_security_group_rules',
|
||||
return_value=[mock.Mock()]):
|
||||
context = mock.Mock()
|
||||
rule_dict = {
|
||||
'security_group_rule': {'protocol': None,
|
||||
'tenant_id': 'fake',
|
||||
'security_group_id': 'fake',
|
||||
'direction': 'fake'}
|
||||
}
|
||||
self.mixin._check_for_duplicate_rules_in_db(context, rule_dict)
|
||||
self.assertIn('protocol', rule_dict['security_group_rule'])
|
||||
|
||||
def test_delete_security_group_rule_in_use(self):
|
||||
with mock.patch.object(registry, "notify") as mock_notify:
|
||||
mock_notify.side_effect = exceptions.CallbackFailure(Exception())
|
||||
|
@ -235,3 +248,15 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
|
|||
mock_notify.assert_has_calls([mock.call('security_group_rule',
|
||||
'precommit_delete', mock.ANY, context=mock.ANY,
|
||||
security_group_rule_id=mock.ANY)])
|
||||
|
||||
def test_get_ip_proto_name_and_num(self):
|
||||
protocols = [constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_TCP),
|
||||
'blah', '111']
|
||||
protocol_names_nums = (
|
||||
[[constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_UDP)],
|
||||
[constants.PROTO_NAME_TCP, str(constants.PROTO_NUM_TCP)],
|
||||
['blah', 'blah'], ['111', '111']])
|
||||
|
||||
for i, protocol in enumerate(protocols):
|
||||
self.assertEqual(protocol_names_nums[i],
|
||||
self.mixin._get_ip_proto_name_and_num(protocol))
|
||||
|
|
|
@ -978,6 +978,40 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
|
|||
self.assertIn(sgr['security_group_rule']['id'],
|
||||
res.json['NeutronError']['message'])
|
||||
|
||||
def test_create_security_group_rule_duplicate_rules_proto_name_num(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NAME_TCP, '22', '22')
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NUM_TCP, '22', '22')
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
|
||||
|
||||
def test_create_security_group_rule_duplicate_rules_proto_num_name(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NUM_UDP, '50', '100')
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NAME_UDP, '50', '100')
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
|
||||
|
||||
def test_create_security_group_rule_min_port_greater_max(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
|
|
Loading…
Reference in New Issue