diff --git a/taskflow/task.py b/taskflow/task.py index 11b9972d..ae947f1c 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -24,6 +24,7 @@ import six from taskflow import atom from taskflow import logging +from taskflow.utils import misc from taskflow.utils import reflection LOG = logging.getLogger(__name__) @@ -157,13 +158,12 @@ class BaseTask(atom.Atom): :param kwargs: any keyword arguments that are tied to the specific progress value. """ - if progress > 1.0: - LOG.warn("Progress must be <= 1.0, clamping to upper bound") - progress = 1.0 - if progress < 0.0: - LOG.warn("Progress must be >= 0.0, clamping to lower bound") - progress = 0.0 - self.trigger(EVENT_UPDATE_PROGRESS, progress, **kwargs) + def on_clamped(): + LOG.warn("Progress value must be greater or equal to 0.0 or less" + " than or equal to 1.0 instead of being '%s'", progress) + cleaned_progress = misc.clamp(progress, 0.0, 1.0, + on_clamped=on_clamped) + self.trigger(EVENT_UPDATE_PROGRESS, cleaned_progress, **kwargs) def trigger(self, event_name, *args, **kwargs): """Execute all callbacks registered for the given event type. diff --git a/taskflow/tests/unit/test_utils.py b/taskflow/tests/unit/test_utils.py index 66f08c09..56e39199 100644 --- a/taskflow/tests/unit/test_utils.py +++ b/taskflow/tests/unit/test_utils.py @@ -443,3 +443,32 @@ class TestSequenceMinus(test.TestCase): def test_equal_items_not_continious(self): result = misc.sequence_minus([1, 2, 3, 1], [1, 3]) self.assertEqual(result, [2, 1]) + + +class TestClamping(test.TestCase): + def test_simple_clamp(self): + result = misc.clamp(1.0, 2.0, 3.0) + self.assertEqual(result, 2.0) + result = misc.clamp(4.0, 2.0, 3.0) + self.assertEqual(result, 3.0) + result = misc.clamp(3.0, 4.0, 4.0) + self.assertEqual(result, 4.0) + + def test_invalid_clamp(self): + self.assertRaises(ValueError, misc.clamp, 0.0, 2.0, 1.0) + + def test_clamped_callback(self): + calls = [] + + def on_clamped(): + calls.append(True) + + misc.clamp(-1, 0.0, 1.0, on_clamped=on_clamped) + self.assertEqual(1, len(calls)) + calls.pop() + + misc.clamp(0.0, 0.0, 1.0, on_clamped=on_clamped) + self.assertEqual(0, len(calls)) + + misc.clamp(2, 0.0, 1.0, on_clamped=on_clamped) + self.assertEqual(1, len(calls)) diff --git a/taskflow/utils/misc.py b/taskflow/utils/misc.py index 88148bd7..34910064 100644 --- a/taskflow/utils/misc.py +++ b/taskflow/utils/misc.py @@ -100,6 +100,22 @@ def parse_uri(uri): query=split.query) +def clamp(value, minimum, maximum, on_clamped=None): + """Clamps a value to ensure its >= minimum and <= maximum.""" + if minimum > maximum: + raise ValueError("Provided minimum '%s' must be less than or equal to" + " the provided maximum '%s'" % (minimum, maximum)) + if value > maximum: + value = maximum + if on_clamped is not None: + on_clamped() + if value < minimum: + value = minimum + if on_clamped is not None: + on_clamped() + return value + + def binary_encode(text, encoding='utf-8'): """Converts a string of into a binary type using given encoding.