diff --git a/oslo_versionedobjects/base.py b/oslo_versionedobjects/base.py index c995bef7..44d7d731 100644 --- a/oslo_versionedobjects/base.py +++ b/oslo_versionedobjects/base.py @@ -562,6 +562,18 @@ class VersionedObject(object): return list(self.fields.keys()) + self.obj_extra_fields +class ComparableVersionedObject(object): + """Mix-in to provide comparason methods + + When objects are to be compared with each other (in tests for example), + this mixin can be used. + """ + def __eq__(self, obj): + # FIXME(inc0): this can return incorrect value if we consider partially + # loaded objects from db and fields which are dropped out differ + return self.obj_to_primitive() == obj.obj_to_primitive() + + class VersionedObjectDictCompat(object): """Mix-in to provide dictionary key access compat diff --git a/oslo_versionedobjects/fields.py b/oslo_versionedobjects/fields.py index ce891044..57c7d3d4 100755 --- a/oslo_versionedobjects/fields.py +++ b/oslo_versionedobjects/fields.py @@ -271,8 +271,11 @@ class Boolean(FieldType): class DateTime(FieldType): - @staticmethod - def coerce(obj, attr, value): + def __init__(self, tzinfo_aware=True, *args, **kwargs): + self.tzinfo_aware = tzinfo_aware + super(DateTime, self).__init__(*args, **kwargs) + + def coerce(self, obj, attr, value): if isinstance(value, six.string_types): # NOTE(danms): Being tolerant of isotime strings here will help us # during our objects transition @@ -280,11 +283,13 @@ class DateTime(FieldType): elif not isinstance(value, datetime.datetime): raise ValueError(_('A datetime.datetime is required here')) - if value.utcoffset() is None: + if value.utcoffset() is None and self.tzinfo_aware: # NOTE(danms): Legacy objects from sqlalchemy are stored in UTC, # but are returned without a timezone attached. # As a transitional aid, assume a tz-naive object is in UTC. value = value.replace(tzinfo=iso8601.iso8601.Utc()) + elif not self.tzinfo_aware: + value = value.replace(tzinfo=None) return value def from_primitive(self, obj, attr, value): @@ -490,7 +495,9 @@ class BooleanField(AutoTypedField): class DateTimeField(AutoTypedField): - AUTO_TYPE = DateTime() + def __init__(self, tzinfo_aware=True, **kwargs): + self.AUTO_TYPE = DateTime(tzinfo_aware=tzinfo_aware) + super(DateTimeField, self).__init__(**kwargs) class DictOfStringsField(AutoTypedField): diff --git a/oslo_versionedobjects/tests/test_fields.py b/oslo_versionedobjects/tests/test_fields.py index 21e7d1b2..48d4f2a4 100755 --- a/oslo_versionedobjects/tests/test_fields.py +++ b/oslo_versionedobjects/tests/test_fields.py @@ -135,6 +135,29 @@ class TestDateTime(TestField): tzinfo=iso8601.iso8601.Utc()))) +class TestDateTimeNoTzinfo(TestField): + def setUp(self): + super(TestDateTimeNoTzinfo, self).setUp() + self.dt = datetime.datetime(1955, 11, 5) + self.field = fields.DateTimeField(tzinfo_aware=False) + self.coerce_good_values = [(self.dt, self.dt), + (timeutils.isotime(self.dt), self.dt)] + self.coerce_bad_values = [1, 'foo'] + self.to_primitive_values = [(self.dt, timeutils.isotime(self.dt))] + self.from_primitive_values = [ + ( + timeutils.isotime(self.dt), + self.dt, + ) + ] + + def test_stringify(self): + self.assertEqual( + '1955-11-05T18:00:00Z', + self.field.stringify( + datetime.datetime(1955, 11, 5, 18, 0, 0))) + + class TestDict(TestField): def setUp(self): super(TestDict, self).setUp() diff --git a/oslo_versionedobjects/tests/test_objects.py b/oslo_versionedobjects/tests/test_objects.py index 48c8c6c3..51decd0d 100755 --- a/oslo_versionedobjects/tests/test_objects.py +++ b/oslo_versionedobjects/tests/test_objects.py @@ -116,6 +116,11 @@ class MyObj(base.VersionedObject, base.VersionedObjectDictCompat): primitive['bar'] = 'old%s' % primitive['bar'] +@base.VersionedObjectRegistry.register +class MyComparableObj(MyObj, base.ComparableVersionedObject): + pass + + @base.VersionedObjectRegistry.register class MyObjDiffVers(MyObj): VERSION = '1.5' @@ -837,6 +842,13 @@ class _TestObject(object): obj.obj_to_primitive('1.0') self.assertTrue(mock_mc.called) + def test_comparable_objects(self): + obj1 = MyComparableObj(foo=1) + obj2 = MyComparableObj(foo=1) + obj3 = MyComparableObj(foo=2) + self.assertTrue(obj1 == obj2) + self.assertFalse(obj1 == obj3) + class TestObject(_LocalTest, _TestObject): def test_set_defaults(self):