- support wss
This commit is contained in:
22
test.py
22
test.py
@@ -49,34 +49,44 @@ class WebSocketTest(unittest.TestCase):
|
|||||||
self.assertEquals(p[0], "www.example.com")
|
self.assertEquals(p[0], "www.example.com")
|
||||||
self.assertEquals(p[1], 80)
|
self.assertEquals(p[1], 80)
|
||||||
self.assertEquals(p[2], "/r")
|
self.assertEquals(p[2], "/r")
|
||||||
|
self.assertEquals(p[3], False)
|
||||||
|
|
||||||
p = ws._parse_url("ws://www.example.com/")
|
p = ws._parse_url("ws://www.example.com/")
|
||||||
self.assertEquals(p[0], "www.example.com")
|
self.assertEquals(p[0], "www.example.com")
|
||||||
self.assertEquals(p[1], 80)
|
self.assertEquals(p[1], 80)
|
||||||
self.assertEquals(p[2], "/")
|
self.assertEquals(p[2], "/")
|
||||||
|
self.assertEquals(p[3], False)
|
||||||
|
|
||||||
p = ws._parse_url("ws://www.example.com")
|
p = ws._parse_url("ws://www.example.com")
|
||||||
self.assertEquals(p[0], "www.example.com")
|
self.assertEquals(p[0], "www.example.com")
|
||||||
self.assertEquals(p[1], 80)
|
self.assertEquals(p[1], 80)
|
||||||
self.assertEquals(p[2], "/")
|
self.assertEquals(p[2], "/")
|
||||||
|
self.assertEquals(p[3], False)
|
||||||
|
|
||||||
p = ws._parse_url("ws://www.example.com:8080/r")
|
p = ws._parse_url("ws://www.example.com:8080/r")
|
||||||
self.assertEquals(p[0], "www.example.com")
|
self.assertEquals(p[0], "www.example.com")
|
||||||
self.assertEquals(p[1], 8080)
|
self.assertEquals(p[1], 8080)
|
||||||
self.assertEquals(p[2], "/r")
|
self.assertEquals(p[2], "/r")
|
||||||
|
self.assertEquals(p[3], False)
|
||||||
|
|
||||||
p = ws._parse_url("ws://www.example.com:8080/")
|
p = ws._parse_url("ws://www.example.com:8080/")
|
||||||
self.assertEquals(p[0], "www.example.com")
|
self.assertEquals(p[0], "www.example.com")
|
||||||
self.assertEquals(p[1], 8080)
|
self.assertEquals(p[1], 8080)
|
||||||
self.assertEquals(p[2], "/")
|
self.assertEquals(p[2], "/")
|
||||||
|
self.assertEquals(p[3], False)
|
||||||
|
|
||||||
p = ws._parse_url("ws://www.example.com:8080")
|
p = ws._parse_url("ws://www.example.com:8080")
|
||||||
self.assertEquals(p[0], "www.example.com")
|
self.assertEquals(p[0], "www.example.com")
|
||||||
self.assertEquals(p[1], 8080)
|
self.assertEquals(p[1], 8080)
|
||||||
self.assertEquals(p[2], "/")
|
self.assertEquals(p[2], "/")
|
||||||
|
self.assertEquals(p[3], False)
|
||||||
|
|
||||||
|
p = ws._parse_url("wss://www.example.com:8080/r")
|
||||||
|
self.assertEquals(p[0], "www.example.com")
|
||||||
|
self.assertEquals(p[1], 8080)
|
||||||
|
self.assertEquals(p[2], "/r")
|
||||||
|
self.assertEquals(p[3], True)
|
||||||
|
|
||||||
# we do not support wss for a while
|
|
||||||
self.assertRaises(ValueError, ws._parse_url, "wss://www.example.com/r")
|
|
||||||
self.assertRaises(ValueError, ws._parse_url, "http://www.example.com/r")
|
self.assertRaises(ValueError, ws._parse_url, "http://www.example.com/r")
|
||||||
|
|
||||||
def testWSKey(self):
|
def testWSKey(self):
|
||||||
@@ -127,7 +137,7 @@ class WebSocketTest(unittest.TestCase):
|
|||||||
|
|
||||||
def testReadHeader(self):
|
def testReadHeader(self):
|
||||||
sock = ws.WebSocket()
|
sock = ws.WebSocket()
|
||||||
sock.sock = HeaderSockMock("data/header01.txt")
|
sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt")
|
||||||
status, header = sock._read_headers()
|
status, header = sock._read_headers()
|
||||||
self.assertEquals(status, 101)
|
self.assertEquals(status, 101)
|
||||||
self.assertEquals(header["connection"], "upgrade")
|
self.assertEquals(header["connection"], "upgrade")
|
||||||
@@ -135,12 +145,12 @@ class WebSocketTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEquals(sock._get_resp(), "ssssss\r\naaaaaaaa")
|
self.assertEquals(sock._get_resp(), "ssssss\r\naaaaaaaa")
|
||||||
|
|
||||||
sock.sock = HeaderSockMock("data/header02.txt")
|
sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt")
|
||||||
self.assertRaises(ws.WebSocketException, sock._read_headers)
|
self.assertRaises(ws.WebSocketException, sock._read_headers)
|
||||||
|
|
||||||
def testSend(self):
|
def testSend(self):
|
||||||
sock = ws.WebSocket()
|
sock = ws.WebSocket()
|
||||||
s = sock.sock = HeaderSockMock("data/header01.txt")
|
s = sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt")
|
||||||
sock.send("Hello")
|
sock.send("Hello")
|
||||||
self.assertEquals(s.sent[0], "\x00Hello\xff")
|
self.assertEquals(s.sent[0], "\x00Hello\xff")
|
||||||
sock.send("こんにちは")
|
sock.send("こんにちは")
|
||||||
@@ -150,7 +160,7 @@ class WebSocketTest(unittest.TestCase):
|
|||||||
|
|
||||||
def testRecv(self):
|
def testRecv(self):
|
||||||
sock = ws.WebSocket()
|
sock = ws.WebSocket()
|
||||||
s = sock.sock = StringSockMock()
|
s = sock.io_sock = sock.sock = StringSockMock()
|
||||||
s.set_data("\x00こんにちは\xff")
|
s.set_data("\x00こんにちは\xff")
|
||||||
data = sock.recv()
|
data = sock.recv()
|
||||||
self.assertEquals(data, "こんにちは")
|
self.assertEquals(data, "こんにちは")
|
||||||
|
49
websocket.py
49
websocket.py
@@ -33,20 +33,27 @@ def getdefaulttimeout():
|
|||||||
return default_timeout
|
return default_timeout
|
||||||
|
|
||||||
def _parse_url(url):
|
def _parse_url(url):
|
||||||
|
"""
|
||||||
|
parse url and the result is tuple of
|
||||||
|
(hostname, port, resource path and the flag of secure mode)
|
||||||
|
"""
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
if parsed.hostname:
|
if parsed.hostname:
|
||||||
hostname = parsed.hostname
|
hostname = parsed.hostname
|
||||||
else:
|
else:
|
||||||
raise ValueError("hostname is invalid")
|
raise ValueError("hostname is invalid")
|
||||||
|
port = 0
|
||||||
if parsed.scheme == "ws":
|
|
||||||
if parsed.port:
|
if parsed.port:
|
||||||
port = parsed.port
|
port = parsed.port
|
||||||
else:
|
|
||||||
|
is_secure = False
|
||||||
|
if parsed.scheme == "ws":
|
||||||
|
if not port:
|
||||||
port = 80
|
port = 80
|
||||||
elif parsed.scheme == "wss":
|
elif parsed.scheme == "wss":
|
||||||
# TODO: support wss
|
is_secure = True
|
||||||
raise ValueError("scheme wss is not supported")
|
if not port:
|
||||||
|
port = 443
|
||||||
else:
|
else:
|
||||||
raise ValueError("scheme %s is invalid" % parsed.scheme)
|
raise ValueError("scheme %s is invalid" % parsed.scheme)
|
||||||
|
|
||||||
@@ -55,7 +62,7 @@ def _parse_url(url):
|
|||||||
else:
|
else:
|
||||||
resource = "/"
|
resource = "/"
|
||||||
|
|
||||||
return (hostname, port, resource)
|
return (hostname, port, resource, is_secure)
|
||||||
|
|
||||||
|
|
||||||
def create_connection(url, timeout=None, **options):
|
def create_connection(url, timeout=None, **options):
|
||||||
@@ -112,6 +119,15 @@ HEADERS_TO_EXIST_FOR_HIXIE75 = [
|
|||||||
"websocket-location",
|
"websocket-location",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
class SSLSocketWrapper(object):
|
||||||
|
def __init__(self, sock):
|
||||||
|
self.ssl = socket.ssl(sock)
|
||||||
|
|
||||||
|
def recv(self, bufsize):
|
||||||
|
return self.ssl.read(bufsize)
|
||||||
|
|
||||||
|
def send(self, payload):
|
||||||
|
return self.ssl.write(payload)
|
||||||
|
|
||||||
class WebSocket(object):
|
class WebSocket(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -119,7 +135,7 @@ class WebSocket(object):
|
|||||||
Initalize WebSocket object.
|
Initalize WebSocket object.
|
||||||
"""
|
"""
|
||||||
self.connected = False
|
self.connected = False
|
||||||
self.sock = socket.socket()
|
self.io_sock = self.sock = socket.socket()
|
||||||
|
|
||||||
def settimeout(self, timeout):
|
def settimeout(self, timeout):
|
||||||
"""
|
"""
|
||||||
@@ -137,13 +153,15 @@ class WebSocket(object):
|
|||||||
"""
|
"""
|
||||||
Connect to url. url is websocket url scheme. ie. ws://host:port/resource
|
Connect to url. url is websocket url scheme. ie. ws://host:port/resource
|
||||||
"""
|
"""
|
||||||
hostname, port, resource = _parse_url(url)
|
hostname, port, resource, is_secure = _parse_url(url)
|
||||||
# TODO: we need to support proxy
|
# TODO: we need to support proxy
|
||||||
self.sock.connect((hostname, port))
|
self.sock.connect((hostname, port))
|
||||||
|
if is_secure:
|
||||||
|
self.io_sock = SSLSocketWrapper(self.sock)
|
||||||
self._handshake(hostname, port, resource, **options)
|
self._handshake(hostname, port, resource, **options)
|
||||||
|
|
||||||
def _handshake(self, host, port, resource, **options):
|
def _handshake(self, host, port, resource, **options):
|
||||||
sock = self.sock
|
sock = self.io_sock
|
||||||
headers = []
|
headers = []
|
||||||
if "header" in options:
|
if "header" in options:
|
||||||
headers.extend(options["header"])
|
headers.extend(options["header"])
|
||||||
@@ -175,17 +193,17 @@ class WebSocket(object):
|
|||||||
|
|
||||||
status, resp_headers = self._read_headers()
|
status, resp_headers = self._read_headers()
|
||||||
if status != 101:
|
if status != 101:
|
||||||
self.sock.close()
|
self.close()
|
||||||
raise WebSocketException("Handshake Status %d" % status)
|
raise WebSocketException("Handshake Status %d" % status)
|
||||||
success, secure = self._validate_header(resp_headers)
|
success, secure = self._validate_header(resp_headers)
|
||||||
if not success:
|
if not success:
|
||||||
self.sock.close()
|
self.close()
|
||||||
raise WebSocketException("Invalid WebSocket Header")
|
raise WebSocketException("Invalid WebSocket Header")
|
||||||
|
|
||||||
if secure:
|
if secure:
|
||||||
resp = self._get_resp()
|
resp = self._get_resp()
|
||||||
if not self._validate_resp(number_1, number_2, key3, resp):
|
if not self._validate_resp(number_1, number_2, key3, resp):
|
||||||
self.sock.close()
|
self.close()
|
||||||
raise WebSocketException("challenge-response error")
|
raise WebSocketException("challenge-response error")
|
||||||
|
|
||||||
self.connected = True
|
self.connected = True
|
||||||
@@ -268,7 +286,7 @@ class WebSocket(object):
|
|||||||
"""
|
"""
|
||||||
if isinstance(payload, unicode):
|
if isinstance(payload, unicode):
|
||||||
payload = payload.encode("utf-8")
|
payload = payload.encode("utf-8")
|
||||||
self.sock.send("".join(["\x00", payload, "\xff"]))
|
self.io_sock.send("".join(["\x00", payload, "\xff"]))
|
||||||
|
|
||||||
def recv(self):
|
def recv(self):
|
||||||
"""
|
"""
|
||||||
@@ -307,16 +325,17 @@ class WebSocket(object):
|
|||||||
"""
|
"""
|
||||||
if self.connected:
|
if self.connected:
|
||||||
try:
|
try:
|
||||||
self.sock.send("\xff\x00")
|
self.io_sock.send("\xff\x00")
|
||||||
result = self._recv(2)
|
result = self._recv(2)
|
||||||
if result != "\xff\x00":
|
if result != "\xff\x00":
|
||||||
logger.error("bad closing Handshake")
|
logger.error("bad closing Handshake")
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
self.sock.close()
|
self.sock.close()
|
||||||
|
self.io_sock = self.sock
|
||||||
|
|
||||||
def _recv(self, bufsize):
|
def _recv(self, bufsize):
|
||||||
bytes = self.sock.recv(bufsize)
|
bytes = self.io_sock.recv(bufsize)
|
||||||
if not bytes:
|
if not bytes:
|
||||||
raise ConnectionClosedException()
|
raise ConnectionClosedException()
|
||||||
return bytes
|
return bytes
|
||||||
|
Reference in New Issue
Block a user