tiny modification

This commit is contained in:
liris
2013-02-12 15:59:11 +09:00
parent 46624e2ac6
commit d433ca86f3
2 changed files with 45 additions and 36 deletions

View File

@@ -138,7 +138,6 @@ class WebSocketTest(unittest.TestCase):
del header["connection"] del header["connection"]
self.assertEquals(sock._validate_header(header, key), False) self.assertEquals(sock._validate_header(header, key), False)
header = required_header.copy() header = required_header.copy()
header["sec-websocket-accept"] = "something" header["sec-websocket-accept"] = "something"
self.assertEquals(sock._validate_header(header, key), False) self.assertEquals(sock._validate_header(header, key), False)
@@ -182,7 +181,7 @@ class WebSocketTest(unittest.TestCase):
self.assertEquals(data, "Hello") self.assertEquals(data, "Hello")
def testWebSocket(self): def testWebSocket(self):
s = ws.create_connection("ws://echo.websocket.org/") #ws://localhost:8080/echo") s = ws.create_connection("ws://echo.websocket.org/")
self.assertNotEquals(s, None) self.assertNotEquals(s, None)
s.send("Hello, World") s.send("Hello, World")
result = s.recv() result = s.recv()
@@ -236,6 +235,7 @@ class WebSocketTest(unittest.TestCase):
u = uuid.UUID(bytes=base64.b64decode(key)) u = uuid.UUID(bytes=base64.b64decode(key))
self.assertEquals(4, u.version) self.assertEquals(4, u.version)
class WebSocketAppTest(unittest.TestCase): class WebSocketAppTest(unittest.TestCase):
class NotSetYet(object): class NotSetYet(object):
@@ -250,7 +250,6 @@ class WebSocketAppTest(unittest.TestCase):
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet() WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
def tearDown(self): def tearDown(self):
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet() WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet() WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet() WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
@@ -305,7 +304,5 @@ class WebSocketAppTest(unittest.TestCase):
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'. # Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
self.assertEquals(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func)) self.assertEquals(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -58,12 +58,14 @@ STATUS_TLS_HANDSHAKE_ERROR = 1015
logger = logging.getLogger() logger = logging.getLogger()
class WebSocketException(Exception): class WebSocketException(Exception):
""" """
websocket exeception class. websocket exeception class.
""" """
pass pass
class WebSocketConnectionClosedException(WebSocketException): class WebSocketConnectionClosedException(WebSocketException):
""" """
If remote host closed the connection or some network error happened, If remote host closed the connection or some network error happened,
@@ -74,6 +76,7 @@ class WebSocketConnectionClosedException(WebSocketException):
default_timeout = None default_timeout = None
traceEnabled = False traceEnabled = False
def enableTrace(tracable): def enableTrace(tracable):
""" """
turn on/off the tracability. turn on/off the tracability.
@@ -87,6 +90,7 @@ def enableTrace(tracable):
logger.addHandler(logging.StreamHandler()) logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
def setdefaulttimeout(timeout): def setdefaulttimeout(timeout):
""" """
Set the global timeout setting to connect. Set the global timeout setting to connect.
@@ -96,12 +100,14 @@ def setdefaulttimeout(timeout):
global default_timeout global default_timeout
default_timeout = timeout default_timeout = timeout
def getdefaulttimeout(): def getdefaulttimeout():
""" """
Return the global timeout setting(second) to connect. Return the global timeout setting(second) to connect.
""" """
return default_timeout return default_timeout
def _parse_url(url): def _parse_url(url):
""" """
parse url and the result is tuple of parse url and the result is tuple of
@@ -144,6 +150,7 @@ def _parse_url(url):
return (hostname, port, resource, is_secure) return (hostname, port, resource, is_secure)
def create_connection(url, timeout=None, **options): def create_connection(url, timeout=None, **options):
""" """
connect to url and return websocket object. connect to url and return websocket object.
@@ -177,6 +184,7 @@ _MAX_CHAR_BYTE = (1<<8) -1
# ref. Websocket gets an update, and it breaks stuff. # ref. Websocket gets an update, and it breaks stuff.
# http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html # http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html
def _create_sec_websocket_key(): def _create_sec_websocket_key():
uid = uuid.uuid4() uid = uuid.uuid4()
return base64.encodestring(uid.bytes).strip() return base64.encodestring(uid.bytes).strip()
@@ -186,6 +194,7 @@ _HEADERS_TO_CHECK = {
"connection": "upgrade", "connection": "upgrade",
} }
class _SSLSocketWrapper(object): class _SSLSocketWrapper(object):
def __init__(self, sock): def __init__(self, sock):
self.ssl = socket.ssl(sock) self.ssl = socket.ssl(sock)
@@ -197,6 +206,8 @@ class _SSLSocketWrapper(object):
return self.ssl.write(payload) return self.ssl.write(payload)
_BOOL_VALUES = (0, 1) _BOOL_VALUES = (0, 1)
def _is_bool(*values): def _is_bool(*values):
for v in values: for v in values:
if v not in _BOOL_VALUES: if v not in _BOOL_VALUES:
@@ -204,6 +215,7 @@ def _is_bool(*values):
return True return True
class ABNF(object): class ABNF(object):
""" """
ABNF frame class. ABNF frame class.
@@ -316,6 +328,7 @@ class ABNF(object):
_d[i] ^= _m[i % 4] _d[i] ^= _m[i % 4]
return _d.tostring() return _d.tostring()
class WebSocket(object): class WebSocket(object):
""" """
Low level WebSocket interface. Low level WebSocket interface.
@@ -337,6 +350,7 @@ class WebSocket(object):
get_mask_key: a callable to produce new mask keys, see the set_mask_key get_mask_key: a callable to produce new mask keys, see the set_mask_key
function's docstring for more details function's docstring for more details
""" """
def __init__(self, get_mask_key = None): def __init__(self, get_mask_key = None):
""" """
Initalize WebSocket object. Initalize WebSocket object.
@@ -423,8 +437,8 @@ class WebSocket(object):
header_str = "\r\n".join(headers) header_str = "\r\n".join(headers)
sock.send(header_str) sock.send(header_str)
if traceEnabled: if traceEnabled:
logger.debug( "--- request header ---") logger.debug("--- request header ---")
logger.debug( header_str) logger.debug(header_str)
logger.debug("-----------------------") logger.debug("-----------------------")
status, resp_headers = self._read_headers() status, resp_headers = self._read_headers()
@@ -549,7 +563,6 @@ class WebSocket(object):
elif frame.opcode == ABNF.OPCODE_PING: elif frame.opcode == ABNF.OPCODE_PING:
self.pong(frame.data) self.pong(frame.data)
def recv_frame(self): def recv_frame(self):
""" """
recieve data as frame from server. recieve data as frame from server.
@@ -603,8 +616,6 @@ class WebSocket(object):
raise ValueError("code is invalid range") raise ValueError("code is invalid range")
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE) self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
def close(self, status = STATUS_NORMAL, reason = ""): def close(self, status = STATUS_NORMAL, reason = ""):
""" """
Close Websocket object Close Websocket object
@@ -662,6 +673,7 @@ class WebSocket(object):
break break
return "".join(line) return "".join(line)
class WebSocketApp(object): class WebSocketApp(object):
""" """
Higher level of APIs are provided. Higher level of APIs are provided.