diff --git a/test_websocket.py b/test_websocket.py index ac67c3d..d849a2b 100644 --- a/test_websocket.py +++ b/test_websocket.py @@ -2,9 +2,10 @@ # import base64 -import uuid -import unittest import socket +import ssl +import unittest +import uuid # websocket-client import websocket as ws @@ -14,32 +15,34 @@ TRACABLE=False def create_mask_key(n): return "abcd" -class StringSockMock: +class SockMock(object): + def __init__(self): - self.set_data("") + self.data = [] self.sent = [] - def set_data(self, data): - self.data = data - self.pos = 0 - self.len = len(data) + def add_packet(self, data): + self.data.append(data) def recv(self, bufsize): - if self.len < self.pos: - return - buf = self.data[self.pos: self.pos + bufsize] - self.pos += bufsize - return buf + if self.data: + e = self.data.pop(0) + if isinstance(e, Exception): + raise e + if len(e) > bufsize: + self.data.insert(0, e[bufsize:]) + return e[:bufsize] def send(self, data): self.sent.append(data) return len(data) -class HeaderSockMock(StringSockMock): +class HeaderSockMock(SockMock): + def __init__(self, fname): - self.set_data(open(fname).read()) - self.sent = [] + SockMock.__init__(self) + self.add_packet(open(fname).read()) class WebSocketTest(unittest.TestCase): @@ -173,15 +176,49 @@ class WebSocketTest(unittest.TestCase): def testRecv(self): # TODO: add longer frame data sock = ws.WebSocket() - s = sock.sock = StringSockMock() - s.set_data("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc") + s = sock.sock = SockMock() + s.add_packet("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc") data = sock.recv() self.assertEquals(data, "こんにちは") - s.set_data("\x81\x85abcd)\x07\x0f\x08\x0e") + s.add_packet("\x81\x85abcd)\x07\x0f\x08\x0e") data = sock.recv() self.assertEquals(data, "Hello") + def testInternalRecvStrict(self): + sock = ws.WebSocket() + s = sock.sock = SockMock() + s.add_packet("foo") + s.add_packet(socket.timeout()) + s.add_packet("bar") + s.add_packet(ssl.SSLError("The read operation timed out")) + s.add_packet("baz") + with self.assertRaises(ws.WebSocketTimeoutException): + data = sock._recv_strict(9) + with self.assertRaises(ws.WebSocketTimeoutException): + data = sock._recv_strict(9) + data = sock._recv_strict(9) + self.assertEquals(data, "foobarbaz") + with self.assertRaises(ws.WebSocketConnectionClosedException): + data = sock._recv_strict(1) + + def testRecvTimeout(self): + sock = ws.WebSocket() + s = sock.sock = SockMock() + s.add_packet("\x81") + s.add_packet(socket.timeout()) + s.add_packet("\x8dabcd\x29\x07\x0f\x08\x0e") + s.add_packet(socket.timeout()) + s.add_packet("\x4e\x43\x33\x0e\x10\x0f\x00\x40") + with self.assertRaises(ws.WebSocketTimeoutException): + data = sock.recv() + with self.assertRaises(ws.WebSocketTimeoutException): + data = sock.recv() + data = sock.recv() + self.assertEquals(data, "Hello, World!") + with self.assertRaises(ws.WebSocketConnectionClosedException): + data = sock.recv() + def testWebSocket(self): s = ws.create_connection("ws://echo.websocket.org/") self.assertNotEquals(s, None)