fixed #160 and some refactoring

This commit is contained in:
liris
2015-03-25 09:13:08 +09:00
parent 77faf8c6f8
commit 3495bbf7e5
7 changed files with 100 additions and 45 deletions

View File

@@ -5,6 +5,7 @@ ChangeLog
- remove unittest2 requirements for python 2.6 (#156) - remove unittest2 requirements for python 2.6 (#156)
- fixed subprotocol case during header validation (#158) - fixed subprotocol case during header validation (#158)
- get response status and headers (#160)
- refactoring. - refactoring.
- 0.26.0 - 0.26.0

View File

@@ -155,15 +155,11 @@ class WebSocket(object):
""" """
Initalize WebSocket object. Initalize WebSocket object.
""" """
if sockopt is None: self.sock_opt = sock_opt(sockopt, sslopt)
sockopt = [] self.handshake_response = None
if sslopt is None:
sslopt = {}
self.connected = False
self.sock = None self.sock = None
self._timeout = None
self.sockopt = sockopt self.connected = False
self.sslopt = sslopt
self.get_mask_key = get_mask_key self.get_mask_key = get_mask_key
self.fire_cont_frame = fire_cont_frame self.fire_cont_frame = fire_cont_frame
self.skip_utf8_validation = skip_utf8_validation self.skip_utf8_validation = skip_utf8_validation
@@ -174,13 +170,12 @@ class WebSocket(object):
self._frame_buffer = FrameBuffer() self._frame_buffer = FrameBuffer()
self._cont_data = None self._cont_data = None
self._recving_frames = None self._recving_frames = None
if enable_multithread: if enable_multithread:
self.lock = threading.Lock() self.lock = threading.Lock()
else: else:
self.lock = NoLock() self.lock = NoLock()
self.subprotocol = None
def fileno(self): def fileno(self):
return self.sock.fileno() return self.sock.fileno()
@@ -200,7 +195,7 @@ class WebSocket(object):
""" """
Get the websocket timeout(second). Get the websocket timeout(second).
""" """
return self._timeout return self.sock_opt.timeout
def settimeout(self, timeout): def settimeout(self, timeout):
""" """
@@ -208,12 +203,45 @@ class WebSocket(object):
timeout: timeout time(second). timeout: timeout time(second).
""" """
self._timeout = timeout self.sock_opt.timeout = timeout
if self.sock: if self.sock:
self.sock.settimeout(timeout) self.sock.settimeout(timeout)
timeout = property(gettimeout, settimeout) timeout = property(gettimeout, settimeout)
def getsubprotocol(self):
"""
get subprotocol
"""
if self.handshake_response:
return self.handshake_response.subprotocol
else:
return None
subprotocol = property(getsubprotocol)
def getstatus(self):
"""
get handshake status
"""
if self.handshake_response:
return self.handshake_response.status
else:
return None
status = property(getstatus)
def getheaders(self):
"""
get handshake response header
"""
if self.handshake_response:
return self.handshake_response.headers
else:
return None
headers = property(getheaders)
def connect(self, url, **options): def connect(self, url, **options):
""" """
Connect to url. url is websocket url scheme. Connect to url. url is websocket url scheme.
@@ -232,6 +260,7 @@ class WebSocket(object):
options: "header" -> custom http header list. options: "header" -> custom http header list.
"cookie" -> cookie value. "cookie" -> cookie value.
"origin" -> custom origin url.
"http_proxy_host" - http proxy host name. "http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80. "http_proxy_port" - http proxy port. If not set, set to 80.
"http_no_proxy" - host names, which doesn't use proxy. "http_no_proxy" - host names, which doesn't use proxy.
@@ -242,12 +271,10 @@ class WebSocket(object):
default is None. default is None.
""" """
if "sockopt" in options: self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options))
del options["sockopt"]
self.sock, addrs = connect(url, self.sockopt, self.sslopt, self.timeout, **options)
try: try:
self.subprotocol = handshake(self.sock, *addrs, **options) self.handshake_response = handshake(self.sock, *addrs, **options)
self.connected = True self.connected = True
except: except:
self.sock.close() self.sock.close()

View File

@@ -35,12 +35,18 @@ from ._socket import*
from ._http import * from ._http import *
from ._exceptions import * from ._exceptions import *
__all__ = ["handshake"] __all__ = ["handshake_response", "handshake"]
# websocket supported version. # websocket supported version.
VERSION = 13 VERSION = 13
class handshake_response(object):
def __init__(self, status, headers, subprotocol):
self.status = status
self.headers = headers
self.subprotocol = subprotocol
def handshake(sock, host, port, resource, **options): def handshake(sock, host, port, resource, **options):
headers, key = _get_handshake_headers(resource, host, port, options) headers, key = _get_handshake_headers(resource, host, port, options)
@@ -48,12 +54,12 @@ def handshake(sock, host, port, resource, **options):
send(sock, header_str) send(sock, header_str)
dump("request header", header_str) dump("request header", header_str)
resp = _get_resp_headers(sock) status, resp = _get_resp_headers(sock)
success, subproto = _validate(resp, key, options.get("subprotocols")) success, subproto = _validate(resp, key, options.get("subprotocols"))
if not success: if not success:
raise WebSocketException("Invalid WebSocket Header") raise WebSocketException("Invalid WebSocket Header")
return subproto return handshake_response(status, resp, subproto)
def _get_handshake_headers(resource, host, port, options): def _get_handshake_headers(resource, host, port, options):
@@ -98,7 +104,7 @@ def _get_resp_headers(sock, success_status=101):
status, resp_headers = read_headers(sock) status, resp_headers = read_headers(sock)
if status != success_status: if status != success_status:
raise WebSocketException("Handshake status %d" % status) raise WebSocketException("Handshake status %d" % status)
return resp_headers return status, resp_headers
_HEADERS_TO_CHECK = { _HEADERS_TO_CHECK = {
"upgrade": "websocket", "upgrade": "websocket",

View File

@@ -49,25 +49,36 @@ from ._url import *
from ._socket import* from ._socket import*
from ._exceptions import * from ._exceptions import *
__all__ = ["connect", "read_headers"] __all__ = ["proxy_info", "connect", "read_headers"]
class proxy_info(object):
def __init__(self, **options):
self.host = options.get("http_proxy_host", None)
if self.host:
self.port = options.get("http_proxy_port", 0)
self.auth = options.get("http_proxy_auth", None)
self.no_proxy = options.get("http_no_proxy", None)
else:
self.port = 0
self.auth = None
self.no_proxy = None
def connect(url, sockopt, sslopt, timeout, **options): def connect(url, options, proxy):
hostname, port, resource, is_secure = parse_url(url) hostname, port, resource, is_secure = parse_url(url)
addrinfo_list, need_tunnel, auth = _get_addrinfo_list(hostname, port, is_secure, **options) addrinfo_list, need_tunnel, auth = _get_addrinfo_list(hostname, port, is_secure, proxy)
if not addrinfo_list: if not addrinfo_list:
raise WebSocketException( raise WebSocketException(
"Host not found.: " + hostname + ":" + str(port)) "Host not found.: " + hostname + ":" + str(port))
sock = None sock = None
try: try:
sock = _open_socket(addrinfo_list, sockopt, timeout) sock = _open_socket(addrinfo_list, options.sockopt, options.timeout)
if need_tunnel: if need_tunnel:
sock = _tunnel(sock, hostname, port, auth) sock = _tunnel(sock, hostname, port, auth)
if is_secure: if is_secure:
if HAVE_SSL: if HAVE_SSL:
sock = _ssl_socket(sock, sslopt) sock = _ssl_socket(sock, options.sslopt)
else: else:
raise WebSocketException("SSL not available.") raise WebSocketException("SSL not available.")
@@ -78,8 +89,9 @@ def connect(url, sockopt, sslopt, timeout, **options):
raise raise
def _get_addrinfo_list(hostname, port, is_secure, **options): def _get_addrinfo_list(hostname, port, is_secure, proxy):
phost, pport, pauth = get_proxy_info(hostname, is_secure, **options) phost, pport, pauth = get_proxy_info(hostname, is_secure,
proxy.host, proxy.port, proxy.auth, proxy.no_proxy)
if not phost: if not phost:
addrinfo_list = socket.getaddrinfo(hostname, port, 0, 0, socket.SOL_TCP) addrinfo_list = socket.getaddrinfo(hostname, port, 0, 0, socket.SOL_TCP)
return addrinfo_list, False, None return addrinfo_list, False, None

View File

@@ -38,9 +38,18 @@ if hasattr(socket, "TCP_KEEPCNT"):
_default_timeout = None _default_timeout = None
__all__ = ["DEFAULT_SOCKET_OPTION", "setdefaulttimeout", "getdefaulttimeout", __all__ = ["DEFAULT_SOCKET_OPTION", "sock_opt", "setdefaulttimeout", "getdefaulttimeout",
"recv", "recv_line", "send"] "recv", "recv_line", "send"]
class sock_opt(object):
def __init__(self, sockopt, sslopt):
if sockopt is None:
sockopt = []
if sslopt is None:
sslopt = {}
self.sockopt = sockopt
self.sslopt = sslopt
self.timeout = None
def setdefaulttimeout(timeout): def setdefaulttimeout(timeout):
""" """

View File

@@ -82,7 +82,8 @@ def _is_no_proxy_host(hostname, no_proxy):
return hostname in no_proxy return hostname in no_proxy
def get_proxy_info(hostname, is_secure, **options): def get_proxy_info(hostname, is_secure,
proxy_host=None, proxy_port=0, proxy_auth=None, no_proxy=None):
""" """
try to retrieve proxy host and port from environment try to retrieve proxy host and port from environment
if not provided in options. if not provided in options.
@@ -103,14 +104,13 @@ def get_proxy_info(hostname, is_secure, **options):
tuple of username and password. tuple of username and password.
defualt is None defualt is None
""" """
if _is_no_proxy_host(hostname, options.get("http_no_proxy", None)): if _is_no_proxy_host(hostname, no_proxy):
return None, 0, None return None, 0, None
http_proxy_host = options.get("http_proxy_host", None) if proxy_host:
if http_proxy_host: port = proxy_port
port = options.get("http_proxy_port", 0) auth = proxy_auth
auth = options.get("http_proxy_auth", None) return proxy_host, port, auth
return http_proxy_host, port, auth
env_keys = ["http_proxy"] env_keys = ["http_proxy"]
if is_secure: if is_secure:

View File

@@ -565,23 +565,23 @@ class ProxyInfoTest(unittest.TestCase):
def testProxyFromArgs(self): def testProxyFromArgs(self):
self.assertEqual(get_proxy_info("echo.websocket.org", False, http_proxy_host="localhost"), ("localhost", 0, None)) self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost"), ("localhost", 0, None))
self.assertEqual(get_proxy_info("echo.websocket.org", False, http_proxy_host="localhost", http_proxy_port=3128), ("localhost", 3128, None)) self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128), ("localhost", 3128, None))
self.assertEqual(get_proxy_info("echo.websocket.org", True, http_proxy_host="localhost"), ("localhost", 0, None)) self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost"), ("localhost", 0, None))
self.assertEqual(get_proxy_info("echo.websocket.org", True, http_proxy_host="localhost", http_proxy_port=3128), ("localhost", 3128, None)) self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128), ("localhost", 3128, None))
self.assertEqual(get_proxy_info("echo.websocket.org", False, http_proxy_host="localhost", http_proxy_auth=("a", "b")), self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_auth=("a", "b")),
("localhost", 0, ("a", "b"))) ("localhost", 0, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", False, http_proxy_host="localhost", http_proxy_port=3128, http_proxy_auth=("a", "b")), self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128, proxy_auth=("a", "b")),
("localhost", 3128, ("a", "b"))) ("localhost", 3128, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, http_proxy_host="localhost", http_proxy_auth=("a", "b")), self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_auth=("a", "b")),
("localhost", 0, ("a", "b"))) ("localhost", 0, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, http_proxy_host="localhost", http_proxy_port=3128, http_proxy_auth=("a", "b")), self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, proxy_auth=("a", "b")),
("localhost", 3128, ("a", "b"))) ("localhost", 3128, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, http_proxy_host="localhost", http_proxy_port=3128, http_no_proxy=["example.com"], http_proxy_auth=("a", "b")), self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, no_proxy=["example.com"], proxy_auth=("a", "b")),
("localhost", 3128, ("a", "b"))) ("localhost", 3128, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, http_proxy_host="localhost", http_proxy_port=3128, http_no_proxy=["echo.websocket.org"], http_proxy_auth=("a", "b")), self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, no_proxy=["echo.websocket.org"], proxy_auth=("a", "b")),
(None, 0, None)) (None, 0, None))