refactoring. pack continuous frame info into one class.

This commit is contained in:
liris
2015-03-25 10:46:28 +09:00
parent 46f9266e6d
commit 5d04ab5bae
2 changed files with 75 additions and 55 deletions

View File

@@ -226,8 +226,9 @@ class frame_buffer(object):
_HEADER_MASK_INDEX = 5
_HEADER_LENGHT_INDEX = 6
def __init__(self, recv_fn):
def __init__(self, recv_fn, skip_utf8_validation):
self.recv = recv_fn
self.skip_utf8_validation = skip_utf8_validation
# Buffers over the packets from the layer beneath until desired amount
# bytes of bytes are received.
self.recv_buffer = []
@@ -290,6 +291,35 @@ class frame_buffer(object):
def recv_mask(self):
self.mask = self.recv_strict(4) if self.has_mask() else ""
def recv_frame(self):
# Header
if self.has_received_header():
self.recv_header()
(fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
# Frame length
if self.has_received_length():
self.recv_length()
length = self.length
# Mask
if self.has_received_mask():
self.recv_mask()
mask = self.mask
# Payload
payload = self.recv_strict(length)
if has_mask:
payload = ABNF.mask(mask, payload)
# Reset for next frame
self.clear()
frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
frame.validate(self.skip_utf8_validation)
return frame
def recv_strict(self, bufsize):
shortage = bufsize - sum(len(x) for x in self.recv_buffer)
while shortage > 0:
@@ -305,3 +335,40 @@ class frame_buffer(object):
else:
self.recv_buffer = [unified[bufsize:]]
return unified[:bufsize]
class continuous_frame(object):
def __init__(self, fire_cont_frame, skip_utf8_validation):
self.fire_cont_frame = fire_cont_frame
self.skip_utf8_validation = skip_utf8_validation
self.cont_data = None
self.recving_frames = None
def validate(self, frame):
if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
raise WebSocketProtocolException("Illegal frame")
if self.recving_frames and frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
raise WebSocketProtocolException("Illegal frame")
def add(self, frame):
if self.cont_data:
self.cont_data[1] += frame.data
else:
if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
self.recving_frames = frame.opcode
self.cont_data = [frame.opcode, frame.data]
if frame.fin:
self.recving_frames = None
def is_fire(self, frame):
return frame.fin or self.fire_cont_frame
def extract(self, frame):
data = self.cont_data
self.cont_data = None
frame.data = data[1]
if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
raise WebSocketPayloadException("cannot decode: " + repr(frame.data))
return [data[0], frame]

View File

@@ -160,12 +160,9 @@ class WebSocket(object):
self.connected = False
self.get_mask_key = get_mask_key
self.fire_cont_frame = fire_cont_frame
self.skip_utf8_validation = skip_utf8_validation
# These buffer over the build-up of a single frame.
self.frame_buffer = frame_buffer(self._recv)
self._cont_data = None
self._recving_frames = None
self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation)
self.cont_frame = continuous_frame(fire_cont_frame, skip_utf8_validation)
if enable_multithread:
self.lock = threading.Lock()
@@ -384,28 +381,11 @@ class WebSocket(object):
# 'NoneType' object has no attribute 'opcode'
raise WebSocketProtocolException("Not a valid frame %s" % frame)
elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT):
if not self._recving_frames and frame.opcode == ABNF.OPCODE_CONT:
raise WebSocketProtocolException("Illegal frame")
if self._recving_frames and frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
raise WebSocketProtocolException("Illegal frame")
self.cont_frame.validate(frame)
self.cont_frame.add(frame)
if self._cont_data:
self._cont_data[1] += frame.data
else:
if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
self._recving_frames = frame.opcode
self._cont_data = [frame.opcode, frame.data]
if frame.fin:
self._recving_frames = None
if frame.fin or self.fire_cont_frame:
data = self._cont_data
self._cont_data = None
frame.data = data[1]
if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
raise WebSocketPayloadException("cannot decode: " + repr(frame.data))
return [data[0], frame]
if self.cont_frame.is_fire(frame):
return self.cont_frame.extract(frame)
elif frame.opcode == ABNF.OPCODE_CLOSE:
self.send_close()
@@ -427,34 +407,7 @@ class WebSocket(object):
return value: ABNF frame object.
"""
frame_buffer = self.frame_buffer
# Header
if frame_buffer.has_received_header():
frame_buffer.recv_header()
(fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = frame_buffer.header
# Frame length
if frame_buffer.has_received_length():
frame_buffer.recv_length()
length = frame_buffer.length
# Mask
if frame_buffer.has_received_mask():
frame_buffer.recv_mask()
mask = frame_buffer.mask
# Payload
payload = frame_buffer.recv_strict(length)
if has_mask:
payload = ABNF.mask(mask, payload)
# Reset for next frame
frame_buffer.clear()
frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
frame.validate(self.skip_utf8_validation)
return frame
return self.frame_buffer.recv_frame()
def send_close(self, status=STATUS_NORMAL, reason=six.b("")):
"""