Various fixes and improvements.

This commit is contained in:
dwelch91
2013-07-07 15:53:21 -07:00
parent 8c9466e0ec
commit 6275781edf

View File

@@ -21,10 +21,13 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import socket
try:
import ssl
except:
pass
HAVE_SSL = True
except ImportError:
HAVE_SSL = False
from urlparse import urlparse
import os
import array
@@ -163,7 +166,7 @@ def create_connection(url, timeout=None, **options):
Passing optional timeout parameter will set the timeout on the socket.
If no timeout is supplied, the global default timeout setting returned by getdefauttimeout() is used.
You can customize using 'options'.
If you set "header" dict object, you can set your own custom header.
If you set "header" list object, you can set your own custom header.
>>> conn = create_connection("ws://echo.websocket.org/",
... header=["User-Agent: MyProgram",
@@ -176,10 +179,10 @@ def create_connection(url, timeout=None, **options):
options: current support option is only "header".
if you set header as dict value, the custom HTTP headers are added.
"""
sockopt = options.get("sockopt", ())
sockopt = options.get("sockopt", [])
sslopt = options.get("sslopt", {})
websock = WebSocket(sockopt=sockopt)
websock.settimeout(timeout != None and timeout or default_timeout)
websock = WebSocket(sockopt=sockopt, sslopt=sslopt)
websock.settimeout(timeout if timeout is not None else default_timeout)
websock.connect(url, **options)
return websock
@@ -195,14 +198,18 @@ def _create_sec_websocket_key():
uid = uuid.uuid4()
return base64.encodestring(uid.bytes).strip()
_HEADERS_TO_CHECK = {
"upgrade": "websocket",
"connection": "upgrade",
}
if HAVE_SSL:
class _SSLSocketWrapper(object):
def __init__(self, sock, sslopt={}):
def __init__(self, sock, sslopt=None):
if sslopt is None:
sslopt = {}
self.ssl = ssl.wrap_socket(sock, **sslopt)
def recv(self, bufsize):
@@ -214,15 +221,6 @@ class _SSLSocketWrapper(object):
def fileno(self):
return self.ssl.fileno()
_BOOL_VALUES = (0, 1)
def _is_bool(*values):
for v in values:
if v not in _BOOL_VALUES:
return False
return True
class ABNF(object):
@@ -292,7 +290,7 @@ class ABNF(object):
"""
format this object to string(byte array) to send data to server.
"""
if not _is_bool(self.fin, self.rsv1, self.rsv2, self.rsv3):
if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
raise ValueError("not 0 or 1")
if self.opcode not in ABNF.OPCODES:
raise ValueError("Invalid OPCODE")
@@ -363,10 +361,14 @@ class WebSocket(object):
sslopt: dict object for ssl socket option.
"""
def __init__(self, get_mask_key=None, sockopt=(), sslopt={}):
def __init__(self, get_mask_key=None, sockopt=None, sslopt=None):
"""
Initalize WebSocket object.
"""
if sockopt is None:
sockopt = []
if sslopt is None:
sslopt = {}
self.connected = False
self.io_sock = self.sock = socket.socket()
for opts in sockopt:
@@ -389,6 +391,12 @@ class WebSocket(object):
"""
self.get_mask_key = func
def gettimeout(self):
"""
Get the websocket timeout(second).
"""
return self.sock.gettimeout()
def settimeout(self, timeout):
"""
Set the timeout to the websocket.
@@ -397,11 +405,7 @@ class WebSocket(object):
"""
self.sock.settimeout(timeout)
def gettimeout(self):
"""
Get the websocket timeout(second).
"""
return self.sock.gettimeout()
timeout = property(gettimeout, settimeout)
def connect(self, url, **options):
"""
@@ -427,7 +431,11 @@ class WebSocket(object):
# TODO: we need to support proxy
self.sock.connect((hostname, port))
if is_secure:
if HAVE_SSL:
self.io_sock = _SSLSocketWrapper(self.sock, self.sslopt)
else:
raise WebSocketException("SSL not available.")
self._handshake(hostname, port, resource, **options)
def _handshake(self, host, port, resource, **options):
@@ -542,6 +550,9 @@ class WebSocket(object):
if traceEnabled:
logger.debug("send: " + repr(data))
def send_binary(self, payload):
return self.send(payload, ABNF.OPCODE_BINARY)
def ping(self, payload=""):
"""
send ping data.
@@ -595,7 +606,7 @@ class WebSocket(object):
"""
header_bytes = self._recv_strict(2)
if not header_bytes:
return None
return
b1 = ord(header_bytes[0])
fin = b1 >> 7 & 1
rsv1 = b1 >> 6 & 1
@@ -755,7 +766,7 @@ class WebSocketApp(object):
self.keep_running = False
self.sock.close()
def run_forever(self, sockopt=(), sslopt={}):
def run_forever(self, sockopt=None, sslopt=None):
"""
run event loop for WebSocket framework.
This loop is infinite loop and is alive during websocket is available.
@@ -763,6 +774,10 @@ class WebSocketApp(object):
sockopt must be tuple and each element is argument of sock.setscokopt.
sslopt: ssl socket optional dict.
"""
if sockopt is None:
sockopt = []
if sslopt is None:
sslopt = {}
if self.sock:
raise WebSocketException("socket is already opened")
try: