This commit is contained in:
liris
2016-05-10 08:52:30 +09:00
18 changed files with 276 additions and 219 deletions

View File

@@ -2,7 +2,7 @@ ChangeLog
============ ============
- 0.37.0 - 0.37.0
- fixed fialer that `websocket.create_connection` does not accept `origin` as a parameter (#246 ) - fixed failure that `websocket.create_connection` does not accept `origin` as a parameter (#246 )
- 0.36.0 - 0.36.0
- added support for using custom connection class (#235) - added support for using custom connection class (#235)
@@ -90,7 +90,7 @@ ChangeLog
- 0.24.0 - 0.24.0
- Supporting http-basic auth in WebSocketApp (#143) - Supporting http-basic auth in WebSocketApp (#143)
- fix failer of test.testInternalRecvStrict(#141) - fix failure of test.testInternalRecvStrict(#141)
- skip utf8 validation by skip_utf8_validation argument (#137) - skip utf8 validation by skip_utf8_validation argument (#137)
- WebsocketProxyException will be raised if we got error about proxy.(#138) - WebsocketProxyException will be raised if we got error about proxy.(#138)

View File

@@ -2,15 +2,18 @@
import argparse import argparse
import code import code
import six
import sys import sys
import threading import threading
import time import time
import websocket
import six
from six.moves.urllib.parse import urlparse from six.moves.urllib.parse import urlparse
import websocket
try: try:
import readline import readline
except: except ImportError:
pass pass
@@ -27,15 +30,17 @@ ENCODING = get_encoding()
class VAction(argparse.Action): class VAction(argparse.Action):
def __call__(self, parser, args, values, option_string=None): def __call__(self, parser, args, values, option_string=None):
if values==None: if values is None:
values = "1" values = "1"
try: try:
values = int(values) values = int(values)
except ValueError: except ValueError:
values = values.count("v")+1 values = values.count("v") + 1
setattr(args, self.dest, values) setattr(args, self.dest, values)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="WebSocket Simple Dump Tool") parser = argparse.ArgumentParser(description="WebSocket Simple Dump Tool")
parser.add_argument("url", metavar="ws_url", parser.add_argument("url", metavar="ws_url",
@@ -63,7 +68,9 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
class RawInput():
class RawInput:
def raw_input(self, prompt): def raw_input(self, prompt):
if six.PY3: if six.PY3:
line = input(prompt) line = input(prompt)
@@ -77,7 +84,9 @@ class RawInput():
return line return line
class InteractiveConsole(RawInput, code.InteractiveConsole): class InteractiveConsole(RawInput, code.InteractiveConsole):
def write(self, data): def write(self, data):
sys.stdout.write("\033[2K\033[E") sys.stdout.write("\033[2K\033[E")
# sys.stdout.write("\n") # sys.stdout.write("\n")
@@ -88,7 +97,9 @@ class InteractiveConsole(RawInput, code.InteractiveConsole):
def read(self): def read(self):
return self.raw_input("> ") return self.raw_input("> ")
class NonInteractive(RawInput): class NonInteractive(RawInput):
def write(self, data): def write(self, data):
sys.stdout.write(data) sys.stdout.write(data)
sys.stdout.write("\n") sys.stdout.write("\n")
@@ -97,23 +108,24 @@ class NonInteractive(RawInput):
def read(self): def read(self):
return self.raw_input("") return self.raw_input("")
def main(): def main():
start_time = time.time() start_time = time.time()
args = parse_args() args = parse_args()
if args.verbose > 1: if args.verbose > 1:
websocket.enableTrace(True) websocket.enableTrace(True)
options = {} options = {}
if (args.proxy): if args.proxy:
p = urlparse(args.proxy) p = urlparse(args.proxy)
options["http_proxy_host"] = p.hostname options["http_proxy_host"] = p.hostname
options["http_proxy_port"] = p.port options["http_proxy_port"] = p.port
if (args.origin): if args.origin:
options["origin"] = args.origin options["origin"] = args.origin
if (args.subprotocols): if args.subprotocols:
options["subprotocols"] = args.subprotocols options["subprotocols"] = args.subprotocols
opts = {} opts = {}
if (args.nocert): if args.nocert:
opts = { "cert_reqs": websocket.ssl.CERT_NONE, "check_hostname": False } opts = {"cert_reqs": websocket.ssl.CERT_NONE, "check_hostname": False}
ws = websocket.create_connection(args.url, sslopt=opts, **options) ws = websocket.create_connection(args.url, sslopt=opts, **options)
if args.raw: if args.raw:
console = NonInteractive() console = NonInteractive()
@@ -125,21 +137,20 @@ def main():
try: try:
frame = ws.recv_frame() frame = ws.recv_frame()
except websocket.WebSocketException: except websocket.WebSocketException:
return (websocket.ABNF.OPCODE_CLOSE, None) return websocket.ABNF.OPCODE_CLOSE, None
if not frame: if not frame:
raise websocket.WebSocketException("Not a valid frame %s" % frame) raise websocket.WebSocketException("Not a valid frame %s" % frame)
elif frame.opcode in OPCODE_DATA: elif frame.opcode in OPCODE_DATA:
return (frame.opcode, frame.data) return frame.opcode, frame.data
elif frame.opcode == websocket.ABNF.OPCODE_CLOSE: elif frame.opcode == websocket.ABNF.OPCODE_CLOSE:
ws.send_close() ws.send_close()
return (frame.opcode, None) return frame.opcode, None
elif frame.opcode == websocket.ABNF.OPCODE_PING: elif frame.opcode == websocket.ABNF.OPCODE_PING:
ws.pong(frame.data) ws.pong(frame.data)
return frame.opcode, frame.data return frame.opcode, frame.data
return frame.opcode, frame.data return frame.opcode, frame.data
def recv_ws(): def recv_ws():
while True: while True:
opcode, data = recv() opcode, data = recv()
@@ -152,7 +163,7 @@ def main():
msg = "%s: %s" % (websocket.ABNF.OPCODE_MAP.get(opcode), data) msg = "%s: %s" % (websocket.ABNF.OPCODE_MAP.get(opcode), data)
if msg is not None: if msg is not None:
if (args.timings): if args.timings:
console.write(str(time.time() - start_time) + ": " + msg) console.write(str(time.time() - start_time) + ": " + msg)
else: else:
console.write(msg) console.write(msg)

View File

@@ -1,13 +1,9 @@
#!/usr/bin/env python #!/usr/bin/env python
import websocket
import json import json
import traceback import traceback
import six
import websocket
SERVER = 'ws://127.0.0.1:8642' SERVER = 'ws://127.0.0.1:8642'
AGENT = 'py-websockets-client' AGENT = 'py-websockets-client'
@@ -19,28 +15,28 @@ ws.close()
for case in range(1, count+1): for case in range(1, count+1):
url = SERVER + '/runCase?case={0}&agent={1}'.format(case, AGENT) url = SERVER + '/runCase?case={0}&agent={1}'.format(case, AGENT)
status = websocket.STATUS_NORMAL status = websocket.STATUS_NORMAL
try: try:
ws = websocket.create_connection(url) ws = websocket.create_connection(url)
while True: while True:
opcode, msg = ws.recv_data() opcode, msg = ws.recv_data()
if opcode == websocket.ABNF.OPCODE_TEXT: if opcode == websocket.ABNF.OPCODE_TEXT:
msg.decode("utf-8") msg.decode("utf-8")
if opcode in (websocket.ABNF.OPCODE_TEXT, websocket.ABNF.OPCODE_BINARY): if opcode in (websocket.ABNF.OPCODE_TEXT, websocket.ABNF.OPCODE_BINARY):
ws.send(msg, opcode) ws.send(msg, opcode)
except UnicodeDecodeError: except UnicodeDecodeError:
# this case is ok. # this case is ok.
status = websocket.STATUS_PROTOCOL_ERROR status = websocket.STATUS_PROTOCOL_ERROR
except websocket.WebSocketProtocolException: except websocket.WebSocketProtocolException:
status = websocket.STATUS_PROTOCOL_ERROR status = websocket.STATUS_PROTOCOL_ERROR
except websocket.WebSocketPayloadException: except websocket.WebSocketPayloadException:
status = websocket.STATUS_INVALID_PAYLOAD status = websocket.STATUS_INVALID_PAYLOAD
except Exception as e: except Exception as e:
# status = websocket.STATUS_PROTOCOL_ERROR # status = websocket.STATUS_PROTOCOL_ERROR
print(traceback.format_exc()) print(traceback.format_exc())
finally: finally:
ws.close(status) ws.close(status)
print("Ran {} test cases.".format(case)) print("Ran {} test cases.".format(case))
url = SERVER + '/updateReports?agent={0}'.format(AGENT) url = SERVER + '/updateReports?agent={0}'.format(AGENT)

View File

@@ -1,7 +1,7 @@
import websocket import websocket
try: try:
import thread import thread
except ImportError: #TODO use Threading instead of _thread in python3 except ImportError: # TODO use Threading instead of _thread in python3
import _thread as thread import _thread as thread
import time import time
import sys import sys
@@ -41,8 +41,8 @@ if __name__ == "__main__":
else: else:
host = sys.argv[1] host = sys.argv[1]
ws = websocket.WebSocketApp(host, ws = websocket.WebSocketApp(host,
on_message = on_message, on_message=on_message,
on_error = on_error, on_error=on_error,
on_close = on_close) on_close=on_close)
ws.on_open = on_open ws.on_open = on_open
ws.run_forever() ws.run_forever()

View File

@@ -2,12 +2,12 @@ from setuptools import setup
import sys import sys
VERSION = "0.37.0" VERSION = "0.37.0"
NAME="websocket_client" NAME = "websocket_client"
install_requires = ["six"] install_requires = ["six"]
tests_require = [] tests_require = []
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
if sys.version_info[1] < 7 or (sys.version_info[1] == 7 and sys.version_info[2]< 9): if sys.version_info[1] < 7 or (sys.version_info[1] == 7 and sys.version_info[2] < 9):
install_requires.append('backports.ssl_match_hostname') install_requires.append('backports.ssl_match_hostname')
if sys.version_info[1] < 7: if sys.version_info[1] < 7:
tests_require.append('unittest2==0.8.0') tests_require.append('unittest2==0.8.0')

View File

@@ -19,7 +19,11 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
from ._core import * from ._abnf import *
from ._app import WebSocketApp from ._app import WebSocketApp
from ._core import *
from ._exceptions import *
from ._logging import *
from ._socket import *
__version__ = "0.37.0" __version__ = "0.37.0"

View File

@@ -19,10 +19,12 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import six
import array import array
import struct
import os import os
import struct
import six
from ._exceptions import * from ._exceptions import *
from ._utils import validate_utf8 from ._utils import validate_utf8
@@ -44,6 +46,22 @@ except ImportError:
else: else:
return _d.tostring() return _d.tostring()
__all__ = [
'ABNF', 'continuous_frame', 'frame_buffer',
'STATUS_NORMAL',
'STATUS_GOING_AWAY',
'STATUS_PROTOCOL_ERROR',
'STATUS_UNSUPPORTED_DATA_TYPE',
'STATUS_STATUS_NOT_AVAILABLE',
'STATUS_ABNORMAL_CLOSED',
'STATUS_INVALID_PAYLOAD',
'STATUS_POLICY_VIOLATION',
'STATUS_MESSAGE_TOO_BIG',
'STATUS_INVALID_EXTENSION',
'STATUS_UNEXPECTED_CONDITION',
'STATUS_TLS_HANDSHAKE_ERROR',
]
# closing frame status codes. # closing frame status codes.
STATUS_NORMAL = 1000 STATUS_NORMAL = 1000
STATUS_GOING_AWAY = 1001 STATUS_GOING_AWAY = 1001
@@ -68,7 +86,8 @@ VALID_CLOSE_STATUS = (
STATUS_MESSAGE_TOO_BIG, STATUS_MESSAGE_TOO_BIG,
STATUS_INVALID_EXTENSION, STATUS_INVALID_EXTENSION,
STATUS_UNEXPECTED_CONDITION, STATUS_UNEXPECTED_CONDITION,
) )
class ABNF(object): class ABNF(object):
""" """
@@ -78,16 +97,16 @@ class ABNF(object):
""" """
# operation code values. # operation code values.
OPCODE_CONT = 0x0 OPCODE_CONT = 0x0
OPCODE_TEXT = 0x1 OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2 OPCODE_BINARY = 0x2
OPCODE_CLOSE = 0x8 OPCODE_CLOSE = 0x8
OPCODE_PING = 0x9 OPCODE_PING = 0x9
OPCODE_PONG = 0xa OPCODE_PONG = 0xa
# available operation code value tuple # available operation code value tuple
OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE, OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE,
OPCODE_PING, OPCODE_PONG) OPCODE_PING, OPCODE_PONG)
# opcode human readable string # opcode human readable string
OPCODE_MAP = { OPCODE_MAP = {
@@ -97,10 +116,10 @@ class ABNF(object):
OPCODE_CLOSE: "close", OPCODE_CLOSE: "close",
OPCODE_PING: "ping", OPCODE_PING: "ping",
OPCODE_PONG: "pong" OPCODE_PONG: "pong"
} }
# data length threshold. # data length threshold.
LENGTH_7 = 0x7e LENGTH_7 = 0x7e
LENGTH_16 = 1 << 16 LENGTH_16 = 1 << 16
LENGTH_63 = 1 << 63 LENGTH_63 = 1 << 63
@@ -116,7 +135,7 @@ class ABNF(object):
self.rsv3 = rsv3 self.rsv3 = rsv3
self.opcode = opcode self.opcode = opcode
self.mask = mask self.mask = mask
if data == None: if data is None:
data = "" data = ""
self.data = data self.data = data
self.get_mask_key = os.urandom self.get_mask_key = os.urandom
@@ -144,17 +163,19 @@ class ABNF(object):
if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
raise WebSocketProtocolException("Invalid close frame.") raise WebSocketProtocolException("Invalid close frame.")
code = 256*six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2]) code = 256 * \
six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2])
if not self._is_valid_close_status(code): if not self._is_valid_close_status(code):
raise WebSocketProtocolException("Invalid close opcode.") raise WebSocketProtocolException("Invalid close opcode.")
def _is_valid_close_status(self, code): @staticmethod
return code in VALID_CLOSE_STATUS or (3000 <= code <5000) def _is_valid_close_status(code):
return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
def __str__(self): def __str__(self):
return "fin=" + str(self.fin) \ return "fin=" + str(self.fin) \
+ " opcode=" + str(self.opcode) \ + " opcode=" + str(self.opcode) \
+ " data=" + str(self.data) + " data=" + str(self.data)
@staticmethod @staticmethod
def create_frame(data, opcode, fin=1): def create_frame(data, opcode, fin=1):
@@ -224,7 +245,7 @@ class ABNF(object):
data: data to mask/unmask. data: data to mask/unmask.
""" """
if data == None: if data is None:
data = "" data = ""
if isinstance(mask_key, six.text_type): if isinstance(mask_key, six.text_type):
@@ -237,9 +258,10 @@ class ABNF(object):
_d = array.array("B", data) _d = array.array("B", data)
return _mask(_m, _d) return _mask(_m, _d)
class frame_buffer(object): class frame_buffer(object):
_HEADER_MASK_INDEX = 5 _HEADER_MASK_INDEX = 5
_HEADER_LENGHT_INDEX = 6 _HEADER_LENGTH_INDEX = 6
def __init__(self, recv_fn, skip_utf8_validation): def __init__(self, recv_fn, skip_utf8_validation):
self.recv = recv_fn self.recv = recv_fn
@@ -255,7 +277,7 @@ class frame_buffer(object):
self.mask = None self.mask = None
def has_received_header(self): def has_received_header(self):
return self.header is None return self.header is None
def recv_header(self): def recv_header(self):
header = self.recv_strict(2) header = self.recv_strict(2)
@@ -284,12 +306,11 @@ class frame_buffer(object):
return False return False
return self.header[frame_buffer._HEADER_MASK_INDEX] return self.header[frame_buffer._HEADER_MASK_INDEX]
def has_received_length(self): def has_received_length(self):
return self.length is None return self.length is None
def recv_length(self): def recv_length(self):
bits = self.header[frame_buffer._HEADER_LENGHT_INDEX] bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
length_bits = bits & 0x7f length_bits = bits & 0x7f
if length_bits == 0x7e: if length_bits == 0x7e:
v = self.recv_strict(2) v = self.recv_strict(2)
@@ -342,10 +363,11 @@ class frame_buffer(object):
# fragmenting the heap -- the number of bytes recv() actually # fragmenting the heap -- the number of bytes recv() actually
# reads is limited by socket buffer and is relatively small, # reads is limited by socket buffer and is relatively small,
# yet passing large numbers repeatedly causes lots of large # yet passing large numbers repeatedly causes lots of large
# buffers allocated and then shrunk, which results in fragmentation. # buffers allocated and then shrunk, which results in
bytes = self.recv(min(16384, shortage)) # fragmentation.
self.recv_buffer.append(bytes) bytes_ = self.recv(min(16384, shortage))
shortage -= len(bytes) self.recv_buffer.append(bytes_)
shortage -= len(bytes_)
unified = six.b("").join(self.recv_buffer) unified = six.b("").join(self.recv_buffer)
@@ -358,6 +380,7 @@ class frame_buffer(object):
class continuous_frame(object): class continuous_frame(object):
def __init__(self, fire_cont_frame, skip_utf8_validation): def __init__(self, fire_cont_frame, skip_utf8_validation):
self.fire_cont_frame = fire_cont_frame self.fire_cont_frame = fire_cont_frame
self.skip_utf8_validation = skip_utf8_validation self.skip_utf8_validation = skip_utf8_validation
@@ -367,7 +390,8 @@ class continuous_frame(object):
def validate(self, frame): def validate(self, frame):
if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
raise WebSocketProtocolException("Illegal frame") raise WebSocketProtocolException("Illegal frame")
if self.recving_frames and frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): if self.recving_frames and \
frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
raise WebSocketProtocolException("Illegal frame") raise WebSocketProtocolException("Illegal frame")
def add(self, frame): def add(self, frame):
@@ -389,6 +413,7 @@ class continuous_frame(object):
self.cont_data = None self.cont_data = None
frame.data = data[1] frame.data = data[1]
if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data): if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
raise WebSocketPayloadException("cannot decode: " + repr(frame.data)) raise WebSocketPayloadException(
"cannot decode: " + repr(frame.data))
return [data[0], frame] return [data[0], frame]

View File

@@ -23,17 +23,18 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
""" """
WebSocketApp provides higher level APIs. WebSocketApp provides higher level APIs.
""" """
import select
import sys
import threading import threading
import time import time
import traceback import traceback
import sys
import select
import six import six
from ._abnf import ABNF
from ._core import WebSocket, getdefaulttimeout from ._core import WebSocket, getdefaulttimeout
from ._exceptions import * from ._exceptions import *
from ._logging import * from ._logging import *
from ._abnf import ABNF
__all__ = ["WebSocketApp"] __all__ = ["WebSocketApp"]
@@ -43,7 +44,8 @@ class WebSocketApp(object):
Higher level of APIs are provided. Higher level of APIs are provided.
The interface is like JavaScript WebSocket object. The interface is like JavaScript WebSocket object.
""" """
def __init__(self, url, header=[],
def __init__(self, url, header=None,
on_open=None, on_message=None, on_error=None, on_open=None, on_message=None, on_error=None,
on_close=None, on_ping=None, on_pong=None, on_close=None, on_ping=None, on_pong=None,
on_cont_message=None, on_cont_message=None,
@@ -87,7 +89,7 @@ class WebSocketApp(object):
subprotocols: array of available sub protocols. default is None. subprotocols: array of available sub protocols. default is None.
""" """
self.url = url self.url = url
self.header = header self.header = header if header is not None else []
self.cookie = cookie self.cookie = cookie
self.on_open = on_open self.on_open = on_open
self.on_message = on_message self.on_message = on_message
@@ -113,7 +115,8 @@ class WebSocketApp(object):
""" """
if not self.sock or self.sock.send(data, opcode) == 0: if not self.sock or self.sock.send(data, opcode) == 0:
raise WebSocketConnectionClosedException("Connection is already closed.") raise WebSocketConnectionClosedException(
"Connection is already closed.")
def close(self): def close(self):
""" """
@@ -168,27 +171,29 @@ class WebSocketApp(object):
close_frame = None close_frame = None
try: try:
self.sock = WebSocket(self.get_mask_key, self.sock = WebSocket(
sockopt=sockopt, sslopt=sslopt, self.get_mask_key, sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=self.on_cont_message and True or False, fire_cont_frame=self.on_cont_message and True or False,
skip_utf8_validation=skip_utf8_validation) skip_utf8_validation=skip_utf8_validation)
self.sock.settimeout(getdefaulttimeout()) self.sock.settimeout(getdefaulttimeout())
self.sock.connect(self.url, header=self.header, cookie=self.cookie, self.sock.connect(
self.url, header=self.header, cookie=self.cookie,
http_proxy_host=http_proxy_host, http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port, http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy,
http_no_proxy=http_no_proxy, http_proxy_auth=http_proxy_auth, http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols,
subprotocols=self.subprotocols,
host=host, origin=origin) host=host, origin=origin)
self._callback(self.on_open) self._callback(self.on_open)
if ping_interval: if ping_interval:
event = threading.Event() event = threading.Event()
thread = threading.Thread(target=self._send_ping, args=(ping_interval, event)) thread = threading.Thread(
target=self._send_ping, args=(ping_interval, event))
thread.setDaemon(True) thread.setDaemon(True)
thread.start() thread.start()
while self.sock.connected: while self.sock.connected:
r, w, e = select.select((self.sock.sock, ), (), (), ping_timeout) r, w, e = select.select(
(self.sock.sock, ), (), (), ping_timeout)
if not self.keep_running: if not self.keep_running:
break break
@@ -203,8 +208,10 @@ class WebSocketApp(object):
self.last_pong_tm = time.time() self.last_pong_tm = time.time()
self._callback(self.on_pong, frame.data) self._callback(self.on_pong, frame.data)
elif op_code == ABNF.OPCODE_CONT and self.on_cont_message: elif op_code == ABNF.OPCODE_CONT and self.on_cont_message:
self._callback(self.on_data, data, frame.opcode, frame.fin) self._callback(self.on_data, data,
self._callback(self.on_cont_message, frame.data, frame.fin) frame.opcode, frame.fin)
self._callback(self.on_cont_message,
frame.data, frame.fin)
else: else:
data = frame.data data = frame.data
if six.PY3 and frame.opcode == ABNF.OPCODE_TEXT: if six.PY3 and frame.opcode == ABNF.OPCODE_TEXT:
@@ -227,8 +234,9 @@ class WebSocketApp(object):
thread.join() thread.join()
self.keep_running = False self.keep_running = False
self.sock.close() self.sock.close()
self._callback(self.on_close, close_args = self._get_close_args(
*self._get_close_args(close_frame.data if close_frame else None)) close_frame.data if close_frame else None)
self._callback(self.on_close, *close_args)
self.sock = None self.sock = None
def _get_close_args(self, data): def _get_close_args(self, data):
@@ -244,7 +252,7 @@ class WebSocketApp(object):
return [] return []
if data and len(data) >= 2: if data and len(data) >= 2:
code = 256*six.byte2int(data[0:1]) + six.byte2int(data[1:2]) code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2])
reason = data[2:].decode('utf-8') reason = data[2:].decode('utf-8')
return [code, reason] return [code, reason]

