diff --git a/taskflow/tests/unit/test_utils_lock_utils.py b/taskflow/tests/unit/test_utils_lock_utils.py index bb334a2d..30c8c983 100644 --- a/taskflow/tests/unit/test_utils_lock_utils.py +++ b/taskflow/tests/unit/test_utils_lock_utils.py @@ -15,6 +15,7 @@ # under the License. import collections +import random import threading import time @@ -23,6 +24,7 @@ from concurrent import futures from taskflow import test from taskflow.test import mock from taskflow.tests import utils as test_utils +from taskflow.types import timing from taskflow.utils import lock_utils from taskflow.utils import misc from taskflow.utils import threading_utils @@ -330,6 +332,43 @@ class MultilockTest(test.TestCase): class ReadWriteLockTest(test.TestCase): + THREAD_COUNT = 20 + + def test_no_double_writers(self): + lock = lock_utils.ReaderWriterLock() + watch = timing.StopWatch(duration=5) + watch.start() + dups = collections.deque() + active = collections.deque() + + def acquire_check(me): + with lock.write_lock(): + if len(active) >= 1: + dups.append(me) + dups.extend(active) + active.append(me) + try: + time.sleep(random.random() / 100) + finally: + active.remove(me) + + def run(): + me = threading.current_thread() + while not watch.expired(): + acquire_check(me) + + threads = [] + for i in range(0, self.THREAD_COUNT): + t = threading_utils.daemon_thread(run) + threads.append(t) + t.start() + while threads: + t = threads.pop() + t.join() + + self.assertEqual([], list(dups)) + self.assertEqual([], list(active)) + def test_writer_abort(self): lock = lock_utils.ReaderWriterLock() self.assertFalse(lock.owner)