diff --git a/ChangeLog b/ChangeLog index c3141fc..d5b9592 100644 --- a/ChangeLog +++ b/ChangeLog @@ -2,7 +2,7 @@ ChangeLog ============ - 0.37.0 - - fixed fialer that `websocket.create_connection` does not accept `origin` as a parameter (#246 ) + - fixed failure that `websocket.create_connection` does not accept `origin` as a parameter (#246 ) - 0.36.0 - added support for using custom connection class (#235) @@ -90,7 +90,7 @@ ChangeLog - 0.24.0 - Supporting http-basic auth in WebSocketApp (#143) - - fix failer of test.testInternalRecvStrict(#141) + - fix failure of test.testInternalRecvStrict(#141) - skip utf8 validation by skip_utf8_validation argument (#137) - WebsocketProxyException will be raised if we got error about proxy.(#138) diff --git a/bin/wsdump.py b/bin/wsdump.py index c2fb49f..5af00ac 100755 --- a/bin/wsdump.py +++ b/bin/wsdump.py @@ -2,15 +2,18 @@ import argparse import code -import six import sys import threading import time -import websocket + +import six from six.moves.urllib.parse import urlparse + +import websocket + try: import readline -except: +except ImportError: pass @@ -27,15 +30,17 @@ ENCODING = get_encoding() class VAction(argparse.Action): + def __call__(self, parser, args, values, option_string=None): - if values==None: + if values is None: values = "1" try: values = int(values) except ValueError: - values = values.count("v")+1 + values = values.count("v") + 1 setattr(args, self.dest, values) + def parse_args(): parser = argparse.ArgumentParser(description="WebSocket Simple Dump Tool") parser.add_argument("url", metavar="ws_url", @@ -63,7 +68,9 @@ def parse_args(): return parser.parse_args() -class RawInput(): + +class RawInput: + def raw_input(self, prompt): if six.PY3: line = input(prompt) @@ -77,7 +84,9 @@ class RawInput(): return line + class InteractiveConsole(RawInput, code.InteractiveConsole): + def write(self, data): sys.stdout.write("\033[2K\033[E") # sys.stdout.write("\n") @@ -88,7 +97,9 @@ class InteractiveConsole(RawInput, code.InteractiveConsole): def read(self): return self.raw_input("> ") + class NonInteractive(RawInput): + def write(self, data): sys.stdout.write(data) sys.stdout.write("\n") @@ -97,23 +108,24 @@ class NonInteractive(RawInput): def read(self): return self.raw_input("") + def main(): start_time = time.time() args = parse_args() if args.verbose > 1: websocket.enableTrace(True) options = {} - if (args.proxy): + if args.proxy: p = urlparse(args.proxy) options["http_proxy_host"] = p.hostname options["http_proxy_port"] = p.port - if (args.origin): + if args.origin: options["origin"] = args.origin - if (args.subprotocols): + if args.subprotocols: options["subprotocols"] = args.subprotocols opts = {} - if (args.nocert): - opts = { "cert_reqs": websocket.ssl.CERT_NONE, "check_hostname": False } + if args.nocert: + opts = {"cert_reqs": websocket.ssl.CERT_NONE, "check_hostname": False} ws = websocket.create_connection(args.url, sslopt=opts, **options) if args.raw: console = NonInteractive() @@ -125,21 +137,20 @@ def main(): try: frame = ws.recv_frame() except websocket.WebSocketException: - return (websocket.ABNF.OPCODE_CLOSE, None) + return websocket.ABNF.OPCODE_CLOSE, None if not frame: raise websocket.WebSocketException("Not a valid frame %s" % frame) elif frame.opcode in OPCODE_DATA: - return (frame.opcode, frame.data) + return frame.opcode, frame.data elif frame.opcode == websocket.ABNF.OPCODE_CLOSE: ws.send_close() - return (frame.opcode, None) + return frame.opcode, None elif frame.opcode == websocket.ABNF.OPCODE_PING: ws.pong(frame.data) return frame.opcode, frame.data return frame.opcode, frame.data - def recv_ws(): while True: opcode, data = recv() @@ -152,7 +163,7 @@ def main(): msg = "%s: %s" % (websocket.ABNF.OPCODE_MAP.get(opcode), data) if msg is not None: - if (args.timings): + if args.timings: console.write(str(time.time() - start_time) + ": " + msg) else: console.write(msg) diff --git a/compliance/test_fuzzingclient.py b/compliance/test_fuzzingclient.py index 8235dfb..f4b0ff1 100644 --- a/compliance/test_fuzzingclient.py +++ b/compliance/test_fuzzingclient.py @@ -1,13 +1,9 @@ #!/usr/bin/env python -import websocket import json import traceback -import six - - - +import websocket SERVER = 'ws://127.0.0.1:8642' AGENT = 'py-websockets-client' @@ -19,28 +15,28 @@ ws.close() for case in range(1, count+1): - url = SERVER + '/runCase?case={0}&agent={1}'.format(case, AGENT) - status = websocket.STATUS_NORMAL - try: - ws = websocket.create_connection(url) - while True: - opcode, msg = ws.recv_data() - if opcode == websocket.ABNF.OPCODE_TEXT: - msg.decode("utf-8") - if opcode in (websocket.ABNF.OPCODE_TEXT, websocket.ABNF.OPCODE_BINARY): - ws.send(msg, opcode) - except UnicodeDecodeError: - # this case is ok. - status = websocket.STATUS_PROTOCOL_ERROR - except websocket.WebSocketProtocolException: - status = websocket.STATUS_PROTOCOL_ERROR - except websocket.WebSocketPayloadException: - status = websocket.STATUS_INVALID_PAYLOAD - except Exception as e: - # status = websocket.STATUS_PROTOCOL_ERROR - print(traceback.format_exc()) - finally: - ws.close(status) + url = SERVER + '/runCase?case={0}&agent={1}'.format(case, AGENT) + status = websocket.STATUS_NORMAL + try: + ws = websocket.create_connection(url) + while True: + opcode, msg = ws.recv_data() + if opcode == websocket.ABNF.OPCODE_TEXT: + msg.decode("utf-8") + if opcode in (websocket.ABNF.OPCODE_TEXT, websocket.ABNF.OPCODE_BINARY): + ws.send(msg, opcode) + except UnicodeDecodeError: + # this case is ok. + status = websocket.STATUS_PROTOCOL_ERROR + except websocket.WebSocketProtocolException: + status = websocket.STATUS_PROTOCOL_ERROR + except websocket.WebSocketPayloadException: + status = websocket.STATUS_INVALID_PAYLOAD + except Exception as e: + # status = websocket.STATUS_PROTOCOL_ERROR + print(traceback.format_exc()) + finally: + ws.close(status) print("Ran {} test cases.".format(case)) url = SERVER + '/updateReports?agent={0}'.format(AGENT) diff --git a/examples/echoapp_client.py b/examples/echoapp_client.py index 48850b9..c15b35f 100644 --- a/examples/echoapp_client.py +++ b/examples/echoapp_client.py @@ -1,7 +1,7 @@ import websocket try: import thread -except ImportError: #TODO use Threading instead of _thread in python3 +except ImportError: # TODO use Threading instead of _thread in python3 import _thread as thread import time import sys @@ -41,8 +41,8 @@ if __name__ == "__main__": else: host = sys.argv[1] ws = websocket.WebSocketApp(host, - on_message = on_message, - on_error = on_error, - on_close = on_close) + on_message=on_message, + on_error=on_error, + on_close=on_close) ws.on_open = on_open ws.run_forever() diff --git a/setup.py b/setup.py index 13778d7..c0d1654 100644 --- a/setup.py +++ b/setup.py @@ -2,12 +2,12 @@ from setuptools import setup import sys VERSION = "0.37.0" -NAME="websocket_client" +NAME = "websocket_client" install_requires = ["six"] tests_require = [] if sys.version_info[0] == 2: - if sys.version_info[1] < 7 or (sys.version_info[1] == 7 and sys.version_info[2]< 9): + if sys.version_info[1] < 7 or (sys.version_info[1] == 7 and sys.version_info[2] < 9): install_requires.append('backports.ssl_match_hostname') if sys.version_info[1] < 7: tests_require.append('unittest2==0.8.0') diff --git a/websocket/__init__.py b/websocket/__init__.py index a895023..644bb75 100644 --- a/websocket/__init__.py +++ b/websocket/__init__.py @@ -19,7 +19,11 @@ Copyright (C) 2010 Hiroki Ohtani(liris) Boston, MA 02110-1335 USA """ -from ._core import * +from ._abnf import * from ._app import WebSocketApp +from ._core import * +from ._exceptions import * +from ._logging import * +from ._socket import * __version__ = "0.37.0" diff --git a/websocket/_abnf.py b/websocket/_abnf.py index 2c7c4ab..5b1a82d 100644 --- a/websocket/_abnf.py +++ b/websocket/_abnf.py @@ -19,10 +19,12 @@ Copyright (C) 2010 Hiroki Ohtani(liris) Boston, MA 02110-1335 USA """ -import six import array -import struct import os +import struct + +import six + from ._exceptions import * from ._utils import validate_utf8 @@ -44,6 +46,22 @@ except ImportError: else: return _d.tostring() +__all__ = [ + 'ABNF', 'continuous_frame', 'frame_buffer', + 'STATUS_NORMAL', + 'STATUS_GOING_AWAY', + 'STATUS_PROTOCOL_ERROR', + 'STATUS_UNSUPPORTED_DATA_TYPE', + 'STATUS_STATUS_NOT_AVAILABLE', + 'STATUS_ABNORMAL_CLOSED', + 'STATUS_INVALID_PAYLOAD', + 'STATUS_POLICY_VIOLATION', + 'STATUS_MESSAGE_TOO_BIG', + 'STATUS_INVALID_EXTENSION', + 'STATUS_UNEXPECTED_CONDITION', + 'STATUS_TLS_HANDSHAKE_ERROR', +] + # closing frame status codes. STATUS_NORMAL = 1000 STATUS_GOING_AWAY = 1001 @@ -68,7 +86,8 @@ VALID_CLOSE_STATUS = ( STATUS_MESSAGE_TOO_BIG, STATUS_INVALID_EXTENSION, STATUS_UNEXPECTED_CONDITION, - ) +) + class ABNF(object): """ @@ -78,16 +97,16 @@ class ABNF(object): """ # operation code values. - OPCODE_CONT = 0x0 - OPCODE_TEXT = 0x1 + OPCODE_CONT = 0x0 + OPCODE_TEXT = 0x1 OPCODE_BINARY = 0x2 - OPCODE_CLOSE = 0x8 - OPCODE_PING = 0x9 - OPCODE_PONG = 0xa + OPCODE_CLOSE = 0x8 + OPCODE_PING = 0x9 + OPCODE_PONG = 0xa # available operation code value tuple OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE, - OPCODE_PING, OPCODE_PONG) + OPCODE_PING, OPCODE_PONG) # opcode human readable string OPCODE_MAP = { @@ -97,10 +116,10 @@ class ABNF(object): OPCODE_CLOSE: "close", OPCODE_PING: "ping", OPCODE_PONG: "pong" - } + } # data length threshold. - LENGTH_7 = 0x7e + LENGTH_7 = 0x7e LENGTH_16 = 1 << 16 LENGTH_63 = 1 << 63 @@ -116,7 +135,7 @@ class ABNF(object): self.rsv3 = rsv3 self.opcode = opcode self.mask = mask - if data == None: + if data is None: data = "" self.data = data self.get_mask_key = os.urandom @@ -144,17 +163,19 @@ class ABNF(object): if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): raise WebSocketProtocolException("Invalid close frame.") - code = 256*six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2]) + code = 256 * \ + six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2]) if not self._is_valid_close_status(code): raise WebSocketProtocolException("Invalid close opcode.") - def _is_valid_close_status(self, code): - return code in VALID_CLOSE_STATUS or (3000 <= code <5000) + @staticmethod + def _is_valid_close_status(code): + return code in VALID_CLOSE_STATUS or (3000 <= code < 5000) def __str__(self): return "fin=" + str(self.fin) \ - + " opcode=" + str(self.opcode) \ - + " data=" + str(self.data) + + " opcode=" + str(self.opcode) \ + + " data=" + str(self.data) @staticmethod def create_frame(data, opcode, fin=1): @@ -224,7 +245,7 @@ class ABNF(object): data: data to mask/unmask. """ - if data == None: + if data is None: data = "" if isinstance(mask_key, six.text_type): @@ -237,9 +258,10 @@ class ABNF(object): _d = array.array("B", data) return _mask(_m, _d) + class frame_buffer(object): _HEADER_MASK_INDEX = 5 - _HEADER_LENGHT_INDEX = 6 + _HEADER_LENGTH_INDEX = 6 def __init__(self, recv_fn, skip_utf8_validation): self.recv = recv_fn @@ -255,7 +277,7 @@ class frame_buffer(object): self.mask = None def has_received_header(self): - return self.header is None + return self.header is None def recv_header(self): header = self.recv_strict(2) @@ -284,12 +306,11 @@ class frame_buffer(object): return False return self.header[frame_buffer._HEADER_MASK_INDEX] - def has_received_length(self): return self.length is None def recv_length(self): - bits = self.header[frame_buffer._HEADER_LENGHT_INDEX] + bits = self.header[frame_buffer._HEADER_LENGTH_INDEX] length_bits = bits & 0x7f if length_bits == 0x7e: v = self.recv_strict(2) @@ -342,10 +363,11 @@ class frame_buffer(object): # fragmenting the heap -- the number of bytes recv() actually # reads is limited by socket buffer and is relatively small, # yet passing large numbers repeatedly causes lots of large - # buffers allocated and then shrunk, which results in fragmentation. - bytes = self.recv(min(16384, shortage)) - self.recv_buffer.append(bytes) - shortage -= len(bytes) + # buffers allocated and then shrunk, which results in + # fragmentation. + bytes_ = self.recv(min(16384, shortage)) + self.recv_buffer.append(bytes_) + shortage -= len(bytes_) unified = six.b("").join(self.recv_buffer) @@ -358,6 +380,7 @@ class frame_buffer(object): class continuous_frame(object): + def __init__(self, fire_cont_frame, skip_utf8_validation): self.fire_cont_frame = fire_cont_frame self.skip_utf8_validation = skip_utf8_validation @@ -367,7 +390,8 @@ class continuous_frame(object): def validate(self, frame): if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: raise WebSocketProtocolException("Illegal frame") - if self.recving_frames and frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): + if self.recving_frames and \ + frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): raise WebSocketProtocolException("Illegal frame") def add(self, frame): @@ -389,6 +413,7 @@ class continuous_frame(object): self.cont_data = None frame.data = data[1] if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data): - raise WebSocketPayloadException("cannot decode: " + repr(frame.data)) + raise WebSocketPayloadException( + "cannot decode: " + repr(frame.data)) return [data[0], frame] diff --git a/websocket/_app.py b/websocket/_app.py index af67080..3ea942d 100644 --- a/websocket/_app.py +++ b/websocket/_app.py @@ -23,17 +23,18 @@ Copyright (C) 2010 Hiroki Ohtani(liris) """ WebSocketApp provides higher level APIs. """ +import select +import sys import threading import time import traceback -import sys -import select + import six +from ._abnf import ABNF from ._core import WebSocket, getdefaulttimeout from ._exceptions import * from ._logging import * -from ._abnf import ABNF __all__ = ["WebSocketApp"] @@ -43,7 +44,8 @@ class WebSocketApp(object): Higher level of APIs are provided. The interface is like JavaScript WebSocket object. """ - def __init__(self, url, header=[], + + def __init__(self, url, header=None, on_open=None, on_message=None, on_error=None, on_close=None, on_ping=None, on_pong=None, on_cont_message=None, @@ -87,7 +89,7 @@ class WebSocketApp(object): subprotocols: array of available sub protocols. default is None. """ self.url = url - self.header = header + self.header = header if header is not None else [] self.cookie = cookie self.on_open = on_open self.on_message = on_message @@ -113,7 +115,8 @@ class WebSocketApp(object): """ if not self.sock or self.sock.send(data, opcode) == 0: - raise WebSocketConnectionClosedException("Connection is already closed.") + raise WebSocketConnectionClosedException( + "Connection is already closed.") def close(self): """ @@ -168,27 +171,29 @@ class WebSocketApp(object): close_frame = None try: - self.sock = WebSocket(self.get_mask_key, - sockopt=sockopt, sslopt=sslopt, + self.sock = WebSocket( + self.get_mask_key, sockopt=sockopt, sslopt=sslopt, fire_cont_frame=self.on_cont_message and True or False, skip_utf8_validation=skip_utf8_validation) self.sock.settimeout(getdefaulttimeout()) - self.sock.connect(self.url, header=self.header, cookie=self.cookie, + self.sock.connect( + self.url, header=self.header, cookie=self.cookie, http_proxy_host=http_proxy_host, - http_proxy_port=http_proxy_port, - http_no_proxy=http_no_proxy, http_proxy_auth=http_proxy_auth, - subprotocols=self.subprotocols, + http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy, + http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols, host=host, origin=origin) self._callback(self.on_open) if ping_interval: event = threading.Event() - thread = threading.Thread(target=self._send_ping, args=(ping_interval, event)) + thread = threading.Thread( + target=self._send_ping, args=(ping_interval, event)) thread.setDaemon(True) thread.start() while self.sock.connected: - r, w, e = select.select((self.sock.sock, ), (), (), ping_timeout) + r, w, e = select.select( + (self.sock.sock, ), (), (), ping_timeout) if not self.keep_running: break @@ -203,8 +208,10 @@ class WebSocketApp(object): self.last_pong_tm = time.time() self._callback(self.on_pong, frame.data) elif op_code == ABNF.OPCODE_CONT and self.on_cont_message: - self._callback(self.on_data, data, frame.opcode, frame.fin) - self._callback(self.on_cont_message, frame.data, frame.fin) + self._callback(self.on_data, data, + frame.opcode, frame.fin) + self._callback(self.on_cont_message, + frame.data, frame.fin) else: data = frame.data if six.PY3 and frame.opcode == ABNF.OPCODE_TEXT: @@ -227,8 +234,9 @@ class WebSocketApp(object): thread.join() self.keep_running = False self.sock.close() - self._callback(self.on_close, - *self._get_close_args(close_frame.data if close_frame else None)) + close_args = self._get_close_args( + close_frame.data if close_frame else None) + self._callback(self.on_close, *close_args) self.sock = None def _get_close_args(self, data): @@ -244,7 +252,7 @@ class WebSocketApp(object): return [] if data and len(data) >= 2: - code = 256*six.byte2int(data[0:1]) + six.byte2int(data[1:2]) + code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2]) reason = data[2:].decode('utf-8') return [code, reason] diff --git a/websocket/_core.py b/websocket/_core.py index 13f7593..adcdb6b 100644 --- a/websocket/_core.py +++ b/websocket/_core.py @@ -21,28 +21,22 @@ Copyright (C) 2010 Hiroki Ohtani(liris) """ from __future__ import print_function - -import six import socket - -if six.PY3: - from base64 import encodebytes as base64encode -else: - from base64 import encodestring as base64encode - import struct import threading +import six + # websocket modules -from ._exceptions import * from ._abnf import * +from ._exceptions import * +from ._handshake import * +from ._http import * +from ._logging import * from ._socket import * from ._utils import * -from ._url import * -from ._logging import * -from ._http import * -from ._handshake import * -from ._ssl_compat import * + +__all__ = ['WebSocket', 'create_connection'] """ websocket python client. @@ -83,7 +77,7 @@ class WebSocket(object): def __init__(self, get_mask_key=None, sockopt=None, sslopt=None, fire_cont_frame=False, enable_multithread=False, - skip_utf8_validation=False, **options): + skip_utf8_validation=False, **_): """ Initialize WebSocket object. """ @@ -95,7 +89,8 @@ class WebSocket(object): self.get_mask_key = get_mask_key # These buffer over the build-up of a single frame. self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation) - self.cont_frame = continuous_frame(fire_cont_frame, skip_utf8_validation) + self.cont_frame = continuous_frame( + fire_cont_frame, skip_utf8_validation) if enable_multithread: self.lock = threading.Lock() @@ -329,7 +324,8 @@ class WebSocket(object): if not frame: # handle error: # 'NoneType' object has no attribute 'opcode' - raise WebSocketProtocolException("Not a valid frame %s" % frame) + raise WebSocketProtocolException( + "Not a valid frame %s" % frame) elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT): self.cont_frame.validate(frame) self.cont_frame.add(frame) @@ -339,17 +335,18 @@ class WebSocket(object): elif frame.opcode == ABNF.OPCODE_CLOSE: self.send_close() - return (frame.opcode, frame) + return frame.opcode, frame elif frame.opcode == ABNF.OPCODE_PING: if len(frame.data) < 126: self.pong(frame.data) else: - raise WebSocketProtocolException("Ping message is too long") + raise WebSocketProtocolException( + "Ping message is too long") if control_frame: - return (frame.opcode, frame) + return frame.opcode, frame elif frame.opcode == ABNF.OPCODE_PONG: if control_frame: - return (frame.opcode, frame) + return frame.opcode, frame def recv_frame(self): """ @@ -389,7 +386,8 @@ class WebSocket(object): try: self.connected = False - self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE) + self.send(struct.pack('!H', status) + + reason, ABNF.OPCODE_CLOSE) sock_timeout = self.sock.gettimeout() self.sock.settimeout(timeout) try: @@ -415,7 +413,7 @@ class WebSocket(object): self.sock.shutdown(socket.SHUT_RDWR) def shutdown(self): - "close socket, immediately." + """close socket, immediately.""" if self.sock: self.sock.close() self.sock = None diff --git a/websocket/_exceptions.py b/websocket/_exceptions.py index 7b3e508..9d1a3bc 100644 --- a/websocket/_exceptions.py +++ b/websocket/_exceptions.py @@ -25,6 +25,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris) define websocket exceptions """ + class WebSocketException(Exception): """ websocket exception class. @@ -72,6 +73,8 @@ class WebSocketBadStatusException(WebSocketException): """ WebSocketBadStatusException will be raised when we get bad handshake status code. """ + def __init__(self, message, status_code): - super(WebSocketBadStatusException, self).__init__(message % status_code) + super(WebSocketBadStatusException, self).__init__( + message % status_code) self.status_code = status_code diff --git a/websocket/_handshake.py b/websocket/_handshake.py index dd52dd4..f2c5352 100644 --- a/websocket/_handshake.py +++ b/websocket/_handshake.py @@ -19,25 +19,22 @@ Copyright (C) 2010 Hiroki Ohtani(liris) Boston, MA 02110-1335 USA """ +import hashlib +import hmac +import os import six + +from ._exceptions import * +from ._http import * +from ._logging import * +from ._socket import * + if six.PY3: from base64 import encodebytes as base64encode else: from base64 import encodestring as base64encode -import uuid -import hashlib -import hmac -import os -import sys - -from ._logging import * -from ._url import * -from ._socket import* -from ._http import * -from ._exceptions import * - __all__ = ["handshake_response", "handshake"] if hasattr(hmac, "compare_digest"): @@ -51,6 +48,7 @@ VERSION = 13 class handshake_response(object): + def __init__(self, status, headers, subprotocol): self.status = status self.headers = headers @@ -73,10 +71,11 @@ def handshake(sock, hostname, port, resource, **options): def _get_handshake_headers(resource, host, port, options): - headers = [] - headers.append("GET %s HTTP/1.1" % resource) - headers.append("Upgrade: websocket") - headers.append("Connection: Upgrade") + headers = [ + "GET %s HTTP/1.1" % resource, + "Upgrade: websocket", + "Connection: Upgrade" + ] if port == 80 or port == 443: hostport = host else: @@ -126,7 +125,7 @@ def _get_resp_headers(sock, success_status=101): _HEADERS_TO_CHECK = { "upgrade": "websocket", "connection": "upgrade", - } +} def _validate(headers, key, subprotocols): diff --git a/websocket/_http.py b/websocket/_http.py index 63f3f83..88f313a 100644 --- a/websocket/_http.py +++ b/websocket/_http.py @@ -19,45 +19,49 @@ Copyright (C) 2010 Hiroki Ohtani(liris) Boston, MA 02110-1335 USA """ - -import six -import socket import errno import os +import socket import sys +import six + +from ._exceptions import * +from ._logging import * +from ._socket import* +from ._ssl_compat import * +from ._url import * + if six.PY3: from base64 import encodebytes as base64encode else: from base64 import encodestring as base64encode -from ._logging import * -from ._url import * -from ._socket import* -from ._exceptions import * -from ._ssl_compat import * - __all__ = ["proxy_info", "connect", "read_headers"] + class proxy_info(object): + def __init__(self, **options): self.host = options.get("http_proxy_host", None) if self.host: self.port = options.get("http_proxy_port", 0) - self.auth = options.get("http_proxy_auth", None) + self.auth = options.get("http_proxy_auth", None) self.no_proxy = options.get("http_no_proxy", None) else: self.port = 0 self.auth = None self.no_proxy = None + def connect(url, options, proxy, socket): hostname, port, resource, is_secure = parse_url(url) if socket: return socket, (hostname, port, resource) - addrinfo_list, need_tunnel, auth = _get_addrinfo_list(hostname, port, is_secure, proxy) + addrinfo_list, need_tunnel, auth = _get_addrinfo_list( + hostname, port, is_secure, proxy) if not addrinfo_list: raise WebSocketException( "Host not found.: " + hostname + ":" + str(port)) @@ -82,10 +86,11 @@ def connect(url, options, proxy, socket): def _get_addrinfo_list(hostname, port, is_secure, proxy): - phost, pport, pauth = get_proxy_info(hostname, is_secure, - proxy.host, proxy.port, proxy.auth, proxy.no_proxy) + phost, pport, pauth = get_proxy_info( + hostname, is_secure, proxy.host, proxy.port, proxy.auth, proxy.no_proxy) if not phost: - addrinfo_list = socket.getaddrinfo(hostname, port, 0, 0, socket.SOL_TCP) + addrinfo_list = socket.getaddrinfo( + hostname, port, 0, 0, socket.SOL_TCP) return addrinfo_list, False, None else: pport = pport and pport or 80 @@ -137,14 +142,15 @@ def _wrap_sni_socket(sock, sslopt, hostname, check_hostname): sslopt.get('keyfile', None), sslopt.get('password', None), ) - # see https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153 + # see + # https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153 context.verify_mode = sslopt['cert_reqs'] if HAVE_CONTEXT_CHECK_HOSTNAME: context.check_hostname = check_hostname if 'ciphers' in sslopt: context.set_ciphers(sslopt['ciphers']) - if 'cert_chain' in sslopt : - certfile,keyfile,password = sslopt['cert_chain'] + if 'cert_chain' in sslopt: + certfile, keyfile, password = sslopt['cert_chain'] context.load_cert_chain(certfile, keyfile, password) return context.wrap_socket( @@ -158,12 +164,13 @@ def _wrap_sni_socket(sock, sslopt, hostname, check_hostname): def _ssl_socket(sock, user_sslopt, hostname): sslopt = dict(cert_reqs=ssl.CERT_REQUIRED) sslopt.update(user_sslopt) - + certPath = os.path.join( os.path.dirname(__file__), "cacert.pem") - if os.path.isfile(certPath) and user_sslopt.get('ca_certs', None) == None: + if os.path.isfile(certPath) and user_sslopt.get('ca_certs', None) is None: sslopt['ca_certs'] = certPath - check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop('check_hostname', True) + check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop( + 'check_hostname', True) if _can_use_sni(): sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname) @@ -176,6 +183,7 @@ def _ssl_socket(sock, user_sslopt, hostname): return sock + def _tunnel(sock, host, port, auth): debug("Connecting proxy...") connect_header = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port) @@ -199,9 +207,10 @@ def _tunnel(sock, host, port, auth): if status != 200: raise WebSocketProxyException( "failed CONNECT via proxy status: %r" % status) - + return sock + def read_headers(sock): status = None headers = {} diff --git a/websocket/_logging.py b/websocket/_logging.py index a77d999..d440bf7 100644 --- a/websocket/_logging.py +++ b/websocket/_logging.py @@ -19,7 +19,6 @@ Copyright (C) 2010 Hiroki Ohtani(liris) Boston, MA 02110-1335 USA """ - import logging _logger = logging.getLogger('websocket') @@ -29,15 +28,15 @@ __all__ = ["enableTrace", "dump", "error", "debug", "trace", "isEnabledForError", "isEnabledForDebug"] -def enableTrace(tracable): +def enableTrace(traceable): """ - turn on/off the tracability. + turn on/off the traceability. - tracable: boolean value. if set True, tracability is enabled. + traceable: boolean value. if set True, traceability is enabled. """ global _traceEnabled - _traceEnabled = tracable - if tracable: + _traceEnabled = traceable + if traceable: if not _logger.handlers: _logger.addHandler(logging.StreamHandler()) _logger.setLevel(logging.DEBUG) diff --git a/websocket/_socket.py b/websocket/_socket.py index b2aaa55..e2e1dd2 100644 --- a/websocket/_socket.py +++ b/websocket/_socket.py @@ -19,13 +19,13 @@ Copyright (C) 2010 Hiroki Ohtani(liris) Boston, MA 02110-1335 USA """ - import socket + import six from ._exceptions import * -from ._utils import * from ._ssl_compat import * +from ._utils import * DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)] if hasattr(socket, "SO_KEEPALIVE"): @@ -42,7 +42,9 @@ _default_timeout = None __all__ = ["DEFAULT_SOCKET_OPTION", "sock_opt", "setdefaulttimeout", "getdefaulttimeout", "recv", "recv_line", "send"] + class sock_opt(object): + def __init__(self, sockopt, sslopt): if sockopt is None: sockopt = [] @@ -52,6 +54,7 @@ class sock_opt(object): self.sslopt = sslopt self.timeout = None + def setdefaulttimeout(timeout): """ Set the global timeout setting to connect. @@ -74,7 +77,7 @@ def recv(sock, bufsize): raise WebSocketConnectionClosedException("socket is already closed.") try: - bytes = sock.recv(bufsize) + bytes_ = sock.recv(bufsize) except socket.timeout as e: message = extract_err_message(e) raise WebSocketTimeoutException(message) @@ -85,10 +88,11 @@ def recv(sock, bufsize): else: raise - if not bytes: - raise WebSocketConnectionClosedException("Connection is already closed.") + if not bytes_: + raise WebSocketConnectionClosedException( + "Connection is already closed.") - return bytes + return bytes_ def recv_line(sock): diff --git a/websocket/_ssl_compat.py b/websocket/_ssl_compat.py index d41ca79..0304816 100644 --- a/websocket/_ssl_compat.py +++ b/websocket/_ssl_compat.py @@ -19,7 +19,6 @@ Copyright (C) 2010 Hiroki Ohtani(liris) Boston, MA 02110-1335 USA """ - __all__ = ["HAVE_SSL", "ssl", "SSLError"] try: diff --git a/websocket/_url.py b/websocket/_url.py index dfd2c4b..11216b6 100644 --- a/websocket/_url.py +++ b/websocket/_url.py @@ -19,9 +19,9 @@ Copyright (C) 2010 Hiroki Ohtani(liris) Boston, MA 02110-1335 USA """ +import os from six.moves.urllib.parse import urlparse -import os __all__ = ["parse_url", "get_proxy_info"] @@ -66,7 +66,7 @@ def parse_url(url): if parsed.query: resource += "?" + parsed.query - return (hostname, port, resource, is_secure) + return hostname, port, resource, is_secure DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"] @@ -82,8 +82,9 @@ def _is_no_proxy_host(hostname, no_proxy): return hostname in no_proxy -def get_proxy_info(hostname, is_secure, - proxy_host=None, proxy_port=0, proxy_auth=None, no_proxy=None): +def get_proxy_info( + hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None, + no_proxy=None): """ try to retrieve proxy host and port from environment if not provided in options. diff --git a/websocket/_utils.py b/websocket/_utils.py index 58f6950..399fb89 100644 --- a/websocket/_utils.py +++ b/websocket/_utils.py @@ -19,16 +19,17 @@ Copyright (C) 2010 Hiroki Ohtani(liris) Boston, MA 02110-1335 USA """ - import six __all__ = ["NoLock", "validate_utf8", "extract_err_message"] + class NoLock(object): + def __enter__(self): pass - def __exit__(self, type, value, traceback): + def __exit__(self, exc_type, exc_value, traceback): pass try: @@ -69,10 +70,11 @@ except ImportError: def _decode(state, codep, ch): tp = _UTF8D[ch] - codep = (ch & 0x3f ) | (codep << 6) if (state != _UTF8_ACCEPT) else (0xff >> tp) & (ch) + codep = (ch & 0x3f) | (codep << 6) if ( + state != _UTF8_ACCEPT) else (0xff >> tp) & ch state = _UTF8D[256 + state + tp] - return state, codep; + return state, codep def _validate_utf8(utfbytes): state = _UTF8_ACCEPT @@ -86,6 +88,7 @@ except ImportError: return True + def validate_utf8(utfbytes): """ validate utf8 byte string. @@ -94,6 +97,7 @@ def validate_utf8(utfbytes): """ return _validate_utf8(utfbytes) + def extract_err_message(exception): if exception.args: return exception.args[0] diff --git a/websocket/tests/test_websocket.py b/websocket/tests/test_websocket.py index 4573bf7..d2170ae 100644 --- a/websocket/tests/test_websocket.py +++ b/websocket/tests/test_websocket.py @@ -1,14 +1,33 @@ # -*- coding: utf-8 -*- # -import six import sys sys.path[0:0] = [""] import os import os.path -import base64 import socket + +import six + +# websocket-client +import websocket as ws +from websocket._handshake import _create_sec_websocket_key, \ + _validate as _validate_header +from websocket._http import read_headers +from websocket._url import get_proxy_info, parse_url +from websocket._utils import validate_utf8 + +if six.PY3: + from base64 import decodebytes as base64decode +else: + from base64 import decodestring as base64decode + +if sys.version_info[0] == 2 and sys.version_info[1] < 7: + import unittest2 as unittest +else: + import unittest + try: from ssl import SSLError except ImportError: @@ -16,37 +35,15 @@ except ImportError: class SSLError(Exception): pass -if sys.version_info[0] == 2 and sys.version_info[1] < 7: - import unittest2 as unittest -else: - import unittest - -import uuid - -if six.PY3: - from base64 import decodebytes as base64decode -else: - from base64 import decodestring as base64decode - - -# websocket-client -import websocket as ws -from websocket._handshake import _create_sec_websocket_key -from websocket._url import parse_url, get_proxy_info -from websocket._utils import validate_utf8 -from websocket._handshake import _validate as _validate_header -from websocket._http import read_headers - - # Skip test to access the internet. TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1' # Skip Secure WebSocket test. TEST_SECURE_WS = True -TRACABLE = False +TRACEABLE = False -def create_mask_key(n): +def create_mask_key(_): return "abcd" @@ -86,7 +83,7 @@ class HeaderSockMock(SockMock): class WebSocketTest(unittest.TestCase): def setUp(self): - ws.enableTrace(TRACABLE) + ws.enableTrace(TRACEABLE) def tearDown(self): pass @@ -263,7 +260,7 @@ class WebSocketTest(unittest.TestCase): @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") def testIter(self): count = 2 - for rsvp in ws.create_connection('ws://stream.meetup.com/2/rsvps'): + for _ in ws.create_connection('ws://stream.meetup.com/2/rsvps'): count -= 1 if count == 0: break @@ -282,7 +279,7 @@ class WebSocketTest(unittest.TestCase): # s.add_packet(SSLError("The read operation timed out")) s.add_packet(six.b("baz")) with self.assertRaises(ws.WebSocketTimeoutException): - data = sock.frame_buffer.recv_strict(9) + sock.frame_buffer.recv_strict(9) # if six.PY2: # with self.assertRaises(ws.WebSocketTimeoutException): # data = sock._recv_strict(9) @@ -292,7 +289,7 @@ class WebSocketTest(unittest.TestCase): data = sock.frame_buffer.recv_strict(9) self.assertEqual(data, six.b("foobarbaz")) with self.assertRaises(ws.WebSocketConnectionClosedException): - data = sock.frame_buffer.recv_strict(1) + sock.frame_buffer.recv_strict(1) def testRecvTimeout(self): sock = ws.WebSocket() @@ -303,13 +300,13 @@ class WebSocketTest(unittest.TestCase): s.add_packet(socket.timeout()) s.add_packet(six.b("\x4e\x43\x33\x0e\x10\x0f\x00\x40")) with self.assertRaises(ws.WebSocketTimeoutException): - data = sock.recv() + sock.recv() with self.assertRaises(ws.WebSocketTimeoutException): - data = sock.recv() + sock.recv() data = sock.recv() self.assertEqual(data, "Hello, World!") with self.assertRaises(ws.WebSocketConnectionClosedException): - data = sock.recv() + sock.recv() def testRecvWithSimpleFragmentation(self): sock = ws.WebSocket() @@ -374,10 +371,10 @@ class WebSocketTest(unittest.TestCase): sock = ws.WebSocket() s = sock.sock = SockMock() # OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, " - s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15" \ + s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15" "\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC")) # OPCODE=CONT, FIN=0, MSG="dear friends, " - s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07" \ + s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07" "\x17MB")) # OPCODE=CONT, FIN=1, MSG="once more" s.add_packet(six.b("\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04")) @@ -397,7 +394,7 @@ class WebSocketTest(unittest.TestCase): # OPCODE=PING, FIN=1, MSG="Please PONG this" s.add_packet(six.b("\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")) # OPCODE=CONT, FIN=1, MSG="of a good thing" - s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c" \ + s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c" "\x08\x0c\x04")) data = sock.recv() self.assertEqual(data, "Too much of a good thing") @@ -479,7 +476,7 @@ class WebSocketAppTest(unittest.TestCase): """ def setUp(self): - ws.enableTrace(TRACABLE) + ws.enableTrace(TRACEABLE) WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet() WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()