Merge "Add ObjectListBase concat methods"

This commit is contained in:
Jenkins
2016-10-01 11:37:10 +00:00
committed by Gerrit Code Review
2 changed files with 111 additions and 0 deletions

View File

@@ -880,6 +880,28 @@ class ObjectListBase(collections.Sequence):
changes.add('objects')
return changes
def __add__(self, other):
# Handling arbitrary fields may not make sense if those fields are not
# all concatenatable. Only concatenate if the base 'objects' field is
# the only one and the classes match.
if (self.__class__ == other.__class__ and
list(self.__class__.fields.keys()) == ['objects']):
return self.__class__(objects=self.objects + other.objects)
else:
raise TypeError("List Objects should be of the same type and only "
"have an 'objects' field")
def __radd__(self, other):
if (self.__class__ == other.__class__ and
list(self.__class__.fields.keys()) == ['objects']):
# This should never be run in practice. If the above condition is
# met then __add__ would have been run.
raise NotImplementedError('__radd__ is not implemented for '
'objects of the same type')
else:
raise TypeError("List Objects should be of the same type and only "
"have an 'objects' field")
class VersionedObjectSerializer(messaging.NoOpSerializer):
"""A VersionedObject-aware Serializer.

View File

@@ -2251,3 +2251,92 @@ class TestUtilityMethods(test.TestCase):
'TestChild': '2.34',
'TestChildTwo': '4.56'},
tree)
class TestListObjectConcat(test.TestCase):
def test_list_object_concat(self):
@base.VersionedObjectRegistry.register_if(False)
class MyList(base.ObjectListBase, base.VersionedObject):
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}
values = [1, 2, 42]
list1 = MyList(objects=[MyOwnedObject(baz=values[0]),
MyOwnedObject(baz=values[1])])
list2 = MyList(objects=[MyOwnedObject(baz=values[2])])
concat_list = list1 + list2
for idx, obj in enumerate(concat_list):
self.assertEqual(values[idx], obj.baz)
# Assert that the original lists are unmodified
self.assertEqual(2, len(list1.objects))
self.assertEqual(1, list1.objects[0].baz)
self.assertEqual(2, list1.objects[1].baz)
self.assertEqual(1, len(list2.objects))
self.assertEqual(42, list2.objects[0].baz)
def test_list_object_concat_fails_different_objects(self):
@base.VersionedObjectRegistry.register_if(False)
class MyList(base.ObjectListBase, base.VersionedObject):
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}
@base.VersionedObjectRegistry.register_if(False)
class MyList2(base.ObjectListBase, base.VersionedObject):
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}
list1 = MyList(objects=[MyOwnedObject(baz=1)])
list2 = MyList2(objects=[MyOwnedObject(baz=2)])
def add(x, y):
return x + y
self.assertRaises(TypeError, add, list1, list2)
# Assert that the original lists are unmodified
self.assertEqual(1, len(list1.objects))
self.assertEqual(1, len(list2.objects))
self.assertEqual(1, list1.objects[0].baz)
self.assertEqual(2, list2.objects[0].baz)
def test_list_object_concat_fails_extra_fields(self):
@base.VersionedObjectRegistry.register_if(False)
class MyList(base.ObjectListBase, base.VersionedObject):
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject'),
'foo': fields.IntegerField(nullable=True)}
list1 = MyList(objects=[MyOwnedObject(baz=1)])
list2 = MyList(objects=[MyOwnedObject(baz=2)])
def add(x, y):
return x + y
self.assertRaises(TypeError, add, list1, list2)
# Assert that the original lists are unmodified
self.assertEqual(1, len(list1.objects))
self.assertEqual(1, len(list2.objects))
self.assertEqual(1, list1.objects[0].baz)
self.assertEqual(2, list2.objects[0].baz)
def test_builtin_list_add_fails(self):
@base.VersionedObjectRegistry.register_if(False)
class MyList(base.ObjectListBase, base.VersionedObject):
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}
list1 = MyList(objects=[MyOwnedObject(baz=1)])
def add(obj):
return obj + []
self.assertRaises(TypeError, add, list1)
def test_builtin_list_radd_fails(self):
@base.VersionedObjectRegistry.register_if(False)
class MyList(base.ObjectListBase, base.VersionedObject):
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}
list1 = MyList(objects=[MyOwnedObject(baz=1)])
def add(obj):
return [] + obj
self.assertRaises(TypeError, add, list1)