GreenTransportBase: implement half-closability and write() that blocks until data is sent.

* add loseWriteConnection() method
* add async_write() that does not wait (synonim for transport.write())
* GreenTransport's protocols now required to implement IHalfClosable interface
* GreenTransport use PullProducer instead of PushProducer
* add test for loseWriteConnection and remove test that does not work anymore
This commit is contained in:
Denis Bilenko
2009-02-09 19:25:29 +06:00
parent 30e594475a
commit 16cbfca0a0
6 changed files with 230 additions and 101 deletions

View File

@@ -20,18 +20,62 @@
# THE SOFTWARE. # THE SOFTWARE.
"""Basic twisted protocols converted to synchronous mode""" """Basic twisted protocols converted to synchronous mode"""
import sys
from twisted.internet.protocol import Protocol as twistedProtocol from twisted.internet.protocol import Protocol as twistedProtocol
from twisted.internet.error import ConnectionDone from twisted.internet.error import ConnectionDone
from twisted.internet.protocol import Factory, ClientFactory from twisted.internet.protocol import Factory, ClientFactory
from twisted.python import failure
from eventlet.api import spawn from twisted.internet.interfaces import IHalfCloseableProtocol
from zope.interface import implements
from eventlet import proc
from eventlet.api import getcurrent
from eventlet.coros import queue, event from eventlet.coros import queue, event
class ReadConnectionDone(ConnectionDone):
pass
class WriteConnectionDone(ConnectionDone):
pass
class ValueQueue(queue):
def wait(self):
"""The difference from queue.wait: if there is an only item in the
queue and it is an exception, raise it, but keep it in the queue, so
that future calls to wait() will raise it again.
"""
self.sem.acquire()
if self.has_final_error():
# the last item, which is an exception, raise without emptying the queue
self.sem.release()
getcurrent().throw(*self.items[0][1])
else:
result, exc = self.items.popleft()
if exc is not None:
getcurrent().throw(*exc)
return result
def has_final_error(self):
return len(self.items)==1 and self.items[0][1] is not None
class Event(event):
def send(self, value, exc=None):
if self.ready():
self.reset()
return event.send(self, value, exc)
def send_exception(self, *throw_args):
if self.ready():
self.reset()
return event.send_exception(self, *throw_args)
class Producer2Event(object): class Producer2Event(object):
# implements IPushProducer # implements IPullProducer
def __init__(self, event): def __init__(self, event):
self.event = event self.event = event
@@ -39,48 +83,78 @@ class Producer2Event(object):
def resumeProducing(self): def resumeProducing(self):
self.event.send(1) self.event.send(1)
def pauseProducing(self):
self.event.reset()
def stopProducing(self): def stopProducing(self):
del self.event del self.event
class GreenTransportBase(object): class GreenTransportBase(object):
write_event = None
transportBufferSize = None transportBufferSize = None
def __init__(self, transportBufferSize=None): def __init__(self, transportBufferSize=None):
if transportBufferSize is not None: if transportBufferSize is not None:
self.transportBufferSize = transportBufferSize self.transportBufferSize = transportBufferSize
self._error_event = event() self._queue = queue()
self._read_disconnected_event = Event()
self._write_disconnected_event = Event()
self._write_event = Event()
def build_protocol(self): def build_protocol(self):
# note to subclassers: self._queue must have send and send_exception that never block protocol = self.protocol_class(self)
self._queue = queue()
protocol = self.protocol_class(self._queue)
return protocol return protocol
def _got_transport(self, transport):
self._queue.send(transport)
def _got_data(self, data):
self._queue.send(data)
def _connectionLost(self, reason):
self._read_disconnected_event.send(reason.value)
self._write_disconnected_event.send(reason.value)
self._queue.send_exception(reason.value)
self._write_event.send_exception(reason.value)
def _readConnectionLost(self):
self._read_disconnected_event.send(ReadConnectionDone)
self._queue.send_exception(ReadConnectionDone)
def _writeConnectionLost(self):
self._write_disconnected_event.send(WriteConnectionDone)
self._write_event.send_exception(WriteConnectionDone)
def _wait(self): def _wait(self):
if self._read_disconnected_event.ready():
if self._queue:
return self._queue.wait()
else:
raise self._read_disconnected_event.wait()
self.resumeProducing() self.resumeProducing()
try: try:
return self._queue.wait() return self._queue.wait()
except:
if self._error_event is not None:
self._error_event.send(None)
raise
finally: finally:
self.pauseProducing() self.pauseProducing()
def write(self, data): def write(self, data):
if self._write_disconnected_event.ready():
raise self._write_disconnected_event.wait()
self._write_event.reset()
self.transport.write(data)
self._write_event.wait()
def async_write(self, data):
self.transport.write(data) self.transport.write(data)
if self.write_event is not None:
self.write_event.wait()
def loseConnection(self): def loseConnection(self):
self.transport.unregisterProducer() self.transport.unregisterProducer()
self.transport.loseConnection() self.transport.loseConnection()
self._error_event.wait() self._read_disconnected_event.wait()
self._write_disconnected_event.wait()
def loseWriteConnection(self):
self.transport.unregisterProducer()
self.transport.loseWriteConnection()
self._write_disconnected_event.wait()
def __getattr__(self, item): def __getattr__(self, item):
if item=='transport': if item=='transport':
@@ -116,25 +190,30 @@ class GreenTransportBase(object):
if self.transportBufferSize is not None: if self.transportBufferSize is not None:
transport.bufferSize = self.transportBufferSize transport.bufferSize = self.transportBufferSize
self._init_transport_producer() self._init_transport_producer()
if self.write_event is None: transport.registerProducer(Producer2Event(self._write_event), False)
self.write_event = event()
self.write_event.send(1)
transport.registerProducer(Producer2Event(self.write_event), True)
class Protocol(twistedProtocol): class Protocol(twistedProtocol):
def __init__(self, queue): implements(IHalfCloseableProtocol)
self._queue = queue
def __init__(self, recepient):
self._recepient = recepient
def connectionMade(self): def connectionMade(self):
self._queue.send(self.transport) self._recepient._got_transport(self.transport)
def dataReceived(self, data): def dataReceived(self, data):
self._queue.send(data) self._recepient._got_data(data)
def connectionLost(self, reason): def connectionLost(self, reason):
self._queue.send_exception(reason.type, reason.value, reason.tb) self._recepient._connectionLost(reason)
del self._queue
def readConnectionLost(self):
self._recepient._readConnectionLost()
def writeConnectionLost(self):
self._recepient._writeConnectionLost()
class UnbufferedTransport(GreenTransportBase): class UnbufferedTransport(GreenTransportBase):
@@ -148,16 +227,12 @@ class UnbufferedTransport(GreenTransportBase):
Return '' if connection was closed cleanly, raise the exception if it was closed Return '' if connection was closed cleanly, raise the exception if it was closed
in a non clean fashion. After that all successive calls return ''. in a non clean fashion. After that all successive calls return ''.
""" """
if self._queue is None: if self._read_disconnected_event.ready():
return '' return ''
try: try:
return self._wait() return self._wait()
except ConnectionDone: except ConnectionDone:
self._queue = None
return '' return ''
except:
self._queue = None
raise
def read(self): def read(self):
"""Read the data from the socket until the connection is closed cleanly. """Read the data from the socket until the connection is closed cleanly.
@@ -192,45 +267,31 @@ class GreenTransport(GreenTransportBase):
_buffer = '' _buffer = ''
_error = None _error = None
def _wait(self):
# don't pause/resume producer here; read and recv methods will do it themselves
try:
return self._queue.wait()
except:
self._error_event.send(None)
raise
def read(self, size=-1): def read(self, size=-1):
"""Read size bytes or until EOF""" """Read size bytes or until EOF"""
if self._queue is not None: if not self._read_disconnected_event.ready():
resumed = False
try: try:
try: while len(self._buffer) < size or size < 0:
while len(self._buffer) < size or size < 0: self._buffer += self._wait()
if not resumed: except ConnectionDone:
self.resumeProducing() pass
resumed = True except:
self._buffer += self._wait() if not self._read_disconnected_event.has_exception():
except ConnectionDone: raise
self._queue = None
except:
self._queue = None
self._error = sys.exc_info()
finally:
if resumed:
self.pauseProducing()
if size>=0: if size>=0:
result, self._buffer = self._buffer[:size], self._buffer[size:] result, self._buffer = self._buffer[:size], self._buffer[size:]
else: else:
result, self._buffer = self._buffer, '' result, self._buffer = self._buffer, ''
if not result and self._error is not None: if not result and self._read_disconnected_event.has_exception():
error, self._error = self._error, None try:
raise error[0], error[1], error[2] self._read_disconnected_event.wait()
except ConnectionDone:
pass
return result return result
def recv(self, buflen=None): def recv(self, buflen=None):
"""Receive a single chunk of undefined size but no bigger than buflen""" """Receive a single chunk of undefined size but no bigger than buflen"""
if self._queue is not None and not self._buffer: if not self._read_disconnected_event.ready():
self.resumeProducing() self.resumeProducing()
try: try:
try: try:
@@ -238,20 +299,21 @@ class GreenTransport(GreenTransportBase):
#print 'received %r' % recvd #print 'received %r' % recvd
self._buffer += recvd self._buffer += recvd
except ConnectionDone: except ConnectionDone:
self._queue = None pass
except: except:
self._queue = None if not self._read_disconnected_event.has_exception():
self._error = sys.exc_info() raise
finally: finally:
self.pauseProducing() self.pauseProducing()
if buflen is None: if buflen is None:
result, self._buffer = self._buffer, '' result, self._buffer = self._buffer, ''
else: else:
result, self._buffer = self._buffer[:buflen], self._buffer[buflen:] result, self._buffer = self._buffer[:buflen], self._buffer[buflen:]
if not result and self._error is not None: if not result and self._read_disconnected_event.has_exception():
error = self._error try:
self._error = None self._read_disconnected_event.wait()
raise error[0], error[1], error[2] except ConnectionDone:
pass
return result return result
# iterator protocol: # iterator protocol:

