- workable with echo.websocket.org

This commit is contained in:
liris
2012-01-11 14:32:35 +09:00
parent bae967e517
commit 0e4f658534
2 changed files with 72 additions and 25 deletions

View File

@@ -6,6 +6,9 @@ import websocket as ws
TRACABLE=False TRACABLE=False
def create_mask_key(n):
return "abcd"
class StringSockMock: class StringSockMock:
def __init__(self): def __init__(self):
self.set_data("") self.set_data("")
@@ -137,43 +140,51 @@ class WebSocketTest(unittest.TestCase):
self.assertRaises(ws.WebSocketException, sock._read_headers) self.assertRaises(ws.WebSocketException, sock._read_headers)
def testSend(self): def testSend(self):
# TODO: add longer frame data
sock = ws.WebSocket() sock = ws.WebSocket()
sock.set_mask_key(create_mask_key)
s = sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt") s = sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt")
sock.send("Hello") sock.send("Hello")
#self.assertEquals(s.sent[0], "\x00Hello\xff") self.assertEquals(s.sent[0], "\x81\x85abcd)\x07\x0f\x08\x0e")
sock.send("こんにちは") sock.send("こんにちは")
#self.assertEquals(s.sent[1], "\x00こんにちは\xff") self.assertEquals(s.sent[1], "\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
sock.send(u"こんにちは") sock.send(u"こんにちは")
#self.assertEquals(s.sent[1], "\x00こんにちは\xff") self.assertEquals(s.sent[1], "\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
def testRecv(self): def testRecv(self):
# TODO: add longer frame data
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.io_sock = sock.sock = StringSockMock() s = sock.io_sock = sock.sock = StringSockMock()
s.set_data("\x00こんにちは\xff") s.set_data("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
data = sock.recv() data = sock.recv()
self.assertEquals(data, "こんにちは") self.assertEquals(data, "こんにちは")
s.set_data("\x81\x05Hello") s.set_data("\x81\x85abcd)\x07\x0f\x08\x0e")
data = sock.recv() data = sock.recv()
self.assertEquals(data, "Hello") self.assertEquals(data, "Hello")
s.set_data("\x81\x81\x7f" + ("a"*255))
data = sock.recv()
self.assertEquals(len(data), 255)
self.assertEquals(data, "a" * 255)
def testWebSocket(self): def testWebSocket(self):
s = ws.create_connection("ws://echo.websocket.org/") #ws://localhost:8080/echo") s = ws.create_connection("ws://echo.websocket.org/") #ws://localhost:8080/echo")
self.assertNotEquals(s, None) self.assertNotEquals(s, None)
s.send("Hello, World") s.send("Hello, World")
result = s.recv() result = s.recv()
self.assertEquals(result, "Hello, World") self.assertEquals(result, "Hello, World")
s.send("こにゃにゃちは、世界") s.send("こにゃにゃちは、世界")
result = s.recv() result = s.recv()
self.assertEquals(result, "こにゃにゃちは、世界") self.assertEquals(result, "こにゃにゃちは、世界")
s.close() s.close()
def testSecureWebsocket(self): def testPingPong(self):
s = ws.create_connection("ws://echo.websocket.org/")
self.assertNotEquals(s, None)
s.ping("Hello")
s.pong("Hi")
s.close()
def testSecureWebSocket(self):
s = ws.create_connection("wss://echo.websocket.org/") s = ws.create_connection("wss://echo.websocket.org/")
self.assertNotEquals(s, None) self.assertNotEquals(s, None)
self.assert_(isinstance(s.io_sock, ws._SSLSocketWrapper)) self.assert_(isinstance(s.io_sock, ws._SSLSocketWrapper))

View File

@@ -22,6 +22,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import socket import socket
from urlparse import urlparse from urlparse import urlparse
import os
import struct import struct
import uuid import uuid
import sha import sha
@@ -175,7 +176,7 @@ class ABNF(object):
LENGTH_63 = 1 << 63 LENGTH_63 = 1 << 63
def __init__(self, fin = 0, rsv1 = 0, rsv2 = 0, rsv3 = 0, def __init__(self, fin = 0, rsv1 = 0, rsv2 = 0, rsv3 = 0,
opcode = OPCODE_TEXT, mask = 0, data = ""): opcode = OPCODE_TEXT, mask = 1, data = ""):
self.fin = fin self.fin = fin
self.rsv1 = rsv1 self.rsv1 = rsv1
self.rsv2 = rsv2 self.rsv2 = rsv2
@@ -183,12 +184,13 @@ class ABNF(object):
self.opcode = opcode self.opcode = opcode
self.mask = mask self.mask = mask
self.data = data self.data = data
self.get_mask_key = os.urandom
@staticmethod @staticmethod
def create_frame(data, opcode): def create_frame(data, opcode):
if opcode == ABNF.OPCODE_TEXT and isinstance(data, unicode): if opcode == ABNF.OPCODE_TEXT and isinstance(data, unicode):
data = data.encode("utf-8") data = data.encode("utf-8")
return ABNF(1, 0, 0, 0, opcode, 0, data) return ABNF(1, 0, 0, 0, opcode, 1, data)
def format(self): def format(self):
if not is_bool(self.fin, self.rsv1, self.rsv2, self.rsv3): if not is_bool(self.fin, self.rsv1, self.rsv2, self.rsv3):
@@ -213,8 +215,24 @@ class ABNF(object):
if not self.mask: if not self.mask:
return frame_header + self.data return frame_header + self.data
else:
raise NotImplementedError("masked format is not implemented") mask_key = self.get_mask_key(4)
return frame_header + self._get_masked(mask_key)
def _get_masked(self, mask_key):
s = ABNF.mask(mask_key, self.data)
return mask_key + "".join(s)
@staticmethod
def mask(mask_key, data):
_m = map(ord, mask_key)
_d = map(ord, data)
for i in range(len(_d)):
_d[i] ^= _m[i % 4]
s = map(chr, _d)
return "".join(s)
@@ -242,6 +260,10 @@ class WebSocket(object):
""" """
self.connected = False self.connected = False
self.io_sock = self.sock = socket.socket() self.io_sock = self.sock = socket.socket()
self.get_mask_key = None
def set_mask_key(self, func):
self.get_mask_key = func
def settimeout(self, timeout): def settimeout(self, timeout):
""" """
@@ -361,25 +383,38 @@ class WebSocket(object):
return status, headers return status, headers
def send(self, payload, binary = False): def send(self, payload, opcode = ABNF.OPCODE_TEXT, binary = False):
""" """
Send the data as string. payload must be utf-8 string or unicoce. Send the data as string. payload must be utf-8 string or unicoce.
""" """
frame = ABNF.create_frame(payload, ABNF.OPCODE_TEXT) frame = ABNF.create_frame(payload, opcode)
if self.get_mask_key:
frame.get_mask_key = self.get_mask_key
data = frame.format() data = frame.format()
print repr(data)
self.io_sock.send(data) self.io_sock.send(data)
if traceEnabled: if traceEnabled:
logger.debug("send: " + repr(data)) logger.debug("send: " + repr(data))
def ping(self, payload):
self.send(payload, ABNF.OPCODE_PING)
def pong(self, payload):
self.send(payload, ABNF.OPCODE_PONG)
def recv(self): def recv(self):
""" """
Receive utf-8 string data from the server. Receive utf-8 string data from the server.
""" """
frame = self.read_frame() while True:
return frame.data frame = self.recv_frame()
if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
return frame.data
elif frame.opcode == ABNF.OPCODE_CLOSE:
return None
elif frame.opcode == ABNF.OPCODE_PING:
self.pong("Hi!")
def read_frame(self): def recv_frame(self):
header_bytes = self._recv(2) header_bytes = self._recv(2)
b1 = ord(header_bytes[0]) b1 = ord(header_bytes[0])
fin = b1 >> 7 & 1 fin = b1 >> 7 & 1
@@ -398,11 +433,12 @@ class WebSocket(object):
elif length == 0x7f: elif length == 0x7f:
l = self._recv(8) l = self._recv(8)
length = struct.unpack("!Q", l)[0] length = struct.unpack("!Q", l)[0]
data = self._recv(length)
if mask: if mask:
raise NotImplementedError("masked data transfer is not implemented") mask_key = self._recv(4)
data = self._recv(length)
if mask:
data = ABNF.mask(mask_key, data)
frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, mask, data) frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, mask, data)
return frame return frame
@@ -414,7 +450,7 @@ class WebSocket(object):
""" """
if self.connected: if self.connected:
try: try:
self.io_sock.send("\xff\x00") self.send("bye", ABNF.OPCODE_CLOSE)
timeout = self.sock.gettimeout() timeout = self.sock.gettimeout()
self.sock.settimeout(1) self.sock.settimeout(1)
try: try: