- refactoring.

- pack some fields to object.
This commit is contained in:
liris
2014-05-30 14:28:27 +09:00
parent 048d683da1
commit 22f92961d4

View File

@@ -225,6 +225,69 @@ _HEADERS_TO_CHECK = {
}
class _FrameBuffer(object):
_HEADER_MASK_INDEX = 5
_HEADER_LENGHT_INDEX = 6
def __init__(self):
self.clear()
def clear(self):
self.header = None
self.length = None
self.mask = None
def has_received_header(self):
return self.header is None
def recv_header(self, recv_fn):
header = recv_fn(2)
b1 = header[0]
if six.PY2:
b1 = ord(b1)
fin = b1 >> 7 & 1
rsv1 = b1 >> 6 & 1
rsv2 = b1 >> 5 & 1
rsv3 = b1 >> 4 & 1
opcode = b1 & 0xf
b2 = header[1]
if six.PY2:
b2 = ord(b2)
has_mask = b2 >> 7 & 1
length_bits = b2 & 0x7f
self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
def has_mask(self):
if not self.header:
return False
return self.header[_FrameBuffer._HEADER_MASK_INDEX]
def has_received_length(self):
return self.length is None
def recv_length(self, recv_fn):
bits = self.header[_FrameBuffer._HEADER_LENGHT_INDEX]
length_bits = bits & 0x7f
if length_bits == 0x7e:
v = recv_fn(2)
self.length = struct.unpack("!H", v)[0]
elif length_bits == 0x7f:
v = recv_fn(8)
self.length = struct.unpack("!Q", v)[0]
else:
self.length = length_bits
def has_received_mask(self):
return self.mask is None
def recv_mask(self, recv_fn):
self.mask = recv_fn(4) if self.has_mask() else ""
class WebSocket(object):
@@ -273,9 +336,7 @@ class WebSocket(object):
# bytes of bytes are received.
self._recv_buffer = []
# These buffer over the build-up of a single frame.
self._frame_header = None
self._frame_length = None
self._frame_mask = None
self._frame_buffer = _FrameBuffer()
self._cont_data = None
def fileno(self):
@@ -635,53 +696,29 @@ class WebSocket(object):
return value: ABNF frame object.
"""
frame_buffer = self._frame_buffer
# Header
if self._frame_header is None:
self._frame_header = self._recv_strict(2)
b1 = self._frame_header[0]
if six.PY2:
b1 = ord(b1)
fin = b1 >> 7 & 1
rsv1 = b1 >> 6 & 1
rsv2 = b1 >> 5 & 1
rsv3 = b1 >> 4 & 1
opcode = b1 & 0xf
b2 = self._frame_header[1]
if six.PY2:
b2 = ord(b2)
has_mask = b2 >> 7 & 1
if frame_buffer.has_received_header():
frame_buffer.recv_header(self._recv_strict)
(fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = frame_buffer.header
# Frame length
if self._frame_length is None:
length_bits = b2 & 0x7f
if length_bits == 0x7e:
length_data = self._recv_strict(2)
self._frame_length = struct.unpack("!H", length_data)[0]
elif length_bits == 0x7f:
length_data = self._recv_strict(8)
self._frame_length = struct.unpack("!Q", length_data)[0]
else:
self._frame_length = length_bits
if frame_buffer.has_received_length():
frame_buffer.recv_length(self._recv_strict)
length = frame_buffer.length
# Mask
if self._frame_mask is None:
self._frame_mask = self._recv_strict(4) if has_mask else ""
if frame_buffer.has_received_mask():
frame_buffer.recv_mask(self._recv_strict)
mask = frame_buffer.mask
# Payload
payload = self._recv_strict(self._frame_length)
payload = self._recv_strict(length)
if has_mask:
payload = ABNF.mask(self._frame_mask, payload)
payload = ABNF.mask(mask, payload)
# Reset for next frame
self._frame_header = None
self._frame_length = None
self._frame_mask = None
frame_buffer.clear()
return ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)