Merge "SG protocol validation to allow numbers or names" into stable/liberty
This commit is contained in:
commit
7b088a2694
|
@ -130,6 +130,13 @@ PROTO_NUM_ICMP = 1
|
|||
PROTO_NUM_ICMP_V6 = 58
|
||||
PROTO_NUM_UDP = 17
|
||||
|
||||
IP_PROTOCOL_MAP = {PROTO_NAME_TCP: PROTO_NUM_TCP,
|
||||
PROTO_NAME_UDP: PROTO_NUM_UDP,
|
||||
PROTO_NAME_ICMP: PROTO_NUM_ICMP,
|
||||
PROTO_NAME_ICMP_V6: PROTO_NUM_ICMP_V6}
|
||||
|
||||
IP_PROTOCOL_NUM_TO_NAME_MAP = {str(v): k for k, v in IP_PROTOCOL_MAP.items()}
|
||||
|
||||
# List of ICMPv6 types that should be allowed by default:
|
||||
# Multicast Listener Query (130),
|
||||
# Multicast Listener Report (131),
|
||||
|
|
|
@ -37,11 +37,6 @@ from neutron.extensions import securitygroup as ext_sg
|
|||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
IP_PROTOCOL_MAP = {constants.PROTO_NAME_TCP: constants.PROTO_NUM_TCP,
|
||||
constants.PROTO_NAME_UDP: constants.PROTO_NUM_UDP,
|
||||
constants.PROTO_NAME_ICMP: constants.PROTO_NUM_ICMP,
|
||||
constants.PROTO_NAME_ICMP_V6: constants.PROTO_NUM_ICMP_V6}
|
||||
|
||||
|
||||
class SecurityGroup(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant):
|
||||
"""Represents a v2 neutron security group."""
|
||||
|
@ -418,7 +413,18 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
|||
# problems with comparing int and string in PostgreSQL. Here this
|
||||
# string is converted to int to give an opportunity to use it as
|
||||
# before.
|
||||
return int(IP_PROTOCOL_MAP.get(protocol, protocol))
|
||||
return int(constants.IP_PROTOCOL_MAP.get(protocol, protocol))
|
||||
|
||||
def _get_ip_proto_name_and_num(self, protocol):
|
||||
if protocol is None:
|
||||
return
|
||||
protocol = str(protocol)
|
||||
if protocol in constants.IP_PROTOCOL_MAP:
|
||||
return [protocol, str(constants.IP_PROTOCOL_MAP.get(protocol))]
|
||||
elif protocol in constants.IP_PROTOCOL_NUM_TO_NAME_MAP:
|
||||
return [constants.IP_PROTOCOL_NUM_TO_NAME_MAP.get(protocol),
|
||||
protocol]
|
||||
return [protocol, protocol]
|
||||
|
||||
def _validate_port_range(self, rule):
|
||||
"""Check that port_range is valid."""
|
||||
|
@ -525,6 +531,10 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
|||
value = sgr.get(key)
|
||||
if value:
|
||||
res[key] = [value]
|
||||
# protocol field will get corresponding name and number
|
||||
value = sgr.get('protocol')
|
||||
if value:
|
||||
res['protocol'] = self._get_ip_proto_name_and_num(value)
|
||||
return res
|
||||
|
||||
def _check_for_duplicate_rules(self, context, security_group_rules):
|
||||
|
@ -553,10 +563,23 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
|||
# relying on this behavior. Therefore, we do the filtering
|
||||
# below to check for these corner cases.
|
||||
for db_rule in db_rules:
|
||||
# need to remove id from db_rule for matching
|
||||
id = db_rule.pop('id')
|
||||
if (security_group_rule['security_group_rule'] == db_rule):
|
||||
raise ext_sg.SecurityGroupRuleExists(id=id)
|
||||
rule_id = db_rule.pop('id', None)
|
||||
# remove protocol and match separately for number and type
|
||||
db_protocol = db_rule.pop('protocol', None)
|
||||
sg_protocol = (
|
||||
security_group_rule['security_group_rule'].pop('protocol',
|
||||
None))
|
||||
is_protocol_matching = (
|
||||
self._get_ip_proto_name_and_num(db_protocol) ==
|
||||
self._get_ip_proto_name_and_num(sg_protocol))
|
||||
are_rules_matching = (
|
||||
security_group_rule['security_group_rule'] == db_rule)
|
||||
# reinstate protocol field for further processing
|
||||
if sg_protocol:
|
||||
security_group_rule['security_group_rule']['protocol'] = (
|
||||
sg_protocol)
|
||||
if (is_protocol_matching and are_rules_matching):
|
||||
raise ext_sg.SecurityGroupRuleExists(id=rule_id)
|
||||
|
||||
def _validate_ip_prefix(self, rule):
|
||||
"""Check that a valid cidr was specified as remote_ip_prefix
|
||||
|
|
|
@ -16,6 +16,7 @@ import testtools
|
|||
|
||||
from neutron.callbacks import exceptions
|
||||
from neutron.callbacks import registry
|
||||
from neutron.common import constants
|
||||
from neutron import context
|
||||
from neutron.db import common_db_mixin
|
||||
from neutron.db import securitygroups_db
|
||||
|
@ -83,3 +84,15 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
|
|||
with testtools.ExpectedException(
|
||||
securitygroup.SecurityGroupRuleNotFound):
|
||||
self.mixin.delete_security_group_rule(self.ctx, 'foo_rule')
|
||||
|
||||
def test_get_ip_proto_name_and_num(self):
|
||||
protocols = [constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_TCP),
|
||||
'blah', '111']
|
||||
protocol_names_nums = (
|
||||
[[constants.PROTO_NAME_UDP, str(constants.PROTO_NUM_UDP)],
|
||||
[constants.PROTO_NAME_TCP, str(constants.PROTO_NUM_TCP)],
|
||||
['blah', 'blah'], ['111', '111']])
|
||||
|
||||
for i, protocol in enumerate(protocols):
|
||||
self.assertEqual(protocol_names_nums[i],
|
||||
self.mixin._get_ip_proto_name_and_num(protocol))
|
||||
|
|
|
@ -954,6 +954,40 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
|
|||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(res.status_int, webob.exc.HTTPConflict.code)
|
||||
|
||||
def test_create_security_group_rule_duplicate_rules_proto_name_num(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NAME_TCP, '22', '22')
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NUM_TCP, '22', '22')
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
|
||||
|
||||
def test_create_security_group_rule_duplicate_rules_proto_num_name(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NUM_UDP, '50', '100')
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress',
|
||||
const.PROTO_NAME_UDP, '50', '100')
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(webob.exc.HTTPConflict.code, res.status_int)
|
||||
|
||||
def test_create_security_group_rule_min_port_greater_max(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
|
|
Loading…
Reference in New Issue