From 5d04ab5bae60181601f44b368fd37bdeffb918e3 Mon Sep 17 00:00:00 2001 From: liris Date: Wed, 25 Mar 2015 10:46:28 +0900 Subject: [PATCH] refactoring. pack continuous frame info into one class. --- websocket/_abnf.py | 69 +++++++++++++++++++++++++++++++++++++++++++++- websocket/_core.py | 61 +++++----------------------------------- 2 files changed, 75 insertions(+), 55 deletions(-) diff --git a/websocket/_abnf.py b/websocket/_abnf.py index 554cefc..d97cfdd 100644 --- a/websocket/_abnf.py +++ b/websocket/_abnf.py @@ -226,8 +226,9 @@ class frame_buffer(object): _HEADER_MASK_INDEX = 5 _HEADER_LENGHT_INDEX = 6 - def __init__(self, recv_fn): + def __init__(self, recv_fn, skip_utf8_validation): self.recv = recv_fn + self.skip_utf8_validation = skip_utf8_validation # Buffers over the packets from the layer beneath until desired amount # bytes of bytes are received. self.recv_buffer = [] @@ -290,6 +291,35 @@ class frame_buffer(object): def recv_mask(self): self.mask = self.recv_strict(4) if self.has_mask() else "" + def recv_frame(self): + # Header + if self.has_received_header(): + self.recv_header() + (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header + + # Frame length + if self.has_received_length(): + self.recv_length() + length = self.length + + # Mask + if self.has_received_mask(): + self.recv_mask() + mask = self.mask + + # Payload + payload = self.recv_strict(length) + if has_mask: + payload = ABNF.mask(mask, payload) + + # Reset for next frame + self.clear() + + frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) + frame.validate(self.skip_utf8_validation) + + return frame + def recv_strict(self, bufsize): shortage = bufsize - sum(len(x) for x in self.recv_buffer) while shortage > 0: @@ -305,3 +335,40 @@ class frame_buffer(object): else: self.recv_buffer = [unified[bufsize:]] return unified[:bufsize] + + +class continuous_frame(object): + def __init__(self, fire_cont_frame, skip_utf8_validation): + self.fire_cont_frame = fire_cont_frame + self.skip_utf8_validation = skip_utf8_validation + self.cont_data = None + self.recving_frames = None + + def validate(self, frame): + if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: + raise WebSocketProtocolException("Illegal frame") + if self.recving_frames and frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): + raise WebSocketProtocolException("Illegal frame") + + def add(self, frame): + if self.cont_data: + self.cont_data[1] += frame.data + else: + if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): + self.recving_frames = frame.opcode + self.cont_data = [frame.opcode, frame.data] + + if frame.fin: + self.recving_frames = None + + def is_fire(self, frame): + return frame.fin or self.fire_cont_frame + + def extract(self, frame): + data = self.cont_data + self.cont_data = None + frame.data = data[1] + if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data): + raise WebSocketPayloadException("cannot decode: " + repr(frame.data)) + + return [data[0], frame] diff --git a/websocket/_core.py b/websocket/_core.py index b7814bb..53a25a5 100644 --- a/websocket/_core.py +++ b/websocket/_core.py @@ -160,12 +160,9 @@ class WebSocket(object): self.connected = False self.get_mask_key = get_mask_key - self.fire_cont_frame = fire_cont_frame - self.skip_utf8_validation = skip_utf8_validation # These buffer over the build-up of a single frame. - self.frame_buffer = frame_buffer(self._recv) - self._cont_data = None - self._recving_frames = None + self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation) + self.cont_frame = continuous_frame(fire_cont_frame, skip_utf8_validation) if enable_multithread: self.lock = threading.Lock() @@ -384,28 +381,11 @@ class WebSocket(object): # 'NoneType' object has no attribute 'opcode' raise WebSocketProtocolException("Not a valid frame %s" % frame) elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT): - if not self._recving_frames and frame.opcode == ABNF.OPCODE_CONT: - raise WebSocketProtocolException("Illegal frame") - if self._recving_frames and frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): - raise WebSocketProtocolException("Illegal frame") + self.cont_frame.validate(frame) + self.cont_frame.add(frame) - if self._cont_data: - self._cont_data[1] += frame.data - else: - if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): - self._recving_frames = frame.opcode - self._cont_data = [frame.opcode, frame.data] - - if frame.fin: - self._recving_frames = None - - if frame.fin or self.fire_cont_frame: - data = self._cont_data - self._cont_data = None - frame.data = data[1] - if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data): - raise WebSocketPayloadException("cannot decode: " + repr(frame.data)) - return [data[0], frame] + if self.cont_frame.is_fire(frame): + return self.cont_frame.extract(frame) elif frame.opcode == ABNF.OPCODE_CLOSE: self.send_close() @@ -427,34 +407,7 @@ class WebSocket(object): return value: ABNF frame object. """ - frame_buffer = self.frame_buffer - # Header - if frame_buffer.has_received_header(): - frame_buffer.recv_header() - (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = frame_buffer.header - - # Frame length - if frame_buffer.has_received_length(): - frame_buffer.recv_length() - length = frame_buffer.length - - # Mask - if frame_buffer.has_received_mask(): - frame_buffer.recv_mask() - mask = frame_buffer.mask - - # Payload - payload = frame_buffer.recv_strict(length) - if has_mask: - payload = ABNF.mask(mask, payload) - - # Reset for next frame - frame_buffer.clear() - - frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) - frame.validate(self.skip_utf8_validation) - - return frame + return self.frame_buffer.recv_frame() def send_close(self, status=STATUS_NORMAL, reason=six.b("")): """