Initial implemenation of using ssl module instead of pyOpenSSL.

This commit is contained in:
Ryan Williams
2009-11-21 18:02:16 -05:00
parent 87c3d5200a
commit 8e27d09a32
8 changed files with 423 additions and 58 deletions

View File

@@ -58,9 +58,8 @@ def ssl_listener(address, certificate, private_key):
spawns greenlets for each incoming connection.
"""
from eventlet import util
socket = util.wrap_ssl(util.tcp_socket(), certificate, private_key)
socket = util.wrap_ssl(util.tcp_socket(), certificate, private_key, True)
util.socket_bind_and_listen(socket, address)
socket.is_secure = True
return socket
def connect_tcp(address, localaddr=None):

295
eventlet/green/ssl.py Normal file
View File

@@ -0,0 +1,295 @@
__ssl = __import__('ssl')
for attr in dir(__ssl):
exec "%s = __ssl.%s" % (attr, attr)
import errno
import time
from eventlet.api import trampoline, getcurrent
from thread import get_ident
from eventlet.greenio import set_nonblocking, GreenSocket, GreenSSLObject, SOCKET_CLOSED, CONNECT_ERR, CONNECT_SUCCESS
orig_socket = __import__('socket')
socket = orig_socket.socket
class GreenSSLSocket(__ssl.SSLSocket):
""" This is a green version of the SSLSocket class from the ssl module added
in 2.6. For documentation on it, please see the Python standard
documentation."""
# we are inheriting from SSLSocket because its constructor calls
# do_handshake whose behavior we wish to override
def __init__(self, sock, *args, **kw):
if not isinstance(sock, GreenSocket):
sock = GreenSocket(sock)
self.act_non_blocking = sock.act_non_blocking
self.timeout = sock.timeout
super(GreenSSLSocket, self).__init__(sock.fd, *args, **kw)
del sock
# the superclass initializer trashes the methods so...
self.send = lambda data, flags=0: GreenSSLSocket.send(self, data, flags)
self.sendto = lambda data, addr, flags=0: GreenSSLSocket.sendto(self, data, addr, flags)
self.recv = lambda buflen=1024, flags=0: GreenSSLSocket.recv(self, buflen, flags)
self.recvfrom = lambda addr, buflen=1024, flags=0: GreenSSLSocket.recvfrom(self, addr, buflen, flags)
self.recv_into = lambda buffer, nbytes=None, flags=0: GreenSSLSocket.recv_into(self, buffer, nbytes, flags)
self.recvfrom_into = lambda buffer, nbytes=None, flags=0: GreenSSLSocket.recvfrom_into(self, buffer, nbytes, flags)
def settimeout(self, timeout):
self.timeout = timeout
def gettimeout(self):
return self.timeout
setblocking = GreenSocket.setblocking
def _call_trampolining(self, func, *a, **kw):
if self.act_non_blocking:
return func(*a, **kw)
else:
while True:
try:
return func(*a, **kw)
except SSLError, exc:
if exc[0] == SSL_ERROR_WANT_READ:
trampoline(self.fileno(),
read=True,
timeout=self.gettimeout(),
timeout_exc=SSLError)
elif exc[0] == SSL_ERROR_WANT_WRITE:
trampoline(self.fileno(),
write=True,
timeout=self.gettimeout(),
timeout_exc=SSLError)
else:
raise
def write(self, data):
"""Write DATA to the underlying SSL channel. Returns
number of bytes of DATA actually transmitted."""
return self._call_trampolining(
super(GreenSSLSocket, self).write, data)
def read(self, len=1024):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
return self._call_trampolining(
super(GreenSSLSocket, self).read,len)
def send (self, data, flags=0):
# *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
self.__class__)
while True:
try:
v = self._sslobj.write(data)
except SSLError, x:
if x.args[0] == SSL_ERROR_WANT_READ:
return 0
elif x.args[0] == SSL_ERROR_WANT_WRITE:
return 0
else:
raise
else:
return v
else:
while True:
try:
return socket.send(self, data, flags)
except orig_socket.error, e:
if self.act_non_blocking:
raise
if e[0] == errno.EWOULDBLOCK or \
e[0] == errno.ENOTCONN:
return 0
raise
def sendto (self, data, addr, flags=0):
# *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is
if self._sslobj:
raise ValueError("sendto not allowed on instances of %s" %
self.__class__)
else:
trampoline(self.fileno(), write=True, timeout_exc=orig_socket.timeout)
return socket.sendto(self, data, addr, flags)
def sendall (self, data, flags=0):
# *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to sendall() on %s" %
self.__class__)
amount = len(data)
count = 0
while (count < amount):
v = self.send(data[count:])
count += v
return amount
else:
while True:
try:
return socket.sendall(self, buflen, flags)
except orig_socket.error, e:
if self.act_non_blocking:
raise
if e[0] == errno.EWOULDBLOCK:
trampoline(self.fileno(), write=True,
timeout=self.gettimeout(), timeout_exc=orig_socket.timeout)
if e[0] in SOCKET_CLOSED:
return ''
raise
def recv(self, buflen=1024, flags=0):
# *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv() on %s" %
self.__class__)
read = self.read(buflen)
return read
else:
while True:
try:
return socket.recv(self, buflen, flags)
except orig_socket.error, e:
if self.act_non_blocking:
raise
if e[0] == errno.EWOULDBLOCK:
trampoline(self.fileno(), read=True,
timeout=self.gettimeout(), timeout_exc=orig_socket.timeout)
if e[0] in SOCKET_CLOSED:
return ''
raise
def recv_into (self, buffer, nbytes=None, flags=0):
if not self.act_non_blocking:
trampoline(self.fileno(), read=True, timeout=self.gettimeout(), timeout_exc=orig_socket.timeout)
return super(GreenSSLSocket, self).recv_into(buffer, nbytes, flags)
def recvfrom (self, addr, buflen=1024, flags=0):
if not self.act_non_blocking:
trampoline(self.fileno(), read=True, timeout=self.gettimeout(), timeout_exc=orig_socket.timeout)
return super(GreenSSLSocket, self).recvfrom(addr, buflen, flags)
def recvfrom_into (self, buffer, nbytes=None, flags=0):
if not self.act_non_blocking:
trampoline(self.fileno(), read=True, timeout=self.gettimeout(), timeout_exc=orig_socket.timeout)
return super(GreenSSLSocket, self).recvfrom_into(buffer, nbytes, flags)
def unwrap(self):
return GreenSocket(super(GreenSSLSocket, self).unwrap())
def do_handshake(self):
"""Perform a TLS/SSL handshake."""
return self._call_trampolining(
super(GreenSSLSocket, self).do_handshake)
def _socket_connect(self, addr):
real_connect = socket.connect
if self.act_non_blocking:
return real_connect(self, addr)
else:
# *NOTE: gross, copied code from greenio because it's not factored
# well enough to reuse
if self.gettimeout() is None:
while True:
try:
return real_connect(self, addr)
except orig_socket.error, exc:
if exc[0] in CONNECT_ERR:
trampoline(self.fileno(), write=True)
elif exc[0] in CONNECT_SUCCESS:
return
else:
raise
else:
end = time.time() + self.gettimeout()
while True:
try:
real_connect(self, addr)
except orig_socket.error, exc:
if exc[0] in CONNECT_ERR:
trampoline(self.fileno(), write=True,
timeout=end-time.time(), timeout_exc=orig_socket.timeout)
elif exc[0] in CONNECT_SUCCESS:
return
else:
raise
if time.time() >= end:
raise orig_socket.timeout
def connect(self, addr):
"""Connects to remote ADDR, and then wraps the connection in
an SSL channel."""
# *NOTE: grrrrr copied this code from ssl.py because of the reference
# to socket.connect which we don't want to call directly
if self._sslobj:
raise ValueError("attempt to connect already-connected SSLSocket!")
self._socket_connect(addr)
self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs)
if self.do_handshake_on_connect:
self.do_handshake()
def accept(self):
"""Accepts a new connection from a remote client, and returns
a tuple containing that new connection wrapped with a server-side
SSL channel, and the address of the remote client."""
# RDW grr duplication of code from greenio
if self.act_non_blocking:
newsock, addr = socket.accept(self)
else:
while True:
try:
newsock, addr = socket.accept(self)
set_nonblocking(newsock)
break
except orig_socket.error, e:
if e[0] != errno.EWOULDBLOCK:
raise
trampoline(self.fileno(), read=True, timeout=self.gettimeout(),
timeout_exc=orig_socket.timeout)
new_ssl = type(self)(newsock,
keyfile=self.keyfile,
certfile=self.certfile,
server_side=True,
cert_reqs=self.cert_reqs,
ssl_version=self.ssl_version,
ca_certs=self.ca_certs,
do_handshake_on_connect=self.do_handshake_on_connect,
suppress_ragged_eofs=self.suppress_ragged_eofs)
return (new_ssl, addr)
SSLSocket = GreenSSLSocket
def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True):
return GreenSSLSocket(sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs)
def sslwrap_simple(sock, keyfile=None, certfile=None):
"""A replacement for the old socket.ssl function. Designed
for compability with Python 2.5 and earlier. Will disappear in
Python 3.0."""
ssl_sock = GreenSSLSocket(sock, 0, keyfile, certfile, CERT_NONE,
PROTOCOL_SSLv23, None)
return GreenSSLObject(ssl_sock)

View File

@@ -146,7 +146,6 @@ def set_nonblocking(fd):
class GreenSocket(object):
is_secure = False
timeout = None
def __init__(self, family_or_realsock=socket.AF_INET, *args, **kwargs):
if isinstance(family_or_realsock, (int, long)):
@@ -689,23 +688,59 @@ class GreenSSL(GreenSocket):
def want_write(self, *args, **kw):
fn = self.want_write = self.fd.want_write
return fn(*args, **kw)
def shutdown_safe(sock):
""" Shuts down the socket. This is a convenience method for
code that wants to gracefully handle regular sockets, SSL.Connection
sockets from PyOpenSSL and ssl.SSLSocket objects from Python 2.6
interchangeably. Both types of ssl socket require a shutdown() before
close, but they have different arity on their shutdown method.
Regular sockets don't need a shutdown before close, but it doesn't hurt.
"""
try:
try:
# socket, ssl.SSLSocket
return sock.shutdown(socket.SHUT_RDWR)
except TypeError:
# SSL.Connection
return sock.shutdown()
except socket.error, e:
# we don't care if the socket is already closed;
# this will often be the case in an http server context
if e[0] != errno.ENOTCONN:
raise
def _convert_to_sslerror(ex):
""" Transliterates SSL.SysCallErrors to socket.sslerrors"""
return socket.sslerror((ex[0], ex[1]))
class GreenSSLObject(object):
""" Wrapper object around the SSLObjects returned by socket.ssl, which have a
slightly different interface from SSL.Connection objects. """
def __init__(self, green_ssl_obj):
""" Should only be called by a 'green' socket.ssl """
assert isinstance(green_ssl_obj, GreenSSL)
try:
from eventlet.green.ssl import GreenSSLSocket
except ImportError:
class GreenSSLSocket(object):
pass
assert isinstance(green_ssl_obj, (GreenSSL, GreenSSLSocket))
self.connection = green_ssl_obj
try:
self.connection.do_handshake()
except SSL.SysCallError, e:
raise _convert_to_sslerror(e)
# if it's already connected, do the handshake
self.connection.getpeername()
except:
pass
else:
try:
self.connection.do_handshake()
except SSL.SysCallError, e:
raise _convert_to_sslerror(e)
def read(self, n=None):
"""If n is provided, read n bytes from the SSL connection, otherwise read

