GreenTransportBase: implement half-closability and write() that blocks until data is sent.

* add loseWriteConnection() method
* add async_write() that does not wait (synonim for transport.write())
* GreenTransport's protocols now required to implement IHalfClosable interface
* GreenTransport use PullProducer instead of PushProducer
* add test for loseWriteConnection and remove test that does not work anymore
This commit is contained in:
Denis Bilenko
2009-02-09 19:25:29 +06:00
parent 30e594475a
commit 16cbfca0a0
6 changed files with 230 additions and 101 deletions

View File

@@ -20,18 +20,62 @@
# THE SOFTWARE.
"""Basic twisted protocols converted to synchronous mode"""
import sys
from twisted.internet.protocol import Protocol as twistedProtocol
from twisted.internet.error import ConnectionDone
from twisted.internet.protocol import Factory, ClientFactory
from twisted.python import failure
from eventlet.api import spawn
from twisted.internet.interfaces import IHalfCloseableProtocol
from zope.interface import implements
from eventlet import proc
from eventlet.api import getcurrent
from eventlet.coros import queue, event
class ReadConnectionDone(ConnectionDone):
pass
class WriteConnectionDone(ConnectionDone):
pass
class ValueQueue(queue):
def wait(self):
"""The difference from queue.wait: if there is an only item in the
queue and it is an exception, raise it, but keep it in the queue, so
that future calls to wait() will raise it again.
"""
self.sem.acquire()
if self.has_final_error():
# the last item, which is an exception, raise without emptying the queue
self.sem.release()
getcurrent().throw(*self.items[0][1])
else:
result, exc = self.items.popleft()
if exc is not None:
getcurrent().throw(*exc)
return result
def has_final_error(self):
return len(self.items)==1 and self.items[0][1] is not None
class Event(event):
def send(self, value, exc=None):
if self.ready():
self.reset()
return event.send(self, value, exc)
def send_exception(self, *throw_args):
if self.ready():
self.reset()
return event.send_exception(self, *throw_args)
class Producer2Event(object):
# implements IPushProducer
# implements IPullProducer
def __init__(self, event):
self.event = event
@@ -39,48 +83,78 @@ class Producer2Event(object):
def resumeProducing(self):
self.event.send(1)
def pauseProducing(self):
self.event.reset()
def stopProducing(self):
del self.event
class GreenTransportBase(object):
write_event = None
transportBufferSize = None
def __init__(self, transportBufferSize=None):
if transportBufferSize is not None:
self.transportBufferSize = transportBufferSize
self._error_event = event()
self._queue = queue()
self._read_disconnected_event = Event()
self._write_disconnected_event = Event()
self._write_event = Event()
def build_protocol(self):
# note to subclassers: self._queue must have send and send_exception that never block
self._queue = queue()
protocol = self.protocol_class(self._queue)
protocol = self.protocol_class(self)
return protocol
def _got_transport(self, transport):
self._queue.send(transport)
def _got_data(self, data):
self._queue.send(data)
def _connectionLost(self, reason):
self._read_disconnected_event.send(reason.value)
self._write_disconnected_event.send(reason.value)
self._queue.send_exception(reason.value)
self._write_event.send_exception(reason.value)
def _readConnectionLost(self):
self._read_disconnected_event.send(ReadConnectionDone)
self._queue.send_exception(ReadConnectionDone)
def _writeConnectionLost(self):
self._write_disconnected_event.send(WriteConnectionDone)
self._write_event.send_exception(WriteConnectionDone)
def _wait(self):
if self._read_disconnected_event.ready():
if self._queue:
return self._queue.wait()
else:
raise self._read_disconnected_event.wait()
self.resumeProducing()
try:
return self._queue.wait()
except:
if self._error_event is not None:
self._error_event.send(None)
raise
finally:
self.pauseProducing()
def write(self, data):
if self._write_disconnected_event.ready():
raise self._write_disconnected_event.wait()
self._write_event.reset()
self.transport.write(data)
self._write_event.wait()
def async_write(self, data):
self.transport.write(data)
if self.write_event is not None:
self.write_event.wait()
def loseConnection(self):
self.transport.unregisterProducer()
self.transport.loseConnection()
self._error_event.wait()
self._read_disconnected_event.wait()
self._write_disconnected_event.wait()
def loseWriteConnection(self):
self.transport.unregisterProducer()
self.transport.loseWriteConnection()
self._write_disconnected_event.wait()
def __getattr__(self, item):
if item=='transport':
@@ -116,25 +190,30 @@ class GreenTransportBase(object):
if self.transportBufferSize is not None:
transport.bufferSize = self.transportBufferSize
self._init_transport_producer()
if self.write_event is None:
self.write_event = event()
self.write_event.send(1)
transport.registerProducer(Producer2Event(self.write_event), True)
transport.registerProducer(Producer2Event(self._write_event), False)
class Protocol(twistedProtocol):
def __init__(self, queue):
self._queue = queue
implements(IHalfCloseableProtocol)
def __init__(self, recepient):
self._recepient = recepient
def connectionMade(self):
self._queue.send(self.transport)
self._recepient._got_transport(self.transport)
def dataReceived(self, data):
self._queue.send(data)
self._recepient._got_data(data)
def connectionLost(self, reason):
self._queue.send_exception(reason.type, reason.value, reason.tb)
del self._queue
self._recepient._connectionLost(reason)
def readConnectionLost(self):
self._recepient._readConnectionLost()
def writeConnectionLost(self):
self._recepient._writeConnectionLost()
class UnbufferedTransport(GreenTransportBase):
@@ -148,16 +227,12 @@ class UnbufferedTransport(GreenTransportBase):
Return '' if connection was closed cleanly, raise the exception if it was closed
in a non clean fashion. After that all successive calls return ''.
"""
if self._queue is None:
if self._read_disconnected_event.ready():
return ''
try:
return self._wait()
except ConnectionDone:
self._queue = None
return ''
except:
self._queue = None
raise
def read(self):
"""Read the data from the socket until the connection is closed cleanly.
@@ -192,45 +267,31 @@ class GreenTransport(GreenTransportBase):
_buffer = ''
_error = None
def _wait(self):
# don't pause/resume producer here; read and recv methods will do it themselves
try:
return self._queue.wait()
except:
self._error_event.send(None)
raise
def read(self, size=-1):
"""Read size bytes or until EOF"""
if self._queue is not None:
resumed = False
if not self._read_disconnected_event.ready():
try:
try:
while len(self._buffer) < size or size < 0:
if not resumed:
self.resumeProducing()
resumed = True
self._buffer += self._wait()
except ConnectionDone:
self._queue = None
except:
self._queue = None
self._error = sys.exc_info()
finally:
if resumed:
self.pauseProducing()
while len(self._buffer) < size or size < 0:
self._buffer += self._wait()
except ConnectionDone:
pass
except:
if not self._read_disconnected_event.has_exception():
raise
if size>=0:
result, self._buffer = self._buffer[:size], self._buffer[size:]
else:
result, self._buffer = self._buffer, ''
if not result and self._error is not None:
error, self._error = self._error, None
raise error[0], error[1], error[2]
if not result and self._read_disconnected_event.has_exception():
try:
self._read_disconnected_event.wait()
except ConnectionDone:
pass
return result
def recv(self, buflen=None):
"""Receive a single chunk of undefined size but no bigger than buflen"""
if self._queue is not None and not self._buffer:
if not self._read_disconnected_event.ready():
self.resumeProducing()
try:
try:
@@ -238,20 +299,21 @@ class GreenTransport(GreenTransportBase):
#print 'received %r' % recvd
self._buffer += recvd
except ConnectionDone:
self._queue = None
pass
except:
self._queue = None
self._error = sys.exc_info()
if not self._read_disconnected_event.has_exception():
raise
finally:
self.pauseProducing()
if buflen is None:
result, self._buffer = self._buffer, ''
else:
result, self._buffer = self._buffer[:buflen], self._buffer[buflen:]
if not result and self._error is not None:
error = self._error
self._error = None
raise error[0], error[1], error[2]
if not result and self._read_disconnected_event.has_exception():
try:
self._read_disconnected_event.wait()
except ConnectionDone:
pass
return result
# iterator protocol:

View File

@@ -23,19 +23,36 @@ from twisted.protocols import basic
from twisted.internet.error import ConnectionDone
from eventlet.twistedutil.protocol import GreenTransportBase
from twisted.internet.interfaces import IHalfCloseableProtocol
from zope.interface import implements
class LineOnlyReceiver(basic.LineOnlyReceiver):
def __init__(self, queue):
self._queue = queue
implements(IHalfCloseableProtocol)
def __init__(self, recepient):
self._recepient = recepient
def connectionMade(self):
self._queue.send(self.transport)
def lineReceived(self, line):
self._queue.send(line)
#print '%r made' % self
self._recepient._got_transport(self.transport)
def connectionLost(self, reason):
self._queue.send_exception(reason.type, reason.value, reason.tb)
#print '%r conn lost %r' % (self, reason)
self._recepient._connectionLost(reason)
def readConnectionLost(self):
#print '%r read conn lost' % self
self._recepient._readConnectionLost()
def writeConnectionLost(self):
#print '%r wr conn lost' % self
self._recepient._writeConnectionLost()
def lineReceived(self, line):
#print '%r line received %r' % (self, line)
self._recepient._got_data(line)
class LineOnlyReceiverTransport(GreenTransportBase):

View File

@@ -12,11 +12,13 @@ from twisted.internet import reactor
# read from TCP connection
conn = GreenClientCreator(reactor).connectTCP('www.google.com', 80)
conn.write('GET / HTTP/1.0\r\n\r\n')
conn.loseWriteConnection()
print conn.read()
# read from SSL connection line by line
conn = GreenClientCreator(reactor, LineOnlyReceiverTransport).connectSSL('sf.net', 443, ssl.ClientContextFactory())
conn.write('GET / HTTP/1.0\r\n\r\n')
#conn.loseWriteConnection()
try:
for num, line in enumerate(conn):
print '%3s %r' % (num, line)

