diff --git a/taskflow/tests/unit/test_types.py b/taskflow/tests/unit/test_types.py index 8aae36fa..c13b263b 100644 --- a/taskflow/tests/unit/test_types.py +++ b/taskflow/tests/unit/test_types.py @@ -217,6 +217,23 @@ class StopWatchTest(test.TestCase): watch.start() self.assertEqual(0, len(watch.splits)) + def test_elapsed_maximum(self): + watch = tt.StopWatch() + watch.start() + + timeutils.advance_time_seconds(1) + self.assertEqual(1, watch.elapsed()) + + timeutils.advance_time_seconds(10) + self.assertEqual(11, watch.elapsed()) + self.assertEqual(1, watch.elapsed(maximum=1)) + + watch.stop() + self.assertEqual(11, watch.elapsed()) + timeutils.advance_time_seconds(10) + self.assertEqual(11, watch.elapsed()) + self.assertEqual(0, watch.elapsed(maximum=-1)) + class TableTest(test.TestCase): def test_create_valid_no_rows(self): diff --git a/taskflow/types/timing.py b/taskflow/types/timing.py index e7fa7d45..8f868431 100644 --- a/taskflow/types/timing.py +++ b/taskflow/types/timing.py @@ -137,17 +137,20 @@ class StopWatch(object): self.start() return self - def elapsed(self): + def elapsed(self, maximum=None): """Returns how many seconds have elapsed.""" - if self._state == self._STOPPED: - return max(0.0, float(timeutils.delta_seconds(self._started_at, - self._stopped_at))) - elif self._state == self._STARTED: - return max(0.0, float(timeutils.delta_seconds(self._started_at, - timeutils.utcnow()))) - else: + if self._state not in (self._STOPPED, self._STARTED): raise RuntimeError("Can not get the elapsed time of a stopwatch" " if it has not been started/stopped") + if self._state == self._STOPPED: + elapsed = max(0.0, float(timeutils.delta_seconds( + self._started_at, self._stopped_at))) + else: + elapsed = max(0.0, float(timeutils.delta_seconds( + self._started_at, timeutils.utcnow()))) + if maximum is not None and elapsed > maximum: + elapsed = max(0.0, maximum) + return elapsed def __enter__(self): """Starts the watch."""