This commit is contained in:
liris
2016-05-10 08:52:30 +09:00
18 changed files with 276 additions and 219 deletions

View File

@@ -2,7 +2,7 @@ ChangeLog
============
- 0.37.0
- fixed fialer that `websocket.create_connection` does not accept `origin` as a parameter (#246 )
- fixed failure that `websocket.create_connection` does not accept `origin` as a parameter (#246 )
- 0.36.0
- added support for using custom connection class (#235)
@@ -90,7 +90,7 @@ ChangeLog
- 0.24.0
- Supporting http-basic auth in WebSocketApp (#143)
- fix failer of test.testInternalRecvStrict(#141)
- fix failure of test.testInternalRecvStrict(#141)
- skip utf8 validation by skip_utf8_validation argument (#137)
- WebsocketProxyException will be raised if we got error about proxy.(#138)

View File

@@ -2,15 +2,18 @@
import argparse
import code
import six
import sys
import threading
import time
import websocket
import six
from six.moves.urllib.parse import urlparse
import websocket
try:
import readline
except:
except ImportError:
pass
@@ -27,15 +30,17 @@ ENCODING = get_encoding()
class VAction(argparse.Action):
def __call__(self, parser, args, values, option_string=None):
if values==None:
if values is None:
values = "1"
try:
values = int(values)
except ValueError:
values = values.count("v")+1
values = values.count("v") + 1
setattr(args, self.dest, values)
def parse_args():
parser = argparse.ArgumentParser(description="WebSocket Simple Dump Tool")
parser.add_argument("url", metavar="ws_url",
@@ -63,7 +68,9 @@ def parse_args():
return parser.parse_args()
class RawInput():
class RawInput:
def raw_input(self, prompt):
if six.PY3:
line = input(prompt)
@@ -77,7 +84,9 @@ class RawInput():
return line
class InteractiveConsole(RawInput, code.InteractiveConsole):
def write(self, data):
sys.stdout.write("\033[2K\033[E")
# sys.stdout.write("\n")
@@ -88,7 +97,9 @@ class InteractiveConsole(RawInput, code.InteractiveConsole):
def read(self):
return self.raw_input("> ")
class NonInteractive(RawInput):
def write(self, data):
sys.stdout.write(data)
sys.stdout.write("\n")
@@ -97,23 +108,24 @@ class NonInteractive(RawInput):
def read(self):
return self.raw_input("")
def main():
start_time = time.time()
args = parse_args()
if args.verbose > 1:
websocket.enableTrace(True)
options = {}
if (args.proxy):
if args.proxy:
p = urlparse(args.proxy)
options["http_proxy_host"] = p.hostname
options["http_proxy_port"] = p.port
if (args.origin):
if args.origin:
options["origin"] = args.origin
if (args.subprotocols):
if args.subprotocols:
options["subprotocols"] = args.subprotocols
opts = {}
if (args.nocert):
opts = { "cert_reqs": websocket.ssl.CERT_NONE, "check_hostname": False }
if args.nocert:
opts = {"cert_reqs": websocket.ssl.CERT_NONE, "check_hostname": False}
ws = websocket.create_connection(args.url, sslopt=opts, **options)
if args.raw:
console = NonInteractive()
@@ -125,21 +137,20 @@ def main():
try:
frame = ws.recv_frame()
except websocket.WebSocketException:
return (websocket.ABNF.OPCODE_CLOSE, None)
return websocket.ABNF.OPCODE_CLOSE, None
if not frame:
raise websocket.WebSocketException("Not a valid frame %s" % frame)
elif frame.opcode in OPCODE_DATA:
return (frame.opcode, frame.data)
return frame.opcode, frame.data
elif frame.opcode == websocket.ABNF.OPCODE_CLOSE:
ws.send_close()
return (frame.opcode, None)
return frame.opcode, None
elif frame.opcode == websocket.ABNF.OPCODE_PING:
ws.pong(frame.data)
return frame.opcode, frame.data
return frame.opcode, frame.data
def recv_ws():
while True:
opcode, data = recv()
@@ -152,7 +163,7 @@ def main():
msg = "%s: %s" % (websocket.ABNF.OPCODE_MAP.get(opcode), data)
if msg is not None:
if (args.timings):
if args.timings:
console.write(str(time.time() - start_time) + ": " + msg)
else:
console.write(msg)

View File

@@ -1,13 +1,9 @@
#!/usr/bin/env python
import websocket
import json
import traceback
import six
import websocket
SERVER = 'ws://127.0.0.1:8642'
AGENT = 'py-websockets-client'
@@ -19,28 +15,28 @@ ws.close()
for case in range(1, count+1):
url = SERVER + '/runCase?case={0}&agent={1}'.format(case, AGENT)
status = websocket.STATUS_NORMAL
try:
ws = websocket.create_connection(url)
while True:
opcode, msg = ws.recv_data()
if opcode == websocket.ABNF.OPCODE_TEXT:
msg.decode("utf-8")
if opcode in (websocket.ABNF.OPCODE_TEXT, websocket.ABNF.OPCODE_BINARY):
ws.send(msg, opcode)
except UnicodeDecodeError:
# this case is ok.
status = websocket.STATUS_PROTOCOL_ERROR
except websocket.WebSocketProtocolException:
status = websocket.STATUS_PROTOCOL_ERROR
except websocket.WebSocketPayloadException:
status = websocket.STATUS_INVALID_PAYLOAD
except Exception as e:
# status = websocket.STATUS_PROTOCOL_ERROR
print(traceback.format_exc())
finally:
ws.close(status)
url = SERVER + '/runCase?case={0}&agent={1}'.format(case, AGENT)
status = websocket.STATUS_NORMAL
try:
ws = websocket.create_connection(url)
while True:
opcode, msg = ws.recv_data()
if opcode == websocket.ABNF.OPCODE_TEXT:
msg.decode("utf-8")
if opcode in (websocket.ABNF.OPCODE_TEXT, websocket.ABNF.OPCODE_BINARY):
ws.send(msg, opcode)
except UnicodeDecodeError:
# this case is ok.
status = websocket.STATUS_PROTOCOL_ERROR
except websocket.WebSocketProtocolException:
status = websocket.STATUS_PROTOCOL_ERROR
except websocket.WebSocketPayloadException:
status = websocket.STATUS_INVALID_PAYLOAD
except Exception as e:
# status = websocket.STATUS_PROTOCOL_ERROR
print(traceback.format_exc())
finally:
ws.close(status)
print("Ran {} test cases.".format(case))
url = SERVER + '/updateReports?agent={0}'.format(AGENT)

View File

@@ -1,7 +1,7 @@
import websocket
try:
import thread
except ImportError: #TODO use Threading instead of _thread in python3
except ImportError: # TODO use Threading instead of _thread in python3
import _thread as thread
import time
import sys
@@ -41,8 +41,8 @@ if __name__ == "__main__":
else:
host = sys.argv[1]
ws = websocket.WebSocketApp(host,
on_message = on_message,
on_error = on_error,
on_close = on_close)
on_message=on_message,
on_error=on_error,
on_close=on_close)
ws.on_open = on_open
ws.run_forever()

View File

@@ -2,12 +2,12 @@ from setuptools import setup
import sys
VERSION = "0.37.0"
NAME="websocket_client"
NAME = "websocket_client"
install_requires = ["six"]
tests_require = []
if sys.version_info[0] == 2:
if sys.version_info[1] < 7 or (sys.version_info[1] == 7 and sys.version_info[2]< 9):
if sys.version_info[1] < 7 or (sys.version_info[1] == 7 and sys.version_info[2] < 9):
install_requires.append('backports.ssl_match_hostname')
if sys.version_info[1] < 7:
tests_require.append('unittest2==0.8.0')

View File

@@ -19,7 +19,11 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
from ._core import *
from ._abnf import *
from ._app import WebSocketApp
from ._core import *
from ._exceptions import *
from ._logging import *
from ._socket import *
__version__ = "0.37.0"

View File

@@ -19,10 +19,12 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
import six
import array
import struct
import os
import struct
import six
from ._exceptions import *
from ._utils import validate_utf8
@@ -44,6 +46,22 @@ except ImportError:
else:
return _d.tostring()
__all__ = [
'ABNF', 'continuous_frame', 'frame_buffer',
'STATUS_NORMAL',
'STATUS_GOING_AWAY',
'STATUS_PROTOCOL_ERROR',
'STATUS_UNSUPPORTED_DATA_TYPE',
'STATUS_STATUS_NOT_AVAILABLE',
'STATUS_ABNORMAL_CLOSED',
'STATUS_INVALID_PAYLOAD',
'STATUS_POLICY_VIOLATION',
'STATUS_MESSAGE_TOO_BIG',
'STATUS_INVALID_EXTENSION',
'STATUS_UNEXPECTED_CONDITION',
'STATUS_TLS_HANDSHAKE_ERROR',
]
# closing frame status codes.
STATUS_NORMAL = 1000
STATUS_GOING_AWAY = 1001
@@ -68,7 +86,8 @@ VALID_CLOSE_STATUS = (
STATUS_MESSAGE_TOO_BIG,
STATUS_INVALID_EXTENSION,
STATUS_UNEXPECTED_CONDITION,
)
)
class ABNF(object):
"""
@@ -78,16 +97,16 @@ class ABNF(object):
"""
# operation code values.
OPCODE_CONT = 0x0
OPCODE_TEXT = 0x1
OPCODE_CONT = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
OPCODE_CLOSE = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xa
OPCODE_CLOSE = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xa
# available operation code value tuple
OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE,
OPCODE_PING, OPCODE_PONG)
OPCODE_PING, OPCODE_PONG)
# opcode human readable string
OPCODE_MAP = {
@@ -97,10 +116,10 @@ class ABNF(object):
OPCODE_CLOSE: "close",
OPCODE_PING: "ping",
OPCODE_PONG: "pong"
}
}
# data length threshold.
LENGTH_7 = 0x7e
LENGTH_7 = 0x7e
LENGTH_16 = 1 << 16
LENGTH_63 = 1 << 63
@@ -116,7 +135,7 @@ class ABNF(object):
self.rsv3 = rsv3
self.opcode = opcode
self.mask = mask
if data == None:
if data is None:
data = ""
self.data = data
self.get_mask_key = os.urandom
@@ -144,17 +163,19 @@ class ABNF(object):
if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
raise WebSocketProtocolException("Invalid close frame.")
code = 256*six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2])
code = 256 * \
six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2])
if not self._is_valid_close_status(code):
raise WebSocketProtocolException("Invalid close opcode.")
def _is_valid_close_status(self, code):
return code in VALID_CLOSE_STATUS or (3000 <= code <5000)
@staticmethod
def _is_valid_close_status(code):
return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
def __str__(self):
return "fin=" + str(self.fin) \
+ " opcode=" + str(self.opcode) \
+ " data=" + str(self.data)
+ " opcode=" + str(self.opcode) \
+ " data=" + str(self.data)
@staticmethod
def create_frame(data, opcode, fin=1):
@@ -224,7 +245,7 @@ class ABNF(object):
data: data to mask/unmask.
"""
if data == None:
if data is None:
data = ""
if isinstance(mask_key, six.text_type):
@@ -237,9 +258,10 @@ class ABNF(object):
_d = array.array("B", data)
return _mask(_m, _d)
class frame_buffer(object):
_HEADER_MASK_INDEX = 5
_HEADER_LENGHT_INDEX = 6
_HEADER_LENGTH_INDEX = 6
def __init__(self, recv_fn, skip_utf8_validation):
self.recv = recv_fn
@@ -255,7 +277,7 @@ class frame_buffer(object):
self.mask = None
def has_received_header(self):
return self.header is None
return self.header is None
def recv_header(self):
header = self.recv_strict(2)
@@ -284,12 +306,11 @@ class frame_buffer(object):
return False
return self.header[frame_buffer._HEADER_MASK_INDEX]
def has_received_length(self):
return self.length is None
def recv_length(self):
bits = self.header[frame_buffer._HEADER_LENGHT_INDEX]
bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
length_bits = bits & 0x7f
if length_bits == 0x7e:
v = self.recv_strict(2)
@@ -342,10 +363,11 @@ class frame_buffer(object):
# fragmenting the heap -- the number of bytes recv() actually
# reads is limited by socket buffer and is relatively small,
# yet passing large numbers repeatedly causes lots of large
# buffers allocated and then shrunk, which results in fragmentation.
bytes = self.recv(min(16384, shortage))
self.recv_buffer.append(bytes)
shortage -= len(bytes)
# buffers allocated and then shrunk, which results in
# fragmentation.
bytes_ = self.recv(min(16384, shortage))
self.recv_buffer.append(bytes_)
shortage -= len(bytes_)
unified = six.b("").join(self.recv_buffer)
@@ -358,6 +380,7 @@ class frame_buffer(object):
class continuous_frame(object):
def __init__(self, fire_cont_frame, skip_utf8_validation):
self.fire_cont_frame = fire_cont_frame
self.skip_utf8_validation = skip_utf8_validation
@@ -367,7 +390,8 @@ class continuous_frame(object):
def validate(self, frame):
if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
raise WebSocketProtocolException("Illegal frame")
if self.recving_frames and frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
if self.recving_frames and \
frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
raise WebSocketProtocolException("Illegal frame")
def add(self, frame):
@@ -389,6 +413,7 @@ class continuous_frame(object):
self.cont_data = None
frame.data = data[1]
if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
raise WebSocketPayloadException("cannot decode: " + repr(frame.data))
raise WebSocketPayloadException(
"cannot decode: " + repr(frame.data))
return [data[0], frame]

View File

@@ -23,17 +23,18 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
"""
WebSocketApp provides higher level APIs.
"""
import select
import sys
import threading
import time
import traceback
import sys
import select
import six
from ._abnf import ABNF
from ._core import WebSocket, getdefaulttimeout
from ._exceptions import *
from ._logging import *
from ._abnf import ABNF
__all__ = ["WebSocketApp"]
@@ -43,7 +44,8 @@ class WebSocketApp(object):
Higher level of APIs are provided.
The interface is like JavaScript WebSocket object.
"""
def __init__(self, url, header=[],
def __init__(self, url, header=None,
on_open=None, on_message=None, on_error=None,
on_close=None, on_ping=None, on_pong=None,
on_cont_message=None,
@@ -87,7 +89,7 @@ class WebSocketApp(object):
subprotocols: array of available sub protocols. default is None.
"""
self.url = url
self.header = header
self.header = header if header is not None else []
self.cookie = cookie
self.on_open = on_open
self.on_message = on_message
@@ -113,7 +115,8 @@ class WebSocketApp(object):
"""
if not self.sock or self.sock.send(data, opcode) == 0:
raise WebSocketConnectionClosedException("Connection is already closed.")
raise WebSocketConnectionClosedException(
"Connection is already closed.")
def close(self):
"""
@@ -168,27 +171,29 @@ class WebSocketApp(object):
close_frame = None
try:
self.sock = WebSocket(self.get_mask_key,
sockopt=sockopt, sslopt=sslopt,
self.sock = WebSocket(
self.get_mask_key, sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=self.on_cont_message and True or False,
skip_utf8_validation=skip_utf8_validation)
self.sock.settimeout(getdefaulttimeout())
self.sock.connect(self.url, header=self.header, cookie=self.cookie,
self.sock.connect(
self.url, header=self.header, cookie=self.cookie,
http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port,
http_no_proxy=http_no_proxy, http_proxy_auth=http_proxy_auth,
subprotocols=self.subprotocols,
http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy,
http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols,
host=host, origin=origin)
self._callback(self.on_open)
if ping_interval:
event = threading.Event()
thread = threading.Thread(target=self._send_ping, args=(ping_interval, event))
thread = threading.Thread(
target=self._send_ping, args=(ping_interval, event))
thread.setDaemon(True)
thread.start()
while self.sock.connected:
r, w, e = select.select((self.sock.sock, ), (), (), ping_timeout)
r, w, e = select.select(
(self.sock.sock, ), (), (), ping_timeout)
if not self.keep_running:
break
@@ -203,8 +208,10 @@ class WebSocketApp(object):
self.last_pong_tm = time.time()
self._callback(self.on_pong, frame.data)
elif op_code == ABNF.OPCODE_CONT and self.on_cont_message:
self._callback(self.on_data, data, frame.opcode, frame.fin)
self._callback(self.on_cont_message, frame.data, frame.fin)
self._callback(self.on_data, data,
frame.opcode, frame.fin)
self._callback(self.on_cont_message,
frame.data, frame.fin)
else:
data = frame.data
if six.PY3 and frame.opcode == ABNF.OPCODE_TEXT:
@@ -227,8 +234,9 @@ class WebSocketApp(object):
thread.join()
self.keep_running = False
self.sock.close()
self._callback(self.on_close,
*self._get_close_args(close_frame.data if close_frame else None))
close_args = self._get_close_args(
close_frame.data if close_frame else None)
self._callback(self.on_close, *close_args)
self.sock = None
def _get_close_args(self, data):
@@ -244,7 +252,7 @@ class WebSocketApp(object):
return []
if data and len(data) >= 2:
code = 256*six.byte2int(data[0:1]) + six.byte2int(data[1:2])
code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2])
reason = data[2:].decode('utf-8')
return [code, reason]

View File

@@ -21,28 +21,22 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
"""
from __future__ import print_function
import six
import socket
if six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
import struct
import threading
import six
# websocket modules
from ._exceptions import *
from ._abnf import *
from ._exceptions import *
from ._handshake import *
from ._http import *
from ._logging import *
from ._socket import *
from ._utils import *
from ._url import *
from ._logging import *
from ._http import *
from ._handshake import *
from ._ssl_compat import *
__all__ = ['WebSocket', 'create_connection']
"""
websocket python client.
@@ -83,7 +77,7 @@ class WebSocket(object):
def __init__(self, get_mask_key=None, sockopt=None, sslopt=None,
fire_cont_frame=False, enable_multithread=False,
skip_utf8_validation=False, **options):
skip_utf8_validation=False, **_):
"""
Initialize WebSocket object.
"""
@@ -95,7 +89,8 @@ class WebSocket(object):
self.get_mask_key = get_mask_key
# These buffer over the build-up of a single frame.
self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation)
self.cont_frame = continuous_frame(fire_cont_frame, skip_utf8_validation)
self.cont_frame = continuous_frame(
fire_cont_frame, skip_utf8_validation)
if enable_multithread:
self.lock = threading.Lock()
@@ -329,7 +324,8 @@ class WebSocket(object):
if not frame:
# handle error:
# 'NoneType' object has no attribute 'opcode'
raise WebSocketProtocolException("Not a valid frame %s" % frame)
raise WebSocketProtocolException(
"Not a valid frame %s" % frame)
elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT):
self.cont_frame.validate(frame)
self.cont_frame.add(frame)
@@ -339,17 +335,18 @@ class WebSocket(object):
elif frame.opcode == ABNF.OPCODE_CLOSE:
self.send_close()
return (frame.opcode, frame)
return frame.opcode, frame
elif frame.opcode == ABNF.OPCODE_PING:
if len(frame.data) < 126:
self.pong(frame.data)
else:
raise WebSocketProtocolException("Ping message is too long")
raise WebSocketProtocolException(
"Ping message is too long")
if control_frame:
return (frame.opcode, frame)
return frame.opcode, frame
elif frame.opcode == ABNF.OPCODE_PONG:
if control_frame:
return (frame.opcode, frame)
return frame.opcode, frame
def recv_frame(self):
"""
@@ -389,7 +386,8 @@ class WebSocket(object):
try:
self.connected = False
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
self.send(struct.pack('!H', status) +
reason, ABNF.OPCODE_CLOSE)
sock_timeout = self.sock.gettimeout()
self.sock.settimeout(timeout)
try:
@@ -415,7 +413,7 @@ class WebSocket(object):
self.sock.shutdown(socket.SHUT_RDWR)
def shutdown(self):
"close socket, immediately."
"""close socket, immediately."""
if self.sock:
self.sock.close()
self.sock = None

View File

@@ -25,6 +25,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
define websocket exceptions
"""
class WebSocketException(Exception):
"""
websocket exception class.
@@ -72,6 +73,8 @@ class WebSocketBadStatusException(WebSocketException):
"""
WebSocketBadStatusException will be raised when we get bad handshake status code.
"""
def __init__(self, message, status_code):
super(WebSocketBadStatusException, self).__init__(message % status_code)
super(WebSocketBadStatusException, self).__init__(
message % status_code)
self.status_code = status_code

View File

@@ -19,25 +19,22 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
import hashlib
import hmac
import os
import six
from ._exceptions import *
from ._http import *
from ._logging import *
from ._socket import *
if six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
import uuid
import hashlib
import hmac
import os
import sys
from ._logging import *
from ._url import *
from ._socket import*
from ._http import *
from ._exceptions import *
__all__ = ["handshake_response", "handshake"]
if hasattr(hmac, "compare_digest"):
@@ -51,6 +48,7 @@ VERSION = 13
class handshake_response(object):
def __init__(self, status, headers, subprotocol):
self.status = status
self.headers = headers
@@ -73,10 +71,11 @@ def handshake(sock, hostname, port, resource, **options):
def _get_handshake_headers(resource, host, port, options):
headers = []
headers.append("GET %s HTTP/1.1" % resource)
headers.append("Upgrade: websocket")
headers.append("Connection: Upgrade")
headers = [
"GET %s HTTP/1.1" % resource,
"Upgrade: websocket",
"Connection: Upgrade"
]
if port == 80 or port == 443:
hostport = host
else:
@@ -126,7 +125,7 @@ def _get_resp_headers(sock, success_status=101):
_HEADERS_TO_CHECK = {
"upgrade": "websocket",
"connection": "upgrade",
}
}
def _validate(headers, key, subprotocols):

View File

@@ -19,45 +19,49 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
import six
import socket
import errno
import os
import socket
import sys
import six
from ._exceptions import *
from ._logging import *
from ._socket import*
from ._ssl_compat import *
from ._url import *
if six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
from ._logging import *
from ._url import *
from ._socket import*
from ._exceptions import *
from ._ssl_compat import *
__all__ = ["proxy_info", "connect", "read_headers"]
class proxy_info(object):
def __init__(self, **options):
self.host = options.get("http_proxy_host", None)
if self.host:
self.port = options.get("http_proxy_port", 0)
self.auth = options.get("http_proxy_auth", None)
self.auth = options.get("http_proxy_auth", None)
self.no_proxy = options.get("http_no_proxy", None)
else:
self.port = 0
self.auth = None
self.no_proxy = None
def connect(url, options, proxy, socket):
hostname, port, resource, is_secure = parse_url(url)
if socket:
return socket, (hostname, port, resource)
addrinfo_list, need_tunnel, auth = _get_addrinfo_list(hostname, port, is_secure, proxy)
addrinfo_list, need_tunnel, auth = _get_addrinfo_list(
hostname, port, is_secure, proxy)
if not addrinfo_list:
raise WebSocketException(
"Host not found.: " + hostname + ":" + str(port))
@@ -82,10 +86,11 @@ def connect(url, options, proxy, socket):
def _get_addrinfo_list(hostname, port, is_secure, proxy):
phost, pport, pauth = get_proxy_info(hostname, is_secure,
proxy.host, proxy.port, proxy.auth, proxy.no_proxy)
phost, pport, pauth = get_proxy_info(
hostname, is_secure, proxy.host, proxy.port, proxy.auth, proxy.no_proxy)
if not phost:
addrinfo_list = socket.getaddrinfo(hostname, port, 0, 0, socket.SOL_TCP)
addrinfo_list = socket.getaddrinfo(
hostname, port, 0, 0, socket.SOL_TCP)
return addrinfo_list, False, None
else:
pport = pport and pport or 80
@@ -137,14 +142,15 @@ def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
sslopt.get('keyfile', None),
sslopt.get('password', None),
)
# see https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153
# see
# https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153
context.verify_mode = sslopt['cert_reqs']
if HAVE_CONTEXT_CHECK_HOSTNAME:
context.check_hostname = check_hostname
if 'ciphers' in sslopt:
context.set_ciphers(sslopt['ciphers'])
if 'cert_chain' in sslopt :
certfile,keyfile,password = sslopt['cert_chain']
if 'cert_chain' in sslopt:
certfile, keyfile, password = sslopt['cert_chain']
context.load_cert_chain(certfile, keyfile, password)
return context.wrap_socket(
@@ -158,12 +164,13 @@ def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
def _ssl_socket(sock, user_sslopt, hostname):
sslopt = dict(cert_reqs=ssl.CERT_REQUIRED)
sslopt.update(user_sslopt)
certPath = os.path.join(
os.path.dirname(__file__), "cacert.pem")
if os.path.isfile(certPath) and user_sslopt.get('ca_certs', None) == None:
if os.path.isfile(certPath) and user_sslopt.get('ca_certs', None) is None:
sslopt['ca_certs'] = certPath
check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop('check_hostname', True)
check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop(
'check_hostname', True)
if _can_use_sni():
sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname)
@@ -176,6 +183,7 @@ def _ssl_socket(sock, user_sslopt, hostname):
return sock
def _tunnel(sock, host, port, auth):
debug("Connecting proxy...")
connect_header = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port)
@@ -199,9 +207,10 @@ def _tunnel(sock, host, port, auth):
if status != 200:
raise WebSocketProxyException(
"failed CONNECT via proxy status: %r" % status)
return sock
def read_headers(sock):
status = None
headers = {}

