Merge pull request #36 from dwelch91/master

Fix for sslopt and some other improvements
This commit is contained in:
liris
2013-07-16 16:53:03 -07:00

View File

@@ -21,10 +21,13 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import socket
try:
import ssl
except:
pass
HAVE_SSL = True
except ImportError:
HAVE_SSL = False
from urlparse import urlparse
import os
import array
@@ -163,7 +166,7 @@ def create_connection(url, timeout=None, **options):
Passing optional timeout parameter will set the timeout on the socket.
If no timeout is supplied, the global default timeout setting returned by getdefauttimeout() is used.
You can customize using 'options'.
If you set "header" dict object, you can set your own custom header.
If you set "header" list object, you can set your own custom header.
>>> conn = create_connection("ws://echo.websocket.org/",
... header=["User-Agent: MyProgram",
@@ -176,10 +179,10 @@ def create_connection(url, timeout=None, **options):
options: current support option is only "header".
if you set header as dict value, the custom HTTP headers are added.
"""
sockopt = options.get("sockopt", ())
sockopt = options.get("sockopt", [])
sslopt = options.get("sslopt", {})
websock = WebSocket(sockopt=sockopt)
websock.settimeout(timeout != None and timeout or default_timeout)
websock = WebSocket(sockopt=sockopt, sslopt=sslopt)
websock.settimeout(timeout if timeout is not None else default_timeout)
websock.connect(url, **options)
return websock
@@ -195,34 +198,29 @@ def _create_sec_websocket_key():
uid = uuid.uuid4()
return base64.encodestring(uid.bytes).strip()
_HEADERS_TO_CHECK = {
"upgrade": "websocket",
"connection": "upgrade",
}
class _SSLSocketWrapper(object):
def __init__(self, sock, sslopt={}):
self.ssl = ssl.wrap_socket(sock, **sslopt)
if HAVE_SSL:
class _SSLSocketWrapper(object):
def __init__(self, sock, sslopt=None):
if sslopt is None:
sslopt = {}
self.ssl = ssl.wrap_socket(sock, **sslopt)
def recv(self, bufsize):
return self.ssl.read(bufsize)
def recv(self, bufsize):
return self.ssl.read(bufsize)
def send(self, payload):
return self.ssl.write(payload)
def send(self, payload):
return self.ssl.write(payload)
def fileno(self):
return self.ssl.fileno()
def fileno(self):
return self.ssl.fileno()
_BOOL_VALUES = (0, 1)
def _is_bool(*values):
for v in values:
if v not in _BOOL_VALUES:
return False
return True
class ABNF(object):
@@ -257,8 +255,8 @@ class ABNF(object):
LENGTH_16 = 1 << 16
LENGTH_63 = 1 << 63
def __init__(self, fin = 0, rsv1 = 0, rsv2 = 0, rsv3 = 0,
opcode = OPCODE_TEXT, mask = 1, data = ""):
def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0,
opcode=OPCODE_TEXT, mask=1, data=""):
"""
Constructor for ABNF.
please check RFC for arguments.
@@ -292,7 +290,7 @@ class ABNF(object):
"""
format this object to string(byte array) to send data to server.
"""
if not _is_bool(self.fin, self.rsv1, self.rsv2, self.rsv3):
if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
raise ValueError("not 0 or 1")
if self.opcode not in ABNF.OPCODES:
raise ValueError("Invalid OPCODE")
@@ -363,10 +361,14 @@ class WebSocket(object):
sslopt: dict object for ssl socket option.
"""
def __init__(self, get_mask_key=None, sockopt=(), sslopt={}):
def __init__(self, get_mask_key=None, sockopt=None, sslopt=None):
"""
Initalize WebSocket object.
"""
if sockopt is None:
sockopt = []
if sslopt is None:
sslopt = {}
self.connected = False
self.io_sock = self.sock = socket.socket()
for opts in sockopt:
@@ -389,6 +391,12 @@ class WebSocket(object):
"""
self.get_mask_key = func
def gettimeout(self):
"""
Get the websocket timeout(second).
"""
return self.sock.gettimeout()
def settimeout(self, timeout):
"""
Set the timeout to the websocket.
@@ -397,11 +405,7 @@ class WebSocket(object):
"""
self.sock.settimeout(timeout)
def gettimeout(self):
"""
Get the websocket timeout(second).
"""
return self.sock.gettimeout()
timeout = property(gettimeout, settimeout)
def connect(self, url, **options):
"""
@@ -427,7 +431,11 @@ class WebSocket(object):
# TODO: we need to support proxy
self.sock.connect((hostname, port))
if is_secure:
self.io_sock = _SSLSocketWrapper(self.sock, self.sslopt)
if HAVE_SSL:
self.io_sock = _SSLSocketWrapper(self.sock, self.sslopt)
else:
raise WebSocketException("SSL not available.")
self._handshake(hostname, port, resource, **options)
def _handshake(self, host, port, resource, **options):
@@ -522,7 +530,7 @@ class WebSocket(object):
return status, headers
def send(self, payload, opcode = ABNF.OPCODE_TEXT):
def send(self, payload, opcode=ABNF.OPCODE_TEXT):
"""
Send the data as string.
@@ -542,7 +550,10 @@ class WebSocket(object):
if traceEnabled:
logger.debug("send: " + repr(data))
def ping(self, payload = ""):
def send_binary(self, payload):
return self.send(payload, ABNF.OPCODE_BINARY)
def ping(self, payload=""):
"""
send ping data.
@@ -595,7 +606,7 @@ class WebSocket(object):
"""
header_bytes = self._recv_strict(2)
if not header_bytes:
return None
return
b1 = ord(header_bytes[0])
fin = b1 >> 7 & 1
rsv1 = b1 >> 6 & 1
@@ -628,7 +639,7 @@ class WebSocket(object):
frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, mask, data)
return frame
def send_close(self, status = STATUS_NORMAL, reason = ""):
def send_close(self, status=STATUS_NORMAL, reason=""):
"""
send close data to the server.
@@ -640,7 +651,7 @@ class WebSocket(object):
raise ValueError("code is invalid range")
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
def close(self, status = STATUS_NORMAL, reason = ""):
def close(self, status=STATUS_NORMAL, reason=""):
"""
Close Websocket object
@@ -739,7 +750,7 @@ class WebSocketApp(object):
self.sock = None
def send(self, data, opcode = ABNF.OPCODE_TEXT):
def send(self, data, opcode=ABNF.OPCODE_TEXT):
"""
send message.
data: message to send. If you set opcode to OPCODE_TEXT, data must be utf-8 string or unicode.
@@ -755,7 +766,7 @@ class WebSocketApp(object):
self.keep_running = False
self.sock.close()
def run_forever(self, sockopt=(), sslopt={}):
def run_forever(self, sockopt=None, sslopt=None):
"""
run event loop for WebSocket framework.
This loop is infinite loop and is alive during websocket is available.
@@ -763,6 +774,10 @@ class WebSocketApp(object):
sockopt must be tuple and each element is argument of sock.setscokopt.
sslopt: ssl socket optional dict.
"""
if sockopt is None:
sockopt = []
if sslopt is None:
sslopt = {}
if self.sock:
raise WebSocketException("socket is already opened")
try: