GreenTransportBase: revert half-closability (does not work with TLS anyway and not worth the complexity)
- loseConnection() and write() have an optional 'sync' argument - twistedutil.protocol.ValueQueue improved
This commit is contained in:
@@ -24,21 +24,20 @@ from twisted.internet.protocol import Protocol as twistedProtocol
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from twisted.internet.protocol import Factory, ClientFactory
|
||||
|
||||
from twisted.internet.interfaces import IHalfCloseableProtocol
|
||||
from zope.interface import implements
|
||||
|
||||
from eventlet import proc
|
||||
from eventlet.api import getcurrent
|
||||
from eventlet.coros import queue, event
|
||||
|
||||
|
||||
class ReadConnectionDone(ConnectionDone):
|
||||
pass
|
||||
|
||||
class WriteConnectionDone(ConnectionDone):
|
||||
pass
|
||||
|
||||
class ValueQueue(queue):
|
||||
"""Queue that keeps the last item forever in the queue if it's an exception.
|
||||
Useful if you send an exception over queue only once, and once sent it must be always
|
||||
available.
|
||||
"""
|
||||
|
||||
def send(self, value=None, exc=None):
|
||||
if exc is not None or not self.has_error():
|
||||
queue.send(self, value, exc)
|
||||
|
||||
def wait(self):
|
||||
"""The difference from queue.wait: if there is an only item in the
|
||||
@@ -46,7 +45,7 @@ class ValueQueue(queue):
|
||||
that future calls to wait() will raise it again.
|
||||
"""
|
||||
self.sem.acquire()
|
||||
if self.has_final_error():
|
||||
if self.has_error() and len(self.items)==1:
|
||||
# the last item, which is an exception, raise without emptying the queue
|
||||
self.sem.release()
|
||||
getcurrent().throw(*self.items[0][1])
|
||||
@@ -56,8 +55,8 @@ class ValueQueue(queue):
|
||||
getcurrent().throw(*exc)
|
||||
return result
|
||||
|
||||
def has_final_error(self):
|
||||
return len(self.items)==1 and self.items[0][1] is not None
|
||||
def has_error(self):
|
||||
return self.items and self.items[-1][1] is not None
|
||||
|
||||
|
||||
class Event(event):
|
||||
@@ -95,9 +94,8 @@ class GreenTransportBase(object):
|
||||
if transportBufferSize is not None:
|
||||
self.transportBufferSize = transportBufferSize
|
||||
self._queue = queue()
|
||||
self._read_disconnected_event = Event()
|
||||
self._write_disconnected_event = Event()
|
||||
self._write_event = Event()
|
||||
self._disconnected_event = Event()
|
||||
|
||||
def build_protocol(self):
|
||||
protocol = self.protocol_class(self)
|
||||
@@ -110,51 +108,37 @@ class GreenTransportBase(object):
|
||||
self._queue.send(data)
|
||||
|
||||
def _connectionLost(self, reason):
|
||||
self._read_disconnected_event.send(reason.value)
|
||||
self._write_disconnected_event.send(reason.value)
|
||||
self._disconnected_event.send(reason.value)
|
||||
self._queue.send_exception(reason.value)
|
||||
self._write_event.send_exception(reason.value)
|
||||
|
||||
def _readConnectionLost(self):
|
||||
self._read_disconnected_event.send(ReadConnectionDone)
|
||||
self._queue.send_exception(ReadConnectionDone)
|
||||
|
||||
def _writeConnectionLost(self):
|
||||
self._write_disconnected_event.send(WriteConnectionDone)
|
||||
self._write_event.send_exception(WriteConnectionDone)
|
||||
|
||||
def _wait(self):
|
||||
if self._read_disconnected_event.ready():
|
||||
if self._disconnected_event.ready():
|
||||
if self._queue:
|
||||
return self._queue.wait()
|
||||
else:
|
||||
raise self._read_disconnected_event.wait()
|
||||
raise self._disconnected_event.wait()
|
||||
self.resumeProducing()
|
||||
try:
|
||||
return self._queue.wait()
|
||||
finally:
|
||||
self.pauseProducing()
|
||||
|
||||
def write(self, data):
|
||||
if self._write_disconnected_event.ready():
|
||||
raise self._write_disconnected_event.wait()
|
||||
self._write_event.reset()
|
||||
self.transport.write(data)
|
||||
self._write_event.wait()
|
||||
def write(self, data, sync=True):
|
||||
if self._disconnected_event.ready():
|
||||
raise self._disconnected_event.wait()
|
||||
if sync:
|
||||
self._write_event.reset()
|
||||
self.transport.write(data)
|
||||
self._write_event.wait()
|
||||
else:
|
||||
self.transport.write(data)
|
||||
|
||||
def async_write(self, data):
|
||||
self.transport.write(data)
|
||||
|
||||
def loseConnection(self):
|
||||
def loseConnection(self, sync=True):
|
||||
self.transport.unregisterProducer()
|
||||
self.transport.loseConnection()
|
||||
self._read_disconnected_event.wait()
|
||||
self._write_disconnected_event.wait()
|
||||
|
||||
def loseWriteConnection(self):
|
||||
self.transport.unregisterProducer()
|
||||
self.transport.loseWriteConnection()
|
||||
self._write_disconnected_event.wait()
|
||||
if sync:
|
||||
self._disconnected_event.wait()
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item=='transport':
|
||||
@@ -194,8 +178,6 @@ class GreenTransportBase(object):
|
||||
|
||||
class Protocol(twistedProtocol):
|
||||
|
||||
implements(IHalfCloseableProtocol)
|
||||
|
||||
def __init__(self, recepient):
|
||||
self._recepient = recepient
|
||||
|
||||
@@ -208,12 +190,6 @@ class Protocol(twistedProtocol):
|
||||
def connectionLost(self, reason):
|
||||
self._recepient._connectionLost(reason)
|
||||
|
||||
def readConnectionLost(self):
|
||||
self._recepient._readConnectionLost()
|
||||
|
||||
def writeConnectionLost(self):
|
||||
self._recepient._writeConnectionLost()
|
||||
|
||||
|
||||
class UnbufferedTransport(GreenTransportBase):
|
||||
"""A very simple implementation of a green transport without an additional buffer"""
|
||||
@@ -226,7 +202,7 @@ class UnbufferedTransport(GreenTransportBase):
|
||||
Return '' if connection was closed cleanly, raise the exception if it was closed
|
||||
in a non clean fashion. After that all successive calls return ''.
|
||||
"""
|
||||
if self._read_disconnected_event.ready():
|
||||
if self._disconnected_event.ready():
|
||||
return ''
|
||||
try:
|
||||
return self._wait()
|
||||
@@ -268,29 +244,29 @@ class GreenTransport(GreenTransportBase):
|
||||
|
||||
def read(self, size=-1):
|
||||
"""Read size bytes or until EOF"""
|
||||
if not self._read_disconnected_event.ready():
|
||||
if not self._disconnected_event.ready():
|
||||
try:
|
||||
while len(self._buffer) < size or size < 0:
|
||||
self._buffer += self._wait()
|
||||
except ConnectionDone:
|
||||
pass
|
||||
except:
|
||||
if not self._read_disconnected_event.has_exception():
|
||||
if not self._disconnected_event.has_exception():
|
||||
raise
|
||||
if size>=0:
|
||||
result, self._buffer = self._buffer[:size], self._buffer[size:]
|
||||
else:
|
||||
result, self._buffer = self._buffer, ''
|
||||
if not result and self._read_disconnected_event.has_exception():
|
||||
if not result and self._disconnected_event.has_exception():
|
||||
try:
|
||||
self._read_disconnected_event.wait()
|
||||
self._disconnected_event.wait()
|
||||
except ConnectionDone:
|
||||
pass
|
||||
return result
|
||||
|
||||
def recv(self, buflen=None):
|
||||
"""Receive a single chunk of undefined size but no bigger than buflen"""
|
||||
if not self._read_disconnected_event.ready():
|
||||
if not self._disconnected_event.ready():
|
||||
self.resumeProducing()
|
||||
try:
|
||||
try:
|
||||
@@ -300,7 +276,7 @@ class GreenTransport(GreenTransportBase):
|
||||
except ConnectionDone:
|
||||
pass
|
||||
except:
|
||||
if not self._read_disconnected_event.has_exception():
|
||||
if not self._disconnected_event.has_exception():
|
||||
raise
|
||||
finally:
|
||||
self.pauseProducing()
|
||||
@@ -308,9 +284,9 @@ class GreenTransport(GreenTransportBase):
|
||||
result, self._buffer = self._buffer, ''
|
||||
else:
|
||||
result, self._buffer = self._buffer[:buflen], self._buffer[buflen:]
|
||||
if not result and self._read_disconnected_event.has_exception():
|
||||
if not result and self._disconnected_event.has_exception():
|
||||
try:
|
||||
self._read_disconnected_event.wait()
|
||||
self._disconnected_event.wait()
|
||||
except ConnectionDone:
|
||||
pass
|
||||
return result
|
||||
|
||||
@@ -23,13 +23,9 @@ from twisted.protocols import basic
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from eventlet.twistedutil.protocol import GreenTransportBase
|
||||
|
||||
from twisted.internet.interfaces import IHalfCloseableProtocol
|
||||
from zope.interface import implements
|
||||
|
||||
class LineOnlyReceiver(basic.LineOnlyReceiver):
|
||||
|
||||
implements(IHalfCloseableProtocol)
|
||||
|
||||
def __init__(self, recepient):
|
||||
self._recepient = recepient
|
||||
|
||||
@@ -41,14 +37,6 @@ class LineOnlyReceiver(basic.LineOnlyReceiver):
|
||||
#print '%r conn lost %r' % (self, reason)
|
||||
self._recepient._connectionLost(reason)
|
||||
|
||||
def readConnectionLost(self):
|
||||
#print '%r read conn lost' % self
|
||||
self._recepient._readConnectionLost()
|
||||
|
||||
def writeConnectionLost(self):
|
||||
#print '%r wr conn lost' % self
|
||||
self._recepient._writeConnectionLost()
|
||||
|
||||
def lineReceived(self, line):
|
||||
#print '%r line received %r' % (self, line)
|
||||
self._recepient._got_data(line)
|
||||
|
||||
@@ -18,7 +18,6 @@ print conn.read()
|
||||
# read from SSL connection line by line
|
||||
conn = GreenClientCreator(reactor, LineOnlyReceiverTransport).connectSSL('sf.net', 443, ssl.ClientContextFactory())
|
||||
conn.write('GET / HTTP/1.0\r\n\r\n')
|
||||
#conn.loseWriteConnection()
|
||||
try:
|
||||
for num, line in enumerate(conn):
|
||||
print '%3s %r' % (num, line)
|
||||
|
||||
@@ -47,7 +47,6 @@ From-Path: msrps://alice.example.com:9892/98cjs;tcp
|
||||
|
||||
print 'Sending:\n%s' % request
|
||||
conn.write(request)
|
||||
#conn.loseWriteConnection()
|
||||
print 'Received:'
|
||||
for x in conn:
|
||||
print repr(x)
|
||||
|
||||
@@ -188,52 +188,52 @@ class TestGreenTransport_bufsize1(TestGreenTransport):
|
||||
# self.assertEqual('', self.conn.recv())
|
||||
#
|
||||
|
||||
class TestHalfClose_TCP(LimitedTestCase):
|
||||
|
||||
def _test_server(self, conn):
|
||||
conn.write('hello')
|
||||
conn.loseWriteConnection()
|
||||
self.assertRaises(pr.WriteConnectionDone, conn.write, 'hey')
|
||||
data = conn.read()
|
||||
self.assertEqual('bye', data)
|
||||
conn.loseConnection()
|
||||
self.assertRaises(ConnectionDone, conn._wait)
|
||||
self.check.append('server')
|
||||
|
||||
def setUp(self):
|
||||
LimitedTestCase.setUp(self)
|
||||
self.factory = pr.SpawnFactory(self._test_server)
|
||||
self.port = reactor.listenTCP(0, self.factory)
|
||||
self.conn = pr.GreenClientCreator(reactor).connectTCP('localhost', self.port.getHost().port)
|
||||
self.port.stopListening()
|
||||
self.check = []
|
||||
|
||||
def test(self):
|
||||
conn = self.conn
|
||||
data = conn.read()
|
||||
self.assertEqual('hello', data)
|
||||
conn.write('bye')
|
||||
conn.loseWriteConnection()
|
||||
self.assertRaises(pr.WriteConnectionDone, conn.write, 'hoy')
|
||||
self.factory.waitall()
|
||||
self.assertRaises(ConnectionDone, conn._wait)
|
||||
assert self.check == ['server']
|
||||
|
||||
class TestHalfClose_TLS(TestHalfClose_TCP):
|
||||
|
||||
def setUp(self):
|
||||
LimitedTestCase.setUp(self)
|
||||
from gnutls.crypto import X509PrivateKey, X509Certificate
|
||||
from gnutls.interfaces.twisted import X509Credentials
|
||||
cert = X509Certificate(open('gnutls_valid.crt').read())
|
||||
key = X509PrivateKey(open('gnutls_valid.key').read())
|
||||
server_credentials = X509Credentials(cert, key)
|
||||
self.factory = pr.SpawnFactory(self._test_server)
|
||||
self.port = reactor.listenTLS(0, self.factory, server_credentials)
|
||||
self.conn = pr.GreenClientCreator(reactor).connectTLS('localhost', self.port.getHost().port, X509Credentials())
|
||||
self.port.stopListening()
|
||||
self.check = []
|
||||
|
||||
# class TestHalfClose_TCP(LimitedTestCase):
|
||||
#
|
||||
# def _test_server(self, conn):
|
||||
# conn.write('hello')
|
||||
# conn.loseWriteConnection()
|
||||
# self.assertRaises(pr.ConnectionDone, conn.write, 'hey')
|
||||
# data = conn.read()
|
||||
# self.assertEqual('bye', data)
|
||||
# conn.loseConnection()
|
||||
# self.assertRaises(ConnectionDone, conn._wait)
|
||||
# self.check.append('server')
|
||||
#
|
||||
# def setUp(self):
|
||||
# LimitedTestCase.setUp(self)
|
||||
# self.factory = pr.SpawnFactory(self._test_server)
|
||||
# self.port = reactor.listenTCP(0, self.factory)
|
||||
# self.conn = pr.GreenClientCreator(reactor).connectTCP('localhost', self.port.getHost().port)
|
||||
# self.port.stopListening()
|
||||
# self.check = []
|
||||
#
|
||||
# def test(self):
|
||||
# conn = self.conn
|
||||
# data = conn.read()
|
||||
# self.assertEqual('hello', data)
|
||||
# conn.write('bye')
|
||||
# conn.loseWriteConnection()
|
||||
# self.assertRaises(pr.ConnectionDone, conn.write, 'hoy')
|
||||
# self.factory.waitall()
|
||||
# self.assertRaises(ConnectionDone, conn._wait)
|
||||
# assert self.check == ['server']
|
||||
#
|
||||
# class TestHalfClose_TLS(TestHalfClose_TCP):
|
||||
#
|
||||
# def setUp(self):
|
||||
# LimitedTestCase.setUp(self)
|
||||
# from gnutls.crypto import X509PrivateKey, X509Certificate
|
||||
# from gnutls.interfaces.twisted import X509Credentials
|
||||
# cert = X509Certificate(open('gnutls_valid.crt').read())
|
||||
# key = X509PrivateKey(open('gnutls_valid.key').read())
|
||||
# server_credentials = X509Credentials(cert, key)
|
||||
# self.factory = pr.SpawnFactory(self._test_server)
|
||||
# self.port = reactor.listenTLS(0, self.factory, server_credentials)
|
||||
# self.conn = pr.GreenClientCreator(reactor).connectTLS('localhost', self.port.getHost().port, X509Credentials())
|
||||
# self.port.stopListening()
|
||||
# self.check = []
|
||||
#
|
||||
|
||||
if socket is not None:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user