websockets: Add websockets13 support

This commit is contained in:
Edward George
2013-12-31 10:54:11 +01:00
committed by Sergey Shepelev
parent 252aeeded5
commit 849d45682f
2 changed files with 627 additions and 34 deletions

View File

@@ -1,13 +1,19 @@
import base64
import codecs
import collections import collections
import errno import errno
from random import Random
import string import string
import struct import struct
import sys
import time
from socket import error as SocketError from socket import error as SocketError
try: try:
from hashlib import md5 from hashlib import md5, sha1
except ImportError: #pragma NO COVER except ImportError: #pragma NO COVER
from md5 import md5 from md5 import md5
from sha import sha as sha1
import eventlet import eventlet
from eventlet import semaphore from eventlet import semaphore
@@ -15,9 +21,41 @@ from eventlet import wsgi
from eventlet.green import socket from eventlet.green import socket
from eventlet.support import get_errno from eventlet.support import get_errno
# Python 2's utf8 decoding is more lenient than we'd like
# In order to pass autobahn's testsuite we need stricter validation
# if available...
for _mod in ('wsaccel.utf8validator', 'autobahn.utf8validator'):
# autobahn has it's own python-based validator. in newest versions
# this prefers to use wsaccel, a cython based implementation, if available.
# wsaccel may also be installed w/out autobahn, or with a earlier version.
try:
utf8validator = __import__(_mod, {}, {}, [''])
except ImportError:
utf8validator = None
else:
break
ACCEPTABLE_CLIENT_ERRORS = set((errno.ECONNRESET, errno.EPIPE)) ACCEPTABLE_CLIENT_ERRORS = set((errno.ECONNRESET, errno.EPIPE))
__all__ = ["WebSocketWSGI", "WebSocket"] __all__ = ["WebSocketWSGI", "WebSocket"]
PROTOCOL_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VALID_CLOSE_STATUS = (range(1000, 1004)
+ range(1007, 1012)
# 3000-3999: reserved for use by libraries, frameworks,
# and applications
+ range(3000, 4000)
# 4000-4999: reserved for private use and thus can't
# be registered
+ range(4000, 5000))
class BadRequest(Exception):
def __init__(self, status='400 Bad Request', body=None, headers=None):
super(Exception, self).__init__()
self.status = status
self.body = body
self.headers = headers
class WebSocketWSGI(object): class WebSocketWSGI(object):
"""Wraps a websocket handler function in a WSGI application. """Wraps a websocket handler function in a WSGI application.
@@ -37,6 +75,26 @@ class WebSocketWSGI(object):
def __init__(self, handler): def __init__(self, handler):
self.handler = handler self.handler = handler
self.protocol_version = None self.protocol_version = None
self.support_legacy_versions = True
self.supported_protocols = []
self.origin_checker = None
@classmethod
def configured(cls,
handler=None,
supported_protocols=None,
origin_checker=None,
support_legacy_versions=False):
def decorator(handler):
inst = cls(handler)
inst.support_legacy_versions = support_legacy_versions
inst.origin_checker = origin_checker
if supported_protocols:
inst.supported_protocols = supported_protocols
return inst
if handler is None:
return decorator
return decorator(handler)
def __call__(self, environ, start_response): def __call__(self, environ, start_response):
if not (environ.get('HTTP_CONNECTION') == 'Upgrade' and if not (environ.get('HTTP_CONNECTION') == 'Upgrade' and
@@ -45,21 +103,42 @@ class WebSocketWSGI(object):
start_response('400 Bad Request', [('Connection', 'close')]) start_response('400 Bad Request', [('Connection', 'close')])
return [] return []
# See if they sent the new-format headers try:
if 'HTTP_SEC_WEBSOCKET_VERSION' in environ:
ws = self._handle_hybi_request(environ)
elif self.support_legacy_versions:
ws = self._handle_legacy_request(environ)
else:
raise BadRequest()
except BadRequest, e:
status = e.status
body = e.body or ''
headers = e.headers or []
start_response(status,
[('Connection', 'close'), ] + headers)
return [body]
try:
self.handler(ws)
except socket.error, e:
if get_errno(e) not in ACCEPTABLE_CLIENT_ERRORS:
raise
# Make sure we send the closing frame
ws._send_closing_frame(True)
# use this undocumented feature of eventlet.wsgi to ensure that it
# doesn't barf on the fact that we didn't call start_response
return wsgi.ALREADY_HANDLED
def _handle_legacy_request(self, environ):
sock = environ['eventlet.input'].get_socket()
if 'HTTP_SEC_WEBSOCKET_KEY1' in environ: if 'HTTP_SEC_WEBSOCKET_KEY1' in environ:
self.protocol_version = 76 self.protocol_version = 76
if 'HTTP_SEC_WEBSOCKET_KEY2' not in environ: if 'HTTP_SEC_WEBSOCKET_KEY2' not in environ:
# That's bad. raise BadRequest()
start_response('400 Bad Request', [('Connection','close')])
return []
else: else:
self.protocol_version = 75 self.protocol_version = 75
# Get the underlying socket and wrap a WebSocket class around it
sock = environ['eventlet.input'].get_socket()
ws = WebSocket(sock, environ, self.protocol_version)
# If it's new-version, we need to work out our challenge response
if self.protocol_version == 76: if self.protocol_version == 76:
key1 = self._extract_number(environ['HTTP_SEC_WEBSOCKET_KEY1']) key1 = self._extract_number(environ['HTTP_SEC_WEBSOCKET_KEY1'])
key2 = self._extract_number(environ['HTTP_SEC_WEBSOCKET_KEY2']) key2 = self._extract_number(environ['HTTP_SEC_WEBSOCKET_KEY2'])
@@ -105,18 +184,49 @@ class WebSocketWSGI(object):
response)) response))
else: #pragma NO COVER else: #pragma NO COVER
raise ValueError("Unknown WebSocket protocol version.") raise ValueError("Unknown WebSocket protocol version.")
sock.sendall(handshake_reply) sock.sendall(handshake_reply)
try: return WebSocket(sock, environ, self.protocol_version)
self.handler(ws)
except socket.error as e: def _handle_hybi_request(self, environ):
if get_errno(e) not in ACCEPTABLE_CLIENT_ERRORS: sock = environ['eventlet.input'].get_socket()
raise hybi_version = environ['HTTP_SEC_WEBSOCKET_VERSION']
# Make sure we send the closing frame if hybi_version not in ('8', '13', ):
ws._send_closing_frame(True) raise BadRequest(status='426 Upgrade Required',
# use this undocumented feature of eventlet.wsgi to ensure that it headers=[('Sec-WebSocket-Version', '8, 13')])
# doesn't barf on the fact that we didn't call start_response self.protocol_version = int(hybi_version)
return wsgi.ALREADY_HANDLED if 'HTTP_SEC_WEBSOCKET_KEY' not in environ:
# That's bad.
raise BadRequest()
origin = environ.get(
'HTTP_ORIGIN',
(environ.get('HTTP_SEC_WEBSOCKET_ORIGIN', '')
if self.protocol_version <= 8 else ''))
if self.origin_checker is not None:
if not self.origin_checker(environ.get('HTTP_HOST'), origin):
raise BadRequest(status='403 Forbidden')
protocols = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', None)
negotiated_protocol = None
if protocols:
for p in (i.strip() for i in protocols.split(',')):
if p in self.supported_protocols:
negotiated_protocol = p
break
#extensions = environ.get('HTTP_SEC_WEBSOCKET_EXTENSIONS', None)
#if extensions:
# extensions = [i.strip() for i in extensions.split(',')]
key = environ['HTTP_SEC_WEBSOCKET_KEY']
response = base64.b64encode(sha1(key + PROTOCOL_GUID).digest())
handshake_reply = ["HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
"Sec-WebSocket-Accept: %s" % (response, )]
if negotiated_protocol:
handshake_reply.append("Sec-WebSocket-Protocol: %s"
% (negotiated_protocol, ))
sock.sendall('\r\n'.join(handshake_reply) + '\r\n\r\n')
return RFC6455WebSocket(sock, environ, self.protocol_version,
protocol=negotiated_protocol)
def _extract_number(self, value): def _extract_number(self, value):
""" """
@@ -265,3 +375,279 @@ class WebSocket(object):
self.socket.shutdown(True) self.socket.shutdown(True)
self.socket.close() self.socket.close()
class ConnectionClosedError(Exception):
pass
class FailedConnectionError(Exception):
def __init__(self, status, message):
super(FailedConnectionError, self).__init__(status, message)
self.message = message
self.status = status
class ProtocolError(ValueError):
pass
class RFC6455WebSocket(WebSocket):
def __init__(self, sock, environ, version=13, protocol=None, client=False):
super(RFC6455WebSocket, self).__init__(sock, environ, version)
self.iterator = self._iter_frames()
self.client = client
self.protocol = protocol
class UTF8Decoder(object):
def __init__(self):
if utf8validator:
self.validator = utf8validator.Utf8Validator()
else:
self.validator = None
decoderclass = codecs.getincrementaldecoder('utf8')
self.decoder = decoderclass()
def reset(self):
if self.validator:
self.validator.reset()
self.decoder.reset()
def decode(self, data, final=False):
if self.validator:
valid, eocp, c_i, t_i = self.validator.validate(data)
if not valid:
raise ValueError('Data is not valid unicode')
return self.decoder.decode(data, final)
def _get_bytes(self, numbytes):
data = ''
while len(data) < numbytes:
d = self.socket.recv(numbytes - len(data))
if not d:
raise ConnectionClosedError()
data = data + d
return data
class Message(object):
def __init__(self, opcode, decoder=None):
self.decoder = decoder
self.data = []
self.finished = False
self.opcode = opcode
def push(self, data, final=False):
if self.decoder:
data = self.decoder.decode(data, final=final)
self.finished = final
self.data.append(data)
def getvalue(self):
return ''.join(self.data)
@staticmethod
def _apply_mask(data, mask, length=None, offset=0):
if length is None:
length = len(data)
cnt = xrange(length)
return ''.join(chr(ord(data[i]) ^ mask[(offset + i) % 4]) for i in cnt)
def _handle_control_frame(self, opcode, data):
if opcode == 8: # connection close
if not data:
status = 1000
elif len(data) > 1:
status = struct.unpack_from('!H', data)[0]
if not status or status not in VALID_CLOSE_STATUS:
raise FailedConnectionError(
1002,
"Unexpected close status code.")
try:
data = self.UTF8Decoder().decode(data[2:], True)
except (UnicodeDecodeError, ValueError):
raise FailedConnectionError(
1002,
"Close message data should be valid UTF-8.")
else:
status = 1002
self.close(close_data=(status, ''))
raise ConnectionClosedError()
elif opcode == 9: # ping
self.send(data, control_code=0xA)
elif opcode == 0xA: # pong
pass
else:
raise FailedConnectionError(
1002, "Unknown control frame received.")
def _iter_frames(self):
fragmented_message = None
try:
while True:
message = self._recv_frame(message=fragmented_message)
if message.opcode & 8:
self._handle_control_frame(
message.opcode, message.getvalue())
continue
if fragmented_message and message is not fragmented_message:
raise RuntimeError('Unexpected message change.')
fragmented_message = message
if message.finished:
data = fragmented_message.getvalue()
fragmented_message = None
yield data
except FailedConnectionError:
exc_typ, exc_val, exc_tb = sys.exc_info()
self.close(close_data=(exc_val.status, exc_val.message))
except ConnectionClosedError:
return
except Exception:
self.close(close_data=(1011, 'Internal Server Error'))
raise
def _recv_frame(self, message=None):
recv = self._get_bytes
header = recv(2)
a, b = struct.unpack('!BB', header)
finished = a >> 7 == 1
rsv123 = a >> 4 & 7
if rsv123:
# must be zero
raise FailedConnectionError(
1002,
"RSV1, RSV2, RSV3: MUST be 0 unless an extension is"
" negotiated that defines meanings for non-zero values.")
opcode = a & 15
if opcode not in (0, 1, 2, 8, 9, 0xA):
raise FailedConnectionError(1002, "Unknown opcode received.")
masked = b & 128 == 128
if not masked and not self.client:
raise FailedConnectionError(1002, "A client MUST mask all frames"
" that it sends to the server")
length = b & 127
if opcode & 8:
if not finished:
raise FailedConnectionError(1002, "Control frames must not"
" be fragmented.")
if length > 125:
raise FailedConnectionError(
1002,
"All control frames MUST have a payload length of 125"
" bytes or less")
elif opcode and message:
raise FailedConnectionError(
1002,
"Received a non-continuation opcode within"
" fragmented message.")
elif not opcode and not message:
raise FailedConnectionError(
1002,
"Received continuation opcode with no previous"
" fragments received.")
if length == 126:
length = struct.unpack('!H', recv(2))[0]
elif length == 127:
length = struct.unpack('!Q', recv(8))[0]
if masked:
mask = struct.unpack('!BBBB', recv(4))
received = 0
if not message or opcode & 8:
decoder = self.UTF8Decoder() if opcode == 1 else None
message = self.Message(opcode, decoder=decoder)
if not length:
message.push('', final=finished)
else:
while received < length:
d = self.socket.recv(length - received)
if not d:
raise ConnectionClosedError()
dlen = len(d)
if masked:
d = self._apply_mask(d, mask, length=dlen, offset=received)
received = received + dlen
try:
message.push(d, final=finished)
except (UnicodeDecodeError, ValueError):
raise FailedConnectionError(
1007, "Text data must be valid utf-8")
return message
@staticmethod
def _pack_message(message, masked=False,
continuation=False, final=True, control_code=None):
is_text = False
if isinstance(message, unicode):
message = message.encode('utf-8')
is_text = True
length = len(message)
if not length:
# no point masking empty data
masked = False
if control_code:
if control_code not in (8, 9, 0xA):
raise ProtocolError('Unknown control opcode.')
if continuation or not final:
raise ProtocolError('Control frame cannot be a fragment.')
if length > 125:
raise ProtocolError('Control frame data too large (>125).')
header = struct.pack('!B', control_code | 1 << 7)
else:
opcode = 0 if continuation else (1 if is_text else 2)
header = struct.pack('!B', opcode | (1 << 7 if final else 0))
lengthdata = 1 << 7 if masked else 0
if length > 65535:
lengthdata = struct.pack('!BQ', lengthdata | 127, length)
elif length > 125:
lengthdata = struct.pack('!BH', lengthdata | 126, length)
else:
lengthdata = struct.pack('!B', lengthdata | length)
if masked:
# NOTE: RFC6455 states:
# A server MUST NOT mask any frames that it sends to the client
rand = Random(time.time())
mask = map(rand.getrandbits, (8, ) * 4)
message = RFC6455WebSocket._apply_mask(message, mask, length)
maskdata = struct.pack('!BBBB', *mask)
else:
maskdata = ''
return ''.join((header, lengthdata, maskdata, message))
def wait(self):
for i in self.iterator:
return i
def _send(self, frame):
self._sendlock.acquire()
try:
self.socket.sendall(frame)
finally:
self._sendlock.release()
def send(self, message, **kw):
kw['masked'] = self.client
payload = self._pack_message(message, **kw)
self._send(payload)
def _send_closing_frame(self, ignore_send_errors=False, close_data=None):
if self.version in (8, 13) and not self.websocket_closed:
if close_data is not None:
status, msg = close_data
if isinstance(msg, unicode):
msg = msg.encode('utf-8')
data = struct.pack('!H', status) + msg
else:
data = ''
try:
self.send(data, control_code=8)
except SocketError:
# Sometimes, like when the remote side cuts off the connection,
# we don't care about this.
if not ignore_send_errors: # pragma NO COVER
raise
self.websocket_closed = True
def close(self, close_data=None):
"""Forcibly close the websocket; generally it is preferable to
return from the handler method."""
self._send_closing_frame(close_data=close_data)
self.socket.shutdown(socket.SHUT_WR)
self.socket.close()

207
tests/websocket_new_test.py Normal file
View File

@@ -0,0 +1,207 @@
import errno
import struct
import eventlet
from eventlet import event
from eventlet.green import httplib
from eventlet.green import socket
from eventlet import websocket
from tests.wsgi_test import _TestBase
# demo app
def handle(ws):
if ws.path == '/echo':
while True:
m = ws.wait()
if m is None:
break
ws.send(m)
elif ws.path == '/range':
for i in xrange(10):
ws.send("msg %d" % i)
eventlet.sleep(0.01)
elif ws.path == '/error':
# some random socket error that we shouldn't normally get
raise socket.error(errno.ENOTSOCK)
else:
ws.close()
wsapp = websocket.WebSocketWSGI(handle)
class TestWebSocket(_TestBase):
TEST_TIMEOUT = 5
def set_site(self):
self.site = wsapp
def test_incomplete_headers_13(self):
headers = dict(kv.split(': ') for kv in [
"Upgrade: websocket",
# NOTE: intentionally no connection header
"Host: localhost:%s" % self.port,
"Origin: http://localhost:%s" % self.port,
"Sec-WebSocket-Version: 13", ])
http = httplib.HTTPConnection('localhost', self.port)
http.request("GET", "/echo", headers=headers)
resp = http.getresponse()
self.assertEqual(resp.status, 400)
self.assertEqual(resp.getheader('connection'), 'close')
self.assertEqual(resp.read(), '')
# Now, miss off key
headers = dict(kv.split(': ') for kv in [
"Upgrade: websocket",
"Connection: Upgrade",
"Host: localhost:%s" % self.port,
"Origin: http://localhost:%s" % self.port,
"Sec-WebSocket-Version: 13", ])
http = httplib.HTTPConnection('localhost', self.port)
http.request("GET", "/echo", headers=headers)
resp = http.getresponse()
self.assertEqual(resp.status, 400)
self.assertEqual(resp.getheader('connection'), 'close')
self.assertEqual(resp.read(), '')
def test_correct_upgrade_request_13(self):
connect = [
"GET /echo HTTP/1.1",
"Upgrade: websocket",
"Connection: Upgrade",
"Host: localhost:%s" % self.port,
"Origin: http://localhost:%s" % self.port,
"Sec-WebSocket-Version: 13",
"Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==", ]
sock = eventlet.connect(
('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n')
result = sock.recv(1024)
## The server responds the correct Websocket handshake
self.assertEqual(result,
'\r\n'.join(['HTTP/1.1 101 Switching Protocols',
'Upgrade: websocket',
'Connection: Upgrade',
'Sec-WebSocket-Accept: ywSyWXCPNsDxLrQdQrn5RFNRfBU=\r\n\r\n', ]))
def test_send_recv_13(self):
connect = [
"GET /echo HTTP/1.1",
"Upgrade: websocket",
"Connection: Upgrade",
"Host: localhost:%s" % self.port,
"Origin: http://localhost:%s" % self.port,
"Sec-WebSocket-Version: 13",
"Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==", ]
sock = eventlet.connect(
('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n')
first_resp = sock.recv(1024)
ws = websocket.RFC6455WebSocket(sock, {}, client=True)
ws.send('hello')
assert ws.wait() == 'hello'
ws.send('hello world!\x01')
ws.send(u'hello world again!')
assert ws.wait() == 'hello world!\x01'
assert ws.wait() == u'hello world again!'
ws.close()
eventlet.sleep(0.01)
def test_breaking_the_connection_13(self):
error_detected = [False]
done_with_request = event.Event()
site = self.site
def error_detector(environ, start_response):
try:
try:
return site(environ, start_response)
except:
error_detected[0] = True
raise
finally:
done_with_request.send(True)
self.site = error_detector
self.spawn_server()
connect = [
"GET /echo HTTP/1.1",
"Upgrade: websocket",
"Connection: Upgrade",
"Host: localhost:%s" % self.port,
"Origin: http://localhost:%s" % self.port,
"Sec-WebSocket-Version: 13",
"Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==", ]
sock = eventlet.connect(
('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n')
resp = sock.recv(1024) # get the headers
sock.close() # close while the app is running
done_with_request.wait()
self.assert_(not error_detected[0])
def test_client_closing_connection_13(self):
error_detected = [False]
done_with_request = event.Event()
site = self.site
def error_detector(environ, start_response):
try:
try:
return site(environ, start_response)
except:
error_detected[0] = True
raise
finally:
done_with_request.send(True)
self.site = error_detector
self.spawn_server()
connect = [
"GET /echo HTTP/1.1",
"Upgrade: websocket",
"Connection: Upgrade",
"Host: localhost:%s" % self.port,
"Origin: http://localhost:%s" % self.port,
"Sec-WebSocket-Version: 13",
"Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==", ]
sock = eventlet.connect(
('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n')
resp = sock.recv(1024) # get the headers
closeframe = struct.pack('!BBIH', 1 << 7 | 8, 1 << 7 | 2, 0, 1000)
sock.sendall(closeframe) # "Close the connection" packet.
done_with_request.wait()
self.assert_(not error_detected[0])
def test_client_invalid_packet_13(self):
error_detected = [False]
done_with_request = event.Event()
site = self.site
def error_detector(environ, start_response):
try:
try:
return site(environ, start_response)
except:
error_detected[0] = True
raise
finally:
done_with_request.send(True)
self.site = error_detector
self.spawn_server()
connect = [
"GET /echo HTTP/1.1",
"Upgrade: websocket",
"Connection: Upgrade",
"Host: localhost:%s" % self.port,
"Origin: http://localhost:%s" % self.port,
"Sec-WebSocket-Version: 13",
"Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==", ]
sock = eventlet.connect(
('localhost', self.port))
sock.sendall('\r\n'.join(connect) + '\r\n\r\n')
resp = sock.recv(1024) # get the headers
sock.sendall('\x07\xff') # Weird packet.
done_with_request.wait()
self.assert_(not error_detected[0])