From 8e27d09a3258a4f128c32b910fc19f9a56e06ed8 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sat, 21 Nov 2009 18:02:16 -0500 Subject: [PATCH] Initial implemenation of using ssl module instead of pyOpenSSL. --- eventlet/api.py | 3 +- eventlet/green/ssl.py | 295 ++++++++++++++++++++++++++++++++++++++++++ eventlet/greenio.py | 49 ++++++- eventlet/util.py | 84 +++++++----- eventlet/wsgi.py | 9 +- tests/api_test.py | 16 ++- tests/greenio_test.py | 18 ++- tests/wsgi_test.py | 7 +- 8 files changed, 423 insertions(+), 58 deletions(-) create mode 100644 eventlet/green/ssl.py diff --git a/eventlet/api.py b/eventlet/api.py index 480b33f..0f2d538 100644 --- a/eventlet/api.py +++ b/eventlet/api.py @@ -58,9 +58,8 @@ def ssl_listener(address, certificate, private_key): spawns greenlets for each incoming connection. """ from eventlet import util - socket = util.wrap_ssl(util.tcp_socket(), certificate, private_key) + socket = util.wrap_ssl(util.tcp_socket(), certificate, private_key, True) util.socket_bind_and_listen(socket, address) - socket.is_secure = True return socket def connect_tcp(address, localaddr=None): diff --git a/eventlet/green/ssl.py b/eventlet/green/ssl.py new file mode 100644 index 0000000..969c564 --- /dev/null +++ b/eventlet/green/ssl.py @@ -0,0 +1,295 @@ +__ssl = __import__('ssl') + +for attr in dir(__ssl): + exec "%s = __ssl.%s" % (attr, attr) + +import errno +import time + +from eventlet.api import trampoline, getcurrent +from thread import get_ident +from eventlet.greenio import set_nonblocking, GreenSocket, GreenSSLObject, SOCKET_CLOSED, CONNECT_ERR, CONNECT_SUCCESS +orig_socket = __import__('socket') +socket = orig_socket.socket + + +class GreenSSLSocket(__ssl.SSLSocket): + """ This is a green version of the SSLSocket class from the ssl module added + in 2.6. For documentation on it, please see the Python standard + documentation.""" + # we are inheriting from SSLSocket because its constructor calls + # do_handshake whose behavior we wish to override + def __init__(self, sock, *args, **kw): + if not isinstance(sock, GreenSocket): + sock = GreenSocket(sock) + + self.act_non_blocking = sock.act_non_blocking + self.timeout = sock.timeout + super(GreenSSLSocket, self).__init__(sock.fd, *args, **kw) + del sock + + # the superclass initializer trashes the methods so... + self.send = lambda data, flags=0: GreenSSLSocket.send(self, data, flags) + self.sendto = lambda data, addr, flags=0: GreenSSLSocket.sendto(self, data, addr, flags) + self.recv = lambda buflen=1024, flags=0: GreenSSLSocket.recv(self, buflen, flags) + self.recvfrom = lambda addr, buflen=1024, flags=0: GreenSSLSocket.recvfrom(self, addr, buflen, flags) + self.recv_into = lambda buffer, nbytes=None, flags=0: GreenSSLSocket.recv_into(self, buffer, nbytes, flags) + self.recvfrom_into = lambda buffer, nbytes=None, flags=0: GreenSSLSocket.recvfrom_into(self, buffer, nbytes, flags) + + def settimeout(self, timeout): + self.timeout = timeout + + def gettimeout(self): + return self.timeout + + setblocking = GreenSocket.setblocking + + def _call_trampolining(self, func, *a, **kw): + if self.act_non_blocking: + return func(*a, **kw) + else: + while True: + try: + return func(*a, **kw) + except SSLError, exc: + if exc[0] == SSL_ERROR_WANT_READ: + trampoline(self.fileno(), + read=True, + timeout=self.gettimeout(), + timeout_exc=SSLError) + elif exc[0] == SSL_ERROR_WANT_WRITE: + trampoline(self.fileno(), + write=True, + timeout=self.gettimeout(), + timeout_exc=SSLError) + else: + raise + + + def write(self, data): + """Write DATA to the underlying SSL channel. Returns + number of bytes of DATA actually transmitted.""" + return self._call_trampolining( + super(GreenSSLSocket, self).write, data) + + def read(self, len=1024): + """Read up to LEN bytes and return them. + Return zero-length string on EOF.""" + return self._call_trampolining( + super(GreenSSLSocket, self).read,len) + + def send (self, data, flags=0): + # *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to send() on %s" % + self.__class__) + while True: + try: + v = self._sslobj.write(data) + except SSLError, x: + if x.args[0] == SSL_ERROR_WANT_READ: + return 0 + elif x.args[0] == SSL_ERROR_WANT_WRITE: + return 0 + else: + raise + else: + return v + else: + while True: + try: + return socket.send(self, data, flags) + except orig_socket.error, e: + if self.act_non_blocking: + raise + if e[0] == errno.EWOULDBLOCK or \ + e[0] == errno.ENOTCONN: + return 0 + raise + + def sendto (self, data, addr, flags=0): + # *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is + if self._sslobj: + raise ValueError("sendto not allowed on instances of %s" % + self.__class__) + else: + trampoline(self.fileno(), write=True, timeout_exc=orig_socket.timeout) + return socket.sendto(self, data, addr, flags) + + def sendall (self, data, flags=0): + # *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to sendall() on %s" % + self.__class__) + amount = len(data) + count = 0 + while (count < amount): + v = self.send(data[count:]) + count += v + return amount + else: + while True: + try: + return socket.sendall(self, buflen, flags) + except orig_socket.error, e: + if self.act_non_blocking: + raise + if e[0] == errno.EWOULDBLOCK: + trampoline(self.fileno(), write=True, + timeout=self.gettimeout(), timeout_exc=orig_socket.timeout) + if e[0] in SOCKET_CLOSED: + return '' + raise + + def recv(self, buflen=1024, flags=0): + # *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to recv() on %s" % + self.__class__) + read = self.read(buflen) + return read + else: + while True: + try: + return socket.recv(self, buflen, flags) + except orig_socket.error, e: + if self.act_non_blocking: + raise + if e[0] == errno.EWOULDBLOCK: + trampoline(self.fileno(), read=True, + timeout=self.gettimeout(), timeout_exc=orig_socket.timeout) + if e[0] in SOCKET_CLOSED: + return '' + raise + + + def recv_into (self, buffer, nbytes=None, flags=0): + if not self.act_non_blocking: + trampoline(self.fileno(), read=True, timeout=self.gettimeout(), timeout_exc=orig_socket.timeout) + return super(GreenSSLSocket, self).recv_into(buffer, nbytes, flags) + + def recvfrom (self, addr, buflen=1024, flags=0): + if not self.act_non_blocking: + trampoline(self.fileno(), read=True, timeout=self.gettimeout(), timeout_exc=orig_socket.timeout) + return super(GreenSSLSocket, self).recvfrom(addr, buflen, flags) + + def recvfrom_into (self, buffer, nbytes=None, flags=0): + if not self.act_non_blocking: + trampoline(self.fileno(), read=True, timeout=self.gettimeout(), timeout_exc=orig_socket.timeout) + return super(GreenSSLSocket, self).recvfrom_into(buffer, nbytes, flags) + + def unwrap(self): + return GreenSocket(super(GreenSSLSocket, self).unwrap()) + + def do_handshake(self): + """Perform a TLS/SSL handshake.""" + return self._call_trampolining( + super(GreenSSLSocket, self).do_handshake) + + def _socket_connect(self, addr): + real_connect = socket.connect + if self.act_non_blocking: + return real_connect(self, addr) + else: + # *NOTE: gross, copied code from greenio because it's not factored + # well enough to reuse + if self.gettimeout() is None: + while True: + try: + return real_connect(self, addr) + except orig_socket.error, exc: + if exc[0] in CONNECT_ERR: + trampoline(self.fileno(), write=True) + elif exc[0] in CONNECT_SUCCESS: + return + else: + raise + else: + end = time.time() + self.gettimeout() + while True: + try: + real_connect(self, addr) + except orig_socket.error, exc: + if exc[0] in CONNECT_ERR: + trampoline(self.fileno(), write=True, + timeout=end-time.time(), timeout_exc=orig_socket.timeout) + elif exc[0] in CONNECT_SUCCESS: + return + else: + raise + if time.time() >= end: + raise orig_socket.timeout + + + def connect(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + # *NOTE: grrrrr copied this code from ssl.py because of the reference + # to socket.connect which we don't want to call directly + if self._sslobj: + raise ValueError("attempt to connect already-connected SSLSocket!") + self._socket_connect(addr) + self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, + self.cert_reqs, self.ssl_version, + self.ca_certs) + if self.do_handshake_on_connect: + self.do_handshake() + + def accept(self): + """Accepts a new connection from a remote client, and returns + a tuple containing that new connection wrapped with a server-side + SSL channel, and the address of the remote client.""" + # RDW grr duplication of code from greenio + if self.act_non_blocking: + newsock, addr = socket.accept(self) + else: + while True: + try: + newsock, addr = socket.accept(self) + set_nonblocking(newsock) + break + except orig_socket.error, e: + if e[0] != errno.EWOULDBLOCK: + raise + trampoline(self.fileno(), read=True, timeout=self.gettimeout(), + timeout_exc=orig_socket.timeout) + + new_ssl = type(self)(newsock, + keyfile=self.keyfile, + certfile=self.certfile, + server_side=True, + cert_reqs=self.cert_reqs, + ssl_version=self.ssl_version, + ca_certs=self.ca_certs, + do_handshake_on_connect=self.do_handshake_on_connect, + suppress_ragged_eofs=self.suppress_ragged_eofs) + return (new_ssl, addr) + + +SSLSocket = GreenSSLSocket + +def wrap_socket(sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True): + return GreenSSLSocket(sock, keyfile=keyfile, certfile=certfile, + server_side=server_side, cert_reqs=cert_reqs, + ssl_version=ssl_version, ca_certs=ca_certs, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs) + + +def sslwrap_simple(sock, keyfile=None, certfile=None): + """A replacement for the old socket.ssl function. Designed + for compability with Python 2.5 and earlier. Will disappear in + Python 3.0.""" + ssl_sock = GreenSSLSocket(sock, 0, keyfile, certfile, CERT_NONE, + PROTOCOL_SSLv23, None) + return GreenSSLObject(ssl_sock) diff --git a/eventlet/greenio.py b/eventlet/greenio.py index 3320967..27858eb 100644 --- a/eventlet/greenio.py +++ b/eventlet/greenio.py @@ -146,7 +146,6 @@ def set_nonblocking(fd): class GreenSocket(object): - is_secure = False timeout = None def __init__(self, family_or_realsock=socket.AF_INET, *args, **kwargs): if isinstance(family_or_realsock, (int, long)): @@ -689,23 +688,59 @@ class GreenSSL(GreenSocket): def want_write(self, *args, **kw): fn = self.want_write = self.fd.want_write return fn(*args, **kw) - - + + +def shutdown_safe(sock): + """ Shuts down the socket. This is a convenience method for + code that wants to gracefully handle regular sockets, SSL.Connection + sockets from PyOpenSSL and ssl.SSLSocket objects from Python 2.6 + interchangeably. Both types of ssl socket require a shutdown() before + close, but they have different arity on their shutdown method. + + Regular sockets don't need a shutdown before close, but it doesn't hurt. + """ + try: + try: + # socket, ssl.SSLSocket + return sock.shutdown(socket.SHUT_RDWR) + except TypeError: + # SSL.Connection + return sock.shutdown() + except socket.error, e: + # we don't care if the socket is already closed; + # this will often be the case in an http server context + if e[0] != errno.ENOTCONN: + raise + + def _convert_to_sslerror(ex): """ Transliterates SSL.SysCallErrors to socket.sslerrors""" return socket.sslerror((ex[0], ex[1])) + class GreenSSLObject(object): """ Wrapper object around the SSLObjects returned by socket.ssl, which have a slightly different interface from SSL.Connection objects. """ def __init__(self, green_ssl_obj): """ Should only be called by a 'green' socket.ssl """ - assert isinstance(green_ssl_obj, GreenSSL) + try: + from eventlet.green.ssl import GreenSSLSocket + except ImportError: + class GreenSSLSocket(object): + pass + + assert isinstance(green_ssl_obj, (GreenSSL, GreenSSLSocket)) self.connection = green_ssl_obj try: - self.connection.do_handshake() - except SSL.SysCallError, e: - raise _convert_to_sslerror(e) + # if it's already connected, do the handshake + self.connection.getpeername() + except: + pass + else: + try: + self.connection.do_handshake() + except SSL.SysCallError, e: + raise _convert_to_sslerror(e) def read(self, n=None): """If n is provided, read n bytes from the SSL connection, otherwise read diff --git a/eventlet/util.py b/eventlet/util.py index 7ac9032..86d34d1 100644 --- a/eventlet/util.py +++ b/eventlet/util.py @@ -35,39 +35,49 @@ def tcp_socket(): s = __original_socket__(socket.AF_INET, socket.SOCK_STREAM) return s - try: - try: - import ssl - __original_ssl__ = ssl.wrap_socket - except ImportError: - __original_ssl__ = socket.ssl -except AttributeError: - __original_ssl__ = None + # if ssl is available, use eventlet.green.ssl for our ssl implementation + import ssl as _ssl + def wrap_ssl(sock, certificate=None, private_key=None, server_side=False): + from eventlet.green import ssl + return ssl.wrap_socket(sock, + keyfile=private_key, certfile=certificate, + server_side=server_side, cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True) - -def wrap_ssl(sock, certificate=None, private_key=None): - from OpenSSL import SSL - from eventlet import greenio - context = SSL.Context(SSL.SSLv23_METHOD) - if certificate is not None: - context.use_certificate_file(certificate) - if private_key is not None: - context.use_privatekey_file(private_key) - context.set_verify(SSL.VERIFY_NONE, lambda *x: True) - - ## TODO only do this on client sockets? how? - connection = SSL.Connection(context, sock) - connection.set_connect_state() - return greenio.GreenSSL(connection) - - -def wrap_ssl_obj(sock, certificate=None, private_key=None): - """ For 100% compatibility with the socket module, this wraps and handshakes an - open connection, returning a SSLObject.""" - from eventlet import greenio - wrapped = wrap_ssl(sock, certificate, private_key) - return greenio.GreenSSLObject(wrapped) + def wrap_ssl_obj(sock, certificate=None, private_key=None): + from eventlet import ssl + warnings.warn("socket.ssl() is deprecated. Use ssl.wrap_socket() instead.", + DeprecationWarning, stacklevel=2) + return ssl.sslwrap_simple(sock, keyfile, certfile) + +except ImportError: + # if ssl is not available, use PyOpenSSL + def wrap_ssl(sock, certificate=None, private_key=None, server_side=False): + from OpenSSL import SSL + from eventlet import greenio + context = SSL.Context(SSL.SSLv23_METHOD) + if certificate is not None: + context.use_certificate_file(certificate) + if private_key is not None: + context.use_privatekey_file(private_key) + context.set_verify(SSL.VERIFY_NONE, lambda *x: True) + + connection = SSL.Connection(context, sock) + if server_side: + connection.set_accept_state() + else: + connection.set_connect_state() + return greenio.GreenSSL(connection) + + def wrap_ssl_obj(sock, certificate=None, private_key=None): + """ For 100% compatibility with the socket module, this wraps and handshakes an + open connection, returning a SSLObject.""" + from eventlet import greenio + wrapped = wrap_ssl(sock, certificate, private_key) + return greenio.GreenSSLObject(wrapped) socket_already_wrapped = False def wrap_socket_with_coroutine_socket(use_thread_pool=True): @@ -80,8 +90,13 @@ def wrap_socket_with_coroutine_socket(use_thread_pool=True): return greenio.GreenSocket(__original_socket__(*args, **kw)) socket.socket = new_socket - # for 100% compatibility, return a GreenSSLObject - socket.ssl = wrap_ssl_obj + socket.ssl = wrap_ssl_obj + try: + import ssl as _ssl + from eventlet.green import ssl + _ssl.wrap_socket = ssl.wrap_socket + except ImportError: + pass if use_thread_pool: try: @@ -254,8 +269,7 @@ def set_reuse_addr(descriptor): descriptor.setsockopt( socket.SOL_SOCKET, socket.SO_REUSEADDR, - descriptor.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1, - ) + descriptor.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1) except socket.error: pass diff --git a/eventlet/wsgi.py b/eventlet/wsgi.py index b1cf368..eea2786 100644 --- a/eventlet/wsgi.py +++ b/eventlet/wsgi.py @@ -29,7 +29,6 @@ def format_date_time(timestamp): _weekdayname[wd], day, _monthname[month], year, hh, mm, ss ) - class Input(object): def __init__(self, rfile, @@ -349,8 +348,7 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler): def finish(self): BaseHTTPServer.BaseHTTPRequestHandler.finish(self) - if self.connection.is_secure: - self.connection.shutdown() + greenio.shutdown_safe(self.connection) self.connection.close() @@ -429,7 +427,7 @@ def server(sock, site, try: host, port = sock.getsockname() port = ':%s' % (port, ) - if sock.is_secure: + if hasattr(sock, 'do_handshake'): scheme = 'https' if port == ':443': port = '' @@ -452,8 +450,7 @@ def server(sock, site, break finally: try: - if sock.is_secure: - sock.shutdown() + greenio.shutdown_safe(sock) sock.close() except socket.error, e: if e[0] != errno.EPIPE: diff --git a/tests/api_test.py b/tests/api_test.py index 9a49745..9bf6258 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -65,10 +65,10 @@ class TestApi(TestCase): try: conn, addr = listenfd.accept() conn.write('hello\r\n') - conn.shutdown() + greenio.shutdown_safe(conn) conn.close() finally: - listenfd.shutdown() + greenio.shutdown_safe(listenfd) listenfd.close() server = api.ssl_listener(('0.0.0.0', 0), @@ -76,13 +76,17 @@ class TestApi(TestCase): self.private_key_file) api.spawn(accept_once, server) - client = util.wrap_ssl( - api.connect_tcp(('127.0.0.1', server.getsockname()[1]))) + raw_client = api.connect_tcp(('127.0.0.1', server.getsockname()[1])) + client = util.wrap_ssl(raw_client) fd = socket._fileobject(client, 'rb', 8192) assert fd.readline() == 'hello\r\n' - self.assertRaises(greenio.SSL.ZeroReturnError, fd.read) - client.shutdown() + try: + self.assertEquals('', fd.read(10)) + except greenio.SSL.ZeroReturnError: + # if it's a GreenSSL object it'll do this + pass + greenio.shutdown_safe(client) client.close() check_hub() diff --git a/tests/greenio_test.py b/tests/greenio_test.py index 64863dc..38777c3 100644 --- a/tests/greenio_test.py +++ b/tests/greenio_test.py @@ -275,7 +275,7 @@ class SSLTest(LimitedTestCase): def serve(listener): sock, addr = listener.accept() sock.write('content') - sock.shutdown() + greenio.shutdown_safe(sock) sock.close() listener = api.ssl_listener(('', 0), self.certificate_file, @@ -285,6 +285,22 @@ class SSLTest(LimitedTestCase): client = greenio.GreenSSLObject(client) self.assertEquals(client.read(1024), 'content') self.assertEquals(client.read(1024), '') + + def test_ssl_close(self): + def serve(listener): + sock, addr = listener.accept() + stuff = sock.read(8192) + empt = sock.read(8192) + + sock = api.ssl_listener(('127.0.0.1', 0), self.certificate_file, self.private_key_file) + server_coro = coros.execute(serve, sock) + + raw_client = api.connect_tcp(('127.0.0.1', sock.getsockname()[1])) + client = util.wrap_ssl(raw_client) + client.write('X') + greenio.shutdown_safe(client) + client.close() + server_coro.wait() if __name__ == '__main__': main() diff --git a/tests/wsgi_test.py b/tests/wsgi_test.py index e50f7a0..bbc9631 100644 --- a/tests/wsgi_test.py +++ b/tests/wsgi_test.py @@ -1,10 +1,12 @@ import cgi import os +import socket from tests import skipped, LimitedTestCase from unittest import main from eventlet import api from eventlet import util +from eventlet import greenio from eventlet import wsgi from eventlet import processes @@ -368,6 +370,8 @@ class TestHttpd(LimitedTestCase): serv.process_request(client_socket) return True except: + import traceback + traceback.print_exc() return False def wsgi_app(environ, start_response): @@ -385,7 +389,8 @@ class TestHttpd(LimitedTestCase): client = api.connect_tcp(('localhost', sock.getsockname()[1])) client = util.wrap_ssl(client) client.write('X') # non-empty payload so that SSL handshake occurs - client.shutdown() + greenio.shutdown_safe(client) + client.close() success = server_coro.wait() self.assert_(success)