enable to set subprotocol
This commit is contained in:
Hiroki Ohtani
2014-10-14 09:17:35 +09:00
parent b2292a9804
commit dd377dc4c7
3 changed files with 40 additions and 17 deletions

View File

@@ -43,7 +43,8 @@ class WebSocketApp(object):
on_open=None, on_message=None, on_error=None, on_open=None, on_message=None, on_error=None,
on_close=None, on_ping=None, on_pong=None, on_close=None, on_ping=None, on_pong=None,
on_cont_message=None, on_cont_message=None,
keep_running=True, get_mask_key=None, cookie=None): keep_running=True, get_mask_key=None, cookie=None,
subprotocols=None):
""" """
url: websocket url. url: websocket url.
header: custom header for websocket handshake. header: custom header for websocket handshake.
@@ -53,21 +54,22 @@ class WebSocketApp(object):
on_message has 2 arguments. on_message has 2 arguments.
The 1st arugment is this class object. The 1st arugment is this class object.
The passing 2nd arugment is utf-8 string which we get from the server. The passing 2nd arugment is utf-8 string which we get from the server.
on_error: callable object which is called when we get error. on_error: callable object which is called when we get error.
on_error has 2 arguments. on_error has 2 arguments.
The 1st arugment is this class object. The 1st arugment is this class object.
The passing 2nd arugment is exception object. The passing 2nd arugment is exception object.
on_close: callable object which is called when closed the connection. on_close: callable object which is called when closed the connection.
this function has one argument. The arugment is this class object. this function has one argument. The arugment is this class object.
on_cont_message: callback object which is called when recieve continued frame data. on_cont_message: callback object which is called when recieve continued frame data.
on_message has 3 arguments. on_message has 3 arguments.
The 1st arugment is this class object. The 1st arugment is this class object.
The passing 2nd arugment is utf-8 string which we get from the server. The passing 2nd arugment is utf-8 string which we get from the server.
The 3rd arugment is continue flag. if 0, the data continue to next frame data The 3rd arugment is continue flag. if 0, the data continue to next frame data
keep_running: a boolean flag indicating whether the app's main loop should keep_running: a boolean flag indicating whether the app's main loop should
keep running, defaults to True keep running, defaults to True
get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's
docstring for more information docstring for more information
subprotocols: array of available sub protocols. default is None.
""" """
self.url = url self.url = url
self.header = header self.header = header
@@ -83,6 +85,7 @@ class WebSocketApp(object):
self.get_mask_key = get_mask_key self.get_mask_key = get_mask_key
self.sock = None self.sock = None
self.last_ping_tm = 0 self.last_ping_tm = 0
self.subprotocols =subprotocols
def send(self, data, opcode=ABNF.OPCODE_TEXT): def send(self, data, opcode=ABNF.OPCODE_TEXT):
""" """
@@ -138,7 +141,8 @@ class WebSocketApp(object):
fire_cont_frame=self.on_cont_message and True or False) fire_cont_frame=self.on_cont_message and True or False)
self.sock.settimeout(getdefaulttimeout()) self.sock.settimeout(getdefaulttimeout())
self.sock.connect(self.url, header=self.header, cookie=self.cookie, self.sock.connect(self.url, header=self.header, cookie=self.cookie,
http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port) http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port,
subprotocols=self.subprotocols)
self._callback(self.on_open) self._callback(self.on_open)
if ping_interval: if ping_interval:
@@ -158,7 +162,7 @@ class WebSocketApp(object):
if r: if r:
op_code, frame = self.sock.recv_data_frame(True) op_code, frame = self.sock.recv_data_frame(True)
if op_code == ABNF.OPCODE_CLOSE: if op_code == ABNF.OPCODE_CLOSE:
close_frmae = frame close_frame = frame
break break
elif op_code == ABNF.OPCODE_PING: elif op_code == ABNF.OPCODE_PING:
self._callback(self.on_ping, frame.data) self._callback(self.on_ping, frame.data)

View File

@@ -208,6 +208,7 @@ def create_connection(url, timeout=None, **options):
"enable_multithread" -> enable lock for multithread. "enable_multithread" -> enable lock for multithread.
"sockopt" -> socket options "sockopt" -> socket options
"sslopt" -> ssl option "sslopt" -> ssl option
"subprotocols" - array of available sub protocols. default is None.
""" """
sockopt = options.get("sockopt", []) sockopt = options.get("sockopt", [])
sslopt = options.get("sslopt", {}) sslopt = options.get("sslopt", {})
@@ -410,6 +411,7 @@ class WebSocket(object):
"cookie" -> cookie value. "cookie" -> cookie value.
"http_proxy_host" - http proxy host name. "http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80. "http_proxy_port" - http proxy port. If not set, set to 80.
"subprotocols" - array of available sub protocols. default is None.
""" """
@@ -509,6 +511,10 @@ class WebSocket(object):
headers.append("Sec-WebSocket-Key: %s" % key) headers.append("Sec-WebSocket-Key: %s" % key)
headers.append("Sec-WebSocket-Version: %s" % VERSION) headers.append("Sec-WebSocket-Version: %s" % VERSION)
subprotocols = options.get("subprotocols")
if subprotocols:
headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
if "header" in options: if "header" in options:
headers.extend(options["header"]) headers.extend(options["header"])
@@ -530,14 +536,14 @@ class WebSocket(object):
_dump("request header", header_str) _dump("request header", header_str)
resp_headers = self._get_resp_headers() resp_headers = self._get_resp_headers()
success = self._validate_header(resp_headers, key) success = self._validate_header(resp_headers, key, options.get("subprotocols"))
if not success: if not success:
self.close() self.close()
raise WebSocketException("Invalid WebSocket Header") raise WebSocketException("Invalid WebSocket Header")
self.connected = True self.connected = True
def _validate_header(self, headers, key): def _validate_header(self, headers, key, subprotocols):
for k, v in _HEADERS_TO_CHECK.items(): for k, v in _HEADERS_TO_CHECK.items():
r = headers.get(k, None) r = headers.get(k, None)
if not r: if not r:
@@ -545,6 +551,13 @@ class WebSocket(object):
r = r.lower() r = r.lower()
if v != r: if v != r:
return False return False
if subprotocols:
subproto = headers.get("sec-websocket-protocol", None)
if not subproto or subproto not in subprotocols:
logger.error("Invalid subprotocol: " + str(subprotocols))
return False
result = headers.get("sec-websocket-accept", None) result = headers.get("sec-websocket-accept", None)
if not result: if not result:

View File

@@ -184,25 +184,31 @@ class WebSocketTest(unittest.TestCase):
"connection": "upgrade", "connection": "upgrade",
"sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=", "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=",
} }
self.assertEqual(sock._validate_header(required_header, key), True) self.assertEqual(sock._validate_header(required_header, key, None), True)
header = required_header.copy() header = required_header.copy()
header["upgrade"] = "http" header["upgrade"] = "http"
self.assertEqual(sock._validate_header(header, key), False) self.assertEqual(sock._validate_header(header, key, None), False)
del header["upgrade"] del header["upgrade"]
self.assertEqual(sock._validate_header(header, key), False) self.assertEqual(sock._validate_header(header, key, None), False)
header = required_header.copy() header = required_header.copy()
header["connection"] = "something" header["connection"] = "something"
self.assertEqual(sock._validate_header(header, key), False) self.assertEqual(sock._validate_header(header, key, None), False)
del header["connection"] del header["connection"]
self.assertEqual(sock._validate_header(header, key), False) self.assertEqual(sock._validate_header(header, key, None), False)
header = required_header.copy() header = required_header.copy()
header["sec-websocket-accept"] = "something" header["sec-websocket-accept"] = "something"
self.assertEqual(sock._validate_header(header, key), False) self.assertEqual(sock._validate_header(header, key, None), False)
del header["sec-websocket-accept"] del header["sec-websocket-accept"]
self.assertEqual(sock._validate_header(header, key), False) self.assertEqual(sock._validate_header(header, key, None), False)
header = required_header.copy()
header["sec-websocket-protocol"] = "sub1"
self.assertEqual(sock._validate_header(header, key, ["sub1", "sub2"]), True)
self.assertEqual(sock._validate_header(header, key, ["sub2", "sub3"]), False)
def testReadHeader(self): def testReadHeader(self):
sock = ws.WebSocket() sock = ws.WebSocket()