diff --git a/oslo_versionedobjects/exception.py b/oslo_versionedobjects/exception.py index 1abdf1c..ee0e9ae 100644 --- a/oslo_versionedobjects/exception.py +++ b/oslo_versionedobjects/exception.py @@ -164,3 +164,11 @@ class ReadOnlyFieldError(VersionedObjectsException): class UnsupportedObjectError(VersionedObjectsException): msg_fmt = _('Unsupported object type %(objtype)s') + + +class EnumRequiresValidValuesError(VersionedObjectsException): + msg_fmt = _('Enum fields require a list of valid_values') + + +class EnumValidValuesInvalidError(VersionedObjectsException): + msg_fmt = _('Enum valid values are not valid') diff --git a/oslo_versionedobjects/fields.py b/oslo_versionedobjects/fields.py index ef5e869..8d2963f 100755 --- a/oslo_versionedobjects/fields.py +++ b/oslo_versionedobjects/fields.py @@ -21,6 +21,7 @@ from oslo_utils import timeutils import six from oslo_versionedobjects._i18n import _ +from oslo_versionedobjects import exception class KeyTypeError(TypeError): @@ -247,6 +248,32 @@ class String(FieldType): return '\'%s\'' % value +class Enum(String): + def __init__(self, valid_values, **kwargs): + if not valid_values: + raise exception.EnumRequiresValidValuesError() + try: + # Test validity of the values + for value in valid_values: + super(Enum, self).coerce(None, 'init', value) + except (TypeError, ValueError): + raise exception.EnumValidValuesInvalidError() + self._valid_values = valid_values + super(Enum, self).__init__(**kwargs) + + def coerce(self, obj, attr, value): + if value not in self._valid_values: + msg = _("Field value %s is invalid") % value + raise ValueError(msg) + return super(Enum, self).coerce(obj, attr, value) + + def stringify(self, value): + if value not in self._valid_values: + msg = _("Field value %s is invalid") % value + raise ValueError(msg) + return super(Enum, self).stringify(value) + + class UUID(FieldType): @staticmethod def coerce(obj, attr, value): @@ -479,6 +506,23 @@ class StringField(AutoTypedField): AUTO_TYPE = String() +class EnumField(AutoTypedField): + def __init__(self, valid_values, **kwargs): + self.AUTO_TYPE = Enum(valid_values) + super(EnumField, self).__init__(**kwargs) + + def __repr__(self): + valid_values = self._type._valid_values + args = { + 'nullable': self._nullable, + 'default': self._default, + } + args.update({'valid_values': valid_values}) + return '%s(%s)' % (self._type.__class__.__name__, + ','.join(['%s=%s' % (k, v) + for k, v in sorted(args.items())])) + + class UUIDField(AutoTypedField): AUTO_TYPE = UUID() @@ -517,6 +561,23 @@ class ListOfStringsField(AutoTypedField): AUTO_TYPE = List(String()) +class ListOfEnumField(AutoTypedField): + def __init__(self, valid_values, **kwargs): + self.AUTO_TYPE = List(Enum(valid_values)) + super(ListOfEnumField, self).__init__(**kwargs) + + def __repr__(self): + valid_values = self._type._element_type._type._valid_values + args = { + 'nullable': self._nullable, + 'default': self._default, + } + args.update({'valid_values': valid_values}) + return '%s(%s)' % (self._type.__class__.__name__, + ','.join(['%s=%s' % (k, v) + for k, v in sorted(args.items())])) + + class SetOfIntegersField(AutoTypedField): AUTO_TYPE = Set(Integer()) diff --git a/oslo_versionedobjects/tests/test_fields.py b/oslo_versionedobjects/tests/test_fields.py index 34ebd5a..480036b 100755 --- a/oslo_versionedobjects/tests/test_fields.py +++ b/oslo_versionedobjects/tests/test_fields.py @@ -19,6 +19,7 @@ from oslo_utils import timeutils import six from oslo_versionedobjects import base as obj_base +from oslo_versionedobjects import exception from oslo_versionedobjects import fields from oslo_versionedobjects import test @@ -71,7 +72,7 @@ class TestField(test.TestCase): class TestString(TestField): def setUp(self): - super(TestField, self).setUp() + super(TestString, self).setUp() self.field = fields.StringField() self.coerce_good_values = [ ('foo', 'foo'), (1, '1'), (1.0, '1.0'), (True, 'True')] @@ -85,6 +86,42 @@ class TestString(TestField): self.assertEqual("'123'", self.field.stringify(123)) +class TestEnum(TestField): + def setUp(self): + super(TestEnum, self).setUp() + self.field = fields.EnumField( + valid_values=['foo', 'bar', 1, True]) + self.coerce_good_values = [('foo', 'foo'), (1, '1'), (True, 'True')] + self.coerce_bad_values = ['boo', 2, False] + self.to_primitive_values = self.coerce_good_values[0:1] + self.from_primitive_values = self.coerce_good_values[0:1] + + def test_stringify(self): + self.assertEqual("'foo'", self.field.stringify('foo')) + + def test_stringify_invalid(self): + self.assertRaises(ValueError, self.field.stringify, '123') + + def test_fingerprint(self): + # Notes(yjiang5): make sure changing valid_value will be detected + # in test_objects.test_versions + field1 = fields.EnumField(valid_values=['foo', 'bar']) + field2 = fields.EnumField(valid_values=['foo', 'bar1']) + self.assertNotEqual(str(field1), str(field2)) + + def test_missing_valid_values(self): + self.assertRaises(exception.EnumRequiresValidValuesError, + fields.EnumField, None) + + def test_empty_valid_values(self): + self.assertRaises(exception.EnumRequiresValidValuesError, + fields.EnumField, []) + + def test_non_iterable_valid_values(self): + self.assertRaises(exception.EnumValidValuesInvalidError, + fields.EnumField, True) + + class TestInteger(TestField): def setUp(self): super(TestField, self).setUp() @@ -267,6 +304,29 @@ class TestListOfStrings(TestField): self.assertEqual("['abc']", self.field.stringify(['abc'])) +class TestListOfEnum(TestField): + def setUp(self): + super(TestListOfEnum, self).setUp() + self.field = fields.ListOfEnumField(valid_values=['foo', 'bar']) + self.coerce_good_values = [(['foo', 'bar'], ['foo', 'bar'])] + self.coerce_bad_values = ['foo', ['foo', 'bar1']] + self.to_primitive_values = [(['foo'], ['foo'])] + self.from_primitive_values = [(['foo'], ['foo'])] + + def test_stringify(self): + self.assertEqual("['foo']", self.field.stringify(['foo'])) + + def test_stringify_invalid(self): + self.assertRaises(ValueError, self.field.stringify, '[abc]') + + def test_fingerprint(self): + # Notes(yjiang5): make sure changing valid_value will be detected + # in test_objects.test_versions + field1 = fields.ListOfEnumField(valid_values=['foo', 'bar']) + field2 = fields.ListOfEnumField(valid_values=['foo', 'bar1']) + self.assertNotEqual(str(field1), str(field2)) + + class TestSet(TestField): def setUp(self): super(TestSet, self).setUp()