View File

@@ -35,39 +35,49 @@ def tcp_socket():
s = __original_socket__(socket.AF_INET, socket.SOCK_STREAM)
return s
try:
try:
import ssl
__original_ssl__ = ssl.wrap_socket
except ImportError:
__original_ssl__ = socket.ssl
except AttributeError:
__original_ssl__ = None
# if ssl is available, use eventlet.green.ssl for our ssl implementation
import ssl as _ssl
def wrap_ssl(sock, certificate=None, private_key=None, server_side=False):
from eventlet.green import ssl
return ssl.wrap_socket(sock,
keyfile=private_key, certfile=certificate,
server_side=server_side, cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True)
def wrap_ssl(sock, certificate=None, private_key=None):
from OpenSSL import SSL
from eventlet import greenio
context = SSL.Context(SSL.SSLv23_METHOD)
if certificate is not None:
context.use_certificate_file(certificate)
if private_key is not None:
context.use_privatekey_file(private_key)
context.set_verify(SSL.VERIFY_NONE, lambda *x: True)
## TODO only do this on client sockets? how?
connection = SSL.Connection(context, sock)
connection.set_connect_state()
return greenio.GreenSSL(connection)
def wrap_ssl_obj(sock, certificate=None, private_key=None):
""" For 100% compatibility with the socket module, this wraps and handshakes an
open connection, returning a SSLObject."""
from eventlet import greenio
wrapped = wrap_ssl(sock, certificate, private_key)
return greenio.GreenSSLObject(wrapped)
def wrap_ssl_obj(sock, certificate=None, private_key=None):
from eventlet import ssl
warnings.warn("socket.ssl() is deprecated. Use ssl.wrap_socket() instead.",
DeprecationWarning, stacklevel=2)
return ssl.sslwrap_simple(sock, keyfile, certfile)
except ImportError:
# if ssl is not available, use PyOpenSSL
def wrap_ssl(sock, certificate=None, private_key=None, server_side=False):
from OpenSSL import SSL
from eventlet import greenio
context = SSL.Context(SSL.SSLv23_METHOD)
if certificate is not None:
context.use_certificate_file(certificate)
if private_key is not None:
context.use_privatekey_file(private_key)
context.set_verify(SSL.VERIFY_NONE, lambda *x: True)
connection = SSL.Connection(context, sock)
if server_side:
connection.set_accept_state()
else:
connection.set_connect_state()
return greenio.GreenSSL(connection)
def wrap_ssl_obj(sock, certificate=None, private_key=None):
""" For 100% compatibility with the socket module, this wraps and handshakes an
open connection, returning a SSLObject."""
from eventlet import greenio
wrapped = wrap_ssl(sock, certificate, private_key)
return greenio.GreenSSLObject(wrapped)
socket_already_wrapped = False
def wrap_socket_with_coroutine_socket(use_thread_pool=True):
@@ -80,8 +90,13 @@ def wrap_socket_with_coroutine_socket(use_thread_pool=True):
return greenio.GreenSocket(__original_socket__(*args, **kw))
socket.socket = new_socket
# for 100% compatibility, return a GreenSSLObject
socket.ssl = wrap_ssl_obj
socket.ssl = wrap_ssl_obj
try:
import ssl as _ssl
from eventlet.green import ssl
_ssl.wrap_socket = ssl.wrap_socket
except ImportError:
pass
if use_thread_pool:
try:
@@ -254,8 +269,7 @@ def set_reuse_addr(descriptor):
descriptor.setsockopt(
socket.SOL_SOCKET,
socket.SO_REUSEADDR,
descriptor.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1,
)
descriptor.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1)
except socket.error:
pass

