diff --git a/doc/source/types.rst b/doc/source/types.rst index 47ba7e48..57e10986 100644 --- a/doc/source/types.rst +++ b/doc/source/types.rst @@ -49,6 +49,11 @@ Periodic .. automodule:: taskflow.types.periodic +Sets +==== + +.. automodule:: taskflow.types.sets + Table ===== diff --git a/taskflow/tests/unit/test_types.py b/taskflow/tests/unit/test_types.py index d927781a..a16bda0d 100644 --- a/taskflow/tests/unit/test_types.py +++ b/taskflow/tests/unit/test_types.py @@ -25,6 +25,7 @@ from taskflow.types import fsm from taskflow.types import graph from taskflow.types import latch from taskflow.types import periodic +from taskflow.types import sets from taskflow.types import table from taskflow.types import timing as tt from taskflow.types import tree @@ -608,6 +609,141 @@ class FSMTest(test.TestCase): self.assertRaises(ValueError, m.add_state, 'b', on_exit=2) +class OrderedSetTest(test.TestCase): + + def test_retain_ordering(self): + items = [10, 9, 8, 7] + s = sets.OrderedSet(iter(items)) + self.assertEqual(items, list(s)) + + def test_retain_duplicate_ordering(self): + items = [10, 9, 10, 8, 9, 7, 8] + s = sets.OrderedSet(iter(items)) + self.assertEqual([10, 9, 8, 7], list(s)) + + def test_length(self): + items = [10, 9, 8, 7] + s = sets.OrderedSet(iter(items)) + self.assertEqual(4, len(s)) + + def test_duplicate_length(self): + items = [10, 9, 10, 8, 9, 7, 8] + s = sets.OrderedSet(iter(items)) + self.assertEqual(4, len(s)) + + def test_contains(self): + items = [10, 9, 8, 7] + s = sets.OrderedSet(iter(items)) + for i in items: + self.assertIn(i, s) + + def test_copy(self): + items = [10, 9, 8, 7] + s = sets.OrderedSet(iter(items)) + s2 = s.copy() + self.assertEqual(s, s2) + self.assertEqual(items, list(s2)) + + def test_empty_intersection(self): + s = sets.OrderedSet([1, 2, 3]) + + es = set(s) + + self.assertEqual(es.intersection(), s.intersection()) + + def test_intersection(self): + s = sets.OrderedSet([1, 2, 3]) + s2 = sets.OrderedSet([2, 3, 4, 5]) + + es = set(s) + es2 = set(s2) + + self.assertEqual(es.intersection(es2), s.intersection(s2)) + self.assertEqual(es2.intersection(s), s2.intersection(s)) + + def test_multi_intersection(self): + s = sets.OrderedSet([1, 2, 3]) + s2 = sets.OrderedSet([2, 3, 4, 5]) + s3 = sets.OrderedSet([1, 2]) + + es = set(s) + es2 = set(s2) + es3 = set(s3) + + self.assertEqual(es.intersection(s2, s3), s.intersection(s2, s3)) + self.assertEqual(es2.intersection(es3), s2.intersection(s3)) + + def test_superset(self): + s = sets.OrderedSet([1, 2, 3]) + s2 = sets.OrderedSet([2, 3]) + self.assertTrue(s.issuperset(s2)) + self.assertFalse(s.issubset(s2)) + + def test_subset(self): + s = sets.OrderedSet([1, 2, 3]) + s2 = sets.OrderedSet([2, 3]) + self.assertTrue(s2.issubset(s)) + self.assertFalse(s2.issuperset(s)) + + def test_empty_difference(self): + s = sets.OrderedSet([1, 2, 3]) + + es = set(s) + + self.assertEqual(es.difference(), s.difference()) + + def test_difference(self): + s = sets.OrderedSet([1, 2, 3]) + s2 = sets.OrderedSet([2, 3]) + + es = set(s) + es2 = set(s2) + + self.assertEqual(es.difference(es2), s.difference(s2)) + self.assertEqual(es2.difference(es), s2.difference(s)) + + def test_multi_difference(self): + s = sets.OrderedSet([1, 2, 3]) + s2 = sets.OrderedSet([2, 3]) + s3 = sets.OrderedSet([3, 4, 5]) + + es = set(s) + es2 = set(s2) + es3 = set(s3) + + self.assertEqual(es3.difference(es), s3.difference(s)) + self.assertEqual(es.difference(es3), s.difference(s3)) + self.assertEqual(es2.difference(es, es3), s2.difference(s, s3)) + + def test_empty_union(self): + s = sets.OrderedSet([1, 2, 3]) + + es = set(s) + + self.assertEqual(es.union(), s.union()) + + def test_union(self): + s = sets.OrderedSet([1, 2, 3]) + s2 = sets.OrderedSet([2, 3, 4]) + + es = set(s) + es2 = set(s2) + + self.assertEqual(es.union(es2), s.union(s2)) + self.assertEqual(es2.union(es), s2.union(s)) + + def test_multi_union(self): + s = sets.OrderedSet([1, 2, 3]) + s2 = sets.OrderedSet([2, 3, 4]) + s3 = sets.OrderedSet([4, 5, 6]) + + es = set(s) + es2 = set(s2) + es3 = set(s3) + + self.assertEqual(es.union(es2, es3), s.union(s2, s3)) + + class PeriodicTest(test.TestCase): def test_invalid_periodic(self): diff --git a/taskflow/types/sets.py b/taskflow/types/sets.py new file mode 100644 index 00000000..4ab99da5 --- /dev/null +++ b/taskflow/types/sets.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2015 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections +import itertools + +try: + from collections import OrderedDict # noqa +except ImportError: + from ordereddict import OrderedDict # noqa + +import six + + +# Used for values that don't matter in sets backed by dicts... +_sentinel = object() + + +def _merge_in(target, iterable=None, sentinel=_sentinel): + """Merges iterable into the target and returns the target.""" + if iterable is not None: + for value in iterable: + target.setdefault(value, sentinel) + return target + + +class OrderedSet(collections.Set, collections.Hashable): + """A read-only hashable set that retains insertion/initial ordering. + + It should work in all existing places that ``frozenset`` is used. + + See: https://mail.python.org/pipermail/python-ideas/2009-May/004567.html + for an idea thread that *may* eventually (*someday*) result in this (or + similar) code being included in the mainline python codebase (although + the end result of that thread is somewhat discouraging in that regard). + """ + + __slots__ = ['_data'] + + def __init__(self, iterable=None): + self._data = _merge_in(OrderedDict(), iterable) + + def __hash__(self): + return self._hash() + + def __contains__(self, value): + return value in self._data + + def __len__(self): + return len(self._data) + + def __iter__(self): + for value in six.iterkeys(self._data): + yield value + + def __repr__(self): + return "%s(%s)" % (type(self).__name__, list(self)) + + def copy(self): + """Return a shallow copy of a set.""" + it = iter(self) + c = self._from_iterable(it) + return c + + def intersection(self, *sets): + """Return the intersection of two or more sets as a new set. + + (i.e. elements that are common to all of the sets.) + """ + def absorb_it(sets): + for value in iter(self): + matches = 0 + for s in sets: + if value in s: + matches += 1 + else: + break + if matches == len(sets): + yield value + it = absorb_it(sets) + c = self._from_iterable(it) + return c + + def issuperset(self, other): + """Report whether this set contains another set.""" + for value in other: + if value not in self: + return False + return True + + def issubset(self, other): + """Report whether another set contains this set.""" + for value in iter(self): + if value not in other: + return False + return True + + def difference(self, *sets): + """Return the difference of two or more sets as a new set. + + (i.e. all elements that are in this set but not the others.) + """ + def absorb_it(sets): + for value in iter(self): + seen = False + for s in sets: + if value in s: + seen = True + break + if not seen: + yield value + it = absorb_it(sets) + c = self._from_iterable(it) + return c + + def union(self, *sets): + """Return the union of sets as a new set. + + (i.e. all elements that are in either set.) + """ + it = itertools.chain(iter(self), *sets) + return self._from_iterable(it)