View File

@@ -21,28 +21,22 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
""" """
from __future__ import print_function from __future__ import print_function
import six
import socket import socket
if six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
import struct import struct
import threading import threading
import six
# websocket modules # websocket modules
from ._exceptions import *
from ._abnf import * from ._abnf import *
from ._exceptions import *
from ._handshake import *
from ._http import *
from ._logging import *
from ._socket import * from ._socket import *
from ._utils import * from ._utils import *
from ._url import *
from ._logging import * __all__ = ['WebSocket', 'create_connection']
from ._http import *
from ._handshake import *
from ._ssl_compat import *
""" """
websocket python client. websocket python client.
@@ -83,7 +77,7 @@ class WebSocket(object):
def __init__(self, get_mask_key=None, sockopt=None, sslopt=None, def __init__(self, get_mask_key=None, sockopt=None, sslopt=None,
fire_cont_frame=False, enable_multithread=False, fire_cont_frame=False, enable_multithread=False,
skip_utf8_validation=False, **options): skip_utf8_validation=False, **_):
""" """
Initialize WebSocket object. Initialize WebSocket object.
""" """
@@ -95,7 +89,8 @@ class WebSocket(object):
self.get_mask_key = get_mask_key self.get_mask_key = get_mask_key
# These buffer over the build-up of a single frame. # These buffer over the build-up of a single frame.
self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation) self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation)
self.cont_frame = continuous_frame(fire_cont_frame, skip_utf8_validation) self.cont_frame = continuous_frame(
fire_cont_frame, skip_utf8_validation)
if enable_multithread: if enable_multithread:
self.lock = threading.Lock() self.lock = threading.Lock()
@@ -329,7 +324,8 @@ class WebSocket(object):
if not frame: if not frame:
# handle error: # handle error:
# 'NoneType' object has no attribute 'opcode' # 'NoneType' object has no attribute 'opcode'
raise WebSocketProtocolException("Not a valid frame %s" % frame) raise WebSocketProtocolException(
"Not a valid frame %s" % frame)
elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT): elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT):
self.cont_frame.validate(frame) self.cont_frame.validate(frame)
self.cont_frame.add(frame) self.cont_frame.add(frame)
@@ -339,17 +335,18 @@ class WebSocket(object):
elif frame.opcode == ABNF.OPCODE_CLOSE: elif frame.opcode == ABNF.OPCODE_CLOSE:
self.send_close() self.send_close()
return (frame.opcode, frame) return frame.opcode, frame
elif frame.opcode == ABNF.OPCODE_PING: elif frame.opcode == ABNF.OPCODE_PING:
if len(frame.data) < 126: if len(frame.data) < 126:
self.pong(frame.data) self.pong(frame.data)
else: else:
raise WebSocketProtocolException("Ping message is too long") raise WebSocketProtocolException(
"Ping message is too long")
if control_frame: if control_frame:
return (frame.opcode, frame) return frame.opcode, frame
elif frame.opcode == ABNF.OPCODE_PONG: elif frame.opcode == ABNF.OPCODE_PONG:
if control_frame: if control_frame:
return (frame.opcode, frame) return frame.opcode, frame
def recv_frame(self): def recv_frame(self):
""" """
@@ -389,7 +386,8 @@ class WebSocket(object):
try: try:
self.connected = False self.connected = False
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE) self.send(struct.pack('!H', status) +
reason, ABNF.OPCODE_CLOSE)
sock_timeout = self.sock.gettimeout() sock_timeout = self.sock.gettimeout()
self.sock.settimeout(timeout) self.sock.settimeout(timeout)
try: try:
@@ -415,7 +413,7 @@ class WebSocket(object):
self.sock.shutdown(socket.SHUT_RDWR) self.sock.shutdown(socket.SHUT_RDWR)
def shutdown(self): def shutdown(self):
"close socket, immediately." """close socket, immediately."""
if self.sock: if self.sock:
self.sock.close() self.sock.close()
self.sock = None self.sock = None

