websockets: Add websockets13 support
This commit is contained in:

committed by
Sergey Shepelev

parent
252aeeded5
commit
849d45682f
@@ -1,13 +1,19 @@
|
|||||||
|
import base64
|
||||||
|
import codecs
|
||||||
import collections
|
import collections
|
||||||
import errno
|
import errno
|
||||||
|
from random import Random
|
||||||
import string
|
import string
|
||||||
import struct
|
import struct
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
from socket import error as SocketError
|
from socket import error as SocketError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from hashlib import md5
|
from hashlib import md5, sha1
|
||||||
except ImportError: #pragma NO COVER
|
except ImportError: #pragma NO COVER
|
||||||
from md5 import md5
|
from md5 import md5
|
||||||
|
from sha import sha as sha1
|
||||||
|
|
||||||
import eventlet
|
import eventlet
|
||||||
from eventlet import semaphore
|
from eventlet import semaphore
|
||||||
@@ -15,9 +21,41 @@ from eventlet import wsgi
|
|||||||
from eventlet.green import socket
|
from eventlet.green import socket
|
||||||
from eventlet.support import get_errno
|
from eventlet.support import get_errno
|
||||||
|
|
||||||
|
# Python 2's utf8 decoding is more lenient than we'd like
|
||||||
|
# In order to pass autobahn's testsuite we need stricter validation
|
||||||
|
# if available...
|
||||||
|
for _mod in ('wsaccel.utf8validator', 'autobahn.utf8validator'):
|
||||||
|
# autobahn has it's own python-based validator. in newest versions
|
||||||
|
# this prefers to use wsaccel, a cython based implementation, if available.
|
||||||
|
# wsaccel may also be installed w/out autobahn, or with a earlier version.
|
||||||
|
try:
|
||||||
|
utf8validator = __import__(_mod, {}, {}, [''])
|
||||||
|
except ImportError:
|
||||||
|
utf8validator = None
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
ACCEPTABLE_CLIENT_ERRORS = set((errno.ECONNRESET, errno.EPIPE))
|
ACCEPTABLE_CLIENT_ERRORS = set((errno.ECONNRESET, errno.EPIPE))
|
||||||
|
|
||||||
__all__ = ["WebSocketWSGI", "WebSocket"]
|
__all__ = ["WebSocketWSGI", "WebSocket"]
|
||||||
|
PROTOCOL_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
||||||
|
VALID_CLOSE_STATUS = (range(1000, 1004)
|
||||||
|
+ range(1007, 1012)
|
||||||
|
# 3000-3999: reserved for use by libraries, frameworks,
|
||||||
|
# and applications
|
||||||
|
+ range(3000, 4000)
|
||||||
|
# 4000-4999: reserved for private use and thus can't
|
||||||
|
# be registered
|
||||||
|
+ range(4000, 5000))
|
||||||
|
|
||||||
|
|
||||||
|
class BadRequest(Exception):
|
||||||
|
def __init__(self, status='400 Bad Request', body=None, headers=None):
|
||||||
|
super(Exception, self).__init__()
|
||||||
|
self.status = status
|
||||||
|
self.body = body
|
||||||
|
self.headers = headers
|
||||||
|
|
||||||
|
|
||||||
class WebSocketWSGI(object):
|
class WebSocketWSGI(object):
|
||||||
"""Wraps a websocket handler function in a WSGI application.
|
"""Wraps a websocket handler function in a WSGI application.
|
||||||
@@ -37,29 +75,70 @@ class WebSocketWSGI(object):
|
|||||||
def __init__(self, handler):
|
def __init__(self, handler):
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
self.protocol_version = None
|
self.protocol_version = None
|
||||||
|
self.support_legacy_versions = True
|
||||||
|
self.supported_protocols = []
|
||||||
|
self.origin_checker = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def configured(cls,
|
||||||
|
handler=None,
|
||||||
|
supported_protocols=None,
|
||||||
|
origin_checker=None,
|
||||||
|
support_legacy_versions=False):
|
||||||
|
def decorator(handler):
|
||||||
|
inst = cls(handler)
|
||||||
|
inst.support_legacy_versions = support_legacy_versions
|
||||||
|
inst.origin_checker = origin_checker
|
||||||
|
if supported_protocols:
|
||||||
|
inst.supported_protocols = supported_protocols
|
||||||
|
return inst
|
||||||
|
if handler is None:
|
||||||
|
return decorator
|
||||||
|
return decorator(handler)
|
||||||
|
|
||||||
def __call__(self, environ, start_response):
|
def __call__(self, environ, start_response):
|
||||||
if not (environ.get('HTTP_CONNECTION') == 'Upgrade' and
|
if not (environ.get('HTTP_CONNECTION') == 'Upgrade' and
|
||||||
environ.get('HTTP_UPGRADE').lower() == 'websocket'):
|
environ.get('HTTP_UPGRADE').lower() == 'websocket'):
|
||||||
# need to check a few more things here for true compliance
|
# need to check a few more things here for true compliance
|
||||||
start_response('400 Bad Request', [('Connection','close')])
|
start_response('400 Bad Request', [('Connection', 'close')])
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# See if they sent the new-format headers
|
try:
|
||||||
|
if 'HTTP_SEC_WEBSOCKET_VERSION' in environ:
|
||||||
|
ws = self._handle_hybi_request(environ)
|
||||||
|
elif self.support_legacy_versions:
|
||||||
|
ws = self._handle_legacy_request(environ)
|
||||||
|
else:
|
||||||
|
raise BadRequest()
|
||||||
|
except BadRequest, e:
|
||||||
|
status = e.status
|
||||||
|
body = e.body or ''
|
||||||
|
headers = e.headers or []
|
||||||
|
start_response(status,
|
||||||
|
[('Connection', 'close'), ] + headers)
|
||||||
|
return [body]
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.handler(ws)
|
||||||
|
except socket.error, e:
|
||||||
|
if get_errno(e) not in ACCEPTABLE_CLIENT_ERRORS:
|
||||||
|
raise
|
||||||
|
# Make sure we send the closing frame
|
||||||
|
ws._send_closing_frame(True)
|
||||||
|
# use this undocumented feature of eventlet.wsgi to ensure that it
|
||||||
|
# doesn't barf on the fact that we didn't call start_response
|
||||||
|
return wsgi.ALREADY_HANDLED
|
||||||
|
|
||||||
|
def _handle_legacy_request(self, environ):
|
||||||
|
sock = environ['eventlet.input'].get_socket()
|
||||||
|
|
||||||
if 'HTTP_SEC_WEBSOCKET_KEY1' in environ:
|
if 'HTTP_SEC_WEBSOCKET_KEY1' in environ:
|
||||||
self.protocol_version = 76
|
self.protocol_version = 76
|
||||||
if 'HTTP_SEC_WEBSOCKET_KEY2' not in environ:
|
if 'HTTP_SEC_WEBSOCKET_KEY2' not in environ:
|
||||||
# That's bad.
|
raise BadRequest()
|
||||||
start_response('400 Bad Request', [('Connection','close')])
|
|
||||||
return []
|
|
||||||
else:
|
else:
|
||||||
self.protocol_version = 75
|
self.protocol_version = 75
|
||||||
|
|
||||||
# Get the underlying socket and wrap a WebSocket class around it
|
|
||||||
sock = environ['eventlet.input'].get_socket()
|
|
||||||
ws = WebSocket(sock, environ, self.protocol_version)
|
|
||||||
|
|
||||||
# If it's new-version, we need to work out our challenge response
|
|
||||||
if self.protocol_version == 76:
|
if self.protocol_version == 76:
|
||||||
key1 = self._extract_number(environ['HTTP_SEC_WEBSOCKET_KEY1'])
|
key1 = self._extract_number(environ['HTTP_SEC_WEBSOCKET_KEY1'])
|
||||||
key2 = self._extract_number(environ['HTTP_SEC_WEBSOCKET_KEY2'])
|
key2 = self._extract_number(environ['HTTP_SEC_WEBSOCKET_KEY2'])
|
||||||
@@ -98,25 +177,56 @@ class WebSocketWSGI(object):
|
|||||||
"Sec-WebSocket-Origin: %s\r\n"
|
"Sec-WebSocket-Origin: %s\r\n"
|
||||||
"Sec-WebSocket-Protocol: %s\r\n"
|
"Sec-WebSocket-Protocol: %s\r\n"
|
||||||
"Sec-WebSocket-Location: %s\r\n"
|
"Sec-WebSocket-Location: %s\r\n"
|
||||||
"\r\n%s"% (
|
"\r\n%s" % (
|
||||||
environ.get('HTTP_ORIGIN'),
|
environ.get('HTTP_ORIGIN'),
|
||||||
environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'default'),
|
environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'default'),
|
||||||
location,
|
location,
|
||||||
response))
|
response))
|
||||||
else: #pragma NO COVER
|
else: #pragma NO COVER
|
||||||
raise ValueError("Unknown WebSocket protocol version.")
|
raise ValueError("Unknown WebSocket protocol version.")
|
||||||
|
|
||||||
sock.sendall(handshake_reply)
|
sock.sendall(handshake_reply)
|
||||||
try:
|
return WebSocket(sock, environ, self.protocol_version)
|
||||||
self.handler(ws)
|
|
||||||
except socket.error as e:
|
def _handle_hybi_request(self, environ):
|
||||||
if get_errno(e) not in ACCEPTABLE_CLIENT_ERRORS:
|
sock = environ['eventlet.input'].get_socket()
|
||||||
raise
|
hybi_version = environ['HTTP_SEC_WEBSOCKET_VERSION']
|
||||||
# Make sure we send the closing frame
|
if hybi_version not in ('8', '13', ):
|
||||||
ws._send_closing_frame(True)
|
raise BadRequest(status='426 Upgrade Required',
|
||||||
# use this undocumented feature of eventlet.wsgi to ensure that it
|
headers=[('Sec-WebSocket-Version', '8, 13')])
|
||||||
# doesn't barf on the fact that we didn't call start_response
|
self.protocol_version = int(hybi_version)
|
||||||
return wsgi.ALREADY_HANDLED
|
if 'HTTP_SEC_WEBSOCKET_KEY' not in environ:
|
||||||
|
# That's bad.
|
||||||
|
raise BadRequest()
|
||||||
|
origin = environ.get(
|
||||||
|
'HTTP_ORIGIN',
|
||||||
|
(environ.get('HTTP_SEC_WEBSOCKET_ORIGIN', '')
|
||||||
|
if self.protocol_version <= 8 else ''))
|
||||||
|
if self.origin_checker is not None:
|
||||||
|
if not self.origin_checker(environ.get('HTTP_HOST'), origin):
|
||||||
|
raise BadRequest(status='403 Forbidden')
|
||||||
|
protocols = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', None)
|
||||||
|
negotiated_protocol = None
|
||||||
|
if protocols:
|
||||||
|
for p in (i.strip() for i in protocols.split(',')):
|
||||||
|
if p in self.supported_protocols:
|
||||||
|
negotiated_protocol = p
|
||||||
|
break
|
||||||
|
#extensions = environ.get('HTTP_SEC_WEBSOCKET_EXTENSIONS', None)
|
||||||
|
#if extensions:
|
||||||
|
# extensions = [i.strip() for i in extensions.split(',')]
|
||||||
|
|
||||||
|
key = environ['HTTP_SEC_WEBSOCKET_KEY']
|
||||||
|
response = base64.b64encode(sha1(key + PROTOCOL_GUID).digest())
|
||||||
|
handshake_reply = ["HTTP/1.1 101 Switching Protocols",
|
||||||
|
"Upgrade: websocket",
|
||||||
|
"Connection: Upgrade",
|
||||||
|
"Sec-WebSocket-Accept: %s" % (response, )]
|
||||||
|
if negotiated_protocol:
|
||||||
|
handshake_reply.append("Sec-WebSocket-Protocol: %s"
|
||||||
|
% (negotiated_protocol, ))
|
||||||
|
sock.sendall('\r\n'.join(handshake_reply) + '\r\n\r\n')
|
||||||
|
return RFC6455WebSocket(sock, environ, self.protocol_version,
|
||||||
|
protocol=negotiated_protocol)
|
||||||
|
|
||||||
def _extract_number(self, value):
|
def _extract_number(self, value):
|
||||||
"""
|
"""
|
||||||
@@ -265,3 +375,279 @@ class WebSocket(object):
|
|||||||
self.socket.shutdown(True)
|
self.socket.shutdown(True)
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionClosedError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FailedConnectionError(Exception):
|
||||||
|
def __init__(self, status, message):
|
||||||
|
super(FailedConnectionError, self).__init__(status, message)
|
||||||
|
self.message = message
|
||||||
|
self.status = status
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RFC6455WebSocket(WebSocket):
|
||||||
|
def __init__(self, sock, environ, version=13, protocol=None, client=False):
|
||||||
|
super(RFC6455WebSocket, self).__init__(sock, environ, version)
|
||||||
|
self.iterator = self._iter_frames()
|
||||||
|
self.client = client
|
||||||
|
self.protocol = protocol
|
||||||
|
|
||||||
|
class UTF8Decoder(object):
|
||||||
|
def __init__(self):
|
||||||
|
if utf8validator:
|
||||||
|
self.validator = utf8validator.Utf8Validator()
|
||||||
|
else:
|
||||||
|
self.validator = None
|
||||||
|
decoderclass = codecs.getincrementaldecoder('utf8')
|
||||||
|
self.decoder = decoderclass()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
if self.validator:
|
||||||
|
self.validator.reset()
|
||||||
|
self.decoder.reset()
|
||||||
|
|
||||||
|
def decode(self, data, final=False):
|
||||||
|
if self.validator:
|
||||||
|
valid, eocp, c_i, t_i = self.validator.validate(data)
|
||||||
|
if not valid:
|
||||||
|
raise ValueError('Data is not valid unicode')
|
||||||
|
return self.decoder.decode(data, final)
|
||||||
|
|
||||||
|
def _get_bytes(self, numbytes):
|
||||||
|
data = ''
|
||||||
|
while len(data) < numbytes:
|
||||||
|
d = self.socket.recv(numbytes - len(data))
|
||||||
|
if not d:
|
||||||
|
raise ConnectionClosedError()
|
||||||
|
data = data + d
|
||||||
|
return data
|
||||||
|
|
||||||
|
class Message(object):
|
||||||
|
def __init__(self, opcode, decoder=None):
|
||||||
|
self.decoder = decoder
|
||||||
|
self.data = []
|
||||||
|
self.finished = False
|
||||||
|
self.opcode = opcode
|
||||||
|
|
||||||
|
def push(self, data, final=False):
|
||||||
|
if self.decoder:
|
||||||
|
data = self.decoder.decode(data, final=final)
|
||||||
|
self.finished = final
|
||||||
|
self.data.append(data)
|
||||||
|
|
||||||
|
def getvalue(self):
|
||||||
|
return ''.join(self.data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _apply_mask(data, mask, length=None, offset=0):
|
||||||
|
if length is None:
|
||||||
|
length = len(data)
|
||||||
|
cnt = xrange(length)
|
||||||
|
return ''.join(chr(ord(data[i]) ^ mask[(offset + i) % 4]) for i in cnt)
|
||||||
|
|
||||||
|
def _handle_control_frame(self, opcode, data):
|
||||||
|
if opcode == 8: # connection close
|
||||||
|
if not data:
|
||||||
|
status = 1000
|
||||||
|
elif len(data) > 1:
|
||||||
|
status = struct.unpack_from('!H', data)[0]
|
||||||
|
if not status or status not in VALID_CLOSE_STATUS:
|
||||||
|
raise FailedConnectionError(
|
||||||
|
1002,
|
||||||
|
"Unexpected close status code.")
|
||||||
|
try:
|
||||||
|
data = self.UTF8Decoder().decode(data[2:], True)
|
||||||
|
except (UnicodeDecodeError, ValueError):
|
||||||
|
raise FailedConnectionError(
|
||||||
|
1002,
|
||||||
|
"Close message data should be valid UTF-8.")
|
||||||
|
else:
|
||||||
|
status = 1002
|
||||||
|
self.close(close_data=(status, ''))
|
||||||
|
raise ConnectionClosedError()
|
||||||
|
elif opcode == 9: # ping
|
||||||
|
self.send(data, control_code=0xA)
|
||||||
|
elif opcode == 0xA: # pong
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise FailedConnectionError(
|
||||||
|
1002, "Unknown control frame received.")
|
||||||
|
|
||||||
|
def _iter_frames(self):
|
||||||
|
fragmented_message = None
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = self._recv_frame(message=fragmented_message)
|
||||||
|
if message.opcode & 8:
|
||||||
|
self._handle_control_frame(
|
||||||
|
message.opcode, message.getvalue())
|
||||||
|
continue
|
||||||
|
if fragmented_message and message is not fragmented_message:
|
||||||
|
raise RuntimeError('Unexpected message change.')
|
||||||
|
fragmented_message = message
|
||||||
|
if message.finished:
|
||||||
|
data = fragmented_message.getvalue()
|
||||||
|
fragmented_message = None
|
||||||
|
yield data
|
||||||
|
except FailedConnectionError:
|
||||||
|
exc_typ, exc_val, exc_tb = sys.exc_info()
|
||||||
|
self.close(close_data=(exc_val.status, exc_val.message))
|
||||||
|
except ConnectionClosedError:
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
self.close(close_data=(1011, 'Internal Server Error'))
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _recv_frame(self, message=None):
|
||||||
|
recv = self._get_bytes
|
||||||
|
header = recv(2)
|
||||||
|
a, b = struct.unpack('!BB', header)
|
||||||
|
finished = a >> 7 == 1
|
||||||
|
rsv123 = a >> 4 & 7
|
||||||
|
if rsv123:
|
||||||
|
# must be zero
|
||||||
|
raise FailedConnectionError(
|
||||||
|
1002,
|
||||||
|
"RSV1, RSV2, RSV3: MUST be 0 unless an extension is"
|
||||||
|
" negotiated that defines meanings for non-zero values.")
|
||||||
|
opcode = a & 15
|
||||||
|
if opcode not in (0, 1, 2, 8, 9, 0xA):
|
||||||
|
raise FailedConnectionError(1002, "Unknown opcode received.")
|
||||||
|
masked = b & 128 == 128
|
||||||
|
if not masked and not self.client:
|
||||||
|
raise FailedConnectionError(1002, "A client MUST mask all frames"
|
||||||
|
" that it sends to the server")
|
||||||
|
length = b & 127
|
||||||
|
if opcode & 8:
|
||||||
|
if not finished:
|
||||||
|
raise FailedConnectionError(1002, "Control frames must not"
|
||||||
|
" be fragmented.")
|
||||||
|
if length > 125:
|
||||||
|
raise FailedConnectionError(
|
||||||
|
1002,
|
||||||
|
"All control frames MUST have a payload length of 125"
|
||||||
|
" bytes or less")
|
||||||
|
elif opcode and message:
|
||||||
|
raise FailedConnectionError(
|
||||||
|
1002,
|
||||||
|
"Received a non-continuation opcode within"
|
||||||
|
" fragmented message.")
|
||||||
|
elif not opcode and not message:
|
||||||
|
raise FailedConnectionError(
|
||||||
|
1002,
|
||||||
|
"Received continuation opcode with no previous"
|
||||||
|
" fragments received.")
|
||||||
|
if length == 126:
|
||||||
|
length = struct.unpack('!H', recv(2))[0]
|
||||||
|
elif length == 127:
|
||||||
|
length = struct.unpack('!Q', recv(8))[0]
|
||||||
|
if masked:
|
||||||
|
mask = struct.unpack('!BBBB', recv(4))
|
||||||
|
received = 0
|
||||||
|
if not message or opcode & 8:
|
||||||
|
decoder = self.UTF8Decoder() if opcode == 1 else None
|
||||||
|
message = self.Message(opcode, decoder=decoder)
|
||||||
|
if not length:
|
||||||
|
message.push('', final=finished)
|
||||||
|
else:
|
||||||
|
while received < length:
|
||||||
|
d = self.socket.recv(length - received)
|
||||||
|
if not d:
|
||||||
|
raise ConnectionClosedError()
|
||||||
|
dlen = len(d)
|
||||||
|
if masked:
|
||||||
|
d = self._apply_mask(d, mask, length=dlen, offset=received)
|
||||||
|
received = received + dlen
|
||||||
|
try:
|
||||||
|
message.push(d, final=finished)
|
||||||
|
except (UnicodeDecodeError, ValueError):
|
||||||
|
raise FailedConnectionError(
|
||||||
|
1007, "Text data must be valid utf-8")
|
||||||
|
return message
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pack_message(message, masked=False,
|
||||||
|
continuation=False, final=True, control_code=None):
|
||||||
|
is_text = False
|
||||||
|
if isinstance(message, unicode):
|
||||||
|
message = message.encode('utf-8')
|
||||||
|
is_text = True
|
||||||
|
length = len(message)
|
||||||
|
if not length:
|
||||||
|
# no point masking empty data
|
||||||
|
masked = False
|
||||||
|
if control_code:
|
||||||
|
if control_code not in (8, 9, 0xA):
|
||||||
|
raise ProtocolError('Unknown control opcode.')
|
||||||
|
if continuation or not final:
|
||||||
|
raise ProtocolError('Control frame cannot be a fragment.')
|
||||||
|
if length > 125:
|
||||||
|
raise ProtocolError('Control frame data too large (>125).')
|
||||||
|
header = struct.pack('!B', control_code | 1 << 7)
|
||||||
|
else:
|
||||||
|
opcode = 0 if continuation else (1 if is_text else 2)
|
||||||
|
header = struct.pack('!B', opcode | (1 << 7 if final else 0))
|
||||||
|
lengthdata = 1 << 7 if masked else 0
|
||||||
|
if length > 65535:
|
||||||
|
lengthdata = struct.pack('!BQ', lengthdata | 127, length)
|
||||||
|
elif length > 125:
|
||||||
|
lengthdata = struct.pack('!BH', lengthdata | 126, length)
|
||||||
|
else:
|
||||||
|
lengthdata = struct.pack('!B', lengthdata | length)
|
||||||
|
if masked:
|
||||||
|
# NOTE: RFC6455 states:
|
||||||
|
# A server MUST NOT mask any frames that it sends to the client
|
||||||
|
rand = Random(time.time())
|
||||||
|
mask = map(rand.getrandbits, (8, ) * 4)
|
||||||
|
message = RFC6455WebSocket._apply_mask(message, mask, length)
|
||||||
|
maskdata = struct.pack('!BBBB', *mask)
|
||||||
|
else:
|
||||||
|
maskdata = ''
|
||||||
|
return ''.join((header, lengthdata, maskdata, message))
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
for i in self.iterator:
|
||||||
|
return i
|
||||||
|
|
||||||
|
def _send(self, frame):
|
||||||
|
self._sendlock.acquire()
|
||||||
|
try:
|
||||||
|
self.socket.sendall(frame)
|
||||||
|
finally:
|
||||||
|
self._sendlock.release()
|
||||||
|
|
||||||
|
def send(self, message, **kw):
|
||||||
|
kw['masked'] = self.client
|
||||||
|
payload = self._pack_message(message, **kw)
|
||||||
|
self._send(payload)
|
||||||
|
|
||||||
|
def _send_closing_frame(self, ignore_send_errors=False, close_data=None):
|
||||||
|
if self.version in (8, 13) and not self.websocket_closed:
|
||||||
|
if close_data is not None:
|
||||||
|
status, msg = close_data
|
||||||
|
if isinstance(msg, unicode):
|
||||||
|
msg = msg.encode('utf-8')
|
||||||
|
data = struct.pack('!H', status) + msg
|
||||||
|
else:
|
||||||
|
data = ''
|
||||||
|
try:
|
||||||
|
self.send(data, control_code=8)
|
||||||
|
except SocketError:
|
||||||
|
# Sometimes, like when the remote side cuts off the connection,
|
||||||
|
# we don't care about this.
|
||||||
|
if not ignore_send_errors: # pragma NO COVER
|
||||||
|
raise
|
||||||
|
self.websocket_closed = True
|
||||||
|
|
||||||
|
def close(self, close_data=None):
|
||||||
|
"""Forcibly close the websocket; generally it is preferable to
|
||||||
|
return from the handler method."""
|
||||||
|
self._send_closing_frame(close_data=close_data)
|
||||||
|
self.socket.shutdown(socket.SHUT_WR)
|
||||||
|
self.socket.close()
|
||||||
|
207
tests/websocket_new_test.py
Normal file
207
tests/websocket_new_test.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
import errno
|
||||||
|
import struct
|
||||||
|
|
||||||
|
import eventlet
|
||||||
|
from eventlet import event
|
||||||
|
from eventlet.green import httplib
|
||||||
|
from eventlet.green import socket
|
||||||
|
from eventlet import websocket
|
||||||
|
|
||||||
|
from tests.wsgi_test import _TestBase
|
||||||
|
|
||||||
|
|
||||||
|
# demo app
|
||||||
|
def handle(ws):
|
||||||
|
if ws.path == '/echo':
|
||||||
|
while True:
|
||||||
|
m = ws.wait()
|
||||||
|
if m is None:
|
||||||
|
break
|
||||||
|
ws.send(m)
|
||||||
|
elif ws.path == '/range':
|
||||||
|
for i in xrange(10):
|
||||||
|
ws.send("msg %d" % i)
|
||||||
|
eventlet.sleep(0.01)
|
||||||
|
elif ws.path == '/error':
|
||||||
|
# some random socket error that we shouldn't normally get
|
||||||
|
raise socket.error(errno.ENOTSOCK)
|
||||||
|
else:
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
wsapp = websocket.WebSocketWSGI(handle)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebSocket(_TestBase):
|
||||||
|
TEST_TIMEOUT = 5
|
||||||
|
|
||||||
|
def set_site(self):
|
||||||
|
self.site = wsapp
|
||||||
|
|
||||||
|
def test_incomplete_headers_13(self):
|
||||||
|
headers = dict(kv.split(': ') for kv in [
|
||||||
|
"Upgrade: websocket",
|
||||||
|
# NOTE: intentionally no connection header
|
||||||
|
"Host: localhost:%s" % self.port,
|
||||||
|
"Origin: http://localhost:%s" % self.port,
|
||||||
|
"Sec-WebSocket-Version: 13", ])
|
||||||
|
http = httplib.HTTPConnection('localhost', self.port)
|
||||||
|
http.request("GET", "/echo", headers=headers)
|
||||||
|
resp = http.getresponse()
|
||||||
|
|
||||||
|
self.assertEqual(resp.status, 400)
|
||||||
|
self.assertEqual(resp.getheader('connection'), 'close')
|
||||||
|
self.assertEqual(resp.read(), '')
|
||||||
|
|
||||||
|
# Now, miss off key
|
||||||
|
headers = dict(kv.split(': ') for kv in [
|
||||||
|
"Upgrade: websocket",
|
||||||
|
"Connection: Upgrade",
|
||||||
|
"Host: localhost:%s" % self.port,
|
||||||
|
"Origin: http://localhost:%s" % self.port,
|
||||||
|
"Sec-WebSocket-Version: 13", ])
|
||||||
|
http = httplib.HTTPConnection('localhost', self.port)
|
||||||
|
http.request("GET", "/echo", headers=headers)
|
||||||
|
resp = http.getresponse()
|
||||||
|
|
||||||
|
self.assertEqual(resp.status, 400)
|
||||||
|
self.assertEqual(resp.getheader('connection'), 'close')
|
||||||
|
self.assertEqual(resp.read(), '')
|
||||||
|
|
||||||
|
def test_correct_upgrade_request_13(self):
|
||||||
|
connect = [
|
||||||
|
"GET /echo HTTP/1.1",
|
||||||
|
"Upgrade: websocket",
|
||||||
|
"Connection: Upgrade",
|
||||||
|
"Host: localhost:%s" % self.port,
|
||||||
|
"Origin: http://localhost:%s" % self.port,
|
||||||
|
"Sec-WebSocket-Version: 13",
|
||||||
|
"Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==", ]
|
||||||
|
sock = eventlet.connect(
|
||||||
|
('localhost', self.port))
|
||||||
|
|
||||||
|
sock.sendall('\r\n'.join(connect) + '\r\n\r\n')
|
||||||
|
result = sock.recv(1024)
|
||||||
|
## The server responds the correct Websocket handshake
|
||||||
|
self.assertEqual(result,
|
||||||
|
'\r\n'.join(['HTTP/1.1 101 Switching Protocols',
|
||||||
|
'Upgrade: websocket',
|
||||||
|
'Connection: Upgrade',
|
||||||
|
'Sec-WebSocket-Accept: ywSyWXCPNsDxLrQdQrn5RFNRfBU=\r\n\r\n', ]))
|
||||||
|
|
||||||
|
def test_send_recv_13(self):
|
||||||
|
connect = [
|
||||||
|
"GET /echo HTTP/1.1",
|
||||||
|
"Upgrade: websocket",
|
||||||
|
"Connection: Upgrade",
|
||||||
|
"Host: localhost:%s" % self.port,
|
||||||
|
"Origin: http://localhost:%s" % self.port,
|
||||||
|
"Sec-WebSocket-Version: 13",
|
||||||
|
"Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==", ]
|
||||||
|
sock = eventlet.connect(
|
||||||
|
('localhost', self.port))
|
||||||
|
|
||||||
|
sock.sendall('\r\n'.join(connect) + '\r\n\r\n')
|
||||||
|
first_resp = sock.recv(1024)
|
||||||
|
ws = websocket.RFC6455WebSocket(sock, {}, client=True)
|
||||||
|
ws.send('hello')
|
||||||
|
assert ws.wait() == 'hello'
|
||||||
|
ws.send('hello world!\x01')
|
||||||
|
ws.send(u'hello world again!')
|
||||||
|
assert ws.wait() == 'hello world!\x01'
|
||||||
|
assert ws.wait() == u'hello world again!'
|
||||||
|
ws.close()
|
||||||
|
eventlet.sleep(0.01)
|
||||||
|
|
||||||
|
def test_breaking_the_connection_13(self):
|
||||||
|
error_detected = [False]
|
||||||
|
done_with_request = event.Event()
|
||||||
|
site = self.site
|
||||||
|
def error_detector(environ, start_response):
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
return site(environ, start_response)
|
||||||
|
except:
|
||||||
|
error_detected[0] = True
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
done_with_request.send(True)
|
||||||
|
self.site = error_detector
|
||||||
|
self.spawn_server()
|
||||||
|
connect = [
|
||||||
|
"GET /echo HTTP/1.1",
|
||||||
|
"Upgrade: websocket",
|
||||||
|
"Connection: Upgrade",
|
||||||
|
"Host: localhost:%s" % self.port,
|
||||||
|
"Origin: http://localhost:%s" % self.port,
|
||||||
|
"Sec-WebSocket-Version: 13",
|
||||||
|
"Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==", ]
|
||||||
|
sock = eventlet.connect(
|
||||||
|
('localhost', self.port))
|
||||||
|
sock.sendall('\r\n'.join(connect) + '\r\n\r\n')
|
||||||
|
resp = sock.recv(1024) # get the headers
|
||||||
|
sock.close() # close while the app is running
|
||||||
|
done_with_request.wait()
|
||||||
|
self.assert_(not error_detected[0])
|
||||||
|
|
||||||
|
def test_client_closing_connection_13(self):
|
||||||
|
error_detected = [False]
|
||||||
|
done_with_request = event.Event()
|
||||||
|
site = self.site
|
||||||
|
def error_detector(environ, start_response):
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
return site(environ, start_response)
|
||||||
|
except:
|
||||||
|
error_detected[0] = True
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
done_with_request.send(True)
|
||||||
|
self.site = error_detector
|
||||||
|
self.spawn_server()
|
||||||
|
connect = [
|
||||||
|
"GET /echo HTTP/1.1",
|
||||||
|
"Upgrade: websocket",
|
||||||
|
"Connection: Upgrade",
|
||||||
|
"Host: localhost:%s" % self.port,
|
||||||
|
"Origin: http://localhost:%s" % self.port,
|
||||||
|
"Sec-WebSocket-Version: 13",
|
||||||
|
"Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==", ]
|
||||||
|
sock = eventlet.connect(
|
||||||
|
('localhost', self.port))
|
||||||
|
sock.sendall('\r\n'.join(connect) + '\r\n\r\n')
|
||||||
|
resp = sock.recv(1024) # get the headers
|
||||||
|
closeframe = struct.pack('!BBIH', 1 << 7 | 8, 1 << 7 | 2, 0, 1000)
|
||||||
|
sock.sendall(closeframe) # "Close the connection" packet.
|
||||||
|
done_with_request.wait()
|
||||||
|
self.assert_(not error_detected[0])
|
||||||
|
|
||||||
|
def test_client_invalid_packet_13(self):
|
||||||
|
error_detected = [False]
|
||||||
|
done_with_request = event.Event()
|
||||||
|
site = self.site
|
||||||
|
def error_detector(environ, start_response):
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
return site(environ, start_response)
|
||||||
|
except:
|
||||||
|
error_detected[0] = True
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
done_with_request.send(True)
|
||||||
|
self.site = error_detector
|
||||||
|
self.spawn_server()
|
||||||
|
connect = [
|
||||||
|
"GET /echo HTTP/1.1",
|
||||||
|
"Upgrade: websocket",
|
||||||
|
"Connection: Upgrade",
|
||||||
|
"Host: localhost:%s" % self.port,
|
||||||
|
"Origin: http://localhost:%s" % self.port,
|
||||||
|
"Sec-WebSocket-Version: 13",
|
||||||
|
"Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==", ]
|
||||||
|
sock = eventlet.connect(
|
||||||
|
('localhost', self.port))
|
||||||
|
sock.sendall('\r\n'.join(connect) + '\r\n\r\n')
|
||||||
|
resp = sock.recv(1024) # get the headers
|
||||||
|
sock.sendall('\x07\xff') # Weird packet.
|
||||||
|
done_with_request.wait()
|
||||||
|
self.assert_(not error_detected[0])
|
Reference in New Issue
Block a user