- support wss

This commit is contained in:
Hiroki Ohtani
2011-01-05 09:19:19 +09:00
parent 2be0632fd7
commit 5f615f6949
2 changed files with 51 additions and 22 deletions

22
test.py
View File

@@ -49,34 +49,44 @@ class WebSocketTest(unittest.TestCase):
self.assertEquals(p[0], "www.example.com") self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 80) self.assertEquals(p[1], 80)
self.assertEquals(p[2], "/r") self.assertEquals(p[2], "/r")
self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com/") p = ws._parse_url("ws://www.example.com/")
self.assertEquals(p[0], "www.example.com") self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 80) self.assertEquals(p[1], 80)
self.assertEquals(p[2], "/") self.assertEquals(p[2], "/")
self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com") p = ws._parse_url("ws://www.example.com")
self.assertEquals(p[0], "www.example.com") self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 80) self.assertEquals(p[1], 80)
self.assertEquals(p[2], "/") self.assertEquals(p[2], "/")
self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com:8080/r") p = ws._parse_url("ws://www.example.com:8080/r")
self.assertEquals(p[0], "www.example.com") self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 8080) self.assertEquals(p[1], 8080)
self.assertEquals(p[2], "/r") self.assertEquals(p[2], "/r")
self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com:8080/") p = ws._parse_url("ws://www.example.com:8080/")
self.assertEquals(p[0], "www.example.com") self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 8080) self.assertEquals(p[1], 8080)
self.assertEquals(p[2], "/") self.assertEquals(p[2], "/")
self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com:8080") p = ws._parse_url("ws://www.example.com:8080")
self.assertEquals(p[0], "www.example.com") self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 8080) self.assertEquals(p[1], 8080)
self.assertEquals(p[2], "/") self.assertEquals(p[2], "/")
self.assertEquals(p[3], False)
p = ws._parse_url("wss://www.example.com:8080/r")
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 8080)
self.assertEquals(p[2], "/r")
self.assertEquals(p[3], True)
# we do not support wss for a while
self.assertRaises(ValueError, ws._parse_url, "wss://www.example.com/r")
self.assertRaises(ValueError, ws._parse_url, "http://www.example.com/r") self.assertRaises(ValueError, ws._parse_url, "http://www.example.com/r")
def testWSKey(self): def testWSKey(self):
@@ -127,7 +137,7 @@ class WebSocketTest(unittest.TestCase):
def testReadHeader(self): def testReadHeader(self):
sock = ws.WebSocket() sock = ws.WebSocket()
sock.sock = HeaderSockMock("data/header01.txt") sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt")
status, header = sock._read_headers() status, header = sock._read_headers()
self.assertEquals(status, 101) self.assertEquals(status, 101)
self.assertEquals(header["connection"], "upgrade") self.assertEquals(header["connection"], "upgrade")
@@ -135,12 +145,12 @@ class WebSocketTest(unittest.TestCase):
self.assertEquals(sock._get_resp(), "ssssss\r\naaaaaaaa") self.assertEquals(sock._get_resp(), "ssssss\r\naaaaaaaa")
sock.sock = HeaderSockMock("data/header02.txt") sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt")
self.assertRaises(ws.WebSocketException, sock._read_headers) self.assertRaises(ws.WebSocketException, sock._read_headers)
def testSend(self): def testSend(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = HeaderSockMock("data/header01.txt") s = sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt")
sock.send("Hello") sock.send("Hello")
self.assertEquals(s.sent[0], "\x00Hello\xff") self.assertEquals(s.sent[0], "\x00Hello\xff")
sock.send("こんにちは") sock.send("こんにちは")
@@ -150,7 +160,7 @@ class WebSocketTest(unittest.TestCase):
def testRecv(self): def testRecv(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = StringSockMock() s = sock.io_sock = sock.sock = StringSockMock()
s.set_data("\x00こんにちは\xff") s.set_data("\x00こんにちは\xff")
data = sock.recv() data = sock.recv()
self.assertEquals(data, "こんにちは") self.assertEquals(data, "こんにちは")

View File

@@ -33,20 +33,27 @@ def getdefaulttimeout():
return default_timeout return default_timeout
def _parse_url(url): def _parse_url(url):
"""
parse url and the result is tuple of
(hostname, port, resource path and the flag of secure mode)
"""
parsed = urlparse(url) parsed = urlparse(url)
if parsed.hostname: if parsed.hostname:
hostname = parsed.hostname hostname = parsed.hostname
else: else:
raise ValueError("hostname is invalid") raise ValueError("hostname is invalid")
port = 0
if parsed.scheme == "ws":
if parsed.port: if parsed.port:
port = parsed.port port = parsed.port
else:
is_secure = False
if parsed.scheme == "ws":
if not port:
port = 80 port = 80
elif parsed.scheme == "wss": elif parsed.scheme == "wss":
# TODO: support wss is_secure = True
raise ValueError("scheme wss is not supported") if not port:
port = 443
else: else:
raise ValueError("scheme %s is invalid" % parsed.scheme) raise ValueError("scheme %s is invalid" % parsed.scheme)
@@ -55,7 +62,7 @@ def _parse_url(url):
else: else:
resource = "/" resource = "/"
return (hostname, port, resource) return (hostname, port, resource, is_secure)
def create_connection(url, timeout=None, **options): def create_connection(url, timeout=None, **options):
@@ -112,6 +119,15 @@ HEADERS_TO_EXIST_FOR_HIXIE75 = [
"websocket-location", "websocket-location",
] ]
class SSLSocketWrapper(object):
def __init__(self, sock):
self.ssl = socket.ssl(sock)
def recv(self, bufsize):
return self.ssl.read(bufsize)
def send(self, payload):
return self.ssl.write(payload)
class WebSocket(object): class WebSocket(object):
def __init__(self): def __init__(self):
@@ -119,7 +135,7 @@ class WebSocket(object):
Initalize WebSocket object. Initalize WebSocket object.
""" """
self.connected = False self.connected = False
self.sock = socket.socket() self.io_sock = self.sock = socket.socket()
def settimeout(self, timeout): def settimeout(self, timeout):
""" """
@@ -137,13 +153,15 @@ class WebSocket(object):
""" """
Connect to url. url is websocket url scheme. ie. ws://host:port/resource Connect to url. url is websocket url scheme. ie. ws://host:port/resource
""" """
hostname, port, resource = _parse_url(url) hostname, port, resource, is_secure = _parse_url(url)
# TODO: we need to support proxy # TODO: we need to support proxy
self.sock.connect((hostname, port)) self.sock.connect((hostname, port))
if is_secure:
self.io_sock = SSLSocketWrapper(self.sock)
self._handshake(hostname, port, resource, **options) self._handshake(hostname, port, resource, **options)
def _handshake(self, host, port, resource, **options): def _handshake(self, host, port, resource, **options):
sock = self.sock sock = self.io_sock
headers = [] headers = []
if "header" in options: if "header" in options:
headers.extend(options["header"]) headers.extend(options["header"])
@@ -175,17 +193,17 @@ class WebSocket(object):
status, resp_headers = self._read_headers() status, resp_headers = self._read_headers()
if status != 101: if status != 101:
self.sock.close() self.close()
raise WebSocketException("Handshake Status %d" % status) raise WebSocketException("Handshake Status %d" % status)
success, secure = self._validate_header(resp_headers) success, secure = self._validate_header(resp_headers)
if not success: if not success:
self.sock.close() self.close()
raise WebSocketException("Invalid WebSocket Header") raise WebSocketException("Invalid WebSocket Header")
if secure: if secure:
resp = self._get_resp() resp = self._get_resp()
if not self._validate_resp(number_1, number_2, key3, resp): if not self._validate_resp(number_1, number_2, key3, resp):
self.sock.close() self.close()
raise WebSocketException("challenge-response error") raise WebSocketException("challenge-response error")
self.connected = True self.connected = True
@@ -268,7 +286,7 @@ class WebSocket(object):
""" """
if isinstance(payload, unicode): if isinstance(payload, unicode):
payload = payload.encode("utf-8") payload = payload.encode("utf-8")
self.sock.send("".join(["\x00", payload, "\xff"])) self.io_sock.send("".join(["\x00", payload, "\xff"]))
def recv(self): def recv(self):
""" """
@@ -307,16 +325,17 @@ class WebSocket(object):
""" """
if self.connected: if self.connected:
try: try:
self.sock.send("\xff\x00") self.io_sock.send("\xff\x00")
result = self._recv(2) result = self._recv(2)
if result != "\xff\x00": if result != "\xff\x00":
logger.error("bad closing Handshake") logger.error("bad closing Handshake")
except: except:
pass pass
self.sock.close() self.sock.close()
self.io_sock = self.sock
def _recv(self, bufsize): def _recv(self, bufsize):
bytes = self.sock.recv(bufsize) bytes = self.io_sock.recv(bufsize)
if not bytes: if not bytes:
raise ConnectionClosedException() raise ConnectionClosedException()
return bytes return bytes