SG protocol validation to allow numbers or names
SG rule protocol provided is validated against the DB rules' protocols for both number and name. The filter provided to DB is modified so that it is queried for records with both the protocol name and number, instead of exactly the type provided with the input. The returned DB rule record's protocol field is validated against the supplied SG protocol field for both name or number. This way, user is still allowed to enter protocol name or number to create a rule, and API compatibility is maintained. Change-Id: If4ad684e961433b8d9d3ec8fe2810585d3f6a093 Closes-Bug: #1215181
This commit is contained in:
parent
deb3a61a62
commit
913a64cc11
|
@ -131,6 +131,8 @@ IP_PROTOCOL_NAME_ALIASES = {PROTO_NAME_IPV6_ICMP_LEGACY: PROTO_NAME_IPV6_ICMP}
|
|||
VALID_DSCP_MARKS = [0, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34,
|
||||
36, 38, 40, 46, 48, 56]
|
||||
|
||||
IP_PROTOCOL_NUM_TO_NAME_MAP = {str(v): k for k, v in IP_PROTOCOL_MAP.items()}
|
||||
|
||||
# List of ICMPv6 types that should be allowed by default:
|
||||
# Multicast Listener Query (130),
|
||||
# Multicast Listener Report (131),
|
||||
|
|
|
@ -424,6 +424,17 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
|||
protocol = constants.IP_PROTOCOL_NAME_ALIASES[protocol]
|
||||
return int(constants.IP_PROTOCOL_MAP.get(protocol, protocol))
|
||||
|
||||
def _get_ip_proto_name_and_num(self, protocol):
|
||||
if protocol is None:
|
||||
return
|
||||
protocol = str(protocol)
|
||||
if protocol in constants.IP_PROTOCOL_MAP:
|
||||
return [protocol, str(constants.IP_PROTOCOL_MAP.get(protocol))]
|
||||
elif protocol in constants.IP_PROTOCOL_NUM_TO_NAME_MAP:
|
||||
return [constants.IP_PROTOCOL_NUM_TO_NAME_MAP.get(protocol),
|
||||
protocol]
|
||||
return [protocol, protocol]
|
||||
|
||||
def _validate_port_range(self, rule):
|
||||
"""Check that port_range is valid."""
|
||||
if (rule['port_range_min'] is None and
|
||||
|
@ -539,6 +550,10 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
|||
value = sgr.get(key)
|
||||
if value:
|
||||
res[key] = [value]
|
||||
# protocol field will get corresponding name and number
|
||||
value = sgr.get('protocol')
|
||||
if value:
|
||||
res['protocol'] = self._get_ip_proto_name_and_num(value)
|
||||
return res
|
||||
|
||||
def _check_for_duplicate_rules(self, context, security_group_rules):
|
||||
|
@ -571,7 +586,21 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
|||
# below to check for these corner cases.
|
||||
for db_rule in db_rules:
|
||||
rule_id = db_rule.pop('id', None)
|
||||
if (security_group_rule['security_group_rule'] == db_rule):
|
||||
# remove protocol and match separately for number and type
|
||||
db_protocol = db_rule.pop('protocol', None)
|
||||
sg_protocol = (
|
||||
security_group_rule['security_group_rule'].pop('protocol',
|
||||
None))
|
||||
is_protocol_matching = (
|
||||
self._get_ip_proto_name_and_num(db_protocol) ==
|
||||
self._get_ip_proto_name_and_num(sg_protocol))
|
||||
are_rules_matching = (
|
||||
security_group_rule['security_group_rule'] == db_rule)
|
||||
# reinstate protocol field for further processing
|
||||
if sg_protocol:
|
||||
security_group_rule['security_group_rule']['protocol'] = (
|
||||
sg_protocol)
|
||||
if (is_protocol_matching and are_rules_matching):
|
||||
raise ext_sg.SecurityGroupRuleExists(rule_id=rule_id)
|
||||
|
||||
def _validate_ip_prefix(self, rule):
|
||||
|
|
|
@ -235,3 +235,15 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
|
|||
mock_notify.assert_has_calls([mock.call('security_group_rule',
|
||||
'precommit_delete', mock.ANY, context=mock.ANY,
|
||||
security_group_rule_id=mock.ANY)])
|
||||
|
||||
def test_get_ip_proto_name_and_num(self):
|
||||
protocols = [constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_TCP),
|
||||
'blah', '111']
|
||||
protocol_names_nums = (
|
||||
[[constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_UDP)],
|
||||
[constants.PROTO_NAME_TCP, str(constants.PROTO_NUM_TCP)],
|
||||
['blah', 'blah'], ['111', '111']])
|
||||
|
||||
for i, protocol in enumerate(protocols):
|
||||
self.assertEqual(protocol_names_nums[i],
|
||||
self.mixin._get_ip_proto_name_and_num(protocol))
|
||||
|
|
|
@ -978,6 +978,40 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
|
|||
self.assertIn(sgr['security_group_rule']['id'],
|
||||
res.json['NeutronError']['message'])
|
||||
|
||||
def test_create_security_group_rule_duplicate_rules_proto_name_num(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NAME_TCP, '22', '22')
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NUM_TCP, '22', '22')
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
|
||||
|
||||
def test_create_security_group_rule_duplicate_rules_proto_num_name(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NUM_UDP, '50', '100')
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NAME_UDP, '50', '100')
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
|
||||
|
||||
def test_create_security_group_rule_min_port_greater_max(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
|
|
Loading…
Reference in New Issue