View File

@@ -25,6 +25,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
define websocket exceptions define websocket exceptions
""" """
class WebSocketException(Exception): class WebSocketException(Exception):
""" """
websocket exception class. websocket exception class.
@@ -72,6 +73,8 @@ class WebSocketBadStatusException(WebSocketException):
""" """
WebSocketBadStatusException will be raised when we get bad handshake status code. WebSocketBadStatusException will be raised when we get bad handshake status code.
""" """
def __init__(self, message, status_code): def __init__(self, message, status_code):
super(WebSocketBadStatusException, self).__init__(message % status_code) super(WebSocketBadStatusException, self).__init__(
message % status_code)
self.status_code = status_code self.status_code = status_code

View File

@@ -19,25 +19,22 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import hashlib
import hmac
import os
import six import six
from ._exceptions import *
from ._http import *
from ._logging import *
from ._socket import *
if six.PY3: if six.PY3:
from base64 import encodebytes as base64encode from base64 import encodebytes as base64encode
else: else:
from base64 import encodestring as base64encode from base64 import encodestring as base64encode
import uuid
import hashlib
import hmac
import os
import sys
from ._logging import *
from ._url import *
from ._socket import*
from ._http import *
from ._exceptions import *
__all__ = ["handshake_response", "handshake"] __all__ = ["handshake_response", "handshake"]
if hasattr(hmac, "compare_digest"): if hasattr(hmac, "compare_digest"):
@@ -51,6 +48,7 @@ VERSION = 13
class handshake_response(object): class handshake_response(object):
def __init__(self, status, headers, subprotocol): def __init__(self, status, headers, subprotocol):
self.status = status self.status = status
self.headers = headers self.headers = headers
@@ -73,10 +71,11 @@ def handshake(sock, hostname, port, resource, **options):
def _get_handshake_headers(resource, host, port, options): def _get_handshake_headers(resource, host, port, options):
headers = [] headers = [
headers.append("GET %s HTTP/1.1" % resource) "GET %s HTTP/1.1" % resource,
headers.append("Upgrade: websocket") "Upgrade: websocket",
headers.append("Connection: Upgrade") "Connection: Upgrade"
]
if port == 80 or port == 443: if port == 80 or port == 443:
hostport = host hostport = host
else: else:
@@ -126,7 +125,7 @@ def _get_resp_headers(sock, success_status=101):
_HEADERS_TO_CHECK = { _HEADERS_TO_CHECK = {
"upgrade": "websocket", "upgrade": "websocket",
"connection": "upgrade", "connection": "upgrade",
} }
def _validate(headers, key, subprotocols): def _validate(headers, key, subprotocols):

