improve websocket specification conformance
This commit is contained in:
liris
2014-10-15 08:25:11 +09:00
parent 572367d0c1
commit 931f79d25e
3 changed files with 30 additions and 45 deletions

View File

@@ -22,7 +22,7 @@ import six
import array import array
import struct import struct
import os import os
from ._exceptions import *
@@ -75,9 +75,16 @@ class ABNF(object):
self.data = data self.data = data
self.get_mask_key = os.urandom self.get_mask_key = os.urandom
if rsv1 or rsv2 or rsv3: def validate(self):
"""
validate the ABNF frame.
"""
if self.rsv1 or self.rsv2 or self.rsv3:
raise NotImplementedError("rsv is not implemented, yet") raise NotImplementedError("rsv is not implemented, yet")
if self.opcode == ABNF.OPCODE_PING and not self.fin:
raise WebSocketException("Invalid ping frame.")
def __str__(self): def __str__(self):
return "fin=" + str(self.fin) \ return "fin=" + str(self.fin) \
+ " opcode=" + str(self.opcode) \ + " opcode=" + str(self.opcode) \

View File

@@ -694,42 +694,8 @@ class WebSocket(object):
return value: tuple of operation code and string(byte array) value. return value: tuple of operation code and string(byte array) value.
""" """
while True: opcode, frame = self.recv_data_frame(control_frame)
frame = self.recv_frame() return opcode, frame.data
if not frame:
# handle error:
# 'NoneType' object has no attribute 'opcode'
raise WebSocketException("Not a valid frame %s" % frame)
elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT):
if frame.opcode == ABNF.OPCODE_CONT and not self._recving_frames:
raise WebSocketException("Illegal frame")
if self._cont_data:
self._cont_data[1] += frame.data
else:
if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
self._recving_frames = frame.opcode
self._cont_data = [frame.opcode, frame.data]
if frame.fin:
self._recving_frames = None
if frame.fin or self.fire_cont_frame:
data = self._cont_data
self._cont_data = None
return data
elif frame.opcode == ABNF.OPCODE_CLOSE:
self.send_close()
return (frame.opcode, frame.data)
elif frame.opcode == ABNF.OPCODE_PING:
if len(frame.data) < 126:
self.pong(frame.data)
if control_frame:
return (frame.opcode, frame.data)
elif frame.opcode == ABNF.OPCODE_PONG:
if control_frame:
return (frame.opcode, frame.data)
def recv_data_frame(self, control_frame=False): def recv_data_frame(self, control_frame=False):
""" """
@@ -747,7 +713,9 @@ class WebSocket(object):
# 'NoneType' object has no attribute 'opcode' # 'NoneType' object has no attribute 'opcode'
raise WebSocketException("Not a valid frame %s" % frame) raise WebSocketException("Not a valid frame %s" % frame)
elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT): elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT):
if frame.opcode == ABNF.OPCODE_CONT and not self._recving_frames: if not self._recving_frames and frame.opcode == ABNF.OPCODE_CONT:
raise WebSocketException("Illegal frame")
if self._recving_frames and frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
raise WebSocketException("Illegal frame") raise WebSocketException("Illegal frame")
if self._cont_data: if self._cont_data:
@@ -763,12 +731,14 @@ class WebSocket(object):
if frame.fin or self.fire_cont_frame: if frame.fin or self.fire_cont_frame:
data = self._cont_data data = self._cont_data
self._cont_data = None self._cont_data = None
return data frame.data = data[1]
return [data[0], frame]
elif frame.opcode == ABNF.OPCODE_CLOSE: elif frame.opcode == ABNF.OPCODE_CLOSE:
self.send_close() self.send_close()
return (frame.opcode, frame) return (frame.opcode, frame)
elif frame.opcode == ABNF.OPCODE_PING: elif frame.opcode == ABNF.OPCODE_PING:
if len(frame.data) < 126:
self.pong(frame.data) self.pong(frame.data)
if control_frame: if control_frame:
return (frame.opcode, frame) return (frame.opcode, frame)
@@ -806,7 +776,10 @@ class WebSocket(object):
# Reset for next frame # Reset for next frame
frame_buffer.clear() frame_buffer.clear()
return ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
frame.validate()
return frame
def send_close(self, status=STATUS_NORMAL, reason=six.b("")): def send_close(self, status=STATUS_NORMAL, reason=six.b("")):
@@ -864,6 +837,9 @@ class WebSocket(object):
if isinstance(data, six.text_type): if isinstance(data, six.text_type):
data = data.encode('utf-8') data = data.encode('utf-8')
if not self.sock:
raise WebSocketConnectionClosedException("socket is already closed.")
try: try:
return self.sock.send(data) return self.sock.send(data)
except socket.timeout as e: except socket.timeout as e:
@@ -877,6 +853,9 @@ class WebSocket(object):
raise raise
def _recv(self, bufsize): def _recv(self, bufsize):
if not self.sock:
raise WebSocketConnectionClosedException("socket is already closed.")
try: try:
bytes = self.sock.recv(bufsize) bytes = self.sock.recv(bufsize)
except socket.timeout as e: except socket.timeout as e:

View File

@@ -430,12 +430,11 @@ class WebSocketTest(unittest.TestCase):
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testAfterClose(self): def testAfterClose(self):
from socket import error
s = ws.create_connection("ws://echo.websocket.org/") s = ws.create_connection("ws://echo.websocket.org/")
self.assertNotEqual(s, None) self.assertNotEqual(s, None)
s.close() s.close()
self.assertRaises(error, s.send, "Hello") self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello")
self.assertRaises(error, s.recv) self.assertRaises(ws.WebSocketConnectionClosedException, s.recv)
def testUUID4(self): def testUUID4(self):
""" WebSocket key should be a UUID4. """ WebSocket key should be a UUID4.