View File

@@ -23,19 +23,36 @@ from twisted.protocols import basic
from twisted.internet.error import ConnectionDone from twisted.internet.error import ConnectionDone
from eventlet.twistedutil.protocol import GreenTransportBase from eventlet.twistedutil.protocol import GreenTransportBase
from twisted.internet.interfaces import IHalfCloseableProtocol
from zope.interface import implements
class LineOnlyReceiver(basic.LineOnlyReceiver): class LineOnlyReceiver(basic.LineOnlyReceiver):
def __init__(self, queue): implements(IHalfCloseableProtocol)
self._queue = queue
def __init__(self, recepient):
self._recepient = recepient
def connectionMade(self): def connectionMade(self):
self._queue.send(self.transport) #print '%r made' % self
self._recepient._got_transport(self.transport)
def lineReceived(self, line):
self._queue.send(line)
def connectionLost(self, reason): def connectionLost(self, reason):
self._queue.send_exception(reason.type, reason.value, reason.tb) #print '%r conn lost %r' % (self, reason)
self._recepient._connectionLost(reason)
def readConnectionLost(self):
#print '%r read conn lost' % self
self._recepient._readConnectionLost()
def writeConnectionLost(self):
#print '%r wr conn lost' % self
self._recepient._writeConnectionLost()
def lineReceived(self, line):
#print '%r line received %r' % (self, line)
self._recepient._got_data(line)
class LineOnlyReceiverTransport(GreenTransportBase): class LineOnlyReceiverTransport(GreenTransportBase):

