tiny modification
This commit is contained in:
@@ -138,7 +138,6 @@ class WebSocketTest(unittest.TestCase):
|
||||
del header["connection"]
|
||||
self.assertEquals(sock._validate_header(header, key), False)
|
||||
|
||||
|
||||
header = required_header.copy()
|
||||
header["sec-websocket-accept"] = "something"
|
||||
self.assertEquals(sock._validate_header(header, key), False)
|
||||
@@ -151,7 +150,7 @@ class WebSocketTest(unittest.TestCase):
|
||||
status, header = sock._read_headers()
|
||||
self.assertEquals(status, 101)
|
||||
self.assertEquals(header["connection"], "upgrade")
|
||||
|
||||
|
||||
sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt")
|
||||
self.assertRaises(ws.WebSocketException, sock._read_headers)
|
||||
|
||||
@@ -176,13 +175,13 @@ class WebSocketTest(unittest.TestCase):
|
||||
s.set_data("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
|
||||
data = sock.recv()
|
||||
self.assertEquals(data, "こんにちは")
|
||||
|
||||
|
||||
s.set_data("\x81\x85abcd)\x07\x0f\x08\x0e")
|
||||
data = sock.recv()
|
||||
self.assertEquals(data, "Hello")
|
||||
|
||||
def testWebSocket(self):
|
||||
s = ws.create_connection("ws://echo.websocket.org/") #ws://localhost:8080/echo")
|
||||
s = ws.create_connection("ws://echo.websocket.org/")
|
||||
self.assertNotEquals(s, None)
|
||||
s.send("Hello, World")
|
||||
result = s.recv()
|
||||
@@ -194,14 +193,14 @@ class WebSocketTest(unittest.TestCase):
|
||||
s.close()
|
||||
|
||||
def testPingPong(self):
|
||||
s = ws.create_connection("ws://echo.websocket.org/")
|
||||
s = ws.create_connection("ws://echo.websocket.org/")
|
||||
self.assertNotEquals(s, None)
|
||||
s.ping("Hello")
|
||||
s.pong("Hi")
|
||||
s.close()
|
||||
|
||||
|
||||
def testSecureWebSocket(self):
|
||||
s = ws.create_connection("wss://echo.websocket.org/")
|
||||
s = ws.create_connection("wss://echo.websocket.org/")
|
||||
self.assertNotEquals(s, None)
|
||||
self.assert_(isinstance(s.io_sock, ws._SSLSocketWrapper))
|
||||
s.send("Hello, World")
|
||||
@@ -213,7 +212,7 @@ class WebSocketTest(unittest.TestCase):
|
||||
s.close()
|
||||
|
||||
def testWebSocketWihtCustomHeader(self):
|
||||
s = ws.create_connection("ws://echo.websocket.org/",
|
||||
s = ws.create_connection("ws://echo.websocket.org/",
|
||||
headers={"User-Agent": "PythonWebsocketClient"})
|
||||
self.assertNotEquals(s, None)
|
||||
s.send("Hello, World")
|
||||
@@ -223,12 +222,12 @@ class WebSocketTest(unittest.TestCase):
|
||||
|
||||
def testAfterClose(self):
|
||||
from socket import error
|
||||
s = ws.create_connection("ws://echo.websocket.org/")
|
||||
s = ws.create_connection("ws://echo.websocket.org/")
|
||||
self.assertNotEquals(s, None)
|
||||
s.close()
|
||||
self.assertRaises(error, s.send, "Hello")
|
||||
self.assertRaises(error, s.recv)
|
||||
|
||||
|
||||
def testUUID4(self):
|
||||
""" WebSocket key should be a UUID4.
|
||||
"""
|
||||
@@ -236,25 +235,25 @@ class WebSocketTest(unittest.TestCase):
|
||||
u = uuid.UUID(bytes=base64.b64decode(key))
|
||||
self.assertEquals(4, u.version)
|
||||
|
||||
|
||||
class WebSocketAppTest(unittest.TestCase):
|
||||
|
||||
class NotSetYet(object):
|
||||
""" A marker class for signalling that a value hasn't been set yet.
|
||||
"""
|
||||
|
||||
|
||||
def setUp(self):
|
||||
ws.enableTrace(TRACABLE)
|
||||
|
||||
|
||||
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
|
||||
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
|
||||
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
|
||||
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
|
||||
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
|
||||
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
|
||||
|
||||
|
||||
def testKeepRunning(self):
|
||||
""" A WebSocketApp should keep running as long as its self.keep_running
|
||||
is not False (in the boolean context).
|
||||
@@ -266,46 +265,44 @@ class WebSocketAppTest(unittest.TestCase):
|
||||
"""
|
||||
WebSocketAppTest.keep_running_open = self.keep_running
|
||||
self.close()
|
||||
|
||||
|
||||
def on_close(self, *args, **kwargs):
|
||||
""" Set the keep_running flag for the test to use.
|
||||
"""
|
||||
WebSocketAppTest.keep_running_close = self.keep_running
|
||||
|
||||
|
||||
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close)
|
||||
app.run_forever()
|
||||
|
||||
self.assertFalse(isinstance(WebSocketAppTest.keep_running_open,
|
||||
|
||||
self.assertFalse(isinstance(WebSocketAppTest.keep_running_open,
|
||||
WebSocketAppTest.NotSetYet))
|
||||
|
||||
self.assertFalse(isinstance(WebSocketAppTest.keep_running_close,
|
||||
|
||||
self.assertFalse(isinstance(WebSocketAppTest.keep_running_close,
|
||||
WebSocketAppTest.NotSetYet))
|
||||
|
||||
|
||||
self.assertEquals(True, WebSocketAppTest.keep_running_open)
|
||||
self.assertEquals(False, WebSocketAppTest.keep_running_close)
|
||||
|
||||
|
||||
def testSockMaskKey(self):
|
||||
""" A WebSocketApp should forward the received mask_key function down
|
||||
to the actual socket.
|
||||
"""
|
||||
|
||||
|
||||
def my_mask_key_func():
|
||||
pass
|
||||
|
||||
def on_open(self, *args, **kwargs):
|
||||
""" Set the value so the test can use it later on and immediately
|
||||
""" Set the value so the test can use it later on and immediately
|
||||
close the connection.
|
||||
"""
|
||||
WebSocketAppTest.get_mask_key_id = id(self.get_mask_key)
|
||||
self.close()
|
||||
|
||||
|
||||
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, get_mask_key=my_mask_key_func)
|
||||
app.run_forever()
|
||||
|
||||
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
|
||||
self.assertEquals(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
28
websocket.py
28
websocket.py
@@ -58,12 +58,14 @@ STATUS_TLS_HANDSHAKE_ERROR = 1015
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class WebSocketException(Exception):
|
||||
"""
|
||||
websocket exeception class.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketConnectionClosedException(WebSocketException):
|
||||
"""
|
||||
If remote host closed the connection or some network error happened,
|
||||
@@ -74,6 +76,7 @@ class WebSocketConnectionClosedException(WebSocketException):
|
||||
default_timeout = None
|
||||
traceEnabled = False
|
||||
|
||||
|
||||
def enableTrace(tracable):
|
||||
"""
|
||||
turn on/off the tracability.
|
||||
@@ -87,6 +90,7 @@ def enableTrace(tracable):
|
||||
logger.addHandler(logging.StreamHandler())
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def setdefaulttimeout(timeout):
|
||||
"""
|
||||
Set the global timeout setting to connect.
|
||||
@@ -96,12 +100,14 @@ def setdefaulttimeout(timeout):
|
||||
global default_timeout
|
||||
default_timeout = timeout
|
||||
|
||||
|
||||
def getdefaulttimeout():
|
||||
"""
|
||||
Return the global timeout setting(second) to connect.
|
||||
"""
|
||||
return default_timeout
|
||||
|
||||
|
||||
def _parse_url(url):
|
||||
"""
|
||||
parse url and the result is tuple of
|
||||
@@ -130,7 +136,7 @@ def _parse_url(url):
|
||||
elif scheme == "wss":
|
||||
is_secure = True
|
||||
if not port:
|
||||
port = 443
|
||||
port = 443
|
||||
else:
|
||||
raise ValueError("scheme %s is invalid" % scheme)
|
||||
|
||||
@@ -144,6 +150,7 @@ def _parse_url(url):
|
||||
|
||||
return (hostname, port, resource, is_secure)
|
||||
|
||||
|
||||
def create_connection(url, timeout=None, **options):
|
||||
"""
|
||||
connect to url and return websocket object.
|
||||
@@ -177,6 +184,7 @@ _MAX_CHAR_BYTE = (1<<8) -1
|
||||
# ref. Websocket gets an update, and it breaks stuff.
|
||||
# http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html
|
||||
|
||||
|
||||
def _create_sec_websocket_key():
|
||||
uid = uuid.uuid4()
|
||||
return base64.encodestring(uid.bytes).strip()
|
||||
@@ -186,6 +194,7 @@ _HEADERS_TO_CHECK = {
|
||||
"connection": "upgrade",
|
||||
}
|
||||
|
||||
|
||||
class _SSLSocketWrapper(object):
|
||||
def __init__(self, sock):
|
||||
self.ssl = socket.ssl(sock)
|
||||
@@ -197,6 +206,8 @@ class _SSLSocketWrapper(object):
|
||||
return self.ssl.write(payload)
|
||||
|
||||
_BOOL_VALUES = (0, 1)
|
||||
|
||||
|
||||
def _is_bool(*values):
|
||||
for v in values:
|
||||
if v not in _BOOL_VALUES:
|
||||
@@ -204,6 +215,7 @@ def _is_bool(*values):
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class ABNF(object):
|
||||
"""
|
||||
ABNF frame class.
|
||||
@@ -316,6 +328,7 @@ class ABNF(object):
|
||||
_d[i] ^= _m[i % 4]
|
||||
return _d.tostring()
|
||||
|
||||
|
||||
class WebSocket(object):
|
||||
"""
|
||||
Low level WebSocket interface.
|
||||
@@ -337,6 +350,7 @@ class WebSocket(object):
|
||||
get_mask_key: a callable to produce new mask keys, see the set_mask_key
|
||||
function's docstring for more details
|
||||
"""
|
||||
|
||||
def __init__(self, get_mask_key = None):
|
||||
"""
|
||||
Initalize WebSocket object.
|
||||
@@ -423,8 +437,8 @@ class WebSocket(object):
|
||||
header_str = "\r\n".join(headers)
|
||||
sock.send(header_str)
|
||||
if traceEnabled:
|
||||
logger.debug( "--- request header ---")
|
||||
logger.debug( header_str)
|
||||
logger.debug("--- request header ---")
|
||||
logger.debug(header_str)
|
||||
logger.debug("-----------------------")
|
||||
|
||||
status, resp_headers = self._read_headers()
|
||||
@@ -549,7 +563,6 @@ class WebSocket(object):
|
||||
elif frame.opcode == ABNF.OPCODE_PING:
|
||||
self.pong(frame.data)
|
||||
|
||||
|
||||
def recv_frame(self):
|
||||
"""
|
||||
recieve data as frame from server.
|
||||
@@ -603,8 +616,6 @@ class WebSocket(object):
|
||||
raise ValueError("code is invalid range")
|
||||
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
|
||||
|
||||
|
||||
|
||||
def close(self, status = STATUS_NORMAL, reason = ""):
|
||||
"""
|
||||
Close Websocket object
|
||||
@@ -662,6 +673,7 @@ class WebSocket(object):
|
||||
break
|
||||
return "".join(line)
|
||||
|
||||
|
||||
class WebSocketApp(object):
|
||||
"""
|
||||
Higher level of APIs are provided.
|
||||
@@ -700,7 +712,7 @@ class WebSocketApp(object):
|
||||
|
||||
def send(self, data, opcode = ABNF.OPCODE_TEXT):
|
||||
"""
|
||||
send message.
|
||||
send message.
|
||||
data: message to send. If you set opcode to OPCODE_TEXT, data must be utf-8 string or unicode.
|
||||
opcode: operation code of data. default is OPCODE_TEXT.
|
||||
"""
|
||||
@@ -753,6 +765,6 @@ if __name__ == "__main__":
|
||||
ws.send("Hello, World")
|
||||
print "Sent"
|
||||
print "Receiving..."
|
||||
result = ws.recv()
|
||||
result = ws.recv()
|
||||
print "Received '%s'" % result
|
||||
ws.close()
|
||||
|
Reference in New Issue
Block a user