Check for duplicate task names on flattening

Task names should be unique within flow. But complete set of tasks
is not generally known until flow flattening is run. So, we check
task names uniqueness as soon as we have execution graph.

Change-Id: I658b3ae606fc79e600d90b51f0b4ed4f4e7d511d
This commit is contained in:
Ivan A. Melnikov
2013-10-03 12:52:17 +04:00
parent 9545fc43b9
commit 91136532ed
3 changed files with 41 additions and 0 deletions

View File

@@ -20,6 +20,7 @@ import string
import networkx as nx
from taskflow import exceptions as exc
from taskflow.patterns import graph_flow as gf
from taskflow.patterns import linear_flow as lf
from taskflow.patterns import unordered_flow as uf
@@ -167,3 +168,20 @@ class FlattenTest(test.TestCase):
set(g_utils.get_no_predecessors(g)))
self.assertEquals(set([d]),
set(g_utils.get_no_successors(g)))
def test_flatten_checks_for_dups(self):
flo = gf.Flow("test").add(
t_utils.DummyTask(name="a"),
t_utils.DummyTask(name="a")
)
with self.assertRaisesRegexp(exc.InvariantViolationException,
'^Tasks with duplicate names'):
f_utils.flatten(flo)
def test_flatten_checks_for_dups_globally(self):
flo = gf.Flow("test").add(
gf.Flow("int1").add(t_utils.DummyTask(name="a")),
gf.Flow("int2").add(t_utils.DummyTask(name="a")))
with self.assertRaisesRegexp(exc.InvariantViolationException,
'^Tasks with duplicate names'):
f_utils.flatten(flo)

View File

@@ -18,11 +18,13 @@
import networkx as nx
from taskflow import exceptions
from taskflow.patterns import graph_flow as gf
from taskflow.patterns import linear_flow as lf
from taskflow.patterns import unordered_flow as uf
from taskflow import task
from taskflow.utils import graph_utils as gu
from taskflow.utils import misc
# Use the 'flatten' reason as the need to add an edge here, which is useful for
@@ -110,6 +112,14 @@ def _flatten(item, flattened):
def flatten(item, freeze=True):
graph = _flatten(item, set())
dup_names = misc.get_duplicate_keys(graph.nodes_iter(),
key=lambda node: node.name)
if dup_names:
raise exceptions.InvariantViolationException(
"Tasks with duplicate names found: %s"
% ', '.join(sorted(dup_names)))
if freeze:
# Frozen graph can't be modified...
return nx.freeze(graph)

View File

@@ -22,6 +22,7 @@ from distutils import version
import collections
import copy
import errno
import itertools
import logging
import os
import sys
@@ -51,6 +52,18 @@ def get_version_string(obj):
return obj_version
def get_duplicate_keys(iterable, key=None):
if key is not None:
iterable = itertools.imap(key, iterable)
keys = set()
duplicates = set()
for item in iterable:
if item in keys:
duplicates.add(item)
keys.add(item)
return duplicates
class ExponentialBackoff(object):
def __init__(self, attempts, exponent=2):
self.attempts = int(attempts)