twistedutil fix: SpawnFactory could spawn handler with unconnected transport (connectionMade); changed Protocol not to hold references to gtransport but send transport via _queue

This commit is contained in:
Denis Bilenko
2008-12-19 15:18:47 +06:00
parent 27693e7215
commit 765afcedb7
3 changed files with 59 additions and 25 deletions

View File

@@ -3,6 +3,7 @@ import sys
from twisted.internet.protocol import Protocol as twistedProtocol
from twisted.internet.error import ConnectionDone
from twisted.internet.protocol import Factory, ClientFactory
from twisted.python import failure
from eventlet.api import spawn
from eventlet.coros import queue, event
@@ -35,7 +36,7 @@ class GreenTransportBase(object):
def build_protocol(self):
# note to subclassers: self._queue must have send and send_exception that never block
self._queue = queue()
protocol = self.protocol_class(self, self._queue)
protocol = self.protocol_class(self._queue)
return protocol
def _wait(self):
@@ -70,29 +71,28 @@ class GreenTransportBase(object):
if self.paused==1:
self.transport.pauseProducing()
def init_transport_producer(self, transport):
transport.pauseProducing()
def _init_transport_producer(self):
self.transport.pauseProducing()
self.paused = 1
def init_transport(self, transport):
def _init_transport(self):
transport = self._queue.wait()
self.transport = transport
if self.transportBufferSize is not None:
transport.bufferSize = self.transportBufferSize
self.init_transport_producer(transport)
self._init_transport_producer()
if self.write_event is None:
self.write_event = event()
self.write_event.send(1)
transport.registerProducer(Producer2Event(self.write_event), True)
self.transport = transport
class Protocol(twistedProtocol):
def __init__(self, gtransport, queue):
self.gtransport = gtransport
def __init__(self, queue):
self._queue = queue
def connectionMade(self):
self.gtransport.init_transport(self.transport)
del self.gtransport
self._queue.send(self.transport)
def dataReceived(self, data):
self._queue.send(data)
@@ -232,7 +232,6 @@ class GreenInstanceFactory(ClientFactory):
self.event = event
def buildProtocol(self, addr):
self.event.send(self.instance)
return self.instance
def clientConnectionFailed(self, connector, reason):
@@ -255,31 +254,31 @@ class GreenClientCreator(object):
def _make_transport_and_factory(self):
gtransport = self.gtransport_class(*self.args, **self.kwargs)
protocol = gtransport.build_protocol()
factory = GreenInstanceFactory(protocol, event())
factory = GreenInstanceFactory(protocol, gtransport._queue)
return gtransport, factory
def connectTCP(self, host, port, *args, **kwargs):
gtransport, factory = self._make_transport_and_factory()
self.reactor.connectTCP(host, port, factory, *args, **kwargs)
factory.event.wait()
gtransport._init_transport()
return gtransport
def connectSSL(self, host, port, *args, **kwargs):
gtransport, factory = self._make_transport_and_factory()
self.reactor.connectSSL(host, port, factory, *args, **kwargs)
factory.event.wait()
gtransport._init_transport()
return gtransport
def connectTLS(self, host, port, *args, **kwargs):
gtransport, factory = self._make_transport_and_factory()
self.reactor.connectTLS(host, port, factory, *args, **kwargs)
factory.event.wait()
gtransport._init_transport()
return gtransport
def connectUNIX(self, address, *args, **kwargs):
gtransport, factory = self._make_transport_and_factory()
self.reactor.connectUNIX(address, factory, *args, **kwargs)
factory.event.wait()
gtransport._init_transport()
return gtransport
def connectSRV(self, service, domain, *args, **kwargs):
@@ -289,7 +288,7 @@ class GreenClientCreator(object):
gtransport, factory = self._make_transport_and_factory()
c = SRVConnector(self.reactor, service, domain, factory, *args, **kwargs)
c.connect()
factory.event.wait()
gtransport._init_transport()
return gtransport
def connect(self, required_args, ConnectorClass, *rest_args, **rest_kwargs):
@@ -316,5 +315,18 @@ class SpawnFactory(Factory):
gtransport = self.gtransport_class(*self.args, **self.kwargs)
protocol = gtransport.build_protocol()
protocol.factory = self
spawn(self.handler, gtransport)
spawn(self._spawn, gtransport, protocol)
return protocol
def _spawn(self, gtransport, protocol):
try:
gtransport._init_transport()
except Exception:
self._log_error(failure.Failure(), gtransport, protocol)
else:
spawn(self.handler, gtransport)
def _log_error(self, failure, gtransport, protocol):
from twisted.python import log
log.msg('%s: %s' % (protocol.transport.getPeer(), failure.getErrorMessage()))

View File

@@ -4,13 +4,11 @@ from eventlet.twistedutil.protocol import GreenTransportBase
class LineOnlyReceiver(basic.LineOnlyReceiver):
def __init__(self, gtransport, queue):
self.gtransport = gtransport
def __init__(self, queue):
self._queue = queue
def connectionMade(self):
self.gtransport.init_transport(self.transport)
del self.gtransport
self._queue.send(self.transport)
def lineReceived(self, line):
self._queue.send(line)

View File

@@ -4,12 +4,12 @@ exit_unless_twisted()
import sys
import unittest
from twisted.internet.error import ConnectionLost, ConnectionDone
from twisted.python import failure
from twisted.internet.error import ConnectionDone
import eventlet.twistedutil.protocol as pr
from eventlet.twistedutil.protocols.basic import LineOnlyReceiverTransport
from eventlet.api import spawn, sleep, with_timeout, call_after
from eventlet.coros import event
from eventlet.green import socket
DELAY=0.01
@@ -181,6 +181,30 @@ class TestGreenTransportError(TestCase):
# self.assertEqual('', self.conn.recv())
#
class TestTLSError(unittest.TestCase):
def test_server_connectionMade_never_called(self):
# trigger case when protocol instance is created,
# but it's connectionMade is never called
from gnutls.interfaces.twisted import X509Credentials
from gnutls.errors import GNUTLSError
cred = X509Credentials(None, None)
ev = event()
def handle(conn):
ev.send("handle must not be called")
s = reactor.listenTLS(0, pr.SpawnFactory(handle, LineOnlyReceiverTransport), cred)
creator = pr.GreenClientCreator(reactor, LineOnlyReceiverTransport)
try:
conn = creator.connectTLS('127.0.0.1', s.getHost().port, cred)
except GNUTLSError:
pass
assert ev.poll() is None, repr(ev.poll())
try:
import gnutls.interfaces.twisted
except ImportError:
del TestTLSError
if __name__=='__main__':
unittest.main()