@@ -43,7 +43,8 @@ class WebSocketApp(object):
|
|||||||
on_open=None, on_message=None, on_error=None,
|
on_open=None, on_message=None, on_error=None,
|
||||||
on_close=None, on_ping=None, on_pong=None,
|
on_close=None, on_ping=None, on_pong=None,
|
||||||
on_cont_message=None,
|
on_cont_message=None,
|
||||||
keep_running=True, get_mask_key=None, cookie=None):
|
keep_running=True, get_mask_key=None, cookie=None,
|
||||||
|
subprotocols=None):
|
||||||
"""
|
"""
|
||||||
url: websocket url.
|
url: websocket url.
|
||||||
header: custom header for websocket handshake.
|
header: custom header for websocket handshake.
|
||||||
@@ -53,21 +54,22 @@ class WebSocketApp(object):
|
|||||||
on_message has 2 arguments.
|
on_message has 2 arguments.
|
||||||
The 1st arugment is this class object.
|
The 1st arugment is this class object.
|
||||||
The passing 2nd arugment is utf-8 string which we get from the server.
|
The passing 2nd arugment is utf-8 string which we get from the server.
|
||||||
on_error: callable object which is called when we get error.
|
on_error: callable object which is called when we get error.
|
||||||
on_error has 2 arguments.
|
on_error has 2 arguments.
|
||||||
The 1st arugment is this class object.
|
The 1st arugment is this class object.
|
||||||
The passing 2nd arugment is exception object.
|
The passing 2nd arugment is exception object.
|
||||||
on_close: callable object which is called when closed the connection.
|
on_close: callable object which is called when closed the connection.
|
||||||
this function has one argument. The arugment is this class object.
|
this function has one argument. The arugment is this class object.
|
||||||
on_cont_message: callback object which is called when recieve continued frame data.
|
on_cont_message: callback object which is called when recieve continued frame data.
|
||||||
on_message has 3 arguments.
|
on_message has 3 arguments.
|
||||||
The 1st arugment is this class object.
|
The 1st arugment is this class object.
|
||||||
The passing 2nd arugment is utf-8 string which we get from the server.
|
The passing 2nd arugment is utf-8 string which we get from the server.
|
||||||
The 3rd arugment is continue flag. if 0, the data continue to next frame data
|
The 3rd arugment is continue flag. if 0, the data continue to next frame data
|
||||||
keep_running: a boolean flag indicating whether the app's main loop should
|
keep_running: a boolean flag indicating whether the app's main loop should
|
||||||
keep running, defaults to True
|
keep running, defaults to True
|
||||||
get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's
|
get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's
|
||||||
docstring for more information
|
docstring for more information
|
||||||
|
subprotocols: array of available sub protocols. default is None.
|
||||||
"""
|
"""
|
||||||
self.url = url
|
self.url = url
|
||||||
self.header = header
|
self.header = header
|
||||||
@@ -83,6 +85,7 @@ class WebSocketApp(object):
|
|||||||
self.get_mask_key = get_mask_key
|
self.get_mask_key = get_mask_key
|
||||||
self.sock = None
|
self.sock = None
|
||||||
self.last_ping_tm = 0
|
self.last_ping_tm = 0
|
||||||
|
self.subprotocols =subprotocols
|
||||||
|
|
||||||
def send(self, data, opcode=ABNF.OPCODE_TEXT):
|
def send(self, data, opcode=ABNF.OPCODE_TEXT):
|
||||||
"""
|
"""
|
||||||
@@ -138,7 +141,8 @@ class WebSocketApp(object):
|
|||||||
fire_cont_frame=self.on_cont_message and True or False)
|
fire_cont_frame=self.on_cont_message and True or False)
|
||||||
self.sock.settimeout(getdefaulttimeout())
|
self.sock.settimeout(getdefaulttimeout())
|
||||||
self.sock.connect(self.url, header=self.header, cookie=self.cookie,
|
self.sock.connect(self.url, header=self.header, cookie=self.cookie,
|
||||||
http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port)
|
http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port,
|
||||||
|
subprotocols=self.subprotocols)
|
||||||
self._callback(self.on_open)
|
self._callback(self.on_open)
|
||||||
|
|
||||||
if ping_interval:
|
if ping_interval:
|
||||||
@@ -158,7 +162,7 @@ class WebSocketApp(object):
|
|||||||
if r:
|
if r:
|
||||||
op_code, frame = self.sock.recv_data_frame(True)
|
op_code, frame = self.sock.recv_data_frame(True)
|
||||||
if op_code == ABNF.OPCODE_CLOSE:
|
if op_code == ABNF.OPCODE_CLOSE:
|
||||||
close_frmae = frame
|
close_frame = frame
|
||||||
break
|
break
|
||||||
elif op_code == ABNF.OPCODE_PING:
|
elif op_code == ABNF.OPCODE_PING:
|
||||||
self._callback(self.on_ping, frame.data)
|
self._callback(self.on_ping, frame.data)
|
||||||
|
@@ -208,6 +208,7 @@ def create_connection(url, timeout=None, **options):
|
|||||||
"enable_multithread" -> enable lock for multithread.
|
"enable_multithread" -> enable lock for multithread.
|
||||||
"sockopt" -> socket options
|
"sockopt" -> socket options
|
||||||
"sslopt" -> ssl option
|
"sslopt" -> ssl option
|
||||||
|
"subprotocols" - array of available sub protocols. default is None.
|
||||||
"""
|
"""
|
||||||
sockopt = options.get("sockopt", [])
|
sockopt = options.get("sockopt", [])
|
||||||
sslopt = options.get("sslopt", {})
|
sslopt = options.get("sslopt", {})
|
||||||
@@ -410,6 +411,7 @@ class WebSocket(object):
|
|||||||
"cookie" -> cookie value.
|
"cookie" -> cookie value.
|
||||||
"http_proxy_host" - http proxy host name.
|
"http_proxy_host" - http proxy host name.
|
||||||
"http_proxy_port" - http proxy port. If not set, set to 80.
|
"http_proxy_port" - http proxy port. If not set, set to 80.
|
||||||
|
"subprotocols" - array of available sub protocols. default is None.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -509,6 +511,10 @@ class WebSocket(object):
|
|||||||
headers.append("Sec-WebSocket-Key: %s" % key)
|
headers.append("Sec-WebSocket-Key: %s" % key)
|
||||||
headers.append("Sec-WebSocket-Version: %s" % VERSION)
|
headers.append("Sec-WebSocket-Version: %s" % VERSION)
|
||||||
|
|
||||||
|
subprotocols = options.get("subprotocols")
|
||||||
|
if subprotocols:
|
||||||
|
headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
|
||||||
|
|
||||||
if "header" in options:
|
if "header" in options:
|
||||||
headers.extend(options["header"])
|
headers.extend(options["header"])
|
||||||
|
|
||||||
@@ -530,14 +536,14 @@ class WebSocket(object):
|
|||||||
_dump("request header", header_str)
|
_dump("request header", header_str)
|
||||||
|
|
||||||
resp_headers = self._get_resp_headers()
|
resp_headers = self._get_resp_headers()
|
||||||
success = self._validate_header(resp_headers, key)
|
success = self._validate_header(resp_headers, key, options.get("subprotocols"))
|
||||||
if not success:
|
if not success:
|
||||||
self.close()
|
self.close()
|
||||||
raise WebSocketException("Invalid WebSocket Header")
|
raise WebSocketException("Invalid WebSocket Header")
|
||||||
|
|
||||||
self.connected = True
|
self.connected = True
|
||||||
|
|
||||||
def _validate_header(self, headers, key):
|
def _validate_header(self, headers, key, subprotocols):
|
||||||
for k, v in _HEADERS_TO_CHECK.items():
|
for k, v in _HEADERS_TO_CHECK.items():
|
||||||
r = headers.get(k, None)
|
r = headers.get(k, None)
|
||||||
if not r:
|
if not r:
|
||||||
@@ -545,6 +551,13 @@ class WebSocket(object):
|
|||||||
r = r.lower()
|
r = r.lower()
|
||||||
if v != r:
|
if v != r:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if subprotocols:
|
||||||
|
subproto = headers.get("sec-websocket-protocol", None)
|
||||||
|
if not subproto or subproto not in subprotocols:
|
||||||
|
logger.error("Invalid subprotocol: " + str(subprotocols))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
result = headers.get("sec-websocket-accept", None)
|
result = headers.get("sec-websocket-accept", None)
|
||||||
if not result:
|
if not result:
|
||||||
|
@@ -184,25 +184,31 @@ class WebSocketTest(unittest.TestCase):
|
|||||||
"connection": "upgrade",
|
"connection": "upgrade",
|
||||||
"sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=",
|
"sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=",
|
||||||
}
|
}
|
||||||
self.assertEqual(sock._validate_header(required_header, key), True)
|
self.assertEqual(sock._validate_header(required_header, key, None), True)
|
||||||
|
|
||||||
header = required_header.copy()
|
header = required_header.copy()
|
||||||
header["upgrade"] = "http"
|
header["upgrade"] = "http"
|
||||||
self.assertEqual(sock._validate_header(header, key), False)
|
self.assertEqual(sock._validate_header(header, key, None), False)
|
||||||
del header["upgrade"]
|
del header["upgrade"]
|
||||||
self.assertEqual(sock._validate_header(header, key), False)
|
self.assertEqual(sock._validate_header(header, key, None), False)
|
||||||
|
|
||||||
header = required_header.copy()
|
header = required_header.copy()
|
||||||
header["connection"] = "something"
|
header["connection"] = "something"
|
||||||
self.assertEqual(sock._validate_header(header, key), False)
|
self.assertEqual(sock._validate_header(header, key, None), False)
|
||||||
del header["connection"]
|
del header["connection"]
|
||||||
self.assertEqual(sock._validate_header(header, key), False)
|
self.assertEqual(sock._validate_header(header, key, None), False)
|
||||||
|
|
||||||
header = required_header.copy()
|
header = required_header.copy()
|
||||||
header["sec-websocket-accept"] = "something"
|
header["sec-websocket-accept"] = "something"
|
||||||
self.assertEqual(sock._validate_header(header, key), False)
|
self.assertEqual(sock._validate_header(header, key, None), False)
|
||||||
del header["sec-websocket-accept"]
|
del header["sec-websocket-accept"]
|
||||||
self.assertEqual(sock._validate_header(header, key), False)
|
self.assertEqual(sock._validate_header(header, key, None), False)
|
||||||
|
|
||||||
|
|
||||||
|
header = required_header.copy()
|
||||||
|
header["sec-websocket-protocol"] = "sub1"
|
||||||
|
self.assertEqual(sock._validate_header(header, key, ["sub1", "sub2"]), True)
|
||||||
|
self.assertEqual(sock._validate_header(header, key, ["sub2", "sub3"]), False)
|
||||||
|
|
||||||
def testReadHeader(self):
|
def testReadHeader(self):
|
||||||
sock = ws.WebSocket()
|
sock = ws.WebSocket()
|
||||||
|
Reference in New Issue
Block a user