- support wss

This commit is contained in:
Hiroki Ohtani
2011-01-05 09:19:19 +09:00
parent 2be0632fd7
commit 5f615f6949
2 changed files with 51 additions and 22 deletions

22
test.py
View File

@@ -49,34 +49,44 @@ class WebSocketTest(unittest.TestCase):
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 80)
self.assertEquals(p[2], "/r")
self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com/")
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 80)
self.assertEquals(p[2], "/")
self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com")
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 80)
self.assertEquals(p[2], "/")
self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com:8080/r")
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 8080)
self.assertEquals(p[2], "/r")
self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com:8080/")
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 8080)
self.assertEquals(p[2], "/")
self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com:8080")
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 8080)
self.assertEquals(p[2], "/")
self.assertEquals(p[3], False)
p = ws._parse_url("wss://www.example.com:8080/r")
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 8080)
self.assertEquals(p[2], "/r")
self.assertEquals(p[3], True)
# we do not support wss for a while
self.assertRaises(ValueError, ws._parse_url, "wss://www.example.com/r")
self.assertRaises(ValueError, ws._parse_url, "http://www.example.com/r")
def testWSKey(self):
@@ -127,7 +137,7 @@ class WebSocketTest(unittest.TestCase):
def testReadHeader(self):
sock = ws.WebSocket()
sock.sock = HeaderSockMock("data/header01.txt")
sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt")
status, header = sock._read_headers()
self.assertEquals(status, 101)
self.assertEquals(header["connection"], "upgrade")
@@ -135,12 +145,12 @@ class WebSocketTest(unittest.TestCase):
self.assertEquals(sock._get_resp(), "ssssss\r\naaaaaaaa")
sock.sock = HeaderSockMock("data/header02.txt")
sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt")
self.assertRaises(ws.WebSocketException, sock._read_headers)
def testSend(self):
sock = ws.WebSocket()
s = sock.sock = HeaderSockMock("data/header01.txt")
s = sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt")
sock.send("Hello")
self.assertEquals(s.sent[0], "\x00Hello\xff")
sock.send("こんにちは")
@@ -150,7 +160,7 @@ class WebSocketTest(unittest.TestCase):
def testRecv(self):
sock = ws.WebSocket()
s = sock.sock = StringSockMock()
s = sock.io_sock = sock.sock = StringSockMock()
s.set_data("\x00こんにちは\xff")
data = sock.recv()
self.assertEquals(data, "こんにちは")

View File

@@ -33,20 +33,27 @@ def getdefaulttimeout():
return default_timeout
def _parse_url(url):
"""
parse url and the result is tuple of
(hostname, port, resource path and the flag of secure mode)
"""
parsed = urlparse(url)
if parsed.hostname:
hostname = parsed.hostname
else:
raise ValueError("hostname is invalid")
port = 0
if parsed.port:
port = parsed.port
is_secure = False
if parsed.scheme == "ws":
if parsed.port:
port = parsed.port
else:
if not port:
port = 80
elif parsed.scheme == "wss":
# TODO: support wss
raise ValueError("scheme wss is not supported")
is_secure = True
if not port:
port = 443
else:
raise ValueError("scheme %s is invalid" % parsed.scheme)
@@ -55,7 +62,7 @@ def _parse_url(url):
else:
resource = "/"
return (hostname, port, resource)
return (hostname, port, resource, is_secure)
def create_connection(url, timeout=None, **options):
@@ -111,7 +118,16 @@ HEADERS_TO_EXIST_FOR_HIXIE75 = [
"websocket-origin",
"websocket-location",
]
class SSLSocketWrapper(object):
def __init__(self, sock):
self.ssl = socket.ssl(sock)
def recv(self, bufsize):
return self.ssl.read(bufsize)
def send(self, payload):
return self.ssl.write(payload)
class WebSocket(object):
def __init__(self):
@@ -119,7 +135,7 @@ class WebSocket(object):
Initalize WebSocket object.
"""
self.connected = False
self.sock = socket.socket()
self.io_sock = self.sock = socket.socket()
def settimeout(self, timeout):
"""
@@ -137,13 +153,15 @@ class WebSocket(object):
"""
Connect to url. url is websocket url scheme. ie. ws://host:port/resource
"""
hostname, port, resource = _parse_url(url)
hostname, port, resource, is_secure = _parse_url(url)
# TODO: we need to support proxy
self.sock.connect((hostname, port))
if is_secure:
self.io_sock = SSLSocketWrapper(self.sock)
self._handshake(hostname, port, resource, **options)
def _handshake(self, host, port, resource, **options):
sock = self.sock
sock = self.io_sock
headers = []
if "header" in options:
headers.extend(options["header"])
@@ -175,17 +193,17 @@ class WebSocket(object):
status, resp_headers = self._read_headers()
if status != 101:
self.sock.close()
self.close()
raise WebSocketException("Handshake Status %d" % status)
success, secure = self._validate_header(resp_headers)
if not success:
self.sock.close()
self.close()
raise WebSocketException("Invalid WebSocket Header")
if secure:
resp = self._get_resp()
if not self._validate_resp(number_1, number_2, key3, resp):
self.sock.close()
self.close()
raise WebSocketException("challenge-response error")
self.connected = True
@@ -268,7 +286,7 @@ class WebSocket(object):
"""
if isinstance(payload, unicode):
payload = payload.encode("utf-8")
self.sock.send("".join(["\x00", payload, "\xff"]))
self.io_sock.send("".join(["\x00", payload, "\xff"]))
def recv(self):
"""
@@ -307,16 +325,17 @@ class WebSocket(object):
"""
if self.connected:
try:
self.sock.send("\xff\x00")
self.io_sock.send("\xff\x00")
result = self._recv(2)
if result != "\xff\x00":
logger.error("bad closing Handshake")
except:
pass
self.sock.close()
self.io_sock = self.sock
def _recv(self, bufsize):
bytes = self.sock.recv(bufsize)
bytes = self.io_sock.recv(bufsize)
if not bytes:
raise ConnectionClosedException()
return bytes