From f3519480d04ae1e7f52c8e3b5ff7e0c6678d2da6 Mon Sep 17 00:00:00 2001 From: Andrew Laski Date: Thu, 8 Sep 2016 13:13:14 -0400 Subject: [PATCH] Add ObjectListBase concat methods This adds support for using + to concatenate two objects derived from ObjectListBase. The original objects are not modified, a new list object is returned. The requirements for this are that the two objects are of the same class, and there is only an 'objects' field. Change-Id: I4610db6b52f3f576a6d0c2e64af39927077586cd --- oslo_versionedobjects/base.py | 22 +++++ oslo_versionedobjects/tests/test_objects.py | 89 +++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/oslo_versionedobjects/base.py b/oslo_versionedobjects/base.py index f396e770..551cc515 100644 --- a/oslo_versionedobjects/base.py +++ b/oslo_versionedobjects/base.py @@ -880,6 +880,28 @@ class ObjectListBase(collections.Sequence): changes.add('objects') return changes + def __add__(self, other): + # Handling arbitrary fields may not make sense if those fields are not + # all concatenatable. Only concatenate if the base 'objects' field is + # the only one and the classes match. + if (self.__class__ == other.__class__ and + list(self.__class__.fields.keys()) == ['objects']): + return self.__class__(objects=self.objects + other.objects) + else: + raise TypeError("List Objects should be of the same type and only " + "have an 'objects' field") + + def __radd__(self, other): + if (self.__class__ == other.__class__ and + list(self.__class__.fields.keys()) == ['objects']): + # This should never be run in practice. If the above condition is + # met then __add__ would have been run. + raise NotImplementedError('__radd__ is not implemented for ' + 'objects of the same type') + else: + raise TypeError("List Objects should be of the same type and only " + "have an 'objects' field") + class VersionedObjectSerializer(messaging.NoOpSerializer): """A VersionedObject-aware Serializer. diff --git a/oslo_versionedobjects/tests/test_objects.py b/oslo_versionedobjects/tests/test_objects.py index c3b0ab32..addcb716 100644 --- a/oslo_versionedobjects/tests/test_objects.py +++ b/oslo_versionedobjects/tests/test_objects.py @@ -2251,3 +2251,92 @@ class TestUtilityMethods(test.TestCase): 'TestChild': '2.34', 'TestChildTwo': '4.56'}, tree) + + +class TestListObjectConcat(test.TestCase): + def test_list_object_concat(self): + @base.VersionedObjectRegistry.register_if(False) + class MyList(base.ObjectListBase, base.VersionedObject): + fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')} + + values = [1, 2, 42] + + list1 = MyList(objects=[MyOwnedObject(baz=values[0]), + MyOwnedObject(baz=values[1])]) + list2 = MyList(objects=[MyOwnedObject(baz=values[2])]) + + concat_list = list1 + list2 + for idx, obj in enumerate(concat_list): + self.assertEqual(values[idx], obj.baz) + + # Assert that the original lists are unmodified + self.assertEqual(2, len(list1.objects)) + self.assertEqual(1, list1.objects[0].baz) + self.assertEqual(2, list1.objects[1].baz) + self.assertEqual(1, len(list2.objects)) + self.assertEqual(42, list2.objects[0].baz) + + def test_list_object_concat_fails_different_objects(self): + @base.VersionedObjectRegistry.register_if(False) + class MyList(base.ObjectListBase, base.VersionedObject): + fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')} + + @base.VersionedObjectRegistry.register_if(False) + class MyList2(base.ObjectListBase, base.VersionedObject): + fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')} + + list1 = MyList(objects=[MyOwnedObject(baz=1)]) + list2 = MyList2(objects=[MyOwnedObject(baz=2)]) + + def add(x, y): + return x + y + + self.assertRaises(TypeError, add, list1, list2) + # Assert that the original lists are unmodified + self.assertEqual(1, len(list1.objects)) + self.assertEqual(1, len(list2.objects)) + self.assertEqual(1, list1.objects[0].baz) + self.assertEqual(2, list2.objects[0].baz) + + def test_list_object_concat_fails_extra_fields(self): + @base.VersionedObjectRegistry.register_if(False) + class MyList(base.ObjectListBase, base.VersionedObject): + fields = {'objects': fields.ListOfObjectsField('MyOwnedObject'), + 'foo': fields.IntegerField(nullable=True)} + + list1 = MyList(objects=[MyOwnedObject(baz=1)]) + list2 = MyList(objects=[MyOwnedObject(baz=2)]) + + def add(x, y): + return x + y + + self.assertRaises(TypeError, add, list1, list2) + # Assert that the original lists are unmodified + self.assertEqual(1, len(list1.objects)) + self.assertEqual(1, len(list2.objects)) + self.assertEqual(1, list1.objects[0].baz) + self.assertEqual(2, list2.objects[0].baz) + + def test_builtin_list_add_fails(self): + @base.VersionedObjectRegistry.register_if(False) + class MyList(base.ObjectListBase, base.VersionedObject): + fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')} + + list1 = MyList(objects=[MyOwnedObject(baz=1)]) + + def add(obj): + return obj + [] + + self.assertRaises(TypeError, add, list1) + + def test_builtin_list_radd_fails(self): + @base.VersionedObjectRegistry.register_if(False) + class MyList(base.ObjectListBase, base.VersionedObject): + fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')} + + list1 = MyList(objects=[MyOwnedObject(baz=1)]) + + def add(obj): + return [] + obj + + self.assertRaises(TypeError, add, list1)