- support wss
This commit is contained in:
22
test.py
22
test.py
@@ -49,34 +49,44 @@ class WebSocketTest(unittest.TestCase):
|
||||
self.assertEquals(p[0], "www.example.com")
|
||||
self.assertEquals(p[1], 80)
|
||||
self.assertEquals(p[2], "/r")
|
||||
self.assertEquals(p[3], False)
|
||||
|
||||
p = ws._parse_url("ws://www.example.com/")
|
||||
self.assertEquals(p[0], "www.example.com")
|
||||
self.assertEquals(p[1], 80)
|
||||
self.assertEquals(p[2], "/")
|
||||
self.assertEquals(p[3], False)
|
||||
|
||||
p = ws._parse_url("ws://www.example.com")
|
||||
self.assertEquals(p[0], "www.example.com")
|
||||
self.assertEquals(p[1], 80)
|
||||
self.assertEquals(p[2], "/")
|
||||
self.assertEquals(p[3], False)
|
||||
|
||||
p = ws._parse_url("ws://www.example.com:8080/r")
|
||||
self.assertEquals(p[0], "www.example.com")
|
||||
self.assertEquals(p[1], 8080)
|
||||
self.assertEquals(p[2], "/r")
|
||||
self.assertEquals(p[3], False)
|
||||
|
||||
p = ws._parse_url("ws://www.example.com:8080/")
|
||||
self.assertEquals(p[0], "www.example.com")
|
||||
self.assertEquals(p[1], 8080)
|
||||
self.assertEquals(p[2], "/")
|
||||
self.assertEquals(p[3], False)
|
||||
|
||||
p = ws._parse_url("ws://www.example.com:8080")
|
||||
self.assertEquals(p[0], "www.example.com")
|
||||
self.assertEquals(p[1], 8080)
|
||||
self.assertEquals(p[2], "/")
|
||||
self.assertEquals(p[3], False)
|
||||
|
||||
p = ws._parse_url("wss://www.example.com:8080/r")
|
||||
self.assertEquals(p[0], "www.example.com")
|
||||
self.assertEquals(p[1], 8080)
|
||||
self.assertEquals(p[2], "/r")
|
||||
self.assertEquals(p[3], True)
|
||||
|
||||
# we do not support wss for a while
|
||||
self.assertRaises(ValueError, ws._parse_url, "wss://www.example.com/r")
|
||||
self.assertRaises(ValueError, ws._parse_url, "http://www.example.com/r")
|
||||
|
||||
def testWSKey(self):
|
||||
@@ -127,7 +137,7 @@ class WebSocketTest(unittest.TestCase):
|
||||
|
||||
def testReadHeader(self):
|
||||
sock = ws.WebSocket()
|
||||
sock.sock = HeaderSockMock("data/header01.txt")
|
||||
sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt")
|
||||
status, header = sock._read_headers()
|
||||
self.assertEquals(status, 101)
|
||||
self.assertEquals(header["connection"], "upgrade")
|
||||
@@ -135,12 +145,12 @@ class WebSocketTest(unittest.TestCase):
|
||||
|
||||
self.assertEquals(sock._get_resp(), "ssssss\r\naaaaaaaa")
|
||||
|
||||
sock.sock = HeaderSockMock("data/header02.txt")
|
||||
sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt")
|
||||
self.assertRaises(ws.WebSocketException, sock._read_headers)
|
||||
|
||||
def testSend(self):
|
||||
sock = ws.WebSocket()
|
||||
s = sock.sock = HeaderSockMock("data/header01.txt")
|
||||
s = sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt")
|
||||
sock.send("Hello")
|
||||
self.assertEquals(s.sent[0], "\x00Hello\xff")
|
||||
sock.send("こんにちは")
|
||||
@@ -150,7 +160,7 @@ class WebSocketTest(unittest.TestCase):
|
||||
|
||||
def testRecv(self):
|
||||
sock = ws.WebSocket()
|
||||
s = sock.sock = StringSockMock()
|
||||
s = sock.io_sock = sock.sock = StringSockMock()
|
||||
s.set_data("\x00こんにちは\xff")
|
||||
data = sock.recv()
|
||||
self.assertEquals(data, "こんにちは")
|
||||
|
51
websocket.py
51
websocket.py
@@ -33,20 +33,27 @@ def getdefaulttimeout():
|
||||
return default_timeout
|
||||
|
||||
def _parse_url(url):
|
||||
"""
|
||||
parse url and the result is tuple of
|
||||
(hostname, port, resource path and the flag of secure mode)
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
if parsed.hostname:
|
||||
hostname = parsed.hostname
|
||||
else:
|
||||
raise ValueError("hostname is invalid")
|
||||
|
||||
port = 0
|
||||
if parsed.port:
|
||||
port = parsed.port
|
||||
|
||||
is_secure = False
|
||||
if parsed.scheme == "ws":
|
||||
if parsed.port:
|
||||
port = parsed.port
|
||||
else:
|
||||
if not port:
|
||||
port = 80
|
||||
elif parsed.scheme == "wss":
|
||||
# TODO: support wss
|
||||
raise ValueError("scheme wss is not supported")
|
||||
is_secure = True
|
||||
if not port:
|
||||
port = 443
|
||||
else:
|
||||
raise ValueError("scheme %s is invalid" % parsed.scheme)
|
||||
|
||||
@@ -55,7 +62,7 @@ def _parse_url(url):
|
||||
else:
|
||||
resource = "/"
|
||||
|
||||
return (hostname, port, resource)
|
||||
return (hostname, port, resource, is_secure)
|
||||
|
||||
|
||||
def create_connection(url, timeout=None, **options):
|
||||
@@ -111,7 +118,16 @@ HEADERS_TO_EXIST_FOR_HIXIE75 = [
|
||||
"websocket-origin",
|
||||
"websocket-location",
|
||||
]
|
||||
|
||||
class SSLSocketWrapper(object):
|
||||
def __init__(self, sock):
|
||||
self.ssl = socket.ssl(sock)
|
||||
|
||||
def recv(self, bufsize):
|
||||
return self.ssl.read(bufsize)
|
||||
|
||||
def send(self, payload):
|
||||
return self.ssl.write(payload)
|
||||
|
||||
class WebSocket(object):
|
||||
def __init__(self):
|
||||
@@ -119,7 +135,7 @@ class WebSocket(object):
|
||||
Initalize WebSocket object.
|
||||
"""
|
||||
self.connected = False
|
||||
self.sock = socket.socket()
|
||||
self.io_sock = self.sock = socket.socket()
|
||||
|
||||
def settimeout(self, timeout):
|
||||
"""
|
||||
@@ -137,13 +153,15 @@ class WebSocket(object):
|
||||
"""
|
||||
Connect to url. url is websocket url scheme. ie. ws://host:port/resource
|
||||
"""
|
||||
hostname, port, resource = _parse_url(url)
|
||||
hostname, port, resource, is_secure = _parse_url(url)
|
||||
# TODO: we need to support proxy
|
||||
self.sock.connect((hostname, port))
|
||||
if is_secure:
|
||||
self.io_sock = SSLSocketWrapper(self.sock)
|
||||
self._handshake(hostname, port, resource, **options)
|
||||
|
||||
def _handshake(self, host, port, resource, **options):
|
||||
sock = self.sock
|
||||
sock = self.io_sock
|
||||
headers = []
|
||||
if "header" in options:
|
||||
headers.extend(options["header"])
|
||||
@@ -175,17 +193,17 @@ class WebSocket(object):
|
||||
|
||||
status, resp_headers = self._read_headers()
|
||||
if status != 101:
|
||||
self.sock.close()
|
||||
self.close()
|
||||
raise WebSocketException("Handshake Status %d" % status)
|
||||
success, secure = self._validate_header(resp_headers)
|
||||
if not success:
|
||||
self.sock.close()
|
||||
self.close()
|
||||
raise WebSocketException("Invalid WebSocket Header")
|
||||
|
||||
if secure:
|
||||
resp = self._get_resp()
|
||||
if not self._validate_resp(number_1, number_2, key3, resp):
|
||||
self.sock.close()
|
||||
self.close()
|
||||
raise WebSocketException("challenge-response error")
|
||||
|
||||
self.connected = True
|
||||
@@ -268,7 +286,7 @@ class WebSocket(object):
|
||||
"""
|
||||
if isinstance(payload, unicode):
|
||||
payload = payload.encode("utf-8")
|
||||
self.sock.send("".join(["\x00", payload, "\xff"]))
|
||||
self.io_sock.send("".join(["\x00", payload, "\xff"]))
|
||||
|
||||
def recv(self):
|
||||
"""
|
||||
@@ -307,16 +325,17 @@ class WebSocket(object):
|
||||
"""
|
||||
if self.connected:
|
||||
try:
|
||||
self.sock.send("\xff\x00")
|
||||
self.io_sock.send("\xff\x00")
|
||||
result = self._recv(2)
|
||||
if result != "\xff\x00":
|
||||
logger.error("bad closing Handshake")
|
||||
except:
|
||||
pass
|
||||
self.sock.close()
|
||||
self.io_sock = self.sock
|
||||
|
||||
def _recv(self, bufsize):
|
||||
bytes = self.sock.recv(bufsize)
|
||||
bytes = self.io_sock.recv(bufsize)
|
||||
if not bytes:
|
||||
raise ConnectionClosedException()
|
||||
return bytes
|
||||
|
Reference in New Issue
Block a user