enable to set subprotocol
This commit is contained in:
Hiroki Ohtani
2014-10-14 09:17:35 +09:00
parent b2292a9804
commit dd377dc4c7
3 changed files with 40 additions and 17 deletions

View File

@@ -43,7 +43,8 @@ class WebSocketApp(object):
on_open=None, on_message=None, on_error=None, on_open=None, on_message=None, on_error=None,
on_close=None, on_ping=None, on_pong=None, on_close=None, on_ping=None, on_pong=None,
on_cont_message=None, on_cont_message=None,
keep_running=True, get_mask_key=None, cookie=None): keep_running=True, get_mask_key=None, cookie=None,
subprotocols=None):
""" """
url: websocket url. url: websocket url.
header: custom header for websocket handshake. header: custom header for websocket handshake.
@@ -68,6 +69,7 @@ class WebSocketApp(object):
keep running, defaults to True keep running, defaults to True
get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's
docstring for more information docstring for more information
subprotocols: array of available sub protocols. default is None.
""" """
self.url = url self.url = url
self.header = header self.header = header
@@ -83,6 +85,7 @@ class WebSocketApp(object):
self.get_mask_key = get_mask_key self.get_mask_key = get_mask_key
self.sock = None self.sock = None
self.last_ping_tm = 0 self.last_ping_tm = 0
self.subprotocols =subprotocols
def send(self, data, opcode=ABNF.OPCODE_TEXT): def send(self, data, opcode=ABNF.OPCODE_TEXT):
""" """
@@ -138,7 +141,8 @@ class WebSocketApp(object):
fire_cont_frame=self.on_cont_message and True or False) fire_cont_frame=self.on_cont_message and True or False)
self.sock.settimeout(getdefaulttimeout()) self.sock.settimeout(getdefaulttimeout())
self.sock.connect(self.url, header=self.header, cookie=self.cookie, self.sock.connect(self.url, header=self.header, cookie=self.cookie,
http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port) http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port,
subprotocols=self.subprotocols)
self._callback(self.on_open) self._callback(self.on_open)
if ping_interval: if ping_interval:
@@ -158,7 +162,7 @@ class WebSocketApp(object):
if r: if r:
op_code, frame = self.sock.recv_data_frame(True) op_code, frame = self.sock.recv_data_frame(True)
if op_code == ABNF.OPCODE_CLOSE: if op_code == ABNF.OPCODE_CLOSE:
close_frmae = frame close_frame = frame
break break
elif op_code == ABNF.OPCODE_PING: elif op_code == ABNF.OPCODE_PING:
self._callback(self.on_ping, frame.data) self._callback(self.on_ping, frame.data)

View File

@@ -208,6 +208,7 @@ def create_connection(url, timeout=None, **options):
"enable_multithread" -> enable lock for multithread. "enable_multithread" -> enable lock for multithread.
"sockopt" -> socket options "sockopt" -> socket options
"sslopt" -> ssl option "sslopt" -> ssl option
"subprotocols" - array of available sub protocols. default is None.
""" """
sockopt = options.get("sockopt", []) sockopt = options.get("sockopt", [])
sslopt = options.get("sslopt", {}) sslopt = options.get("sslopt", {})
@@ -410,6 +411,7 @@ class WebSocket(object):
"cookie" -> cookie value. "cookie" -> cookie value.
"http_proxy_host" - http proxy host name. "http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80. "http_proxy_port" - http proxy port. If not set, set to 80.
"subprotocols" - array of available sub protocols. default is None.
""" """
@@ -509,6 +511,10 @@ class WebSocket(object):
headers.append("Sec-WebSocket-Key: %s" % key) headers.append("Sec-WebSocket-Key: %s" % key)
headers.append("Sec-WebSocket-Version: %s" % VERSION) headers.append("Sec-WebSocket-Version: %s" % VERSION)
subprotocols = options.get("subprotocols")
if subprotocols:
headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
if "header" in options: if "header" in options:
headers.extend(options["header"]) headers.extend(options["header"])
@@ -530,14 +536,14 @@ class WebSocket(object):
_dump("request header", header_str) _dump("request header", header_str)
resp_headers = self._get_resp_headers() resp_headers = self._get_resp_headers()
success = self._validate_header(resp_headers, key) success = self._validate_header(resp_headers, key, options.get("subprotocols"))
if not success: if not success:
self.close() self.close()
raise WebSocketException("Invalid WebSocket Header") raise WebSocketException("Invalid WebSocket Header")
self.connected = True self.connected = True
def _validate_header(self, headers, key): def _validate_header(self, headers, key, subprotocols):
for k, v in _HEADERS_TO_CHECK.items(): for k, v in _HEADERS_TO_CHECK.items():
r = headers.get(k, None) r = headers.get(k, None)
if not r: if not r:
@@ -546,6 +552,13 @@ class WebSocket(object):
if v != r: if v != r:
return False return False
if subprotocols:
subproto = headers.get("sec-websocket-protocol", None)
if not subproto or subproto not in subprotocols:
logger.error("Invalid subprotocol: " + str(subprotocols))
return False
result = headers.get("sec-websocket-accept", None) result = headers.get("sec-websocket-accept", None)
if not result: if not result:
return False return False

View File

@@ -184,25 +184,31 @@ class WebSocketTest(unittest.TestCase):
"connection": "upgrade", "connection": "upgrade",
"sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=", "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=",
} }
self.assertEqual(sock._validate_header(required_header, key), True) self.assertEqual(sock._validate_header(required_header, key, None), True)
header = required_header.copy() header = required_header.copy()
header["upgrade"] = "http" header["upgrade"] = "http"
self.assertEqual(sock._validate_header(header, key), False) self.assertEqual(sock._validate_header(header, key, None), False)
del header["upgrade"] del header["upgrade"]
self.assertEqual(sock._validate_header(header, key), False) self.assertEqual(sock._validate_header(header, key, None), False)
header = required_header.copy() header = required_header.copy()
header["connection"] = "something" header["connection"] = "something"
self.assertEqual(sock._validate_header(header, key), False) self.assertEqual(sock._validate_header(header, key, None), False)
del header["connection"] del header["connection"]
self.assertEqual(sock._validate_header(header, key), False) self.assertEqual(sock._validate_header(header, key, None), False)
header = required_header.copy() header = required_header.copy()
header["sec-websocket-accept"] = "something" header["sec-websocket-accept"] = "something"
self.assertEqual(sock._validate_header(header, key), False) self.assertEqual(sock._validate_header(header, key, None), False)
del header["sec-websocket-accept"] del header["sec-websocket-accept"]
self.assertEqual(sock._validate_header(header, key), False) self.assertEqual(sock._validate_header(header, key, None), False)
header = required_header.copy()
header["sec-websocket-protocol"] = "sub1"
self.assertEqual(sock._validate_header(header, key, ["sub1", "sub2"]), True)
self.assertEqual(sock._validate_header(header, key, ["sub2", "sub3"]), False)
def testReadHeader(self): def testReadHeader(self):
sock = ws.WebSocket() sock = ws.WebSocket()