@@ -43,7 +43,8 @@ class WebSocketApp(object):
|
||||
on_open=None, on_message=None, on_error=None,
|
||||
on_close=None, on_ping=None, on_pong=None,
|
||||
on_cont_message=None,
|
||||
keep_running=True, get_mask_key=None, cookie=None):
|
||||
keep_running=True, get_mask_key=None, cookie=None,
|
||||
subprotocols=None):
|
||||
"""
|
||||
url: websocket url.
|
||||
header: custom header for websocket handshake.
|
||||
@@ -53,21 +54,22 @@ class WebSocketApp(object):
|
||||
on_message has 2 arguments.
|
||||
The 1st arugment is this class object.
|
||||
The passing 2nd arugment is utf-8 string which we get from the server.
|
||||
on_error: callable object which is called when we get error.
|
||||
on_error: callable object which is called when we get error.
|
||||
on_error has 2 arguments.
|
||||
The 1st arugment is this class object.
|
||||
The passing 2nd arugment is exception object.
|
||||
on_close: callable object which is called when closed the connection.
|
||||
on_close: callable object which is called when closed the connection.
|
||||
this function has one argument. The arugment is this class object.
|
||||
on_cont_message: callback object which is called when recieve continued frame data.
|
||||
on_cont_message: callback object which is called when recieve continued frame data.
|
||||
on_message has 3 arguments.
|
||||
The 1st arugment is this class object.
|
||||
The passing 2nd arugment is utf-8 string which we get from the server.
|
||||
The 3rd arugment is continue flag. if 0, the data continue to next frame data
|
||||
keep_running: a boolean flag indicating whether the app's main loop should
|
||||
keep_running: a boolean flag indicating whether the app's main loop should
|
||||
keep running, defaults to True
|
||||
get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's
|
||||
get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's
|
||||
docstring for more information
|
||||
subprotocols: array of available sub protocols. default is None.
|
||||
"""
|
||||
self.url = url
|
||||
self.header = header
|
||||
@@ -83,6 +85,7 @@ class WebSocketApp(object):
|
||||
self.get_mask_key = get_mask_key
|
||||
self.sock = None
|
||||
self.last_ping_tm = 0
|
||||
self.subprotocols =subprotocols
|
||||
|
||||
def send(self, data, opcode=ABNF.OPCODE_TEXT):
|
||||
"""
|
||||
@@ -138,7 +141,8 @@ class WebSocketApp(object):
|
||||
fire_cont_frame=self.on_cont_message and True or False)
|
||||
self.sock.settimeout(getdefaulttimeout())
|
||||
self.sock.connect(self.url, header=self.header, cookie=self.cookie,
|
||||
http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port)
|
||||
http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port,
|
||||
subprotocols=self.subprotocols)
|
||||
self._callback(self.on_open)
|
||||
|
||||
if ping_interval:
|
||||
@@ -158,7 +162,7 @@ class WebSocketApp(object):
|
||||
if r:
|
||||
op_code, frame = self.sock.recv_data_frame(True)
|
||||
if op_code == ABNF.OPCODE_CLOSE:
|
||||
close_frmae = frame
|
||||
close_frame = frame
|
||||
break
|
||||
elif op_code == ABNF.OPCODE_PING:
|
||||
self._callback(self.on_ping, frame.data)
|
||||
|
@@ -208,6 +208,7 @@ def create_connection(url, timeout=None, **options):
|
||||
"enable_multithread" -> enable lock for multithread.
|
||||
"sockopt" -> socket options
|
||||
"sslopt" -> ssl option
|
||||
"subprotocols" - array of available sub protocols. default is None.
|
||||
"""
|
||||
sockopt = options.get("sockopt", [])
|
||||
sslopt = options.get("sslopt", {})
|
||||
@@ -410,6 +411,7 @@ class WebSocket(object):
|
||||
"cookie" -> cookie value.
|
||||
"http_proxy_host" - http proxy host name.
|
||||
"http_proxy_port" - http proxy port. If not set, set to 80.
|
||||
"subprotocols" - array of available sub protocols. default is None.
|
||||
|
||||
"""
|
||||
|
||||
@@ -509,6 +511,10 @@ class WebSocket(object):
|
||||
headers.append("Sec-WebSocket-Key: %s" % key)
|
||||
headers.append("Sec-WebSocket-Version: %s" % VERSION)
|
||||
|
||||
subprotocols = options.get("subprotocols")
|
||||
if subprotocols:
|
||||
headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
|
||||
|
||||
if "header" in options:
|
||||
headers.extend(options["header"])
|
||||
|
||||
@@ -530,14 +536,14 @@ class WebSocket(object):
|
||||
_dump("request header", header_str)
|
||||
|
||||
resp_headers = self._get_resp_headers()
|
||||
success = self._validate_header(resp_headers, key)
|
||||
success = self._validate_header(resp_headers, key, options.get("subprotocols"))
|
||||
if not success:
|
||||
self.close()
|
||||
raise WebSocketException("Invalid WebSocket Header")
|
||||
|
||||
self.connected = True
|
||||
|
||||
def _validate_header(self, headers, key):
|
||||
def _validate_header(self, headers, key, subprotocols):
|
||||
for k, v in _HEADERS_TO_CHECK.items():
|
||||
r = headers.get(k, None)
|
||||
if not r:
|
||||
@@ -545,6 +551,13 @@ class WebSocket(object):
|
||||
r = r.lower()
|
||||
if v != r:
|
||||
return False
|
||||
|
||||
if subprotocols:
|
||||
subproto = headers.get("sec-websocket-protocol", None)
|
||||
if not subproto or subproto not in subprotocols:
|
||||
logger.error("Invalid subprotocol: " + str(subprotocols))
|
||||
return False
|
||||
|
||||
|
||||
result = headers.get("sec-websocket-accept", None)
|
||||
if not result:
|
||||
|
@@ -184,25 +184,31 @@ class WebSocketTest(unittest.TestCase):
|
||||
"connection": "upgrade",
|
||||
"sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=",
|
||||
}
|
||||
self.assertEqual(sock._validate_header(required_header, key), True)
|
||||
self.assertEqual(sock._validate_header(required_header, key, None), True)
|
||||
|
||||
header = required_header.copy()
|
||||
header["upgrade"] = "http"
|
||||
self.assertEqual(sock._validate_header(header, key), False)
|
||||
self.assertEqual(sock._validate_header(header, key, None), False)
|
||||
del header["upgrade"]
|
||||
self.assertEqual(sock._validate_header(header, key), False)
|
||||
self.assertEqual(sock._validate_header(header, key, None), False)
|
||||
|
||||
header = required_header.copy()
|
||||
header["connection"] = "something"
|
||||
self.assertEqual(sock._validate_header(header, key), False)
|
||||
self.assertEqual(sock._validate_header(header, key, None), False)
|
||||
del header["connection"]
|
||||
self.assertEqual(sock._validate_header(header, key), False)
|
||||
self.assertEqual(sock._validate_header(header, key, None), False)
|
||||
|
||||
header = required_header.copy()
|
||||
header["sec-websocket-accept"] = "something"
|
||||
self.assertEqual(sock._validate_header(header, key), False)
|
||||
self.assertEqual(sock._validate_header(header, key, None), False)
|
||||
del header["sec-websocket-accept"]
|
||||
self.assertEqual(sock._validate_header(header, key), False)
|
||||
self.assertEqual(sock._validate_header(header, key, None), False)
|
||||
|
||||
|
||||
header = required_header.copy()
|
||||
header["sec-websocket-protocol"] = "sub1"
|
||||
self.assertEqual(sock._validate_header(header, key, ["sub1", "sub2"]), True)
|
||||
self.assertEqual(sock._validate_header(header, key, ["sub2", "sub3"]), False)
|
||||
|
||||
def testReadHeader(self):
|
||||
sock = ws.WebSocket()
|
||||
|
Reference in New Issue
Block a user