View File

@@ -12,11 +12,13 @@ from twisted.internet import reactor
# read from TCP connection # read from TCP connection
conn = GreenClientCreator(reactor).connectTCP('www.google.com', 80) conn = GreenClientCreator(reactor).connectTCP('www.google.com', 80)
conn.write('GET / HTTP/1.0\r\n\r\n') conn.write('GET / HTTP/1.0\r\n\r\n')
conn.loseWriteConnection()
print conn.read() print conn.read()
# read from SSL connection line by line # read from SSL connection line by line
conn = GreenClientCreator(reactor, LineOnlyReceiverTransport).connectSSL('sf.net', 443, ssl.ClientContextFactory()) conn = GreenClientCreator(reactor, LineOnlyReceiverTransport).connectSSL('sf.net', 443, ssl.ClientContextFactory())
conn.write('GET / HTTP/1.0\r\n\r\n') conn.write('GET / HTTP/1.0\r\n\r\n')
#conn.loseWriteConnection()
try: try:
for num, line in enumerate(conn): for num, line in enumerate(conn):
print '%3s %r' % (num, line) print '%3s %r' % (num, line)

View File

@@ -52,6 +52,7 @@ class Chat:
else: else:
print peer, 'connection done' print peer, 'connection done'
finally: finally:
conn.loseConnection()
self.participants.remove(conn) self.participants.remove(conn)
print __doc__ print __doc__

