Refactor and add IETF-07 protocol version support.

- Add initial IETF-07 (HyBi-07) protocol version support. This version
  still uses base64 encoding since the API for binary support is not
  yet finalized.

- Move socket send and recieve functions into the WebSocketServer
  class instead of having the sub-class do this. This simplifies
  sub-classes somewhat. The send_frame routine now returns the number
  of frames that were unable to be sent. If this value is non-zero
  then the sub-class should call again when the socket is ready until
  the pending frames count is 0.

- Do traffic reporting in the main class instead.

- When the client is HyBi style (i.e. IETF-07) then use the
  sub-protocol header to select whether to do base64 encoding or
  simply send the frame data raw (binary). Update include/websock.js
  to send a 'base64' protocol selector. Once the API support binary,
  then the client will need to detect this and set the protocol to
  'binary'.
This commit is contained in:
Joel Martin 2011-05-01 22:17:04 -05:00
parent f3bb09a7f2
commit b681bd8921
11 changed files with 498 additions and 467 deletions

View File

@ -252,23 +252,23 @@ function init() {
function open(uri) { function open(uri) {
init(); init();
websocket = new WebSocket(uri, 'websockify'); websocket = new WebSocket(uri, 'base64');
websocket.onmessage = recv_message; websocket.onmessage = recv_message;
websocket.onopen = function(e) { websocket.onopen = function() {
Util.Debug(">> WebSock.onopen"); Util.Debug(">> WebSock.onopen");
eventHandlers.open(); eventHandlers.open();
Util.Debug("<< WebSock.onopen"); Util.Debug("<< WebSock.onopen");
}; };
websocket.onclose = function(e) { websocket.onclose = function(e) {
Util.Debug(">> WebSock.onclose"); Util.Debug(">> WebSock.onclose");
eventHandlers.close(); eventHandlers.close(e);
Util.Debug("<< WebSock.onclose"); Util.Debug("<< WebSock.onclose");
}; };
websocket.onerror = function(e) { websocket.onerror = function(e) {
Util.Debug("<< WebSock.onerror: " + e); Util.Debug(">> WebSock.onerror: " + e);
eventHandlers.error(e); eventHandlers.error(e);
Util.Debug("<< WebSock.onerror: "); Util.Debug("<< WebSock.onerror");
}; };
} }

View File

@ -92,7 +92,7 @@
} }
uri = scheme + host + ":" + port; uri = scheme + host + ":" + port;
message("connecting to " + uri); message("connecting to " + uri);
ws = new WebSocket(uri); ws = new WebSocket(uri, "base64");
ws.onmessage = function(e) { ws.onmessage = function(e) {
//console.log(">> WebSockets.onmessage"); //console.log(">> WebSockets.onmessage");

View File

@ -8,84 +8,68 @@ Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
You can make a cert/key with openssl using: You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' '''
import os, sys, socket, select import os, sys, select, optparse
sys.path.insert(0,os.path.dirname(__file__) + "/../") sys.path.insert(0,os.path.dirname(__file__) + "/../")
from websocket import WebSocketServer from websocket import WebSocketServer
class WebSocketEcho(WebSocketServer): class WebSocketEcho(WebSocketServer):
""" """
WebSockets server that echo back whatever is received from the WebSockets server that echos back whatever is received from the
client. All traffic to/from the client is base64 client. """
encoded/decoded.
"""
buffer_size = 8096 buffer_size = 8096
def new_client(self, client): def new_client(self):
""" """
Echo back whatever is received. Echo back whatever is received.
""" """
cqueue = [] cqueue = []
c_pend = 0
cpartial = "" cpartial = ""
rlist = [client] rlist = [self.client]
while True: while True:
wlist = [] wlist = []
if cqueue: wlist.append(client) if cqueue or c_pend: wlist.append(self.client)
ins, outs, excepts = select.select(rlist, wlist, [], 1) ins, outs, excepts = select.select(rlist, wlist, [], 1)
if excepts: raise Exception("Socket exception") if excepts: raise Exception("Socket exception")
if client in outs: if self.client in outs:
# Send queued target data to the client # Send queued target data to the client
dat = cqueue.pop(0) c_pend = self.send_frames(cqueue)
sent = client.send(dat) cqueue = []
self.vmsg("Sent %s/%s bytes of frame: '%s'" % (
sent, len(dat), self.decode(dat)[0]))
if sent != len(dat):
# requeue the remaining data
cqueue.insert(0, dat[sent:])
if self.client in ins:
if client in ins:
# Receive client data, decode it, and send it back # Receive client data, decode it, and send it back
buf = client.recv(self.buffer_size) frames, closed = self.recv_frames()
if len(buf) == 0: raise self.EClose("Client closed") cqueue.extend(frames)
if buf == '\xff\x00': if closed:
raise self.EClose("Client sent orderly close frame") self.send_close()
elif buf[-1] == '\xff': raise self.EClose(closed)
if cpartial:
# Prepend saved partial and decode frame(s)
frames = self.decode(cpartial + buf)
cpartial = ""
else:
# decode frame(s)
frames = self.decode(buf)
for frame in frames:
self.vmsg("Received frame: %s" % repr(frame))
cqueue.append(self.encode(frame))
else:
# Save off partial WebSockets frame
self.vmsg("Received partial frame")
cpartial = cpartial + buf
if __name__ == '__main__': if __name__ == '__main__':
try: parser = optparse.OptionParser(usage="%prog [options] listen_port")
if len(sys.argv) < 1: raise parser.add_option("--verbose", "-v", action="store_true",
listen_port = int(sys.argv[1]) help="verbose messages and per frame traffic")
except: parser.add_option("--cert", default="self.pem",
print "Usage: %s <listen_port>" % sys.argv[0] help="SSL certificate file")
sys.exit(1) parser.add_option("--key", default=None,
help="SSL key file (if separate from cert)")
parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted connections")
(opts, args) = parser.parse_args()
server = WebSocketEcho( try:
listen_port=listen_port, if len(args) != 1: raise
#verbose=True, opts.listen_port = int(args[0])
cert='self.pem', except:
web='.') parser.error("Invalid arguments")
opts.web = "."
server = WebSocketEcho(**opts.__dict__)
server.start_server() server.start_server()

View File

@ -1,151 +0,0 @@
<html>
<head><title>WebSockets Test</title></head>
<body>
Host: <input id='host' style='width:100'>&nbsp;
Port: <input id='port' style='width:50'>&nbsp;
Encrypt: <input id='encrypt' type='checkbox'>&nbsp;
<input id='connectButton' type='button' value='Start' style='width:100px'
onclick="connect();">&nbsp;
<br>
Messages:<br>
<textarea id="messages" style="font-size: 9;" cols=80 rows=25></textarea>
</body>
<!-- Uncomment to activate firebug lite -->
<!--
<script type='text/javascript'
src='http://getfirebug.com/releases/lite/1.2/firebug-lite-compressed.js'></script>
-->
<script src="include/base64.js"></script>
<script src="include/util.js"></script>
<script src="include/webutil.js"></script>
<script>
var host = null, port = null;
var ws = null;
var VNC_native_ws = true;
function message(str) {
console.log(str);
cell = $D('messages');
cell.innerHTML += str + "\n";
cell.scrollTop = cell.scrollHeight;
}
function print_response(str) {
message("str.length: " + str.length);
for (i=0; i < str.length; i++) {
message(i + ": " + (str.charCodeAt(i) % 256));
}
}
function send() {
var str = "";
str = str + String.fromCharCode(0x81);
str = str + String.fromCharCode(0xff);
for (var i=0; i<256; i+=4) {
str = str + String.fromCharCode(i);
}
str = str + String.fromCharCode(0);
str = str + String.fromCharCode(0x40);
str = str + String.fromCharCode(0x41);
str = str + String.fromCharCode(0xff);
str = str + String.fromCharCode(0x81);
ws.send(str);
}
function init_ws() {
console.log(">> init_ws");
var scheme = "ws://";
if ($D('encrypt').checked) {
scheme = "wss://";
}
var uri = scheme + host + ":" + port;
console.log("connecting to " + uri);
ws = new WebSocket(uri);
ws.onmessage = function(e) {
console.log(">> WebSockets.onmessage");
print_response(e.data);
console.log("<< WebSockets.onmessage");
};
ws.onopen = function(e) {
console.log(">> WebSockets.onopen");
send();
console.log("<< WebSockets.onopen");
};
ws.onclose = function(e) {
console.log(">> WebSockets.onclose");
console.log("<< WebSockets.onclose");
};
ws.onerror = function(e) {
console.log(">> WebSockets.onerror");
console.log(" " + e);
console.log("<< WebSockets.onerror");
};
console.log("<< init_ws");
}
function connect() {
console.log(">> connect");
host = $D('host').value;
port = $D('port').value;
if ((!host) || (!port)) {
console.log("must set host and port");
return;
}
if (ws) {
ws.close();
}
init_ws();
$D('connectButton').value = "Stop";
$D('connectButton').onclick = disconnect;
console.log("<< connect");
}
function disconnect() {
console.log(">> disconnect");
if (ws) {
ws.close();
}
$D('connectButton').value = "Start";
$D('connectButton').onclick = connect;
console.log("<< disconnect");
}
/* If no builtin websockets then load web_socket.js */
if (! window.WebSocket) {
console.log("Loading web-socket-js flash bridge");
var extra = "<script src='include/web-socket-js/swfobject.js'><\/script>";
extra += "<script src='include/web-socket-js/FABridge.js'><\/script>";
extra += "<script src='include/web-socket-js/web_socket.js'><\/script>";
document.write(extra);
VNC_native_ws = false;
}
window.onload = function() {
console.log("onload");
if (! VNC_native_ws) {
WebSocket.__swfLocation = "include/web-socket-js/WebSocketMain.swf";
WebSocket.__initialize();
}
var url = document.location.href;
$D('host').value = (url.match(/host=([^&#]*)/) || ['',''])[1];
$D('port').value = (url.match(/port=([^&#]*)/) || ['',''])[1];
}
</script>
</html>

View File

@ -1,87 +0,0 @@
#!/usr/bin/env python
'''
WebSocket server-side load test program. Sends and receives traffic
that has a random payload (length and content) that is checksummed and
given a sequence number. Any errors are reported and counted.
'''
import sys, os, socket, ssl, time, traceback
import random, time
from base64 import b64encode, b64decode
from codecs import utf_8_encode, utf_8_decode
from select import select
sys.path.insert(0,os.path.dirname(__file__) + "/../")
from websocket import *
buffer_size = 65536
recv_cnt = send_cnt = 0
def check(buf):
if buf[0] != '\x00' or buf[-1] != '\xff':
raise Exception("Invalid WS packet")
for decoded in decode(buf):
nums = [ord(c) for c in decoded]
print "Received nums: ", nums
return
def responder(client):
cpartial = ""
socks = [client]
sent = False
received = False
while True:
ins, outs, excepts = select(socks, socks, socks, 1)
if excepts: raise Exception("Socket exception")
if client in ins:
buf = client.recv(buffer_size)
if len(buf) == 0: raise Exception("Client closed")
received = True
#print "Client recv: %s (%d)" % (repr(buf[1:-1]), len(buf))
if buf[-1] == '\xff':
if cpartial:
err = check(cpartial + buf)
cpartial = ""
else:
err = check(buf)
if err:
print err
else:
print "received partitial"
cpartial = cpartial + buf
if received and not sent and client in outs:
sent = True
#nums = "".join([unichr(c) for c in range(0,256)])
#nums = "".join([chr(c) for c in range(1,128)])
#nums = nums + chr(194) + chr(128) + chr(194) + chr(129)
#nums = "".join([chr(c) for c in range(0,256)])
nums = "\x81\xff"
nums = nums + "".join([chr(c) for c in range(0,256,4)])
nums = nums + "\x00\x40\x41\xff\x81"
# print nums
client.send(encode(nums))
# client.send("\x00" + nums + "\xff")
# print "Sent characters 0-255"
# #print "Client send: %s (%d)" % (repr(nums), len(nums))
if __name__ == '__main__':
try:
if len(sys.argv) < 2: raise
listen_port = int(sys.argv[1])
except:
print "Usage: <listen_port>"
sys.exit(1)
settings['listen_port'] = listen_port
settings['daemon'] = False
settings['handler'] = responder
start_server()

View File

@ -166,7 +166,7 @@
} }
timestamp = (new Date()).getTime(); timestamp = (new Date()).getTime();
arr.pushStr("^" + send_seq + ":" + timestamp + ":" + payload + "$") arr.pushStr("^" + send_seq + ":" + timestamp + ":" + payload + "$");
send_seq ++; send_seq ++;
ws.send(arr); ws.send(arr);
sent++; sent++;
@ -196,10 +196,10 @@
ws.maxBufferedAmount = 5000; ws.maxBufferedAmount = 5000;
ws.open(uri); ws.open(uri);
ws.on('message', function(e) { ws.on('message', function() {
recvMsg(); recvMsg();
}); });
ws.on('open', function(e) { ws.on('open', function() {
send_ref = setTimeout(sendMsg, sendDelay); send_ref = setTimeout(sendMsg, sendDelay);
}); });
ws.on('close', function(e) { ws.on('close', function(e) {

View File

@ -2,9 +2,10 @@
<head> <head>
<title>WebSockets Load Test</title> <title>WebSockets Load Test</title>
<script src="include/base64.js"></script>
<script src="include/util.js"></script> <script src="include/util.js"></script>
<script src="include/webutil.js"></script> <script src="include/webutil.js"></script>
<script src="include/base64.js"></script>
<script src="include/websock.js"></script>
<!-- Uncomment to activate firebug lite --> <!-- Uncomment to activate firebug lite -->
<!-- <!--
<script type='text/javascript' <script type='text/javascript'
@ -73,10 +74,9 @@
function check_respond(data) { function check_respond(data) {
//console.log(">> check_respond"); //console.log(">> check_respond");
var decoded, first, last, str, length, chksum, nums, arr; var first, last, str, length, chksum, nums, arr;
decoded = Base64.decode(data); first = String.fromCharCode(data.shift());
first = String.fromCharCode(decoded.shift()); last = String.fromCharCode(data.pop());
last = String.fromCharCode(decoded.pop());
if (first != "^") { if (first != "^") {
errors++; errors++;
@ -88,7 +88,7 @@
error("Packet missing end char '$'"); error("Packet missing end char '$'");
return; return;
} }
arr = decoded.map(function(num) { arr = data.map(function(num) {
return String.fromCharCode(num); return String.fromCharCode(num);
} ).join('').split(':'); } ).join('').split(':');
seq = arr[0]; seq = arr[0];
@ -125,10 +125,6 @@
} }
function send() { function send() {
if (ws.bufferedAmount > 0) {
console.log("Delaying send");
return;
}
var length = Math.floor(Math.random()*(max_send-9)) + 10; // 10 - max_send var length = Math.floor(Math.random()*(max_send-9)) + 10; // 10 - max_send
var numlist = [], arr = []; var numlist = [], arr = [];
for (var i=0; i < length; i++) { for (var i=0; i < length; i++) {
@ -142,7 +138,7 @@
var nums = numlist.join(''); var nums = numlist.join('');
arr.pushStr("^" + send_seq + ":" + length + ":" + chksum + ":" + nums + "$") arr.pushStr("^" + send_seq + ":" + length + ":" + chksum + ":" + nums + "$")
send_seq ++; send_seq ++;
ws.send(Base64.encode(arr)); ws.send(arr);
sent++; sent++;
} }
@ -160,28 +156,30 @@
} }
var uri = scheme + host + ":" + port; var uri = scheme + host + ":" + port;
console.log("connecting to " + uri); console.log("connecting to " + uri);
ws = new WebSocket(uri); ws = new Websock();
ws.open(uri);
ws.onmessage = function(e) { ws.on('message', function() {
//console.log(">> WebSockets.onmessage"); //console.log(">> WebSockets.onmessage");
check_respond(e.data); arr = ws.rQshiftBytes(ws.rQlen());
check_respond(arr);
//console.log("<< WebSockets.onmessage"); //console.log("<< WebSockets.onmessage");
}; });
ws.onopen = function(e) { ws.on('open', function() {
console.log(">> WebSockets.onopen"); console.log(">> WebSockets.onopen");
send_ref = setInterval(send, sendDelay); send_ref = setInterval(send, sendDelay);
console.log("<< WebSockets.onopen"); console.log("<< WebSockets.onopen");
}; });
ws.onclose = function(e) { ws.on('close', function(e) {
console.log(">> WebSockets.onclose"); console.log(">> WebSockets.onclose");
clearInterval(send_ref); clearInterval(send_ref);
console.log("<< WebSockets.onclose"); console.log("<< WebSockets.onclose");
}; });
ws.onerror = function(e) { ws.on('error', function(e) {
console.log(">> WebSockets.onerror"); console.log(">> WebSockets.onerror");
console.log(" " + e); console.log(" " + e);
console.log("<< WebSockets.onerror"); console.log("<< WebSockets.onerror");
}; });
console.log("<< init_ws"); console.log("<< init_ws");
} }

View File

@ -6,17 +6,14 @@ that has a random payload (length and content) that is checksummed and
given a sequence number. Any errors are reported and counted. given a sequence number. Any errors are reported and counted.
''' '''
import sys, os, socket, ssl, time, traceback import sys, os, select, random, time, optparse
import random, time
from select import select
sys.path.insert(0,os.path.dirname(__file__) + "/../") sys.path.insert(0,os.path.dirname(__file__) + "/../")
from websocket import WebSocketServer from websocket import WebSocketServer
class WebSocketLoad(WebSocketServer):
class WebSocketTest(WebSocketServer):
buffer_size = 65536 buffer_size = 65536
max_packet_size = 10000 max_packet_size = 10000
recv_cnt = 0 recv_cnt = 0
send_cnt = 0 send_cnt = 0
@ -32,54 +29,48 @@ class WebSocketTest(WebSocketServer):
WebSocketServer.__init__(self, *args, **kwargs) WebSocketServer.__init__(self, *args, **kwargs)
def new_client(self, client): def new_client(self):
self.send_cnt = 0 self.send_cnt = 0
self.recv_cnt = 0 self.recv_cnt = 0
try: try:
self.responder(client) self.responder(self.client)
except: except:
print "accumulated errors:", self.errors print "accumulated errors:", self.errors
self.errors = 0 self.errors = 0
raise raise
def responder(self, client): def responder(self, client):
c_pend = 0
cqueue = [] cqueue = []
cpartial = "" cpartial = ""
socks = [client] socks = [client]
last_send = time.time() * 1000 last_send = time.time() * 1000
while True: while True:
ins, outs, excepts = select(socks, socks, socks, 1) ins, outs, excepts = select.select(socks, socks, socks, 1)
if excepts: raise Exception("Socket exception") if excepts: raise Exception("Socket exception")
if client in ins: if client in ins:
buf = client.recv(self.buffer_size) frames, closed = self.recv_frames()
if len(buf) == 0:
raise self.EClose("Client closed") err = self.check(frames)
#print "Client recv: %s (%d)" % (repr(buf[1:-1]), len(buf)) if err:
if buf[-1] == '\xff': self.errors = self.errors + 1
if cpartial: print err
err = self.check(cpartial + buf)
cpartial = "" if closed:
else: self.send_close()
err = self.check(buf) raise self.EClose(closed)
if err:
self.traffic("}")
self.errors = self.errors + 1
print err
else:
self.traffic(">")
else:
self.traffic(".>")
cpartial = cpartial + buf
now = time.time() * 1000 now = time.time() * 1000
if client in outs and now > (last_send + self.delay): if client in outs:
last_send = now if c_pend:
#print "Client send: %s" % repr(cqueue[0]) last_send = now
client.send(self.generate()) c_pend = self.send_frames()
self.traffic("<") elif now > (last_send + self.delay):
last_send = now
c_pend = self.send_frames([self.generate()])
def generate(self): def generate(self):
length = random.randint(10, self.max_packet_size) length = random.randint(10, self.max_packet_size)
@ -93,18 +84,13 @@ class WebSocketTest(WebSocketServer):
data = "^%d:%d:%d:%s$" % (self.send_cnt, length, chksum, nums) data = "^%d:%d:%d:%s$" % (self.send_cnt, length, chksum, nums)
self.send_cnt += 1 self.send_cnt += 1
return WebSocketServer.encode(data) return data
def check(self, buf): def check(self, frames):
try:
data_list = WebSocketServer.decode(buf)
except:
print "\n<BOF>" + repr(buf) + "<EOF>"
return "Failed to decode"
err = "" err = ""
for data in data_list: for data in frames:
if data.count('$') > 1: if data.count('$') > 1:
raise Exception("Multiple parts within single packet") raise Exception("Multiple parts within single packet")
if len(data) == 0: if len(data) == 0:
@ -151,21 +137,31 @@ class WebSocketTest(WebSocketServer):
if __name__ == '__main__': if __name__ == '__main__':
try: parser = optparse.OptionParser(usage="%prog [options] listen_port")
if len(sys.argv) < 2: raise parser.add_option("--verbose", "-v", action="store_true",
listen_port = int(sys.argv[1]) help="verbose messages and per frame traffic")
if len(sys.argv) == 3: parser.add_option("--cert", default="self.pem",
delay = int(sys.argv[2]) help="SSL certificate file")
else: parser.add_option("--key", default=None,
delay = 10 help="SSL key file (if separate from cert)")
except: parser.add_option("--ssl-only", action="store_true",
print "Usage: %s <listen_port> [delay_ms]" % sys.argv[0] help="disallow non-encrypted connections")
sys.exit(1) (opts, args) = parser.parse_args()
server = WebSocketTest( try:
listen_port=listen_port, if len(args) != 1: raise
verbose=True, opts.listen_port = int(args[0])
cert='self.pem',
web='.', if len(args) not in [1,2]: raise
delay=delay) opts.listen_port = int(args[0])
if len(args) == 2:
opts.delay = int(args[1])
else:
opts.delay = 10
except:
parser.error("Invalid arguments")
opts.web = "."
server = WebSocketLoad(**opts.__dict__)
server.start_server() server.start_server()

View File

@ -6,14 +6,17 @@ Display UTF-8 encoding for 0-255.'''
import sys, os, socket, ssl, time, traceback import sys, os, socket, ssl, time, traceback
from select import select from select import select
sys.path.insert(0,os.path.dirname(__file__) + "/../utils/") sys.path.insert(0,os.path.dirname(__file__) + "/../")
from websocket import WebSocketServer from websocket import WebSocketServer
if __name__ == '__main__': if __name__ == '__main__':
print "val: hixie | hybi_base64 | hybi_binary"
for c in range(0, 256): for c in range(0, 256):
print "%d: %s" % (c, repr(WebSocketServer.encode(chr(c))[1:-1])) hixie = WebSocketServer.encode_hixie(chr(c))
#nums = "".join([chr(c) for c in range(0,256)]) hybi_base64 = WebSocketServer.encode_hybi(chr(c), opcode=1,
#for char in WebSocketServer.encode(nums): base64=True)
# print "%d" % ord(char), hybi_binary = WebSocketServer.encode_hybi(chr(c), opcode=2,
#print repr(WebSocketServer.encode(nums)) base64=False)
print "%d: %s | %s | %s" % (c, repr(hixie), repr(hybi_base64),
repr(hybi_binary))

View File

@ -5,6 +5,11 @@ Python WebSocket library with support for "wss://" encryption.
Copyright 2010 Joel Martin Copyright 2010 Joel Martin
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3) Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
Supports following protocol versions:
- http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-75
- http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
- http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-07
You can make a cert/key with openssl using: You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates as taken from http://docs.python.org/dev/library/ssl.html#certificates
@ -17,9 +22,15 @@ from SimpleHTTPServer import SimpleHTTPRequestHandler
from cStringIO import StringIO from cStringIO import StringIO
from base64 import b64encode, b64decode from base64 import b64encode, b64decode
try: try:
from hashlib import md5 from hashlib import md5, sha1
except: except:
from md5 import md5 # Support python 2.4 # Support python 2.4
from md5 import md5
from sha import sha as sha1
try:
import numpy, ctypes
except:
numpy = ctypes = None
from urlparse import urlsplit from urlparse import urlsplit
from cgi import parse_qsl from cgi import parse_qsl
@ -29,14 +40,22 @@ class WebSocketServer(object):
Must be sub-classed with new_client method definition. Must be sub-classed with new_client method definition.
""" """
server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r buffer_size = 65536
server_handshake_hixie = """HTTP/1.1 101 Web Socket Protocol Handshake\r
Upgrade: WebSocket\r Upgrade: WebSocket\r
Connection: Upgrade\r Connection: Upgrade\r
%sWebSocket-Origin: %s\r %sWebSocket-Origin: %s\r
%sWebSocket-Location: %s://%s%s\r %sWebSocket-Location: %s://%s%s\r
%sWebSocket-Protocol: sample\r """
\r
%s""" server_handshake_hybi = """HTTP/1.1 101 Switching Protocols\r
Upgrade: websocket\r
Connection: Upgrade\r
Sec-WebSocket-Accept: %s\r
"""
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n""" policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n"""
@ -54,7 +73,6 @@ Connection: Upgrade\r
self.ssl_only = ssl_only self.ssl_only = ssl_only
self.daemon = daemon self.daemon = daemon
# Make paths settings absolute # Make paths settings absolute
self.cert = os.path.abspath(cert) self.cert = os.path.abspath(cert)
self.key = self.web = self.record = '' self.key = self.web = self.record = ''
@ -124,18 +142,129 @@ Connection: Upgrade\r
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno()) os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
@staticmethod @staticmethod
def encode(buf): def encode_hybi(buf, opcode, base64=False):
""" Encode a WebSocket packet. """ """ Encode a HyBi style WebSocket frame.
buf = b64encode(buf) Optional opcode:
return "\x00%s\xff" % buf 0x0 - continuation
0x1 - text frame (base64 encode buf)
0x2 - binary frame (use raw buf)
0x8 - connection close
0x9 - ping
0xA - pong
"""
if base64:
buf = b64encode(buf)
b1 = 0x80 | (opcode & 0x0f) # FIN + opcode
payload_len = len(buf)
if payload_len <= 125:
header = struct.pack('>BB', b1, payload_len)
elif payload_len > 125 and payload_len <= 65536:
header = struct.pack('>BBH', b1, 126, payload_len)
elif payload_len >= 65536:
header = struct.pack('>BBQ', b1, 127, payload_len)
#print "Encoded: %s" % repr(header + buf)
return header + buf
@staticmethod @staticmethod
def decode(buf): def decode_hybi(buf, base64=False):
""" Decode WebSocket packets. """ """ Decode HyBi style WebSocket packets.
if buf.count('\xff') > 1: Returns:
return [b64decode(d[1:]) for d in buf.split('\xff')] {'fin' : 0_or_1,
'opcode' : number,
'mask' : 32_bit_number,
'length' : payload_bytes_number,
'payload' : decoded_buffer,
'left' : bytes_left_number}
"""
ret = {'fin' : 0,
'opcode' : 0,
'mask' : 0,
'length' : 0,
'payload' : None,
'left' : 0}
blen = len(buf)
ret['left'] = blen
header_len = 2
if blen < header_len:
return ret # Incomplete frame header
b1, b2 = struct.unpack_from(">BB", buf)
ret['opcode'] = b1 & 0x0f
ret['fin'] = (b1 & 0x80) >> 7
has_mask = (b2 & 0x80) >> 7
ret['length'] = b2 & 0x7f
if ret['length'] == 126:
header_len = 4
if blen < header_len:
return ret # Incomplete frame header
(ret['length'],) = struct.unpack_from('>xxH', buf)
elif ret['length'] == 127:
header_len = 10
if blen < header_len:
return ret # Incomplete frame header
(ret['length'],) = struct.unpack_from('>xxQ', buf)
full_len = header_len + has_mask * 4 + ret['length']
if blen < full_len: # Incomplete frame
return ret # Incomplete frame header
# Number of bytes that are part of the next frame(s)
ret['left'] = blen - full_len
# Process 1 frame
if has_mask:
# unmask payload
ret['mask'] = buf[header_len:header_len+4]
b = c = ''
if ret['length'] >= 4:
mask = numpy.frombuffer(buf, dtype=numpy.dtype('<L4'),
offset=header_len, count=1)
data = numpy.frombuffer(buf, dtype=numpy.dtype('<L4'),
offset=header_len + 4, count=int(ret['length'] / 4))
#b = numpy.bitwise_xor(data, mask).data
b = numpy.bitwise_xor(data, mask).tostring()
if ret['length'] % 4:
print "Partial unmask"
mask = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
offset=header_len, count=(ret['length'] % 4))
data = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
offset=full_len - (ret['length'] % 4),
count=(ret['length'] % 4))
c = numpy.bitwise_xor(data, mask).tostring()
ret['payload'] = b + c
else: else:
return [b64decode(buf[1:-1])] print "Unmasked frame:", repr(buf)
ret['payload'] = buf[(header_len + has_mask * 4):full_len]
if base64 and ret['opcode'] in [1, 2]:
try:
ret['payload'] = b64decode(ret['payload'])
except:
print "Exception while b64decoding buffer:", repr(buf)
raise
return ret
@staticmethod
def encode_hixie(buf):
return "\x00" + b64encode(buf) + "\xff"
@staticmethod
def decode_hixie(buf):
end = buf.find('\xff')
return {'payload': b64decode(buf[1:end]),
'left': len(buf) - (end + 1)}
@staticmethod @staticmethod
def parse_handshake(handshake): def parse_handshake(handshake):
@ -160,7 +289,7 @@ Connection: Upgrade\r
@staticmethod @staticmethod
def gen_md5(keys): def gen_md5(keys):
""" Generate hash value for WebSockets handshake v76. """ """ Generate hash value for WebSockets hixie-76. """
key1 = keys['Sec-WebSocket-Key1'] key1 = keys['Sec-WebSocket-Key1']
key2 = keys['Sec-WebSocket-Key2'] key2 = keys['Sec-WebSocket-Key2']
key3 = keys['key3'] key3 = keys['key3']
@ -171,7 +300,6 @@ Connection: Upgrade\r
return md5(struct.pack('>II8s', num1, num2, key3)).digest() return md5(struct.pack('>II8s', num1, num2, key3)).digest()
# #
# WebSocketServer logging/output functions # WebSocketServer logging/output functions
# #
@ -195,6 +323,125 @@ Connection: Upgrade\r
# #
# Main WebSocketServer methods # Main WebSocketServer methods
# #
def send_frames(self, bufs=None):
""" Encode and send WebSocket frames. Any frames already
queued will be sent first. If buf is not set then only queued
frames will be sent. Returns the number of pending frames that
could not be fully sent. If returned pending frames is greater
than 0, then the caller should call again when the socket is
ready. """
if bufs:
for buf in bufs:
if self.version.startswith("hybi"):
if self.base64:
self.send_parts.append(self.encode_hybi(buf,
opcode=1, base64=True))
else:
self.send_parts.append(self.encode_hybi(buf,
opcode=2, base64=False))
else:
self.send_parts.append(self.encode_hixie(buf))
while self.send_parts:
# Send pending frames
buf = self.send_parts.pop(0)
sent = self.client.send(buf)
if sent == len(buf):
self.traffic("<")
else:
self.traffic("<.")
self.send_parts.insert(0, buf[sent:])
break
return len(self.send_parts)
def recv_frames(self):
""" Receive and decode WebSocket frames.
Returns:
(bufs_list, closed_string)
"""
closed = False
bufs = []
buf = self.client.recv(self.buffer_size)
if len(buf) == 0:
closed = "Client closed abruptly"
return bufs, closed
if self.recv_part:
# Add partially received frames to current read buffer
buf = self.recv_part + buf
self.recv_part = None
while buf:
if self.version.startswith("hybi"):
frame = self.decode_hybi(buf, base64=self.base64)
#print "Received buf: %s, frame: %s" % (repr(buf), frame)
if frame['payload'] == None:
# Incomplete/partial frame
self.traffic("}.")
if frame['left'] > 0:
self.recv_part = buf[-frame['left']:]
break
else:
if frame['opcode'] == 0x8: # connection close
code, reason = struct.unpack_from(
">H%ds" % (frame['length']-2),
frame['payload'])
closed = "Client closed, reason: %s - %s" % (
code, reason)
break
else:
if buf[0:2] == '\xff\x00':
closed = "Client sent orderly close frame"
break
elif buf[0:2] == '\x00\xff':
buf = buf[2:]
continue # No-op
elif buf.count('\xff') == 0:
# Partial frame
self.traffic("}.")
self.recv_part = buf
break
frame = self.decode_hixie(buf)
self.traffic("}")
bufs.append(frame['payload'])
if frame['left']:
buf = buf[-frame['left']:]
else:
buf = ''
return bufs, closed
def send_close(self, code=None, reason=''):
""" Send a WebSocket orderly close frame. """
if self.version.startswith("hybi"):
msg = ''
if code != None:
msg = struct.pack(">H%ds" % (len(reason)), code)
buf = self.encode_hybi(msg, opcode=0x08, base64=False)
self.client.send(buf)
elif self.version == "hixie-76":
buf = self.encode_hixie('\xff\x00')
self.client.send(buf)
# No orderly close for 75
def do_handshake(self, sock, address): def do_handshake(self, sock, address):
""" """
@ -222,7 +469,7 @@ Connection: Upgrade\r
# Peek, but do not read the data so that we have a opportunity # Peek, but do not read the data so that we have a opportunity
# to SSL wrap the socket first # to SSL wrap the socket first
handshake = sock.recv(1024, socket.MSG_PEEK) handshake = sock.recv(1024, socket.MSG_PEEK)
#self.msg("Handshake [%s]" % repr(handshake)) #self.msg("Handshake [%s]" % handshake)
if handshake == "": if handshake == "":
raise self.EClose("ignoring empty handshake") raise self.EClose("ignoring empty handshake")
@ -268,8 +515,9 @@ Connection: Upgrade\r
raise self.EClose("Client closed during handshake") raise self.EClose("Client closed during handshake")
# Check for and handle normal web requests # Check for and handle normal web requests
if handshake.startswith('GET ') and \ if (handshake.startswith('GET ') and
handshake.find('Upgrade: WebSocket\r\n') == -1: handshake.find('Upgrade: WebSocket\r\n') == -1 and
handshake.find('Upgrade: websocket\r\n') == -1):
if not self.web: if not self.web:
raise self.EClose("Normal web request received but disallowed") raise self.EClose("Normal web request received but disallowed")
sh = SplitHTTPHandler(handshake, retsock, address) sh = SplitHTTPHandler(handshake, retsock, address)
@ -282,26 +530,73 @@ Connection: Upgrade\r
#self.msg("handshake: " + repr(handshake)) #self.msg("handshake: " + repr(handshake))
# Parse client WebSockets handshake # Parse client WebSockets handshake
self.headers = self.parse_handshake(handshake) h = self.headers = self.parse_handshake(handshake)
prot = 'WebSocket-Protocol'
protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',')
ver = h.get('Sec-WebSocket-Version')
if ver:
# HyBi/IETF version of the protocol
if not numpy or not ctypes:
self.EClose("Python numpy and ctypes modules required for HyBi-07 or greater")
if ver == '7':
self.version = "hybi-07"
else:
raise self.EClose('Unsupported protocol version %s' % ver)
key = h['Sec-WebSocket-Key']
# Choose binary if client supports it
if 'binary' in protocols:
self.base64 = False
elif 'base64' in protocols:
self.base64 = True
else:
raise self.EClose("Client must support 'binary' or 'base64' protocol")
# Generate the hash value for the accept header
accept = b64encode(sha1(key + self.GUID).digest())
response = self.server_handshake_hybi % accept
if self.base64:
response += "Sec-WebSocket-Protocol: base64\r\n"
else:
response += "Sec-WebSocket-Protocol: binary\r\n"
response += "\r\n"
if self.headers.get('key3'):
trailer = self.gen_md5(self.headers)
pre = "Sec-"
ver = 76
else: else:
trailer = "" # Hixie version of the protocol (75 or 76)
pre = ""
ver = 75
self.msg("%s: %s WebSocket connection (version %s)" if h.get('key3'):
% (address[0], stype, ver)) trailer = self.gen_md5(h)
pre = "Sec-"
self.version = "hixie-76"
else:
trailer = ""
pre = ""
self.version = "hixie-75"
# We only support base64 in Hixie era
self.base64 = True
response = self.server_handshake_hixie % (pre,
h['Origin'], pre, scheme, h['Host'], h['path'])
if 'base64' in protocols:
response += "%sWebSocket-Protocol: base64\r\n" % pre
else:
self.msg("Warning: client does not report 'base64' protocol support")
response += "\r\n" + trailer
self.msg("%s: %s WebSocket connection" % (address[0], stype))
self.msg("%s: Version %s, base64: '%s'" % (address[0],
self.version, self.base64))
# Send server WebSockets handshake response # Send server WebSockets handshake response
response = self.server_handshake % (pre, #self.msg("sending response [%s]" % response)
self.headers['Origin'], pre, scheme,
self.headers['Host'], self.headers['path'], pre,
trailer)
#self.msg("sending response:", repr(response))
retsock.send(response) retsock.send(response)
# Return the WebSockets socket which may be SSL wrapped # Return the WebSockets socket which may be SSL wrapped
@ -368,7 +663,8 @@ Connection: Upgrade\r
while True: while True:
try: try:
try: try:
csock = startsock = None self.client = None
startsock = None
pid = err = 0 pid = err = 0
try: try:
@ -394,9 +690,14 @@ Connection: Upgrade\r
pid = os.fork() pid = os.fork()
if pid == 0: if pid == 0:
# Initialize per client settings
self.send_parts = []
self.recv_part = None
self.base64 = False
# handler process # handler process
csock = self.do_handshake(startsock, address) self.client = self.do_handshake(
self.new_client(csock) startsock, address)
self.new_client()
else: else:
# parent process # parent process
self.handler_id += 1 self.handler_id += 1
@ -413,8 +714,8 @@ Connection: Upgrade\r
self.msg(traceback.format_exc()) self.msg(traceback.format_exc())
finally: finally:
if csock and csock != startsock: if self.client and self.client != startsock:
csock.close() self.client.close()
if startsock: if startsock:
startsock.close() startsock.close()

View File

@ -133,7 +133,7 @@ Traffic Legend:
# will be run in a separate forked process for each connection. # will be run in a separate forked process for each connection.
# #
def new_client(self, client): def new_client(self):
""" """
Called after a new WebSocket connection has been established. Called after a new WebSocket connection has been established.
""" """
@ -156,9 +156,9 @@ Traffic Legend:
if self.verbose and not self.daemon: if self.verbose and not self.daemon:
print self.traffic_legend print self.traffic_legend
# Stat proxying # Start proxying
try: try:
self.do_proxy(client, tsock) self.do_proxy(tsock)
except: except:
if tsock: if tsock:
tsock.close() tsock.close()
@ -169,14 +169,14 @@ Traffic Legend:
self.rec.close() self.rec.close()
raise raise
def do_proxy(self, client, target): def do_proxy(self, target):
""" """
Proxy client WebSocket to normal target socket. Proxy client WebSocket to normal target socket.
""" """
cqueue = [] cqueue = []
cpartial = "" c_pend = 0
tqueue = [] tqueue = []
rlist = [client, target] rlist = [self.client, target]
tstart = int(time.time()*1000) tstart = int(time.time()*1000)
while True: while True:
@ -184,7 +184,7 @@ Traffic Legend:
tdelta = int(time.time()*1000) - tstart tdelta = int(time.time()*1000) - tstart
if tqueue: wlist.append(target) if tqueue: wlist.append(target)
if cqueue: wlist.append(client) if cqueue or c_pend: wlist.append(self.client)
ins, outs, excepts = select(rlist, wlist, [], 1) ins, outs, excepts = select(rlist, wlist, [], 1)
if excepts: raise Exception("Socket exception") if excepts: raise Exception("Socket exception")
@ -199,53 +199,40 @@ Traffic Legend:
tqueue.insert(0, dat[sent:]) tqueue.insert(0, dat[sent:])
self.traffic(".>") self.traffic(".>")
if client in outs:
# Send queued target data to the client
dat = cqueue.pop(0)
sent = client.send(dat)
if sent == len(dat):
self.traffic("<")
if self.rec:
self.rec.write("%s,\n" %
repr("{%s{" % tdelta + dat[1:-1]))
else:
cqueue.insert(0, dat[sent:])
self.traffic("<.")
if target in ins: if target in ins:
# Receive target data, encode it and queue for client # Receive target data, encode it and queue for client
buf = target.recv(self.buffer_size) buf = target.recv(self.buffer_size)
if len(buf) == 0: raise self.EClose("Target closed") if len(buf) == 0: raise self.EClose("Target closed")
cqueue.append(self.encode(buf)) cqueue.append(buf)
self.traffic("{") self.traffic("{")
if client in ins:
# Receive client data, decode it, and queue for target
buf = client.recv(self.buffer_size)
if len(buf) == 0: raise self.EClose("Client closed")
if buf == '\xff\x00': if self.client in outs:
raise self.EClose("Client sent orderly close frame") # Send queued target data to the client
elif buf[-1] == '\xff': c_pend = self.send_frames(cqueue)
if buf.count('\xff') > 1: cqueue = []
self.traffic(str(buf.count('\xff')))
self.traffic("}") #if self.rec:
if self.rec: # self.rec.write("%s,\n" %
self.rec.write("%s,\n" % # repr("{%s{" % tdelta + dat[1:-1]))
(repr("}%s}" % tdelta + buf[1:-1])))
if cpartial:
# Prepend saved partial and decode frame(s) if self.client in ins:
tqueue.extend(self.decode(cpartial + buf)) # Receive client data, decode it, and queue for target
cpartial = "" bufs, closed = self.recv_frames()
else: tqueue.extend(bufs)
# decode frame(s)
tqueue.extend(self.decode(buf)) #if self.rec:
else: # for b in bufs:
# Save off partial WebSockets frame # self.rec.write(
self.traffic(".}") # repr("}%s}%s" % (tdelta, b)) + ",\n")
cpartial = cpartial + buf
if closed:
# TODO: What about blocking on client socket?
self.send_close()
raise self.EClose(closed)
if __name__ == '__main__': if __name__ == '__main__':
usage = "\n %prog [options]" usage = "\n %prog [options]"