Use a tiny clamp helper to clamp the 'on_progress' value

Add a misc.clamp function that will clamp a value to a given
range (it can also call a callback if clamping occurs). Use it
to clamp the progress value that was previously clamped with
a set of customized logic that can now be replaced with a more
generalized logic that can be shared.

Change-Id: I8369dbb61f73a60932d9e15c8b4d06db249ea38e
This commit is contained in:
Joshua Harlow
2014-12-12 23:03:12 -08:00
parent 97dd12e3e0
commit cdfd8ece61
3 changed files with 52 additions and 7 deletions

View File

@@ -24,6 +24,7 @@ import six
from taskflow import atom from taskflow import atom
from taskflow import logging from taskflow import logging
from taskflow.utils import misc
from taskflow.utils import reflection from taskflow.utils import reflection
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -157,13 +158,12 @@ class BaseTask(atom.Atom):
:param kwargs: any keyword arguments that are tied to the specific :param kwargs: any keyword arguments that are tied to the specific
progress value. progress value.
""" """
if progress > 1.0: def on_clamped():
LOG.warn("Progress must be <= 1.0, clamping to upper bound") LOG.warn("Progress value must be greater or equal to 0.0 or less"
progress = 1.0 " than or equal to 1.0 instead of being '%s'", progress)
if progress < 0.0: cleaned_progress = misc.clamp(progress, 0.0, 1.0,
LOG.warn("Progress must be >= 0.0, clamping to lower bound") on_clamped=on_clamped)
progress = 0.0 self.trigger(EVENT_UPDATE_PROGRESS, cleaned_progress, **kwargs)
self.trigger(EVENT_UPDATE_PROGRESS, progress, **kwargs)
def trigger(self, event_name, *args, **kwargs): def trigger(self, event_name, *args, **kwargs):
"""Execute all callbacks registered for the given event type. """Execute all callbacks registered for the given event type.

View File

@@ -443,3 +443,32 @@ class TestSequenceMinus(test.TestCase):
def test_equal_items_not_continious(self): def test_equal_items_not_continious(self):
result = misc.sequence_minus([1, 2, 3, 1], [1, 3]) result = misc.sequence_minus([1, 2, 3, 1], [1, 3])
self.assertEqual(result, [2, 1]) self.assertEqual(result, [2, 1])
class TestClamping(test.TestCase):
def test_simple_clamp(self):
result = misc.clamp(1.0, 2.0, 3.0)
self.assertEqual(result, 2.0)
result = misc.clamp(4.0, 2.0, 3.0)
self.assertEqual(result, 3.0)
result = misc.clamp(3.0, 4.0, 4.0)
self.assertEqual(result, 4.0)
def test_invalid_clamp(self):
self.assertRaises(ValueError, misc.clamp, 0.0, 2.0, 1.0)
def test_clamped_callback(self):
calls = []
def on_clamped():
calls.append(True)
misc.clamp(-1, 0.0, 1.0, on_clamped=on_clamped)
self.assertEqual(1, len(calls))
calls.pop()
misc.clamp(0.0, 0.0, 1.0, on_clamped=on_clamped)
self.assertEqual(0, len(calls))
misc.clamp(2, 0.0, 1.0, on_clamped=on_clamped)
self.assertEqual(1, len(calls))

View File

@@ -100,6 +100,22 @@ def parse_uri(uri):
query=split.query) query=split.query)
def clamp(value, minimum, maximum, on_clamped=None):
"""Clamps a value to ensure its >= minimum and <= maximum."""
if minimum > maximum:
raise ValueError("Provided minimum '%s' must be less than or equal to"
" the provided maximum '%s'" % (minimum, maximum))
if value > maximum:
value = maximum
if on_clamped is not None:
on_clamped()
if value < minimum:
value = minimum
if on_clamped is not None:
on_clamped()
return value
def binary_encode(text, encoding='utf-8'): def binary_encode(text, encoding='utf-8'):
"""Converts a string of into a binary type using given encoding. """Converts a string of into a binary type using given encoding.