SG protocol validation to allow numbers or names

SG rule protocol provided is validated against the DB rules'
protocols for both number and name. The filter provided to DB
is modified so that it is queried for records with both the
protocol name and number, instead of exactly the type provided
with the input. The returned DB rule record's protocol field is
validated against the supplied SG protocol field for both name
or number.
This way, user is still allowed to enter protocol name or number
to create a rule, and API compatibility is maintained.

Closes-Bug: #1215181
(cherry picked from commit 913a64cc11)

Also squashed the following regression fix:

===

Don't drop 'protocol' from client supplied security_group_rule dict

If protocol was present in the dict, but was None, then it was never
re-instantiated after being popped out of the dict. This later resulted
in KeyError when trying to access the key on the dict.

Change-Id: I4985e7b54117bee3241d7365cb438197a09b9b86
Closes-Bug: #1566327
(cherry picked from commit 5a41caa47a)

===

Change-Id: If4ad684e961433b8d9d3ec8fe2810585d3f6a093
This commit is contained in:
Sreekumar S 2016-01-22 19:09:49 +05:30 committed by Ihar Hrachyshka
parent b435ec56af
commit 93d719a554
4 changed files with 84 additions and 1 deletions

View File

@ -128,6 +128,8 @@ IP_PROTOCOL_MAP = {PROTO_NAME_AH: PROTO_NUM_AH,
IP_PROTOCOL_NAME_ALIASES = {PROTO_NAME_IPV6_ICMP_LEGACY: PROTO_NAME_IPV6_ICMP}
IP_PROTOCOL_NUM_TO_NAME_MAP = {str(v): k for k, v in IP_PROTOCOL_MAP.items()}
# List of ICMPv6 types that should be allowed by default:
# Multicast Listener Query (130),
# Multicast Listener Report (131),

View File

@ -424,6 +424,17 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
protocol = constants.IP_PROTOCOL_NAME_ALIASES[protocol]
return int(constants.IP_PROTOCOL_MAP.get(protocol, protocol))
def _get_ip_proto_name_and_num(self, protocol):
if protocol is None:
return
protocol = str(protocol)
if protocol in constants.IP_PROTOCOL_MAP:
return [protocol, str(constants.IP_PROTOCOL_MAP.get(protocol))]
elif protocol in constants.IP_PROTOCOL_NUM_TO_NAME_MAP:
return [constants.IP_PROTOCOL_NUM_TO_NAME_MAP.get(protocol),
protocol]
return [protocol, protocol]
def _validate_port_range(self, rule):
"""Check that port_range is valid."""
if (rule['port_range_min'] is None and
@ -539,6 +550,10 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
value = sgr.get(key)
if value:
res[key] = [value]
# protocol field will get corresponding name and number
value = sgr.get('protocol')
if value:
res['protocol'] = self._get_ip_proto_name_and_num(value)
return res
def _check_for_duplicate_rules(self, context, security_group_rules):
@ -569,9 +584,16 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
# is changed which cannot be because other methods are already
# relying on this behavior. Therefore, we do the filtering
# below to check for these corner cases.
rule_dict = security_group_rule['security_group_rule'].copy()
sg_protocol = rule_dict.pop('protocol', None)
for db_rule in db_rules:
rule_id = db_rule.pop('id', None)
if (security_group_rule['security_group_rule'] == db_rule):
# remove protocol and match separately for number and type
db_protocol = db_rule.pop('protocol', None)
is_protocol_matching = (
self._get_ip_proto_name_and_num(db_protocol) ==
self._get_ip_proto_name_and_num(sg_protocol))
if (is_protocol_matching and rule_dict == db_rule):
raise ext_sg.SecurityGroupRuleExists(rule_id=rule_id)
def _validate_ip_prefix(self, rule):

View File

@ -91,6 +91,19 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
self.mixin.create_security_group_rule(
self.ctx, mock.MagicMock())
def test__check_for_duplicate_rules_in_db_does_not_drop_protocol(self):
with mock.patch.object(self.mixin, 'get_security_group_rules',
return_value=[mock.Mock()]):
context = mock.Mock()
rule_dict = {
'security_group_rule': {'protocol': None,
'tenant_id': 'fake',
'security_group_id': 'fake',
'direction': 'fake'}
}
self.mixin._check_for_duplicate_rules_in_db(context, rule_dict)
self.assertIn('protocol', rule_dict['security_group_rule'])
def test_delete_security_group_rule_in_use(self):
with mock.patch.object(registry, "notify") as mock_notify:
mock_notify.side_effect = exceptions.CallbackFailure(Exception())
@ -235,3 +248,15 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
mock_notify.assert_has_calls([mock.call('security_group_rule',
'precommit_delete', mock.ANY, context=mock.ANY,
security_group_rule_id=mock.ANY)])
def test_get_ip_proto_name_and_num(self):
protocols = [constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_TCP),
'blah', '111']
protocol_names_nums = (
[[constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_UDP)],
[constants.PROTO_NAME_TCP, str(constants.PROTO_NUM_TCP)],
['blah', 'blah'], ['111', '111']])
for i, protocol in enumerate(protocols):
self.assertEqual(protocol_names_nums[i],
self.mixin._get_ip_proto_name_and_num(protocol))

View File

@ -978,6 +978,40 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
self.assertIn(sgr['security_group_rule']['id'],
res.json['NeutronError']['message'])
def test_create_security_group_rule_duplicate_rules_proto_name_num(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NAME_TCP, '22', '22')
self._create_security_group_rule(self.fmt, rule)
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NUM_TCP, '22', '22')
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
def test_create_security_group_rule_duplicate_rules_proto_num_name(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NUM_UDP, '50', '100')
self._create_security_group_rule(self.fmt, rule)
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NAME_UDP, '50', '100')
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
def test_create_security_group_rule_min_port_greater_max(self):
name = 'webservers'
description = 'my webservers'