229 lines
7.3 KiB
Python
229 lines
7.3 KiB
Python
import eventlet
|
|
from eventlet import debug, hubs, Timeout, spawn_n, greenthread, wsgi, patcher
|
|
from eventlet.green import urllib2
|
|
from eventlet.websocket import WebSocket
|
|
from nose.tools import ok_, eq_, set_trace, raises
|
|
from StringIO import StringIO
|
|
from unittest import TestCase
|
|
from tests.wsgi_test import _TestBase
|
|
import logging
|
|
from tests import mock
|
|
import random
|
|
|
|
httplib2 = patcher.import_patched('httplib2')
|
|
|
|
|
|
class WebSocketWSGI(object):
|
|
def __init__(self, handler):
|
|
self.handler = handler
|
|
|
|
def __call__(self, environ, start_response):
|
|
print environ
|
|
if not (environ.get('HTTP_CONNECTION') == 'Upgrade' and
|
|
environ.get('HTTP_UPGRADE') == 'WebSocket'):
|
|
# need to check a few more things here for true compliance
|
|
start_response('400 Bad Request', [('Connection','close')])
|
|
return []
|
|
|
|
sock = environ['eventlet.input'].get_socket()
|
|
ws = WebSocket(sock, environ)
|
|
handshake_reply = ("HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
|
|
"Upgrade: WebSocket\r\n"
|
|
"Connection: Upgrade\r\n"
|
|
"WebSocket-Origin: %s\r\n"
|
|
"WebSocket-Location: ws://%s%s\r\n\r\n" % (
|
|
environ.get('HTTP_ORIGIN'),
|
|
environ.get('HTTP_HOST'),
|
|
environ.get('PATH_INFO')))
|
|
sock.sendall(handshake_reply)
|
|
try:
|
|
self.handler(ws)
|
|
except socket.error, e:
|
|
if get_errno(e) != errno.EPIPE:
|
|
raise
|
|
# use this undocumented feature of eventlet.wsgi to ensure that it
|
|
# doesn't barf on the fact that we didn't call start_response
|
|
return wsgi.ALREADY_HANDLED
|
|
|
|
# demo app
|
|
import os
|
|
import random
|
|
def handle(ws):
|
|
""" This is the websocket handler function. Note that we
|
|
can dispatch based on path in here, too."""
|
|
if ws.path == '/echo':
|
|
while True:
|
|
m = ws.wait()
|
|
if m is None:
|
|
break
|
|
ws.send(m)
|
|
|
|
elif ws.path == '/range':
|
|
for i in xrange(10):
|
|
ws.send("msg %d" % i)
|
|
eventlet.sleep(0.1)
|
|
|
|
else:
|
|
ws.close()
|
|
|
|
wsapp = WebSocketWSGI(handle)
|
|
|
|
|
|
class TestWebSocket(_TestBase):
|
|
|
|
# def setUp(self):
|
|
# super(_TestBase, self).setUp()
|
|
# self.logfile = StringIO()
|
|
# self.site = Site()
|
|
# self.killer = None
|
|
# self.set_site()
|
|
# self.spawn_server()
|
|
# self.site.application = WebSocketWSGI(handle, 'http://localhost:%s' % self.port)
|
|
|
|
TEST_TIMEOUT = 5
|
|
|
|
def set_site(self):
|
|
self.site = wsapp
|
|
|
|
|
|
@raises(urllib2.HTTPError)
|
|
def test_incorrect_headers(self):
|
|
try:
|
|
urllib2.urlopen("http://localhost:%s/echo" % self.port)
|
|
except urllib2.HTTPError, e:
|
|
eq_(e.code, 400)
|
|
raise
|
|
|
|
def test_incomplete_headers(self):
|
|
headers = dict(kv.split(': ') for kv in [
|
|
"Upgrade: WebSocket",
|
|
#"Connection: Upgrade", Without this should trigger the HTTPServerError
|
|
"Host: localhost:%s" % self.port,
|
|
"Origin: http://localhost:%s" % self.port,
|
|
"WebSocket-Protocol: ws",
|
|
])
|
|
http = httplib2.Http()
|
|
resp, content = http.request("http://localhost:%s/echo" % self.port, headers=headers)
|
|
|
|
self.assertEqual(resp['status'], '400')
|
|
self.assertEqual(resp['connection'], 'close')
|
|
self.assertEqual(content, '')
|
|
|
|
def test_correct_upgrade_request(self):
|
|
connect = [
|
|
"GET /echo HTTP/1.1",
|
|
"Upgrade: WebSocket",
|
|
"Connection: Upgrade",
|
|
"Host: localhost:%s" % self.port,
|
|
"Origin: http://localhost:%s" % self.port,
|
|
"WebSocket-Protocol: ws",
|
|
]
|
|
sock = eventlet.connect(
|
|
('localhost', self.port))
|
|
|
|
fd = sock.makefile('rw', close=True)
|
|
fd.write('\r\n'.join(connect) + '\r\n\r\n')
|
|
fd.flush()
|
|
result = sock.recv(1024)
|
|
fd.close()
|
|
## The server responds the correct Websocket handshake
|
|
self.assertEqual(result,
|
|
'\r\n'.join(['HTTP/1.1 101 Web Socket Protocol Handshake',
|
|
'Upgrade: WebSocket',
|
|
'Connection: Upgrade',
|
|
'WebSocket-Origin: http://localhost:%s' % self.port,
|
|
'WebSocket-Location: ws://localhost:%s/echo\r\n\r\n' % self.port]))
|
|
|
|
def test_sending_messages_to_websocket(self):
|
|
connect = [
|
|
"GET /echo HTTP/1.1",
|
|
"Upgrade: WebSocket",
|
|
"Connection: Upgrade",
|
|
"Host: localhost:%s" % self.port,
|
|
"Origin: http://localhost:%s" % self.port,
|
|
"WebSocket-Protocol: ws",
|
|
]
|
|
sock = eventlet.connect(
|
|
('localhost', self.port))
|
|
|
|
fd = sock.makefile('rw', close=True)
|
|
fd.write('\r\n'.join(connect) + '\r\n\r\n')
|
|
fd.flush()
|
|
first_resp = sock.recv(1024)
|
|
fd.write('\x00hello\xFF')
|
|
fd.flush()
|
|
result = sock.recv(1024)
|
|
self.assertEqual(result, '\x00hello\xff')
|
|
fd.write('\x00start')
|
|
fd.flush()
|
|
fd.write(' end\xff')
|
|
fd.flush()
|
|
result = sock.recv(1024)
|
|
self.assertEqual(result, '\x00start end\xff')
|
|
fd.write('')
|
|
fd.flush()
|
|
|
|
|
|
|
|
def test_getting_messages_from_websocket(self):
|
|
connect = [
|
|
"GET /range HTTP/1.1",
|
|
"Upgrade: WebSocket",
|
|
"Connection: Upgrade",
|
|
"Host: localhost:%s" % self.port,
|
|
"Origin: http://localhost:%s" % self.port,
|
|
"WebSocket-Protocol: ws",
|
|
]
|
|
sock = eventlet.connect(
|
|
('localhost', self.port))
|
|
|
|
fd = sock.makefile('rw', close=True)
|
|
fd.write('\r\n'.join(connect) + '\r\n\r\n')
|
|
fd.flush()
|
|
resp = sock.recv(1024)
|
|
headers, result = resp.split('\r\n\r\n')
|
|
msgs = [result.strip('\x00\xff')]
|
|
cnt = 10
|
|
while cnt:
|
|
msgs.append(sock.recv(20).strip('\x00\xff'))
|
|
cnt -= 1
|
|
# Last item in msgs is an empty string
|
|
self.assertEqual(msgs[:-1], ['msg %d' % i for i in range(10)])
|
|
|
|
|
|
class TestWebSocketObject(TestCase):
|
|
|
|
def setUp(self):
|
|
self.mock_socket = s = mock.Mock()
|
|
self.environ = env = dict(HTTP_ORIGIN='http://localhost', HTTP_WEBSOCKET_PROTOCOL='ws',
|
|
PATH_INFO='test')
|
|
|
|
self.test_ws = WebSocket(s, env)
|
|
|
|
def test_recieve(self):
|
|
ws = self.test_ws
|
|
ws.socket.recv.return_value = '\x00hello\xFF'
|
|
eq_(ws.wait(), 'hello')
|
|
eq_(ws._buf, '')
|
|
eq_(len(ws._msgs), 0)
|
|
ws.socket.recv.return_value = ''
|
|
eq_(ws.wait(), None)
|
|
eq_(ws._buf, '')
|
|
eq_(len(ws._msgs), 0)
|
|
|
|
|
|
def test_send_to_ws(self):
|
|
ws = self.test_ws
|
|
ws.send(u'hello')
|
|
ok_(ws.socket.sendall.called_with("\x00hello\xFF"))
|
|
ws.send(10)
|
|
ok_(ws.socket.sendall.called_with("\x0010\xFF"))
|
|
|
|
def test_close_ws(self):
|
|
ws = self.test_ws
|
|
ws.close()
|
|
ok_(ws.socket.shutdown.called_with(True))
|
|
|
|
|
|
|