# -*- coding: utf-8 -*- # Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. import collections import random import threading import time from concurrent import futures from taskflow import test from taskflow.test import mock from taskflow.tests import utils as test_utils from taskflow.types import timing from taskflow.utils import lock_utils from taskflow.utils import misc from taskflow.utils import threading_utils # NOTE(harlowja): Sleep a little so now() can not be the same (which will # cause false positives when our overlap detection code runs). If there are # real overlaps then they will still exist. NAPPY_TIME = 0.05 # We will spend this amount of time doing some "fake" work. WORK_TIMES = [(0.01 + x / 100.0) for x in range(0, 5)] # Try to use a more accurate time for overlap detection (one that should # never go backwards and cause false positives during overlap detection...). now = misc.find_monotonic(allow_time_time=True) def _find_overlaps(times, start, end): overlaps = 0 for (s, e) in times: if s >= start and e <= end: overlaps += 1 return overlaps def _spawn_variation(readers, writers, max_workers=None): start_stops = collections.deque() lock = lock_utils.ReaderWriterLock() def read_func(ident): with lock.read_lock(): # TODO(harlowja): sometime in the future use a monotonic clock here # to avoid problems that can be caused by ntpd resyncing the clock # while we are actively running. enter_time = now() time.sleep(WORK_TIMES[ident % len(WORK_TIMES)]) exit_time = now() start_stops.append((lock.READER, enter_time, exit_time)) time.sleep(NAPPY_TIME) def write_func(ident): with lock.write_lock(): enter_time = now() time.sleep(WORK_TIMES[ident % len(WORK_TIMES)]) exit_time = now() start_stops.append((lock.WRITER, enter_time, exit_time)) time.sleep(NAPPY_TIME) if max_workers is None: max_workers = max(0, readers) + max(0, writers) if max_workers > 0: with futures.ThreadPoolExecutor(max_workers=max_workers) as e: count = 0 for _i in range(0, readers): e.submit(read_func, count) count += 1 for _i in range(0, writers): e.submit(write_func, count) count += 1 writer_times = [] reader_times = [] for (lock_type, start, stop) in list(start_stops): if lock_type == lock.WRITER: writer_times.append((start, stop)) else: reader_times.append((start, stop)) return (writer_times, reader_times) class MultilockTest(test.TestCase): THREAD_COUNT = 20 def test_empty_error(self): self.assertRaises(ValueError, lock_utils.MultiLock, []) self.assertRaises(ValueError, lock_utils.MultiLock, ()) self.assertRaises(ValueError, lock_utils.MultiLock, iter([])) def test_creation(self): locks = [] for _i in range(0, 10): locks.append(threading.Lock()) n_lock = lock_utils.MultiLock(locks) self.assertEqual(0, n_lock.obtained) self.assertEqual(len(locks), len(n_lock)) def test_acquired(self): lock1 = threading.Lock() lock2 = threading.Lock() n_lock = lock_utils.MultiLock((lock1, lock2)) self.assertTrue(n_lock.acquire()) try: self.assertTrue(lock1.locked()) self.assertTrue(lock2.locked()) finally: n_lock.release() self.assertFalse(lock1.locked()) self.assertFalse(lock2.locked()) def test_acquired_context_manager(self): lock1 = threading.Lock() n_lock = lock_utils.MultiLock([lock1]) with n_lock as gotten: self.assertTrue(gotten) self.assertTrue(lock1.locked()) self.assertFalse(lock1.locked()) self.assertEqual(0, n_lock.obtained) def test_partial_acquired(self): lock1 = threading.Lock() lock2 = mock.create_autospec(threading.Lock()) lock2.acquire.return_value = False n_lock = lock_utils.MultiLock((lock1, lock2)) with n_lock as gotten: self.assertFalse(gotten) self.assertTrue(lock1.locked()) self.assertEqual(1, n_lock.obtained) self.assertEqual(2, len(n_lock)) self.assertEqual(0, n_lock.obtained) def test_partial_acquired_failure(self): lock1 = threading.Lock() lock2 = mock.create_autospec(threading.Lock()) lock2.acquire.side_effect = RuntimeError("Broke") n_lock = lock_utils.MultiLock((lock1, lock2)) self.assertRaises(threading.ThreadError, n_lock.acquire) self.assertEqual(1, n_lock.obtained) n_lock.release() def test_release_failure(self): lock1 = threading.Lock() lock2 = mock.create_autospec(threading.Lock()) lock2.acquire.return_value = True lock2.release.side_effect = RuntimeError("Broke") n_lock = lock_utils.MultiLock((lock1, lock2)) self.assertTrue(n_lock.acquire()) self.assertEqual(2, n_lock.obtained) self.assertRaises(threading.ThreadError, n_lock.release) self.assertEqual(2, n_lock.obtained) lock2.release.side_effect = None n_lock.release() self.assertEqual(0, n_lock.obtained) def test_release_partial_failure(self): lock1 = threading.Lock() lock2 = mock.create_autospec(threading.Lock()) lock2.acquire.return_value = True lock2.release.side_effect = RuntimeError("Broke") lock3 = threading.Lock() n_lock = lock_utils.MultiLock((lock1, lock2, lock3)) self.assertTrue(n_lock.acquire()) self.assertEqual(3, n_lock.obtained) self.assertRaises(threading.ThreadError, n_lock.release) self.assertEqual(2, n_lock.obtained) lock2.release.side_effect = None n_lock.release() self.assertEqual(0, n_lock.obtained) def test_acquired_pass(self): activated = collections.deque() acquires = collections.deque() lock1 = threading.Lock() lock2 = threading.Lock() n_lock = lock_utils.MultiLock((lock1, lock2)) def critical_section(): start = now() time.sleep(NAPPY_TIME) end = now() activated.append((start, end)) def run(): with n_lock as gotten: acquires.append(gotten) critical_section() threads = [] for _i in range(0, self.THREAD_COUNT): t = threading_utils.daemon_thread(run) threads.append(t) t.start() while threads: t = threads.pop() t.join() self.assertEqual(self.THREAD_COUNT, len(acquires)) self.assertTrue(all(acquires)) for (start, end) in activated: self.assertEqual(1, _find_overlaps(activated, start, end)) self.assertFalse(lock1.locked()) self.assertFalse(lock2.locked()) def test_acquired_fail(self): activated = collections.deque() acquires = collections.deque() lock1 = threading.Lock() lock2 = threading.Lock() n_lock = lock_utils.MultiLock((lock1, lock2)) def run(): with n_lock as gotten: acquires.append(gotten) start = now() time.sleep(NAPPY_TIME) end = now() activated.append((start, end)) def run_fail(): try: with n_lock as gotten: acquires.append(gotten) raise RuntimeError() except RuntimeError: pass threads = [] for i in range(0, self.THREAD_COUNT): if i % 2 == 1: target = run_fail else: target = run t = threading_utils.daemon_thread(target) threads.append(t) t.start() while threads: t = threads.pop() t.join() self.assertEqual(self.THREAD_COUNT, len(acquires)) self.assertTrue(all(acquires)) for (start, end) in activated: self.assertEqual(1, _find_overlaps(activated, start, end)) self.assertFalse(lock1.locked()) self.assertFalse(lock2.locked()) def test_double_acquire_single(self): activated = collections.deque() acquires = [] def run(): start = now() time.sleep(NAPPY_TIME) end = now() activated.append((start, end)) lock1 = threading.RLock() lock2 = threading.RLock() n_lock = lock_utils.MultiLock((lock1, lock2)) with n_lock as gotten: acquires.append(gotten) run() with n_lock as gotten: acquires.append(gotten) run() run() self.assertTrue(all(acquires)) self.assertEqual(2, len(acquires)) for (start, end) in activated: self.assertEqual(1, _find_overlaps(activated, start, end)) def test_double_acquire_many(self): activated = collections.deque() acquires = collections.deque() n_lock = lock_utils.MultiLock((threading.RLock(), threading.RLock())) def critical_section(): start = now() time.sleep(NAPPY_TIME) end = now() activated.append((start, end)) def run(): with n_lock as gotten: acquires.append(gotten) critical_section() with n_lock as gotten: acquires.append(gotten) critical_section() critical_section() threads = [] for i in range(0, self.THREAD_COUNT): t = threading_utils.daemon_thread(run) threads.append(t) t.start() while threads: t = threads.pop() t.join() self.assertTrue(all(acquires)) self.assertEqual(self.THREAD_COUNT * 2, len(acquires)) self.assertEqual(self.THREAD_COUNT * 3, len(activated)) for (start, end) in activated: self.assertEqual(1, _find_overlaps(activated, start, end)) def test_no_acquire_release(self): lock1 = threading.Lock() lock2 = threading.Lock() n_lock = lock_utils.MultiLock((lock1, lock2)) self.assertRaises(threading.ThreadError, n_lock.release) class ReadWriteLockTest(test.TestCase): THREAD_COUNT = 20 def test_no_double_writers(self): lock = lock_utils.ReaderWriterLock() watch = timing.StopWatch(duration=5) watch.start() dups = collections.deque() active = collections.deque() def acquire_check(me): with lock.write_lock(): if len(active) >= 1: dups.append(me) dups.extend(active) active.append(me) try: time.sleep(random.random() / 100) finally: active.remove(me) def run(): me = threading.current_thread() while not watch.expired(): acquire_check(me) threads = [] for i in range(0, self.THREAD_COUNT): t = threading_utils.daemon_thread(run) threads.append(t) t.start() while threads: t = threads.pop() t.join() self.assertEqual([], list(dups)) self.assertEqual([], list(active)) def test_no_concurrent_readers_writers(self): lock = lock_utils.ReaderWriterLock() watch = timing.StopWatch(duration=5) watch.start() dups = collections.deque() active = collections.deque() def acquire_check(me, reader): if reader: lock_func = lock.read_lock else: lock_func = lock.write_lock with lock_func(): if not reader: # There should be no-one else currently active, if there # is ensure we capture them so that we can later blow-up # the test. if len(active) >= 1: dups.append(me) dups.extend(active) active.append(me) try: time.sleep(random.random() / 100) finally: active.remove(me) def run(): me = threading.current_thread() while not watch.expired(): acquire_check(me, random.choice([True, False])) threads = [] for i in range(0, self.THREAD_COUNT): t = threading_utils.daemon_thread(run) threads.append(t) t.start() while threads: t = threads.pop() t.join() self.assertEqual([], list(dups)) self.assertEqual([], list(active)) def test_writer_abort(self): lock = lock_utils.ReaderWriterLock() self.assertFalse(lock.owner) def blow_up(): with lock.write_lock(): self.assertEqual(lock.WRITER, lock.owner) raise RuntimeError("Broken") self.assertRaises(RuntimeError, blow_up) self.assertFalse(lock.owner) def test_reader_abort(self): lock = lock_utils.ReaderWriterLock() self.assertFalse(lock.owner) def blow_up(): with lock.read_lock(): self.assertEqual(lock.READER, lock.owner) raise RuntimeError("Broken") self.assertRaises(RuntimeError, blow_up) self.assertFalse(lock.owner) def test_double_reader_abort(self): lock = lock_utils.ReaderWriterLock() activated = collections.deque() def double_bad_reader(): with lock.read_lock(): with lock.read_lock(): raise RuntimeError("Broken") def happy_writer(): with lock.write_lock(): activated.append(lock.owner) with futures.ThreadPoolExecutor(max_workers=20) as e: for i in range(0, 20): if i % 2 == 0: e.submit(double_bad_reader) else: e.submit(happy_writer) self.assertEqual(10, len([a for a in activated if a == 'w'])) def test_double_reader_writer(self): lock = lock_utils.ReaderWriterLock() activated = collections.deque() active = threading_utils.Event() def double_reader(): with lock.read_lock(): active.set() while not lock.has_pending_writers: time.sleep(0.001) with lock.read_lock(): activated.append(lock.owner) def happy_writer(): with lock.write_lock(): activated.append(lock.owner) reader = threading_utils.daemon_thread(double_reader) reader.start() self.assertTrue(active.wait(test_utils.WAIT_TIMEOUT)) writer = threading_utils.daemon_thread(happy_writer) writer.start() reader.join() writer.join() self.assertEqual(2, len(activated)) self.assertEqual(['r', 'w'], list(activated)) def test_reader_chaotic(self): lock = lock_utils.ReaderWriterLock() activated = collections.deque() def chaotic_reader(blow_up): with lock.read_lock(): if blow_up: raise RuntimeError("Broken") else: activated.append(lock.owner) def happy_writer(): with lock.write_lock(): activated.append(lock.owner) with futures.ThreadPoolExecutor(max_workers=20) as e: for i in range(0, 20): if i % 2 == 0: e.submit(chaotic_reader, blow_up=bool(i % 4 == 0)) else: e.submit(happy_writer) writers = [a for a in activated if a == 'w'] readers = [a for a in activated if a == 'r'] self.assertEqual(10, len(writers)) self.assertEqual(5, len(readers)) def test_writer_chaotic(self): lock = lock_utils.ReaderWriterLock() activated = collections.deque() def chaotic_writer(blow_up): with lock.write_lock(): if blow_up: raise RuntimeError("Broken") else: activated.append(lock.owner) def happy_reader(): with lock.read_lock(): activated.append(lock.owner) with futures.ThreadPoolExecutor(max_workers=20) as e: for i in range(0, 20): if i % 2 == 0: e.submit(chaotic_writer, blow_up=bool(i % 4 == 0)) else: e.submit(happy_reader) writers = [a for a in activated if a == 'w'] readers = [a for a in activated if a == 'r'] self.assertEqual(5, len(writers)) self.assertEqual(10, len(readers)) def test_single_reader_writer(self): results = [] lock = lock_utils.ReaderWriterLock() with lock.read_lock(): self.assertTrue(lock.is_reader()) self.assertEqual(0, len(results)) with lock.write_lock(): results.append(1) self.assertTrue(lock.is_writer()) with lock.read_lock(): self.assertTrue(lock.is_reader()) self.assertEqual(1, len(results)) self.assertFalse(lock.is_reader()) self.assertFalse(lock.is_writer()) def test_reader_to_writer(self): lock = lock_utils.ReaderWriterLock() def writer_func(): with lock.write_lock(): pass with lock.read_lock(): self.assertRaises(RuntimeError, writer_func) self.assertFalse(lock.is_writer()) self.assertFalse(lock.is_reader()) self.assertFalse(lock.is_writer()) def test_writer_to_reader(self): lock = lock_utils.ReaderWriterLock() def reader_func(): with lock.read_lock(): self.assertTrue(lock.is_writer()) self.assertTrue(lock.is_reader()) with lock.write_lock(): self.assertIsNone(reader_func()) self.assertFalse(lock.is_reader()) self.assertFalse(lock.is_reader()) self.assertFalse(lock.is_writer()) def test_double_writer(self): lock = lock_utils.ReaderWriterLock() with lock.write_lock(): self.assertFalse(lock.is_reader()) self.assertTrue(lock.is_writer()) with lock.write_lock(): self.assertTrue(lock.is_writer()) self.assertTrue(lock.is_writer()) self.assertFalse(lock.is_reader()) self.assertFalse(lock.is_writer()) def test_double_reader(self): lock = lock_utils.ReaderWriterLock() with lock.read_lock(): self.assertTrue(lock.is_reader()) self.assertFalse(lock.is_writer()) with lock.read_lock(): self.assertTrue(lock.is_reader()) self.assertTrue(lock.is_reader()) self.assertFalse(lock.is_reader()) self.assertFalse(lock.is_writer()) def test_multi_reader_multi_writer(self): writer_times, reader_times = _spawn_variation(10, 10) self.assertEqual(10, len(writer_times)) self.assertEqual(10, len(reader_times)) for (start, stop) in writer_times: self.assertEqual(0, _find_overlaps(reader_times, start, stop)) self.assertEqual(1, _find_overlaps(writer_times, start, stop)) for (start, stop) in reader_times: self.assertEqual(0, _find_overlaps(writer_times, start, stop)) def test_multi_reader_single_writer(self): writer_times, reader_times = _spawn_variation(9, 1) self.assertEqual(1, len(writer_times)) self.assertEqual(9, len(reader_times)) start, stop = writer_times[0] self.assertEqual(0, _find_overlaps(reader_times, start, stop)) def test_multi_writer(self): writer_times, reader_times = _spawn_variation(0, 10) self.assertEqual(10, len(writer_times)) self.assertEqual(0, len(reader_times)) for (start, stop) in writer_times: self.assertEqual(1, _find_overlaps(writer_times, start, stop))