This commit is contained in:
liris
2015-03-20 09:13:35 +09:00
10 changed files with 393 additions and 253 deletions

View File

@@ -1,6 +1,10 @@
ChangeLog ChangeLog
============ ============
- 0.27.0
- refactoring.
- 0.26.0 - 0.26.0
- all WebSocketException provide message string (#152) - all WebSocketException provide message string (#152)

View File

@@ -35,6 +35,8 @@ from ._exceptions import *
from ._logging import * from ._logging import *
from websocket._abnf import ABNF from websocket._abnf import ABNF
__all__ = ["WebSocketApp"]
class WebSocketApp(object): class WebSocketApp(object):
""" """

View File

@@ -50,8 +50,6 @@ else:
import os import os
import errno import errno
import struct import struct
import uuid
import hashlib
import threading import threading
# websocket modules # websocket modules
@@ -61,6 +59,8 @@ from ._socket import *
from ._utils import * from ._utils import *
from ._url import * from ._url import *
from ._logging import * from ._logging import *
from ._http import *
from ._handshake import *
""" """
websocket python client. websocket python client.
@@ -71,10 +71,6 @@ Please see http://tools.ietf.org/html/rfc6455 for protocol.
""" """
# websocket supported version.
VERSION = 13
def create_connection(url, timeout=None, **options): def create_connection(url, timeout=None, **options):
""" """
connect to url and return websocket object. connect to url and return websocket object.
@@ -125,17 +121,6 @@ def create_connection(url, timeout=None, **options):
return websock return websock
def _create_sec_websocket_key():
uid = uuid.uuid4()
return base64encode(uid.bytes).decode('utf-8').strip()
_HEADERS_TO_CHECK = {
"upgrade": "websocket",
"connection": "upgrade",
}
class WebSocket(object): class WebSocket(object):
""" """
Low level WebSocket interface. Low level WebSocket interface.
@@ -257,200 +242,17 @@ class WebSocket(object):
default is None. default is None.
""" """
if "sockopt" in options:
hostname, port, resource, is_secure = parse_url(url) del options["sockopt"]
proxy_host, proxy_port, proxy_auth = get_proxy_info(hostname, is_secure, **options) self.sock, addrs = connect(url, self.sockopt, self.sslopt, self.timeout, **options)
if not proxy_host:
addrinfo_list = socket.getaddrinfo(hostname, port, 0, 0, socket.SOL_TCP)
else:
proxy_port = proxy_port and proxy_port or 80
addrinfo_list = socket.getaddrinfo(proxy_host, proxy_port, 0, 0, socket.SOL_TCP)
if not addrinfo_list:
raise WebSocketException("Host not found.: " + hostname + ":" + str(port))
err = None
for addrinfo in addrinfo_list:
family = addrinfo[0]
self.sock = socket.socket(family)
self.sock.settimeout(self.timeout)
for opts in DEFAULT_SOCKET_OPTION:
self.sock.setsockopt(*opts)
for opts in self.sockopt:
self.sock.setsockopt(*opts)
address = addrinfo[4]
try:
self.sock.connect(address)
except socket.error as error:
error.remote_ip = str(address[0])
if error.errno in (errno.ECONNREFUSED, ):
err = error
continue
else:
raise
else:
break
else:
raise err
if proxy_host:
self._tunnel(hostname, port, proxy_auth)
if is_secure:
if HAVE_SSL:
sslopt = dict(cert_reqs=ssl.CERT_REQUIRED)
certPath = os.path.join(
os.path.dirname(__file__), "cacert.pem")
if os.path.isfile(certPath):
sslopt['ca_certs'] = certPath
sslopt.update(self.sslopt)
check_hostname = sslopt.pop('check_hostname', True)
self.sock = ssl.wrap_socket(self.sock, **sslopt)
if (sslopt["cert_reqs"] != ssl.CERT_NONE
and check_hostname):
match_hostname(self.sock.getpeercert(), hostname)
else:
raise WebSocketException("SSL not available.")
self._handshake(hostname, port, resource, **options)
def _tunnel(self, host, port, auth):
debug("Connecting proxy...")
connect_header = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port)
# TODO: support digest auth.
if auth and auth[0]:
auth_str = auth[0]
if auth[1]:
auth_str += ":" + auth[1]
encoded_str = base64encode(auth_str.encode()).strip().decode()
connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str
connect_header += "\r\n"
dump("request header", connect_header)
self._send(connect_header)
try: try:
status, resp_headers = self._read_headers() self.subprotocol = handshake(self.sock, *addrs, **options)
except Exepiton as e: self.connected = True
raise WebSocketProxyException(str(e)) except:
self.sock.close()
if status != 200: self.sock = None
raise WebSocketProxyException("failed CONNECT via proxy status: " + str(status)) raise
def _get_resp_headers(self, success_status=101):
status, resp_headers = self._read_headers()
if status != success_status:
self.close()
raise WebSocketException("Handshake status %d" % status)
return resp_headers
def _get_handshake_headers(self, resource, host, port, options):
headers = []
headers.append("GET %s HTTP/1.1" % resource)
headers.append("Upgrade: websocket")
headers.append("Connection: Upgrade")
if port == 80:
hostport = host
else:
hostport = "%s:%d" % (host, port)
headers.append("Host: %s" % hostport)
if "origin" in options:
headers.append("Origin: %s" % options["origin"])
else:
headers.append("Origin: http://%s" % hostport)
key = _create_sec_websocket_key()
headers.append("Sec-WebSocket-Key: %s" % key)
headers.append("Sec-WebSocket-Version: %s" % VERSION)
subprotocols = options.get("subprotocols")
if subprotocols:
headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
if "header" in options:
headers.extend(options["header"])
cookie = options.get("cookie", None)
if cookie:
headers.append("Cookie: %s" % cookie)
headers.append("")
headers.append("")
return headers, key
def _handshake(self, host, port, resource, **options):
headers, key = self._get_handshake_headers(resource, host, port, options)
header_str = "\r\n".join(headers)
self._send(header_str)
dump("request header", header_str)
resp_headers = self._get_resp_headers()
success = self._validate_header(resp_headers, key, options.get("subprotocols"))
if not success:
self.close()
raise WebSocketException("Invalid WebSocket Header")
self.connected = True
def _validate_header(self, headers, key, subprotocols):
for k, v in _HEADERS_TO_CHECK.items():
r = headers.get(k, None)
if not r:
return False
r = r.lower()
if v != r:
return False
if subprotocols:
subproto = headers.get("sec-websocket-protocol", None)
if not subproto or subproto not in subprotocols:
error("Invalid subprotocol: " + str(subprotocols))
return False
self.subprotocol = subproto
result = headers.get("sec-websocket-accept", None)
if not result:
return False
result = result.lower()
if isinstance(result, six.text_type):
result = result.encode('utf-8')
value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')
hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
return hashed == result
def _read_headers(self):
status = None
headers = {}
trace("--- response header ---")
while True:
line = self._recv_line()
line = line.decode('utf-8').strip()
if not line:
break
trace(line)
if not status:
status_info = line.split(" ", 2)
status = int(status_info[1])
else:
kv = line.split(":", 1)
if len(kv) == 2:
key, value = kv
headers[key.lower()] = value.strip().lower()
else:
raise WebSocketException("Invalid header")
trace("-----------------------")
return status, headers
def send(self, payload, opcode=ABNF.OPCODE_TEXT): def send(self, payload, opcode=ABNF.OPCODE_TEXT):
""" """
@@ -720,15 +522,3 @@ class WebSocket(object):
else: else:
self._recv_buffer = [unified[bufsize:]] self._recv_buffer = [unified[bufsize:]]
return unified[:bufsize] return unified[:bufsize]
def _recv_line(self):
try:
return recv_line(self.sock)
except WebSocketConnectionClosedException:
if self.sock:
self.sock.close()
self.sock = None
self.connected = False
raise
except:
raise

143
websocket/_handshake.py Normal file
View File

@@ -0,0 +1,143 @@
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor,
Boston, MA 02110-1335 USA
"""
import six
if six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
import uuid
import hashlib
from ._logging import *
from ._url import *
from ._socket import*
from ._http import *
__all__ = ["handshake"]
# websocket supported version.
VERSION = 13
def handshake(sock, host, port, resource, **options):
headers, key = _get_handshake_headers(resource, host, port, options)
header_str = "\r\n".join(headers)
send(sock, header_str)
dump("request header", header_str)
resp = _get_resp_headers(sock)
success, subproto = _validate(resp, key, options.get("subprotocols"))
if not success:
raise WebSocketException("Invalid WebSocket Header")
return subproto
def _get_handshake_headers(resource, host, port, options):
headers = []
headers.append("GET %s HTTP/1.1" % resource)
headers.append("Upgrade: websocket")
headers.append("Connection: Upgrade")
if port == 80:
hostport = host
else:
hostport = "%s:%d" % (host, port)
headers.append("Host: %s" % hostport)
if "origin" in options:
headers.append("Origin: %s" % options["origin"])
else:
headers.append("Origin: http://%s" % hostport)
key = _create_sec_websocket_key()
headers.append("Sec-WebSocket-Key: %s" % key)
headers.append("Sec-WebSocket-Version: %s" % VERSION)
subprotocols = options.get("subprotocols")
if subprotocols:
headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
if "header" in options:
headers.extend(options["header"])
cookie = options.get("cookie", None)
if cookie:
headers.append("Cookie: %s" % cookie)
headers.append("")
headers.append("")
return headers, key
def _get_resp_headers(sock, success_status=101):
status, resp_headers = read_headers(sock)
if status != success_status:
raise WebSocketException("Handshake status %d" % status)
return resp_headers
_HEADERS_TO_CHECK = {
"upgrade": "websocket",
"connection": "upgrade",
}
def _validate(headers, key, subprotocols):
subproto = None
for k, v in _HEADERS_TO_CHECK.items():
r = headers.get(k, None)
if not r:
return False, None
r = r.lower()
if v != r:
return False, None
if subprotocols:
subproto = headers.get("sec-websocket-protocol", None)
if not subproto or subproto not in subprotocols:
error("Invalid subprotocol: " + str(subprotocols))
return False, None
result = headers.get("sec-websocket-accept", None)
if not result:
return False, None
result = result.lower()
if isinstance(result, six.text_type):
result = result.encode('utf-8')
value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')
hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
success = (hashed == result)
if success:
return True, subproto
else:
return False, None
def _create_sec_websocket_key():
uid = uuid.uuid4()
return base64encode(uid.bytes).decode('utf-8').strip()

185
websocket/_http.py Normal file
View File

@@ -0,0 +1,185 @@
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor,
Boston, MA 02110-1335 USA
"""
import six
import socket
try:
import ssl
from ssl import SSLError
if hasattr(ssl, "match_hostname"):
from ssl import match_hostname
else:
from backports.ssl_match_hostname import match_hostname
HAVE_SSL = True
except ImportError:
# dummy class of SSLError for ssl none-support environment.
class SSLError(Exception):
pass
HAVE_SSL = False
if six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
from ._logging import *
from ._url import *
from ._socket import*
from ._exceptions import *
__all__ = ["connect", "read_headers"]
def connect(url, sockopt, sslopt, timeout, **options):
hostname, port, resource, is_secure = parse_url(url)
addrinfo_list, need_tunnel, auth = _get_addrinfo_list(hostname, port, is_secure, **options)
if not addrinfo_list:
raise WebSocketException(
"Host not found.: " + hostname + ":" + str(port))
sock = None
try:
sock = _open_socket(addrinfo_list, sockopt, timeout)
if need_tunnel:
sock = _tunnel(sock, hostname, port, auth)
if is_secure:
if HAVE_SSL:
sock = _ssl_socket(sock, sslopt)
else:
raise WebSocketException("SSL not available.")
return sock, (hostname, port, resource)
except:
if sock:
sock.close()
raise
def _get_addrinfo_list(hostname, port, is_secure, **options):
phost, pport, pauth = get_proxy_info(hostname, is_secure, **options)
if not phost:
addrinfo_list = socket.getaddrinfo(hostname, port, 0, 0, socket.SOL_TCP)
return addrinfo_list, False, None
else:
pport = pport and pport or 80
addrinfo_list = socket.getaddrinfo(phost, pport, 0, 0, socket.SOL_TCP)
return addrinfo_list, True, pauth
def _open_socket(addrinfo_list, sockopt, timeout):
err = None
for addrinfo in addrinfo_list:
family = addrinfo[0]
sock = socket.socket(family)
sock.settimeout(timeout)
for opts in DEFAULT_SOCKET_OPTION:
sock.setsockopt(*opts)
for opts in sockopt:
sock.setsockopt(*opts)
address = addrinfo[4]
try:
sock.connect(address)
except socket.error as error:
error.remote_ip = str(address[0])
if error.errno in (errno.ECONNREFUSED, ):
err = error
continue
else:
raise
else:
break
else:
raise err
return sock
def _ssl_socket(sock, sslopt):
sslopt = dict(cert_reqs=ssl.CERT_REQUIRED)
certPath = os.path.join(
os.path.dirname(__file__), "cacert.pem")
if os.path.isfile(certPath):
sslopt['ca_certs'] = certPath
sslopt.update(sslopt)
check_hostname = sslopt.pop('check_hostname', True)
sock = ssl.wrap_socket(sock, **sslopt)
if (sslopt["cert_reqs"] != ssl.CERT_NONE
and check_hostname):
match_hostname(sock.getpeercert(), hostname)
def _tunnel(sock, host, port, auth):
debug("Connecting proxy...")
connect_header = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port)
# TODO: support digest auth.
if auth and auth[0]:
auth_str = auth[0]
if auth[1]:
auth_str += ":" + auth[1]
encoded_str = base64encode(auth_str.encode()).strip().decode()
connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str
connect_header += "\r\n"
dump("request header", connect_header)
send(sock, connect_header)
try:
status, resp_headers = read_headers()
except Exepiton as e:
raise WebSocketProxyException(str(e))
if status != 200:
raise WebSocketProxyException(
"failed CONNECT via proxy status: " + str(status))
def read_headers(sock):
status = None
headers = {}
trace("--- response header ---")
while True:
line = recv_line(sock)
line = line.decode('utf-8').strip()
if not line:
break
trace(line)
if not status:
status_info = line.split(" ", 2)
status = int(status_info[1])
else:
kv = line.split(":", 1)
if len(kv) == 2:
key, value = kv
headers[key.lower()] = value.strip().lower()
else:
raise WebSocketException("Invalid header")
trace("-----------------------")
return status, headers

