reimplement BoundedSemaphore using two Semaphores; this resulted in a much more readable implementation

- queue(0) should work well as a channel now
- add 'balance' property, similar to that of channel's
This commit is contained in:
Denis Bilenko
2009-03-29 23:02:52 +07:00
parent e12619f438
commit ada0811e7a
2 changed files with 26 additions and 49 deletions

View File

@@ -335,74 +335,51 @@ class BoundedSemaphore(object):
if count > limit: if count > limit:
# accidentally, this also catches the case when limit is None # accidentally, this also catches the case when limit is None
raise ValueError("'count' cannot be more than 'limit'") raise ValueError("'count' cannot be more than 'limit'")
self.counter = count self.lower_bound = Semaphore(count)
self.limit = limit self.upper_bound = Semaphore(limit-count)
self._acquire_waiters = {}
self._release_waiters = {}
def __str__(self): def __str__(self):
params = (self.__class__.__name__, hex(id(self)), self.counter, self.limit) params = (self.__class__.__name__, hex(id(self)), self.lower_bound.counter, self.upper_bound.counter)
return '<%s at %s %r/%r>' % params return '<%s at %s %r/%r>' % params
def locked(self): def locked(self):
return self.counter <= 0 return self.lower_bound.locked()
def bounded(self): def bounded(self):
return self.counter >= self.limit return self.upper_bound.locked()
def acquire(self, blocking=True): def acquire(self, blocking=True):
if not blocking and self.locked(): if not blocking and self.locked():
return False return False
if self.counter<=0: self.upper_bound.release()
if self._release_waiters: try:
api.get_hub().schedule_call_global(0, self._do_unlock) return self.lower_bound.acquire()
self._acquire_waiters[api.getcurrent()] = None except:
try: self.upper_bound.counter -= 1
api.get_hub().switch() # using counter directly means that it can be less than zero.
finally: # however I certainly don't need to wait here and I don't seem to have
self._acquire_waiters.pop(api.getcurrent(), None) # a need to care about such inconsistency
self.counter -= 1 raise
if self._release_waiters and self.counter < self.limit:
api.get_hub().schedule_call_global(0, self._do_release)
return True
__enter__ = acquire __enter__ = acquire
def _do_unlock(self):
if self._release_waiters and self._acquire_waiters:
waiter, _unused = self._release_waiters.popitem()
waiter.switch()
self._do_acquire()
def _do_release(self):
if self._release_waiters and self.counter<self.limit:
waiter, _unused = self._release_waiters.popitem()
waiter.switch()
def _do_acquire(self):
if self._acquire_waiters and self.counter>0:
waiter, _unused = self._acquire_waiters.popitem()
waiter.switch()
def release(self, blocking=True): def release(self, blocking=True):
if not blocking and self.bounded(): if not blocking and self.bounded():
return False return False
if self.counter>=self.limit: self.lower_bound.release()
if self._acquire_waiters: try:
api.get_hub().schedule_call_global(0, self._do_unlock) return self.upper_bound.acquire()
self._release_waiters[api.getcurrent()] = None except:
try: self.lower_bound.counter -= 1
api.get_hub().switch() raise
finally:
self._release_waiters.pop(api.getcurrent(), None)
self.counter += 1
if self._acquire_waiters and self.counter > 0:
api.get_hub().schedule_call_global(0, self._do_acquire)
return True
def __exit__(self, typ, val, tb): def __exit__(self, typ, val, tb):
self.release() self.release()
@property
def balance(self):
return self.lower_bound.counter - self.upper_bound.counter
def semaphore(count=0, limit=None): def semaphore(count=0, limit=None):
if limit is None: if limit is None:

View File

@@ -32,13 +32,13 @@ class TestSemaphore(LimitedTestCase):
self.assertEqual(sem.acquire(), True) self.assertEqual(sem.acquire(), True)
api.spawn(sem.release) api.spawn(sem.release)
self.assertEqual(sem.acquire(), True) self.assertEqual(sem.acquire(), True)
self.assertEqual(0, sem.counter) self.assertEqual(-3, sem.balance)
sem.release() sem.release()
sem.release() sem.release()
sem.release() sem.release()
api.spawn(sem.acquire) api.spawn(sem.acquire)
sem.release() sem.release()
self.assertEqual(3, sem.counter) self.assertEqual(3, sem.balance)
def test_bounded_with_zero_limit(self): def test_bounded_with_zero_limit(self):
sem = coros.semaphore(0, 0) sem = coros.semaphore(0, 0)