Check for duplicate task names on flattening
Task names should be unique within flow. But complete set of tasks is not generally known until flow flattening is run. So, we check task names uniqueness as soon as we have execution graph. Change-Id: I658b3ae606fc79e600d90b51f0b4ed4f4e7d511d
This commit is contained in:
@@ -20,6 +20,7 @@ import string
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow.patterns import graph_flow as gf
|
||||
from taskflow.patterns import linear_flow as lf
|
||||
from taskflow.patterns import unordered_flow as uf
|
||||
@@ -167,3 +168,20 @@ class FlattenTest(test.TestCase):
|
||||
set(g_utils.get_no_predecessors(g)))
|
||||
self.assertEquals(set([d]),
|
||||
set(g_utils.get_no_successors(g)))
|
||||
|
||||
def test_flatten_checks_for_dups(self):
|
||||
flo = gf.Flow("test").add(
|
||||
t_utils.DummyTask(name="a"),
|
||||
t_utils.DummyTask(name="a")
|
||||
)
|
||||
with self.assertRaisesRegexp(exc.InvariantViolationException,
|
||||
'^Tasks with duplicate names'):
|
||||
f_utils.flatten(flo)
|
||||
|
||||
def test_flatten_checks_for_dups_globally(self):
|
||||
flo = gf.Flow("test").add(
|
||||
gf.Flow("int1").add(t_utils.DummyTask(name="a")),
|
||||
gf.Flow("int2").add(t_utils.DummyTask(name="a")))
|
||||
with self.assertRaisesRegexp(exc.InvariantViolationException,
|
||||
'^Tasks with duplicate names'):
|
||||
f_utils.flatten(flo)
|
||||
|
||||
@@ -18,11 +18,13 @@
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from taskflow import exceptions
|
||||
from taskflow.patterns import graph_flow as gf
|
||||
from taskflow.patterns import linear_flow as lf
|
||||
from taskflow.patterns import unordered_flow as uf
|
||||
from taskflow import task
|
||||
from taskflow.utils import graph_utils as gu
|
||||
from taskflow.utils import misc
|
||||
|
||||
|
||||
# Use the 'flatten' reason as the need to add an edge here, which is useful for
|
||||
@@ -110,6 +112,14 @@ def _flatten(item, flattened):
|
||||
|
||||
def flatten(item, freeze=True):
|
||||
graph = _flatten(item, set())
|
||||
|
||||
dup_names = misc.get_duplicate_keys(graph.nodes_iter(),
|
||||
key=lambda node: node.name)
|
||||
if dup_names:
|
||||
raise exceptions.InvariantViolationException(
|
||||
"Tasks with duplicate names found: %s"
|
||||
% ', '.join(sorted(dup_names)))
|
||||
|
||||
if freeze:
|
||||
# Frozen graph can't be modified...
|
||||
return nx.freeze(graph)
|
||||
|
||||
@@ -22,6 +22,7 @@ from distutils import version
|
||||
import collections
|
||||
import copy
|
||||
import errno
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -51,6 +52,18 @@ def get_version_string(obj):
|
||||
return obj_version
|
||||
|
||||
|
||||
def get_duplicate_keys(iterable, key=None):
|
||||
if key is not None:
|
||||
iterable = itertools.imap(key, iterable)
|
||||
keys = set()
|
||||
duplicates = set()
|
||||
for item in iterable:
|
||||
if item in keys:
|
||||
duplicates.add(item)
|
||||
keys.add(item)
|
||||
return duplicates
|
||||
|
||||
|
||||
class ExponentialBackoff(object):
|
||||
def __init__(self, attempts, exponent=2):
|
||||
self.attempts = int(attempts)
|
||||
|
||||
Reference in New Issue
Block a user