View File

@@ -52,6 +52,7 @@ class Chat:
else:
print peer, 'connection done'
finally:
conn.loseConnection()
self.participants.remove(conn)
print __doc__

View File

@@ -45,8 +45,11 @@ From-Path: msrps://alice.example.com:9892/98cjs;tcp
-------49fh$
""".replace('\n', '\r\n')
print 'Sending:\n%s' % request
conn.write(request)
#conn.loseWriteConnection()
print 'Received:'
for x in conn:
print x
print repr(x)
if '-------' in x:
break

View File

@@ -20,10 +20,9 @@
# THE SOFTWARE.
from twisted.internet import reactor
from greentest import exit_unless_twisted
from greentest import exit_unless_twisted, LimitedTestCase
exit_unless_twisted()
import sys
import unittest
from twisted.internet.error import ConnectionDone
@@ -116,7 +115,6 @@ class TestGreenTransport(TestUnbufferedTransport):
self.conn.write('hello\r\n')
self.assertEqual(self.conn.read(9), 'you said ')
self.assertEqual(self.conn.read(999), 'hello. BYE')
self.assertEqual(None, self.conn._queue)
self.assertEqual(self.conn.read(9), '')
self.assertEqual(self.conn.read(1), '')
self.assertEqual(self.conn.recv(9), '')
@@ -156,26 +154,25 @@ class TestGreenTransport(TestUnbufferedTransport):
class TestGreenTransport_bufsize1(TestGreenTransport):
transportBufferSize = 1
class TestGreenTransportError(TestCase):
setup_server = setup_server_SpawnFactory
gtransportClass = pr.GreenTransport
def test_read_error(self):
self.conn.write('hello\r\n')
sleep(DELAY*1.5) # make sure the rest of data arrives
try:
1/0
except:
#self.conn.loseConnection(failure.Failure()) # does not work, why?
spawn(self.conn._queue.send_exception, *sys.exc_info())
self.assertEqual(self.conn.read(9), 'you said ')
self.assertEqual(self.conn.read(7), 'hello. ')
self.assertEqual(self.conn.read(9), 'BYE')
self.assertRaises(ZeroDivisionError, self.conn.read, 9)
self.assertEqual(None, self.conn._queue)
self.assertEqual(self.conn.read(1), '')
self.assertEqual(self.conn.read(1), '')
# class TestGreenTransportError(TestCase):
# setup_server = setup_server_SpawnFactory
# gtransportClass = pr.GreenTransport
#
# def test_read_error(self):
# self.conn.write('hello\r\n')
# sleep(DELAY*1.5) # make sure the rest of data arrives
# try:
# 1/0
# except:
# #self.conn.loseConnection(failure.Failure()) # does not work, why?
# spawn(self.conn._queue.send_exception, *sys.exc_info())
# self.assertEqual(self.conn.read(9), 'you said ')
# self.assertEqual(self.conn.read(7), 'hello. ')
# self.assertEqual(self.conn.read(9), 'BYE')
# self.assertRaises(ZeroDivisionError, self.conn.read, 9)
# self.assertEqual(self.conn.read(1), '')
# self.assertEqual(self.conn.read(1), '')
#
# def test_recv_error(self):
# self.conn.write('hello')
# self.assertEqual('you said hello. ', self.conn.recv())
@@ -187,11 +184,57 @@ class TestGreenTransportError(TestCase):
# spawn(self.conn._queue.send_exception, *sys.exc_info())
# self.assertEqual('BYE', self.conn.recv())
# self.assertRaises(ZeroDivisionError, self.conn.recv, 9)
# self.assertEqual(None, self.conn._queue)
# self.assertEqual('', self.conn.recv(1))
# self.assertEqual('', self.conn.recv())
#
class TestHalfClose_TCP(LimitedTestCase):
def _test_server(self, conn):
conn.write('hello')
conn.loseWriteConnection()
self.assertRaises(pr.WriteConnectionDone, conn.write, 'hey')
data = conn.read()
self.assertEqual('bye', data)
conn.loseConnection()
self.assertRaises(ConnectionDone, conn._wait)
self.check.append('server')
def setUp(self):
LimitedTestCase.setUp(self)
self.factory = pr.SpawnFactory(self._test_server)
self.port = reactor.listenTCP(0, self.factory)
self.conn = pr.GreenClientCreator(reactor).connectTCP('localhost', self.port.getHost().port)
self.port.stopListening()
self.check = []
def test(self):
conn = self.conn
data = conn.read()
self.assertEqual('hello', data)
conn.write('bye')
conn.loseWriteConnection()
self.assertRaises(pr.WriteConnectionDone, conn.write, 'hoy')
self.factory.waitall()
self.assertRaises(ConnectionDone, conn._wait)
assert self.check == ['server']
class TestHalfClose_TLS(TestHalfClose_TCP):
def setUp(self):
LimitedTestCase.setUp(self)
from gnutls.crypto import X509PrivateKey, X509Certificate
from gnutls.interfaces.twisted import X509Credentials
cert = X509Certificate(open('gnutls_valid.crt').read())
key = X509PrivateKey(open('gnutls_valid.key').read())
server_credentials = X509Credentials(cert, key)
self.factory = pr.SpawnFactory(self._test_server)
self.port = reactor.listenTLS(0, self.factory, server_credentials)
self.conn = pr.GreenClientCreator(reactor).connectTLS('localhost', self.port.getHost().port, X509Credentials())
self.port.stopListening()
self.check = []
if socket is not None:
class TestUnbufferedTransport_socketserver(TestUnbufferedTransport):
@@ -232,6 +275,7 @@ try:
import gnutls.interfaces.twisted
except ImportError:
del TestTLSError
del TestHalfClose_TLS
if __name__=='__main__':
unittest.main()