From 931f79d25ecaf728f792f12194d5624264f2401c Mon Sep 17 00:00:00 2001 From: liris Date: Wed, 15 Oct 2014 08:25:11 +0900 Subject: [PATCH] refs #117 improve websocket specification conformance --- websocket/_abnf.py | 11 ++++-- websocket/_core.py | 59 ++++++++++--------------------- websocket/tests/test_websocket.py | 5 ++- 3 files changed, 30 insertions(+), 45 deletions(-) diff --git a/websocket/_abnf.py b/websocket/_abnf.py index 8252f4c..24cf28b 100644 --- a/websocket/_abnf.py +++ b/websocket/_abnf.py @@ -22,7 +22,7 @@ import six import array import struct import os - +from ._exceptions import * @@ -75,9 +75,16 @@ class ABNF(object): self.data = data self.get_mask_key = os.urandom - if rsv1 or rsv2 or rsv3: + def validate(self): + """ + validate the ABNF frame. + """ + if self.rsv1 or self.rsv2 or self.rsv3: raise NotImplementedError("rsv is not implemented, yet") + if self.opcode == ABNF.OPCODE_PING and not self.fin: + raise WebSocketException("Invalid ping frame.") + def __str__(self): return "fin=" + str(self.fin) \ + " opcode=" + str(self.opcode) \ diff --git a/websocket/_core.py b/websocket/_core.py index 49e8d49..ef1247f 100644 --- a/websocket/_core.py +++ b/websocket/_core.py @@ -694,42 +694,8 @@ class WebSocket(object): return value: tuple of operation code and string(byte array) value. """ - while True: - frame = self.recv_frame() - if not frame: - # handle error: - # 'NoneType' object has no attribute 'opcode' - raise WebSocketException("Not a valid frame %s" % frame) - elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT): - if frame.opcode == ABNF.OPCODE_CONT and not self._recving_frames: - raise WebSocketException("Illegal frame") - - if self._cont_data: - self._cont_data[1] += frame.data - else: - if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): - self._recving_frames = frame.opcode - self._cont_data = [frame.opcode, frame.data] - - if frame.fin: - self._recving_frames = None - - if frame.fin or self.fire_cont_frame: - data = self._cont_data - self._cont_data = None - return data - - elif frame.opcode == ABNF.OPCODE_CLOSE: - self.send_close() - return (frame.opcode, frame.data) - elif frame.opcode == ABNF.OPCODE_PING: - if len(frame.data) < 126: - self.pong(frame.data) - if control_frame: - return (frame.opcode, frame.data) - elif frame.opcode == ABNF.OPCODE_PONG: - if control_frame: - return (frame.opcode, frame.data) + opcode, frame = self.recv_data_frame(control_frame) + return opcode, frame.data def recv_data_frame(self, control_frame=False): """ @@ -747,7 +713,9 @@ class WebSocket(object): # 'NoneType' object has no attribute 'opcode' raise WebSocketException("Not a valid frame %s" % frame) elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT): - if frame.opcode == ABNF.OPCODE_CONT and not self._recving_frames: + if not self._recving_frames and frame.opcode == ABNF.OPCODE_CONT: + raise WebSocketException("Illegal frame") + if self._recving_frames and frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): raise WebSocketException("Illegal frame") if self._cont_data: @@ -763,13 +731,15 @@ class WebSocket(object): if frame.fin or self.fire_cont_frame: data = self._cont_data self._cont_data = None - return data + frame.data = data[1] + return [data[0], frame] elif frame.opcode == ABNF.OPCODE_CLOSE: self.send_close() return (frame.opcode, frame) elif frame.opcode == ABNF.OPCODE_PING: - self.pong(frame.data) + if len(frame.data) < 126: + self.pong(frame.data) if control_frame: return (frame.opcode, frame) elif frame.opcode == ABNF.OPCODE_PONG: @@ -806,7 +776,10 @@ class WebSocket(object): # Reset for next frame frame_buffer.clear() - return ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) + frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) + frame.validate() + + return frame def send_close(self, status=STATUS_NORMAL, reason=six.b("")): @@ -864,6 +837,9 @@ class WebSocket(object): if isinstance(data, six.text_type): data = data.encode('utf-8') + if not self.sock: + raise WebSocketConnectionClosedException("socket is already closed.") + try: return self.sock.send(data) except socket.timeout as e: @@ -877,6 +853,9 @@ class WebSocket(object): raise def _recv(self, bufsize): + if not self.sock: + raise WebSocketConnectionClosedException("socket is already closed.") + try: bytes = self.sock.recv(bufsize) except socket.timeout as e: diff --git a/websocket/tests/test_websocket.py b/websocket/tests/test_websocket.py index 0495721..84ea5ca 100644 --- a/websocket/tests/test_websocket.py +++ b/websocket/tests/test_websocket.py @@ -430,12 +430,11 @@ class WebSocketTest(unittest.TestCase): @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") def testAfterClose(self): - from socket import error s = ws.create_connection("ws://echo.websocket.org/") self.assertNotEqual(s, None) s.close() - self.assertRaises(error, s.send, "Hello") - self.assertRaises(error, s.recv) + self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello") + self.assertRaises(ws.WebSocketConnectionClosedException, s.recv) def testUUID4(self): """ WebSocket key should be a UUID4.