diff --git a/taskflow/tests/unit/test_types.py b/taskflow/tests/unit/test_types.py index c69d9ccb..cbbe2939 100644 --- a/taskflow/tests/unit/test_types.py +++ b/taskflow/tests/unit/test_types.py @@ -18,6 +18,7 @@ import time import networkx as nx import six +from six.moves import cPickle as pickle from taskflow import exceptions as excp from taskflow import test @@ -498,6 +499,15 @@ class FSMTest(test.TestCase): class OrderedSetTest(test.TestCase): + def test_pickleable(self): + items = [10, 9, 8, 7] + s = sets.OrderedSet(items) + self.assertEqual(items, list(s)) + s_bin = pickle.dumps(s) + s2 = pickle.loads(s_bin) + self.assertEqual(s, s2) + self.assertEqual(items, list(s2)) + def test_retain_ordering(self): items = [10, 9, 8, 7] s = sets.OrderedSet(iter(items)) diff --git a/taskflow/types/sets.py b/taskflow/types/sets.py index 4ab99da5..527ad2a7 100644 --- a/taskflow/types/sets.py +++ b/taskflow/types/sets.py @@ -66,6 +66,12 @@ class OrderedSet(collections.Set, collections.Hashable): for value in six.iterkeys(self._data): yield value + def __setstate__(self, items): + self.__init__(iterable=iter(items)) + + def __getstate__(self): + return tuple(self) + def __repr__(self): return "%s(%s)" % (type(self).__name__, list(self))