Initial implemenation of using ssl module instead of pyOpenSSL.

This commit is contained in:
Ryan Williams
2009-11-21 18:02:16 -05:00
parent 87c3d5200a
commit 8e27d09a32
8 changed files with 423 additions and 58 deletions

View File

@@ -58,9 +58,8 @@ def ssl_listener(address, certificate, private_key):
spawns greenlets for each incoming connection. spawns greenlets for each incoming connection.
""" """
from eventlet import util from eventlet import util
socket = util.wrap_ssl(util.tcp_socket(), certificate, private_key) socket = util.wrap_ssl(util.tcp_socket(), certificate, private_key, True)
util.socket_bind_and_listen(socket, address) util.socket_bind_and_listen(socket, address)
socket.is_secure = True
return socket return socket
def connect_tcp(address, localaddr=None): def connect_tcp(address, localaddr=None):

295
eventlet/green/ssl.py Normal file
View File

@@ -0,0 +1,295 @@
__ssl = __import__('ssl')
for attr in dir(__ssl):
exec "%s = __ssl.%s" % (attr, attr)
import errno
import time
from eventlet.api import trampoline, getcurrent
from thread import get_ident
from eventlet.greenio import set_nonblocking, GreenSocket, GreenSSLObject, SOCKET_CLOSED, CONNECT_ERR, CONNECT_SUCCESS
orig_socket = __import__('socket')
socket = orig_socket.socket
class GreenSSLSocket(__ssl.SSLSocket):
""" This is a green version of the SSLSocket class from the ssl module added
in 2.6. For documentation on it, please see the Python standard
documentation."""
# we are inheriting from SSLSocket because its constructor calls
# do_handshake whose behavior we wish to override
def __init__(self, sock, *args, **kw):
if not isinstance(sock, GreenSocket):
sock = GreenSocket(sock)
self.act_non_blocking = sock.act_non_blocking
self.timeout = sock.timeout
super(GreenSSLSocket, self).__init__(sock.fd, *args, **kw)
del sock
# the superclass initializer trashes the methods so...
self.send = lambda data, flags=0: GreenSSLSocket.send(self, data, flags)
self.sendto = lambda data, addr, flags=0: GreenSSLSocket.sendto(self, data, addr, flags)
self.recv = lambda buflen=1024, flags=0: GreenSSLSocket.recv(self, buflen, flags)
self.recvfrom = lambda addr, buflen=1024, flags=0: GreenSSLSocket.recvfrom(self, addr, buflen, flags)
self.recv_into = lambda buffer, nbytes=None, flags=0: GreenSSLSocket.recv_into(self, buffer, nbytes, flags)
self.recvfrom_into = lambda buffer, nbytes=None, flags=0: GreenSSLSocket.recvfrom_into(self, buffer, nbytes, flags)
def settimeout(self, timeout):
self.timeout = timeout
def gettimeout(self):
return self.timeout
setblocking = GreenSocket.setblocking
def _call_trampolining(self, func, *a, **kw):
if self.act_non_blocking:
return func(*a, **kw)
else:
while True:
try:
return func(*a, **kw)
except SSLError, exc:
if exc[0] == SSL_ERROR_WANT_READ:
trampoline(self.fileno(),
read=True,
timeout=self.gettimeout(),
timeout_exc=SSLError)
elif exc[0] == SSL_ERROR_WANT_WRITE:
trampoline(self.fileno(),
write=True,
timeout=self.gettimeout(),
timeout_exc=SSLError)
else:
raise
def write(self, data):
"""Write DATA to the underlying SSL channel. Returns
number of bytes of DATA actually transmitted."""
return self._call_trampolining(
super(GreenSSLSocket, self).write, data)
def read(self, len=1024):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
return self._call_trampolining(
super(GreenSSLSocket, self).read,len)
def send (self, data, flags=0):
# *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
self.__class__)
while True:
try:
v = self._sslobj.write(data)
except SSLError, x:
if x.args[0] == SSL_ERROR_WANT_READ:
return 0
elif x.args[0] == SSL_ERROR_WANT_WRITE:
return 0
else:
raise
else:
return v
else:
while True:
try:
return socket.send(self, data, flags)
except orig_socket.error, e:
if self.act_non_blocking:
raise
if e[0] == errno.EWOULDBLOCK or \
e[0] == errno.ENOTCONN:
return 0
raise
def sendto (self, data, addr, flags=0):
# *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is
if self._sslobj:
raise ValueError("sendto not allowed on instances of %s" %
self.__class__)
else:
trampoline(self.fileno(), write=True, timeout_exc=orig_socket.timeout)
return socket.sendto(self, data, addr, flags)
def sendall (self, data, flags=0):
# *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to sendall() on %s" %
self.__class__)
amount = len(data)
count = 0
while (count < amount):
v = self.send(data[count:])
count += v
return amount
else:
while True:
try:
return socket.sendall(self, buflen, flags)
except orig_socket.error, e:
if self.act_non_blocking:
raise
if e[0] == errno.EWOULDBLOCK:
trampoline(self.fileno(), write=True,
timeout=self.gettimeout(), timeout_exc=orig_socket.timeout)
if e[0] in SOCKET_CLOSED:
return ''
raise
def recv(self, buflen=1024, flags=0):
# *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv() on %s" %
self.__class__)
read = self.read(buflen)
return read
else:
while True:
try:
return socket.recv(self, buflen, flags)
except orig_socket.error, e:
if self.act_non_blocking:
raise
if e[0] == errno.EWOULDBLOCK:
trampoline(self.fileno(), read=True,
timeout=self.gettimeout(), timeout_exc=orig_socket.timeout)
if e[0] in SOCKET_CLOSED:
return ''
raise
def recv_into (self, buffer, nbytes=None, flags=0):
if not self.act_non_blocking:
trampoline(self.fileno(), read=True, timeout=self.gettimeout(), timeout_exc=orig_socket.timeout)
return super(GreenSSLSocket, self).recv_into(buffer, nbytes, flags)
def recvfrom (self, addr, buflen=1024, flags=0):
if not self.act_non_blocking:
trampoline(self.fileno(), read=True, timeout=self.gettimeout(), timeout_exc=orig_socket.timeout)
return super(GreenSSLSocket, self).recvfrom(addr, buflen, flags)
def recvfrom_into (self, buffer, nbytes=None, flags=0):
if not self.act_non_blocking:
trampoline(self.fileno(), read=True, timeout=self.gettimeout(), timeout_exc=orig_socket.timeout)
return super(GreenSSLSocket, self).recvfrom_into(buffer, nbytes, flags)
def unwrap(self):
return GreenSocket(super(GreenSSLSocket, self).unwrap())
def do_handshake(self):
"""Perform a TLS/SSL handshake."""
return self._call_trampolining(
super(GreenSSLSocket, self).do_handshake)
def _socket_connect(self, addr):
real_connect = socket.connect
if self.act_non_blocking:
return real_connect(self, addr)
else:
# *NOTE: gross, copied code from greenio because it's not factored
# well enough to reuse
if self.gettimeout() is None:
while True:
try:
return real_connect(self, addr)
except orig_socket.error, exc:
if exc[0] in CONNECT_ERR:
trampoline(self.fileno(), write=True)
elif exc[0] in CONNECT_SUCCESS:
return
else:
raise
else:
end = time.time() + self.gettimeout()
while True:
try:
real_connect(self, addr)
except orig_socket.error, exc:
if exc[0] in CONNECT_ERR:
trampoline(self.fileno(), write=True,
timeout=end-time.time(), timeout_exc=orig_socket.timeout)
elif exc[0] in CONNECT_SUCCESS:
return
else:
raise
if time.time() >= end:
raise orig_socket.timeout
def connect(self, addr):
"""Connects to remote ADDR, and then wraps the connection in
an SSL channel."""
# *NOTE: grrrrr copied this code from ssl.py because of the reference
# to socket.connect which we don't want to call directly
if self._sslobj:
raise ValueError("attempt to connect already-connected SSLSocket!")
self._socket_connect(addr)
self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs)
if self.do_handshake_on_connect:
self.do_handshake()
def accept(self):
"""Accepts a new connection from a remote client, and returns
a tuple containing that new connection wrapped with a server-side
SSL channel, and the address of the remote client."""
# RDW grr duplication of code from greenio
if self.act_non_blocking:
newsock, addr = socket.accept(self)
else:
while True:
try:
newsock, addr = socket.accept(self)
set_nonblocking(newsock)
break
except orig_socket.error, e:
if e[0] != errno.EWOULDBLOCK:
raise
trampoline(self.fileno(), read=True, timeout=self.gettimeout(),
timeout_exc=orig_socket.timeout)
new_ssl = type(self)(newsock,
keyfile=self.keyfile,
certfile=self.certfile,
server_side=True,
cert_reqs=self.cert_reqs,
ssl_version=self.ssl_version,
ca_certs=self.ca_certs,
do_handshake_on_connect=self.do_handshake_on_connect,
suppress_ragged_eofs=self.suppress_ragged_eofs)
return (new_ssl, addr)
SSLSocket = GreenSSLSocket
def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True):
return GreenSSLSocket(sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs)
def sslwrap_simple(sock, keyfile=None, certfile=None):
"""A replacement for the old socket.ssl function. Designed
for compability with Python 2.5 and earlier. Will disappear in
Python 3.0."""
ssl_sock = GreenSSLSocket(sock, 0, keyfile, certfile, CERT_NONE,
PROTOCOL_SSLv23, None)
return GreenSSLObject(ssl_sock)

View File

@@ -146,7 +146,6 @@ def set_nonblocking(fd):
class GreenSocket(object): class GreenSocket(object):
is_secure = False
timeout = None timeout = None
def __init__(self, family_or_realsock=socket.AF_INET, *args, **kwargs): def __init__(self, family_or_realsock=socket.AF_INET, *args, **kwargs):
if isinstance(family_or_realsock, (int, long)): if isinstance(family_or_realsock, (int, long)):
@@ -689,23 +688,59 @@ class GreenSSL(GreenSocket):
def want_write(self, *args, **kw): def want_write(self, *args, **kw):
fn = self.want_write = self.fd.want_write fn = self.want_write = self.fd.want_write
return fn(*args, **kw) return fn(*args, **kw)
def shutdown_safe(sock):
""" Shuts down the socket. This is a convenience method for
code that wants to gracefully handle regular sockets, SSL.Connection
sockets from PyOpenSSL and ssl.SSLSocket objects from Python 2.6
interchangeably. Both types of ssl socket require a shutdown() before
close, but they have different arity on their shutdown method.
Regular sockets don't need a shutdown before close, but it doesn't hurt.
"""
try:
try:
# socket, ssl.SSLSocket
return sock.shutdown(socket.SHUT_RDWR)
except TypeError:
# SSL.Connection
return sock.shutdown()
except socket.error, e:
# we don't care if the socket is already closed;
# this will often be the case in an http server context
if e[0] != errno.ENOTCONN:
raise
def _convert_to_sslerror(ex): def _convert_to_sslerror(ex):
""" Transliterates SSL.SysCallErrors to socket.sslerrors""" """ Transliterates SSL.SysCallErrors to socket.sslerrors"""
return socket.sslerror((ex[0], ex[1])) return socket.sslerror((ex[0], ex[1]))
class GreenSSLObject(object): class GreenSSLObject(object):
""" Wrapper object around the SSLObjects returned by socket.ssl, which have a """ Wrapper object around the SSLObjects returned by socket.ssl, which have a
slightly different interface from SSL.Connection objects. """ slightly different interface from SSL.Connection objects. """
def __init__(self, green_ssl_obj): def __init__(self, green_ssl_obj):
""" Should only be called by a 'green' socket.ssl """ """ Should only be called by a 'green' socket.ssl """
assert isinstance(green_ssl_obj, GreenSSL) try:
from eventlet.green.ssl import GreenSSLSocket
except ImportError:
class GreenSSLSocket(object):
pass
assert isinstance(green_ssl_obj, (GreenSSL, GreenSSLSocket))
self.connection = green_ssl_obj self.connection = green_ssl_obj
try: try:
self.connection.do_handshake() # if it's already connected, do the handshake
except SSL.SysCallError, e: self.connection.getpeername()
raise _convert_to_sslerror(e) except:
pass
else:
try:
self.connection.do_handshake()
except SSL.SysCallError, e:
raise _convert_to_sslerror(e)
def read(self, n=None): def read(self, n=None):
"""If n is provided, read n bytes from the SSL connection, otherwise read """If n is provided, read n bytes from the SSL connection, otherwise read

