diff --git a/eventlet/twistedutil/protocol.py b/eventlet/twistedutil/protocol.py index 5a9af18..5ce0f30 100644 --- a/eventlet/twistedutil/protocol.py +++ b/eventlet/twistedutil/protocol.py @@ -20,18 +20,62 @@ # THE SOFTWARE. """Basic twisted protocols converted to synchronous mode""" -import sys from twisted.internet.protocol import Protocol as twistedProtocol from twisted.internet.error import ConnectionDone from twisted.internet.protocol import Factory, ClientFactory -from twisted.python import failure -from eventlet.api import spawn +from twisted.internet.interfaces import IHalfCloseableProtocol +from zope.interface import implements + +from eventlet import proc +from eventlet.api import getcurrent from eventlet.coros import queue, event + +class ReadConnectionDone(ConnectionDone): + pass + +class WriteConnectionDone(ConnectionDone): + pass + +class ValueQueue(queue): + + def wait(self): + """The difference from queue.wait: if there is an only item in the + queue and it is an exception, raise it, but keep it in the queue, so + that future calls to wait() will raise it again. + """ + self.sem.acquire() + if self.has_final_error(): + # the last item, which is an exception, raise without emptying the queue + self.sem.release() + getcurrent().throw(*self.items[0][1]) + else: + result, exc = self.items.popleft() + if exc is not None: + getcurrent().throw(*exc) + return result + + def has_final_error(self): + return len(self.items)==1 and self.items[0][1] is not None + + +class Event(event): + + def send(self, value, exc=None): + if self.ready(): + self.reset() + return event.send(self, value, exc) + + def send_exception(self, *throw_args): + if self.ready(): + self.reset() + return event.send_exception(self, *throw_args) + + class Producer2Event(object): - # implements IPushProducer + # implements IPullProducer def __init__(self, event): self.event = event @@ -39,48 +83,78 @@ class Producer2Event(object): def resumeProducing(self): self.event.send(1) - def pauseProducing(self): - self.event.reset() - def stopProducing(self): del self.event + class GreenTransportBase(object): - write_event = None transportBufferSize = None def __init__(self, transportBufferSize=None): if transportBufferSize is not None: self.transportBufferSize = transportBufferSize - self._error_event = event() + self._queue = queue() + self._read_disconnected_event = Event() + self._write_disconnected_event = Event() + self._write_event = Event() def build_protocol(self): - # note to subclassers: self._queue must have send and send_exception that never block - self._queue = queue() - protocol = self.protocol_class(self._queue) + protocol = self.protocol_class(self) return protocol + def _got_transport(self, transport): + self._queue.send(transport) + + def _got_data(self, data): + self._queue.send(data) + + def _connectionLost(self, reason): + self._read_disconnected_event.send(reason.value) + self._write_disconnected_event.send(reason.value) + self._queue.send_exception(reason.value) + self._write_event.send_exception(reason.value) + + def _readConnectionLost(self): + self._read_disconnected_event.send(ReadConnectionDone) + self._queue.send_exception(ReadConnectionDone) + + def _writeConnectionLost(self): + self._write_disconnected_event.send(WriteConnectionDone) + self._write_event.send_exception(WriteConnectionDone) + def _wait(self): + if self._read_disconnected_event.ready(): + if self._queue: + return self._queue.wait() + else: + raise self._read_disconnected_event.wait() self.resumeProducing() try: return self._queue.wait() - except: - if self._error_event is not None: - self._error_event.send(None) - raise finally: self.pauseProducing() def write(self, data): + if self._write_disconnected_event.ready(): + raise self._write_disconnected_event.wait() + self._write_event.reset() + self.transport.write(data) + self._write_event.wait() + + def async_write(self, data): self.transport.write(data) - if self.write_event is not None: - self.write_event.wait() def loseConnection(self): self.transport.unregisterProducer() self.transport.loseConnection() - self._error_event.wait() + self._read_disconnected_event.wait() + self._write_disconnected_event.wait() + + def loseWriteConnection(self): + self.transport.unregisterProducer() + self.transport.loseWriteConnection() + self._write_disconnected_event.wait() def __getattr__(self, item): if item=='transport': @@ -116,25 +190,30 @@ class GreenTransportBase(object): if self.transportBufferSize is not None: transport.bufferSize = self.transportBufferSize self._init_transport_producer() - if self.write_event is None: - self.write_event = event() - self.write_event.send(1) - transport.registerProducer(Producer2Event(self.write_event), True) + transport.registerProducer(Producer2Event(self._write_event), False) + class Protocol(twistedProtocol): - def __init__(self, queue): - self._queue = queue + implements(IHalfCloseableProtocol) + + def __init__(self, recepient): + self._recepient = recepient def connectionMade(self): - self._queue.send(self.transport) + self._recepient._got_transport(self.transport) def dataReceived(self, data): - self._queue.send(data) + self._recepient._got_data(data) def connectionLost(self, reason): - self._queue.send_exception(reason.type, reason.value, reason.tb) - del self._queue + self._recepient._connectionLost(reason) + + def readConnectionLost(self): + self._recepient._readConnectionLost() + + def writeConnectionLost(self): + self._recepient._writeConnectionLost() class UnbufferedTransport(GreenTransportBase): @@ -148,16 +227,12 @@ class UnbufferedTransport(GreenTransportBase): Return '' if connection was closed cleanly, raise the exception if it was closed in a non clean fashion. After that all successive calls return ''. """ - if self._queue is None: + if self._read_disconnected_event.ready(): return '' try: return self._wait() except ConnectionDone: - self._queue = None return '' - except: - self._queue = None - raise def read(self): """Read the data from the socket until the connection is closed cleanly. @@ -192,45 +267,31 @@ class GreenTransport(GreenTransportBase): _buffer = '' _error = None - def _wait(self): - # don't pause/resume producer here; read and recv methods will do it themselves - try: - return self._queue.wait() - except: - self._error_event.send(None) - raise - def read(self, size=-1): """Read size bytes or until EOF""" - if self._queue is not None: - resumed = False + if not self._read_disconnected_event.ready(): try: - try: - while len(self._buffer) < size or size < 0: - if not resumed: - self.resumeProducing() - resumed = True - self._buffer += self._wait() - except ConnectionDone: - self._queue = None - except: - self._queue = None - self._error = sys.exc_info() - finally: - if resumed: - self.pauseProducing() + while len(self._buffer) < size or size < 0: + self._buffer += self._wait() + except ConnectionDone: + pass + except: + if not self._read_disconnected_event.has_exception(): + raise if size>=0: result, self._buffer = self._buffer[:size], self._buffer[size:] else: result, self._buffer = self._buffer, '' - if not result and self._error is not None: - error, self._error = self._error, None - raise error[0], error[1], error[2] + if not result and self._read_disconnected_event.has_exception(): + try: + self._read_disconnected_event.wait() + except ConnectionDone: + pass return result def recv(self, buflen=None): """Receive a single chunk of undefined size but no bigger than buflen""" - if self._queue is not None and not self._buffer: + if not self._read_disconnected_event.ready(): self.resumeProducing() try: try: @@ -238,20 +299,21 @@ class GreenTransport(GreenTransportBase): #print 'received %r' % recvd self._buffer += recvd except ConnectionDone: - self._queue = None + pass except: - self._queue = None - self._error = sys.exc_info() + if not self._read_disconnected_event.has_exception(): + raise finally: self.pauseProducing() if buflen is None: result, self._buffer = self._buffer, '' else: result, self._buffer = self._buffer[:buflen], self._buffer[buflen:] - if not result and self._error is not None: - error = self._error - self._error = None - raise error[0], error[1], error[2] + if not result and self._read_disconnected_event.has_exception(): + try: + self._read_disconnected_event.wait() + except ConnectionDone: + pass return result # iterator protocol: diff --git a/eventlet/twistedutil/protocols/basic.py b/eventlet/twistedutil/protocols/basic.py index fea16b9..67ec918 100644 --- a/eventlet/twistedutil/protocols/basic.py +++ b/eventlet/twistedutil/protocols/basic.py @@ -23,19 +23,36 @@ from twisted.protocols import basic from twisted.internet.error import ConnectionDone from eventlet.twistedutil.protocol import GreenTransportBase +from twisted.internet.interfaces import IHalfCloseableProtocol +from zope.interface import implements + class LineOnlyReceiver(basic.LineOnlyReceiver): - def __init__(self, queue): - self._queue = queue + implements(IHalfCloseableProtocol) + + def __init__(self, recepient): + self._recepient = recepient def connectionMade(self): - self._queue.send(self.transport) - - def lineReceived(self, line): - self._queue.send(line) + #print '%r made' % self + self._recepient._got_transport(self.transport) def connectionLost(self, reason): - self._queue.send_exception(reason.type, reason.value, reason.tb) + #print '%r conn lost %r' % (self, reason) + self._recepient._connectionLost(reason) + + def readConnectionLost(self): + #print '%r read conn lost' % self + self._recepient._readConnectionLost() + + def writeConnectionLost(self): + #print '%r wr conn lost' % self + self._recepient._writeConnectionLost() + + def lineReceived(self, line): + #print '%r line received %r' % (self, line) + self._recepient._got_data(line) + class LineOnlyReceiverTransport(GreenTransportBase): diff --git a/examples/twisted_client.py b/examples/twisted_client.py index 3f112ca..8ba7749 100644 --- a/examples/twisted_client.py +++ b/examples/twisted_client.py @@ -12,11 +12,13 @@ from twisted.internet import reactor # read from TCP connection conn = GreenClientCreator(reactor).connectTCP('www.google.com', 80) conn.write('GET / HTTP/1.0\r\n\r\n') +conn.loseWriteConnection() print conn.read() # read from SSL connection line by line conn = GreenClientCreator(reactor, LineOnlyReceiverTransport).connectSSL('sf.net', 443, ssl.ClientContextFactory()) conn.write('GET / HTTP/1.0\r\n\r\n') +#conn.loseWriteConnection() try: for num, line in enumerate(conn): print '%3s %r' % (num, line) diff --git a/examples/twisted_server.py b/examples/twisted_server.py index eaf421e..a52895a 100644 --- a/examples/twisted_server.py +++ b/examples/twisted_server.py @@ -52,6 +52,7 @@ class Chat: else: print peer, 'connection done' finally: + conn.loseConnection() self.participants.remove(conn) print __doc__ diff --git a/examples/twisted_srvconnector.py b/examples/twisted_srvconnector.py index 498dc4f..1ce9e00 100644 --- a/examples/twisted_srvconnector.py +++ b/examples/twisted_srvconnector.py @@ -45,8 +45,11 @@ From-Path: msrps://alice.example.com:9892/98cjs;tcp -------49fh$ """.replace('\n', '\r\n') +print 'Sending:\n%s' % request conn.write(request) +#conn.loseWriteConnection() +print 'Received:' for x in conn: - print x + print repr(x) if '-------' in x: break diff --git a/greentest/test__twistedutil_protocol.py b/greentest/test__twistedutil_protocol.py index 7dcf069..88b5a7e 100644 --- a/greentest/test__twistedutil_protocol.py +++ b/greentest/test__twistedutil_protocol.py @@ -20,10 +20,9 @@ # THE SOFTWARE. from twisted.internet import reactor -from greentest import exit_unless_twisted +from greentest import exit_unless_twisted, LimitedTestCase exit_unless_twisted() -import sys import unittest from twisted.internet.error import ConnectionDone @@ -116,7 +115,6 @@ class TestGreenTransport(TestUnbufferedTransport): self.conn.write('hello\r\n') self.assertEqual(self.conn.read(9), 'you said ') self.assertEqual(self.conn.read(999), 'hello. BYE') - self.assertEqual(None, self.conn._queue) self.assertEqual(self.conn.read(9), '') self.assertEqual(self.conn.read(1), '') self.assertEqual(self.conn.recv(9), '') @@ -156,26 +154,25 @@ class TestGreenTransport(TestUnbufferedTransport): class TestGreenTransport_bufsize1(TestGreenTransport): transportBufferSize = 1 -class TestGreenTransportError(TestCase): - setup_server = setup_server_SpawnFactory - gtransportClass = pr.GreenTransport - - def test_read_error(self): - self.conn.write('hello\r\n') - sleep(DELAY*1.5) # make sure the rest of data arrives - try: - 1/0 - except: - #self.conn.loseConnection(failure.Failure()) # does not work, why? - spawn(self.conn._queue.send_exception, *sys.exc_info()) - self.assertEqual(self.conn.read(9), 'you said ') - self.assertEqual(self.conn.read(7), 'hello. ') - self.assertEqual(self.conn.read(9), 'BYE') - self.assertRaises(ZeroDivisionError, self.conn.read, 9) - self.assertEqual(None, self.conn._queue) - self.assertEqual(self.conn.read(1), '') - self.assertEqual(self.conn.read(1), '') - +# class TestGreenTransportError(TestCase): +# setup_server = setup_server_SpawnFactory +# gtransportClass = pr.GreenTransport +# +# def test_read_error(self): +# self.conn.write('hello\r\n') +# sleep(DELAY*1.5) # make sure the rest of data arrives +# try: +# 1/0 +# except: +# #self.conn.loseConnection(failure.Failure()) # does not work, why? +# spawn(self.conn._queue.send_exception, *sys.exc_info()) +# self.assertEqual(self.conn.read(9), 'you said ') +# self.assertEqual(self.conn.read(7), 'hello. ') +# self.assertEqual(self.conn.read(9), 'BYE') +# self.assertRaises(ZeroDivisionError, self.conn.read, 9) +# self.assertEqual(self.conn.read(1), '') +# self.assertEqual(self.conn.read(1), '') +# # def test_recv_error(self): # self.conn.write('hello') # self.assertEqual('you said hello. ', self.conn.recv()) @@ -187,11 +184,57 @@ class TestGreenTransportError(TestCase): # spawn(self.conn._queue.send_exception, *sys.exc_info()) # self.assertEqual('BYE', self.conn.recv()) # self.assertRaises(ZeroDivisionError, self.conn.recv, 9) -# self.assertEqual(None, self.conn._queue) # self.assertEqual('', self.conn.recv(1)) # self.assertEqual('', self.conn.recv()) # +class TestHalfClose_TCP(LimitedTestCase): + + def _test_server(self, conn): + conn.write('hello') + conn.loseWriteConnection() + self.assertRaises(pr.WriteConnectionDone, conn.write, 'hey') + data = conn.read() + self.assertEqual('bye', data) + conn.loseConnection() + self.assertRaises(ConnectionDone, conn._wait) + self.check.append('server') + + def setUp(self): + LimitedTestCase.setUp(self) + self.factory = pr.SpawnFactory(self._test_server) + self.port = reactor.listenTCP(0, self.factory) + self.conn = pr.GreenClientCreator(reactor).connectTCP('localhost', self.port.getHost().port) + self.port.stopListening() + self.check = [] + + def test(self): + conn = self.conn + data = conn.read() + self.assertEqual('hello', data) + conn.write('bye') + conn.loseWriteConnection() + self.assertRaises(pr.WriteConnectionDone, conn.write, 'hoy') + self.factory.waitall() + self.assertRaises(ConnectionDone, conn._wait) + assert self.check == ['server'] + +class TestHalfClose_TLS(TestHalfClose_TCP): + + def setUp(self): + LimitedTestCase.setUp(self) + from gnutls.crypto import X509PrivateKey, X509Certificate + from gnutls.interfaces.twisted import X509Credentials + cert = X509Certificate(open('gnutls_valid.crt').read()) + key = X509PrivateKey(open('gnutls_valid.key').read()) + server_credentials = X509Credentials(cert, key) + self.factory = pr.SpawnFactory(self._test_server) + self.port = reactor.listenTLS(0, self.factory, server_credentials) + self.conn = pr.GreenClientCreator(reactor).connectTLS('localhost', self.port.getHost().port, X509Credentials()) + self.port.stopListening() + self.check = [] + + if socket is not None: class TestUnbufferedTransport_socketserver(TestUnbufferedTransport): @@ -232,6 +275,7 @@ try: import gnutls.interfaces.twisted except ImportError: del TestTLSError + del TestHalfClose_TLS if __name__=='__main__': unittest.main()