tiny modification

This commit is contained in:
liris
2013-02-12 15:59:11 +09:00
parent 46624e2ac6
commit d433ca86f3
2 changed files with 45 additions and 36 deletions

View File

@@ -138,7 +138,6 @@ class WebSocketTest(unittest.TestCase):
del header["connection"] del header["connection"]
self.assertEquals(sock._validate_header(header, key), False) self.assertEquals(sock._validate_header(header, key), False)
header = required_header.copy() header = required_header.copy()
header["sec-websocket-accept"] = "something" header["sec-websocket-accept"] = "something"
self.assertEquals(sock._validate_header(header, key), False) self.assertEquals(sock._validate_header(header, key), False)
@@ -151,7 +150,7 @@ class WebSocketTest(unittest.TestCase):
status, header = sock._read_headers() status, header = sock._read_headers()
self.assertEquals(status, 101) self.assertEquals(status, 101)
self.assertEquals(header["connection"], "upgrade") self.assertEquals(header["connection"], "upgrade")
sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt") sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt")
self.assertRaises(ws.WebSocketException, sock._read_headers) self.assertRaises(ws.WebSocketException, sock._read_headers)
@@ -176,13 +175,13 @@ class WebSocketTest(unittest.TestCase):
s.set_data("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc") s.set_data("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
data = sock.recv() data = sock.recv()
self.assertEquals(data, "こんにちは") self.assertEquals(data, "こんにちは")
s.set_data("\x81\x85abcd)\x07\x0f\x08\x0e") s.set_data("\x81\x85abcd)\x07\x0f\x08\x0e")
data = sock.recv() data = sock.recv()
self.assertEquals(data, "Hello") self.assertEquals(data, "Hello")
def testWebSocket(self): def testWebSocket(self):
s = ws.create_connection("ws://echo.websocket.org/") #ws://localhost:8080/echo") s = ws.create_connection("ws://echo.websocket.org/")
self.assertNotEquals(s, None) self.assertNotEquals(s, None)
s.send("Hello, World") s.send("Hello, World")
result = s.recv() result = s.recv()
@@ -194,14 +193,14 @@ class WebSocketTest(unittest.TestCase):
s.close() s.close()
def testPingPong(self): def testPingPong(self):
s = ws.create_connection("ws://echo.websocket.org/") s = ws.create_connection("ws://echo.websocket.org/")
self.assertNotEquals(s, None) self.assertNotEquals(s, None)
s.ping("Hello") s.ping("Hello")
s.pong("Hi") s.pong("Hi")
s.close() s.close()
def testSecureWebSocket(self): def testSecureWebSocket(self):
s = ws.create_connection("wss://echo.websocket.org/") s = ws.create_connection("wss://echo.websocket.org/")
self.assertNotEquals(s, None) self.assertNotEquals(s, None)
self.assert_(isinstance(s.io_sock, ws._SSLSocketWrapper)) self.assert_(isinstance(s.io_sock, ws._SSLSocketWrapper))
s.send("Hello, World") s.send("Hello, World")
@@ -213,7 +212,7 @@ class WebSocketTest(unittest.TestCase):
s.close() s.close()
def testWebSocketWihtCustomHeader(self): def testWebSocketWihtCustomHeader(self):
s = ws.create_connection("ws://echo.websocket.org/", s = ws.create_connection("ws://echo.websocket.org/",
headers={"User-Agent": "PythonWebsocketClient"}) headers={"User-Agent": "PythonWebsocketClient"})
self.assertNotEquals(s, None) self.assertNotEquals(s, None)
s.send("Hello, World") s.send("Hello, World")
@@ -223,12 +222,12 @@ class WebSocketTest(unittest.TestCase):
def testAfterClose(self): def testAfterClose(self):
from socket import error from socket import error
s = ws.create_connection("ws://echo.websocket.org/") s = ws.create_connection("ws://echo.websocket.org/")
self.assertNotEquals(s, None) self.assertNotEquals(s, None)
s.close() s.close()
self.assertRaises(error, s.send, "Hello") self.assertRaises(error, s.send, "Hello")
self.assertRaises(error, s.recv) self.assertRaises(error, s.recv)
def testUUID4(self): def testUUID4(self):
""" WebSocket key should be a UUID4. """ WebSocket key should be a UUID4.
""" """
@@ -236,25 +235,25 @@ class WebSocketTest(unittest.TestCase):
u = uuid.UUID(bytes=base64.b64decode(key)) u = uuid.UUID(bytes=base64.b64decode(key))
self.assertEquals(4, u.version) self.assertEquals(4, u.version)
class WebSocketAppTest(unittest.TestCase): class WebSocketAppTest(unittest.TestCase):
class NotSetYet(object): class NotSetYet(object):
""" A marker class for signalling that a value hasn't been set yet. """ A marker class for signalling that a value hasn't been set yet.
""" """
def setUp(self): def setUp(self):
ws.enableTrace(TRACABLE) ws.enableTrace(TRACABLE)
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet() WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet() WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet() WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
def tearDown(self): def tearDown(self):
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet() WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet() WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet() WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
def testKeepRunning(self): def testKeepRunning(self):
""" A WebSocketApp should keep running as long as its self.keep_running """ A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context). is not False (in the boolean context).
@@ -266,46 +265,44 @@ class WebSocketAppTest(unittest.TestCase):
""" """
WebSocketAppTest.keep_running_open = self.keep_running WebSocketAppTest.keep_running_open = self.keep_running
self.close() self.close()
def on_close(self, *args, **kwargs): def on_close(self, *args, **kwargs):
""" Set the keep_running flag for the test to use. """ Set the keep_running flag for the test to use.
""" """
WebSocketAppTest.keep_running_close = self.keep_running WebSocketAppTest.keep_running_close = self.keep_running
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close) app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close)
app.run_forever() app.run_forever()
self.assertFalse(isinstance(WebSocketAppTest.keep_running_open, self.assertFalse(isinstance(WebSocketAppTest.keep_running_open,
WebSocketAppTest.NotSetYet)) WebSocketAppTest.NotSetYet))
self.assertFalse(isinstance(WebSocketAppTest.keep_running_close, self.assertFalse(isinstance(WebSocketAppTest.keep_running_close,
WebSocketAppTest.NotSetYet)) WebSocketAppTest.NotSetYet))
self.assertEquals(True, WebSocketAppTest.keep_running_open) self.assertEquals(True, WebSocketAppTest.keep_running_open)
self.assertEquals(False, WebSocketAppTest.keep_running_close) self.assertEquals(False, WebSocketAppTest.keep_running_close)
def testSockMaskKey(self): def testSockMaskKey(self):
""" A WebSocketApp should forward the received mask_key function down """ A WebSocketApp should forward the received mask_key function down
to the actual socket. to the actual socket.
""" """
def my_mask_key_func(): def my_mask_key_func():
pass pass
def on_open(self, *args, **kwargs): def on_open(self, *args, **kwargs):
""" Set the value so the test can use it later on and immediately """ Set the value so the test can use it later on and immediately
close the connection. close the connection.
""" """
WebSocketAppTest.get_mask_key_id = id(self.get_mask_key) WebSocketAppTest.get_mask_key_id = id(self.get_mask_key)
self.close() self.close()
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, get_mask_key=my_mask_key_func) app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, get_mask_key=my_mask_key_func)
app.run_forever() app.run_forever()
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'. # Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
self.assertEquals(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func)) self.assertEquals(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -58,12 +58,14 @@ STATUS_TLS_HANDSHAKE_ERROR = 1015
logger = logging.getLogger() logger = logging.getLogger()
class WebSocketException(Exception): class WebSocketException(Exception):
""" """
websocket exeception class. websocket exeception class.
""" """
pass pass
class WebSocketConnectionClosedException(WebSocketException): class WebSocketConnectionClosedException(WebSocketException):
""" """
If remote host closed the connection or some network error happened, If remote host closed the connection or some network error happened,
@@ -74,6 +76,7 @@ class WebSocketConnectionClosedException(WebSocketException):
default_timeout = None default_timeout = None
traceEnabled = False traceEnabled = False
def enableTrace(tracable): def enableTrace(tracable):
""" """
turn on/off the tracability. turn on/off the tracability.
@@ -87,6 +90,7 @@ def enableTrace(tracable):
logger.addHandler(logging.StreamHandler()) logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
def setdefaulttimeout(timeout): def setdefaulttimeout(timeout):
""" """
Set the global timeout setting to connect. Set the global timeout setting to connect.
@@ -96,12 +100,14 @@ def setdefaulttimeout(timeout):
global default_timeout global default_timeout
default_timeout = timeout default_timeout = timeout
def getdefaulttimeout(): def getdefaulttimeout():
""" """
Return the global timeout setting(second) to connect. Return the global timeout setting(second) to connect.
""" """
return default_timeout return default_timeout
def _parse_url(url): def _parse_url(url):
""" """
parse url and the result is tuple of parse url and the result is tuple of
@@ -130,7 +136,7 @@ def _parse_url(url):
elif scheme == "wss": elif scheme == "wss":
is_secure = True is_secure = True
if not port: if not port:
port = 443 port = 443
else: else:
raise ValueError("scheme %s is invalid" % scheme) raise ValueError("scheme %s is invalid" % scheme)
@@ -144,6 +150,7 @@ def _parse_url(url):
return (hostname, port, resource, is_secure) return (hostname, port, resource, is_secure)
def create_connection(url, timeout=None, **options): def create_connection(url, timeout=None, **options):
""" """
connect to url and return websocket object. connect to url and return websocket object.
@@ -177,6 +184,7 @@ _MAX_CHAR_BYTE = (1<<8) -1
# ref. Websocket gets an update, and it breaks stuff. # ref. Websocket gets an update, and it breaks stuff.
# http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html # http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html
def _create_sec_websocket_key(): def _create_sec_websocket_key():
uid = uuid.uuid4() uid = uuid.uuid4()
return base64.encodestring(uid.bytes).strip() return base64.encodestring(uid.bytes).strip()
@@ -186,6 +194,7 @@ _HEADERS_TO_CHECK = {
"connection": "upgrade", "connection": "upgrade",
} }
class _SSLSocketWrapper(object): class _SSLSocketWrapper(object):
def __init__(self, sock): def __init__(self, sock):
self.ssl = socket.ssl(sock) self.ssl = socket.ssl(sock)
@@ -197,6 +206,8 @@ class _SSLSocketWrapper(object):
return self.ssl.write(payload) return self.ssl.write(payload)
_BOOL_VALUES = (0, 1) _BOOL_VALUES = (0, 1)
def _is_bool(*values): def _is_bool(*values):
for v in values: for v in values:
if v not in _BOOL_VALUES: if v not in _BOOL_VALUES:
@@ -204,6 +215,7 @@ def _is_bool(*values):
return True return True
class ABNF(object): class ABNF(object):
""" """
ABNF frame class. ABNF frame class.
@@ -316,6 +328,7 @@ class ABNF(object):
_d[i] ^= _m[i % 4] _d[i] ^= _m[i % 4]
return _d.tostring() return _d.tostring()
class WebSocket(object): class WebSocket(object):
""" """
Low level WebSocket interface. Low level WebSocket interface.
@@ -337,6 +350,7 @@ class WebSocket(object):
get_mask_key: a callable to produce new mask keys, see the set_mask_key get_mask_key: a callable to produce new mask keys, see the set_mask_key
function's docstring for more details function's docstring for more details
""" """
def __init__(self, get_mask_key = None): def __init__(self, get_mask_key = None):
""" """
Initalize WebSocket object. Initalize WebSocket object.
@@ -423,8 +437,8 @@ class WebSocket(object):
header_str = "\r\n".join(headers) header_str = "\r\n".join(headers)
sock.send(header_str) sock.send(header_str)
if traceEnabled: if traceEnabled:
logger.debug( "--- request header ---") logger.debug("--- request header ---")
logger.debug( header_str) logger.debug(header_str)
logger.debug("-----------------------") logger.debug("-----------------------")
status, resp_headers = self._read_headers() status, resp_headers = self._read_headers()
@@ -549,7 +563,6 @@ class WebSocket(object):
elif frame.opcode == ABNF.OPCODE_PING: elif frame.opcode == ABNF.OPCODE_PING:
self.pong(frame.data) self.pong(frame.data)
def recv_frame(self): def recv_frame(self):
""" """
recieve data as frame from server. recieve data as frame from server.
@@ -603,8 +616,6 @@ class WebSocket(object):
raise ValueError("code is invalid range") raise ValueError("code is invalid range")
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE) self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
def close(self, status = STATUS_NORMAL, reason = ""): def close(self, status = STATUS_NORMAL, reason = ""):
""" """
Close Websocket object Close Websocket object
@@ -662,6 +673,7 @@ class WebSocket(object):
break break
return "".join(line) return "".join(line)
class WebSocketApp(object): class WebSocketApp(object):
""" """
Higher level of APIs are provided. Higher level of APIs are provided.
@@ -700,7 +712,7 @@ class WebSocketApp(object):
def send(self, data, opcode = ABNF.OPCODE_TEXT): def send(self, data, opcode = ABNF.OPCODE_TEXT):
""" """
send message. send message.
data: message to send. If you set opcode to OPCODE_TEXT, data must be utf-8 string or unicode. data: message to send. If you set opcode to OPCODE_TEXT, data must be utf-8 string or unicode.
opcode: operation code of data. default is OPCODE_TEXT. opcode: operation code of data. default is OPCODE_TEXT.
""" """
@@ -753,6 +765,6 @@ if __name__ == "__main__":
ws.send("Hello, World") ws.send("Hello, World")
print "Sent" print "Sent"
print "Receiving..." print "Receiving..."
result = ws.recv() result = ws.recv()
print "Received '%s'" % result print "Received '%s'" % result
ws.close() ws.close()