twistedutil fix: SpawnFactory could spawn handler with unconnected transport (connectionMade); changed Protocol not to hold references to gtransport but send transport via _queue

This commit is contained in:
Denis Bilenko
2008-12-19 15:18:47 +06:00
parent 27693e7215
commit 765afcedb7
3 changed files with 59 additions and 25 deletions

View File

@@ -3,6 +3,7 @@ import sys
from twisted.internet.protocol import Protocol as twistedProtocol from twisted.internet.protocol import Protocol as twistedProtocol
from twisted.internet.error import ConnectionDone from twisted.internet.error import ConnectionDone
from twisted.internet.protocol import Factory, ClientFactory from twisted.internet.protocol import Factory, ClientFactory
from twisted.python import failure
from eventlet.api import spawn from eventlet.api import spawn
from eventlet.coros import queue, event from eventlet.coros import queue, event
@@ -35,7 +36,7 @@ class GreenTransportBase(object):
def build_protocol(self): def build_protocol(self):
# note to subclassers: self._queue must have send and send_exception that never block # note to subclassers: self._queue must have send and send_exception that never block
self._queue = queue() self._queue = queue()
protocol = self.protocol_class(self, self._queue) protocol = self.protocol_class(self._queue)
return protocol return protocol
def _wait(self): def _wait(self):
@@ -70,29 +71,28 @@ class GreenTransportBase(object):
if self.paused==1: if self.paused==1:
self.transport.pauseProducing() self.transport.pauseProducing()
def init_transport_producer(self, transport): def _init_transport_producer(self):
transport.pauseProducing() self.transport.pauseProducing()
self.paused = 1 self.paused = 1
def init_transport(self, transport): def _init_transport(self):
transport = self._queue.wait()
self.transport = transport
if self.transportBufferSize is not None: if self.transportBufferSize is not None:
transport.bufferSize = self.transportBufferSize transport.bufferSize = self.transportBufferSize
self.init_transport_producer(transport) self._init_transport_producer()
if self.write_event is None: if self.write_event is None:
self.write_event = event() self.write_event = event()
self.write_event.send(1) self.write_event.send(1)
transport.registerProducer(Producer2Event(self.write_event), True) transport.registerProducer(Producer2Event(self.write_event), True)
self.transport = transport
class Protocol(twistedProtocol): class Protocol(twistedProtocol):
def __init__(self, gtransport, queue): def __init__(self, queue):
self.gtransport = gtransport
self._queue = queue self._queue = queue
def connectionMade(self): def connectionMade(self):
self.gtransport.init_transport(self.transport) self._queue.send(self.transport)
del self.gtransport
def dataReceived(self, data): def dataReceived(self, data):
self._queue.send(data) self._queue.send(data)
@@ -232,7 +232,6 @@ class GreenInstanceFactory(ClientFactory):
self.event = event self.event = event
def buildProtocol(self, addr): def buildProtocol(self, addr):
self.event.send(self.instance)
return self.instance return self.instance
def clientConnectionFailed(self, connector, reason): def clientConnectionFailed(self, connector, reason):
@@ -255,31 +254,31 @@ class GreenClientCreator(object):
def _make_transport_and_factory(self): def _make_transport_and_factory(self):
gtransport = self.gtransport_class(*self.args, **self.kwargs) gtransport = self.gtransport_class(*self.args, **self.kwargs)
protocol = gtransport.build_protocol() protocol = gtransport.build_protocol()
factory = GreenInstanceFactory(protocol, event()) factory = GreenInstanceFactory(protocol, gtransport._queue)
return gtransport, factory return gtransport, factory
def connectTCP(self, host, port, *args, **kwargs): def connectTCP(self, host, port, *args, **kwargs):
gtransport, factory = self._make_transport_and_factory() gtransport, factory = self._make_transport_and_factory()
self.reactor.connectTCP(host, port, factory, *args, **kwargs) self.reactor.connectTCP(host, port, factory, *args, **kwargs)
factory.event.wait() gtransport._init_transport()
return gtransport return gtransport
def connectSSL(self, host, port, *args, **kwargs): def connectSSL(self, host, port, *args, **kwargs):
gtransport, factory = self._make_transport_and_factory() gtransport, factory = self._make_transport_and_factory()
self.reactor.connectSSL(host, port, factory, *args, **kwargs) self.reactor.connectSSL(host, port, factory, *args, **kwargs)
factory.event.wait() gtransport._init_transport()
return gtransport return gtransport
def connectTLS(self, host, port, *args, **kwargs): def connectTLS(self, host, port, *args, **kwargs):
gtransport, factory = self._make_transport_and_factory() gtransport, factory = self._make_transport_and_factory()
self.reactor.connectTLS(host, port, factory, *args, **kwargs) self.reactor.connectTLS(host, port, factory, *args, **kwargs)
factory.event.wait() gtransport._init_transport()
return gtransport return gtransport
def connectUNIX(self, address, *args, **kwargs): def connectUNIX(self, address, *args, **kwargs):
gtransport, factory = self._make_transport_and_factory() gtransport, factory = self._make_transport_and_factory()
self.reactor.connectUNIX(address, factory, *args, **kwargs) self.reactor.connectUNIX(address, factory, *args, **kwargs)
factory.event.wait() gtransport._init_transport()
return gtransport return gtransport
def connectSRV(self, service, domain, *args, **kwargs): def connectSRV(self, service, domain, *args, **kwargs):
@@ -289,7 +288,7 @@ class GreenClientCreator(object):
gtransport, factory = self._make_transport_and_factory() gtransport, factory = self._make_transport_and_factory()
c = SRVConnector(self.reactor, service, domain, factory, *args, **kwargs) c = SRVConnector(self.reactor, service, domain, factory, *args, **kwargs)
c.connect() c.connect()
factory.event.wait() gtransport._init_transport()
return gtransport return gtransport
def connect(self, required_args, ConnectorClass, *rest_args, **rest_kwargs): def connect(self, required_args, ConnectorClass, *rest_args, **rest_kwargs):
@@ -316,5 +315,18 @@ class SpawnFactory(Factory):
gtransport = self.gtransport_class(*self.args, **self.kwargs) gtransport = self.gtransport_class(*self.args, **self.kwargs)
protocol = gtransport.build_protocol() protocol = gtransport.build_protocol()
protocol.factory = self protocol.factory = self
spawn(self.handler, gtransport) spawn(self._spawn, gtransport, protocol)
return protocol return protocol
def _spawn(self, gtransport, protocol):
try:
gtransport._init_transport()
except Exception:
self._log_error(failure.Failure(), gtransport, protocol)
else:
spawn(self.handler, gtransport)
def _log_error(self, failure, gtransport, protocol):
from twisted.python import log
log.msg('%s: %s' % (protocol.transport.getPeer(), failure.getErrorMessage()))

View File

@@ -4,13 +4,11 @@ from eventlet.twistedutil.protocol import GreenTransportBase
class LineOnlyReceiver(basic.LineOnlyReceiver): class LineOnlyReceiver(basic.LineOnlyReceiver):
def __init__(self, gtransport, queue): def __init__(self, queue):
self.gtransport = gtransport
self._queue = queue self._queue = queue
def connectionMade(self): def connectionMade(self):
self.gtransport.init_transport(self.transport) self._queue.send(self.transport)
del self.gtransport
def lineReceived(self, line): def lineReceived(self, line):
self._queue.send(line) self._queue.send(line)

View File

@@ -4,12 +4,12 @@ exit_unless_twisted()
import sys import sys
import unittest import unittest
from twisted.internet.error import ConnectionLost, ConnectionDone from twisted.internet.error import ConnectionDone
from twisted.python import failure
import eventlet.twistedutil.protocol as pr import eventlet.twistedutil.protocol as pr
from eventlet.twistedutil.protocols.basic import LineOnlyReceiverTransport from eventlet.twistedutil.protocols.basic import LineOnlyReceiverTransport
from eventlet.api import spawn, sleep, with_timeout, call_after from eventlet.api import spawn, sleep, with_timeout, call_after
from eventlet.coros import event
from eventlet.green import socket from eventlet.green import socket
DELAY=0.01 DELAY=0.01
@@ -181,6 +181,30 @@ class TestGreenTransportError(TestCase):
# self.assertEqual('', self.conn.recv()) # self.assertEqual('', self.conn.recv())
# #
class TestTLSError(unittest.TestCase):
def test_server_connectionMade_never_called(self):
# trigger case when protocol instance is created,
# but it's connectionMade is never called
from gnutls.interfaces.twisted import X509Credentials
from gnutls.errors import GNUTLSError
cred = X509Credentials(None, None)
ev = event()
def handle(conn):
ev.send("handle must not be called")
s = reactor.listenTLS(0, pr.SpawnFactory(handle, LineOnlyReceiverTransport), cred)
creator = pr.GreenClientCreator(reactor, LineOnlyReceiverTransport)
try:
conn = creator.connectTLS('127.0.0.1', s.getHost().port, cred)
except GNUTLSError:
pass
assert ev.poll() is None, repr(ev.poll())
try:
import gnutls.interfaces.twisted
except ImportError:
del TestTLSError
if __name__=='__main__': if __name__=='__main__':
unittest.main() unittest.main()