SG protocol validation to allow numbers or names
SG rule protocol provided is validated against the DB rules' protocols for both number and name. The filter provided to DB is modified so that it is queried for records with both the protocol name and number, instead of exactly the type provided with the input. The returned DB rule record's protocol field is validated against the supplied SG protocol field for both name or number. This way, user is still allowed to enter protocol name or number to create a rule, and API compatibility is maintained. Closes-Bug: #1215181 (cherry picked from commit913a64cc11
) Also squashed the following regression fix: === Don't drop 'protocol' from client supplied security_group_rule dict If protocol was present in the dict, but was None, then it was never re-instantiated after being popped out of the dict. This later resulted in KeyError when trying to access the key on the dict. Change-Id: I4985e7b54117bee3241d7365cb438197a09b9b86 Closes-Bug: #1566327 (cherry picked from commit5a41caa47a
) === Change-Id: If4ad684e961433b8d9d3ec8fe2810585d3f6a093
This commit is contained in:
parent
b435ec56af
commit
93d719a554
|
@ -128,6 +128,8 @@ IP_PROTOCOL_MAP = {PROTO_NAME_AH: PROTO_NUM_AH,
|
|||
|
||||
IP_PROTOCOL_NAME_ALIASES = {PROTO_NAME_IPV6_ICMP_LEGACY: PROTO_NAME_IPV6_ICMP}
|
||||
|
||||
IP_PROTOCOL_NUM_TO_NAME_MAP = {str(v): k for k, v in IP_PROTOCOL_MAP.items()}
|
||||
|
||||
# List of ICMPv6 types that should be allowed by default:
|
||||
# Multicast Listener Query (130),
|
||||
# Multicast Listener Report (131),
|
||||
|
|
|
@ -424,6 +424,17 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
|||
protocol = constants.IP_PROTOCOL_NAME_ALIASES[protocol]
|
||||
return int(constants.IP_PROTOCOL_MAP.get(protocol, protocol))
|
||||
|
||||
def _get_ip_proto_name_and_num(self, protocol):
|
||||
if protocol is None:
|
||||
return
|
||||
protocol = str(protocol)
|
||||
if protocol in constants.IP_PROTOCOL_MAP:
|
||||
return [protocol, str(constants.IP_PROTOCOL_MAP.get(protocol))]
|
||||
elif protocol in constants.IP_PROTOCOL_NUM_TO_NAME_MAP:
|
||||
return [constants.IP_PROTOCOL_NUM_TO_NAME_MAP.get(protocol),
|
||||
protocol]
|
||||
return [protocol, protocol]
|
||||
|
||||
def _validate_port_range(self, rule):
|
||||
"""Check that port_range is valid."""
|
||||
if (rule['port_range_min'] is None and
|
||||
|
@ -539,6 +550,10 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
|||
value = sgr.get(key)
|
||||
if value:
|
||||
res[key] = [value]
|
||||
# protocol field will get corresponding name and number
|
||||
value = sgr.get('protocol')
|
||||
if value:
|
||||
res['protocol'] = self._get_ip_proto_name_and_num(value)
|
||||
return res
|
||||
|
||||
def _check_for_duplicate_rules(self, context, security_group_rules):
|
||||
|
@ -569,9 +584,16 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
|||
# is changed which cannot be because other methods are already
|
||||
# relying on this behavior. Therefore, we do the filtering
|
||||
# below to check for these corner cases.
|
||||
rule_dict = security_group_rule['security_group_rule'].copy()
|
||||
sg_protocol = rule_dict.pop('protocol', None)
|
||||
for db_rule in db_rules:
|
||||
rule_id = db_rule.pop('id', None)
|
||||
if (security_group_rule['security_group_rule'] == db_rule):
|
||||
# remove protocol and match separately for number and type
|
||||
db_protocol = db_rule.pop('protocol', None)
|
||||
is_protocol_matching = (
|
||||
self._get_ip_proto_name_and_num(db_protocol) ==
|
||||
self._get_ip_proto_name_and_num(sg_protocol))
|
||||
if (is_protocol_matching and rule_dict == db_rule):
|
||||
raise ext_sg.SecurityGroupRuleExists(rule_id=rule_id)
|
||||
|
||||
def _validate_ip_prefix(self, rule):
|
||||
|
|
|
@ -91,6 +91,19 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
|
|||
self.mixin.create_security_group_rule(
|
||||
self.ctx, mock.MagicMock())
|
||||
|
||||
def test__check_for_duplicate_rules_in_db_does_not_drop_protocol(self):
|
||||
with mock.patch.object(self.mixin, 'get_security_group_rules',
|
||||
return_value=[mock.Mock()]):
|
||||
context = mock.Mock()
|
||||
rule_dict = {
|
||||
'security_group_rule': {'protocol': None,
|
||||
'tenant_id': 'fake',
|
||||
'security_group_id': 'fake',
|
||||
'direction': 'fake'}
|
||||
}
|
||||
self.mixin._check_for_duplicate_rules_in_db(context, rule_dict)
|
||||
self.assertIn('protocol', rule_dict['security_group_rule'])
|
||||
|
||||
def test_delete_security_group_rule_in_use(self):
|
||||
with mock.patch.object(registry, "notify") as mock_notify:
|
||||
mock_notify.side_effect = exceptions.CallbackFailure(Exception())
|
||||
|
@ -235,3 +248,15 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
|
|||
mock_notify.assert_has_calls([mock.call('security_group_rule',
|
||||
'precommit_delete', mock.ANY, context=mock.ANY,
|
||||
security_group_rule_id=mock.ANY)])
|
||||
|
||||
def test_get_ip_proto_name_and_num(self):
|
||||
protocols = [constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_TCP),
|
||||
'blah', '111']
|
||||
protocol_names_nums = (
|
||||
[[constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_UDP)],
|
||||
[constants.PROTO_NAME_TCP, str(constants.PROTO_NUM_TCP)],
|
||||
['blah', 'blah'], ['111', '111']])
|
||||
|
||||
for i, protocol in enumerate(protocols):
|
||||
self.assertEqual(protocol_names_nums[i],
|
||||
self.mixin._get_ip_proto_name_and_num(protocol))
|
||||
|
|
|
@ -978,6 +978,40 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
|
|||
self.assertIn(sgr['security_group_rule']['id'],
|
||||
res.json['NeutronError']['message'])
|
||||
|
||||
def test_create_security_group_rule_duplicate_rules_proto_name_num(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NAME_TCP, '22', '22')
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NUM_TCP, '22', '22')
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
|
||||
|
||||
def test_create_security_group_rule_duplicate_rules_proto_num_name(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NUM_UDP, '50', '100')
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NAME_UDP, '50', '100')
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
|
||||
|
||||
def test_create_security_group_rule_min_port_greater_max(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
|
|
Loading…
Reference in New Issue