View File

@@ -19,7 +19,6 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
import logging
_logger = logging.getLogger('websocket')
@@ -29,15 +28,15 @@ __all__ = ["enableTrace", "dump", "error", "debug", "trace",
"isEnabledForError", "isEnabledForDebug"]
def enableTrace(tracable):
def enableTrace(traceable):
"""
turn on/off the tracability.
turn on/off the traceability.
tracable: boolean value. if set True, tracability is enabled.
traceable: boolean value. if set True, traceability is enabled.
"""
global _traceEnabled
_traceEnabled = tracable
if tracable:
_traceEnabled = traceable
if traceable:
if not _logger.handlers:
_logger.addHandler(logging.StreamHandler())
_logger.setLevel(logging.DEBUG)

View File

@@ -19,13 +19,13 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
import socket
import six
from ._exceptions import *
from ._utils import *
from ._ssl_compat import *
from ._utils import *
DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)]
if hasattr(socket, "SO_KEEPALIVE"):
@@ -42,7 +42,9 @@ _default_timeout = None
__all__ = ["DEFAULT_SOCKET_OPTION", "sock_opt", "setdefaulttimeout", "getdefaulttimeout",
"recv", "recv_line", "send"]
class sock_opt(object):
def __init__(self, sockopt, sslopt):
if sockopt is None:
sockopt = []
@@ -52,6 +54,7 @@ class sock_opt(object):
self.sslopt = sslopt
self.timeout = None
def setdefaulttimeout(timeout):
"""
Set the global timeout setting to connect.
@@ -74,7 +77,7 @@ def recv(sock, bufsize):
raise WebSocketConnectionClosedException("socket is already closed.")
try:
bytes = sock.recv(bufsize)
bytes_ = sock.recv(bufsize)
except socket.timeout as e:
message = extract_err_message(e)
raise WebSocketTimeoutException(message)
@@ -85,10 +88,11 @@ def recv(sock, bufsize):
else:
raise
if not bytes:
raise WebSocketConnectionClosedException("Connection is already closed.")
if not bytes_:
raise WebSocketConnectionClosedException(
"Connection is already closed.")
return bytes
return bytes_
def recv_line(sock):

View File

@@ -19,7 +19,6 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
__all__ = ["HAVE_SSL", "ssl", "SSLError"]
try:

View File

@@ -19,9 +19,9 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
import os
from six.moves.urllib.parse import urlparse
import os
__all__ = ["parse_url", "get_proxy_info"]
@@ -66,7 +66,7 @@ def parse_url(url):
if parsed.query:
resource += "?" + parsed.query
return (hostname, port, resource, is_secure)
return hostname, port, resource, is_secure
DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"]
@@ -82,8 +82,9 @@ def _is_no_proxy_host(hostname, no_proxy):
return hostname in no_proxy
def get_proxy_info(hostname, is_secure,
proxy_host=None, proxy_port=0, proxy_auth=None, no_proxy=None):
def get_proxy_info(
hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None,
no_proxy=None):
"""
try to retrieve proxy host and port from environment
if not provided in options.

View File

@@ -19,16 +19,17 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
import six
__all__ = ["NoLock", "validate_utf8", "extract_err_message"]
class NoLock(object):
def __enter__(self):
pass
def __exit__(self, type, value, traceback):
def __exit__(self, exc_type, exc_value, traceback):
pass
try:
@@ -69,10 +70,11 @@ except ImportError:
def _decode(state, codep, ch):
tp = _UTF8D[ch]
codep = (ch & 0x3f ) | (codep << 6) if (state != _UTF8_ACCEPT) else (0xff >> tp) & (ch)
codep = (ch & 0x3f) | (codep << 6) if (
state != _UTF8_ACCEPT) else (0xff >> tp) & ch
state = _UTF8D[256 + state + tp]
return state, codep;
return state, codep
def _validate_utf8(utfbytes):
state = _UTF8_ACCEPT
@@ -86,6 +88,7 @@ except ImportError:
return True
def validate_utf8(utfbytes):
"""
validate utf8 byte string.
@@ -94,6 +97,7 @@ def validate_utf8(utfbytes):
"""
return _validate_utf8(utfbytes)
def extract_err_message(exception):
if exception.args:
return exception.args[0]

View File

@@ -1,14 +1,33 @@
# -*- coding: utf-8 -*-
#
import six
import sys
sys.path[0:0] = [""]
import os
import os.path
import base64
import socket
import six
# websocket-client
import websocket as ws
from websocket._handshake import _create_sec_websocket_key, \
_validate as _validate_header
from websocket._http import read_headers
from websocket._url import get_proxy_info, parse_url
from websocket._utils import validate_utf8
if six.PY3:
from base64 import decodebytes as base64decode
else:
from base64 import decodestring as base64decode
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
try:
from ssl import SSLError
except ImportError:
@@ -16,37 +35,15 @@ except ImportError:
class SSLError(Exception):
pass
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
import uuid
if six.PY3:
from base64 import decodebytes as base64decode
else:
from base64 import decodestring as base64decode
# websocket-client
import websocket as ws
from websocket._handshake import _create_sec_websocket_key
from websocket._url import parse_url, get_proxy_info
from websocket._utils import validate_utf8
from websocket._handshake import _validate as _validate_header
from websocket._http import read_headers
# Skip test to access the internet.
TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1'
# Skip Secure WebSocket test.
TEST_SECURE_WS = True
TRACABLE = False
TRACEABLE = False
def create_mask_key(n):
def create_mask_key(_):
return "abcd"
@@ -86,7 +83,7 @@ class HeaderSockMock(SockMock):
class WebSocketTest(unittest.TestCase):
def setUp(self):
ws.enableTrace(TRACABLE)
ws.enableTrace(TRACEABLE)
def tearDown(self):
pass
@@ -263,7 +260,7 @@ class WebSocketTest(unittest.TestCase):
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testIter(self):
count = 2
for rsvp in ws.create_connection('ws://stream.meetup.com/2/rsvps'):
for _ in ws.create_connection('ws://stream.meetup.com/2/rsvps'):
count -= 1
if count == 0:
break
@@ -282,7 +279,7 @@ class WebSocketTest(unittest.TestCase):
# s.add_packet(SSLError("The read operation timed out"))
s.add_packet(six.b("baz"))
with self.assertRaises(ws.WebSocketTimeoutException):
data = sock.frame_buffer.recv_strict(9)
sock.frame_buffer.recv_strict(9)
# if six.PY2:
# with self.assertRaises(ws.WebSocketTimeoutException):
# data = sock._recv_strict(9)
@@ -292,7 +289,7 @@ class WebSocketTest(unittest.TestCase):
data = sock.frame_buffer.recv_strict(9)
self.assertEqual(data, six.b("foobarbaz"))
with self.assertRaises(ws.WebSocketConnectionClosedException):
data = sock.frame_buffer.recv_strict(1)
sock.frame_buffer.recv_strict(1)
def testRecvTimeout(self):
sock = ws.WebSocket()
@@ -303,13 +300,13 @@ class WebSocketTest(unittest.TestCase):
s.add_packet(socket.timeout())
s.add_packet(six.b("\x4e\x43\x33\x0e\x10\x0f\x00\x40"))
with self.assertRaises(ws.WebSocketTimeoutException):
data = sock.recv()
sock.recv()
with self.assertRaises(ws.WebSocketTimeoutException):
data = sock.recv()
sock.recv()
data = sock.recv()
self.assertEqual(data, "Hello, World!")
with self.assertRaises(ws.WebSocketConnectionClosedException):
data = sock.recv()
sock.recv()
def testRecvWithSimpleFragmentation(self):
sock = ws.WebSocket()
@@ -374,10 +371,10 @@ class WebSocketTest(unittest.TestCase):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15" \
s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15"
"\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC"))
# OPCODE=CONT, FIN=0, MSG="dear friends, "
s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07" \
s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07"
"\x17MB"))
# OPCODE=CONT, FIN=1, MSG="once more"
s.add_packet(six.b("\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04"))
@@ -397,7 +394,7 @@ class WebSocketTest(unittest.TestCase):
# OPCODE=PING, FIN=1, MSG="Please PONG this"
s.add_packet(six.b("\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17"))
# OPCODE=CONT, FIN=1, MSG="of a good thing"
s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c" \
s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c"
"\x08\x0c\x04"))
data = sock.recv()
self.assertEqual(data, "Too much of a good thing")
@@ -479,7 +476,7 @@ class WebSocketAppTest(unittest.TestCase):
"""
def setUp(self):
ws.enableTrace(TRACABLE)
ws.enableTrace(TRACEABLE)
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()