View File

@@ -29,7 +29,6 @@ def format_date_time(timestamp):
_weekdayname[wd], day, _monthname[month], year, hh, mm, ss
)
class Input(object):
def __init__(self,
rfile,
@@ -349,8 +348,7 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
def finish(self):
BaseHTTPServer.BaseHTTPRequestHandler.finish(self)
if self.connection.is_secure:
self.connection.shutdown()
greenio.shutdown_safe(self.connection)
self.connection.close()
@@ -429,7 +427,7 @@ def server(sock, site,
try:
host, port = sock.getsockname()
port = ':%s' % (port, )
if sock.is_secure:
if hasattr(sock, 'do_handshake'):
scheme = 'https'
if port == ':443':
port = ''
@@ -452,8 +450,7 @@ def server(sock, site,
break
finally:
try:
if sock.is_secure:
sock.shutdown()
greenio.shutdown_safe(sock)
sock.close()
except socket.error, e:
if e[0] != errno.EPIPE:

View File

@@ -65,10 +65,10 @@ class TestApi(TestCase):
try:
conn, addr = listenfd.accept()
conn.write('hello\r\n')
conn.shutdown()
greenio.shutdown_safe(conn)
conn.close()
finally:
listenfd.shutdown()
greenio.shutdown_safe(listenfd)
listenfd.close()
server = api.ssl_listener(('0.0.0.0', 0),
@@ -76,13 +76,17 @@ class TestApi(TestCase):
self.private_key_file)
api.spawn(accept_once, server)
client = util.wrap_ssl(
api.connect_tcp(('127.0.0.1', server.getsockname()[1])))
raw_client = api.connect_tcp(('127.0.0.1', server.getsockname()[1]))
client = util.wrap_ssl(raw_client)
fd = socket._fileobject(client, 'rb', 8192)
assert fd.readline() == 'hello\r\n'
self.assertRaises(greenio.SSL.ZeroReturnError, fd.read)
client.shutdown()
try:
self.assertEquals('', fd.read(10))
except greenio.SSL.ZeroReturnError:
# if it's a GreenSSL object it'll do this
pass
greenio.shutdown_safe(client)
client.close()
check_hub()

View File

@@ -275,7 +275,7 @@ class SSLTest(LimitedTestCase):
def serve(listener):
sock, addr = listener.accept()
sock.write('content')
sock.shutdown()
greenio.shutdown_safe(sock)
sock.close()
listener = api.ssl_listener(('', 0),
self.certificate_file,
@@ -285,6 +285,22 @@ class SSLTest(LimitedTestCase):
client = greenio.GreenSSLObject(client)
self.assertEquals(client.read(1024), 'content')
self.assertEquals(client.read(1024), '')
def test_ssl_close(self):
def serve(listener):
sock, addr = listener.accept()
stuff = sock.read(8192)
empt = sock.read(8192)
sock = api.ssl_listener(('127.0.0.1', 0), self.certificate_file, self.private_key_file)
server_coro = coros.execute(serve, sock)
raw_client = api.connect_tcp(('127.0.0.1', sock.getsockname()[1]))
client = util.wrap_ssl(raw_client)
client.write('X')
greenio.shutdown_safe(client)
client.close()
server_coro.wait()
if __name__ == '__main__':
main()

View File

@@ -1,10 +1,12 @@
import cgi
import os
import socket
from tests import skipped, LimitedTestCase
from unittest import main
from eventlet import api
from eventlet import util
from eventlet import greenio
from eventlet import wsgi
from eventlet import processes
@@ -368,6 +370,8 @@ class TestHttpd(LimitedTestCase):
serv.process_request(client_socket)
return True
except:
import traceback
traceback.print_exc()
return False
def wsgi_app(environ, start_response):
@@ -385,7 +389,8 @@ class TestHttpd(LimitedTestCase):
client = api.connect_tcp(('localhost', sock.getsockname()[1]))
client = util.wrap_ssl(client)
client.write('X') # non-empty payload so that SSL handshake occurs
client.shutdown()
greenio.shutdown_safe(client)
client.close()
success = server_coro.wait()
self.assert_(success)