Merge "Tweak functor used to find flatteners/storage routines"
This commit is contained in:
commit
8a718aa03d
@ -20,7 +20,6 @@ import threading
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow import flow
|
||||
from taskflow import logging
|
||||
from taskflow import retry
|
||||
from taskflow import task
|
||||
from taskflow.types import graph as gr
|
||||
from taskflow.types import tree as tr
|
||||
@ -281,34 +280,22 @@ class PatternCompiler(object):
|
||||
self._freeze = freeze
|
||||
self._lock = threading.Lock()
|
||||
self._compilation = None
|
||||
self._flatten_matchers = [
|
||||
((flow.Flow,), self._flatten_flow),
|
||||
((task.BaseTask,), self._flatten_task),
|
||||
]
|
||||
|
||||
def _flatten(self, item, parent):
|
||||
"""Flattens a item (pattern, task) into a graph + tree node."""
|
||||
functor = self._find_flattener(item, parent)
|
||||
functor = misc.match_type_handler(item, self._flatten_matchers)
|
||||
if not functor:
|
||||
raise TypeError("Unknown item '%s' (%s) requested to flatten"
|
||||
% (item, type(item)))
|
||||
self._pre_item_flatten(item)
|
||||
graph, node = functor(item, parent)
|
||||
self._post_item_flatten(item, graph, node)
|
||||
return graph, node
|
||||
|
||||
def _find_flattener(self, item, parent):
|
||||
"""Locates the flattening function to use to flatten the given item."""
|
||||
if isinstance(item, flow.Flow):
|
||||
return self._flatten_flow
|
||||
elif isinstance(item, task.BaseTask):
|
||||
return self._flatten_task
|
||||
elif isinstance(item, retry.Retry):
|
||||
if parent is None:
|
||||
raise TypeError("Retry controller '%s' (%s) must only be used"
|
||||
" as a flow constructor parameter and not as a"
|
||||
" root component" % (item, type(item)))
|
||||
else:
|
||||
raise TypeError("Retry controller '%s' (%s) must only be used"
|
||||
" as a flow constructor parameter and not as a"
|
||||
" flow added component" % (item, type(item)))
|
||||
else:
|
||||
raise TypeError("Unknown item '%s' (%s) requested to flatten"
|
||||
% (item, type(item)))
|
||||
|
||||
def _connect_retry(self, retry, graph):
|
||||
graph.add_node(retry)
|
||||
|
||||
|
@ -132,6 +132,10 @@ class Storage(object):
|
||||
self._transients = {}
|
||||
self._injected_args = {}
|
||||
self._lock = lock_utils.ReaderWriterLock()
|
||||
self._ensure_matchers = [
|
||||
((task.BaseTask,), self._ensure_task),
|
||||
((retry.Retry,), self._ensure_retry),
|
||||
]
|
||||
|
||||
# NOTE(imelnikov): failure serialization looses information,
|
||||
# so we cache failures here, in atom name -> failure mapping.
|
||||
@ -168,17 +172,14 @@ class Storage(object):
|
||||
|
||||
Returns uuid for the atomdetail that is/was created.
|
||||
"""
|
||||
if isinstance(atom, task.BaseTask):
|
||||
return self._ensure_task(atom.name,
|
||||
misc.get_version_string(atom),
|
||||
atom.save_as)
|
||||
elif isinstance(atom, retry.Retry):
|
||||
return self._ensure_retry(atom.name,
|
||||
misc.get_version_string(atom),
|
||||
atom.save_as)
|
||||
functor = misc.match_type_handler(atom, self._ensure_matchers)
|
||||
if not functor:
|
||||
raise TypeError("Unknown item '%s' (%s) requested to ensure"
|
||||
% (atom, type(atom)))
|
||||
else:
|
||||
raise TypeError("Object of type 'atom' expected not"
|
||||
" '%s' (%s)" % (atom, type(atom)))
|
||||
return functor(atom.name,
|
||||
misc.get_version_string(atom),
|
||||
atom.save_as)
|
||||
|
||||
def _ensure_task(self, task_name, task_version, result_mapping):
|
||||
"""Ensures there is a taskdetail that corresponds to the task info.
|
||||
|
@ -34,9 +34,7 @@ class PatternCompileTest(test.TestCase):
|
||||
|
||||
def test_retry(self):
|
||||
r = retry.AlwaysRevert('r1')
|
||||
msg_regex = "^Retry controller .* must only be used .*"
|
||||
self.assertRaisesRegexp(TypeError, msg_regex,
|
||||
compiler.PatternCompiler(r).compile)
|
||||
self.assertRaises(TypeError, compiler.PatternCompiler(r).compile)
|
||||
|
||||
def test_wrong_object(self):
|
||||
msg_regex = '^Unknown item .* requested to flatten'
|
||||
|
@ -377,12 +377,12 @@ class RetryTest(utils.EngineTestBase):
|
||||
def test_run_just_retry(self):
|
||||
flow = utils.OneReturnRetry(provides='x')
|
||||
engine = self._make_engine(flow)
|
||||
self.assertRaisesRegexp(TypeError, 'Retry controller', engine.run)
|
||||
self.assertRaises(TypeError, engine.run)
|
||||
|
||||
def test_use_retry_as_a_task(self):
|
||||
flow = lf.Flow('test').add(utils.OneReturnRetry(provides='x'))
|
||||
engine = self._make_engine(flow)
|
||||
self.assertRaisesRegexp(TypeError, 'Retry controller', engine.run)
|
||||
self.assertRaises(TypeError, engine.run)
|
||||
|
||||
def test_resume_flow_that_had_been_interrupted_during_retrying(self):
|
||||
flow = lf.Flow('flow-1', retry.Times(3, 'r1')).add(
|
||||
|
@ -87,6 +87,18 @@ def find_monotonic(allow_time_time=False):
|
||||
return None
|
||||
|
||||
|
||||
def match_type_handler(item, type_handlers):
|
||||
"""Matches a given items type using the given match types + handlers.
|
||||
|
||||
Returns the handler if a type match occurs, otherwise none.
|
||||
"""
|
||||
for (match_types, handler_func) in type_handlers:
|
||||
if isinstance(item, match_types):
|
||||
return handler_func
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def countdown_iter(start_at, decr=1):
|
||||
"""Generator that decrements after each generation until <= zero.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user