Adds Hyper-V Security Groups implementation

Implements the security groups API in the Hyper-V agent.

To enable security groups on the Hyper-V agent, its config file
must contain the following option:

[SECURITYGROUP]
firewall_driver=neutron.plugins.hyperv.agent.security_groups_driver.HyperVSecurityGroupsDriver

Change-Id: I7556001557cd013c10b7f883dbf371afa8d09626
Implements: blueprint hyperv-security-groups
This commit is contained in:
Claudiu Belu 2014-02-12 16:52:47 -08:00
parent 260a9f5935
commit c823016d31
8 changed files with 730 additions and 14 deletions

View File

@ -27,6 +27,7 @@ from oslo.config import cfg
from neutron.agent.common import config
from neutron.agent import rpc as agent_rpc
from neutron.agent import securitygroups_rpc as sg_rpc
from neutron.common import config as logging_config
from neutron.common import constants as n_const
from neutron.common import topics
@ -70,6 +71,45 @@ CONF.register_opts(agent_opts, "AGENT")
config.register_agent_state_opts_helper(cfg.CONF)
class HyperVSecurityAgent(sg_rpc.SecurityGroupAgentRpcMixin):
# Set RPC API version to 1.1 by default.
RPC_API_VERSION = '1.1'
def __init__(self, context, plugin_rpc):
self.context = context
self.plugin_rpc = plugin_rpc
self.init_firewall()
if sg_rpc.is_firewall_enabled():
self._setup_rpc()
def _setup_rpc(self):
self.topic = topics.AGENT
self.dispatcher = self._create_rpc_dispatcher()
consumers = [[topics.SECURITY_GROUP, topics.UPDATE]]
self.connection = agent_rpc.create_consumers(self.dispatcher,
self.topic,
consumers)
def _create_rpc_dispatcher(self):
rpc_callback = HyperVSecurityCallbackMixin(self)
return dispatcher.RpcDispatcher([rpc_callback])
class HyperVSecurityCallbackMixin(sg_rpc.SecurityGroupAgentRpcCallbackMixin):
# Set RPC API version to 1.1 by default.
RPC_API_VERSION = '1.1'
def __init__(self, sg_agent):
self.sg_agent = sg_agent
class HyperVPluginApi(agent_rpc.PluginApi,
sg_rpc.SecurityGroupServerRpcApiMixin):
pass
class HyperVNeutronAgent(object):
# Set RPC API version to 1.0 by default.
RPC_API_VERSION = '1.0'
@ -103,7 +143,7 @@ class HyperVNeutronAgent(object):
def _setup_rpc(self):
self.agent_id = 'hyperv_%s' % platform.node()
self.topic = topics.AGENT
self.plugin_rpc = agent_rpc.PluginApi(topics.PLUGIN)
self.plugin_rpc = HyperVPluginApi(topics.PLUGIN)
self.state_rpc = agent_rpc.PluginReportStateAPI(topics.PLUGIN)
@ -119,6 +159,9 @@ class HyperVNeutronAgent(object):
self.connection = agent_rpc.create_consumers(self.dispatcher,
self.topic,
consumers)
self.sec_groups_agent = HyperVSecurityAgent(
self.context, self.plugin_rpc)
report_interval = CONF.AGENT.report_interval
if report_interval:
heartbeat = loopingcall.LoopingCall(self._report_state)
@ -165,6 +208,9 @@ class HyperVNeutronAgent(object):
def port_update(self, context, port=None, network_type=None,
segmentation_id=None, physical_network=None):
LOG.debug(_("port_update received"))
if 'security_groups' in port:
self.sec_groups_agent.refresh_firewall()
self._treat_vif_port(
port['id'], port['network_id'],
network_type, physical_network,
@ -311,6 +357,8 @@ class HyperVNeutronAgent(object):
device_details['physical_network'],
device_details['segmentation_id'],
device_details['admin_state_up'])
self.sec_groups_agent.prepare_devices_filter(devices)
self.plugin_rpc.update_device_up(self.context,
device,
self.agent_id,

View File

@ -0,0 +1,136 @@
#Copyright 2014 Cloudbase Solutions SRL
#All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# @author: Claudiu Belu, Cloudbase Solutions Srl
from neutron.agent import firewall
from neutron.openstack.common import log as logging
from neutron.plugins.hyperv.agent import utilsfactory
from neutron.plugins.hyperv.agent import utilsv2
LOG = logging.getLogger(__name__)
class HyperVSecurityGroupsDriver(firewall.FirewallDriver):
"""Security Groups Driver.
Security Groups implementation for Hyper-V VMs.
"""
_ACL_PROP_MAP = {
'direction': {'ingress': utilsv2.HyperVUtilsV2._ACL_DIR_IN,
'egress': utilsv2.HyperVUtilsV2._ACL_DIR_OUT},
'ethertype': {'IPv4': utilsv2.HyperVUtilsV2._ACL_TYPE_IPV4,
'IPv6': utilsv2.HyperVUtilsV2._ACL_TYPE_IPV6},
'default': "ANY",
'address_default': {'IPv4': '0.0.0.0/0', 'IPv6': '::/0'}
}
def __init__(self):
self._utils = utilsfactory.get_hypervutils()
self._security_ports = {}
def prepare_port_filter(self, port):
LOG.debug('Creating port %s rules' % len(port['security_group_rules']))
# newly created port, add default rules.
if port['device'] not in self._security_ports:
LOG.debug('Creating default reject rules.')
self._utils.create_default_reject_all_rules(port['id'])
self._security_ports[port['device']] = port
self._create_port_rules(port['id'], port['security_group_rules'])
def _create_port_rules(self, port_id, rules):
for rule in rules:
param_map = self._create_param_map(rule)
try:
self._utils.create_security_rule(port_id, **param_map)
except Exception as ex:
LOG.error(_('Hyper-V Exception: %(hyperv_exeption)s while '
'adding rule: %(rule)s'),
dict(hyperv_exeption=ex, rule=rule))
def _remove_port_rules(self, port_id, rules):
for rule in rules:
param_map = self._create_param_map(rule)
try:
self._utils.remove_security_rule(port_id, **param_map)
except Exception as ex:
LOG.error(_('Hyper-V Exception: %(hyperv_exeption)s while '
'removing rule: %(rule)s'),
dict(hyperv_exeption=ex, rule=rule))
def _create_param_map(self, rule):
if 'port_range_min' in rule and 'port_range_max' in rule:
local_port = '%s-%s' % (rule['port_range_min'],
rule['port_range_max'])
else:
local_port = self._ACL_PROP_MAP['default']
return {
'direction': self._ACL_PROP_MAP['direction'][rule['direction']],
'acl_type': self._ACL_PROP_MAP['ethertype'][rule['ethertype']],
'local_port': local_port,
'protocol': self._get_rule_prop_or_default(rule, 'protocol'),
'remote_address': self._get_rule_remote_address(rule)
}
def apply_port_filter(self, port):
LOG.info('Aplying port filter.')
def update_port_filter(self, port):
LOG.info('Updating port rules.')
if port['device'] not in self._security_ports:
self.prepare_port_filter(port)
return
old_port = self._security_ports[port['device']]
rules = old_port['security_group_rules']
param_port_rules = port['security_group_rules']
new_rules = [r for r in param_port_rules if r not in rules]
remove_rules = [r for r in rules if r not in param_port_rules]
LOG.info("Creating %s new rules, removing %s old rules." % (
len(new_rules), len(remove_rules)))
self._remove_port_rules(old_port['id'], remove_rules)
self._create_port_rules(port['id'], new_rules)
self._security_ports[port['device']] = port
def remove_port_filter(self, port):
LOG.info('Removing port filter')
self._security_ports.pop(port['device'], None)
@property
def ports(self):
return self._security_ports
def _get_rule_remote_address(self, rule):
if rule['direction'] is 'ingress':
ip_prefix = 'source_ip_prefix'
else:
ip_prefix = 'dest_ip_prefix'
if ip_prefix in rule:
return rule[ip_prefix]
return self._ACL_PROP_MAP['address_default'][rule['ethertype']]
def _get_rule_prop_or_default(self, rule, prop):
if prop in rule:
return rule[prop]
return self._ACL_PROP_MAP['default']

View File

@ -49,18 +49,24 @@ def _check_min_windows_version(major, minor, build=0):
return map(int, version_str.split('.')) >= [major, minor, build]
def _get_class(v1_class, v2_class, force_v1_flag):
# V2 classes are supported starting from Hyper-V Server 2012 and
# Windows Server 2012 (kernel version 6.2)
if not force_v1_flag and _check_min_windows_version(6, 2):
cls = v2_class
def get_hypervutils():
# V1 virtualization namespace features are supported up to
# Windows Server / Hyper-V Server 2012
# V2 virtualization namespace features are supported starting with
# Windows Server / Hyper-V Server 2012
# Windows Server / Hyper-V Server 2012 R2 uses the V2 namespace and
# introduces additional features
force_v1_flag = CONF.hyperv.force_hyperv_utils_v1
if _check_min_windows_version(6, 3):
if force_v1_flag:
LOG.warning('V1 virtualization namespace no longer supported on '
'Windows Server / Hyper-V Server 2012 R2 or above.')
cls = utilsv2.HyperVUtilsV2R2
elif not force_v1_flag and _check_min_windows_version(6, 2):
cls = utilsv2.HyperVUtilsV2
else:
cls = v1_class
cls = utils.HyperVUtils
LOG.debug(_("Loading class: %(module_name)s.%(class_name)s"),
{'module_name': cls.__module__, 'class_name': cls.__name__})
return cls
def get_hypervutils():
return _get_class(utils.HyperVUtils, utilsv2.HyperVUtilsV2,
CONF.hyperv.force_hyperv_utils_v1)()
return cls()

View File

@ -26,17 +26,32 @@ class HyperVUtilsV2(utils.HyperVUtils):
_ETHERNET_SWITCH_PORT = 'Msvm_EthernetSwitchPort'
_PORT_ALLOC_SET_DATA = 'Msvm_EthernetPortAllocationSettingData'
_PORT_VLAN_SET_DATA = 'Msvm_EthernetSwitchPortVlanSettingData'
_PORT_SECURITY_SET_DATA = 'Msvm_EthernetSwitchPortSecuritySettingData'
_PORT_ALLOC_ACL_SET_DATA = 'Msvm_EthernetSwitchPortAclSettingData'
_PORT_EXT_ACL_SET_DATA = _PORT_ALLOC_ACL_SET_DATA
_LAN_ENDPOINT = 'Msvm_LANEndpoint'
_STATE_DISABLED = 3
_OPERATION_MODE_ACCESS = 1
_ACL_DIR_IN = 1
_ACL_DIR_OUT = 2
_ACL_TYPE_IPV4 = 2
_ACL_TYPE_IPV6 = 3
_ACL_ACTION_ALLOW = 1
_ACL_ACTION_DENY = 2
_ACL_ACTION_METER = 3
_ACL_APPLICABILITY_LOCAL = 1
_ACL_APPLICABILITY_REMOTE = 2
_ACL_DEFAULT = 'ANY'
_IPV4_ANY = '0.0.0.0/0'
_IPV6_ANY = '::/0'
_TCP_PROTOCOL = 'tcp'
_UDP_PROTOCOL = 'udp'
_MAX_WEIGHT = 65500
_wmi_namespace = '//./root/virtualization/v2'
@ -80,6 +95,12 @@ class HyperVUtilsV2(utils.HyperVUtils):
element.path_(), [res_setting_data.GetText_(1)])
self._check_job_status(ret_val, job_path)
def _remove_virt_feature(self, feature_resource):
vs_man_svc = self._conn.Msvm_VirtualSystemManagementService()[0]
(job_path, ret_val) = vs_man_svc.RemoveFeatureSettings(
FeatureSettings=[feature_resource.path_()])
self._check_job_status(ret_val, job_path)
def disconnect_switch_port(
self, vswitch_name, switch_port_name, delete_port):
"""Disconnects the switch port."""
@ -121,7 +142,7 @@ class HyperVUtilsV2(utils.HyperVUtils):
port_alloc, found = self._get_switch_port_allocation(switch_port_name)
if not found:
raise utils.HyperVException(
msg=_('Port Alloc not found: %s') % switch_port_name)
msg=_('Port Allocation not found: %s') % switch_port_name)
vs_man_svc = self._conn.Msvm_VirtualSystemManagementService()[0]
vlan_settings = self._get_vlan_setting_data_from_port_alloc(port_alloc)
@ -196,3 +217,173 @@ class HyperVUtilsV2(utils.HyperVUtils):
acl.Action = self._ACL_ACTION_METER
acl.Applicability = self._ACL_APPLICABILITY_LOCAL
self._add_virt_feature(port, acl)
def create_security_rule(self, switch_port_name, direction, acl_type,
local_port, protocol, remote_address):
port, found = self._get_switch_port_allocation(switch_port_name, False)
if not found:
return
# Add the ACLs only if they don't already exist
acls = port.associators(wmi_result_class=self._PORT_EXT_ACL_SET_DATA)
weight = self._get_new_weight(acls)
self._bind_security_rule(
port, direction, acl_type, self._ACL_ACTION_ALLOW, local_port,
protocol, remote_address, weight)
def remove_security_rule(self, switch_port_name, direction, acl_type,
local_port, protocol, remote_address):
port, found = self._get_switch_port_allocation(switch_port_name, False)
if not found:
# Port not found. It happens when the VM was already deleted.
return
acls = port.associators(wmi_result_class=self._PORT_EXT_ACL_SET_DATA)
filtered_acls = self._filter_security_acls(
acls, self._ACL_ACTION_ALLOW, direction, acl_type, local_port,
protocol, remote_address)
for acl in filtered_acls:
self._remove_virt_feature(acl)
def create_default_reject_all_rules(self, switch_port_name):
port, found = self._get_switch_port_allocation(switch_port_name, False)
if not found:
raise utils.HyperVException(
msg=_('Port Allocation not found: %s') % switch_port_name)
acls = port.associators(wmi_result_class=self._PORT_EXT_ACL_SET_DATA)
filtered_acls = [v for v in acls if v.Action == self._ACL_ACTION_DENY]
# 2 directions x 2 address types x 2 protocols = 8 ACLs
if len(filtered_acls) >= 8:
return
for acl in filtered_acls:
self._remove_virt_feature(acl)
weight = 0
ipv4_pair = (self._ACL_TYPE_IPV4, self._IPV4_ANY)
ipv6_pair = (self._ACL_TYPE_IPV6, self._IPV6_ANY)
for direction in [self._ACL_DIR_IN, self._ACL_DIR_OUT]:
for acl_type, address in [ipv4_pair, ipv6_pair]:
for protocol in [self._TCP_PROTOCOL, self._UDP_PROTOCOL]:
self._bind_security_rule(
port, direction, acl_type, self._ACL_ACTION_DENY,
self._ACL_DEFAULT, protocol, address, weight)
weight += 1
def _bind_security_rule(self, port, direction, acl_type, action,
local_port, protocol, remote_address, weight):
acls = port.associators(wmi_result_class=self._PORT_EXT_ACL_SET_DATA)
filtered_acls = self._filter_security_acls(
acls, action, direction, acl_type, local_port, protocol,
remote_address)
for acl in filtered_acls:
self._remove_virt_feature(acl)
acl = self._create_security_acl(
direction, acl_type, action, local_port, protocol, remote_address,
weight)
self._add_virt_feature(port, acl)
def _create_acl(self, direction, acl_type, action):
acl = self._get_default_setting_data(self._PORT_ALLOC_ACL_SET_DATA)
acl.set(Direction=direction,
AclType=acl_type,
Action=action,
Applicability=self._ACL_APPLICABILITY_LOCAL)
return acl
def _create_security_acl(self, direction, acl_type, action, local_port,
protocol, remote_ip_address, weight):
acl = self._create_acl(direction, acl_type, action)
(remote_address, remote_prefix_length) = remote_ip_address.split('/')
acl.set(Applicability=self._ACL_APPLICABILITY_REMOTE,
RemoteAddress=remote_address,
RemoteAddressPrefixLength=remote_prefix_length)
return acl
def _filter_acls(self, acls, action, direction, acl_type, remote_addr=""):
return [v for v in acls
if v.Action == action and
v.Direction == direction and
v.AclType == acl_type and
v.RemoteAddress == remote_addr]
def _filter_security_acls(self, acls, acl_action, direction, acl_type,
local_port, protocol, remote_addr=""):
(remote_address, remote_prefix_length) = remote_addr.split('/')
remote_prefix_length = int(remote_prefix_length)
return [v for v in acls
if v.Direction == direction and
v.Action in [self._ACL_ACTION_ALLOW, self._ACL_ACTION_DENY] and
v.AclType == acl_type and
v.RemoteAddress == remote_address and
v.RemoteAddressPrefixLength == remote_prefix_length]
def _get_new_weight(self, acls):
return 0
class HyperVUtilsV2R2(HyperVUtilsV2):
_PORT_EXT_ACL_SET_DATA = 'Msvm_EthernetSwitchPortExtendedAclSettingData'
_MAX_WEIGHT = 65500
def create_security_rule(self, switch_port_name, direction, acl_type,
local_port, protocol, remote_address):
protocols = [protocol]
if protocol is self._ACL_DEFAULT:
protocols = [self._TCP_PROTOCOL, self._UDP_PROTOCOL]
for proto in protocols:
super(HyperVUtilsV2R2, self).create_security_rule(
switch_port_name, direction, acl_type, local_port,
proto, remote_address)
def remove_security_rule(self, switch_port_name, direction, acl_type,
local_port, protocol, remote_address):
protocols = [protocol]
if protocol is self._ACL_DEFAULT:
protocols = ['tcp', 'udp']
for proto in protocols:
super(HyperVUtilsV2R2, self).remove_security_rule(
switch_port_name, direction, acl_type,
local_port, proto, remote_address)
def _create_security_acl(self, direction, acl_type, action, local_port,
protocol, remote_addr, weight):
acl = self._get_default_setting_data(self._PORT_EXT_ACL_SET_DATA)
acl.set(Direction=direction,
Action=action,
LocalPort=str(local_port),
Protocol=protocol,
RemoteIPAddress=remote_addr,
IdleSessionTimeout=0,
Weight=weight)
return acl
def _filter_security_acls(self, acls, action, direction, acl_type,
local_port, protocol, remote_addr=""):
return [v for v in acls
if v.Action == action and
v.Direction == direction and
v.LocalPort in [str(local_port), self._ACL_DEFAULT] and
v.Protocol in [protocol] and
v.RemoteIPAddress == remote_addr]
def _get_new_weight(self, acls):
if not acls:
return self._MAX_WEIGHT - 1
weights = [a.Weight for a in acls]
min_weight = min(weights)
for weight in range(min_weight, self._MAX_WEIGHT):
if weight not in weights:
return weight
return min_weight - 1

View File

@ -57,6 +57,7 @@ class TestHyperVNeutronAgent(base.BaseTestCase):
self.agent = hyperv_neutron_agent.HyperVNeutronAgent()
self.agent.plugin_rpc = mock.Mock()
self.agent.sec_groups_agent = mock.MagicMock()
self.agent.context = mock.Mock()
self.agent.agent_id = mock.Mock()

View File

@ -0,0 +1,176 @@
# Copyright 2014 Cloudbase Solutions SRL
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# @author: Claudiu Belu, Cloudbase Solutions Srl
"""
Unit tests for the Hyper-V Security Groups Driver.
"""
import mock
from oslo.config import cfg
from neutron.plugins.hyperv.agent import security_groups_driver as sg_driver
from neutron.plugins.hyperv.agent import utilsfactory
from neutron.tests import base
CONF = cfg.CONF
class TestHyperVSecurityGroupsDriver(base.BaseTestCase):
_FAKE_DEVICE = 'fake_device'
_FAKE_ID = 'fake_id'
_FAKE_DIRECTION = 'ingress'
_FAKE_ETHERTYPE = 'IPv4'
_FAKE_ETHERTYPE_IPV6 = 'IPv6'
_FAKE_DEST_IP_PREFIX = 'fake_dest_ip_prefix'
_FAKE_SOURCE_IP_PREFIX = 'fake_source_ip_prefix'
_FAKE_PARAM_NAME = 'fake_param_name'
_FAKE_PARAM_VALUE = 'fake_param_value'
_FAKE_PORT_MIN = 9001
_FAKE_PORT_MAX = 9011
def setUp(self):
super(TestHyperVSecurityGroupsDriver, self).setUp()
self._mock_windows_version = mock.patch.object(utilsfactory,
'get_hypervutils')
self._mock_windows_version.start()
self.addCleanup(mock.patch.stopall)
self._driver = sg_driver.HyperVSecurityGroupsDriver()
self._driver._utils = mock.MagicMock()
@mock.patch('neutron.plugins.hyperv.agent.security_groups_driver'
'.HyperVSecurityGroupsDriver._create_port_rules')
def test_prepare_port_filter(self, mock_create_rules):
mock_port = self._get_port()
mock_utils_method = self._driver._utils.create_default_reject_all_rules
self._driver.prepare_port_filter(mock_port)
self.assertEqual(mock_port,
self._driver._security_ports[self._FAKE_DEVICE])
mock_utils_method.assert_called_once_with(self._FAKE_ID)
self._driver._create_port_rules.assert_called_once_with(
self._FAKE_ID, mock_port['security_group_rules'])
def test_update_port_filter(self):
mock_port = self._get_port()
new_mock_port = self._get_port()
new_mock_port['id'] += '2'
new_mock_port['security_group_rules'][0]['ethertype'] += "2"
self._driver._security_ports[mock_port['device']] = mock_port
self._driver._create_port_rules = mock.MagicMock()
self._driver._remove_port_rules = mock.MagicMock()
self._driver.update_port_filter(new_mock_port)
self._driver._remove_port_rules.assert_called_once_with(
mock_port['id'], mock_port['security_group_rules'])
self._driver._create_port_rules.assert_called_once_with(
new_mock_port['id'], new_mock_port['security_group_rules'])
self.assertEqual(new_mock_port,
self._driver._security_ports[new_mock_port['device']])
@mock.patch('neutron.plugins.hyperv.agent.security_groups_driver'
'.HyperVSecurityGroupsDriver.prepare_port_filter')
def test_update_port_filter_new_port(self, mock_method):
mock_port = self._get_port()
self._driver.prepare_port_filter = mock.MagicMock()
self._driver.update_port_filter(mock_port)
self._driver.prepare_port_filter.assert_called_once_with(mock_port)
def test_remove_port_filter(self):
mock_port = self._get_port()
self._driver._security_ports[mock_port['device']] = mock_port
self._driver.remove_port_filter(mock_port)
self.assertFalse(mock_port['device'] in self._driver._security_ports)
def test_create_port_rules_exception(self):
fake_rule = self._create_security_rule()
self._driver._utils.create_security_rule.side_effect = Exception(
'Generated Exception for testing.')
self._driver._create_port_rules(self._FAKE_ID, [fake_rule])
def test_create_param_map(self):
fake_rule = self._create_security_rule()
self._driver._get_rule_remote_address = mock.MagicMock(
return_value=self._FAKE_SOURCE_IP_PREFIX)
actual = self._driver._create_param_map(fake_rule)
expected = {
'direction': self._driver._ACL_PROP_MAP[
'direction'][self._FAKE_DIRECTION],
'acl_type': self._driver._ACL_PROP_MAP[
'ethertype'][self._FAKE_ETHERTYPE],
'local_port': '%s-%s' % (self._FAKE_PORT_MIN, self._FAKE_PORT_MAX),
'protocol': self._driver._ACL_PROP_MAP['default'],
'remote_address': self._FAKE_SOURCE_IP_PREFIX
}
self.assertEqual(expected, actual)
@mock.patch('neutron.plugins.hyperv.agent.security_groups_driver'
'.HyperVSecurityGroupsDriver._create_param_map')
def test_create_port_rules(self, mock_method):
fake_rule = self._create_security_rule()
mock_method.return_value = {
self._FAKE_PARAM_NAME: self._FAKE_PARAM_VALUE}
self._driver._create_port_rules(self._FAKE_ID, [fake_rule])
self._driver._utils.create_security_rule.assert_called_once_with(
self._FAKE_ID, fake_param_name=self._FAKE_PARAM_VALUE)
def test_convert_any_address_to_same_ingress(self):
rule = self._create_security_rule()
actual = self._driver._get_rule_remote_address(rule)
self.assertEqual(self._FAKE_SOURCE_IP_PREFIX, actual)
def test_convert_any_address_to_same_egress(self):
rule = self._create_security_rule()
rule['direction'] += '2'
actual = self._driver._get_rule_remote_address(rule)
self.assertEqual(self._FAKE_DEST_IP_PREFIX, actual)
def test_convert_any_address_to_ipv4(self):
rule = self._create_security_rule()
del rule['source_ip_prefix']
actual = self._driver._get_rule_remote_address(rule)
self.assertEqual(self._driver._ACL_PROP_MAP['address_default']['IPv4'],
actual)
def test_convert_any_address_to_ipv6(self):
rule = self._create_security_rule()
del rule['source_ip_prefix']
rule['ethertype'] = self._FAKE_ETHERTYPE_IPV6
actual = self._driver._get_rule_remote_address(rule)
self.assertEqual(self._driver._ACL_PROP_MAP['address_default']['IPv6'],
actual)
def _get_port(self):
return {
'device': self._FAKE_DEVICE,
'id': self._FAKE_ID,
'security_group_rules': [self._create_security_rule()]
}
def _create_security_rule(self):
return {
'direction': self._FAKE_DIRECTION,
'ethertype': self._FAKE_ETHERTYPE,
'dest_ip_prefix': self._FAKE_DEST_IP_PREFIX,
'source_ip_prefix': self._FAKE_SOURCE_IP_PREFIX,
'port_range_min': self._FAKE_PORT_MIN,
'port_range_max': self._FAKE_PORT_MAX
}

