Tweak functor used to find flatteners/storage routines

Make both of these finding functions use similar routines
that the utility module now provides since the logic that
both use can be shared.

Change-Id: Ib941b99945d42f5c0d791e9b2a0696d0e62a2388
This commit is contained in:
Joshua Harlow
2015-03-02 18:49:17 -08:00
parent d4f02f6fea
commit b598897d15
5 changed files with 34 additions and 36 deletions

View File

@@ -20,7 +20,6 @@ import threading
from taskflow import exceptions as exc from taskflow import exceptions as exc
from taskflow import flow from taskflow import flow
from taskflow import logging from taskflow import logging
from taskflow import retry
from taskflow import task from taskflow import task
from taskflow.types import graph as gr from taskflow.types import graph as gr
from taskflow.types import tree as tr from taskflow.types import tree as tr
@@ -281,34 +280,22 @@ class PatternCompiler(object):
self._freeze = freeze self._freeze = freeze
self._lock = threading.Lock() self._lock = threading.Lock()
self._compilation = None self._compilation = None
self._flatten_matchers = [
((flow.Flow,), self._flatten_flow),
((task.BaseTask,), self._flatten_task),
]
def _flatten(self, item, parent): def _flatten(self, item, parent):
"""Flattens a item (pattern, task) into a graph + tree node.""" """Flattens a item (pattern, task) into a graph + tree node."""
functor = self._find_flattener(item, parent) functor = misc.match_type_handler(item, self._flatten_matchers)
if not functor:
raise TypeError("Unknown item '%s' (%s) requested to flatten"
% (item, type(item)))
self._pre_item_flatten(item) self._pre_item_flatten(item)
graph, node = functor(item, parent) graph, node = functor(item, parent)
self._post_item_flatten(item, graph, node) self._post_item_flatten(item, graph, node)
return graph, node return graph, node
def _find_flattener(self, item, parent):
"""Locates the flattening function to use to flatten the given item."""
if isinstance(item, flow.Flow):
return self._flatten_flow
elif isinstance(item, task.BaseTask):
return self._flatten_task
elif isinstance(item, retry.Retry):
if parent is None:
raise TypeError("Retry controller '%s' (%s) must only be used"
" as a flow constructor parameter and not as a"
" root component" % (item, type(item)))
else:
raise TypeError("Retry controller '%s' (%s) must only be used"
" as a flow constructor parameter and not as a"
" flow added component" % (item, type(item)))
else:
raise TypeError("Unknown item '%s' (%s) requested to flatten"
% (item, type(item)))
def _connect_retry(self, retry, graph): def _connect_retry(self, retry, graph):
graph.add_node(retry) graph.add_node(retry)

View File

@@ -132,6 +132,10 @@ class Storage(object):
self._transients = {} self._transients = {}
self._injected_args = {} self._injected_args = {}
self._lock = lock_utils.ReaderWriterLock() self._lock = lock_utils.ReaderWriterLock()
self._ensure_matchers = [
((task.BaseTask,), self._ensure_task),
((retry.Retry,), self._ensure_retry),
]
# NOTE(imelnikov): failure serialization looses information, # NOTE(imelnikov): failure serialization looses information,
# so we cache failures here, in atom name -> failure mapping. # so we cache failures here, in atom name -> failure mapping.
@@ -168,17 +172,14 @@ class Storage(object):
Returns uuid for the atomdetail that is/was created. Returns uuid for the atomdetail that is/was created.
""" """
if isinstance(atom, task.BaseTask): functor = misc.match_type_handler(atom, self._ensure_matchers)
return self._ensure_task(atom.name, if not functor:
misc.get_version_string(atom), raise TypeError("Unknown item '%s' (%s) requested to ensure"
atom.save_as) % (atom, type(atom)))
elif isinstance(atom, retry.Retry):
return self._ensure_retry(atom.name,
misc.get_version_string(atom),
atom.save_as)
else: else:
raise TypeError("Object of type 'atom' expected not" return functor(atom.name,
" '%s' (%s)" % (atom, type(atom))) misc.get_version_string(atom),
atom.save_as)
def _ensure_task(self, task_name, task_version, result_mapping): def _ensure_task(self, task_name, task_version, result_mapping):
"""Ensures there is a taskdetail that corresponds to the task info. """Ensures there is a taskdetail that corresponds to the task info.

View File

@@ -34,9 +34,7 @@ class PatternCompileTest(test.TestCase):
def test_retry(self): def test_retry(self):
r = retry.AlwaysRevert('r1') r = retry.AlwaysRevert('r1')
msg_regex = "^Retry controller .* must only be used .*" self.assertRaises(TypeError, compiler.PatternCompiler(r).compile)
self.assertRaisesRegexp(TypeError, msg_regex,
compiler.PatternCompiler(r).compile)
def test_wrong_object(self): def test_wrong_object(self):
msg_regex = '^Unknown item .* requested to flatten' msg_regex = '^Unknown item .* requested to flatten'

View File

@@ -377,12 +377,12 @@ class RetryTest(utils.EngineTestBase):
def test_run_just_retry(self): def test_run_just_retry(self):
flow = utils.OneReturnRetry(provides='x') flow = utils.OneReturnRetry(provides='x')
engine = self._make_engine(flow) engine = self._make_engine(flow)
self.assertRaisesRegexp(TypeError, 'Retry controller', engine.run) self.assertRaises(TypeError, engine.run)
def test_use_retry_as_a_task(self): def test_use_retry_as_a_task(self):
flow = lf.Flow('test').add(utils.OneReturnRetry(provides='x')) flow = lf.Flow('test').add(utils.OneReturnRetry(provides='x'))
engine = self._make_engine(flow) engine = self._make_engine(flow)
self.assertRaisesRegexp(TypeError, 'Retry controller', engine.run) self.assertRaises(TypeError, engine.run)
def test_resume_flow_that_had_been_interrupted_during_retrying(self): def test_resume_flow_that_had_been_interrupted_during_retrying(self):
flow = lf.Flow('flow-1', retry.Times(3, 'r1')).add( flow = lf.Flow('flow-1', retry.Times(3, 'r1')).add(

View File

@@ -87,6 +87,18 @@ def find_monotonic(allow_time_time=False):
return None return None
def match_type_handler(item, type_handlers):
"""Matches a given items type using the given match types + handlers.
Returns the handler if a type match occurs, otherwise none.
"""
for (match_types, handler_func) in type_handlers:
if isinstance(item, match_types):
return handler_func
else:
return None
def countdown_iter(start_at, decr=1): def countdown_iter(start_at, decr=1):
"""Generator that decrements after each generation until <= zero. """Generator that decrements after each generation until <= zero.