View File

@@ -35,39 +35,49 @@ def tcp_socket():
s = __original_socket__(socket.AF_INET, socket.SOCK_STREAM) s = __original_socket__(socket.AF_INET, socket.SOCK_STREAM)
return s return s
try: try:
try: # if ssl is available, use eventlet.green.ssl for our ssl implementation
import ssl import ssl as _ssl
__original_ssl__ = ssl.wrap_socket def wrap_ssl(sock, certificate=None, private_key=None, server_side=False):
except ImportError: from eventlet.green import ssl
__original_ssl__ = socket.ssl return ssl.wrap_socket(sock,
except AttributeError: keyfile=private_key, certfile=certificate,
__original_ssl__ = None server_side=server_side, cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True)
def wrap_ssl_obj(sock, certificate=None, private_key=None):
def wrap_ssl(sock, certificate=None, private_key=None): from eventlet import ssl
from OpenSSL import SSL warnings.warn("socket.ssl() is deprecated. Use ssl.wrap_socket() instead.",
from eventlet import greenio DeprecationWarning, stacklevel=2)
context = SSL.Context(SSL.SSLv23_METHOD) return ssl.sslwrap_simple(sock, keyfile, certfile)
if certificate is not None:
context.use_certificate_file(certificate) except ImportError:
if private_key is not None: # if ssl is not available, use PyOpenSSL
context.use_privatekey_file(private_key) def wrap_ssl(sock, certificate=None, private_key=None, server_side=False):
context.set_verify(SSL.VERIFY_NONE, lambda *x: True) from OpenSSL import SSL
from eventlet import greenio
## TODO only do this on client sockets? how? context = SSL.Context(SSL.SSLv23_METHOD)
connection = SSL.Connection(context, sock) if certificate is not None:
connection.set_connect_state() context.use_certificate_file(certificate)
return greenio.GreenSSL(connection) if private_key is not None:
context.use_privatekey_file(private_key)
context.set_verify(SSL.VERIFY_NONE, lambda *x: True)
def wrap_ssl_obj(sock, certificate=None, private_key=None):
""" For 100% compatibility with the socket module, this wraps and handshakes an connection = SSL.Connection(context, sock)
open connection, returning a SSLObject.""" if server_side:
from eventlet import greenio connection.set_accept_state()
wrapped = wrap_ssl(sock, certificate, private_key) else:
return greenio.GreenSSLObject(wrapped) connection.set_connect_state()
return greenio.GreenSSL(connection)
def wrap_ssl_obj(sock, certificate=None, private_key=None):
""" For 100% compatibility with the socket module, this wraps and handshakes an
open connection, returning a SSLObject."""
from eventlet import greenio
wrapped = wrap_ssl(sock, certificate, private_key)
return greenio.GreenSSLObject(wrapped)
socket_already_wrapped = False socket_already_wrapped = False
def wrap_socket_with_coroutine_socket(use_thread_pool=True): def wrap_socket_with_coroutine_socket(use_thread_pool=True):
@@ -80,8 +90,13 @@ def wrap_socket_with_coroutine_socket(use_thread_pool=True):
return greenio.GreenSocket(__original_socket__(*args, **kw)) return greenio.GreenSocket(__original_socket__(*args, **kw))
socket.socket = new_socket socket.socket = new_socket
# for 100% compatibility, return a GreenSSLObject socket.ssl = wrap_ssl_obj
socket.ssl = wrap_ssl_obj try:
import ssl as _ssl
from eventlet.green import ssl
_ssl.wrap_socket = ssl.wrap_socket
except ImportError:
pass
if use_thread_pool: if use_thread_pool:
try: try:
@@ -254,8 +269,7 @@ def set_reuse_addr(descriptor):
descriptor.setsockopt( descriptor.setsockopt(
socket.SOL_SOCKET, socket.SOL_SOCKET,
socket.SO_REUSEADDR, socket.SO_REUSEADDR,
descriptor.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1, descriptor.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1)
)
except socket.error: except socket.error:
pass pass