View File

@@ -19,45 +19,49 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import six
import socket
import errno import errno
import os import os
import socket
import sys import sys
import six
from ._exceptions import *
from ._logging import *
from ._socket import*
from ._ssl_compat import *
from ._url import *
if six.PY3: if six.PY3:
from base64 import encodebytes as base64encode from base64 import encodebytes as base64encode
else: else:
from base64 import encodestring as base64encode from base64 import encodestring as base64encode
from ._logging import *
from ._url import *
from ._socket import*
from ._exceptions import *
from ._ssl_compat import *
__all__ = ["proxy_info", "connect", "read_headers"] __all__ = ["proxy_info", "connect", "read_headers"]
class proxy_info(object): class proxy_info(object):
def __init__(self, **options): def __init__(self, **options):
self.host = options.get("http_proxy_host", None) self.host = options.get("http_proxy_host", None)
if self.host: if self.host:
self.port = options.get("http_proxy_port", 0) self.port = options.get("http_proxy_port", 0)
self.auth = options.get("http_proxy_auth", None) self.auth = options.get("http_proxy_auth", None)
self.no_proxy = options.get("http_no_proxy", None) self.no_proxy = options.get("http_no_proxy", None)
else: else:
self.port = 0 self.port = 0
self.auth = None self.auth = None
self.no_proxy = None self.no_proxy = None
def connect(url, options, proxy, socket): def connect(url, options, proxy, socket):
hostname, port, resource, is_secure = parse_url(url) hostname, port, resource, is_secure = parse_url(url)
if socket: if socket:
return socket, (hostname, port, resource) return socket, (hostname, port, resource)
addrinfo_list, need_tunnel, auth = _get_addrinfo_list(hostname, port, is_secure, proxy) addrinfo_list, need_tunnel, auth = _get_addrinfo_list(
hostname, port, is_secure, proxy)
if not addrinfo_list: if not addrinfo_list:
raise WebSocketException( raise WebSocketException(
"Host not found.: " + hostname + ":" + str(port)) "Host not found.: " + hostname + ":" + str(port))
@@ -82,10 +86,11 @@ def connect(url, options, proxy, socket):
def _get_addrinfo_list(hostname, port, is_secure, proxy): def _get_addrinfo_list(hostname, port, is_secure, proxy):
phost, pport, pauth = get_proxy_info(hostname, is_secure, phost, pport, pauth = get_proxy_info(
proxy.host, proxy.port, proxy.auth, proxy.no_proxy) hostname, is_secure, proxy.host, proxy.port, proxy.auth, proxy.no_proxy)
if not phost: if not phost:
addrinfo_list = socket.getaddrinfo(hostname, port, 0, 0, socket.SOL_TCP) addrinfo_list = socket.getaddrinfo(
hostname, port, 0, 0, socket.SOL_TCP)
return addrinfo_list, False, None return addrinfo_list, False, None
else: else:
pport = pport and pport or 80 pport = pport and pport or 80
@@ -137,14 +142,15 @@ def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
sslopt.get('keyfile', None), sslopt.get('keyfile', None),
sslopt.get('password', None), sslopt.get('password', None),
) )
# see https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153 # see
# https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153
context.verify_mode = sslopt['cert_reqs'] context.verify_mode = sslopt['cert_reqs']
if HAVE_CONTEXT_CHECK_HOSTNAME: if HAVE_CONTEXT_CHECK_HOSTNAME:
context.check_hostname = check_hostname context.check_hostname = check_hostname
if 'ciphers' in sslopt: if 'ciphers' in sslopt:
context.set_ciphers(sslopt['ciphers']) context.set_ciphers(sslopt['ciphers'])
if 'cert_chain' in sslopt : if 'cert_chain' in sslopt:
certfile,keyfile,password = sslopt['cert_chain'] certfile, keyfile, password = sslopt['cert_chain']
context.load_cert_chain(certfile, keyfile, password) context.load_cert_chain(certfile, keyfile, password)
return context.wrap_socket( return context.wrap_socket(
@@ -158,12 +164,13 @@ def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
def _ssl_socket(sock, user_sslopt, hostname): def _ssl_socket(sock, user_sslopt, hostname):
sslopt = dict(cert_reqs=ssl.CERT_REQUIRED) sslopt = dict(cert_reqs=ssl.CERT_REQUIRED)
sslopt.update(user_sslopt) sslopt.update(user_sslopt)
certPath = os.path.join( certPath = os.path.join(
os.path.dirname(__file__), "cacert.pem") os.path.dirname(__file__), "cacert.pem")
if os.path.isfile(certPath) and user_sslopt.get('ca_certs', None) == None: if os.path.isfile(certPath) and user_sslopt.get('ca_certs', None) is None:
sslopt['ca_certs'] = certPath sslopt['ca_certs'] = certPath
check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop('check_hostname', True) check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop(
'check_hostname', True)
if _can_use_sni(): if _can_use_sni():
sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname) sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname)
@@ -176,6 +183,7 @@ def _ssl_socket(sock, user_sslopt, hostname):
return sock return sock
def _tunnel(sock, host, port, auth): def _tunnel(sock, host, port, auth):
debug("Connecting proxy...") debug("Connecting proxy...")
connect_header = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port) connect_header = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port)
@@ -199,9 +207,10 @@ def _tunnel(sock, host, port, auth):
if status != 200: if status != 200:
raise WebSocketProxyException( raise WebSocketProxyException(
"failed CONNECT via proxy status: %r" % status) "failed CONNECT via proxy status: %r" % status)
return sock return sock
def read_headers(sock): def read_headers(sock):
status = None status = None
headers = {} headers = {}

View File

@@ -19,7 +19,6 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import logging import logging
_logger = logging.getLogger('websocket') _logger = logging.getLogger('websocket')
@@ -29,15 +28,15 @@ __all__ = ["enableTrace", "dump", "error", "debug", "trace",
"isEnabledForError", "isEnabledForDebug"] "isEnabledForError", "isEnabledForDebug"]
def enableTrace(tracable): def enableTrace(traceable):
""" """
turn on/off the tracability. turn on/off the traceability.
tracable: boolean value. if set True, tracability is enabled. traceable: boolean value. if set True, traceability is enabled.
""" """
global _traceEnabled global _traceEnabled
_traceEnabled = tracable _traceEnabled = traceable
if tracable: if traceable:
if not _logger.handlers: if not _logger.handlers:
_logger.addHandler(logging.StreamHandler()) _logger.addHandler(logging.StreamHandler())
_logger.setLevel(logging.DEBUG) _logger.setLevel(logging.DEBUG)

View File

@@ -19,13 +19,13 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import socket import socket
import six import six
from ._exceptions import * from ._exceptions import *
from ._utils import *
from ._ssl_compat import * from ._ssl_compat import *
from ._utils import *
DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)] DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)]
if hasattr(socket, "SO_KEEPALIVE"): if hasattr(socket, "SO_KEEPALIVE"):
@@ -42,7 +42,9 @@ _default_timeout = None
__all__ = ["DEFAULT_SOCKET_OPTION", "sock_opt", "setdefaulttimeout", "getdefaulttimeout", __all__ = ["DEFAULT_SOCKET_OPTION", "sock_opt", "setdefaulttimeout", "getdefaulttimeout",
"recv", "recv_line", "send"] "recv", "recv_line", "send"]
class sock_opt(object): class sock_opt(object):
def __init__(self, sockopt, sslopt): def __init__(self, sockopt, sslopt):
if sockopt is None: if sockopt is None:
sockopt = [] sockopt = []
@@ -52,6 +54,7 @@ class sock_opt(object):
self.sslopt = sslopt self.sslopt = sslopt
self.timeout = None self.timeout = None
def setdefaulttimeout(timeout): def setdefaulttimeout(timeout):
""" """
Set the global timeout setting to connect. Set the global timeout setting to connect.
@@ -74,7 +77,7 @@ def recv(sock, bufsize):
raise WebSocketConnectionClosedException("socket is already closed.") raise WebSocketConnectionClosedException("socket is already closed.")
try: try:
bytes = sock.recv(bufsize) bytes_ = sock.recv(bufsize)
except socket.timeout as e: except socket.timeout as e:
message = extract_err_message(e) message = extract_err_message(e)
raise WebSocketTimeoutException(message) raise WebSocketTimeoutException(message)
@@ -85,10 +88,11 @@ def recv(sock, bufsize):
else: else:
raise raise
if not bytes: if not bytes_:
raise WebSocketConnectionClosedException("Connection is already closed.") raise WebSocketConnectionClosedException(
"Connection is already closed.")
return bytes return bytes_
def recv_line(sock): def recv_line(sock):

View File

@@ -19,7 +19,6 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
__all__ = ["HAVE_SSL", "ssl", "SSLError"] __all__ = ["HAVE_SSL", "ssl", "SSLError"]
try: try:

View File

@@ -19,9 +19,9 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import os
from six.moves.urllib.parse import urlparse from six.moves.urllib.parse import urlparse
import os
__all__ = ["parse_url", "get_proxy_info"] __all__ = ["parse_url", "get_proxy_info"]
@@ -66,7 +66,7 @@ def parse_url(url):
if parsed.query: if parsed.query:
resource += "?" + parsed.query resource += "?" + parsed.query
return (hostname, port, resource, is_secure) return hostname, port, resource, is_secure
DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"] DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"]
@@ -82,8 +82,9 @@ def _is_no_proxy_host(hostname, no_proxy):
return hostname in no_proxy return hostname in no_proxy
def get_proxy_info(hostname, is_secure, def get_proxy_info(
proxy_host=None, proxy_port=0, proxy_auth=None, no_proxy=None): hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None,
no_proxy=None):
""" """
try to retrieve proxy host and port from environment try to retrieve proxy host and port from environment
if not provided in options. if not provided in options.

View File

@@ -19,16 +19,17 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import six import six
__all__ = ["NoLock", "validate_utf8", "extract_err_message"] __all__ = ["NoLock", "validate_utf8", "extract_err_message"]
class NoLock(object): class NoLock(object):
def __enter__(self): def __enter__(self):
pass pass
def __exit__(self, type, value, traceback): def __exit__(self, exc_type, exc_value, traceback):
pass pass
try: try:
@@ -69,10 +70,11 @@ except ImportError:
def _decode(state, codep, ch): def _decode(state, codep, ch):
tp = _UTF8D[ch] tp = _UTF8D[ch]
codep = (ch & 0x3f ) | (codep << 6) if (state != _UTF8_ACCEPT) else (0xff >> tp) & (ch) codep = (ch & 0x3f) | (codep << 6) if (
state != _UTF8_ACCEPT) else (0xff >> tp) & ch
state = _UTF8D[256 + state + tp] state = _UTF8D[256 + state + tp]
return state, codep; return state, codep
def _validate_utf8(utfbytes): def _validate_utf8(utfbytes):
state = _UTF8_ACCEPT state = _UTF8_ACCEPT
@@ -86,6 +88,7 @@ except ImportError:
return True return True
def validate_utf8(utfbytes): def validate_utf8(utfbytes):
""" """
validate utf8 byte string. validate utf8 byte string.
@@ -94,6 +97,7 @@ def validate_utf8(utfbytes):
""" """
return _validate_utf8(utfbytes) return _validate_utf8(utfbytes)
def extract_err_message(exception): def extract_err_message(exception):
if exception.args: if exception.args:
return exception.args[0] return exception.args[0]

View File

@@ -1,14 +1,33 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import six
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
import os import os
import os.path import os.path
import base64
import socket import socket
import six
# websocket-client
import websocket as ws
from websocket._handshake import _create_sec_websocket_key, \
_validate as _validate_header
from websocket._http import read_headers
from websocket._url import get_proxy_info, parse_url
from websocket._utils import validate_utf8
if six.PY3:
from base64 import decodebytes as base64decode
else:
from base64 import decodestring as base64decode
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
try: try:
from ssl import SSLError from ssl import SSLError
except ImportError: except ImportError:
@@ -16,37 +35,15 @@ except ImportError:
class SSLError(Exception): class SSLError(Exception):
pass pass
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
import uuid
if six.PY3:
from base64 import decodebytes as base64decode
else:
from base64 import decodestring as base64decode
# websocket-client
import websocket as ws
from websocket._handshake import _create_sec_websocket_key
from websocket._url import parse_url, get_proxy_info
from websocket._utils import validate_utf8
from websocket._handshake import _validate as _validate_header
from websocket._http import read_headers
# Skip test to access the internet. # Skip test to access the internet.
TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1' TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1'
# Skip Secure WebSocket test. # Skip Secure WebSocket test.
TEST_SECURE_WS = True TEST_SECURE_WS = True
TRACABLE = False TRACEABLE = False
def create_mask_key(n): def create_mask_key(_):
return "abcd" return "abcd"
@@ -86,7 +83,7 @@ class HeaderSockMock(SockMock):
class WebSocketTest(unittest.TestCase): class WebSocketTest(unittest.TestCase):
def setUp(self): def setUp(self):
ws.enableTrace(TRACABLE) ws.enableTrace(TRACEABLE)
def tearDown(self): def tearDown(self):
pass pass
@@ -263,7 +260,7 @@ class WebSocketTest(unittest.TestCase):
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testIter(self): def testIter(self):
count = 2 count = 2
for rsvp in ws.create_connection('ws://stream.meetup.com/2/rsvps'): for _ in ws.create_connection('ws://stream.meetup.com/2/rsvps'):
count -= 1 count -= 1
if count == 0: if count == 0:
break break
@@ -282,7 +279,7 @@ class WebSocketTest(unittest.TestCase):
# s.add_packet(SSLError("The read operation timed out")) # s.add_packet(SSLError("The read operation timed out"))
s.add_packet(six.b("baz")) s.add_packet(six.b("baz"))
with self.assertRaises(ws.WebSocketTimeoutException): with self.assertRaises(ws.WebSocketTimeoutException):
data = sock.frame_buffer.recv_strict(9) sock.frame_buffer.recv_strict(9)
# if six.PY2: # if six.PY2:
# with self.assertRaises(ws.WebSocketTimeoutException): # with self.assertRaises(ws.WebSocketTimeoutException):
# data = sock._recv_strict(9) # data = sock._recv_strict(9)
@@ -292,7 +289,7 @@ class WebSocketTest(unittest.TestCase):
data = sock.frame_buffer.recv_strict(9) data = sock.frame_buffer.recv_strict(9)
self.assertEqual(data, six.b("foobarbaz")) self.assertEqual(data, six.b("foobarbaz"))
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
data = sock.frame_buffer.recv_strict(1) sock.frame_buffer.recv_strict(1)
def testRecvTimeout(self): def testRecvTimeout(self):
sock = ws.WebSocket() sock = ws.WebSocket()
@@ -303,13 +300,13 @@ class WebSocketTest(unittest.TestCase):
s.add_packet(socket.timeout()) s.add_packet(socket.timeout())
s.add_packet(six.b("\x4e\x43\x33\x0e\x10\x0f\x00\x40")) s.add_packet(six.b("\x4e\x43\x33\x0e\x10\x0f\x00\x40"))
with self.assertRaises(ws.WebSocketTimeoutException): with self.assertRaises(ws.WebSocketTimeoutException):
data = sock.recv() sock.recv()
with self.assertRaises(ws.WebSocketTimeoutException): with self.assertRaises(ws.WebSocketTimeoutException):
data = sock.recv() sock.recv()
data = sock.recv() data = sock.recv()
self.assertEqual(data, "Hello, World!") self.assertEqual(data, "Hello, World!")
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
data = sock.recv() sock.recv()
def testRecvWithSimpleFragmentation(self): def testRecvWithSimpleFragmentation(self):
sock = ws.WebSocket() sock = ws.WebSocket()
@@ -374,10 +371,10 @@ class WebSocketTest(unittest.TestCase):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, " # OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15" \ s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15"
"\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC")) "\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC"))
# OPCODE=CONT, FIN=0, MSG="dear friends, " # OPCODE=CONT, FIN=0, MSG="dear friends, "
s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07" \ s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07"
"\x17MB")) "\x17MB"))
# OPCODE=CONT, FIN=1, MSG="once more" # OPCODE=CONT, FIN=1, MSG="once more"
s.add_packet(six.b("\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04")) s.add_packet(six.b("\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04"))
@@ -397,7 +394,7 @@ class WebSocketTest(unittest.TestCase):
# OPCODE=PING, FIN=1, MSG="Please PONG this" # OPCODE=PING, FIN=1, MSG="Please PONG this"
s.add_packet(six.b("\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")) s.add_packet(six.b("\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17"))
# OPCODE=CONT, FIN=1, MSG="of a good thing" # OPCODE=CONT, FIN=1, MSG="of a good thing"
s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c" \ s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c"
"\x08\x0c\x04")) "\x08\x0c\x04"))
data = sock.recv() data = sock.recv()
self.assertEqual(data, "Too much of a good thing") self.assertEqual(data, "Too much of a good thing")
@@ -479,7 +476,7 @@ class WebSocketAppTest(unittest.TestCase):
""" """
def setUp(self): def setUp(self):
ws.enableTrace(TRACABLE) ws.enableTrace(TRACEABLE)
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet() WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet() WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()