View File

@ -34,6 +34,9 @@ CONF = cfg.CONF
class TestHyperVUtilsFactory(base.BaseTestCase):
def test_get_hypervutils_v2_r2(self):
self._test_returned_class(utilsv2.HyperVUtilsV2R2, True, '6.3.0')
def test_get_hypervutils_v2(self):
self._test_returned_class(utilsv2.HyperVUtilsV2, False, '6.2.0')

View File

@ -41,6 +41,14 @@ class TestHyperVUtilsV2(base.BaseTestCase):
_FAKE_CLASS_NAME = "fake_class_name"
_FAKE_ELEMENT_NAME = "fake_element_name"
_FAKE_ACL_ACT = 'fake_acl_action'
_FAKE_ACL_DIR = 'fake_acl_dir'
_FAKE_ACL_TYPE = 'fake_acl_type'
_FAKE_LOCAL_PORT = 'fake_local_port'
_FAKE_PROTOCOL = 'fake_port_protocol'
_FAKE_REMOTE_ADDR = '0.0.0.0/0'
_FAKE_WEIGHT = 'fake_weight'
def setUp(self):
super(TestHyperVUtilsV2, self).setUp()
self._utils = utilsv2.HyperVUtilsV2()
@ -144,6 +152,20 @@ class TestHyperVUtilsV2(base.BaseTestCase):
mock_svc.RemoveResourceSettings.assert_called_with(
ResourceSettings=[self._FAKE_RES_PATH])
@mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
'._check_job_status')
def test_remove_virt_feature(self, mock_check_job_status):
mock_svc = self._utils._conn.Msvm_VirtualSystemManagementService()[0]
mock_svc.RemoveFeatureSettings.return_value = (self._FAKE_JOB_PATH,
self._FAKE_RET_VAL)
mock_res_setting_data = mock.MagicMock()
mock_res_setting_data.path_.return_value = self._FAKE_RES_PATH
self._utils._remove_virt_feature(mock_res_setting_data)
mock_svc.RemoveFeatureSettings.assert_called_with(
FeatureSettings=[self._FAKE_RES_PATH])
def test_disconnect_switch_port_delete_port(self):
self._test_disconnect_switch_port(True)
@ -249,3 +271,136 @@ class TestHyperVUtilsV2(base.BaseTestCase):
self.assertEqual(4, len(self._utils._add_virt_feature.mock_calls))
self._utils._add_virt_feature.assert_called_with(
mock_port, mock_acl)
@mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
'._remove_virt_feature')
@mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
'._bind_security_rule')
def test_create_default_reject_all_rules(self, mock_bind, mock_remove):
(m_port, m_acl) = self._setup_security_rule_test()
m_acl.Action = self._utils._ACL_ACTION_DENY
self._utils.create_default_reject_all_rules(self._FAKE_PORT_NAME)
calls = []
ipv4_pair = (self._utils._ACL_TYPE_IPV4, self._utils._IPV4_ANY)
ipv6_pair = (self._utils._ACL_TYPE_IPV6, self._utils._IPV6_ANY)
for direction in [self._utils._ACL_DIR_IN, self._utils._ACL_DIR_OUT]:
for acl_type, address in [ipv4_pair, ipv6_pair]:
for protocol in [self._utils._TCP_PROTOCOL,
self._utils._UDP_PROTOCOL]:
calls.append(mock.call(m_port, direction, acl_type,
self._utils._ACL_ACTION_DENY,
self._utils._ACL_DEFAULT,
protocol, address, mock.ANY))
self._utils._remove_virt_feature.assert_called_once_with(m_acl)
self._utils._bind_security_rule.assert_has_calls(calls)
@mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
'._remove_virt_feature')
@mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
'._add_virt_feature')
@mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
'._create_security_acl')
def test_bind_security_rule(self, mock_create_acl, mock_add, mock_remove):
(m_port, m_acl) = self._setup_security_rule_test()
mock_create_acl.return_value = m_acl
self._utils._bind_security_rule(
m_port, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE,
self._FAKE_ACL_ACT, self._FAKE_LOCAL_PORT, self._FAKE_PROTOCOL,
self._FAKE_REMOTE_ADDR, self._FAKE_WEIGHT)
self._utils._add_virt_feature.assert_called_once_with(m_port, m_acl)
@mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
'._remove_virt_feature')
def test_remove_security_rule(self, mock_remove_feature):
mock_acl = self._setup_security_rule_test()[1]
self._utils.remove_security_rule(
self._FAKE_PORT_NAME, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE,
self._FAKE_LOCAL_PORT, self._FAKE_PROTOCOL, self._FAKE_REMOTE_ADDR)
self._utils._remove_virt_feature.assert_called_once_with(mock_acl)
def _setup_security_rule_test(self):
mock_port = mock.MagicMock()
mock_acl = mock.MagicMock()
mock_port.associators.return_value = [mock_acl]
self._utils._get_switch_port_allocation = mock.MagicMock(return_value=(
mock_port, True))
self._utils._filter_security_acls = mock.MagicMock(
return_value=[mock_acl])
return (mock_port, mock_acl)
def test_filter_acls(self):
mock_acl = mock.MagicMock()
mock_acl.Action = self._FAKE_ACL_ACT
mock_acl.Applicability = self._utils._ACL_APPLICABILITY_LOCAL
mock_acl.Direction = self._FAKE_ACL_DIR
mock_acl.AclType = self._FAKE_ACL_TYPE
mock_acl.RemoteAddress = self._FAKE_REMOTE_ADDR
acls = [mock_acl, mock_acl]
good_acls = self._utils._filter_acls(
acls, self._FAKE_ACL_ACT, self._FAKE_ACL_DIR,
self._FAKE_ACL_TYPE, self._FAKE_REMOTE_ADDR)
bad_acls = self._utils._filter_acls(
acls, self._FAKE_ACL_ACT, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE)
self.assertEqual(acls, good_acls)
self.assertEqual([], bad_acls)
class TestHyperVUtilsV2R2(base.BaseTestCase):
_FAKE_ACL_ACT = 'fake_acl_action'
_FAKE_ACL_DIR = 'fake_direction'
_FAKE_ACL_TYPE = 'fake_acl_type'
_FAKE_LOCAL_PORT = 'fake_local_port'
_FAKE_PROTOCOL = 'fake_port_protocol'
_FAKE_REMOTE_ADDR = '10.0.0.0/0'
def setUp(self):
super(TestHyperVUtilsV2R2, self).setUp()
self._utils = utilsv2.HyperVUtilsV2R2()
def test_filter_security_acls(self):
self._test_filter_security_acls(
self._FAKE_LOCAL_PORT, self._FAKE_PROTOCOL, self._FAKE_REMOTE_ADDR)
def test_filter_security_acls_default(self):
default = self._utils._ACL_DEFAULT
self._test_filter_security_acls(
default, default, self._FAKE_REMOTE_ADDR)
def _test_filter_security_acls(self, local_port, protocol, remote_addr):
mock_acl = mock.MagicMock()
mock_acl.Action = self._utils._ACL_ACTION_ALLOW
mock_acl.Direction = self._FAKE_ACL_DIR
mock_acl.LocalPort = local_port
mock_acl.Protocol = protocol
mock_acl.RemoteIPAddress = remote_addr
acls = [mock_acl, mock_acl]
good_acls = self._utils._filter_security_acls(
acls, mock_acl.Action, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE,
local_port, protocol, remote_addr)
bad_acls = self._utils._filter_security_acls(
acls, self._FAKE_ACL_ACT, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE,
local_port, protocol, remote_addr)
self.assertEqual(acls, good_acls)
self.assertEqual([], bad_acls)
def test_get_new_weight(self):
mockacl1 = mock.MagicMock()
mockacl1.Weight = self._utils._MAX_WEIGHT - 1
mockacl2 = mock.MagicMock()
mockacl2.Weight = self._utils._MAX_WEIGHT - 3
self.assertEqual(self._utils._MAX_WEIGHT - 2,
self._utils._get_new_weight([mockacl1, mockacl2]))
def test_get_new_weight_no_acls(self):
self.assertEqual(self._utils._MAX_WEIGHT - 1,
self._utils._get_new_weight([]))