This commit is contained in:
liris
2015-03-18 10:39:40 +09:00
4 changed files with 141 additions and 106 deletions

View File

@@ -29,12 +29,13 @@ import traceback
import sys import sys
import select import select
import six import six
import logging
from ._core import WebSocket, getdefaulttimeout, logger from ._core import WebSocket, getdefaulttimeout
from ._exceptions import * from ._exceptions import *
from ._logging import *
from websocket._abnf import ABNF from websocket._abnf import ABNF
class WebSocketApp(object): class WebSocketApp(object):
""" """
Higher level of APIs are provided. Higher level of APIs are provided.
@@ -61,15 +62,17 @@ class WebSocketApp(object):
The passing 2nd arugment is exception object. The passing 2nd arugment is exception object.
on_close: callable object which is called when closed the connection. on_close: callable object which is called when closed the connection.
this function has one argument. The arugment is this class object. this function has one argument. The arugment is this class object.
on_cont_message: callback object which is called when recieve continued frame data. on_cont_message: callback object which is called when recieve continued
frame data.
on_message has 3 arguments. on_message has 3 arguments.
The 1st arugment is this class object. The 1st arugment is this class object.
The passing 2nd arugment is utf-8 string which we get from the server. The passing 2nd arugment is utf-8 string which we get from the server.
The 3rd arugment is continue flag. if 0, the data continue to next frame data The 3rd arugment is continue flag. if 0, the data continue
keep_running: a boolean flag indicating whether the app's main loop should to next frame data
keep running, defaults to True keep_running: a boolean flag indicating whether the app's main loop
get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's should keep running, defaults to True
docstring for more information get_mask_key: a callable to produce new mask keys,
see the WebSocket.set_mask_key's docstring for more information
subprotocols: array of available sub protocols. default is None. subprotocols: array of available sub protocols. default is None.
""" """
self.url = url self.url = url
@@ -86,12 +89,13 @@ class WebSocketApp(object):
self.get_mask_key = get_mask_key self.get_mask_key = get_mask_key
self.sock = None self.sock = None
self.last_ping_tm = 0 self.last_ping_tm = 0
self.subprotocols =subprotocols self.subprotocols = subprotocols
def send(self, data, opcode=ABNF.OPCODE_TEXT): def send(self, data, opcode=ABNF.OPCODE_TEXT):
""" """
send message. send message.
data: message to send. If you set opcode to OPCODE_TEXT, data must be utf-8 string or unicode. data: message to send. If you set opcode to OPCODE_TEXT,
data must be utf-8 string or unicode.
opcode: operation code of data. default is OPCODE_TEXT. opcode: operation code of data. default is OPCODE_TEXT.
""" """
@@ -112,16 +116,20 @@ class WebSocketApp(object):
if self.sock: if self.sock:
self.sock.ping() self.sock.ping()
def run_forever(self, sockopt=None, sslopt=None, ping_interval=0, ping_timeout=None, def run_forever(self, sockopt=None, sslopt=None,
http_proxy_host=None, http_proxy_port=None, http_no_proxy=None, http_proxy_auth=None, ping_interval=0, ping_timeout=None,
skip_utf8_validation=False): http_proxy_host=None, http_proxy_port=None,
http_no_proxy=None, http_proxy_auth=None,
skip_utf8_validation=False):
""" """
run event loop for WebSocket framework. run event loop for WebSocket framework.
This loop is infinite loop and is alive during websocket is available. This loop is infinite loop and is alive during websocket is available.
sockopt: values for socket.setsockopt. sockopt: values for socket.setsockopt.
sockopt must be tuple and each element is argument of sock.setscokopt. sockopt must be tuple
and each element is argument of sock.setscokopt.
sslopt: ssl socket optional dict. sslopt: ssl socket optional dict.
ping_interval: automatically send "ping" command every specified period(second) ping_interval: automatically send "ping" command
every specified period(second)
if set to 0, not send automatically. if set to 0, not send automatically.
ping_timeout: timeout(second) if the pong message is not recieved. ping_timeout: timeout(second) if the pong message is not recieved.
http_proxy_host: http proxy host name. http_proxy_host: http proxy host name.
@@ -130,7 +138,7 @@ class WebSocketApp(object):
skip_utf8_validation: skip utf8 validation. skip_utf8_validation: skip utf8 validation.
""" """
if not ping_timeout or ping_timeout<=0: if not ping_timeout or ping_timeout <= 0:
ping_timeout = None ping_timeout = None
if sockopt is None: if sockopt is None:
sockopt = [] sockopt = []
@@ -142,18 +150,20 @@ class WebSocketApp(object):
close_frame = None close_frame = None
try: try:
self.sock = WebSocket(self.get_mask_key, sockopt=sockopt, sslopt=sslopt, self.sock = WebSocket(self.get_mask_key,
sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=self.on_cont_message and True or False, fire_cont_frame=self.on_cont_message and True or False,
skip_utf8_validation=skip_utf8_validation) skip_utf8_validation=skip_utf8_validation)
self.sock.settimeout(getdefaulttimeout()) self.sock.settimeout(getdefaulttimeout())
self.sock.connect(self.url, header=self.header, cookie=self.cookie, self.sock.connect(self.url, header=self.header, cookie=self.cookie,
http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port, http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port,
http_no_proxy=http_no_proxy, http_proxy_auth=http_proxy_auth, http_no_proxy=http_no_proxy, http_proxy_auth=http_proxy_auth,
subprotocols=self.subprotocols) subprotocols=self.subprotocols)
self._callback(self.on_open) self._callback(self.on_open)
if ping_interval: if ping_interval:
event = threading.Event() event = threading.Event()
thread = threading.Thread(target=self._send_ping, args=(ping_interval, event)) thread = threading.Thread(target=self._send_ping, args=(ping_interval, event))
thread.setDaemon(True) thread.setDaemon(True)
thread.start() thread.start()
@@ -194,7 +204,7 @@ class WebSocketApp(object):
*self._get_close_args(close_frame.data if close_frame else None)) *self._get_close_args(close_frame.data if close_frame else None))
self.sock = None self.sock = None
def _get_close_args(self,data): def _get_close_args(self, data):
""" this functions extracts the code, reason from the close body """ this functions extracts the code, reason from the close body
if they exists, and if the self.on_close except three arguments """ if they exists, and if the self.on_close except three arguments """
import inspect import inspect
@@ -202,19 +212,19 @@ class WebSocketApp(object):
if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3: if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3:
return [] return []
if data and len(data) >=2: if data and len(data) >= 2:
code = 256*six.byte2int(data[0:1]) + six.byte2int(data[1:2]) code = 256*six.byte2int(data[0:1]) + six.byte2int(data[1:2])
reason = data[2:].decode('utf-8') reason = data[2:].decode('utf-8')
return [code,reason] return [code, reason]
return [None,None] return [None, None]
def _callback(self, callback, *args): def _callback(self, callback, *args):
if callback: if callback:
try: try:
callback(self, *args) callback(self, *args)
except Exception as e: except Exception as e:
logger.error(e) error(e)
if logger.isEnabledFor(logging.DEBUG): if isEnableForDebug():
_, _, tb = sys.exc_info() _, _, tb = sys.exc_info()
traceback.print_tb(tb) traceback.print_tb(tb)

View File

@@ -53,7 +53,6 @@ import struct
import uuid import uuid
import hashlib import hashlib
import threading import threading
import logging
# websocket modules # websocket modules
from ._exceptions import * from ._exceptions import *
@@ -61,6 +60,7 @@ from ._abnf import *
from ._socket import * from ._socket import *
from ._utils import * from ._utils import *
from ._url import * from ._url import *
from ._logging import *
""" """
websocket python client. websocket python client.
@@ -75,41 +75,14 @@ Please see http://tools.ietf.org/html/rfc6455 for protocol.
VERSION = 13 VERSION = 13
logger = logging.getLogger()
traceEnabled = False
def enableTrace(tracable):
"""
turn on/off the tracability.
tracable: boolean value. if set True, tracability is enabled.
"""
global traceEnabled
traceEnabled = tracable
if tracable:
if not logger.handlers:
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.DEBUG)
def _dump(title, message):
if traceEnabled:
logger.debug("--- " + title + " ---")
logger.debug(message)
logger.debug("-----------------------")
def create_connection(url, timeout=None, **options): def create_connection(url, timeout=None, **options):
""" """
connect to url and return websocket object. connect to url and return websocket object.
Connect to url and return the WebSocket object. Connect to url and return the WebSocket object.
Passing optional timeout parameter will set the timeout on the socket. Passing optional timeout parameter will set the timeout on the socket.
If no timeout is supplied, the global default timeout setting returned by getdefauttimeout() is used. If no timeout is supplied,
the global default timeout setting returned by getdefauttimeout() is used.
You can customize using 'options'. You can customize using 'options'.
If you set "header" list object, you can set your own custom header. If you set "header" list object, you can set your own custom header.
@@ -119,7 +92,8 @@ def create_connection(url, timeout=None, **options):
timeout: socket timeout time. This value is integer. timeout: socket timeout time. This value is integer.
if you set None for this value, it means "use default_timeout value" if you set None for this value,
it means "use default_timeout value"
options: "header" -> custom http header list. options: "header" -> custom http header list.
@@ -127,12 +101,14 @@ def create_connection(url, timeout=None, **options):
"http_proxy_host" - http proxy host name. "http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80. "http_proxy_port" - http proxy port. If not set, set to 80.
"http_no_proxy" - host names, which doesn't use proxy. "http_no_proxy" - host names, which doesn't use proxy.
"http_proxy_auth" - http proxy auth infomation. tuple of username and password. "http_proxy_auth" - http proxy auth infomation.
tuple of username and password.
defualt is None defualt is None
"enable_multithread" -> enable lock for multithread. "enable_multithread" -> enable lock for multithread.
"sockopt" -> socket options "sockopt" -> socket options
"sslopt" -> ssl option "sslopt" -> ssl option
"subprotocols" - array of available sub protocols. default is None. "subprotocols" - array of available sub protocols.
default is None.
"skip_utf8_validation" - skip utf8 validation. "skip_utf8_validation" - skip utf8 validation.
""" """
sockopt = options.get("sockopt", []) sockopt = options.get("sockopt", [])
@@ -141,8 +117,9 @@ def create_connection(url, timeout=None, **options):
enable_multithread = options.get("enable_multithread", False) enable_multithread = options.get("enable_multithread", False)
skip_utf8_validation = options.get("skip_utf8_validation", False) skip_utf8_validation = options.get("skip_utf8_validation", False)
websock = WebSocket(sockopt=sockopt, sslopt=sslopt, websock = WebSocket(sockopt=sockopt, sslopt=sslopt,
fire_cont_frame = fire_cont_frame, enable_multithread=enable_multithread, fire_cont_frame=fire_cont_frame,
skip_utf8_validation=skip_utf8_validation) enable_multithread=enable_multithread,
skip_utf8_validation=skip_utf8_validation)
websock.settimeout(timeout if timeout is not None else getdefaulttimeout()) websock.settimeout(timeout if timeout is not None else getdefaulttimeout())
websock.connect(url, **options) websock.connect(url, **options)
return websock return websock
@@ -188,7 +165,8 @@ class WebSocket(object):
""" """
def __init__(self, get_mask_key=None, sockopt=None, sslopt=None, def __init__(self, get_mask_key=None, sockopt=None, sslopt=None,
fire_cont_frame=False, enable_multithread=False, skip_utf8_validation=False): fire_cont_frame=False, enable_multithread=False,
skip_utf8_validation=False):
""" """
Initalize WebSocket object. Initalize WebSocket object.
""" """
@@ -253,7 +231,8 @@ class WebSocket(object):
def connect(self, url, **options): def connect(self, url, **options):
""" """
Connect to url. url is websocket url scheme. ie. ws://host:port/resource Connect to url. url is websocket url scheme.
ie. ws://host:port/resource
You can customize using 'options'. You can customize using 'options'.
If you set "header" list object, you can set your own custom header. If you set "header" list object, you can set your own custom header.
@@ -271,9 +250,11 @@ class WebSocket(object):
"http_proxy_host" - http proxy host name. "http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80. "http_proxy_port" - http proxy port. If not set, set to 80.
"http_no_proxy" - host names, which doesn't use proxy. "http_no_proxy" - host names, which doesn't use proxy.
"http_proxy_auth" - http proxy auth infomation. tuple of username and password. "http_proxy_auth" - http proxy auth infomation.
defualt is None tuple of username and password.
"subprotocols" - array of available sub protocols. default is None. defualt is None
"subprotocols" - array of available sub protocols.
default is None.
""" """
@@ -335,7 +316,7 @@ class WebSocket(object):
self._handshake(hostname, port, resource, **options) self._handshake(hostname, port, resource, **options)
def _tunnel(self, host, port, auth): def _tunnel(self, host, port, auth):
logger.debug("Connecting proxy...") debug("Connecting proxy...")
connect_header = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port) connect_header = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port)
# TODO: support digest auth. # TODO: support digest auth.
if auth and auth[0]: if auth and auth[0]:
@@ -345,7 +326,7 @@ class WebSocket(object):
encoded_str = base64encode(auth_str.encode()).strip().decode() encoded_str = base64encode(auth_str.encode()).strip().decode()
connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str
connect_header += "\r\n" connect_header += "\r\n"
_dump("request header", connect_header) dump("request header", connect_header)
self._send(connect_header) self._send(connect_header)
@@ -357,7 +338,7 @@ class WebSocket(object):
if status != 200: if status != 200:
raise WebSocketProxyException("failed CONNECT via proxy status: " + str(status)) raise WebSocketProxyException("failed CONNECT via proxy status: " + str(status))
def _get_resp_headers(self, success_status = 101): def _get_resp_headers(self, success_status=101):
status, resp_headers = self._read_headers() status, resp_headers = self._read_headers()
if status != success_status: if status != success_status:
self.close() self.close()
@@ -394,7 +375,7 @@ class WebSocket(object):
cookie = options.get("cookie", None) cookie = options.get("cookie", None)
if cookie: if cookie:
headers.append("Cookie: %s" % cookie) headers.append("Cookie: %s" % cookie)
headers.append("") headers.append("")
headers.append("") headers.append("")
@@ -406,7 +387,7 @@ class WebSocket(object):
header_str = "\r\n".join(headers) header_str = "\r\n".join(headers)
self._send(header_str) self._send(header_str)
_dump("request header", header_str) dump("request header", header_str)
resp_headers = self._get_resp_headers() resp_headers = self._get_resp_headers()
success = self._validate_header(resp_headers, key, options.get("subprotocols")) success = self._validate_header(resp_headers, key, options.get("subprotocols"))
@@ -428,11 +409,10 @@ class WebSocket(object):
if subprotocols: if subprotocols:
subproto = headers.get("sec-websocket-protocol", None) subproto = headers.get("sec-websocket-protocol", None)
if not subproto or subproto not in subprotocols: if not subproto or subproto not in subprotocols:
logger.error("Invalid subprotocol: " + str(subprotocols)) error("Invalid subprotocol: " + str(subprotocols))
return False return False
self.subprotocol = subproto self.subprotocol = subproto
result = headers.get("sec-websocket-accept", None) result = headers.get("sec-websocket-accept", None)
if not result: if not result:
return False return False
@@ -448,17 +428,14 @@ class WebSocket(object):
def _read_headers(self): def _read_headers(self):
status = None status = None
headers = {} headers = {}
if traceEnabled: trace("--- response header ---")
logger.debug("--- response header ---")
while True: while True:
line = self._recv_line() line = self._recv_line()
line = line.decode('utf-8').strip() line = line.decode('utf-8').strip()
if not line: if not line:
break break
trace(line)
if traceEnabled:
logger.debug(line)
if not status: if not status:
status_info = line.split(" ", 2) status_info = line.split(" ", 2)
@@ -471,8 +448,7 @@ class WebSocket(object):
else: else:
raise WebSocketException("Invalid header") raise WebSocketException("Invalid header")
if traceEnabled: trace("-----------------------")
logger.debug("-----------------------")
return status, headers return status, headers
@@ -509,8 +485,7 @@ class WebSocket(object):
frame.get_mask_key = self.get_mask_key frame.get_mask_key = self.get_mask_key
data = frame.format() data = frame.format()
length = len(data) length = len(data)
if traceEnabled: trace("send: " + repr(data))
logger.debug("send: " + repr(data))
with self.lock: with self.lock:
while data: while data:
@@ -519,7 +494,6 @@ class WebSocket(object):
return length return length
def send_binary(self, payload): def send_binary(self, payload):
return self.send(payload, ABNF.OPCODE_BINARY) return self.send(payload, ABNF.OPCODE_BINARY)
@@ -657,7 +631,6 @@ class WebSocket(object):
return frame return frame
def send_close(self, status=STATUS_NORMAL, reason=six.b("")): def send_close(self, status=STATUS_NORMAL, reason=six.b("")):
""" """
send close data to the server. send close data to the server.
@@ -690,10 +663,10 @@ class WebSocket(object):
self.sock.settimeout(3) self.sock.settimeout(3)
try: try:
frame = self.recv_frame() frame = self.recv_frame()
if logger.isEnabledFor(logging.ERROR): if isEnableForError():
recv_status = struct.unpack("!H", frame.data)[0] recv_status = struct.unpack("!H", frame.data)[0]
if recv_status != STATUS_NORMAL: if recv_status != STATUS_NORMAL:
logger.error("close status: " + repr(recv_status)) error("close status: " + repr(recv_status))
except: except:
pass pass
self.sock.settimeout(timeout) self.sock.settimeout(timeout)
@@ -748,7 +721,6 @@ class WebSocket(object):
self._recv_buffer = [unified[bufsize:]] self._recv_buffer = [unified[bufsize:]]
return unified[:bufsize] return unified[:bufsize]
def _recv_line(self): def _recv_line(self):
try: try:
return recv_line(self.sock) return recv_line(self.sock)
@@ -760,17 +732,3 @@ class WebSocket(object):
raise raise
except: except:
raise raise
if __name__ == "__main__":
enableTrace(True)
ws = create_connection("ws://echo.websocket.org/")
print("Sending 'Hello, World'...")
ws.send("Hello, World")
print("Sent")
print("Receiving...")
result = ws.recv()
print("Received '%s'" % result)
ws.close()

68
websocket/_logging.py Normal file
View File

@@ -0,0 +1,68 @@
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor,
Boston, MA 02110-1335 USA
"""
import logging
_logger = logging.getLogger()
_traceEnabled = False
def enableTrace(tracable):
"""
turn on/off the tracability.
tracable: boolean value. if set True, tracability is enabled.
"""
global _traceEnabled
_traceEnabled = tracable
if tracable:
if not _logger.handlers:
_logger.addHandler(logging.StreamHandler())
_logger.setLevel(logging.DEBUG)
def dump(title, message):
if _traceEnabled:
_logger.debug("--- " + title + " ---")
_logger.debug(message)
_logger.debug("-----------------------")
def error(msg):
_logger.error(msg)
def debug(msg):
_logger.debug(msg)
def trace(msg):
if _traceEnabled:
_logger.debug(msg)
def isEnableForError():
return _logger.isEnableFor(logging.ERROR)
def isEnableForDebug():
return _logger.isEnableFor(logging.DEBUG)

View File

@@ -32,7 +32,6 @@ from websocket._utils import validate_utf8
# Skip test to access the internet. # Skip test to access the internet.
TEST_WITH_INTERNET = False TEST_WITH_INTERNET = False
# TEST_WITH_INTERNET = True
# Skip Secure WebSocket test. # Skip Secure WebSocket test.
TEST_SECURE_WS = False TEST_SECURE_WS = False