diff --git a/websocket/_app.py b/websocket/_app.py index 5c8b9d5..0a08c64 100644 --- a/websocket/_app.py +++ b/websocket/_app.py @@ -43,7 +43,8 @@ class WebSocketApp(object): on_open=None, on_message=None, on_error=None, on_close=None, on_ping=None, on_pong=None, on_cont_message=None, - keep_running=True, get_mask_key=None, cookie=None): + keep_running=True, get_mask_key=None, cookie=None, + subprotocols=None): """ url: websocket url. header: custom header for websocket handshake. @@ -53,21 +54,22 @@ class WebSocketApp(object): on_message has 2 arguments. The 1st arugment is this class object. The passing 2nd arugment is utf-8 string which we get from the server. - on_error: callable object which is called when we get error. + on_error: callable object which is called when we get error. on_error has 2 arguments. The 1st arugment is this class object. The passing 2nd arugment is exception object. - on_close: callable object which is called when closed the connection. + on_close: callable object which is called when closed the connection. this function has one argument. The arugment is this class object. - on_cont_message: callback object which is called when recieve continued frame data. + on_cont_message: callback object which is called when recieve continued frame data. on_message has 3 arguments. The 1st arugment is this class object. The passing 2nd arugment is utf-8 string which we get from the server. The 3rd arugment is continue flag. if 0, the data continue to next frame data - keep_running: a boolean flag indicating whether the app's main loop should + keep_running: a boolean flag indicating whether the app's main loop should keep running, defaults to True - get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's + get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's docstring for more information + subprotocols: array of available sub protocols. default is None. """ self.url = url self.header = header @@ -83,6 +85,7 @@ class WebSocketApp(object): self.get_mask_key = get_mask_key self.sock = None self.last_ping_tm = 0 + self.subprotocols =subprotocols def send(self, data, opcode=ABNF.OPCODE_TEXT): """ @@ -138,7 +141,8 @@ class WebSocketApp(object): fire_cont_frame=self.on_cont_message and True or False) self.sock.settimeout(getdefaulttimeout()) self.sock.connect(self.url, header=self.header, cookie=self.cookie, - http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port) + http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port, + subprotocols=self.subprotocols) self._callback(self.on_open) if ping_interval: @@ -158,7 +162,7 @@ class WebSocketApp(object): if r: op_code, frame = self.sock.recv_data_frame(True) if op_code == ABNF.OPCODE_CLOSE: - close_frmae = frame + close_frame = frame break elif op_code == ABNF.OPCODE_PING: self._callback(self.on_ping, frame.data) diff --git a/websocket/_core.py b/websocket/_core.py index d703307..f53f79b 100644 --- a/websocket/_core.py +++ b/websocket/_core.py @@ -208,6 +208,7 @@ def create_connection(url, timeout=None, **options): "enable_multithread" -> enable lock for multithread. "sockopt" -> socket options "sslopt" -> ssl option + "subprotocols" - array of available sub protocols. default is None. """ sockopt = options.get("sockopt", []) sslopt = options.get("sslopt", {}) @@ -410,6 +411,7 @@ class WebSocket(object): "cookie" -> cookie value. "http_proxy_host" - http proxy host name. "http_proxy_port" - http proxy port. If not set, set to 80. + "subprotocols" - array of available sub protocols. default is None. """ @@ -509,6 +511,10 @@ class WebSocket(object): headers.append("Sec-WebSocket-Key: %s" % key) headers.append("Sec-WebSocket-Version: %s" % VERSION) + subprotocols = options.get("subprotocols") + if subprotocols: + headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols)) + if "header" in options: headers.extend(options["header"]) @@ -530,14 +536,14 @@ class WebSocket(object): _dump("request header", header_str) resp_headers = self._get_resp_headers() - success = self._validate_header(resp_headers, key) + success = self._validate_header(resp_headers, key, options.get("subprotocols")) if not success: self.close() raise WebSocketException("Invalid WebSocket Header") self.connected = True - def _validate_header(self, headers, key): + def _validate_header(self, headers, key, subprotocols): for k, v in _HEADERS_TO_CHECK.items(): r = headers.get(k, None) if not r: @@ -545,6 +551,13 @@ class WebSocket(object): r = r.lower() if v != r: return False + + if subprotocols: + subproto = headers.get("sec-websocket-protocol", None) + if not subproto or subproto not in subprotocols: + logger.error("Invalid subprotocol: " + str(subprotocols)) + return False + result = headers.get("sec-websocket-accept", None) if not result: diff --git a/websocket/tests/test_websocket.py b/websocket/tests/test_websocket.py index 28e10f5..0495721 100644 --- a/websocket/tests/test_websocket.py +++ b/websocket/tests/test_websocket.py @@ -184,25 +184,31 @@ class WebSocketTest(unittest.TestCase): "connection": "upgrade", "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=", } - self.assertEqual(sock._validate_header(required_header, key), True) + self.assertEqual(sock._validate_header(required_header, key, None), True) header = required_header.copy() header["upgrade"] = "http" - self.assertEqual(sock._validate_header(header, key), False) + self.assertEqual(sock._validate_header(header, key, None), False) del header["upgrade"] - self.assertEqual(sock._validate_header(header, key), False) + self.assertEqual(sock._validate_header(header, key, None), False) header = required_header.copy() header["connection"] = "something" - self.assertEqual(sock._validate_header(header, key), False) + self.assertEqual(sock._validate_header(header, key, None), False) del header["connection"] - self.assertEqual(sock._validate_header(header, key), False) + self.assertEqual(sock._validate_header(header, key, None), False) header = required_header.copy() header["sec-websocket-accept"] = "something" - self.assertEqual(sock._validate_header(header, key), False) + self.assertEqual(sock._validate_header(header, key, None), False) del header["sec-websocket-accept"] - self.assertEqual(sock._validate_header(header, key), False) + self.assertEqual(sock._validate_header(header, key, None), False) + + + header = required_header.copy() + header["sec-websocket-protocol"] = "sub1" + self.assertEqual(sock._validate_header(header, key, ["sub1", "sub2"]), True) + self.assertEqual(sock._validate_header(header, key, ["sub2", "sub3"]), False) def testReadHeader(self): sock = ws.WebSocket()