tiny modification
This commit is contained in:
@@ -138,7 +138,6 @@ class WebSocketTest(unittest.TestCase):
|
|||||||
del header["connection"]
|
del header["connection"]
|
||||||
self.assertEquals(sock._validate_header(header, key), False)
|
self.assertEquals(sock._validate_header(header, key), False)
|
||||||
|
|
||||||
|
|
||||||
header = required_header.copy()
|
header = required_header.copy()
|
||||||
header["sec-websocket-accept"] = "something"
|
header["sec-websocket-accept"] = "something"
|
||||||
self.assertEquals(sock._validate_header(header, key), False)
|
self.assertEquals(sock._validate_header(header, key), False)
|
||||||
@@ -151,7 +150,7 @@ class WebSocketTest(unittest.TestCase):
|
|||||||
status, header = sock._read_headers()
|
status, header = sock._read_headers()
|
||||||
self.assertEquals(status, 101)
|
self.assertEquals(status, 101)
|
||||||
self.assertEquals(header["connection"], "upgrade")
|
self.assertEquals(header["connection"], "upgrade")
|
||||||
|
|
||||||
sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt")
|
sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt")
|
||||||
self.assertRaises(ws.WebSocketException, sock._read_headers)
|
self.assertRaises(ws.WebSocketException, sock._read_headers)
|
||||||
|
|
||||||
@@ -176,13 +175,13 @@ class WebSocketTest(unittest.TestCase):
|
|||||||
s.set_data("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
|
s.set_data("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
|
||||||
data = sock.recv()
|
data = sock.recv()
|
||||||
self.assertEquals(data, "こんにちは")
|
self.assertEquals(data, "こんにちは")
|
||||||
|
|
||||||
s.set_data("\x81\x85abcd)\x07\x0f\x08\x0e")
|
s.set_data("\x81\x85abcd)\x07\x0f\x08\x0e")
|
||||||
data = sock.recv()
|
data = sock.recv()
|
||||||
self.assertEquals(data, "Hello")
|
self.assertEquals(data, "Hello")
|
||||||
|
|
||||||
def testWebSocket(self):
|
def testWebSocket(self):
|
||||||
s = ws.create_connection("ws://echo.websocket.org/") #ws://localhost:8080/echo")
|
s = ws.create_connection("ws://echo.websocket.org/")
|
||||||
self.assertNotEquals(s, None)
|
self.assertNotEquals(s, None)
|
||||||
s.send("Hello, World")
|
s.send("Hello, World")
|
||||||
result = s.recv()
|
result = s.recv()
|
||||||
@@ -194,14 +193,14 @@ class WebSocketTest(unittest.TestCase):
|
|||||||
s.close()
|
s.close()
|
||||||
|
|
||||||
def testPingPong(self):
|
def testPingPong(self):
|
||||||
s = ws.create_connection("ws://echo.websocket.org/")
|
s = ws.create_connection("ws://echo.websocket.org/")
|
||||||
self.assertNotEquals(s, None)
|
self.assertNotEquals(s, None)
|
||||||
s.ping("Hello")
|
s.ping("Hello")
|
||||||
s.pong("Hi")
|
s.pong("Hi")
|
||||||
s.close()
|
s.close()
|
||||||
|
|
||||||
def testSecureWebSocket(self):
|
def testSecureWebSocket(self):
|
||||||
s = ws.create_connection("wss://echo.websocket.org/")
|
s = ws.create_connection("wss://echo.websocket.org/")
|
||||||
self.assertNotEquals(s, None)
|
self.assertNotEquals(s, None)
|
||||||
self.assert_(isinstance(s.io_sock, ws._SSLSocketWrapper))
|
self.assert_(isinstance(s.io_sock, ws._SSLSocketWrapper))
|
||||||
s.send("Hello, World")
|
s.send("Hello, World")
|
||||||
@@ -213,7 +212,7 @@ class WebSocketTest(unittest.TestCase):
|
|||||||
s.close()
|
s.close()
|
||||||
|
|
||||||
def testWebSocketWihtCustomHeader(self):
|
def testWebSocketWihtCustomHeader(self):
|
||||||
s = ws.create_connection("ws://echo.websocket.org/",
|
s = ws.create_connection("ws://echo.websocket.org/",
|
||||||
headers={"User-Agent": "PythonWebsocketClient"})
|
headers={"User-Agent": "PythonWebsocketClient"})
|
||||||
self.assertNotEquals(s, None)
|
self.assertNotEquals(s, None)
|
||||||
s.send("Hello, World")
|
s.send("Hello, World")
|
||||||
@@ -223,12 +222,12 @@ class WebSocketTest(unittest.TestCase):
|
|||||||
|
|
||||||
def testAfterClose(self):
|
def testAfterClose(self):
|
||||||
from socket import error
|
from socket import error
|
||||||
s = ws.create_connection("ws://echo.websocket.org/")
|
s = ws.create_connection("ws://echo.websocket.org/")
|
||||||
self.assertNotEquals(s, None)
|
self.assertNotEquals(s, None)
|
||||||
s.close()
|
s.close()
|
||||||
self.assertRaises(error, s.send, "Hello")
|
self.assertRaises(error, s.send, "Hello")
|
||||||
self.assertRaises(error, s.recv)
|
self.assertRaises(error, s.recv)
|
||||||
|
|
||||||
def testUUID4(self):
|
def testUUID4(self):
|
||||||
""" WebSocket key should be a UUID4.
|
""" WebSocket key should be a UUID4.
|
||||||
"""
|
"""
|
||||||
@@ -236,25 +235,25 @@ class WebSocketTest(unittest.TestCase):
|
|||||||
u = uuid.UUID(bytes=base64.b64decode(key))
|
u = uuid.UUID(bytes=base64.b64decode(key))
|
||||||
self.assertEquals(4, u.version)
|
self.assertEquals(4, u.version)
|
||||||
|
|
||||||
|
|
||||||
class WebSocketAppTest(unittest.TestCase):
|
class WebSocketAppTest(unittest.TestCase):
|
||||||
|
|
||||||
class NotSetYet(object):
|
class NotSetYet(object):
|
||||||
""" A marker class for signalling that a value hasn't been set yet.
|
""" A marker class for signalling that a value hasn't been set yet.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
ws.enableTrace(TRACABLE)
|
ws.enableTrace(TRACABLE)
|
||||||
|
|
||||||
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
|
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
|
||||||
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
|
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
|
||||||
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
|
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
|
||||||
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
|
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
|
||||||
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
|
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
|
||||||
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
|
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
|
||||||
|
|
||||||
def testKeepRunning(self):
|
def testKeepRunning(self):
|
||||||
""" A WebSocketApp should keep running as long as its self.keep_running
|
""" A WebSocketApp should keep running as long as its self.keep_running
|
||||||
is not False (in the boolean context).
|
is not False (in the boolean context).
|
||||||
@@ -266,46 +265,44 @@ class WebSocketAppTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
WebSocketAppTest.keep_running_open = self.keep_running
|
WebSocketAppTest.keep_running_open = self.keep_running
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def on_close(self, *args, **kwargs):
|
def on_close(self, *args, **kwargs):
|
||||||
""" Set the keep_running flag for the test to use.
|
""" Set the keep_running flag for the test to use.
|
||||||
"""
|
"""
|
||||||
WebSocketAppTest.keep_running_close = self.keep_running
|
WebSocketAppTest.keep_running_close = self.keep_running
|
||||||
|
|
||||||
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close)
|
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close)
|
||||||
app.run_forever()
|
app.run_forever()
|
||||||
|
|
||||||
self.assertFalse(isinstance(WebSocketAppTest.keep_running_open,
|
self.assertFalse(isinstance(WebSocketAppTest.keep_running_open,
|
||||||
WebSocketAppTest.NotSetYet))
|
WebSocketAppTest.NotSetYet))
|
||||||
|
|
||||||
self.assertFalse(isinstance(WebSocketAppTest.keep_running_close,
|
self.assertFalse(isinstance(WebSocketAppTest.keep_running_close,
|
||||||
WebSocketAppTest.NotSetYet))
|
WebSocketAppTest.NotSetYet))
|
||||||
|
|
||||||
self.assertEquals(True, WebSocketAppTest.keep_running_open)
|
self.assertEquals(True, WebSocketAppTest.keep_running_open)
|
||||||
self.assertEquals(False, WebSocketAppTest.keep_running_close)
|
self.assertEquals(False, WebSocketAppTest.keep_running_close)
|
||||||
|
|
||||||
def testSockMaskKey(self):
|
def testSockMaskKey(self):
|
||||||
""" A WebSocketApp should forward the received mask_key function down
|
""" A WebSocketApp should forward the received mask_key function down
|
||||||
to the actual socket.
|
to the actual socket.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def my_mask_key_func():
|
def my_mask_key_func():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_open(self, *args, **kwargs):
|
def on_open(self, *args, **kwargs):
|
||||||
""" Set the value so the test can use it later on and immediately
|
""" Set the value so the test can use it later on and immediately
|
||||||
close the connection.
|
close the connection.
|
||||||
"""
|
"""
|
||||||
WebSocketAppTest.get_mask_key_id = id(self.get_mask_key)
|
WebSocketAppTest.get_mask_key_id = id(self.get_mask_key)
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, get_mask_key=my_mask_key_func)
|
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, get_mask_key=my_mask_key_func)
|
||||||
app.run_forever()
|
app.run_forever()
|
||||||
|
|
||||||
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
|
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
|
||||||
self.assertEquals(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func))
|
self.assertEquals(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
||||||
|
28
websocket.py
28
websocket.py
@@ -58,12 +58,14 @@ STATUS_TLS_HANDSHAKE_ERROR = 1015
|
|||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
class WebSocketException(Exception):
|
class WebSocketException(Exception):
|
||||||
"""
|
"""
|
||||||
websocket exeception class.
|
websocket exeception class.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class WebSocketConnectionClosedException(WebSocketException):
|
class WebSocketConnectionClosedException(WebSocketException):
|
||||||
"""
|
"""
|
||||||
If remote host closed the connection or some network error happened,
|
If remote host closed the connection or some network error happened,
|
||||||
@@ -74,6 +76,7 @@ class WebSocketConnectionClosedException(WebSocketException):
|
|||||||
default_timeout = None
|
default_timeout = None
|
||||||
traceEnabled = False
|
traceEnabled = False
|
||||||
|
|
||||||
|
|
||||||
def enableTrace(tracable):
|
def enableTrace(tracable):
|
||||||
"""
|
"""
|
||||||
turn on/off the tracability.
|
turn on/off the tracability.
|
||||||
@@ -87,6 +90,7 @@ def enableTrace(tracable):
|
|||||||
logger.addHandler(logging.StreamHandler())
|
logger.addHandler(logging.StreamHandler())
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
def setdefaulttimeout(timeout):
|
def setdefaulttimeout(timeout):
|
||||||
"""
|
"""
|
||||||
Set the global timeout setting to connect.
|
Set the global timeout setting to connect.
|
||||||
@@ -96,12 +100,14 @@ def setdefaulttimeout(timeout):
|
|||||||
global default_timeout
|
global default_timeout
|
||||||
default_timeout = timeout
|
default_timeout = timeout
|
||||||
|
|
||||||
|
|
||||||
def getdefaulttimeout():
|
def getdefaulttimeout():
|
||||||
"""
|
"""
|
||||||
Return the global timeout setting(second) to connect.
|
Return the global timeout setting(second) to connect.
|
||||||
"""
|
"""
|
||||||
return default_timeout
|
return default_timeout
|
||||||
|
|
||||||
|
|
||||||
def _parse_url(url):
|
def _parse_url(url):
|
||||||
"""
|
"""
|
||||||
parse url and the result is tuple of
|
parse url and the result is tuple of
|
||||||
@@ -130,7 +136,7 @@ def _parse_url(url):
|
|||||||
elif scheme == "wss":
|
elif scheme == "wss":
|
||||||
is_secure = True
|
is_secure = True
|
||||||
if not port:
|
if not port:
|
||||||
port = 443
|
port = 443
|
||||||
else:
|
else:
|
||||||
raise ValueError("scheme %s is invalid" % scheme)
|
raise ValueError("scheme %s is invalid" % scheme)
|
||||||
|
|
||||||
@@ -144,6 +150,7 @@ def _parse_url(url):
|
|||||||
|
|
||||||
return (hostname, port, resource, is_secure)
|
return (hostname, port, resource, is_secure)
|
||||||
|
|
||||||
|
|
||||||
def create_connection(url, timeout=None, **options):
|
def create_connection(url, timeout=None, **options):
|
||||||
"""
|
"""
|
||||||
connect to url and return websocket object.
|
connect to url and return websocket object.
|
||||||
@@ -177,6 +184,7 @@ _MAX_CHAR_BYTE = (1<<8) -1
|
|||||||
# ref. Websocket gets an update, and it breaks stuff.
|
# ref. Websocket gets an update, and it breaks stuff.
|
||||||
# http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html
|
# http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html
|
||||||
|
|
||||||
|
|
||||||
def _create_sec_websocket_key():
|
def _create_sec_websocket_key():
|
||||||
uid = uuid.uuid4()
|
uid = uuid.uuid4()
|
||||||
return base64.encodestring(uid.bytes).strip()
|
return base64.encodestring(uid.bytes).strip()
|
||||||
@@ -186,6 +194,7 @@ _HEADERS_TO_CHECK = {
|
|||||||
"connection": "upgrade",
|
"connection": "upgrade",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class _SSLSocketWrapper(object):
|
class _SSLSocketWrapper(object):
|
||||||
def __init__(self, sock):
|
def __init__(self, sock):
|
||||||
self.ssl = socket.ssl(sock)
|
self.ssl = socket.ssl(sock)
|
||||||
@@ -197,6 +206,8 @@ class _SSLSocketWrapper(object):
|
|||||||
return self.ssl.write(payload)
|
return self.ssl.write(payload)
|
||||||
|
|
||||||
_BOOL_VALUES = (0, 1)
|
_BOOL_VALUES = (0, 1)
|
||||||
|
|
||||||
|
|
||||||
def _is_bool(*values):
|
def _is_bool(*values):
|
||||||
for v in values:
|
for v in values:
|
||||||
if v not in _BOOL_VALUES:
|
if v not in _BOOL_VALUES:
|
||||||
@@ -204,6 +215,7 @@ def _is_bool(*values):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class ABNF(object):
|
class ABNF(object):
|
||||||
"""
|
"""
|
||||||
ABNF frame class.
|
ABNF frame class.
|
||||||
@@ -316,6 +328,7 @@ class ABNF(object):
|
|||||||
_d[i] ^= _m[i % 4]
|
_d[i] ^= _m[i % 4]
|
||||||
return _d.tostring()
|
return _d.tostring()
|
||||||
|
|
||||||
|
|
||||||
class WebSocket(object):
|
class WebSocket(object):
|
||||||
"""
|
"""
|
||||||
Low level WebSocket interface.
|
Low level WebSocket interface.
|
||||||
@@ -337,6 +350,7 @@ class WebSocket(object):
|
|||||||
get_mask_key: a callable to produce new mask keys, see the set_mask_key
|
get_mask_key: a callable to produce new mask keys, see the set_mask_key
|
||||||
function's docstring for more details
|
function's docstring for more details
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, get_mask_key = None):
|
def __init__(self, get_mask_key = None):
|
||||||
"""
|
"""
|
||||||
Initalize WebSocket object.
|
Initalize WebSocket object.
|
||||||
@@ -423,8 +437,8 @@ class WebSocket(object):
|
|||||||
header_str = "\r\n".join(headers)
|
header_str = "\r\n".join(headers)
|
||||||
sock.send(header_str)
|
sock.send(header_str)
|
||||||
if traceEnabled:
|
if traceEnabled:
|
||||||
logger.debug( "--- request header ---")
|
logger.debug("--- request header ---")
|
||||||
logger.debug( header_str)
|
logger.debug(header_str)
|
||||||
logger.debug("-----------------------")
|
logger.debug("-----------------------")
|
||||||
|
|
||||||
status, resp_headers = self._read_headers()
|
status, resp_headers = self._read_headers()
|
||||||
@@ -549,7 +563,6 @@ class WebSocket(object):
|
|||||||
elif frame.opcode == ABNF.OPCODE_PING:
|
elif frame.opcode == ABNF.OPCODE_PING:
|
||||||
self.pong(frame.data)
|
self.pong(frame.data)
|
||||||
|
|
||||||
|
|
||||||
def recv_frame(self):
|
def recv_frame(self):
|
||||||
"""
|
"""
|
||||||
recieve data as frame from server.
|
recieve data as frame from server.
|
||||||
@@ -603,8 +616,6 @@ class WebSocket(object):
|
|||||||
raise ValueError("code is invalid range")
|
raise ValueError("code is invalid range")
|
||||||
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
|
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def close(self, status = STATUS_NORMAL, reason = ""):
|
def close(self, status = STATUS_NORMAL, reason = ""):
|
||||||
"""
|
"""
|
||||||
Close Websocket object
|
Close Websocket object
|
||||||
@@ -662,6 +673,7 @@ class WebSocket(object):
|
|||||||
break
|
break
|
||||||
return "".join(line)
|
return "".join(line)
|
||||||
|
|
||||||
|
|
||||||
class WebSocketApp(object):
|
class WebSocketApp(object):
|
||||||
"""
|
"""
|
||||||
Higher level of APIs are provided.
|
Higher level of APIs are provided.
|
||||||
@@ -700,7 +712,7 @@ class WebSocketApp(object):
|
|||||||
|
|
||||||
def send(self, data, opcode = ABNF.OPCODE_TEXT):
|
def send(self, data, opcode = ABNF.OPCODE_TEXT):
|
||||||
"""
|
"""
|
||||||
send message.
|
send message.
|
||||||
data: message to send. If you set opcode to OPCODE_TEXT, data must be utf-8 string or unicode.
|
data: message to send. If you set opcode to OPCODE_TEXT, data must be utf-8 string or unicode.
|
||||||
opcode: operation code of data. default is OPCODE_TEXT.
|
opcode: operation code of data. default is OPCODE_TEXT.
|
||||||
"""
|
"""
|
||||||
@@ -753,6 +765,6 @@ if __name__ == "__main__":
|
|||||||
ws.send("Hello, World")
|
ws.send("Hello, World")
|
||||||
print "Sent"
|
print "Sent"
|
||||||
print "Receiving..."
|
print "Receiving..."
|
||||||
result = ws.recv()
|
result = ws.recv()
|
||||||
print "Received '%s'" % result
|
print "Received '%s'" % result
|
||||||
ws.close()
|
ws.close()
|
||||||
|
Reference in New Issue
Block a user