From 6275781edf211ff6a0f6f55cb8317c2280e900e2 Mon Sep 17 00:00:00 2001 From: dwelch91 Date: Sun, 7 Jul 2013 15:53:21 -0700 Subject: [PATCH] Various fixes and improvements. --- websocket.py | 97 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 56 insertions(+), 41 deletions(-) diff --git a/websocket.py b/websocket.py index fb70f45..baf4d72 100644 --- a/websocket.py +++ b/websocket.py @@ -21,10 +21,13 @@ Copyright (C) 2010 Hiroki Ohtani(liris) import socket + try: import ssl -except: - pass + HAVE_SSL = True +except ImportError: + HAVE_SSL = False + from urlparse import urlparse import os import array @@ -163,7 +166,7 @@ def create_connection(url, timeout=None, **options): Passing optional timeout parameter will set the timeout on the socket. If no timeout is supplied, the global default timeout setting returned by getdefauttimeout() is used. You can customize using 'options'. - If you set "header" dict object, you can set your own custom header. + If you set "header" list object, you can set your own custom header. >>> conn = create_connection("ws://echo.websocket.org/", ... header=["User-Agent: MyProgram", @@ -176,10 +179,10 @@ def create_connection(url, timeout=None, **options): options: current support option is only "header". if you set header as dict value, the custom HTTP headers are added. """ - sockopt = options.get("sockopt", ()) + sockopt = options.get("sockopt", []) sslopt = options.get("sslopt", {}) - websock = WebSocket(sockopt=sockopt) - websock.settimeout(timeout != None and timeout or default_timeout) + websock = WebSocket(sockopt=sockopt, sslopt=sslopt) + websock.settimeout(timeout if timeout is not None else default_timeout) websock.connect(url, **options) return websock @@ -195,34 +198,29 @@ def _create_sec_websocket_key(): uid = uuid.uuid4() return base64.encodestring(uid.bytes).strip() + _HEADERS_TO_CHECK = { "upgrade": "websocket", "connection": "upgrade", } -class _SSLSocketWrapper(object): - def __init__(self, sock, sslopt={}): - self.ssl = ssl.wrap_socket(sock, **sslopt) +if HAVE_SSL: + class _SSLSocketWrapper(object): + def __init__(self, sock, sslopt=None): + if sslopt is None: + sslopt = {} + self.ssl = ssl.wrap_socket(sock, **sslopt) - def recv(self, bufsize): - return self.ssl.read(bufsize) + def recv(self, bufsize): + return self.ssl.read(bufsize) - def send(self, payload): - return self.ssl.write(payload) + def send(self, payload): + return self.ssl.write(payload) - def fileno(self): - return self.ssl.fileno() + def fileno(self): + return self.ssl.fileno() -_BOOL_VALUES = (0, 1) - - -def _is_bool(*values): - for v in values: - if v not in _BOOL_VALUES: - return False - - return True class ABNF(object): @@ -257,8 +255,8 @@ class ABNF(object): LENGTH_16 = 1 << 16 LENGTH_63 = 1 << 63 - def __init__(self, fin = 0, rsv1 = 0, rsv2 = 0, rsv3 = 0, - opcode = OPCODE_TEXT, mask = 1, data = ""): + def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0, + opcode=OPCODE_TEXT, mask=1, data=""): """ Constructor for ABNF. please check RFC for arguments. @@ -292,7 +290,7 @@ class ABNF(object): """ format this object to string(byte array) to send data to server. """ - if not _is_bool(self.fin, self.rsv1, self.rsv2, self.rsv3): + if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]): raise ValueError("not 0 or 1") if self.opcode not in ABNF.OPCODES: raise ValueError("Invalid OPCODE") @@ -363,10 +361,14 @@ class WebSocket(object): sslopt: dict object for ssl socket option. """ - def __init__(self, get_mask_key=None, sockopt=(), sslopt={}): + def __init__(self, get_mask_key=None, sockopt=None, sslopt=None): """ Initalize WebSocket object. """ + if sockopt is None: + sockopt = [] + if sslopt is None: + sslopt = {} self.connected = False self.io_sock = self.sock = socket.socket() for opts in sockopt: @@ -389,6 +391,12 @@ class WebSocket(object): """ self.get_mask_key = func + def gettimeout(self): + """ + Get the websocket timeout(second). + """ + return self.sock.gettimeout() + def settimeout(self, timeout): """ Set the timeout to the websocket. @@ -397,11 +405,7 @@ class WebSocket(object): """ self.sock.settimeout(timeout) - def gettimeout(self): - """ - Get the websocket timeout(second). - """ - return self.sock.gettimeout() + timeout = property(gettimeout, settimeout) def connect(self, url, **options): """ @@ -427,7 +431,11 @@ class WebSocket(object): # TODO: we need to support proxy self.sock.connect((hostname, port)) if is_secure: - self.io_sock = _SSLSocketWrapper(self.sock, self.sslopt) + if HAVE_SSL: + self.io_sock = _SSLSocketWrapper(self.sock, self.sslopt) + else: + raise WebSocketException("SSL not available.") + self._handshake(hostname, port, resource, **options) def _handshake(self, host, port, resource, **options): @@ -522,7 +530,7 @@ class WebSocket(object): return status, headers - def send(self, payload, opcode = ABNF.OPCODE_TEXT): + def send(self, payload, opcode=ABNF.OPCODE_TEXT): """ Send the data as string. @@ -542,7 +550,10 @@ class WebSocket(object): if traceEnabled: logger.debug("send: " + repr(data)) - def ping(self, payload = ""): + def send_binary(self, payload): + return self.send(payload, ABNF.OPCODE_BINARY) + + def ping(self, payload=""): """ send ping data. @@ -595,7 +606,7 @@ class WebSocket(object): """ header_bytes = self._recv_strict(2) if not header_bytes: - return None + return b1 = ord(header_bytes[0]) fin = b1 >> 7 & 1 rsv1 = b1 >> 6 & 1 @@ -628,7 +639,7 @@ class WebSocket(object): frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, mask, data) return frame - def send_close(self, status = STATUS_NORMAL, reason = ""): + def send_close(self, status=STATUS_NORMAL, reason=""): """ send close data to the server. @@ -640,7 +651,7 @@ class WebSocket(object): raise ValueError("code is invalid range") self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE) - def close(self, status = STATUS_NORMAL, reason = ""): + def close(self, status=STATUS_NORMAL, reason=""): """ Close Websocket object @@ -739,7 +750,7 @@ class WebSocketApp(object): self.sock = None - def send(self, data, opcode = ABNF.OPCODE_TEXT): + def send(self, data, opcode=ABNF.OPCODE_TEXT): """ send message. data: message to send. If you set opcode to OPCODE_TEXT, data must be utf-8 string or unicode. @@ -755,7 +766,7 @@ class WebSocketApp(object): self.keep_running = False self.sock.close() - def run_forever(self, sockopt=(), sslopt={}): + def run_forever(self, sockopt=None, sslopt=None): """ run event loop for WebSocket framework. This loop is infinite loop and is alive during websocket is available. @@ -763,6 +774,10 @@ class WebSocketApp(object): sockopt must be tuple and each element is argument of sock.setscokopt. sslopt: ssl socket optional dict. """ + if sockopt is None: + sockopt = [] + if sslopt is None: + sslopt = {} if self.sock: raise WebSocketException("socket is already opened") try: