diff --git a/neutron/services/trunk/rules.py b/neutron/services/trunk/rules.py index 80fb80cee31..9535e9f1ed7 100644 --- a/neutron/services/trunk/rules.py +++ b/neutron/services/trunk/rules.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +from neutron_lib.api import converters from neutron_lib import constants as n_const from neutron_lib import exceptions as n_exc @@ -102,20 +103,26 @@ class SubPortsValidator(object): # figure out defaults when the time comes to support Ironic. # We can reasonably expect segmentation details to be provided # in all other cases for now. - segmentation_id = subport.get("segmentation_id") - segmentation_type = subport.get("segmentation_type") - if not segmentation_id or not segmentation_type: + try: + segmentation_type = subport["segmentation_type"] + segmentation_id = ( + converters.convert_to_int(subport["segmentation_id"])) + except KeyError: msg = _("Invalid subport details '%s': missing segmentation " "information. Must specify both segmentation_id and " "segmentation_type") % subport raise n_exc.InvalidInput(error_message=msg) + except n_exc.InvalidInput: + msg = _("Invalid subport details: segmentation_id '%s' is " + "not an integer") % subport["segmentation_id"] + raise n_exc.InvalidInput(error_message=msg) if segmentation_type not in self._segmentation_types: - msg = _("Invalid segmentation_type '%s'") % segmentation_type + msg = _("Unknown segmentation_type '%s'") % segmentation_type raise n_exc.InvalidInput(error_message=msg) if not self._segmentation_types[segmentation_type](segmentation_id): - msg = _("Invalid segmentation id '%s'") % segmentation_id + msg = _("Segmentation ID '%s' is not in range") % segmentation_id raise n_exc.InvalidInput(error_message=msg) trunk_validator = TrunkPortValidator(subport['port_id']) diff --git a/neutron/services/trunk/validators/vlan.py b/neutron/services/trunk/validators/vlan.py index 51f61abfa3f..a9ef8f89570 100644 --- a/neutron/services/trunk/validators/vlan.py +++ b/neutron/services/trunk/validators/vlan.py @@ -16,7 +16,7 @@ from oslo_log import log as logging from neutron.callbacks import events from neutron.callbacks import registry -from neutron.plugins.common import constants as common_consts +from neutron.plugins.common import utils from neutron.services.trunk import constants as trunk_consts LOG = logging.getLogger(__name__) @@ -29,10 +29,6 @@ def register(): def handler(resource, event, trigger, **kwargs): - trigger.add_segmentation_type(trunk_consts.VLAN, vlan_range) + trigger.add_segmentation_type( + trunk_consts.VLAN, utils.is_valid_vlan_tag) LOG.debug('Registration complete') - - -def vlan_range(segmentation_id): - min_vid, max_vid = common_consts.MIN_VLAN_TAG, common_consts.MAX_VLAN_TAG - return min_vid <= segmentation_id <= max_vid diff --git a/neutron/tests/unit/services/trunk/test_rules.py b/neutron/tests/unit/services/trunk/test_rules.py index ba205c691ee..6de285d4d43 100644 --- a/neutron/tests/unit/services/trunk/test_rules.py +++ b/neutron/tests/unit/services/trunk/test_rules.py @@ -20,11 +20,11 @@ from neutron_lib import exceptions as n_exc from oslo_utils import uuidutils from neutron import manager +from neutron.plugins.common import utils from neutron.services.trunk import constants from neutron.services.trunk import exceptions as trunk_exc from neutron.services.trunk import plugin as trunk_plugin from neutron.services.trunk import rules -from neutron.services.trunk.validators import vlan as vlan_driver from neutron.tests import base from neutron.tests.unit.plugins.ml2 import test_plugin @@ -33,7 +33,7 @@ class SubPortsValidatorTestCase(base.BaseTestCase): def setUp(self): super(SubPortsValidatorTestCase, self).setUp() - self.segmentation_types = {constants.VLAN: vlan_driver.vlan_range} + self.segmentation_types = {constants.VLAN: utils.is_valid_vlan_tag} self.context = mock.ANY def test_validate_subport_subport_and_trunk_shared_port_id(self): @@ -57,6 +57,26 @@ class SubPortsValidatorTestCase(base.BaseTestCase): validator.validate, self.context) + def test_validate_subport_vlan_id_not_an_int(self): + validator = rules.SubPortsValidator( + self.segmentation_types, + [{'port_id': uuidutils.generate_uuid(), + 'segmentation_type': 'vlan', + 'segmentation_id': 'IamNotAnumber'}]) + self.assertRaises(n_exc.InvalidInput, + validator.validate, + self.context) + + def test_validate_subport_valid_vlan_id_as_string(self): + validator = rules.SubPortsValidator( + self.segmentation_types, + [{'port_id': uuidutils.generate_uuid(), + 'segmentation_type': 'vlan', + 'segmentation_id': '2'}]) + with mock.patch.object(rules.TrunkPortValidator, 'validate') as f: + validator.validate(self.context) + f.assert_called_once_with(self.context) + def test_validate_subport_subport_invalid_segmenation_type(self): validator = rules.SubPortsValidator( self.segmentation_types, @@ -101,7 +121,7 @@ class TrunkPortValidatorTestCase(test_plugin.Ml2PluginV2TestCase): super(TrunkPortValidatorTestCase, self).setUp() self.trunk_plugin = trunk_plugin.TrunkPlugin() self.trunk_plugin.add_segmentation_type(constants.VLAN, - vlan_driver.vlan_range) + utils.is_valid_vlan_tag) def test_validate_port_parent_in_use_by_trunk(self): with self.port() as trunk_parent: