Move validation of compiled unit out of compiler
Instead of having the compiler do any validation on the graph it has created instead have the compiler just compile and have the engine that uses that compiled result do any post compilation validation instead. This makes it more clear that the compiler just compiles a flow (and tasks and nested flows) into a graph, and that is all that it does. Change-Id: I96a35d732dc2be9fc8bc8dc6466256a19ac2df6d
This commit is contained in:
@@ -17,16 +17,15 @@
|
||||
import threading
|
||||
|
||||
import fasteners
|
||||
from oslo_utils import excutils
|
||||
import six
|
||||
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow import flow
|
||||
from taskflow import logging
|
||||
from taskflow import task
|
||||
from taskflow.types import graph as gr
|
||||
from taskflow.types import tree as tr
|
||||
from taskflow.utils import iter_utils
|
||||
from taskflow.utils import misc
|
||||
|
||||
from taskflow.flow import (LINK_INVARIANT, LINK_RETRY) # noqa
|
||||
|
||||
@@ -322,24 +321,26 @@ class PatternCompiler(object):
|
||||
|
||||
def _post_compile(self, graph, node):
|
||||
"""Called after the compilation of the root finishes successfully."""
|
||||
dup_names = misc.get_duplicate_keys(
|
||||
(node for node, node_attrs in graph.nodes_iter(data=True)
|
||||
if node_attrs['kind'] in ATOMS),
|
||||
key=lambda node: node.name)
|
||||
if dup_names:
|
||||
raise exc.Duplicate(
|
||||
"Atoms with duplicate names found: %s" % (sorted(dup_names)))
|
||||
self._history.clear()
|
||||
self._level = 0
|
||||
|
||||
@fasteners.locked
|
||||
def compile(self):
|
||||
"""Compiles the contained item into a compiled equivalent."""
|
||||
if self._compilation is None:
|
||||
self._pre_compile()
|
||||
graph, node = self._compile(self._root, parent=None)
|
||||
self._post_compile(graph, node)
|
||||
if self._freeze:
|
||||
graph.freeze()
|
||||
node.freeze()
|
||||
self._compilation = Compilation(graph, node)
|
||||
try:
|
||||
graph, node = self._compile(self._root, parent=None)
|
||||
except Exception:
|
||||
with excutils.save_and_reraise_exception():
|
||||
# Always clear the history, to avoid retaining junk
|
||||
# in memory that isn't needed to be in memory if
|
||||
# compilation fails...
|
||||
self._history.clear()
|
||||
else:
|
||||
self._post_compile(graph, node)
|
||||
if self._freeze:
|
||||
graph.freeze()
|
||||
node.freeze()
|
||||
self._compilation = Compilation(graph, node)
|
||||
return self._compilation
|
||||
|
@@ -222,6 +222,24 @@ class ActionEngine(base.Engine):
|
||||
six.itervalues(self.storage.get_revert_failures()))
|
||||
failure.Failure.reraise_if_any(it)
|
||||
|
||||
@staticmethod
|
||||
def _check_compilation(compilation):
|
||||
"""Performs post compilation validation/checks."""
|
||||
seen = set()
|
||||
dups = set()
|
||||
execution_graph = compilation.execution_graph
|
||||
for node, node_attrs in execution_graph.nodes_iter(data=True):
|
||||
if node_attrs['kind'] in compiler.ATOMS:
|
||||
atom_name = node.name
|
||||
if atom_name in seen:
|
||||
dups.add(atom_name)
|
||||
else:
|
||||
seen.add(atom_name)
|
||||
if dups:
|
||||
raise exc.Duplicate(
|
||||
"Atoms with duplicate names found: %s" % (sorted(dups)))
|
||||
return compilation
|
||||
|
||||
def _change_state(self, state):
|
||||
with self._state_lock:
|
||||
old_state = self.storage.get_flow_state()
|
||||
@@ -318,8 +336,7 @@ class ActionEngine(base.Engine):
|
||||
def compile(self):
|
||||
if self._compiled:
|
||||
return
|
||||
self._compilation = self._compiler.compile()
|
||||
|
||||
self._compilation = self._check_compilation(self._compiler.compile())
|
||||
self._runtime = runtime.Runtime(self._compilation,
|
||||
self.storage,
|
||||
self.atom_notifier,
|
||||
|
@@ -14,6 +14,7 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from taskflow import engines
|
||||
from taskflow.engines.action_engine import compiler
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow.patterns import graph_flow as gf
|
||||
@@ -399,17 +400,19 @@ class PatternCompileTest(test.TestCase):
|
||||
test_utils.DummyTask(name="a"),
|
||||
test_utils.DummyTask(name="a")
|
||||
)
|
||||
e = engines.load(flo)
|
||||
self.assertRaisesRegexp(exc.Duplicate,
|
||||
'^Atoms with duplicate names',
|
||||
compiler.PatternCompiler(flo).compile)
|
||||
e.compile)
|
||||
|
||||
def test_checks_for_dups_globally(self):
|
||||
flo = gf.Flow("test").add(
|
||||
gf.Flow("int1").add(test_utils.DummyTask(name="a")),
|
||||
gf.Flow("int2").add(test_utils.DummyTask(name="a")))
|
||||
e = engines.load(flo)
|
||||
self.assertRaisesRegexp(exc.Duplicate,
|
||||
'^Atoms with duplicate names',
|
||||
compiler.PatternCompiler(flo).compile)
|
||||
e.compile)
|
||||
|
||||
def test_retry_in_linear_flow(self):
|
||||
flo = lf.Flow("test", retry.AlwaysRevert("c"))
|
||||
|
@@ -35,7 +35,6 @@ from oslo_utils import importutils
|
||||
from oslo_utils import netutils
|
||||
from oslo_utils import reflection
|
||||
import six
|
||||
from six.moves import map as compat_map
|
||||
from six.moves import range as compat_range
|
||||
|
||||
from taskflow.types import failure
|
||||
@@ -453,18 +452,6 @@ def sequence_minus(seq1, seq2):
|
||||
return result
|
||||
|
||||
|
||||
def get_duplicate_keys(iterable, key=None):
|
||||
if key is not None:
|
||||
iterable = compat_map(key, iterable)
|
||||
keys = set()
|
||||
duplicates = set()
|
||||
for item in iterable:
|
||||
if item in keys:
|
||||
duplicates.add(item)
|
||||
keys.add(item)
|
||||
return duplicates
|
||||
|
||||
|
||||
class ExponentialBackoff(object):
|
||||
"""An iterable object that will yield back an exponential delay sequence.
|
||||
|
||||
|
Reference in New Issue
Block a user