SG protocol validation to allow numbers or names

SG rule protocol provided is validated against the DB rules'
protocols for both number and name. The filter provided to DB
is modified so that it is queried for records with both the
protocol name and number, instead of exactly the type provided
with the input. The returned DB rule record's protocol field is
validated against the supplied SG protocol field for both name
or number.
This way, user is still allowed to enter protocol name or number
to create a rule, and API compatibility is maintained.

Change-Id: If4ad684e961433b8d9d3ec8fe2810585d3f6a093
Closes-Bug: #1215181
This commit is contained in:
Sreekumar S 2016-01-22 19:09:49 +05:30
parent deb3a61a62
commit 913a64cc11
4 changed files with 78 additions and 1 deletions

View File

@ -131,6 +131,8 @@ IP_PROTOCOL_NAME_ALIASES = {PROTO_NAME_IPV6_ICMP_LEGACY: PROTO_NAME_IPV6_ICMP}
VALID_DSCP_MARKS = [0, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34,
36, 38, 40, 46, 48, 56]
IP_PROTOCOL_NUM_TO_NAME_MAP = {str(v): k for k, v in IP_PROTOCOL_MAP.items()}
# List of ICMPv6 types that should be allowed by default:
# Multicast Listener Query (130),
# Multicast Listener Report (131),

View File

@ -424,6 +424,17 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
protocol = constants.IP_PROTOCOL_NAME_ALIASES[protocol]
return int(constants.IP_PROTOCOL_MAP.get(protocol, protocol))
def _get_ip_proto_name_and_num(self, protocol):
if protocol is None:
return
protocol = str(protocol)
if protocol in constants.IP_PROTOCOL_MAP:
return [protocol, str(constants.IP_PROTOCOL_MAP.get(protocol))]
elif protocol in constants.IP_PROTOCOL_NUM_TO_NAME_MAP:
return [constants.IP_PROTOCOL_NUM_TO_NAME_MAP.get(protocol),
protocol]
return [protocol, protocol]
def _validate_port_range(self, rule):
"""Check that port_range is valid."""
if (rule['port_range_min'] is None and
@ -539,6 +550,10 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
value = sgr.get(key)
if value:
res[key] = [value]
# protocol field will get corresponding name and number
value = sgr.get('protocol')
if value:
res['protocol'] = self._get_ip_proto_name_and_num(value)
return res
def _check_for_duplicate_rules(self, context, security_group_rules):
@ -571,7 +586,21 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
# below to check for these corner cases.
for db_rule in db_rules:
rule_id = db_rule.pop('id', None)
if (security_group_rule['security_group_rule'] == db_rule):
# remove protocol and match separately for number and type
db_protocol = db_rule.pop('protocol', None)
sg_protocol = (
security_group_rule['security_group_rule'].pop('protocol',
None))
is_protocol_matching = (
self._get_ip_proto_name_and_num(db_protocol) ==
self._get_ip_proto_name_and_num(sg_protocol))
are_rules_matching = (
security_group_rule['security_group_rule'] == db_rule)
# reinstate protocol field for further processing
if sg_protocol:
security_group_rule['security_group_rule']['protocol'] = (
sg_protocol)
if (is_protocol_matching and are_rules_matching):
raise ext_sg.SecurityGroupRuleExists(rule_id=rule_id)
def _validate_ip_prefix(self, rule):

View File

@ -235,3 +235,15 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
mock_notify.assert_has_calls([mock.call('security_group_rule',
'precommit_delete', mock.ANY, context=mock.ANY,
security_group_rule_id=mock.ANY)])
def test_get_ip_proto_name_and_num(self):
protocols = [constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_TCP),
'blah', '111']
protocol_names_nums = (
[[constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_UDP)],
[constants.PROTO_NAME_TCP, str(constants.PROTO_NUM_TCP)],
['blah', 'blah'], ['111', '111']])
for i, protocol in enumerate(protocols):
self.assertEqual(protocol_names_nums[i],
self.mixin._get_ip_proto_name_and_num(protocol))

View File

@ -978,6 +978,40 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
self.assertIn(sgr['security_group_rule']['id'],
res.json['NeutronError']['message'])
def test_create_security_group_rule_duplicate_rules_proto_name_num(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NAME_TCP, '22', '22')
self._create_security_group_rule(self.fmt, rule)
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NUM_TCP, '22', '22')
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
def test_create_security_group_rule_duplicate_rules_proto_num_name(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NUM_UDP, '50', '100')
self._create_security_group_rule(self.fmt, rule)
rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress',
const.PROTO_NAME_UDP, '50', '100')
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
def test_create_security_group_rule_min_port_greater_max(self):
name = 'webservers'
description = 'my webservers'