diff --git a/taskflow/tests/unit/test_flattening.py b/taskflow/tests/unit/test_flattening.py index e4f112a9..46b66272 100644 --- a/taskflow/tests/unit/test_flattening.py +++ b/taskflow/tests/unit/test_flattening.py @@ -20,6 +20,7 @@ import string import networkx as nx +from taskflow import exceptions as exc from taskflow.patterns import graph_flow as gf from taskflow.patterns import linear_flow as lf from taskflow.patterns import unordered_flow as uf @@ -167,3 +168,20 @@ class FlattenTest(test.TestCase): set(g_utils.get_no_predecessors(g))) self.assertEquals(set([d]), set(g_utils.get_no_successors(g))) + + def test_flatten_checks_for_dups(self): + flo = gf.Flow("test").add( + t_utils.DummyTask(name="a"), + t_utils.DummyTask(name="a") + ) + with self.assertRaisesRegexp(exc.InvariantViolationException, + '^Tasks with duplicate names'): + f_utils.flatten(flo) + + def test_flatten_checks_for_dups_globally(self): + flo = gf.Flow("test").add( + gf.Flow("int1").add(t_utils.DummyTask(name="a")), + gf.Flow("int2").add(t_utils.DummyTask(name="a"))) + with self.assertRaisesRegexp(exc.InvariantViolationException, + '^Tasks with duplicate names'): + f_utils.flatten(flo) diff --git a/taskflow/utils/flow_utils.py b/taskflow/utils/flow_utils.py index 18bab12a..c38ce32c 100644 --- a/taskflow/utils/flow_utils.py +++ b/taskflow/utils/flow_utils.py @@ -18,11 +18,13 @@ import networkx as nx +from taskflow import exceptions from taskflow.patterns import graph_flow as gf from taskflow.patterns import linear_flow as lf from taskflow.patterns import unordered_flow as uf from taskflow import task from taskflow.utils import graph_utils as gu +from taskflow.utils import misc # Use the 'flatten' reason as the need to add an edge here, which is useful for @@ -110,6 +112,14 @@ def _flatten(item, flattened): def flatten(item, freeze=True): graph = _flatten(item, set()) + + dup_names = misc.get_duplicate_keys(graph.nodes_iter(), + key=lambda node: node.name) + if dup_names: + raise exceptions.InvariantViolationException( + "Tasks with duplicate names found: %s" + % ', '.join(sorted(dup_names))) + if freeze: # Frozen graph can't be modified... return nx.freeze(graph) diff --git a/taskflow/utils/misc.py b/taskflow/utils/misc.py index fc076d12..df27edb3 100644 --- a/taskflow/utils/misc.py +++ b/taskflow/utils/misc.py @@ -22,6 +22,7 @@ from distutils import version import collections import copy import errno +import itertools import logging import os import sys @@ -51,6 +52,18 @@ def get_version_string(obj): return obj_version +def get_duplicate_keys(iterable, key=None): + if key is not None: + iterable = itertools.imap(key, iterable) + keys = set() + duplicates = set() + for item in iterable: + if item in keys: + duplicates.add(item) + keys.add(item) + return duplicates + + class ExponentialBackoff(object): def __init__(self, attempts, exponent=2): self.attempts = int(attempts)