View File

@@ -45,8 +45,11 @@ From-Path: msrps://alice.example.com:9892/98cjs;tcp
-------49fh$ -------49fh$
""".replace('\n', '\r\n') """.replace('\n', '\r\n')
print 'Sending:\n%s' % request
conn.write(request) conn.write(request)
#conn.loseWriteConnection()
print 'Received:'
for x in conn: for x in conn:
print x print repr(x)
if '-------' in x: if '-------' in x:
break break

View File

@@ -20,10 +20,9 @@
# THE SOFTWARE. # THE SOFTWARE.
from twisted.internet import reactor from twisted.internet import reactor
from greentest import exit_unless_twisted from greentest import exit_unless_twisted, LimitedTestCase
exit_unless_twisted() exit_unless_twisted()
import sys
import unittest import unittest
from twisted.internet.error import ConnectionDone from twisted.internet.error import ConnectionDone
@@ -116,7 +115,6 @@ class TestGreenTransport(TestUnbufferedTransport):
self.conn.write('hello\r\n') self.conn.write('hello\r\n')
self.assertEqual(self.conn.read(9), 'you said ') self.assertEqual(self.conn.read(9), 'you said ')
self.assertEqual(self.conn.read(999), 'hello. BYE') self.assertEqual(self.conn.read(999), 'hello. BYE')
self.assertEqual(None, self.conn._queue)
self.assertEqual(self.conn.read(9), '') self.assertEqual(self.conn.read(9), '')
self.assertEqual(self.conn.read(1), '') self.assertEqual(self.conn.read(1), '')
self.assertEqual(self.conn.recv(9), '') self.assertEqual(self.conn.recv(9), '')
@@ -156,26 +154,25 @@ class TestGreenTransport(TestUnbufferedTransport):
class TestGreenTransport_bufsize1(TestGreenTransport): class TestGreenTransport_bufsize1(TestGreenTransport):
transportBufferSize = 1 transportBufferSize = 1
class TestGreenTransportError(TestCase): # class TestGreenTransportError(TestCase):
setup_server = setup_server_SpawnFactory # setup_server = setup_server_SpawnFactory
gtransportClass = pr.GreenTransport # gtransportClass = pr.GreenTransport
#
def test_read_error(self): # def test_read_error(self):
self.conn.write('hello\r\n') # self.conn.write('hello\r\n')
sleep(DELAY*1.5) # make sure the rest of data arrives # sleep(DELAY*1.5) # make sure the rest of data arrives
try: # try:
1/0 # 1/0
except: # except:
#self.conn.loseConnection(failure.Failure()) # does not work, why? # #self.conn.loseConnection(failure.Failure()) # does not work, why?
spawn(self.conn._queue.send_exception, *sys.exc_info()) # spawn(self.conn._queue.send_exception, *sys.exc_info())
self.assertEqual(self.conn.read(9), 'you said ') # self.assertEqual(self.conn.read(9), 'you said ')
self.assertEqual(self.conn.read(7), 'hello. ') # self.assertEqual(self.conn.read(7), 'hello. ')
self.assertEqual(self.conn.read(9), 'BYE') # self.assertEqual(self.conn.read(9), 'BYE')
self.assertRaises(ZeroDivisionError, self.conn.read, 9) # self.assertRaises(ZeroDivisionError, self.conn.read, 9)
self.assertEqual(None, self.conn._queue) # self.assertEqual(self.conn.read(1), '')
self.assertEqual(self.conn.read(1), '') # self.assertEqual(self.conn.read(1), '')
self.assertEqual(self.conn.read(1), '') #
# def test_recv_error(self): # def test_recv_error(self):
# self.conn.write('hello') # self.conn.write('hello')
# self.assertEqual('you said hello. ', self.conn.recv()) # self.assertEqual('you said hello. ', self.conn.recv())
@@ -187,11 +184,57 @@ class TestGreenTransportError(TestCase):
# spawn(self.conn._queue.send_exception, *sys.exc_info()) # spawn(self.conn._queue.send_exception, *sys.exc_info())
# self.assertEqual('BYE', self.conn.recv()) # self.assertEqual('BYE', self.conn.recv())
# self.assertRaises(ZeroDivisionError, self.conn.recv, 9) # self.assertRaises(ZeroDivisionError, self.conn.recv, 9)
# self.assertEqual(None, self.conn._queue)
# self.assertEqual('', self.conn.recv(1)) # self.assertEqual('', self.conn.recv(1))
# self.assertEqual('', self.conn.recv()) # self.assertEqual('', self.conn.recv())
# #
class TestHalfClose_TCP(LimitedTestCase):
def _test_server(self, conn):
conn.write('hello')
conn.loseWriteConnection()
self.assertRaises(pr.WriteConnectionDone, conn.write, 'hey')
data = conn.read()
self.assertEqual('bye', data)
conn.loseConnection()
self.assertRaises(ConnectionDone, conn._wait)
self.check.append('server')
def setUp(self):
LimitedTestCase.setUp(self)
self.factory = pr.SpawnFactory(self._test_server)
self.port = reactor.listenTCP(0, self.factory)
self.conn = pr.GreenClientCreator(reactor).connectTCP('localhost', self.port.getHost().port)
self.port.stopListening()
self.check = []
def test(self):
conn = self.conn
data = conn.read()
self.assertEqual('hello', data)
conn.write('bye')
conn.loseWriteConnection()
self.assertRaises(pr.WriteConnectionDone, conn.write, 'hoy')
self.factory.waitall()
self.assertRaises(ConnectionDone, conn._wait)
assert self.check == ['server']
class TestHalfClose_TLS(TestHalfClose_TCP):
def setUp(self):
LimitedTestCase.setUp(self)
from gnutls.crypto import X509PrivateKey, X509Certificate
from gnutls.interfaces.twisted import X509Credentials
cert = X509Certificate(open('gnutls_valid.crt').read())
key = X509PrivateKey(open('gnutls_valid.key').read())
server_credentials = X509Credentials(cert, key)
self.factory = pr.SpawnFactory(self._test_server)
self.port = reactor.listenTLS(0, self.factory, server_credentials)
self.conn = pr.GreenClientCreator(reactor).connectTLS('localhost', self.port.getHost().port, X509Credentials())
self.port.stopListening()
self.check = []
if socket is not None: if socket is not None:
class TestUnbufferedTransport_socketserver(TestUnbufferedTransport): class TestUnbufferedTransport_socketserver(TestUnbufferedTransport):
@@ -232,6 +275,7 @@ try:
import gnutls.interfaces.twisted import gnutls.interfaces.twisted
except ImportError: except ImportError:
del TestTLSError del TestTLSError
del TestHalfClose_TLS
if __name__=='__main__': if __name__=='__main__':
unittest.main() unittest.main()