View File

@@ -25,6 +25,9 @@ import logging
_logger = logging.getLogger() _logger = logging.getLogger()
_traceEnabled = False _traceEnabled = False
__all__ = ["enableTrace", "dump", "error", "debug", "trace",
"isEnableForError", "isEnableForDebug"]
def enableTrace(tracable): def enableTrace(tracable):
""" """

View File

@@ -26,7 +26,7 @@ import six
from ._exceptions import * from ._exceptions import *
from ._utils import * from ._utils import *
DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1),] DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)]
if hasattr(socket, "SO_KEEPALIVE"): if hasattr(socket, "SO_KEEPALIVE"):
DEFAULT_SOCKET_OPTION.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)) DEFAULT_SOCKET_OPTION.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1))
if hasattr(socket, "TCP_KEEPIDLE"): if hasattr(socket, "TCP_KEEPIDLE"):
@@ -38,6 +38,10 @@ if hasattr(socket, "TCP_KEEPCNT"):
_default_timeout = None _default_timeout = None
__all__ = ["DEFAULT_SOCKET_OPTION", "setdefaulttimeout", "getdefaulttimeout",
"recv", "recv_line", "send"]
def setdefaulttimeout(timeout): def setdefaulttimeout(timeout):
""" """
Set the global timeout setting to connect. Set the global timeout setting to connect.
@@ -73,9 +77,10 @@ def recv(sock, bufsize):
if not bytes: if not bytes:
raise WebSocketConnectionClosedException("Connection is already closed.") raise WebSocketConnectionClosedException("Connection is already closed.")
return bytes return bytes
def recv_line(sock): def recv_line(sock):
line = [] line = []
while True: while True:
@@ -89,10 +94,10 @@ def recv_line(sock):
def send(sock, data): def send(sock, data):
if isinstance(data, six.text_type): if isinstance(data, six.text_type):
data = data.encode('utf-8') data = data.encode('utf-8')
if not sock: if not sock:
raise WebSocketConnectionClosedException("socket is already closed.") raise WebSocketConnectionClosedException("socket is already closed.")
try: try:
return sock.send(data) return sock.send(data)
except socket.timeout as e: except socket.timeout as e:
@@ -104,4 +109,3 @@ def send(sock, data):
raise WebSocketTimeoutException(message) raise WebSocketTimeoutException(message)
else: else:
raise raise

