tiny modification

This commit is contained in:
liris
2013-02-12 15:59:11 +09:00
parent 46624e2ac6
commit d433ca86f3
2 changed files with 45 additions and 36 deletions

View File

@@ -138,7 +138,6 @@ class WebSocketTest(unittest.TestCase):
del header["connection"]
self.assertEquals(sock._validate_header(header, key), False)
header = required_header.copy()
header["sec-websocket-accept"] = "something"
self.assertEquals(sock._validate_header(header, key), False)
@@ -151,7 +150,7 @@ class WebSocketTest(unittest.TestCase):
status, header = sock._read_headers()
self.assertEquals(status, 101)
self.assertEquals(header["connection"], "upgrade")
sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt")
self.assertRaises(ws.WebSocketException, sock._read_headers)
@@ -176,13 +175,13 @@ class WebSocketTest(unittest.TestCase):
s.set_data("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
data = sock.recv()
self.assertEquals(data, "こんにちは")
s.set_data("\x81\x85abcd)\x07\x0f\x08\x0e")
data = sock.recv()
self.assertEquals(data, "Hello")
def testWebSocket(self):
s = ws.create_connection("ws://echo.websocket.org/") #ws://localhost:8080/echo")
s = ws.create_connection("ws://echo.websocket.org/")
self.assertNotEquals(s, None)
s.send("Hello, World")
result = s.recv()
@@ -194,14 +193,14 @@ class WebSocketTest(unittest.TestCase):
s.close()
def testPingPong(self):
s = ws.create_connection("ws://echo.websocket.org/")
s = ws.create_connection("ws://echo.websocket.org/")
self.assertNotEquals(s, None)
s.ping("Hello")
s.pong("Hi")
s.close()
def testSecureWebSocket(self):
s = ws.create_connection("wss://echo.websocket.org/")
s = ws.create_connection("wss://echo.websocket.org/")
self.assertNotEquals(s, None)
self.assert_(isinstance(s.io_sock, ws._SSLSocketWrapper))
s.send("Hello, World")
@@ -213,7 +212,7 @@ class WebSocketTest(unittest.TestCase):
s.close()
def testWebSocketWihtCustomHeader(self):
s = ws.create_connection("ws://echo.websocket.org/",
s = ws.create_connection("ws://echo.websocket.org/",
headers={"User-Agent": "PythonWebsocketClient"})
self.assertNotEquals(s, None)
s.send("Hello, World")
@@ -223,12 +222,12 @@ class WebSocketTest(unittest.TestCase):
def testAfterClose(self):
from socket import error
s = ws.create_connection("ws://echo.websocket.org/")
s = ws.create_connection("ws://echo.websocket.org/")
self.assertNotEquals(s, None)
s.close()
self.assertRaises(error, s.send, "Hello")
self.assertRaises(error, s.recv)
def testUUID4(self):
""" WebSocket key should be a UUID4.
"""
@@ -236,25 +235,25 @@ class WebSocketTest(unittest.TestCase):
u = uuid.UUID(bytes=base64.b64decode(key))
self.assertEquals(4, u.version)
class WebSocketAppTest(unittest.TestCase):
class NotSetYet(object):
""" A marker class for signalling that a value hasn't been set yet.
"""
def setUp(self):
ws.enableTrace(TRACABLE)
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
def tearDown(self):
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
def testKeepRunning(self):
""" A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context).
@@ -266,46 +265,44 @@ class WebSocketAppTest(unittest.TestCase):
"""
WebSocketAppTest.keep_running_open = self.keep_running
self.close()
def on_close(self, *args, **kwargs):
""" Set the keep_running flag for the test to use.
"""
WebSocketAppTest.keep_running_close = self.keep_running
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close)
app.run_forever()
self.assertFalse(isinstance(WebSocketAppTest.keep_running_open,
self.assertFalse(isinstance(WebSocketAppTest.keep_running_open,
WebSocketAppTest.NotSetYet))
self.assertFalse(isinstance(WebSocketAppTest.keep_running_close,
self.assertFalse(isinstance(WebSocketAppTest.keep_running_close,
WebSocketAppTest.NotSetYet))
self.assertEquals(True, WebSocketAppTest.keep_running_open)
self.assertEquals(False, WebSocketAppTest.keep_running_close)
def testSockMaskKey(self):
""" A WebSocketApp should forward the received mask_key function down
to the actual socket.
"""
def my_mask_key_func():
pass
def on_open(self, *args, **kwargs):
""" Set the value so the test can use it later on and immediately
""" Set the value so the test can use it later on and immediately
close the connection.
"""
WebSocketAppTest.get_mask_key_id = id(self.get_mask_key)
self.close()
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, get_mask_key=my_mask_key_func)
app.run_forever()
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
self.assertEquals(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func))
if __name__ == "__main__":
unittest.main()

View File

@@ -58,12 +58,14 @@ STATUS_TLS_HANDSHAKE_ERROR = 1015
logger = logging.getLogger()
class WebSocketException(Exception):
"""
websocket exeception class.
"""
pass
class WebSocketConnectionClosedException(WebSocketException):
"""
If remote host closed the connection or some network error happened,
@@ -74,6 +76,7 @@ class WebSocketConnectionClosedException(WebSocketException):
default_timeout = None
traceEnabled = False
def enableTrace(tracable):
"""
turn on/off the tracability.
@@ -87,6 +90,7 @@ def enableTrace(tracable):
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.DEBUG)
def setdefaulttimeout(timeout):
"""
Set the global timeout setting to connect.
@@ -96,12 +100,14 @@ def setdefaulttimeout(timeout):
global default_timeout
default_timeout = timeout
def getdefaulttimeout():
"""
Return the global timeout setting(second) to connect.
"""
return default_timeout
def _parse_url(url):
"""
parse url and the result is tuple of
@@ -130,7 +136,7 @@ def _parse_url(url):
elif scheme == "wss":
is_secure = True
if not port:
port = 443
port = 443
else:
raise ValueError("scheme %s is invalid" % scheme)
@@ -144,6 +150,7 @@ def _parse_url(url):
return (hostname, port, resource, is_secure)
def create_connection(url, timeout=None, **options):
"""
connect to url and return websocket object.
@@ -177,6 +184,7 @@ _MAX_CHAR_BYTE = (1<<8) -1
# ref. Websocket gets an update, and it breaks stuff.
# http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html
def _create_sec_websocket_key():
uid = uuid.uuid4()
return base64.encodestring(uid.bytes).strip()
@@ -186,6 +194,7 @@ _HEADERS_TO_CHECK = {
"connection": "upgrade",
}
class _SSLSocketWrapper(object):
def __init__(self, sock):
self.ssl = socket.ssl(sock)
@@ -197,6 +206,8 @@ class _SSLSocketWrapper(object):
return self.ssl.write(payload)
_BOOL_VALUES = (0, 1)
def _is_bool(*values):
for v in values:
if v not in _BOOL_VALUES:
@@ -204,6 +215,7 @@ def _is_bool(*values):
return True
class ABNF(object):
"""
ABNF frame class.
@@ -316,6 +328,7 @@ class ABNF(object):
_d[i] ^= _m[i % 4]
return _d.tostring()
class WebSocket(object):
"""
Low level WebSocket interface.
@@ -337,6 +350,7 @@ class WebSocket(object):
get_mask_key: a callable to produce new mask keys, see the set_mask_key
function's docstring for more details
"""
def __init__(self, get_mask_key = None):
"""
Initalize WebSocket object.
@@ -423,8 +437,8 @@ class WebSocket(object):
header_str = "\r\n".join(headers)
sock.send(header_str)
if traceEnabled:
logger.debug( "--- request header ---")
logger.debug( header_str)
logger.debug("--- request header ---")
logger.debug(header_str)
logger.debug("-----------------------")
status, resp_headers = self._read_headers()
@@ -549,7 +563,6 @@ class WebSocket(object):
elif frame.opcode == ABNF.OPCODE_PING:
self.pong(frame.data)
def recv_frame(self):
"""
recieve data as frame from server.
@@ -603,8 +616,6 @@ class WebSocket(object):
raise ValueError("code is invalid range")
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
def close(self, status = STATUS_NORMAL, reason = ""):
"""
Close Websocket object
@@ -662,6 +673,7 @@ class WebSocket(object):
break
return "".join(line)
class WebSocketApp(object):
"""
Higher level of APIs are provided.
@@ -700,7 +712,7 @@ class WebSocketApp(object):
def send(self, data, opcode = ABNF.OPCODE_TEXT):
"""
send message.
send message.
data: message to send. If you set opcode to OPCODE_TEXT, data must be utf-8 string or unicode.
opcode: operation code of data. default is OPCODE_TEXT.
"""
@@ -753,6 +765,6 @@ if __name__ == "__main__":
ws.send("Hello, World")
print "Sent"
print "Receiving..."
result = ws.recv()
result = ws.recv()
print "Received '%s'" % result
ws.close()