View File

@@ -29,7 +29,6 @@ def format_date_time(timestamp):
_weekdayname[wd], day, _monthname[month], year, hh, mm, ss _weekdayname[wd], day, _monthname[month], year, hh, mm, ss
) )
class Input(object): class Input(object):
def __init__(self, def __init__(self,
rfile, rfile,
@@ -349,8 +348,7 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
def finish(self): def finish(self):
BaseHTTPServer.BaseHTTPRequestHandler.finish(self) BaseHTTPServer.BaseHTTPRequestHandler.finish(self)
if self.connection.is_secure: greenio.shutdown_safe(self.connection)
self.connection.shutdown()
self.connection.close() self.connection.close()
@@ -429,7 +427,7 @@ def server(sock, site,
try: try:
host, port = sock.getsockname() host, port = sock.getsockname()
port = ':%s' % (port, ) port = ':%s' % (port, )
if sock.is_secure: if hasattr(sock, 'do_handshake'):
scheme = 'https' scheme = 'https'
if port == ':443': if port == ':443':
port = '' port = ''
@@ -452,8 +450,7 @@ def server(sock, site,
break break
finally: finally:
try: try:
if sock.is_secure: greenio.shutdown_safe(sock)
sock.shutdown()
sock.close() sock.close()
except socket.error, e: except socket.error, e:
if e[0] != errno.EPIPE: if e[0] != errno.EPIPE:

View File

@@ -65,10 +65,10 @@ class TestApi(TestCase):
try: try:
conn, addr = listenfd.accept() conn, addr = listenfd.accept()
conn.write('hello\r\n') conn.write('hello\r\n')
conn.shutdown() greenio.shutdown_safe(conn)
conn.close() conn.close()
finally: finally:
listenfd.shutdown() greenio.shutdown_safe(listenfd)
listenfd.close() listenfd.close()
server = api.ssl_listener(('0.0.0.0', 0), server = api.ssl_listener(('0.0.0.0', 0),
@@ -76,13 +76,17 @@ class TestApi(TestCase):
self.private_key_file) self.private_key_file)
api.spawn(accept_once, server) api.spawn(accept_once, server)
client = util.wrap_ssl( raw_client = api.connect_tcp(('127.0.0.1', server.getsockname()[1]))
api.connect_tcp(('127.0.0.1', server.getsockname()[1]))) client = util.wrap_ssl(raw_client)
fd = socket._fileobject(client, 'rb', 8192) fd = socket._fileobject(client, 'rb', 8192)
assert fd.readline() == 'hello\r\n' assert fd.readline() == 'hello\r\n'
self.assertRaises(greenio.SSL.ZeroReturnError, fd.read) try:
client.shutdown() self.assertEquals('', fd.read(10))
except greenio.SSL.ZeroReturnError:
# if it's a GreenSSL object it'll do this
pass
greenio.shutdown_safe(client)
client.close() client.close()
check_hub() check_hub()

View File

@@ -275,7 +275,7 @@ class SSLTest(LimitedTestCase):
def serve(listener): def serve(listener):
sock, addr = listener.accept() sock, addr = listener.accept()
sock.write('content') sock.write('content')
sock.shutdown() greenio.shutdown_safe(sock)
sock.close() sock.close()
listener = api.ssl_listener(('', 0), listener = api.ssl_listener(('', 0),
self.certificate_file, self.certificate_file,
@@ -285,6 +285,22 @@ class SSLTest(LimitedTestCase):
client = greenio.GreenSSLObject(client) client = greenio.GreenSSLObject(client)
self.assertEquals(client.read(1024), 'content') self.assertEquals(client.read(1024), 'content')
self.assertEquals(client.read(1024), '') self.assertEquals(client.read(1024), '')
def test_ssl_close(self):
def serve(listener):
sock, addr = listener.accept()
stuff = sock.read(8192)
empt = sock.read(8192)
sock = api.ssl_listener(('127.0.0.1', 0), self.certificate_file, self.private_key_file)
server_coro = coros.execute(serve, sock)
raw_client = api.connect_tcp(('127.0.0.1', sock.getsockname()[1]))
client = util.wrap_ssl(raw_client)
client.write('X')
greenio.shutdown_safe(client)
client.close()
server_coro.wait()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@@ -1,10 +1,12 @@
import cgi import cgi
import os import os
import socket
from tests import skipped, LimitedTestCase from tests import skipped, LimitedTestCase
from unittest import main from unittest import main
from eventlet import api from eventlet import api
from eventlet import util from eventlet import util
from eventlet import greenio
from eventlet import wsgi from eventlet import wsgi
from eventlet import processes from eventlet import processes
@@ -368,6 +370,8 @@ class TestHttpd(LimitedTestCase):
serv.process_request(client_socket) serv.process_request(client_socket)
return True return True
except: except:
import traceback
traceback.print_exc()
return False return False
def wsgi_app(environ, start_response): def wsgi_app(environ, start_response):
@@ -385,7 +389,8 @@ class TestHttpd(LimitedTestCase):
client = api.connect_tcp(('localhost', sock.getsockname()[1])) client = api.connect_tcp(('localhost', sock.getsockname()[1]))
client = util.wrap_ssl(client) client = util.wrap_ssl(client)
client.write('X') # non-empty payload so that SSL handshake occurs client.write('X') # non-empty payload so that SSL handshake occurs
client.shutdown() greenio.shutdown_safe(client)
client.close()
success = server_coro.wait() success = server_coro.wait()
self.assert_(success) self.assert_(success)