- handshake is workable.

This commit is contained in:
liris
2012-01-10 09:19:31 +09:00
parent 0b8c7e0aac
commit 295a364532
5 changed files with 54 additions and 128 deletions

View File

@@ -1,9 +1,6 @@
HTTP/1.1 101 WebSocket Protocol Handshake HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade Connection: Upgrade
Upgrade: WebSocket Upgrade: WebSocket
sec-websocket-location: http://localhost/r Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
sec-websocket-origin: http://localhost/r
some_header: something some_header: something
ssssss
aaaaaaaa

View File

@@ -1,9 +1,6 @@
HTTP/1.1 101 WebSocket Protocol Handshake HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade Connection: Upgrade
Upgrade WebSocket Upgrade WebSocket
sec-websocket-location: http://localhost/r Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
sec-websocket-origin: http://localhost/r
some_header: something some_header: something
ssssss
aaaaaaaa

View File

@@ -1,6 +1,6 @@
from setuptools import setup from setuptools import setup
VERSION = "0.4.1" VERSION = "0.5.0"
setup( setup(

View File

@@ -92,50 +92,39 @@ class WebSocketTest(unittest.TestCase):
self.assertRaises(ValueError, ws._parse_url, "http://www.example.com/r") self.assertRaises(ValueError, ws._parse_url, "http://www.example.com/r")
def testWSKey(self): def testWSKey(self):
n, k = ws._create_sec_websocket_key() key = ws._create_sec_websocket_key()
self.assert_(0 < n < (1<<32)) self.assert_(key != 24)
self.assert_(len(k) > 0) self.assert_("¥n" not in key)
k3 = ws._create_key3()
self.assertEquals(len(k3), 8)
def testWsUtils(self): def testWsUtils(self):
sock = ws.WebSocket() sock = ws.WebSocket()
self.assertNotEquals(sock._validate_resp(1,2,"test", "fuga"), True)
hashed = '6\xa3p\xb6#\xac\xb9=\xec\x0e\x96\xb5\xc1@\x1d\x90'
self.assertEquals(sock._validate_resp(1,2,"test", hashed), True)
hibi_header = { key = "c6b8hTg4EeGb2gQMztV1/g=="
required_header = {
"upgrade": "websocket", "upgrade": "websocket",
"connection": "upgrade", "connection": "upgrade",
"sec-websocket-origin": "http://www.example.com", "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=",
"sec-websocket-location": "http://www.example.com",
} }
self.assertEquals(sock._validate_header(hibi_header), (True, True)) self.assertEquals(sock._validate_header(required_header, key), True)
header = hibi_header.copy() header = required_header.copy()
header["upgrade"] = "http" header["upgrade"] = "http"
self.assertEquals(sock._validate_header(header), (False, False)) self.assertEquals(sock._validate_header(header, key), False)
del header["upgrade"] del header["upgrade"]
self.assertEquals(sock._validate_header(header), (False, False)) self.assertEquals(sock._validate_header(header, key), False)
header = hibi_header.copy() header = required_header.copy()
header["connection"] = "http" header["connection"] = "something"
self.assertEquals(sock._validate_header(header), (False, False)) self.assertEquals(sock._validate_header(header, key), False)
del header["connection"] del header["connection"]
self.assertEquals(sock._validate_header(header), (False, False)) self.assertEquals(sock._validate_header(header, key), False)
header = hibi_header.copy()
header["sec-websocket-origin"] = "somewhere origin"
self.assertEquals(sock._validate_header(header), (True, True))
del header["sec-websocket-origin"]
self.assertEquals(sock._validate_header(header), (False, True))
header = hibi_header.copy() header = required_header.copy()
header["sec-websocket-location"] = "somewhere location" header["sec-websocket-accept"] = "something"
self.assertEquals(sock._validate_header(header), (True, True)) self.assertEquals(sock._validate_header(header, key), False)
del header["sec-websocket-location"] del header["sec-websocket-accept"]
self.assertEquals(sock._validate_header(header), (False, True)) self.assertEquals(sock._validate_header(header, key), False)
def testReadHeader(self): def testReadHeader(self):
sock = ws.WebSocket() sock = ws.WebSocket()
@@ -143,10 +132,7 @@ class WebSocketTest(unittest.TestCase):
status, header = sock._read_headers() status, header = sock._read_headers()
self.assertEquals(status, 101) self.assertEquals(status, 101)
self.assertEquals(header["connection"], "upgrade") self.assertEquals(header["connection"], "upgrade")
self.assertEquals("ssssss" in header, False)
self.assertEquals(sock._get_resp(), "ssssss\r\naaaaaaaa")
sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt") sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt")
self.assertRaises(ws.WebSocketException, sock._read_headers) self.assertRaises(ws.WebSocketException, sock._read_headers)

View File

@@ -22,12 +22,14 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import socket import socket
from urlparse import urlparse from urlparse import urlparse
import random import uuid
import struct import sha
import hashlib import base64
import logging import logging
VERSION = 13
logger = logging.getLogger() logger = logging.getLogger()
class WebSocketException(Exception): class WebSocketException(Exception):
@@ -128,39 +130,14 @@ _MAX_CHAR_BYTE = (1<<8) -1
# http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html # http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html
def _create_sec_websocket_key(): def _create_sec_websocket_key():
spaces_n = random.randint(1, 12) uid = uuid.uuid1()
max_n = _MAX_INTEGER / spaces_n return base64.encodestring(uid.bytes).strip()
number_n = random.randint(0, max_n)
product_n = number_n * spaces_n
key_n = str(product_n)
for i in range(random.randint(1, 12)):
c = random.choice(_AVAILABLE_KEY_CHARS)
pos = random.randint(0, len(key_n))
key_n = key_n[0:pos] + chr(c) + key_n[pos:]
for i in range(spaces_n):
pos = random.randint(1, len(key_n)-1)
key_n = key_n[0:pos] + " " + key_n[pos:]
return number_n, key_n
def _create_key3():
return "".join([chr(random.randint(0, _MAX_CHAR_BYTE)) for i in range(8)])
HEADERS_TO_CHECK = { HEADERS_TO_CHECK = {
"upgrade": "websocket", "upgrade": "websocket",
"connection": "upgrade", "connection": "upgrade",
} }
HEADERS_TO_EXIST_FOR_HYBI00 = [
"sec-websocket-origin",
"sec-websocket-location",
]
HEADERS_TO_EXIST_FOR_HIXIE75 = [
"websocket-origin",
"websocket-location",
]
class _SSLSocketWrapper(object): class _SSLSocketWrapper(object):
def __init__(self, sock): def __init__(self, sock):
self.ssl = socket.ssl(sock) self.ssl = socket.ssl(sock)
@@ -229,7 +206,7 @@ class WebSocket(object):
sock = self.io_sock sock = self.io_sock
headers = [] headers = []
headers.append("GET %s HTTP/1.1" % resource) headers.append("GET %s HTTP/1.1" % resource)
headers.append("Upgrade: WebSocket") headers.append("Upgrade: websocket")
headers.append("Connection: Upgrade") headers.append("Connection: Upgrade")
if port == 80: if port == 80:
hostport = host hostport = host
@@ -238,16 +215,15 @@ class WebSocket(object):
headers.append("Host: %s" % hostport) headers.append("Host: %s" % hostport)
headers.append("Origin: %s" % hostport) headers.append("Origin: %s" % hostport)
number_1, key_1 = _create_sec_websocket_key() key = _create_sec_websocket_key()
headers.append("Sec-WebSocket-Key1: %s" % key_1) headers.append("Sec-WebSocket-Key: %s" % key)
number_2, key_2 = _create_sec_websocket_key() headers.append("Sec-WebSocket-Protocol: chat, superchat")
headers.append("Sec-WebSocket-Key2: %s" % key_2) headers.append("Sec-WebSocket-Version: %s" % VERSION)
if "header" in options: if "header" in options:
headers.extend(options["header"]) headers.extend(options["header"])
headers.append("") headers.append("")
key3 = _create_key3() headers.append("")
headers.append(key3)
header_str = "\r\n".join(headers) header_str = "\r\n".join(headers)
sock.send(header_str) sock.send(header_str)
@@ -260,61 +236,31 @@ class WebSocket(object):
if status != 101: if status != 101:
self.close() self.close()
raise WebSocketException("Handshake Status %d" % status) raise WebSocketException("Handshake Status %d" % status)
success, secure = self._validate_header(resp_headers)
success = self._validate_header(resp_headers, key)
if not success: if not success:
self.close() self.close()
raise WebSocketException("Invalid WebSocket Header") raise WebSocketException("Invalid WebSocket Header")
if secure:
resp = self._get_resp()
if not self._validate_resp(number_1, number_2, key3, resp):
self.close()
raise WebSocketException("challenge-response error")
self.connected = True self.connected = True
def _validate_resp(self, number_1, number_2, key3, resp): def _validate_header(self, headers, key):
challenge = struct.pack("!I", number_1) for k, v in HEADERS_TO_CHECK.iteritems():
challenge += struct.pack("!I", number_2) r = headers.get(k, None)
challenge += key3 if not r:
digest = hashlib.md5(challenge).digest() return False
r = r.lower()
if v != r:
return False
result = headers.get("sec-websocket-accept", None)
if not result:
return False
result = result.lower()
return resp == digest value = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
hashed = base64.encodestring(sha.sha(value).digest()).strip().lower()
def _get_resp(self): return hashed == result
result = self._recv(16)
if traceEnabled:
logger.debug("--- challenge response result ---")
logger.debug(repr(result))
logger.debug("---------------------------------")
return result
def _validate_header(self, headers):
#TODO: check other headers
for key, value in HEADERS_TO_CHECK.iteritems():
v = headers.get(key, None)
if value != v:
return False, False
success = 0
for key in HEADERS_TO_EXIST_FOR_HYBI00:
if key in headers:
success += 1
if success == len(HEADERS_TO_EXIST_FOR_HYBI00):
return True, True
elif success != 0:
return False, True
success = 0
for key in HEADERS_TO_EXIST_FOR_HIXIE75:
if key in headers:
success += 1
if success == len(HEADERS_TO_EXIST_FOR_HIXIE75):
return True, False
return False, False
def _read_headers(self): def _read_headers(self):
status = None status = None