GreenTransportBase: implement half-closability and write() that blocks until data is sent.
* add loseWriteConnection() method * add async_write() that does not wait (synonim for transport.write()) * GreenTransport's protocols now required to implement IHalfClosable interface * GreenTransport use PullProducer instead of PushProducer * add test for loseWriteConnection and remove test that does not work anymore
This commit is contained in:
@@ -20,18 +20,62 @@
|
||||
# THE SOFTWARE.
|
||||
|
||||
"""Basic twisted protocols converted to synchronous mode"""
|
||||
import sys
|
||||
from twisted.internet.protocol import Protocol as twistedProtocol
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from twisted.internet.protocol import Factory, ClientFactory
|
||||
from twisted.python import failure
|
||||
|
||||
from eventlet.api import spawn
|
||||
from twisted.internet.interfaces import IHalfCloseableProtocol
|
||||
from zope.interface import implements
|
||||
|
||||
from eventlet import proc
|
||||
from eventlet.api import getcurrent
|
||||
from eventlet.coros import queue, event
|
||||
|
||||
|
||||
class ReadConnectionDone(ConnectionDone):
|
||||
pass
|
||||
|
||||
class WriteConnectionDone(ConnectionDone):
|
||||
pass
|
||||
|
||||
class ValueQueue(queue):
|
||||
|
||||
def wait(self):
|
||||
"""The difference from queue.wait: if there is an only item in the
|
||||
queue and it is an exception, raise it, but keep it in the queue, so
|
||||
that future calls to wait() will raise it again.
|
||||
"""
|
||||
self.sem.acquire()
|
||||
if self.has_final_error():
|
||||
# the last item, which is an exception, raise without emptying the queue
|
||||
self.sem.release()
|
||||
getcurrent().throw(*self.items[0][1])
|
||||
else:
|
||||
result, exc = self.items.popleft()
|
||||
if exc is not None:
|
||||
getcurrent().throw(*exc)
|
||||
return result
|
||||
|
||||
def has_final_error(self):
|
||||
return len(self.items)==1 and self.items[0][1] is not None
|
||||
|
||||
|
||||
class Event(event):
|
||||
|
||||
def send(self, value, exc=None):
|
||||
if self.ready():
|
||||
self.reset()
|
||||
return event.send(self, value, exc)
|
||||
|
||||
def send_exception(self, *throw_args):
|
||||
if self.ready():
|
||||
self.reset()
|
||||
return event.send_exception(self, *throw_args)
|
||||
|
||||
|
||||
class Producer2Event(object):
|
||||
|
||||
# implements IPushProducer
|
||||
# implements IPullProducer
|
||||
|
||||
def __init__(self, event):
|
||||
self.event = event
|
||||
@@ -39,48 +83,78 @@ class Producer2Event(object):
|
||||
def resumeProducing(self):
|
||||
self.event.send(1)
|
||||
|
||||
def pauseProducing(self):
|
||||
self.event.reset()
|
||||
|
||||
def stopProducing(self):
|
||||
del self.event
|
||||
|
||||
|
||||
class GreenTransportBase(object):
|
||||
|
||||
write_event = None
|
||||
transportBufferSize = None
|
||||
|
||||
def __init__(self, transportBufferSize=None):
|
||||
if transportBufferSize is not None:
|
||||
self.transportBufferSize = transportBufferSize
|
||||
self._error_event = event()
|
||||
self._queue = queue()
|
||||
self._read_disconnected_event = Event()
|
||||
self._write_disconnected_event = Event()
|
||||
self._write_event = Event()
|
||||
|
||||
def build_protocol(self):
|
||||
# note to subclassers: self._queue must have send and send_exception that never block
|
||||
self._queue = queue()
|
||||
protocol = self.protocol_class(self._queue)
|
||||
protocol = self.protocol_class(self)
|
||||
return protocol
|
||||
|
||||
def _got_transport(self, transport):
|
||||
self._queue.send(transport)
|
||||
|
||||
def _got_data(self, data):
|
||||
self._queue.send(data)
|
||||
|
||||
def _connectionLost(self, reason):
|
||||
self._read_disconnected_event.send(reason.value)
|
||||
self._write_disconnected_event.send(reason.value)
|
||||
self._queue.send_exception(reason.value)
|
||||
self._write_event.send_exception(reason.value)
|
||||
|
||||
def _readConnectionLost(self):
|
||||
self._read_disconnected_event.send(ReadConnectionDone)
|
||||
self._queue.send_exception(ReadConnectionDone)
|
||||
|
||||
def _writeConnectionLost(self):
|
||||
self._write_disconnected_event.send(WriteConnectionDone)
|
||||
self._write_event.send_exception(WriteConnectionDone)
|
||||
|
||||
def _wait(self):
|
||||
if self._read_disconnected_event.ready():
|
||||
if self._queue:
|
||||
return self._queue.wait()
|
||||
else:
|
||||
raise self._read_disconnected_event.wait()
|
||||
self.resumeProducing()
|
||||
try:
|
||||
return self._queue.wait()
|
||||
except:
|
||||
if self._error_event is not None:
|
||||
self._error_event.send(None)
|
||||
raise
|
||||
finally:
|
||||
self.pauseProducing()
|
||||
|
||||
def write(self, data):
|
||||
if self._write_disconnected_event.ready():
|
||||
raise self._write_disconnected_event.wait()
|
||||
self._write_event.reset()
|
||||
self.transport.write(data)
|
||||
self._write_event.wait()
|
||||
|
||||
def async_write(self, data):
|
||||
self.transport.write(data)
|
||||
if self.write_event is not None:
|
||||
self.write_event.wait()
|
||||
|
||||
def loseConnection(self):
|
||||
self.transport.unregisterProducer()
|
||||
self.transport.loseConnection()
|
||||
self._error_event.wait()
|
||||
self._read_disconnected_event.wait()
|
||||
self._write_disconnected_event.wait()
|
||||
|
||||
def loseWriteConnection(self):
|
||||
self.transport.unregisterProducer()
|
||||
self.transport.loseWriteConnection()
|
||||
self._write_disconnected_event.wait()
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item=='transport':
|
||||
@@ -116,25 +190,30 @@ class GreenTransportBase(object):
|
||||
if self.transportBufferSize is not None:
|
||||
transport.bufferSize = self.transportBufferSize
|
||||
self._init_transport_producer()
|
||||
if self.write_event is None:
|
||||
self.write_event = event()
|
||||
self.write_event.send(1)
|
||||
transport.registerProducer(Producer2Event(self.write_event), True)
|
||||
transport.registerProducer(Producer2Event(self._write_event), False)
|
||||
|
||||
|
||||
class Protocol(twistedProtocol):
|
||||
|
||||
def __init__(self, queue):
|
||||
self._queue = queue
|
||||
implements(IHalfCloseableProtocol)
|
||||
|
||||
def __init__(self, recepient):
|
||||
self._recepient = recepient
|
||||
|
||||
def connectionMade(self):
|
||||
self._queue.send(self.transport)
|
||||
self._recepient._got_transport(self.transport)
|
||||
|
||||
def dataReceived(self, data):
|
||||
self._queue.send(data)
|
||||
self._recepient._got_data(data)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self._queue.send_exception(reason.type, reason.value, reason.tb)
|
||||
del self._queue
|
||||
self._recepient._connectionLost(reason)
|
||||
|
||||
def readConnectionLost(self):
|
||||
self._recepient._readConnectionLost()
|
||||
|
||||
def writeConnectionLost(self):
|
||||
self._recepient._writeConnectionLost()
|
||||
|
||||
|
||||
class UnbufferedTransport(GreenTransportBase):
|
||||
@@ -148,16 +227,12 @@ class UnbufferedTransport(GreenTransportBase):
|
||||
Return '' if connection was closed cleanly, raise the exception if it was closed
|
||||
in a non clean fashion. After that all successive calls return ''.
|
||||
"""
|
||||
if self._queue is None:
|
||||
if self._read_disconnected_event.ready():
|
||||
return ''
|
||||
try:
|
||||
return self._wait()
|
||||
except ConnectionDone:
|
||||
self._queue = None
|
||||
return ''
|
||||
except:
|
||||
self._queue = None
|
||||
raise
|
||||
|
||||
def read(self):
|
||||
"""Read the data from the socket until the connection is closed cleanly.
|
||||
@@ -192,45 +267,31 @@ class GreenTransport(GreenTransportBase):
|
||||
_buffer = ''
|
||||
_error = None
|
||||
|
||||
def _wait(self):
|
||||
# don't pause/resume producer here; read and recv methods will do it themselves
|
||||
try:
|
||||
return self._queue.wait()
|
||||
except:
|
||||
self._error_event.send(None)
|
||||
raise
|
||||
|
||||
def read(self, size=-1):
|
||||
"""Read size bytes or until EOF"""
|
||||
if self._queue is not None:
|
||||
resumed = False
|
||||
if not self._read_disconnected_event.ready():
|
||||
try:
|
||||
try:
|
||||
while len(self._buffer) < size or size < 0:
|
||||
if not resumed:
|
||||
self.resumeProducing()
|
||||
resumed = True
|
||||
self._buffer += self._wait()
|
||||
except ConnectionDone:
|
||||
self._queue = None
|
||||
except:
|
||||
self._queue = None
|
||||
self._error = sys.exc_info()
|
||||
finally:
|
||||
if resumed:
|
||||
self.pauseProducing()
|
||||
while len(self._buffer) < size or size < 0:
|
||||
self._buffer += self._wait()
|
||||
except ConnectionDone:
|
||||
pass
|
||||
except:
|
||||
if not self._read_disconnected_event.has_exception():
|
||||
raise
|
||||
if size>=0:
|
||||
result, self._buffer = self._buffer[:size], self._buffer[size:]
|
||||
else:
|
||||
result, self._buffer = self._buffer, ''
|
||||
if not result and self._error is not None:
|
||||
error, self._error = self._error, None
|
||||
raise error[0], error[1], error[2]
|
||||
if not result and self._read_disconnected_event.has_exception():
|
||||
try:
|
||||
self._read_disconnected_event.wait()
|
||||
except ConnectionDone:
|
||||
pass
|
||||
return result
|
||||
|
||||
def recv(self, buflen=None):
|
||||
"""Receive a single chunk of undefined size but no bigger than buflen"""
|
||||
if self._queue is not None and not self._buffer:
|
||||
if not self._read_disconnected_event.ready():
|
||||
self.resumeProducing()
|
||||
try:
|
||||
try:
|
||||
@@ -238,20 +299,21 @@ class GreenTransport(GreenTransportBase):
|
||||
#print 'received %r' % recvd
|
||||
self._buffer += recvd
|
||||
except ConnectionDone:
|
||||
self._queue = None
|
||||
pass
|
||||
except:
|
||||
self._queue = None
|
||||
self._error = sys.exc_info()
|
||||
if not self._read_disconnected_event.has_exception():
|
||||
raise
|
||||
finally:
|
||||
self.pauseProducing()
|
||||
if buflen is None:
|
||||
result, self._buffer = self._buffer, ''
|
||||
else:
|
||||
result, self._buffer = self._buffer[:buflen], self._buffer[buflen:]
|
||||
if not result and self._error is not None:
|
||||
error = self._error
|
||||
self._error = None
|
||||
raise error[0], error[1], error[2]
|
||||
if not result and self._read_disconnected_event.has_exception():
|
||||
try:
|
||||
self._read_disconnected_event.wait()
|
||||
except ConnectionDone:
|
||||
pass
|
||||
return result
|
||||
|
||||
# iterator protocol:
|
||||
|
@@ -23,19 +23,36 @@ from twisted.protocols import basic
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from eventlet.twistedutil.protocol import GreenTransportBase
|
||||
|
||||
from twisted.internet.interfaces import IHalfCloseableProtocol
|
||||
from zope.interface import implements
|
||||
|
||||
class LineOnlyReceiver(basic.LineOnlyReceiver):
|
||||
|
||||
def __init__(self, queue):
|
||||
self._queue = queue
|
||||
implements(IHalfCloseableProtocol)
|
||||
|
||||
def __init__(self, recepient):
|
||||
self._recepient = recepient
|
||||
|
||||
def connectionMade(self):
|
||||
self._queue.send(self.transport)
|
||||
|
||||
def lineReceived(self, line):
|
||||
self._queue.send(line)
|
||||
#print '%r made' % self
|
||||
self._recepient._got_transport(self.transport)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self._queue.send_exception(reason.type, reason.value, reason.tb)
|
||||
#print '%r conn lost %r' % (self, reason)
|
||||
self._recepient._connectionLost(reason)
|
||||
|
||||
def readConnectionLost(self):
|
||||
#print '%r read conn lost' % self
|
||||
self._recepient._readConnectionLost()
|
||||
|
||||
def writeConnectionLost(self):
|
||||
#print '%r wr conn lost' % self
|
||||
self._recepient._writeConnectionLost()
|
||||
|
||||
def lineReceived(self, line):
|
||||
#print '%r line received %r' % (self, line)
|
||||
self._recepient._got_data(line)
|
||||
|
||||
|
||||
class LineOnlyReceiverTransport(GreenTransportBase):
|
||||
|
||||
|
@@ -12,11 +12,13 @@ from twisted.internet import reactor
|
||||
# read from TCP connection
|
||||
conn = GreenClientCreator(reactor).connectTCP('www.google.com', 80)
|
||||
conn.write('GET / HTTP/1.0\r\n\r\n')
|
||||
conn.loseWriteConnection()
|
||||
print conn.read()
|
||||
|
||||
# read from SSL connection line by line
|
||||
conn = GreenClientCreator(reactor, LineOnlyReceiverTransport).connectSSL('sf.net', 443, ssl.ClientContextFactory())
|
||||
conn.write('GET / HTTP/1.0\r\n\r\n')
|
||||
#conn.loseWriteConnection()
|
||||
try:
|
||||
for num, line in enumerate(conn):
|
||||
print '%3s %r' % (num, line)
|
||||
|
@@ -52,6 +52,7 @@ class Chat:
|
||||
else:
|
||||
print peer, 'connection done'
|
||||
finally:
|
||||
conn.loseConnection()
|
||||
self.participants.remove(conn)
|
||||
|
||||
print __doc__
|
||||
|
@@ -45,8 +45,11 @@ From-Path: msrps://alice.example.com:9892/98cjs;tcp
|
||||
-------49fh$
|
||||
""".replace('\n', '\r\n')
|
||||
|
||||
print 'Sending:\n%s' % request
|
||||
conn.write(request)
|
||||
#conn.loseWriteConnection()
|
||||
print 'Received:'
|
||||
for x in conn:
|
||||
print x
|
||||
print repr(x)
|
||||
if '-------' in x:
|
||||
break
|
||||
|
@@ -20,10 +20,9 @@
|
||||
# THE SOFTWARE.
|
||||
|
||||
from twisted.internet import reactor
|
||||
from greentest import exit_unless_twisted
|
||||
from greentest import exit_unless_twisted, LimitedTestCase
|
||||
exit_unless_twisted()
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from twisted.internet.error import ConnectionDone
|
||||
|
||||
@@ -116,7 +115,6 @@ class TestGreenTransport(TestUnbufferedTransport):
|
||||
self.conn.write('hello\r\n')
|
||||
self.assertEqual(self.conn.read(9), 'you said ')
|
||||
self.assertEqual(self.conn.read(999), 'hello. BYE')
|
||||
self.assertEqual(None, self.conn._queue)
|
||||
self.assertEqual(self.conn.read(9), '')
|
||||
self.assertEqual(self.conn.read(1), '')
|
||||
self.assertEqual(self.conn.recv(9), '')
|
||||
@@ -156,26 +154,25 @@ class TestGreenTransport(TestUnbufferedTransport):
|
||||
class TestGreenTransport_bufsize1(TestGreenTransport):
|
||||
transportBufferSize = 1
|
||||
|
||||
class TestGreenTransportError(TestCase):
|
||||
setup_server = setup_server_SpawnFactory
|
||||
gtransportClass = pr.GreenTransport
|
||||
|
||||
def test_read_error(self):
|
||||
self.conn.write('hello\r\n')
|
||||
sleep(DELAY*1.5) # make sure the rest of data arrives
|
||||
try:
|
||||
1/0
|
||||
except:
|
||||
#self.conn.loseConnection(failure.Failure()) # does not work, why?
|
||||
spawn(self.conn._queue.send_exception, *sys.exc_info())
|
||||
self.assertEqual(self.conn.read(9), 'you said ')
|
||||
self.assertEqual(self.conn.read(7), 'hello. ')
|
||||
self.assertEqual(self.conn.read(9), 'BYE')
|
||||
self.assertRaises(ZeroDivisionError, self.conn.read, 9)
|
||||
self.assertEqual(None, self.conn._queue)
|
||||
self.assertEqual(self.conn.read(1), '')
|
||||
self.assertEqual(self.conn.read(1), '')
|
||||
|
||||
# class TestGreenTransportError(TestCase):
|
||||
# setup_server = setup_server_SpawnFactory
|
||||
# gtransportClass = pr.GreenTransport
|
||||
#
|
||||
# def test_read_error(self):
|
||||
# self.conn.write('hello\r\n')
|
||||
# sleep(DELAY*1.5) # make sure the rest of data arrives
|
||||
# try:
|
||||
# 1/0
|
||||
# except:
|
||||
# #self.conn.loseConnection(failure.Failure()) # does not work, why?
|
||||
# spawn(self.conn._queue.send_exception, *sys.exc_info())
|
||||
# self.assertEqual(self.conn.read(9), 'you said ')
|
||||
# self.assertEqual(self.conn.read(7), 'hello. ')
|
||||
# self.assertEqual(self.conn.read(9), 'BYE')
|
||||
# self.assertRaises(ZeroDivisionError, self.conn.read, 9)
|
||||
# self.assertEqual(self.conn.read(1), '')
|
||||
# self.assertEqual(self.conn.read(1), '')
|
||||
#
|
||||
# def test_recv_error(self):
|
||||
# self.conn.write('hello')
|
||||
# self.assertEqual('you said hello. ', self.conn.recv())
|
||||
@@ -187,11 +184,57 @@ class TestGreenTransportError(TestCase):
|
||||
# spawn(self.conn._queue.send_exception, *sys.exc_info())
|
||||
# self.assertEqual('BYE', self.conn.recv())
|
||||
# self.assertRaises(ZeroDivisionError, self.conn.recv, 9)
|
||||
# self.assertEqual(None, self.conn._queue)
|
||||
# self.assertEqual('', self.conn.recv(1))
|
||||
# self.assertEqual('', self.conn.recv())
|
||||
#
|
||||
|
||||
class TestHalfClose_TCP(LimitedTestCase):
|
||||
|
||||
def _test_server(self, conn):
|
||||
conn.write('hello')
|
||||
conn.loseWriteConnection()
|
||||
self.assertRaises(pr.WriteConnectionDone, conn.write, 'hey')
|
||||
data = conn.read()
|
||||
self.assertEqual('bye', data)
|
||||
conn.loseConnection()
|
||||
self.assertRaises(ConnectionDone, conn._wait)
|
||||
self.check.append('server')
|
||||
|
||||
def setUp(self):
|
||||
LimitedTestCase.setUp(self)
|
||||
self.factory = pr.SpawnFactory(self._test_server)
|
||||
self.port = reactor.listenTCP(0, self.factory)
|
||||
self.conn = pr.GreenClientCreator(reactor).connectTCP('localhost', self.port.getHost().port)
|
||||
self.port.stopListening()
|
||||
self.check = []
|
||||
|
||||
def test(self):
|
||||
conn = self.conn
|
||||
data = conn.read()
|
||||
self.assertEqual('hello', data)
|
||||
conn.write('bye')
|
||||
conn.loseWriteConnection()
|
||||
self.assertRaises(pr.WriteConnectionDone, conn.write, 'hoy')
|
||||
self.factory.waitall()
|
||||
self.assertRaises(ConnectionDone, conn._wait)
|
||||
assert self.check == ['server']
|
||||
|
||||
class TestHalfClose_TLS(TestHalfClose_TCP):
|
||||
|
||||
def setUp(self):
|
||||
LimitedTestCase.setUp(self)
|
||||
from gnutls.crypto import X509PrivateKey, X509Certificate
|
||||
from gnutls.interfaces.twisted import X509Credentials
|
||||
cert = X509Certificate(open('gnutls_valid.crt').read())
|
||||
key = X509PrivateKey(open('gnutls_valid.key').read())
|
||||
server_credentials = X509Credentials(cert, key)
|
||||
self.factory = pr.SpawnFactory(self._test_server)
|
||||
self.port = reactor.listenTLS(0, self.factory, server_credentials)
|
||||
self.conn = pr.GreenClientCreator(reactor).connectTLS('localhost', self.port.getHost().port, X509Credentials())
|
||||
self.port.stopListening()
|
||||
self.check = []
|
||||
|
||||
|
||||
if socket is not None:
|
||||
|
||||
class TestUnbufferedTransport_socketserver(TestUnbufferedTransport):
|
||||
@@ -232,6 +275,7 @@ try:
|
||||
import gnutls.interfaces.twisted
|
||||
except ImportError:
|
||||
del TestTLSError
|
||||
del TestHalfClose_TLS
|
||||
|
||||
if __name__=='__main__':
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user