Merge
This commit is contained in:
88
eventlet/websocket.py
Normal file
88
eventlet/websocket.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import collections
|
||||
import errno
|
||||
from eventlet import wsgi
|
||||
from eventlet import pools
|
||||
import eventlet
|
||||
from eventlet.common import get_errno
|
||||
from eventlet.green import socket
|
||||
#from pprint import pformat
|
||||
|
||||
class WebSocket(object):
|
||||
"""Handles access to the actual socket"""
|
||||
|
||||
def __init__(self, sock, environ):
|
||||
"""
|
||||
:param socket: The eventlet socket
|
||||
:type socket: :class:`eventlet.greenio.GreenSocket`
|
||||
:param environ: The wsgi environment
|
||||
"""
|
||||
self.socket = sock
|
||||
self.origin = environ.get('HTTP_ORIGIN')
|
||||
self.protocol = environ.get('HTTP_WEBSOCKET_PROTOCOL')
|
||||
self.path = environ.get('PATH_INFO')
|
||||
self.environ = environ
|
||||
self._buf = ""
|
||||
self._msgs = collections.deque()
|
||||
self._sendlock = pools.TokenPool(1)
|
||||
|
||||
@staticmethod
|
||||
def pack_message(message):
|
||||
"""Pack the message inside ``00`` and ``FF``
|
||||
|
||||
As per the dataframing section (5.3) for the websocket spec
|
||||
"""
|
||||
if isinstance(message, unicode):
|
||||
message = message.encode('utf-8')
|
||||
elif not isinstance(message, str):
|
||||
message = str(message)
|
||||
packed = "\x00%s\xFF" % message
|
||||
return packed
|
||||
|
||||
def parse_messages(self):
|
||||
""" Parses for messages in the buffer *buf*. It is assumed that
|
||||
the buffer contains the start character for a message, but that it
|
||||
may contain only part of the rest of the message. NOTE: only understands
|
||||
lengthless messages for now.
|
||||
|
||||
Returns an array of messages, and the buffer remainder that didn't contain
|
||||
any full messages."""
|
||||
msgs = []
|
||||
end_idx = 0
|
||||
buf = self._buf
|
||||
while buf:
|
||||
assert ord(buf[0]) == 0, "Don't understand how to parse this type of message: %r" % buf
|
||||
end_idx = buf.find("\xFF")
|
||||
if end_idx == -1: #pragma NO COVER
|
||||
break
|
||||
msgs.append(buf[1:end_idx].decode('utf-8', 'replace'))
|
||||
buf = buf[end_idx+1:]
|
||||
self._buf = buf
|
||||
return msgs
|
||||
|
||||
def send(self, message):
|
||||
"""Send a message to the client"""
|
||||
packed = self.pack_message(message)
|
||||
# if two greenthreads are trying to send at the same time
|
||||
# on the same socket, sendlock prevents interleaving and corruption
|
||||
t = self._sendlock.get()
|
||||
try:
|
||||
self.socket.sendall(packed)
|
||||
finally:
|
||||
self._sendlock.put(t)
|
||||
|
||||
def wait(self):
|
||||
"""Waits for an deserializes messages"""
|
||||
|
||||
while not self._msgs:
|
||||
# no parsed messages, must mean buf needs more data
|
||||
delta = self.socket.recv(1024)
|
||||
if delta == '':
|
||||
return None
|
||||
self._buf += delta
|
||||
msgs = self.parse_messages()
|
||||
self._msgs.extend(msgs)
|
||||
return self._msgs.popleft()
|
||||
|
||||
def close(self):
|
||||
self.socket.shutdown(True)
|
||||
|
@@ -312,7 +312,7 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||
try:
|
||||
try:
|
||||
result = self.application(self.environ, start_response)
|
||||
if result is ALREADY_HANDLED:
|
||||
if isinstance(result, _AlreadyHandled):
|
||||
self.close_connection = 1
|
||||
return
|
||||
if not headers_sent and hasattr(result, '__len__') and \
|
||||
|
1
setup.py
1
setup.py
@@ -30,6 +30,7 @@ setup(
|
||||
)
|
||||
).read(),
|
||||
test_suite = 'nose.collector',
|
||||
tests_require = 'httplib2',
|
||||
classifiers=[
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python",
|
||||
|
271
tests/mock.py
Executable file
271
tests/mock.py
Executable file
@@ -0,0 +1,271 @@
|
||||
# mock.py
|
||||
# Test tools for mocking and patching.
|
||||
# Copyright (C) 2007-2009 Michael Foord
|
||||
# E-mail: fuzzyman AT voidspace DOT org DOT uk
|
||||
|
||||
# mock 0.6.0
|
||||
# http://www.voidspace.org.uk/python/mock/
|
||||
|
||||
# Released subject to the BSD License
|
||||
# Please see http://www.voidspace.org.uk/python/license.shtml
|
||||
|
||||
# Scripts maintained at http://www.voidspace.org.uk/python/index.shtml
|
||||
# Comments, suggestions and bug reports welcome.
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Mock',
|
||||
'patch',
|
||||
'patch_object',
|
||||
'sentinel',
|
||||
'DEFAULT'
|
||||
)
|
||||
|
||||
__version__ = '0.6.0'
|
||||
|
||||
class SentinelObject(object):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return '<SentinelObject "%s">' % self.name
|
||||
|
||||
|
||||
class Sentinel(object):
|
||||
def __init__(self):
|
||||
self._sentinels = {}
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self._sentinels.setdefault(name, SentinelObject(name))
|
||||
|
||||
|
||||
sentinel = Sentinel()
|
||||
|
||||
DEFAULT = sentinel.DEFAULT
|
||||
|
||||
class OldStyleClass:
|
||||
pass
|
||||
ClassType = type(OldStyleClass)
|
||||
|
||||
def _is_magic(name):
|
||||
return '__%s__' % name[2:-2] == name
|
||||
|
||||
def _copy(value):
|
||||
if type(value) in (dict, list, tuple, set):
|
||||
return type(value)(value)
|
||||
return value
|
||||
|
||||
|
||||
class Mock(object):
|
||||
|
||||
def __init__(self, spec=None, side_effect=None, return_value=DEFAULT,
|
||||
name=None, parent=None, wraps=None):
|
||||
self._parent = parent
|
||||
self._name = name
|
||||
if spec is not None and not isinstance(spec, list):
|
||||
spec = [member for member in dir(spec) if not _is_magic(member)]
|
||||
|
||||
self._methods = spec
|
||||
self._children = {}
|
||||
self._return_value = return_value
|
||||
self.side_effect = side_effect
|
||||
self._wraps = wraps
|
||||
|
||||
self.reset_mock()
|
||||
|
||||
|
||||
def reset_mock(self):
|
||||
self.called = False
|
||||
self.call_args = None
|
||||
self.call_count = 0
|
||||
self.call_args_list = []
|
||||
self.method_calls = []
|
||||
for child in self._children.itervalues():
|
||||
child.reset_mock()
|
||||
if isinstance(self._return_value, Mock):
|
||||
self._return_value.reset_mock()
|
||||
|
||||
|
||||
def __get_return_value(self):
|
||||
if self._return_value is DEFAULT:
|
||||
self._return_value = Mock()
|
||||
return self._return_value
|
||||
|
||||
def __set_return_value(self, value):
|
||||
self._return_value = value
|
||||
|
||||
return_value = property(__get_return_value, __set_return_value)
|
||||
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.called = True
|
||||
self.call_count += 1
|
||||
self.call_args = (args, kwargs)
|
||||
self.call_args_list.append((args, kwargs))
|
||||
|
||||
parent = self._parent
|
||||
name = self._name
|
||||
while parent is not None:
|
||||
parent.method_calls.append((name, args, kwargs))
|
||||
if parent._parent is None:
|
||||
break
|
||||
name = parent._name + '.' + name
|
||||
parent = parent._parent
|
||||
|
||||
ret_val = DEFAULT
|
||||
if self.side_effect is not None:
|
||||
if (isinstance(self.side_effect, Exception) or
|
||||
isinstance(self.side_effect, (type, ClassType)) and
|
||||
issubclass(self.side_effect, Exception)):
|
||||
raise self.side_effect
|
||||
|
||||
ret_val = self.side_effect(*args, **kwargs)
|
||||
if ret_val is DEFAULT:
|
||||
ret_val = self.return_value
|
||||
|
||||
if self._wraps is not None and self._return_value is DEFAULT:
|
||||
return self._wraps(*args, **kwargs)
|
||||
if ret_val is DEFAULT:
|
||||
ret_val = self.return_value
|
||||
return ret_val
|
||||
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._methods is not None:
|
||||
if name not in self._methods:
|
||||
raise AttributeError("Mock object has no attribute '%s'" % name)
|
||||
elif _is_magic(name):
|
||||
raise AttributeError(name)
|
||||
|
||||
if name not in self._children:
|
||||
wraps = None
|
||||
if self._wraps is not None:
|
||||
wraps = getattr(self._wraps, name)
|
||||
self._children[name] = Mock(parent=self, name=name, wraps=wraps)
|
||||
|
||||
return self._children[name]
|
||||
|
||||
|
||||
def assert_called_with(self, *args, **kwargs):
|
||||
assert self.call_args == (args, kwargs), 'Expected: %s\nCalled with: %s' % ((args, kwargs), self.call_args)
|
||||
|
||||
|
||||
def _dot_lookup(thing, comp, import_path):
|
||||
try:
|
||||
return getattr(thing, comp)
|
||||
except AttributeError:
|
||||
__import__(import_path)
|
||||
return getattr(thing, comp)
|
||||
|
||||
|
||||
def _importer(target):
|
||||
components = target.split('.')
|
||||
import_path = components.pop(0)
|
||||
thing = __import__(import_path)
|
||||
|
||||
for comp in components:
|
||||
import_path += ".%s" % comp
|
||||
thing = _dot_lookup(thing, comp, import_path)
|
||||
return thing
|
||||
|
||||
|
||||
class _patch(object):
|
||||
def __init__(self, target, attribute, new, spec, create):
|
||||
self.target = target
|
||||
self.attribute = attribute
|
||||
self.new = new
|
||||
self.spec = spec
|
||||
self.create = create
|
||||
self.has_local = False
|
||||
|
||||
|
||||
def __call__(self, func):
|
||||
if hasattr(func, 'patchings'):
|
||||
func.patchings.append(self)
|
||||
return func
|
||||
|
||||
def patched(*args, **keywargs):
|
||||
# don't use a with here (backwards compatability with 2.5)
|
||||
extra_args = []
|
||||
for patching in patched.patchings:
|
||||
arg = patching.__enter__()
|
||||
if patching.new is DEFAULT:
|
||||
extra_args.append(arg)
|
||||
args += tuple(extra_args)
|
||||
try:
|
||||
return func(*args, **keywargs)
|
||||
finally:
|
||||
for patching in getattr(patched, 'patchings', []):
|
||||
patching.__exit__()
|
||||
|
||||
patched.patchings = [self]
|
||||
patched.__name__ = func.__name__
|
||||
patched.compat_co_firstlineno = getattr(func, "compat_co_firstlineno",
|
||||
func.func_code.co_firstlineno)
|
||||
return patched
|
||||
|
||||
|
||||
def get_original(self):
|
||||
target = self.target
|
||||
name = self.attribute
|
||||
create = self.create
|
||||
|
||||
original = DEFAULT
|
||||
if _has_local_attr(target, name):
|
||||
try:
|
||||
original = target.__dict__[name]
|
||||
except AttributeError:
|
||||
# for instances of classes with slots, they have no __dict__
|
||||
original = getattr(target, name)
|
||||
elif not create and not hasattr(target, name):
|
||||
raise AttributeError("%s does not have the attribute %r" % (target, name))
|
||||
return original
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
new, spec, = self.new, self.spec
|
||||
original = self.get_original()
|
||||
if new is DEFAULT:
|
||||
# XXXX what if original is DEFAULT - shouldn't use it as a spec
|
||||
inherit = False
|
||||
if spec == True:
|
||||
# set spec to the object we are replacing
|
||||
spec = original
|
||||
if isinstance(spec, (type, ClassType)):
|
||||
inherit = True
|
||||
new = Mock(spec=spec)
|
||||
if inherit:
|
||||
new.return_value = Mock(spec=spec)
|
||||
self.temp_original = original
|
||||
setattr(self.target, self.attribute, new)
|
||||
return new
|
||||
|
||||
|
||||
def __exit__(self, *_):
|
||||
if self.temp_original is not DEFAULT:
|
||||
setattr(self.target, self.attribute, self.temp_original)
|
||||
else:
|
||||
delattr(self.target, self.attribute)
|
||||
del self.temp_original
|
||||
|
||||
|
||||
def patch_object(target, attribute, new=DEFAULT, spec=None, create=False):
|
||||
return _patch(target, attribute, new, spec, create)
|
||||
|
||||
|
||||
def patch(target, new=DEFAULT, spec=None, create=False):
|
||||
try:
|
||||
target, attribute = target.rsplit('.', 1)
|
||||
except (TypeError, ValueError):
|
||||
raise TypeError("Need a valid target to patch. You supplied: %r" % (target,))
|
||||
target = _importer(target)
|
||||
return _patch(target, attribute, new, spec, create)
|
||||
|
||||
|
||||
|
||||
def _has_local_attr(obj, name):
|
||||
try:
|
||||
return name in vars(obj)
|
||||
except TypeError:
|
||||
# objects without a __dict__
|
||||
return hasattr(obj, name)
|
228
tests/test_websocket.py
Normal file
228
tests/test_websocket.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import eventlet
|
||||
from eventlet import debug, hubs, Timeout, spawn_n, greenthread, wsgi, patcher
|
||||
from eventlet.green import urllib2
|
||||
from eventlet.websocket import WebSocket
|
||||
from nose.tools import ok_, eq_, set_trace, raises
|
||||
from StringIO import StringIO
|
||||
from unittest import TestCase
|
||||
from tests.wsgi_test import _TestBase
|
||||
import logging
|
||||
from tests import mock
|
||||
import random
|
||||
|
||||
httplib2 = patcher.import_patched('httplib2')
|
||||
|
||||
|
||||
class WebSocketWSGI(object):
|
||||
def __init__(self, handler):
|
||||
self.handler = handler
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
print environ
|
||||
if not (environ.get('HTTP_CONNECTION') == 'Upgrade' and
|
||||
environ.get('HTTP_UPGRADE') == 'WebSocket'):
|
||||
# need to check a few more things here for true compliance
|
||||
start_response('400 Bad Request', [('Connection','close')])
|
||||
return []
|
||||
|
||||
sock = environ['eventlet.input'].get_socket()
|
||||
ws = WebSocket(sock, environ)
|
||||
handshake_reply = ("HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
|
||||
"Upgrade: WebSocket\r\n"
|
||||
"Connection: Upgrade\r\n"
|
||||
"WebSocket-Origin: %s\r\n"
|
||||
"WebSocket-Location: ws://%s%s\r\n\r\n" % (
|
||||
environ.get('HTTP_ORIGIN'),
|
||||
environ.get('HTTP_HOST'),
|
||||
environ.get('PATH_INFO')))
|
||||
sock.sendall(handshake_reply)
|
||||
try:
|
||||
self.handler(ws)
|
||||
except socket.error, e:
|
||||
if get_errno(e) != errno.EPIPE:
|
||||
raise
|
||||
# use this undocumented feature of eventlet.wsgi to ensure that it
|
||||
# doesn't barf on the fact that we didn't call start_response
|
||||
return wsgi.ALREADY_HANDLED
|
||||
|
||||
# demo app
|
||||
import os
|
||||
import random
|
||||
def handle(ws):
|
||||
""" This is the websocket handler function. Note that we
|
||||
can dispatch based on path in here, too."""
|
||||
if ws.path == '/echo':
|
||||
while True:
|
||||
m = ws.wait()
|
||||
if m is None:
|
||||
break
|
||||
ws.send(m)
|
||||
|
||||
elif ws.path == '/range':
|
||||
for i in xrange(10):
|
||||
ws.send("msg %d" % i)
|
||||
eventlet.sleep(0.1)
|
||||
|
||||
else:
|
||||
ws.close()
|
||||
|
||||
wsapp = WebSocketWSGI(handle)
|
||||
|
||||
|
||||
class TestWebSocket(_TestBase):
|
||||
|
||||
# def setUp(self):
|
||||
# super(_TestBase, self).setUp()
|
||||
# self.logfile = StringIO()
|
||||
# self.site = Site()
|
||||
# self.killer = None
|
||||
# self.set_site()
|
||||
# self.spawn_server()
|
||||
# self.site.application = WebSocketWSGI(handle, 'http://localhost:%s' % self.port)
|
||||
|
||||
TEST_TIMEOUT = 5
|
||||
|
||||
def set_site(self):
|
||||
self.site = wsapp
|
||||
|
||||
|
||||
@raises(urllib2.HTTPError)
|
||||
def test_incorrect_headers(self):
|
||||
try:
|
||||
urllib2.urlopen("http://localhost:%s/echo" % self.port)
|
||||
except urllib2.HTTPError, e:
|
||||
eq_(e.code, 400)
|
||||
raise
|
||||
|
||||
def test_incomplete_headers(self):
|
||||
headers = dict(kv.split(': ') for kv in [
|
||||
"Upgrade: WebSocket",
|
||||
#"Connection: Upgrade", Without this should trigger the HTTPServerError
|
||||
"Host: localhost:%s" % self.port,
|
||||
"Origin: http://localhost:%s" % self.port,
|
||||
"WebSocket-Protocol: ws",
|
||||
])
|
||||
http = httplib2.Http()
|
||||
resp, content = http.request("http://localhost:%s/echo" % self.port, headers=headers)
|
||||
|
||||
self.assertEqual(resp['status'], '400')
|
||||
self.assertEqual(resp['connection'], 'close')
|
||||
self.assertEqual(content, '')
|
||||
|
||||
def test_correct_upgrade_request(self):
|
||||
connect = [
|
||||
"GET /echo HTTP/1.1",
|
||||
"Upgrade: WebSocket",
|
||||
"Connection: Upgrade",
|
||||
"Host: localhost:%s" % self.port,
|
||||
"Origin: http://localhost:%s" % self.port,
|
||||
"WebSocket-Protocol: ws",
|
||||
]
|
||||
sock = eventlet.connect(
|
||||
('localhost', self.port))
|
||||
|
||||
fd = sock.makefile('rw', close=True)
|
||||
fd.write('\r\n'.join(connect) + '\r\n\r\n')
|
||||
fd.flush()
|
||||
result = sock.recv(1024)
|
||||
fd.close()
|
||||
## The server responds the correct Websocket handshake
|
||||
self.assertEqual(result,
|
||||
'\r\n'.join(['HTTP/1.1 101 Web Socket Protocol Handshake',
|
||||
'Upgrade: WebSocket',
|
||||
'Connection: Upgrade',
|
||||
'WebSocket-Origin: http://localhost:%s' % self.port,
|
||||
'WebSocket-Location: ws://localhost:%s/echo\r\n\r\n' % self.port]))
|
||||
|
||||
def test_sending_messages_to_websocket(self):
|
||||
connect = [
|
||||
"GET /echo HTTP/1.1",
|
||||
"Upgrade: WebSocket",
|
||||
"Connection: Upgrade",
|
||||
"Host: localhost:%s" % self.port,
|
||||
"Origin: http://localhost:%s" % self.port,
|
||||
"WebSocket-Protocol: ws",
|
||||
]
|
||||
sock = eventlet.connect(
|
||||
('localhost', self.port))
|
||||
|
||||
fd = sock.makefile('rw', close=True)
|
||||
fd.write('\r\n'.join(connect) + '\r\n\r\n')
|
||||
fd.flush()
|
||||
first_resp = sock.recv(1024)
|
||||
fd.write('\x00hello\xFF')
|
||||
fd.flush()
|
||||
result = sock.recv(1024)
|
||||
self.assertEqual(result, '\x00hello\xff')
|
||||
fd.write('\x00start')
|
||||
fd.flush()
|
||||
fd.write(' end\xff')
|
||||
fd.flush()
|
||||
result = sock.recv(1024)
|
||||
self.assertEqual(result, '\x00start end\xff')
|
||||
fd.write('')
|
||||
fd.flush()
|
||||
|
||||
|
||||
|
||||
def test_getting_messages_from_websocket(self):
|
||||
connect = [
|
||||
"GET /range HTTP/1.1",
|
||||
"Upgrade: WebSocket",
|
||||
"Connection: Upgrade",
|
||||
"Host: localhost:%s" % self.port,
|
||||
"Origin: http://localhost:%s" % self.port,
|
||||
"WebSocket-Protocol: ws",
|
||||
]
|
||||
sock = eventlet.connect(
|
||||
('localhost', self.port))
|
||||
|
||||
fd = sock.makefile('rw', close=True)
|
||||
fd.write('\r\n'.join(connect) + '\r\n\r\n')
|
||||
fd.flush()
|
||||
resp = sock.recv(1024)
|
||||
headers, result = resp.split('\r\n\r\n')
|
||||
msgs = [result.strip('\x00\xff')]
|
||||
cnt = 10
|
||||
while cnt:
|
||||
msgs.append(sock.recv(20).strip('\x00\xff'))
|
||||
cnt -= 1
|
||||
# Last item in msgs is an empty string
|
||||
self.assertEqual(msgs[:-1], ['msg %d' % i for i in range(10)])
|
||||
|
||||
|
||||
class TestWebSocketObject(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_socket = s = mock.Mock()
|
||||
self.environ = env = dict(HTTP_ORIGIN='http://localhost', HTTP_WEBSOCKET_PROTOCOL='ws',
|
||||
PATH_INFO='test')
|
||||
|
||||
self.test_ws = WebSocket(s, env)
|
||||
|
||||
def test_recieve(self):
|
||||
ws = self.test_ws
|
||||
ws.socket.recv.return_value = '\x00hello\xFF'
|
||||
eq_(ws.wait(), 'hello')
|
||||
eq_(ws._buf, '')
|
||||
eq_(len(ws._msgs), 0)
|
||||
ws.socket.recv.return_value = ''
|
||||
eq_(ws.wait(), None)
|
||||
eq_(ws._buf, '')
|
||||
eq_(len(ws._msgs), 0)
|
||||
|
||||
|
||||
def test_send_to_ws(self):
|
||||
ws = self.test_ws
|
||||
ws.send(u'hello')
|
||||
ok_(ws.socket.sendall.called_with("\x00hello\xFF"))
|
||||
ws.send(10)
|
||||
ok_(ws.socket.sendall.called_with("\x0010\xFF"))
|
||||
|
||||
def test_close_ws(self):
|
||||
ws = self.test_ws
|
||||
ws.close()
|
||||
ok_(ws.socket.shutdown.called_with(True))
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user