View File

@@ -23,6 +23,9 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
from six.moves.urllib.parse import urlparse from six.moves.urllib.parse import urlparse
import os import os
__all__ = ["parse_url", "get_proxy_info"]
def parse_url(url): def parse_url(url):
""" """
parse url and the result is tuple of parse url and the result is tuple of
@@ -68,6 +71,7 @@ def parse_url(url):
DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"] DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"]
def _is_no_proxy_host(hostname, no_proxy): def _is_no_proxy_host(hostname, no_proxy):
if not no_proxy: if not no_proxy:
v = os.environ.get("no_proxy", "").replace(" ", "") v = os.environ.get("no_proxy", "").replace(" ", "")
@@ -77,21 +81,26 @@ def _is_no_proxy_host(hostname, no_proxy):
return hostname in no_proxy return hostname in no_proxy
def get_proxy_info(hostname, is_secure, **options): def get_proxy_info(hostname, is_secure, **options):
""" """
try to retrieve proxy host and port from environment if not provided in options. try to retrieve proxy host and port from environment
if not provided in options.
result is (proxy_host, proxy_port, proxy_auth). result is (proxy_host, proxy_port, proxy_auth).
proxy_auth is tuple of username and password of proxy authentication information. proxy_auth is tuple of username and password
of proxy authentication information.
hostname: websocket server name. hostname: websocket server name.
is_secure: is the connection secure? (wss) is_secure: is the connection secure? (wss)
looks for "https_proxy" in env before falling back to "http_proxy" looks for "https_proxy" in env
before falling back to "http_proxy"
options: "http_proxy_host" - http proxy host name. options: "http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. "http_proxy_port" - http proxy port.
"http_no_proxy" - host names, which doesn't use proxy. "http_no_proxy" - host names, which doesn't use proxy.
"http_proxy_auth" - http proxy auth infomation. tuple of username and password. "http_proxy_auth" - http proxy auth infomation.
tuple of username and password.
defualt is None defualt is None
""" """
if _is_no_proxy_host(hostname, options.get("http_no_proxy", None)): if _is_no_proxy_host(hostname, options.get("http_no_proxy", None)):
@@ -99,7 +108,9 @@ def get_proxy_info(hostname, is_secure, **options):
http_proxy_host = options.get("http_proxy_host", None) http_proxy_host = options.get("http_proxy_host", None)
if http_proxy_host: if http_proxy_host:
return http_proxy_host, options.get("http_proxy_port", 0), options.get("http_proxy_auth", None) port = options.get("http_proxy_port", 0)
auth = options.get("http_proxy_auth", None)
return http_proxy_host, port, auth
env_keys = ["http_proxy"] env_keys = ["http_proxy"]
if is_secure: if is_secure:
@@ -113,6 +124,3 @@ def get_proxy_info(hostname, is_secure, **options):
return proxy.hostname, proxy.port, auth return proxy.hostname, proxy.port, auth
return None, 0, None return None, 0, None

