diff --git a/quark/db/api.py b/quark/db/api.py index fe59d69..188e6b7 100644 --- a/quark/db/api.py +++ b/quark/db/api.py @@ -30,10 +30,10 @@ from sqlalchemy.orm import class_mapper from quark.db import models from quark import network_strategy -from quark import port_vlan_id from quark import protocols +from quark import tags - +PORT_TAG_REGISTRY = tags.PORT_TAG_REGISTRY STRATEGY = network_strategy.STRATEGY LOG = logging.getLogger(__name__) CONF = cfg.CONF @@ -236,7 +236,7 @@ def port_create(context, **port_dict): port["tenant_id"] = context.tenant_id if "addresses" in port_dict: port["ip_addresses"].extend(port_dict["addresses"]) - _port_store_vlan_id(port, **port_dict) + PORT_TAG_REGISTRY.set_all(port, **port_dict) context.session.add(port) return port @@ -285,20 +285,10 @@ def update_port_associations_for_ip(context, ports, address): assoc_ports - new_ports, new_address) -def _port_store_vlan_id(port, **kwargs): - if "vlan_id" in kwargs: - try: - port_vlan_id.store_vlan_id(port, kwargs.pop("vlan_id")) - except Exception as e: - LOG.error("Exception occurred while trying to store VLAN ID on " - "port '%(port_id)d': %(message)s", - {'port_id': port.id, 'message': e.message}) - - def port_update(context, port, **kwargs): if "addresses" in kwargs: port["ip_addresses"] = kwargs.pop("addresses") - _port_store_vlan_id(port, **kwargs) + PORT_TAG_REGISTRY.set_all(port, **kwargs) port.update(kwargs) context.session.add(port) return port diff --git a/quark/plugin_views.py b/quark/plugin_views.py index e13ea76..35e6797 100644 --- a/quark/plugin_views.py +++ b/quark/plugin_views.py @@ -23,13 +23,14 @@ from oslo_log import log as logging from quark.db import ip_types from quark import network_strategy -from quark import port_vlan_id from quark import protocols +from quark import tags CONF = cfg.CONF LOG = logging.getLogger(__name__) STRATEGY = network_strategy.STRATEGY +PORT_TAG_REGISTRY = tags.PORT_TAG_REGISTRY quark_view_opts = [ cfg.BoolOpt('show_allocation_pools', @@ -185,9 +186,15 @@ def _port_dict(port, fields=None): # are not eager loaded. According to mdietz this be a small impact on # performance, but if the tag system gets used more on ports, we may # want to eager load the tags. - vlan_id = port_vlan_id.retrieve_vlan_id(port) - if vlan_id: - res["vlan_id"] = vlan_id + try: + t = PORT_TAG_REGISTRY.get_all(port) + res.update(t) + except Exception as e: + # NOTE(morgabra) We really don't want to break port-listing if + # this goes sideways here, so we pass. + msg = ("Unknown error loading tags for port %s: %s" + % (port["id"], e)) + LOG.exception(msg) return res diff --git a/quark/port_vlan_id.py b/quark/port_vlan_id.py deleted file mode 100644 index 78ad549..0000000 --- a/quark/port_vlan_id.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2015 Rackspace -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -VLAN_TAG_PREFIX = "VLAN_ID:" -MIN_VLAN_ID = 1 -MAX_VLAN_ID = 4096 - - -class InvalidVlanIdError(Exception): - """Raised if an invalid VLAN ID is detected.""" - def __init__(self, vlan_id): - self.vlan_id = vlan_id - self.message = ("Invalid VLAN ID detected. Got '%(vlan_id)s'. " - "Integer conversion yields: '%(vlan_id_int)d'. " - "VLAN ID should be between %(min)d and %(max)d " - "inclusive." % {'vlan_id': vlan_id, - 'vlan_id_int': int(vlan_id), - 'min': MIN_VLAN_ID, - 'max': MAX_VLAN_ID}) - - -def _validate_vlan_id(vlan_id): - """Validates a VLAN ID. - - :param vlan_id: The VLAN ID to validate against. - :raises InvalidVlanIdError: Raised if the VLAN ID is invalid. - """ - vlan_id_int = int(vlan_id) - if vlan_id_int < MIN_VLAN_ID or vlan_id_int > MAX_VLAN_ID: - raise InvalidVlanIdError(vlan_id) - - -def _build_vlan_tag_string(vlan_id): - """Builds a VLAN ID tag string. - - :param vlan_id: The VLAN ID as a string. - :returns: The VLAN ID string as appropriate for a port tag. - """ - return "%s%d" % (VLAN_TAG_PREFIX, int(vlan_id)) - - -def store_vlan_id(port, vlan_id): - """Stores a VLAN ID on a specified port. - - :param port: The port object on which to store the VLAN ID. - :param vlan_id: The VLAN ID as a string. - - :raises InvalidVlanIdError: If the vlan_id is invalid, this exception - is raised. - """ - _validate_vlan_id(vlan_id) - port.tags.append(_build_vlan_tag_string(vlan_id)) - - -def retrieve_vlan_id(port): - """Retrieves the VLAN ID associated with the given port, if it exists. - - :param port: The port object. - :returns: The VLAN ID as an integer, if the port has one attached. - Otherwise returns None. - - :raises InvalidVlanIdError: This exception is raised if the retrieved - VLAN ID is invalid. - """ - for tag in port.tags: - if is_vlan_id_tag(tag): - vlan_id = _extract_vlan_id_from_tag(tag) - _validate_vlan_id(vlan_id) - return vlan_id - - return None - - -def _extract_vlan_id_from_tag(tag): - """Extracts the VLAN ID from a given tag, if possible. - - Assumes the tag argument is definitely a VLAN ID tag as identified by - is_vlan_id_tag(tag). - - :param tag: The tag object. - :returns: The VLAN ID as an integer if extraction is successful - Otherwise returns None. - """ - try: - vlan_id = int(tag[len(VLAN_TAG_PREFIX):]) - except Exception: - return None - return vlan_id - - -def is_vlan_id_tag(tag): - """Determines if the given tag is a VLAN tag. - - :param tag: A tag model object. - :returns: True if the tag is a VLAN ID tag. False otherwise. - """ - return tag[0:len(VLAN_TAG_PREFIX)] == VLAN_TAG_PREFIX - - -def has_vlan_id(port): - """Determines if the specified port has a VLAN ID attached. - - :param port: The port object. - :returns: True if the port has an associated VLAN ID, False otherwise. - """ - for tag in port.tags: - if is_vlan_id_tag(tag): - return True - return False diff --git a/quark/tags.py b/quark/tags.py new file mode 100644 index 0000000..d1d959b --- /dev/null +++ b/quark/tags.py @@ -0,0 +1,176 @@ +# Copyright 2015 Rackspace +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from neutron.common import exceptions + + +class TagValidationError(Exception): + def __init__(self, value, message): + self.value = value + self.message = message + + +class Tag(object): + + @classmethod + def get_name(cls): + """API name of the tag.""" + if not hasattr(cls, 'NAME'): + raise NotImplementedError() + return cls.NAME + + @classmethod + def get_prefix(cls): + """Tag 'key', saved in the database as :""" + return "%s:" % cls.get_name().upper() + + def serialize(self, value): + return "%s%s" % (self.get_prefix(), value) + + def deserialize(self, tag): + if self.is_tag(tag): + try: + return tag[len(self.get_prefix()):] + except Exception: + pass + return None + + def validate(self, value): + raise NotImplementedError + + def set(self, model, value): + """Set tag on model object.""" + self.validate(value) + self._pop(model) + value = self.serialize(value) + model.tags.append(value) + + def get(self, model): + """Get a matching valid tag off the model.""" + for tag in model.tags: + if self.is_tag(tag): + value = self.deserialize(tag) + try: + self.validate(value) + return value + except TagValidationError: + continue + return None + + def _pop(self, model): + """Pop all matching tags off the model and return them.""" + tags = [] + + # collect any exsiting tags with matching prefix + for tag in model.tags: + if self.is_tag(tag): + tags.append(tag) + + # remove collected tags from model + if tags: + for tag in tags: + model.tags.remove(tag) + + return tags + + def pop(self, model): + """Pop all matching tags off the port, return a valid one.""" + tags = self._pop(model) + if tags: + for tag in tags: + value = self.deserialize(tag) + try: + self.validate(value) + return value + except TagValidationError: + continue + + def is_tag(self, tag): + """Is a given tag this type?""" + return tag[0:len(self.get_prefix())] == self.get_prefix() + + def has_tag(self, model): + """Does the given port have this tag?""" + for tag in model.tags: + if self.is_tag(tag): + return True + return False + + +class VlanTag(Tag): + + NAME = "vlan_id" + MIN_VLAN_ID = 1 + MAX_VLAN_ID = 4096 + + def validate(self, value): + """Validates a VLAN ID. + + :param value: The VLAN ID to validate against. + :raises TagValidationError: Raised if the VLAN ID is invalid. + """ + try: + vlan_id_int = int(value) + assert vlan_id_int >= self.MIN_VLAN_ID + assert vlan_id_int <= self.MAX_VLAN_ID + except Exception: + msg = ("Invalid vlan_id. Got '%(vlan_id)s'. " + "vlan_id should be an integer between %(min)d and %(max)d " + "inclusive." % {'vlan_id': value, + 'min': self.MIN_VLAN_ID, + 'max': self.MAX_VLAN_ID}) + raise TagValidationError(value, msg) + return True + + +class TagRegistry(object): + + tags = {} + + def get_all(self, model): + """Get all known tags from a model. + + Returns a dict of {:}. + """ + tags = {} + for name, tag in self.tags.items(): + for mtag in model.tags: + if tag.is_tag(mtag): + tags[name] = tag.get(model) + return tags + + def set_all(self, model, **tags): + """Validate and set all known tags on a port.""" + for name, tag in self.tags.items(): + if name in tags: + value = tags.pop(name) + if value: + try: + tag.set(model, value) + except TagValidationError as e: + raise exceptions.BadRequest( + resource="tags", + msg="%s" % (e.message)) + + +class PortTagRegistry(TagRegistry): + + def __init__(self): + self.tags = { + VlanTag.get_name(): VlanTag() + } + + +PORT_TAG_REGISTRY = PortTagRegistry() diff --git a/quark/tests/plugin_modules/test_ports.py b/quark/tests/plugin_modules/test_ports.py index 612f921..c9db930 100644 --- a/quark/tests/plugin_modules/test_ports.py +++ b/quark/tests/plugin_modules/test_ports.py @@ -26,7 +26,7 @@ from quark.db import models from quark import exceptions as q_exc from quark import network_strategy from quark.plugin_modules import ports as quark_ports -from quark import port_vlan_id +from quark import tags from quark.tests import test_quark_plugin @@ -131,7 +131,7 @@ class TestQuarkGetPorts(test_quark_plugin.TestQuarkPlugin): def test_port_show_vlan_id(self): """Prove VLAN IDs are included in port information when available.""" - port_tags = [port_vlan_id._build_vlan_tag_string("5")] + port_tags = [tags.VlanTag().serialize(5)] port = dict(mac_address=int('AABBCCDDEEFF', 16), network_id=1, tenant_id=self.context.tenant_id, device_id=2, tags=port_tags) @@ -143,7 +143,26 @@ class TestQuarkGetPorts(test_quark_plugin.TestQuarkPlugin): 'admin_state_up': None, 'fixed_ips': [], 'device_id': 2, - 'vlan_id': 5} + 'vlan_id': '5'} + with self._stubs(ports=port): + result = self.plugin.get_port(self.context, 1) + for key in expected.keys(): + self.assertEqual(result[key], expected[key]) + + def test_port_show_invalid_vlan_id(self): + """Prove VLAN IDs are included in port information when available.""" + port_tags = [tags.VlanTag().serialize('invalid')] + port = dict(mac_address=int('AABBCCDDEEFF', 16), network_id=1, + tenant_id=self.context.tenant_id, device_id=2, + tags=port_tags) + expected = {'status': "ACTIVE", + 'device_owner': None, + 'mac_address': 'AA:BB:CC:DD:EE:FF', + 'network_id': 1, + 'tenant_id': self.context.tenant_id, + 'admin_state_up': None, + 'fixed_ips': [], + 'device_id': 2} with self._stubs(ports=port): result = self.plugin.get_port(self.context, 1) for key in expected.keys(): diff --git a/quark/tests/test_db_api.py b/quark/tests/test_db_api.py index 50b2a83..cf92d1f 100644 --- a/quark/tests/test_db_api.py +++ b/quark/tests/test_db_api.py @@ -19,6 +19,7 @@ from oslo_log import log as logging from quark.db import api as db_api from quark.db import models +from quark import tags from quark.tests.functional.base import BaseFunctionalTest LOG = logging.getLogger(__name__) @@ -291,3 +292,15 @@ class TestDBAPI(BaseFunctionalTest): set([mock_ports[0], mock_ports[3]]), mock_new_address) + + def test_update_port_sets_vlan_tag(self): + self.context.session.add = mock.Mock() + mock_port = models.Port(id=1, network_id="2", ip_addresses=[], tags=[]) + db_api.port_update(self.context, mock_port, vlan_id=1) + self.assertEqual(mock_port.tags, [tags.VlanTag().serialize(1)]) + + def test_create_port_sets_vlan_tag(self): + self.context.session.add = mock.Mock() + port_req = {"id": 1, "network_id": "2", "vlan_id": 1} + new_port = db_api.port_create(self.context, **port_req) + self.assertEqual(new_port.tags, [tags.VlanTag().serialize(1)]) diff --git a/quark/tests/test_port_vlan_id.py b/quark/tests/test_port_vlan_id.py deleted file mode 100644 index 70024a7..0000000 --- a/quark/tests/test_port_vlan_id.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright 2015 Rackspace -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import random -import string - -from quark.db import models -from quark import port_vlan_id -from quark.tests import test_base - -MAX_VLAN_ID = port_vlan_id.MAX_VLAN_ID -MIN_VLAN_ID = port_vlan_id.MIN_VLAN_ID - - -class TestPortVlanId(test_base.TestBase): - def setUp(self): - super(TestPortVlanId, self).setUp() - self.port_with_vlan = ( - self._create_port_with_vlan_id( - 0, random.randrange(MIN_VLAN_ID, MAX_VLAN_ID))) - self.port_without_vlan = self._create_test_port(1) - - def _create_test_port_with_lots_of_tags(self, port_id, vlan_id): - port = self._create_test_port(port_id) - port.tags.append("One weird olde tag.") - port.tags.append("Yet another tag.") - if vlan_id is not None: - port_vlan_id.store_vlan_id(port, vlan_id) - port.tags.append("One final tag") - return port - - def _create_test_port(self, port_id): - port = models.Port(id=port_id, network_id="1", ip_addresses=[], - tags=[]) - return port - - def _create_port_with_vlan_id(self, port_id, vlan_id): - port = self._create_test_port(port_id) - tag_contents = "%s%d" % (port_vlan_id.VLAN_TAG_PREFIX, vlan_id) - port.tags.append(tag_contents) - return port - - def test__validate_vlan_id(self): - valid_ids = [MIN_VLAN_ID, MAX_VLAN_ID] - for n in range(0, 10): - valid_ids.append(random.randrange(MIN_VLAN_ID, MAX_VLAN_ID)) - for vlan_id in valid_ids: - try: - port_vlan_id._validate_vlan_id(vlan_id) - except port_vlan_id.InvalidVlanIdError as e: - self.assertFalse(True, - "_validate_vlan_id raised an exception on " - "what should be a valid VLAN ID. Exception " - "message: %s" % (e.message)) - - invalid_ids = [MIN_VLAN_ID - 1, MAX_VLAN_ID + 1] - for n in range(0, 5): - valid_ids.append(random.randrange(MAX_VLAN_ID + 1, - MAX_VLAN_ID + 100)) - valid_ids.append(random.randrange(MIN_VLAN_ID - 100, - MIN_VLAN_ID - 1)) - for vlan_id in invalid_ids: - # _validate_vlan_id should raise on invalid IDs. - self.assertRaises(port_vlan_id.InvalidVlanIdError, - port_vlan_id._validate_vlan_id, - vlan_id) - - def test_store_vlan_id_vlan(self): - # Test against a valid VLAN ID - port = self.port_without_vlan - port_vlan_id.store_vlan_id(port, MIN_VLAN_ID) - self.assertTrue(len(port.tags) == 1) - vlan_tag = port.tags[0] - self.assertTrue(string.find(vlan_tag, - port_vlan_id.VLAN_TAG_PREFIX) == 0, - "Couldn't find the VLAN tag prefix in the vlan tag!") - self.assertTrue(string.find(vlan_tag, str(MIN_VLAN_ID)) != -1, - "The VLAN ID was not stored in the VLAN tag!") - - # Also test against an invalid ID - self.port_without_vlan = self._create_test_port(2) - port = self.port_without_vlan - self.assertRaises(port_vlan_id.InvalidVlanIdError, - port_vlan_id.store_vlan_id, - port, port_vlan_id.MIN_VLAN_ID - 1) - self.assertTrue(len(port.tags) == 0, - "The port has a new tag, even though the VLAN ID was " - "invalid!") - - def test_retrieve_vlan_id(self): - # VLAN ID exists - port = self.port_with_vlan - vlan_id = port_vlan_id.retrieve_vlan_id(port) - self.assertIsNotNone(vlan_id, - "VLAN ID returned by retrieve_vlan_id " - "is None despite having stored the VLAN ID on " - "this port earlier.") - - # VLAN ID is absent - port = self.port_without_vlan - vlan_id = port_vlan_id.retrieve_vlan_id(port) - self.assertIsNone(vlan_id, - "VLAN ID is not None, even though the port does " - "not have a VLAN ID stored.") - - # Other tags are present on the port. - port = self._create_test_port_with_lots_of_tags(3, 5) - vlan_id = port_vlan_id.retrieve_vlan_id(port) - self.assertEqual(5, vlan_id, - "Retrieved VLAN ID did not match expectations with " - "another tag present.") - - def test_is_vlan_id_tag(self): - # Test some good cases, note that is_vlan_id_tag doesn't validate - # the VLAN_ID itself, as it should've been validated before it was - # stored to the port model. - test_ids = [-1, 2, 3, 4, 100, 1000, 5234, "puppy", "dog"] - good_tags = [("%s%s" % (port_vlan_id.VLAN_TAG_PREFIX, vlan_id)) - for vlan_id in str(test_ids)] - for tag in good_tags: - self.assertTrue(port_vlan_id.is_vlan_id_tag(tag), - "A known good tag was not recognized as one by " - "is_vlan_id_tag. Tag: '%(tag)s'" % {'tag': tag}) - - # Test some bad ones - bad_tags = ["", "snake:50", "cipher", "zero", "[]asdrf897y", - port_vlan_id.VLAN_TAG_PREFIX[:-2], - "some_other_key_value_pair:234"] - for tag in bad_tags: - self.assertFalse(port_vlan_id.is_vlan_id_tag(tag), - "A known bad tag was recognized as a VLAN ID tag " - "by is_vlan_id_tag. Tag: '%(tag)s'" % - {'tag': tag}) - - def test_has_vlan_id(self): - # Test port with VLAN ID, but no other tags - port = self.port_with_vlan - self.assertTrue(port_vlan_id.has_vlan_id(port), - "has_vlan_id returned False even though the port is " - "known to have a valid VLAN ID tag.") - - # Test port without, no tags - port = self.port_without_vlan - self.assertFalse(port_vlan_id.has_vlan_id(port), - "has_vlan_id returned True even though the port " - "doesn't have a VLAN ID tag.") - - # Test port with VLAN ID, and several tags - port = self._create_test_port_with_lots_of_tags(5, 1337) - self.assertTrue(port_vlan_id.has_vlan_id(port), - "has_vlan_id returned False even though the port " - "has a VLAN ID tag.") - - # Test port without VLAN ID, but with several tags - port = self._create_test_port_with_lots_of_tags(5, None) - self.assertFalse(port_vlan_id.has_vlan_id(port), - "has_vlan_id returned True even though the port " - "does not have a VLAN ID tag.") diff --git a/quark/tests/test_tags.py b/quark/tests/test_tags.py new file mode 100644 index 0000000..271491b --- /dev/null +++ b/quark/tests/test_tags.py @@ -0,0 +1,247 @@ +# Copyright 2015 Rackspace +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from neutron.common import exceptions + +from quark.db import models +from quark import tags +from quark.tests import test_base + + +class FakeTag(tags.Tag): + + NAME = "fake_tag" + + def validate(self, value): + if value == "invalid": + raise tags.TagValidationError(value, "FakeTag value is 'invalid'.") + return True + + +class FooTag(tags.Tag): + + NAME = "foo_tag" + + def validate(self, value): + if value != "foo": + raise tags.TagValidationError(value, "FooTag value is not 'foo'.") + return True + + +class FakeTagRegistry(tags.PortTagRegistry): + def __init__(self, tags=None): + if tags: + self.tags = tags + + +class TestTagBase(test_base.TestBase): + + def setUp(self, tag=None, value=None, value2=None, invalid_value=None): + super(TestTagBase, self).setUp() + self.tag = tag if tag else FakeTag() + self.value = value if value else "first" + self.value2 = value2 if value2 else "second" + self.invalid_value = invalid_value if invalid_value else "invalid" + + self.foo_tag = FooTag() + self.existing_tags = ["EXISTING_TAG:already exists", + "random tag", + "123"] + tags = { + self.tag.get_name(): self.tag, + self.foo_tag.get_name(): self.foo_tag + } + self.registry = FakeTagRegistry(tags=tags) + + def _create_test_model(self, id, tags=None): + tags = tags if tags else [] + tags = self.existing_tags + tags + model = models.Port(id=id, network_id="1", ip_addresses=[], + tags=tags) + return model + + def _assert_tags(self, model, tags=None): + """Assert given tags and already existing tags are present.""" + tags = tags if tags else [] + expected_tags = (self.existing_tags + tags) + self.assertEqual(sorted(model.tags), + sorted(expected_tags)) + + def test_tag_registry_get_all(self): + model = self._create_test_model(1, tags=[]) + self.tag.set(model, self.value) + self.foo_tag.set(model, "foo") + expected_tags = [ + self.tag.serialize(self.value), + self.foo_tag.serialize("foo") + ] + self._assert_tags(model, tags=expected_tags) + + tags = self.registry.get_all(model) + self.assertEqual(tags, {self.tag.get_name(): str(self.value), + 'foo_tag': 'foo'}) + + def test_tag_registry_set_all(self): + model = self._create_test_model(1, tags=[]) + self._assert_tags(model, tags=[]) + + kwargs = {self.foo_tag.get_name(): "foo", + self.tag.get_name(): self.value} + self.registry.set_all(model, **kwargs) + + expected_tags = [ + self.tag.serialize(self.value), + self.foo_tag.serialize("foo") + ] + self._assert_tags(model, tags=expected_tags) + + def test_tag_registry_set_all_invalid_raises(self): + model = self._create_test_model(1, tags=[]) + self._assert_tags(model, tags=[]) + + kwargs = {self.foo_tag.get_name(): "foo", + self.tag.get_name(): self.invalid_value} + + with self.assertRaises(exceptions.BadRequest): + self.registry.set_all(model, **kwargs) + + def test_tag_get(self): + tags = [ + self.tag.serialize(self.value) + ] + model = self._create_test_model(1, tags=tags) + self._assert_tags(model, tags=tags) + + self.assertEqual(self.tag.get(model), str(self.value)) + self._assert_tags(model, tags=tags) + + self.assertEqual(self.tag.get(model), str(self.value)) + self._assert_tags(model, tags=tags) + + def test_tag_get_invalid(self): + tags = [ + self.foo_tag.serialize(self.invalid_value) + ] + model = self._create_test_model(1, tags=tags) + self._assert_tags(model, tags=tags) + + self.assertEqual(self.tag.get(model), None) + self._assert_tags(model, tags=tags) + + def test_tag_set(self): + model = self._create_test_model(1, tags=[]) + self._assert_tags(model, tags=[]) + + expected_tags = [ + self.tag.serialize(self.value) + ] + + self.tag.set(model, self.value) + self._assert_tags( + model, tags=expected_tags) + + self.tag.set(model, self.value) + self._assert_tags( + model, tags=expected_tags) + + def test_tag_set_existing(self): + tags = [ + self.tag.serialize(self.value) + ] + model = self._create_test_model(1, tags=tags) + self._assert_tags(model, tags=tags) + + self.tag.set(model, self.value2) + self._assert_tags( + model, tags=[self.tag.serialize(self.value2)]) + + def test_tag_set_invalid(self): + model = self._create_test_model(1, tags=[]) + + with self.assertRaises(tags.TagValidationError): + self.tag.set(model, self.invalid_value) + + self._assert_tags(model, tags=[]) + + def test_pop(self): + tags = [ + self.tag.serialize(self.value) + ] + model = self._create_test_model(1, tags=tags) + self._assert_tags(model, tags=tags) + + self.assertEqual(self.tag.pop(model), str(self.value)) + self._assert_tags(model, tags=[]) + + self.assertEqual(self.tag.pop(model), None) + self._assert_tags(model, tags=[]) + + def test_pop_invalid(self): + tags = [ + self.tag.serialize(self.invalid_value), + self.tag.serialize(self.value), + self.tag.serialize(self.value2) + ] + model = self._create_test_model(1, tags=tags) + self._assert_tags(model, tags=tags) + + self.assertTrue(self.tag.pop(model) in + [str(self.value), str(self.value2)]) + self._assert_tags(model, tags=[]) + + self.assertEqual(self.tag.pop(model), None) + self._assert_tags(model, tags=[]) + + +class TestVlanTag(TestTagBase): + + def setUp(self): + tag = tags.VlanTag() + value = 50 + value2 = 100 + invalid_value = 5000 + super(TestVlanTag, self).setUp( + tag=tag, value=value, value2=value2, invalid_value=invalid_value) + + def test_vlan_validation(self): + model = self._create_test_model(1, tags=[]) + + with self.assertRaises(tags.TagValidationError): + self.tag.set(model, self.tag.MIN_VLAN_ID - 1) + self._assert_tags(model, tags=[]) + + with self.assertRaises(tags.TagValidationError): + self.tag.set(model, self.tag.MAX_VLAN_ID + 1) + self._assert_tags(model, tags=[]) + + with self.assertRaises(tags.TagValidationError): + self.tag.set(model, 'three') + self._assert_tags(model, tags=[]) + + self.tag.set(model, self.tag.MIN_VLAN_ID) + self._assert_tags( + model, tags=[self.tag.serialize(self.tag.MIN_VLAN_ID)]) + + self.tag.set(model, self.tag.MAX_VLAN_ID) + self._assert_tags( + model, tags=[self.tag.serialize(self.tag.MAX_VLAN_ID)]) + + self.tag.set(model, str(self.tag.MIN_VLAN_ID)) + self._assert_tags( + model, tags=[self.tag.serialize(self.tag.MIN_VLAN_ID)]) + + self.tag.set(model, str(self.tag.MAX_VLAN_ID)) + self._assert_tags( + model, tags=[self.tag.serialize(self.tag.MAX_VLAN_ID)])