Merge "Add ObjectListBase concat methods"
This commit is contained in:
@@ -880,6 +880,28 @@ class ObjectListBase(collections.Sequence):
|
||||
changes.add('objects')
|
||||
return changes
|
||||
|
||||
def __add__(self, other):
|
||||
# Handling arbitrary fields may not make sense if those fields are not
|
||||
# all concatenatable. Only concatenate if the base 'objects' field is
|
||||
# the only one and the classes match.
|
||||
if (self.__class__ == other.__class__ and
|
||||
list(self.__class__.fields.keys()) == ['objects']):
|
||||
return self.__class__(objects=self.objects + other.objects)
|
||||
else:
|
||||
raise TypeError("List Objects should be of the same type and only "
|
||||
"have an 'objects' field")
|
||||
|
||||
def __radd__(self, other):
|
||||
if (self.__class__ == other.__class__ and
|
||||
list(self.__class__.fields.keys()) == ['objects']):
|
||||
# This should never be run in practice. If the above condition is
|
||||
# met then __add__ would have been run.
|
||||
raise NotImplementedError('__radd__ is not implemented for '
|
||||
'objects of the same type')
|
||||
else:
|
||||
raise TypeError("List Objects should be of the same type and only "
|
||||
"have an 'objects' field")
|
||||
|
||||
|
||||
class VersionedObjectSerializer(messaging.NoOpSerializer):
|
||||
"""A VersionedObject-aware Serializer.
|
||||
|
||||
@@ -2251,3 +2251,92 @@ class TestUtilityMethods(test.TestCase):
|
||||
'TestChild': '2.34',
|
||||
'TestChildTwo': '4.56'},
|
||||
tree)
|
||||
|
||||
|
||||
class TestListObjectConcat(test.TestCase):
|
||||
def test_list_object_concat(self):
|
||||
@base.VersionedObjectRegistry.register_if(False)
|
||||
class MyList(base.ObjectListBase, base.VersionedObject):
|
||||
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}
|
||||
|
||||
values = [1, 2, 42]
|
||||
|
||||
list1 = MyList(objects=[MyOwnedObject(baz=values[0]),
|
||||
MyOwnedObject(baz=values[1])])
|
||||
list2 = MyList(objects=[MyOwnedObject(baz=values[2])])
|
||||
|
||||
concat_list = list1 + list2
|
||||
for idx, obj in enumerate(concat_list):
|
||||
self.assertEqual(values[idx], obj.baz)
|
||||
|
||||
# Assert that the original lists are unmodified
|
||||
self.assertEqual(2, len(list1.objects))
|
||||
self.assertEqual(1, list1.objects[0].baz)
|
||||
self.assertEqual(2, list1.objects[1].baz)
|
||||
self.assertEqual(1, len(list2.objects))
|
||||
self.assertEqual(42, list2.objects[0].baz)
|
||||
|
||||
def test_list_object_concat_fails_different_objects(self):
|
||||
@base.VersionedObjectRegistry.register_if(False)
|
||||
class MyList(base.ObjectListBase, base.VersionedObject):
|
||||
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}
|
||||
|
||||
@base.VersionedObjectRegistry.register_if(False)
|
||||
class MyList2(base.ObjectListBase, base.VersionedObject):
|
||||
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}
|
||||
|
||||
list1 = MyList(objects=[MyOwnedObject(baz=1)])
|
||||
list2 = MyList2(objects=[MyOwnedObject(baz=2)])
|
||||
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
self.assertRaises(TypeError, add, list1, list2)
|
||||
# Assert that the original lists are unmodified
|
||||
self.assertEqual(1, len(list1.objects))
|
||||
self.assertEqual(1, len(list2.objects))
|
||||
self.assertEqual(1, list1.objects[0].baz)
|
||||
self.assertEqual(2, list2.objects[0].baz)
|
||||
|
||||
def test_list_object_concat_fails_extra_fields(self):
|
||||
@base.VersionedObjectRegistry.register_if(False)
|
||||
class MyList(base.ObjectListBase, base.VersionedObject):
|
||||
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject'),
|
||||
'foo': fields.IntegerField(nullable=True)}
|
||||
|
||||
list1 = MyList(objects=[MyOwnedObject(baz=1)])
|
||||
list2 = MyList(objects=[MyOwnedObject(baz=2)])
|
||||
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
self.assertRaises(TypeError, add, list1, list2)
|
||||
# Assert that the original lists are unmodified
|
||||
self.assertEqual(1, len(list1.objects))
|
||||
self.assertEqual(1, len(list2.objects))
|
||||
self.assertEqual(1, list1.objects[0].baz)
|
||||
self.assertEqual(2, list2.objects[0].baz)
|
||||
|
||||
def test_builtin_list_add_fails(self):
|
||||
@base.VersionedObjectRegistry.register_if(False)
|
||||
class MyList(base.ObjectListBase, base.VersionedObject):
|
||||
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}
|
||||
|
||||
list1 = MyList(objects=[MyOwnedObject(baz=1)])
|
||||
|
||||
def add(obj):
|
||||
return obj + []
|
||||
|
||||
self.assertRaises(TypeError, add, list1)
|
||||
|
||||
def test_builtin_list_radd_fails(self):
|
||||
@base.VersionedObjectRegistry.register_if(False)
|
||||
class MyList(base.ObjectListBase, base.VersionedObject):
|
||||
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}
|
||||
|
||||
list1 = MyList(objects=[MyOwnedObject(baz=1)])
|
||||
|
||||
def add(obj):
|
||||
return [] + obj
|
||||
|
||||
self.assertRaises(TypeError, add, list1)
|
||||
|
||||
Reference in New Issue
Block a user