View File

@@ -22,6 +22,8 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import six import six
__all__ = ["NoLock", "validate_utf8", "extract_err_message"]
class NoLock(object): class NoLock(object):
def __enter__(self): def __enter__(self):
pass pass

View File

@@ -25,13 +25,16 @@ import uuid
# websocket-client # websocket-client
import websocket as ws import websocket as ws
from websocket._core import _create_sec_websocket_key from websocket._handshake import _create_sec_websocket_key
from websocket._url import parse_url, get_proxy_info from websocket._url import parse_url, get_proxy_info
from websocket._utils import validate_utf8 from websocket._utils import validate_utf8
from websocket._handshake import _validate as _validate_header
from websocket._http import read_headers
# Skip test to access the internet. # Skip test to access the internet.
TEST_WITH_INTERNET = False TEST_WITH_INTERNET = False
# TEST_WITH_INTERNET = True
# Skip Secure WebSocket test. # Skip Secure WebSocket test.
TEST_SECURE_WS = False TEST_SECURE_WS = False
@@ -75,6 +78,7 @@ class HeaderSockMock(SockMock):
with open(path, "rb") as f: with open(path, "rb") as f:
self.add_packet(f.read()) self.add_packet(f.read())
class WebSocketTest(unittest.TestCase): class WebSocketTest(unittest.TestCase):
def setUp(self): def setUp(self):
ws.enableTrace(TRACABLE) ws.enableTrace(TRACABLE)
@@ -178,49 +182,44 @@ class WebSocketTest(unittest.TestCase):
self.assertTrue(six.u("¥n") not in key) self.assertTrue(six.u("¥n") not in key)
def testWsUtils(self): def testWsUtils(self):
sock = ws.WebSocket()
key = "c6b8hTg4EeGb2gQMztV1/g==" key = "c6b8hTg4EeGb2gQMztV1/g=="
required_header = { required_header = {
"upgrade": "websocket", "upgrade": "websocket",
"connection": "upgrade", "connection": "upgrade",
"sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=", "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=",
} }
self.assertEqual(sock._validate_header(required_header, key, None), True) self.assertEqual(_validate_header(required_header, key, None), (True, None))
header = required_header.copy() header = required_header.copy()
header["upgrade"] = "http" header["upgrade"] = "http"
self.assertEqual(sock._validate_header(header, key, None), False) self.assertEqual(_validate_header(header, key, None), (False, None))
del header["upgrade"] del header["upgrade"]
self.assertEqual(sock._validate_header(header, key, None), False) self.assertEqual(_validate_header(header, key, None), (False, None))
header = required_header.copy() header = required_header.copy()
header["connection"] = "something" header["connection"] = "something"
self.assertEqual(sock._validate_header(header, key, None), False) self.assertEqual(_validate_header(header, key, None), (False, None))
del header["connection"] del header["connection"]
self.assertEqual(sock._validate_header(header, key, None), False) self.assertEqual(_validate_header(header, key, None), (False, None))
header = required_header.copy() header = required_header.copy()
header["sec-websocket-accept"] = "something" header["sec-websocket-accept"] = "something"
self.assertEqual(sock._validate_header(header, key, None), False) self.assertEqual(_validate_header(header, key, None), (False, None))
del header["sec-websocket-accept"] del header["sec-websocket-accept"]
self.assertEqual(sock._validate_header(header, key, None), False) self.assertEqual(_validate_header(header, key, None), (False, None))
header = required_header.copy() header = required_header.copy()
header["sec-websocket-protocol"] = "sub1" header["sec-websocket-protocol"] = "sub1"
self.assertEqual(sock._validate_header(header, key, ["sub1", "sub2"]), True) self.assertEqual(_validate_header(header, key, ["sub1", "sub2"]), (True, "sub1"))
self.assertEqual(sock._validate_header(header, key, ["sub2", "sub3"]), False) self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None))
def testReadHeader(self): def testReadHeader(self):
sock = ws.WebSocket() status, header = read_headers(HeaderSockMock("data/header01.txt"))
sock.sock = HeaderSockMock("data/header01.txt")
status, header = sock._read_headers()
self.assertEqual(status, 101) self.assertEqual(status, 101)
self.assertEqual(header["connection"], "upgrade") self.assertEqual(header["connection"], "upgrade")
sock.sock = HeaderSockMock("data/header02.txt") HeaderSockMock("data/header02.txt")
self.assertRaises(ws.WebSocketException, sock._read_headers) self.assertRaises(ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt"))
def testSend(self): def testSend(self):
# TODO: add longer frame data # TODO: add longer frame data