Merge branch 'master' of https://github.com/ralphbean/websocket-client into ralphbean-master

This commit is contained in:
liris 2014-04-22 07:55:43 +09:00
commit eab5169ec1
5 changed files with 127 additions and 63 deletions

@ -2,6 +2,7 @@
import argparse
import code
import six
import sys
import threading
import websocket
@ -45,11 +46,15 @@ class InteractiveConsole(code.InteractiveConsole):
sys.stdout.flush()
def raw_input(self, prompt):
line = raw_input(prompt)
if ENCODING and ENCODING != "utf-8" and not isinstance(line, unicode):
if six.PY3:
line = input(prompt)
else:
line = raw_input(prompt)
if ENCODING and ENCODING != "utf-8" and not isinstance(line, six.text_type):
line = line.decode(ENCODING).encode("utf-8")
elif isinstance(line, unicode):
line = encode("utf-8")
elif isinstance(line, six.text_type):
line = line.encode("utf-8")
return line

@ -1,3 +1,4 @@
from __future__ import print_function
import websocket
if __name__ == "__main__":

@ -25,7 +25,10 @@ setup(
],
keywords='websockets',
scripts=["bin/wsdump.py"],
install_requires=['backports.ssl_match_hostname'],
install_requires=[
'backports.ssl_match_hostname',
'six',
],
packages=["tests", "websocket"],
package_data={
'tests': ['data/*.txt'],

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
#
import six
import sys
sys.path[0:0] = [""]
@ -63,7 +64,7 @@ class HeaderSockMock(SockMock):
def __init__(self, fname):
SockMock.__init__(self)
path = os.path.join(os.path.dirname(__file__), fname)
self.add_packet(open(path).read())
self.add_packet(open(path).read().encode('utf-8'))
class WebSocketTest(unittest.TestCase):
@ -163,7 +164,7 @@ class WebSocketTest(unittest.TestCase):
def testWSKey(self):
key = ws._create_sec_websocket_key()
self.assert_(key != 24)
self.assert_("¥n" not in key)
self.assert_(six.u("¥n") not in key)
def testWsUtils(self):
sock = ws.WebSocket()
@ -210,57 +211,59 @@ class WebSocketTest(unittest.TestCase):
sock.set_mask_key(create_mask_key)
s = sock.sock = HeaderSockMock("data/header01.txt")
sock.send("Hello")
self.assertEquals(s.sent[0], "\x81\x85abcd)\x07\x0f\x08\x0e")
self.assertEquals(s.sent[0], six.b("\x81\x85abcd)\x07\x0f\x08\x0e"))
sock.send("こんにちは")
self.assertEquals(s.sent[1], "\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
self.assertEquals(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc"))
sock.send(u"こんにちは")
self.assertEquals(s.sent[1], "\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
self.assertEquals(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc"))
def testRecv(self):
# TODO: add longer frame data
sock = ws.WebSocket()
s = sock.sock = SockMock()
s.add_packet("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
something = six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
s.add_packet(something)
data = sock.recv()
self.assertEquals(data, "こんにちは")
data = data.decode('utf-8')
self.assertEquals(data, u"こんにちは")
s.add_packet("\x81\x85abcd)\x07\x0f\x08\x0e")
s.add_packet(six.b("\x81\x85abcd)\x07\x0f\x08\x0e"))
data = sock.recv()
self.assertEquals(data, "Hello")
self.assertEquals(data.decode('utf-8'), "Hello")
def testInternalRecvStrict(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
s.add_packet("foo")
s.add_packet(six.b("foo"))
s.add_packet(socket.timeout())
s.add_packet("bar")
s.add_packet(six.b("bar"))
s.add_packet(SSLError("The read operation timed out"))
s.add_packet("baz")
s.add_packet(six.b("baz"))
with self.assertRaises(ws.WebSocketTimeoutException):
data = sock._recv_strict(9)
with self.assertRaises(ws.WebSocketTimeoutException):
with self.assertRaises(SSLError):
data = sock._recv_strict(9)
data = sock._recv_strict(9)
self.assertEquals(data, "foobarbaz")
self.assertEquals(data.decode('utf-8'), "foobarbaz")
with self.assertRaises(ws.WebSocketConnectionClosedException):
data = sock._recv_strict(1)
def testRecvTimeout(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
s.add_packet("\x81")
s.add_packet(six.b("\x81"))
s.add_packet(socket.timeout())
s.add_packet("\x8dabcd\x29\x07\x0f\x08\x0e")
s.add_packet(six.b("\x8dabcd\x29\x07\x0f\x08\x0e"))
s.add_packet(socket.timeout())
s.add_packet("\x4e\x43\x33\x0e\x10\x0f\x00\x40")
s.add_packet(six.b("\x4e\x43\x33\x0e\x10\x0f\x00\x40"))
with self.assertRaises(ws.WebSocketTimeoutException):
data = sock.recv()
with self.assertRaises(ws.WebSocketTimeoutException):
data = sock.recv()
data = sock.recv()
self.assertEquals(data, "Hello, World!")
self.assertEquals(data.decode('utf-8'), "Hello, World!")
with self.assertRaises(ws.WebSocketConnectionClosedException):
data = sock.recv()
@ -268,11 +271,11 @@ class WebSocketTest(unittest.TestCase):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is "
s.add_packet("\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
s.add_packet(six.b("\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C"))
# OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17"))
data = sock.recv()
self.assertEqual(data, "Brevity is the soul of wit")
self.assertEqual(data.decode('utf-8'), "Brevity is the soul of wit")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
@ -280,23 +283,24 @@ class WebSocketTest(unittest.TestCase):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17"))
self.assertRaises(ws.WebSocketException, sock.recv)
def testRecvWithProlongedFragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
s.add_packet("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15" \
"\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC")
s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15" \
"\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC"))
# OPCODE=CONT, FIN=0, MSG="dear friends, "
s.add_packet("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07" \
"\x17MB")
s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07" \
"\x17MB"))
# OPCODE=CONT, FIN=1, MSG="once more"
s.add_packet("\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04")
s.add_packet(six.b("\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04"))
data = sock.recv()
self.assertEqual(data, "Once more unto the breach, dear friends, " \
"once more")
self.assertEqual(
data.decode('utf-8'),
"Once more unto the breach, dear friends, once more")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
@ -305,18 +309,19 @@ class WebSocketTest(unittest.TestCase):
sock.set_mask_key(create_mask_key)
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Too much "
s.add_packet("\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA")
s.add_packet(six.b("\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA"))
# OPCODE=PING, FIN=1, MSG="Please PONG this"
s.add_packet("\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")
s.add_packet(six.b("\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17"))
# OPCODE=CONT, FIN=1, MSG="of a good thing"
s.add_packet("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c" \
"\x08\x0c\x04")
s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c" \
"\x08\x0c\x04"))
data = sock.recv()
self.assertEqual(data, "Too much of a good thing")
self.assertEqual(data.decode('utf-8'), "Too much of a good thing")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
self.assertEqual(s.sent[0], "\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D" \
"\x15\n\n\x17")
self.assertEqual(
s.sent[0],
six.b("\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17"))
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testWebSocket(self):

@ -18,8 +18,10 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
"""
from __future__ import print_function
import six
import socket
try:
@ -34,7 +36,13 @@ except ImportError:
HAVE_SSL = False
from urlparse import urlparse
try:
# python 3
from urllib.parse import urlparse
except ImportError:
# python 2
from urlparse import urlparse
import os
import array
import struct
@ -215,7 +223,7 @@ def create_connection(url, timeout=None, **options):
return websock
_MAX_INTEGER = (1 << 32) -1
_AVAILABLE_KEY_CHARS = range(0x21, 0x2f + 1) + range(0x3a, 0x7e + 1)
_AVAILABLE_KEY_CHARS = list(range(0x21, 0x2f + 1)) + list(range(0x3a, 0x7e + 1))
_MAX_CHAR_BYTE = (1<<8) -1
# ref. Websocket gets an update, and it breaks stuff.
@ -224,7 +232,7 @@ _MAX_CHAR_BYTE = (1<<8) -1
def _create_sec_websocket_key():
uid = uuid.uuid4()
return base64.encodestring(uid.bytes).strip()
return base64.encodestring(uid.bytes).decode('utf-8').strip()
_HEADERS_TO_CHECK = {
@ -300,7 +308,7 @@ class ABNF(object):
fin: fin flag. if set to 0, create continue fragmentation.
"""
if opcode == ABNF.OPCODE_TEXT and isinstance(data, unicode):
if opcode == ABNF.OPCODE_TEXT and isinstance(data, six.text_type):
data = data.encode("utf-8")
# mask must be set if send data from client
return ABNF(fin, 0, 0, 0, opcode, 1, data)
@ -329,6 +337,8 @@ class ABNF(object):
frame_header += chr(self.mask << 7 | 0x7f)
frame_header += struct.pack("!Q", length)
frame_header = six.b(frame_header)
if not self.mask:
return frame_header + self.data
else:
@ -337,7 +347,11 @@ class ABNF(object):
def _get_masked(self, mask_key):
s = ABNF.mask(mask_key, self.data)
return mask_key + "".join(s)
if isinstance(mask_key, six.text_type):
mask_key = mask_key.encode('utf-8')
return mask_key + s
@staticmethod
def mask(mask_key, data):
@ -348,9 +362,16 @@ class ABNF(object):
data: data to mask/unmask.
"""
if isinstance(mask_key, six.text_type):
mask_key = six.b(mask_key)
if isinstance(data, six.text_type):
data = six.b(data)
_m = array.array("B", mask_key)
_d = array.array("B", data)
for i in xrange(len(_d)):
for i in range(len(_d)):
_d[i] ^= _m[i % 4]
return _d.tostring()
@ -519,7 +540,7 @@ class WebSocket(object):
self.connected = True
def _validate_header(self, headers, key):
for k, v in _HEADERS_TO_CHECK.iteritems():
for k, v in _HEADERS_TO_CHECK.items():
r = headers.get(k, None)
if not r:
return False
@ -532,7 +553,10 @@ class WebSocket(object):
return False
result = result.lower()
value = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
if isinstance(result, six.text_type):
result = result.encode('utf-8')
value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')
hashed = base64.encodestring(hashlib.sha1(value).digest()).strip().lower()
return hashed == result
@ -544,7 +568,8 @@ class WebSocket(object):
while True:
line = self._recv_line()
if line == "\r\n":
line = line.decode('utf-8')
if line == "\r\n" or line == "\n":
break
line = line.strip()
if traceEnabled:
@ -575,6 +600,7 @@ class WebSocket(object):
opcode: operation code to send. Please see OPCODE_XXX.
"""
frame = ABNF.create_frame(payload, opcode)
return self.send_frame(frame)
@ -715,17 +741,28 @@ class WebSocket(object):
return value: ABNF frame object.
"""
# Header
if self._frame_header is None:
self._frame_header = self._recv_strict(2)
b1 = ord(self._frame_header[0])
b1 = self._frame_header[0]
if six.PY2:
b1 = ord(b1)
fin = b1 >> 7 & 1
rsv1 = b1 >> 6 & 1
rsv2 = b1 >> 5 & 1
rsv3 = b1 >> 4 & 1
opcode = b1 & 0xf
b2 = ord(self._frame_header[1])
b2 = self._frame_header[1]
if six.PY2:
b2 = ord(b2)
has_mask = b2 >> 7 & 1
# Frame length
if self._frame_length is None:
length_bits = b2 & 0x7f
@ -737,13 +774,16 @@ class WebSocket(object):
self._frame_length = struct.unpack("!Q", length_data)[0]
else:
self._frame_length = length_bits
# Mask
if self._frame_mask is None:
self._frame_mask = self._recv_strict(4) if has_mask else ""
# Payload
payload = self._recv_strict(self._frame_length)
if has_mask:
payload = ABNF.mask(self._frame_mask, payload)
# Reset for next frame
self._frame_header = None
self._frame_length = None
@ -798,13 +838,19 @@ class WebSocket(object):
self.sock.close()
def _send(self, data):
if isinstance(data, six.text_type):
data = data.encode('utf-8')
try:
return self.sock.send(data)
except socket.timeout as e:
raise WebSocketTimeoutException(e.message)
message = getattr(e, 'strerror', getattr(e, 'message', ''))
raise WebSocketTimeoutException(message)
except Exception as e:
if "timed out" in e.message:
raise WebSocketTimeoutException(e.message)
message = getattr(e, 'strerror', getattr(e, 'message', ''))
if "timed out" in message:
raise WebSocketTimeoutException(message)
else:
raise
@ -812,24 +858,27 @@ class WebSocket(object):
try:
bytes = self.sock.recv(bufsize)
except socket.timeout as e:
raise WebSocketTimeoutException(e.message)
message = getattr(e, 'strerror', getattr(e, 'message', ''))
raise WebSocketTimeoutException(message)
except SSLError as e:
if e.message == "The read operation timed out":
raise WebSocketTimeoutException(e.message)
message = getattr(e, 'strerror', getattr(e, 'message', ''))
if message == "The read operation timed out":
raise WebSocketTimeoutException(message)
else:
raise
if not bytes:
raise WebSocketConnectionClosedException()
return bytes
def _recv_strict(self, bufsize):
shortage = bufsize - sum(len(x) for x in self._recv_buffer)
while shortage > 0:
bytes = self._recv(shortage)
self._recv_buffer.append(bytes)
shortage -= len(bytes)
unified = "".join(self._recv_buffer)
unified = six.b("").join(self._recv_buffer)
if shortage == 0:
self._recv_buffer = []
return unified
@ -843,9 +892,9 @@ class WebSocket(object):
while True:
c = self._recv(1)
line.append(c)
if c == "\n":
if c == six.b("\n"):
break
return "".join(line)
return six.b("").join(line)
class WebSocketApp(object):
@ -903,6 +952,7 @@ class WebSocketApp(object):
data: message to send. If you set opcode to OPCODE_TEXT, data must be utf-8 string or unicode.
opcode: operation code of data. default is OPCODE_TEXT.
"""
if self.sock.send(data, opcode) == 0:
raise WebSocketConnectionClosedException()
@ -959,7 +1009,7 @@ class WebSocketApp(object):
if ping_timeout and self.last_ping_tm and time.time() - self.last_ping_tm > ping_timeout:
self.last_ping_tm = 0
raise WebSocketTimeoutException()
if r:
op_code, frame = self.sock.recv_data_frame(True)
if op_code == ABNF.OPCODE_CLOSE: