semaphore: support timeout for acquire

Fixes https://bitbucket.org/eventlet/eventlet/issue/147/semaphoresemaphore-should-support-a
This commit is contained in:
Justin Patrin
2013-06-28 16:56:53 +04:00
committed by Sergey Shepelev
parent 25812fca81
commit 747b753a20
2 changed files with 75 additions and 19 deletions

View File

@@ -1,7 +1,11 @@
from __future__ import with_statement
from eventlet import greenthread from eventlet import greenthread
from eventlet import hubs from eventlet import hubs
from eventlet.timeout import Timeout
class Semaphore(object): class Semaphore(object):
"""An unbounded semaphore. """An unbounded semaphore.
Optionally initialize with a resource *count*, then :meth:`acquire` and Optionally initialize with a resource *count*, then :meth:`acquire` and
:meth:`release` resources as needed. Attempting to :meth:`acquire` when :meth:`release` resources as needed. Attempting to :meth:`acquire` when
@@ -17,6 +21,13 @@ class Semaphore(object):
do_some_stuff() do_some_stuff()
If not specified, *value* defaults to 1. If not specified, *value* defaults to 1.
It is possible to limit acquire time::
sem = Semaphore()
ok = sem.acquire(timeout=0.1)
# True if acquired, False if timed out.
""" """
def __init__(self, value=1): def __init__(self, value=1):
@@ -36,15 +47,17 @@ class Semaphore(object):
return '<%s c=%s _w[%s]>' % params return '<%s c=%s _w[%s]>' % params
def locked(self): def locked(self):
"""Returns true if a call to acquire would block.""" """Returns true if a call to acquire would block.
"""
return self.counter <= 0 return self.counter <= 0
def bounded(self): def bounded(self):
"""Returns False; for consistency with """Returns False; for consistency with
:class:`~eventlet.semaphore.CappedSemaphore`.""" :class:`~eventlet.semaphore.CappedSemaphore`.
"""
return False return False
def acquire(self, blocking=True): def acquire(self, blocking=True, timeout=None):
"""Acquire a semaphore. """Acquire a semaphore.
When invoked without arguments: if the internal counter is larger than When invoked without arguments: if the internal counter is larger than
@@ -61,12 +74,24 @@ class Semaphore(object):
When invoked with blocking set to false, do not block. If a call without When invoked with blocking set to false, do not block. If a call without
an argument would block, return false immediately; otherwise, do the an argument would block, return false immediately; otherwise, do the
same thing as when called without arguments, and return true.""" same thing as when called without arguments, and return true.
"""
if not blocking and timeout is not None:
raise ValueError("can't specify timeout for non-blocking acquire")
if not blocking and self.locked(): if not blocking and self.locked():
return False return False
if self.counter <= 0: if self.counter <= 0:
self._waiters.add(greenthread.getcurrent()) self._waiters.add(greenthread.getcurrent())
try: try:
if timeout is not None:
ok = False
with Timeout(timeout, False):
while self.counter <= 0:
hubs.get_hub().switch()
ok = True
if not ok:
return False
else:
while self.counter <= 0: while self.counter <= 0:
hubs.get_hub().switch() hubs.get_hub().switch()
finally: finally:
@@ -83,14 +108,15 @@ class Semaphore(object):
larger than zero again, wake up that thread. larger than zero again, wake up that thread.
The *blocking* argument is for consistency with CappedSemaphore and is The *blocking* argument is for consistency with CappedSemaphore and is
ignored""" ignored
"""
self.counter += 1 self.counter += 1
if self._waiters: if self._waiters:
hubs.get_hub().schedule_call_global(0, self._do_acquire) hubs.get_hub().schedule_call_global(0, self._do_acquire)
return True return True
def _do_acquire(self): def _do_acquire(self):
if self._waiters and self.counter>0: if self._waiters and self.counter > 0:
waiter = self._waiters.pop() waiter = self._waiters.pop()
waiter.switch() waiter.switch()
@@ -115,11 +141,14 @@ class Semaphore(object):
class BoundedSemaphore(Semaphore): class BoundedSemaphore(Semaphore):
"""A bounded semaphore checks to make sure its current value doesn't exceed """A bounded semaphore checks to make sure its current value doesn't exceed
its initial value. If it does, ValueError is raised. In most situations its initial value. If it does, ValueError is raised. In most situations
semaphores are used to guard resources with limited capacity. If the semaphores are used to guard resources with limited capacity. If the
semaphore is released too many times it's a sign of a bug. If not given, semaphore is released too many times it's a sign of a bug. If not given,
*value* defaults to 1.""" *value* defaults to 1.
"""
def __init__(self, value=1): def __init__(self, value=1):
super(BoundedSemaphore, self).__init__(value) super(BoundedSemaphore, self).__init__(value)
self.original_counter = value self.original_counter = value
@@ -131,12 +160,15 @@ class BoundedSemaphore(Semaphore):
larger than zero again, wake up that thread. larger than zero again, wake up that thread.
The *blocking* argument is for consistency with :class:`CappedSemaphore` The *blocking* argument is for consistency with :class:`CappedSemaphore`
and is ignored""" and is ignored
"""
if self.counter >= self.original_counter: if self.counter >= self.original_counter:
raise ValueError, "Semaphore released too many times" raise ValueError, "Semaphore released too many times"
return super(BoundedSemaphore, self).release(blocking) return super(BoundedSemaphore, self).release(blocking)
class CappedSemaphore(object): class CappedSemaphore(object):
"""A blockingly bounded semaphore. """A blockingly bounded semaphore.
Optionally initialize with a resource *count*, then :meth:`acquire` and Optionally initialize with a resource *count*, then :meth:`acquire` and
@@ -158,6 +190,7 @@ class CappedSemaphore(object):
with sem: with sem:
do_some_stuff() do_some_stuff()
""" """
def __init__(self, count, limit): def __init__(self, count, limit):
if count < 0: if count < 0:
raise ValueError("CappedSemaphore must be initialized with a " raise ValueError("CappedSemaphore must be initialized with a "
@@ -166,7 +199,7 @@ class CappedSemaphore(object):
# accidentally, this also catches the case when limit is None # accidentally, this also catches the case when limit is None
raise ValueError("'count' cannot be more than 'limit'") raise ValueError("'count' cannot be more than 'limit'")
self.lower_bound = Semaphore(count) self.lower_bound = Semaphore(count)
self.upper_bound = Semaphore(limit-count) self.upper_bound = Semaphore(limit - count)
def __repr__(self): def __repr__(self):
params = (self.__class__.__name__, hex(id(self)), params = (self.__class__.__name__, hex(id(self)),
@@ -179,11 +212,13 @@ class CappedSemaphore(object):
return '<%s b=%s l=%s u=%s>' % params return '<%s b=%s l=%s u=%s>' % params
def locked(self): def locked(self):
"""Returns true if a call to acquire would block.""" """Returns true if a call to acquire would block.
"""
return self.lower_bound.locked() return self.lower_bound.locked()
def bounded(self): def bounded(self):
"""Returns true if a call to release would block.""" """Returns true if a call to release would block.
"""
return self.upper_bound.locked() return self.upper_bound.locked()
def acquire(self, blocking=True): def acquire(self, blocking=True):
@@ -203,7 +238,8 @@ class CappedSemaphore(object):
When invoked with blocking set to false, do not block. If a call without When invoked with blocking set to false, do not block. If a call without
an argument would block, return false immediately; otherwise, do the an argument would block, return false immediately; otherwise, do the
same thing as when called without arguments, and return true.""" same thing as when called without arguments, and return true.
"""
if not blocking and self.locked(): if not blocking and self.locked():
return False return False
self.upper_bound.release() self.upper_bound.release()
@@ -225,7 +261,8 @@ class CappedSemaphore(object):
Imagine the docs of :meth:`acquire` here, but with every direction Imagine the docs of :meth:`acquire` here, but with every direction
reversed. When calling this method, it will block if the internal reversed. When calling this method, it will block if the internal
counter is greater than or equal to *limit*.""" counter is greater than or equal to *limit*.
"""
if not blocking and self.bounded(): if not blocking and self.bounded():
return False return False
self.lower_bound.release() self.lower_bound.release()
@@ -247,5 +284,6 @@ class CappedSemaphore(object):
the negative of the number of releases that would be required in order the negative of the number of releases that would be required in order
to make the counter 0 again (one more release would push the counter to to make the counter 0 again (one more release would push the counter to
1 and unblock acquirers). It takes into account how many greenthreads 1 and unblock acquirers). It takes into account how many greenthreads
are currently blocking in :meth:`acquire` and :meth:`release`.""" are currently blocking in :meth:`acquire` and :meth:`release`.
"""
return self.lower_bound.balance - self.upper_bound.balance return self.lower_bound.balance - self.upper_bound.balance

View File

@@ -1,9 +1,13 @@
import time
import unittest import unittest
import eventlet import eventlet
from eventlet import semaphore from eventlet import semaphore
from tests import LimitedTestCase from tests import LimitedTestCase
class TestSemaphore(LimitedTestCase): class TestSemaphore(LimitedTestCase):
def test_bounded(self): def test_bounded(self):
sem = semaphore.CappedSemaphore(2, limit=3) sem = semaphore.CappedSemaphore(2, limit=3)
self.assertEqual(sem.acquire(), True) self.assertEqual(sem.acquire(), True)
@@ -26,6 +30,20 @@ class TestSemaphore(LimitedTestCase):
sem.release() sem.release()
gt.wait() gt.wait()
def test_non_blocking(self):
sem = semaphore.Semaphore(0)
self.assertEqual(sem.acquire(blocking=False), False)
if __name__=='__main__': def test_timeout(self):
sem = semaphore.Semaphore(0)
start = time.time()
self.assertEqual(sem.acquire(timeout=0.1), False)
self.assertTrue(time.time() - start >= 0.1)
def test_timeout_non_blocking(self):
sem = semaphore.Semaphore()
self.assertRaises(ValueError, sem.acquire, blocking=False, timeout=1)
if __name__ == '__main__':
unittest.main() unittest.main()