diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 8852c6e..3df82c1 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -23,6 +23,9 @@ import websocket as ws # Skip test to access the internet. TEST_WITH_INTERNET = False +# Skip Secure WebSocket test. +TEST_SECURE_WS = False + TRACABLE = False @@ -57,8 +60,8 @@ class HeaderSockMock(SockMock): def __init__(self, fname): SockMock.__init__(self) path = os.path.join(os.path.dirname(__file__), fname) - self.add_packet(open(path).read().encode('utf-8')) - + with open(path, "rb") as f: + self.add_packet(f.read()) class WebSocketTest(unittest.TestCase): def setUp(self): @@ -68,96 +71,96 @@ class WebSocketTest(unittest.TestCase): pass def testDefaultTimeout(self): - self.assertEquals(ws.getdefaulttimeout(), None) + self.assertEqual(ws.getdefaulttimeout(), None) ws.setdefaulttimeout(10) - self.assertEquals(ws.getdefaulttimeout(), 10) + self.assertEqual(ws.getdefaulttimeout(), 10) ws.setdefaulttimeout(None) def testParseUrl(self): p = ws._parse_url("ws://www.example.com/r") - self.assertEquals(p[0], "www.example.com") - self.assertEquals(p[1], 80) - self.assertEquals(p[2], "/r") - self.assertEquals(p[3], False) + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], False) p = ws._parse_url("ws://www.example.com/r/") - self.assertEquals(p[0], "www.example.com") - self.assertEquals(p[1], 80) - self.assertEquals(p[2], "/r/") - self.assertEquals(p[3], False) + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/r/") + self.assertEqual(p[3], False) p = ws._parse_url("ws://www.example.com/") - self.assertEquals(p[0], "www.example.com") - self.assertEquals(p[1], 80) - self.assertEquals(p[2], "/") - self.assertEquals(p[3], False) + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/") + self.assertEqual(p[3], False) p = ws._parse_url("ws://www.example.com") - self.assertEquals(p[0], "www.example.com") - self.assertEquals(p[1], 80) - self.assertEquals(p[2], "/") - self.assertEquals(p[3], False) + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/") + self.assertEqual(p[3], False) p = ws._parse_url("ws://www.example.com:8080/r") - self.assertEquals(p[0], "www.example.com") - self.assertEquals(p[1], 8080) - self.assertEquals(p[2], "/r") - self.assertEquals(p[3], False) + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], False) p = ws._parse_url("ws://www.example.com:8080/") - self.assertEquals(p[0], "www.example.com") - self.assertEquals(p[1], 8080) - self.assertEquals(p[2], "/") - self.assertEquals(p[3], False) + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/") + self.assertEqual(p[3], False) p = ws._parse_url("ws://www.example.com:8080") - self.assertEquals(p[0], "www.example.com") - self.assertEquals(p[1], 8080) - self.assertEquals(p[2], "/") - self.assertEquals(p[3], False) + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/") + self.assertEqual(p[3], False) p = ws._parse_url("wss://www.example.com:8080/r") - self.assertEquals(p[0], "www.example.com") - self.assertEquals(p[1], 8080) - self.assertEquals(p[2], "/r") - self.assertEquals(p[3], True) + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], True) p = ws._parse_url("wss://www.example.com:8080/r?key=value") - self.assertEquals(p[0], "www.example.com") - self.assertEquals(p[1], 8080) - self.assertEquals(p[2], "/r?key=value") - self.assertEquals(p[3], True) + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r?key=value") + self.assertEqual(p[3], True) self.assertRaises(ValueError, ws._parse_url, "http://www.example.com/r") p = ws._parse_url("ws://[2a03:4000:123:83::3]/r") - self.assertEquals(p[0], "2a03:4000:123:83::3") - self.assertEquals(p[1], 80) - self.assertEquals(p[2], "/r") - self.assertEquals(p[3], False) + self.assertEqual(p[0], "2a03:4000:123:83::3") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], False) p = ws._parse_url("ws://[2a03:4000:123:83::3]:8080/r") - self.assertEquals(p[0], "2a03:4000:123:83::3") - self.assertEquals(p[1], 8080) - self.assertEquals(p[2], "/r") - self.assertEquals(p[3], False) + self.assertEqual(p[0], "2a03:4000:123:83::3") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], False) p = ws._parse_url("wss://[2a03:4000:123:83::3]/r") - self.assertEquals(p[0], "2a03:4000:123:83::3") - self.assertEquals(p[1], 443) - self.assertEquals(p[2], "/r") - self.assertEquals(p[3], True) + self.assertEqual(p[0], "2a03:4000:123:83::3") + self.assertEqual(p[1], 443) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], True) p = ws._parse_url("wss://[2a03:4000:123:83::3]:8080/r") - self.assertEquals(p[0], "2a03:4000:123:83::3") - self.assertEquals(p[1], 8080) - self.assertEquals(p[2], "/r") - self.assertEquals(p[3], True) + self.assertEqual(p[0], "2a03:4000:123:83::3") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], True) def testWSKey(self): key = ws._create_sec_websocket_key() - self.assert_(key != 24) - self.assert_(six.u("¥n") not in key) + self.assertTrue(key != 24) + self.assertTrue(six.u("¥n") not in key) def testWsUtils(self): sock = ws.WebSocket() @@ -168,32 +171,32 @@ class WebSocketTest(unittest.TestCase): "connection": "upgrade", "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=", } - self.assertEquals(sock._validate_header(required_header, key), True) + self.assertEqual(sock._validate_header(required_header, key), True) header = required_header.copy() header["upgrade"] = "http" - self.assertEquals(sock._validate_header(header, key), False) + self.assertEqual(sock._validate_header(header, key), False) del header["upgrade"] - self.assertEquals(sock._validate_header(header, key), False) + self.assertEqual(sock._validate_header(header, key), False) header = required_header.copy() header["connection"] = "something" - self.assertEquals(sock._validate_header(header, key), False) + self.assertEqual(sock._validate_header(header, key), False) del header["connection"] - self.assertEquals(sock._validate_header(header, key), False) + self.assertEqual(sock._validate_header(header, key), False) header = required_header.copy() header["sec-websocket-accept"] = "something" - self.assertEquals(sock._validate_header(header, key), False) + self.assertEqual(sock._validate_header(header, key), False) del header["sec-websocket-accept"] - self.assertEquals(sock._validate_header(header, key), False) + self.assertEqual(sock._validate_header(header, key), False) def testReadHeader(self): sock = ws.WebSocket() sock.sock = HeaderSockMock("data/header01.txt") status, header = sock._read_headers() - self.assertEquals(status, 101) - self.assertEquals(header["connection"], "upgrade") + self.assertEqual(status, 101) + self.assertEqual(header["connection"], "upgrade") sock.sock = HeaderSockMock("data/header02.txt") self.assertRaises(ws.WebSocketException, sock._read_headers) @@ -204,13 +207,13 @@ class WebSocketTest(unittest.TestCase): sock.set_mask_key(create_mask_key) s = sock.sock = HeaderSockMock("data/header01.txt") sock.send("Hello") - self.assertEquals(s.sent[0], six.b("\x81\x85abcd)\x07\x0f\x08\x0e")) + self.assertEqual(s.sent[0], six.b("\x81\x85abcd)\x07\x0f\x08\x0e")) sock.send("こんにちは") - self.assertEquals(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")) + self.assertEqual(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")) sock.send(u"こんにちは") - self.assertEquals(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")) + self.assertEqual(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")) def testRecv(self): # TODO: add longer frame data @@ -219,11 +222,11 @@ class WebSocketTest(unittest.TestCase): something = six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc") s.add_packet(something) data = sock.recv() - self.assertEquals(data, "こんにちは") + self.assertEqual(data, "こんにちは") s.add_packet(six.b("\x81\x85abcd)\x07\x0f\x08\x0e")) data = sock.recv() - self.assertEquals(data, "Hello") + self.assertEqual(data, "Hello") def testInternalRecvStrict(self): sock = ws.WebSocket() @@ -238,7 +241,7 @@ class WebSocketTest(unittest.TestCase): with self.assertRaises(SSLError): data = sock._recv_strict(9) data = sock._recv_strict(9) - self.assertEquals(data, six.b("foobarbaz")) + self.assertEqual(data, six.b("foobarbaz")) with self.assertRaises(ws.WebSocketConnectionClosedException): data = sock._recv_strict(1) @@ -255,7 +258,7 @@ class WebSocketTest(unittest.TestCase): with self.assertRaises(ws.WebSocketTimeoutException): data = sock.recv() data = sock.recv() - self.assertEquals(data, "Hello, World!") + self.assertEqual(data, "Hello, World!") with self.assertRaises(ws.WebSocketConnectionClosedException): data = sock.recv() @@ -318,37 +321,38 @@ class WebSocketTest(unittest.TestCase): @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") def testWebSocket(self): s = ws.create_connection("ws://echo.websocket.org/") - self.assertNotEquals(s, None) + self.assertNotEqual(s, None) s.send("Hello, World") result = s.recv() - self.assertEquals(result, "Hello, World") + self.assertEqual(result, "Hello, World") s.send(u"こにゃにゃちは、世界") result = s.recv() - self.assertEquals(result, "こにゃにゃちは、世界") + self.assertEqual(result, "こにゃにゃちは、世界") s.close() @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") def testPingPong(self): s = ws.create_connection("ws://echo.websocket.org/") - self.assertNotEquals(s, None) + self.assertNotEqual(s, None) s.ping("Hello") s.pong("Hi") s.close() @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + @unittest.skipUnless(TEST_SECURE_WS, "wss://echo.websocket.org doesn't work well.") def testSecureWebSocket(self): if 1: import ssl s = ws.create_connection("wss://echo.websocket.org/") - self.assertNotEquals(s, None) - self.assert_(isinstance(s.sock, ssl.SSLSocket)) + self.assertNotEqual(s, None) + self.assertTrue(isinstance(s.sock, ssl.SSLSocket)) s.send("Hello, World") result = s.recv() - self.assertEquals(result, "Hello, World") + self.assertEqual(result, "Hello, World") s.send(u"こにゃにゃちは、世界") result = s.recv() - self.assertEquals(result, "こにゃにゃちは、世界") + self.assertEqual(result, "こにゃにゃちは、世界") s.close() #except: # pass @@ -357,17 +361,17 @@ class WebSocketTest(unittest.TestCase): def testWebSocketWihtCustomHeader(self): s = ws.create_connection("ws://echo.websocket.org/", headers={"User-Agent": "PythonWebsocketClient"}) - self.assertNotEquals(s, None) + self.assertNotEqual(s, None) s.send("Hello, World") result = s.recv() - self.assertEquals(result, "Hello, World") + self.assertEqual(result, "Hello, World") s.close() @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") def testAfterClose(self): from socket import error s = ws.create_connection("ws://echo.websocket.org/") - self.assertNotEquals(s, None) + self.assertNotEqual(s, None) s.close() self.assertRaises(error, s.send, "Hello") self.assertRaises(error, s.recv) @@ -377,7 +381,7 @@ class WebSocketTest(unittest.TestCase): """ key = ws._create_sec_websocket_key() u = uuid.UUID(bytes=base64.b64decode(key)) - self.assertEquals(4, u.version) + self.assertEqual(4, u.version) class WebSocketAppTest(unittest.TestCase): @@ -425,8 +429,8 @@ class WebSocketAppTest(unittest.TestCase): self.assertFalse(isinstance(WebSocketAppTest.keep_running_close, WebSocketAppTest.NotSetYet)) - self.assertEquals(True, WebSocketAppTest.keep_running_open) - self.assertEquals(False, WebSocketAppTest.keep_running_close) + self.assertEqual(True, WebSocketAppTest.keep_running_open) + self.assertEqual(False, WebSocketAppTest.keep_running_close) @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") def testSockMaskKey(self): @@ -448,7 +452,7 @@ class WebSocketAppTest(unittest.TestCase): app.run_forever() # Note: We can't use 'is' for comparing the functions directly, need to use 'id'. - self.assertEquals(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func)) + self.assertEqual(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func)) class SockOptTest(unittest.TestCase): @@ -456,9 +460,9 @@ class SockOptTest(unittest.TestCase): def testSockOpt(self): sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),) s = ws.WebSocket(sockopt=sockopt) - self.assertNotEquals(s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0) + self.assertNotEqual(s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0) s = ws.create_connection("ws://echo.websocket.org", sockopt=sockopt) - self.assertNotEquals(s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0) + self.assertNotEqual(s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0) if __name__ == "__main__": diff --git a/websocket/__init__.py b/websocket/__init__.py index 3c821ee..9747b2b 100644 --- a/websocket/__init__.py +++ b/websocket/__init__.py @@ -41,13 +41,16 @@ except ImportError: HAVE_SSL = False from six.moves.urllib.parse import urlparse +if six.PY3: + from base64 import encodebytes as base64encode +else: + from base64 import encodestring as base64encode import os import array import struct import uuid import hashlib -import base64 import threading import time import logging @@ -231,7 +234,7 @@ _MAX_CHAR_BYTE = (1<<8) -1 def _create_sec_websocket_key(): uid = uuid.uuid4() - return base64.encodestring(uid.bytes).decode('utf-8').strip() + return base64encode(uid.bytes).decode('utf-8').strip() _HEADERS_TO_CHECK = { @@ -373,7 +376,10 @@ class ABNF(object): for i in range(len(_d)): _d[i] ^= _m[i % 4] - return _d.tostring() + if six.PY3: + return _d.tobytes() + else: + return _d.tostring() class WebSocket(object): @@ -557,7 +563,7 @@ class WebSocket(object): result = result.encode('utf-8') value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8') - hashed = base64.encodestring(hashlib.sha1(value).digest()).strip().lower() + hashed = base64encode(hashlib.sha1(value).digest()).strip().lower() return hashed == result def _read_headers(self):