From 6ddba84970459011883b74b08997b4967eece245 Mon Sep 17 00:00:00 2001 From: Ben Ford Date: Wed, 5 May 2010 07:08:07 +0100 Subject: [PATCH 1/3] Changed is ALREADY_HANDLED test to isinstance test --- eventlet/wsgi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eventlet/wsgi.py b/eventlet/wsgi.py index 35a7944..1b46395 100644 --- a/eventlet/wsgi.py +++ b/eventlet/wsgi.py @@ -312,7 +312,7 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler): try: try: result = self.application(self.environ, start_response) - if result is ALREADY_HANDLED: + if isinstance(result, _AlreadyHandled): self.close_connection = 1 return if not headers_sent and hasattr(result, '__len__') and \ From 8d6ad99285516d8ce9b1a1f2b01db3368d8b702c Mon Sep 17 00:00:00 2001 From: Ben Ford Date: Wed, 5 May 2010 07:14:57 +0100 Subject: [PATCH 2/3] Added websocket. This has 100% test coverage but has introduced a testing dependency on httplib2 and mock. --- eventlet/websocket.py | 88 ++++++++++++++++ setup.py | 1 + tests/test_websocket.py | 228 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 317 insertions(+) create mode 100644 eventlet/websocket.py create mode 100644 tests/test_websocket.py diff --git a/eventlet/websocket.py b/eventlet/websocket.py new file mode 100644 index 0000000..83071e2 --- /dev/null +++ b/eventlet/websocket.py @@ -0,0 +1,88 @@ +import collections +import errno +from eventlet import wsgi +from eventlet import pools +import eventlet +from eventlet.common import get_errno +from eventlet.green import socket +#from pprint import pformat + +class WebSocket(object): + """Handles access to the actual socket""" + + def __init__(self, sock, environ): + """ + :param socket: The eventlet socket + :type socket: :class:`eventlet.greenio.GreenSocket` + :param environ: The wsgi environment + """ + self.socket = sock + self.origin = environ.get('HTTP_ORIGIN') + self.protocol = environ.get('HTTP_WEBSOCKET_PROTOCOL') + self.path = environ.get('PATH_INFO') + self.environ = environ + self._buf = "" + self._msgs = collections.deque() + self._sendlock = pools.TokenPool(1) + + @staticmethod + def pack_message(message): + """Pack the message inside ``00`` and ``FF`` + + As per the dataframing section (5.3) for the websocket spec + """ + if isinstance(message, unicode): + message = message.encode('utf-8') + elif not isinstance(message, str): + message = str(message) + packed = "\x00%s\xFF" % message + return packed + + def parse_messages(self): + """ Parses for messages in the buffer *buf*. It is assumed that + the buffer contains the start character for a message, but that it + may contain only part of the rest of the message. NOTE: only understands + lengthless messages for now. + + Returns an array of messages, and the buffer remainder that didn't contain + any full messages.""" + msgs = [] + end_idx = 0 + buf = self._buf + while buf: + assert ord(buf[0]) == 0, "Don't understand how to parse this type of message: %r" % buf + end_idx = buf.find("\xFF") + if end_idx == -1: #pragma NO COVER + break + msgs.append(buf[1:end_idx].decode('utf-8', 'replace')) + buf = buf[end_idx+1:] + self._buf = buf + return msgs + + def send(self, message): + """Send a message to the client""" + packed = self.pack_message(message) + # if two greenthreads are trying to send at the same time + # on the same socket, sendlock prevents interleaving and corruption + t = self._sendlock.get() + try: + self.socket.sendall(packed) + finally: + self._sendlock.put(t) + + def wait(self): + """Waits for an deserializes messages""" + + while not self._msgs: + # no parsed messages, must mean buf needs more data + delta = self.socket.recv(1024) + if delta == '': + return None + self._buf += delta + msgs = self.parse_messages() + self._msgs.extend(msgs) + return self._msgs.popleft() + + def close(self): + self.socket.shutdown(True) + diff --git a/setup.py b/setup.py index 1d312e9..7098c23 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ setup( ) ).read(), test_suite = 'nose.collector', + tests_require = 'mock', classifiers=[ "License :: OSI Approved :: MIT License", "Programming Language :: Python", diff --git a/tests/test_websocket.py b/tests/test_websocket.py new file mode 100644 index 0000000..db4638e --- /dev/null +++ b/tests/test_websocket.py @@ -0,0 +1,228 @@ +import eventlet +from eventlet import debug, hubs, Timeout, spawn_n, greenthread, wsgi, patcher +from eventlet.green import urllib2 +from eventlet.websocket import WebSocket +from nose.tools import ok_, eq_, set_trace, raises +from StringIO import StringIO +from unittest import TestCase +from tests.wsgi_test import _TestBase +import logging +import mock +import random + +httplib2 = patcher.import_patched('httplib2') + + +class WebSocketWSGI(object): + def __init__(self, handler): + self.handler = handler + + def __call__(self, environ, start_response): + print environ + if not (environ.get('HTTP_CONNECTION') == 'Upgrade' and + environ.get('HTTP_UPGRADE') == 'WebSocket'): + # need to check a few more things here for true compliance + start_response('400 Bad Request', [('Connection','close')]) + return [] + + sock = environ['eventlet.input'].get_socket() + ws = WebSocket(sock, environ) + handshake_reply = ("HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: %s\r\n" + "WebSocket-Location: ws://%s%s\r\n\r\n" % ( + environ.get('HTTP_ORIGIN'), + environ.get('HTTP_HOST'), + environ.get('PATH_INFO'))) + sock.sendall(handshake_reply) + try: + self.handler(ws) + except socket.error, e: + if get_errno(e) != errno.EPIPE: + raise + # use this undocumented feature of eventlet.wsgi to ensure that it + # doesn't barf on the fact that we didn't call start_response + return wsgi.ALREADY_HANDLED + +# demo app +import os +import random +def handle(ws): + """ This is the websocket handler function. Note that we + can dispatch based on path in here, too.""" + if ws.path == '/echo': + while True: + m = ws.wait() + if m is None: + break + ws.send(m) + + elif ws.path == '/range': + for i in xrange(10): + ws.send("msg %d" % i) + eventlet.sleep(0.1) + + else: + ws.close() + +wsapp = WebSocketWSGI(handle) + + +class TestWebSocket(_TestBase): + +# def setUp(self): +# super(_TestBase, self).setUp() +# self.logfile = StringIO() +# self.site = Site() +# self.killer = None +# self.set_site() +# self.spawn_server() +# self.site.application = WebSocketWSGI(handle, 'http://localhost:%s' % self.port) + + TEST_TIMEOUT = 5 + + def set_site(self): + self.site = wsapp + + + @raises(urllib2.HTTPError) + def test_incorrect_headers(self): + try: + urllib2.urlopen("http://localhost:%s/echo" % self.port) + except urllib2.HTTPError, e: + eq_(e.code, 400) + raise + + def test_incomplete_headers(self): + headers = dict(kv.split(': ') for kv in [ + "Upgrade: WebSocket", + #"Connection: Upgrade", Without this should trigger the HTTPServerError + "Host: localhost:%s" % self.port, + "Origin: http://localhost:%s" % self.port, + "WebSocket-Protocol: ws", + ]) + http = httplib2.Http() + resp, content = http.request("http://localhost:%s/echo" % self.port, headers=headers) + + self.assertEqual(resp['status'], '400') + self.assertEqual(resp['connection'], 'close') + self.assertEqual(content, '') + + def test_correct_upgrade_request(self): + connect = [ + "GET /echo HTTP/1.1", + "Upgrade: WebSocket", + "Connection: Upgrade", + "Host: localhost:%s" % self.port, + "Origin: http://localhost:%s" % self.port, + "WebSocket-Protocol: ws", + ] + sock = eventlet.connect( + ('localhost', self.port)) + + fd = sock.makefile('rw', close=True) + fd.write('\r\n'.join(connect) + '\r\n\r\n') + fd.flush() + result = sock.recv(1024) + fd.close() + ## The server responds the correct Websocket handshake + self.assertEqual(result, + '\r\n'.join(['HTTP/1.1 101 Web Socket Protocol Handshake', + 'Upgrade: WebSocket', + 'Connection: Upgrade', + 'WebSocket-Origin: http://localhost:%s' % self.port, + 'WebSocket-Location: ws://localhost:%s/echo\r\n\r\n' % self.port])) + + def test_sending_messages_to_websocket(self): + connect = [ + "GET /echo HTTP/1.1", + "Upgrade: WebSocket", + "Connection: Upgrade", + "Host: localhost:%s" % self.port, + "Origin: http://localhost:%s" % self.port, + "WebSocket-Protocol: ws", + ] + sock = eventlet.connect( + ('localhost', self.port)) + + fd = sock.makefile('rw', close=True) + fd.write('\r\n'.join(connect) + '\r\n\r\n') + fd.flush() + first_resp = sock.recv(1024) + fd.write('\x00hello\xFF') + fd.flush() + result = sock.recv(1024) + self.assertEqual(result, '\x00hello\xff') + fd.write('\x00start') + fd.flush() + fd.write(' end\xff') + fd.flush() + result = sock.recv(1024) + self.assertEqual(result, '\x00start end\xff') + fd.write('') + fd.flush() + + + + def test_getting_messages_from_websocket(self): + connect = [ + "GET /range HTTP/1.1", + "Upgrade: WebSocket", + "Connection: Upgrade", + "Host: localhost:%s" % self.port, + "Origin: http://localhost:%s" % self.port, + "WebSocket-Protocol: ws", + ] + sock = eventlet.connect( + ('localhost', self.port)) + + fd = sock.makefile('rw', close=True) + fd.write('\r\n'.join(connect) + '\r\n\r\n') + fd.flush() + resp = sock.recv(1024) + headers, result = resp.split('\r\n\r\n') + msgs = [result.strip('\x00\xff')] + cnt = 10 + while cnt: + msgs.append(sock.recv(20).strip('\x00\xff')) + cnt -= 1 + # Last item in msgs is an empty string + self.assertEqual(msgs[:-1], ['msg %d' % i for i in range(10)]) + + +class TestWebSocketObject(TestCase): + + def setUp(self): + self.mock_socket = s = mock.Mock() + self.environ = env = dict(HTTP_ORIGIN='http://localhost', HTTP_WEBSOCKET_PROTOCOL='ws', + PATH_INFO='test') + + self.test_ws = WebSocket(s, env) + + def test_recieve(self): + ws = self.test_ws + ws.socket.recv.return_value = '\x00hello\xFF' + eq_(ws.wait(), 'hello') + eq_(ws._buf, '') + eq_(len(ws._msgs), 0) + ws.socket.recv.return_value = '' + eq_(ws.wait(), None) + eq_(ws._buf, '') + eq_(len(ws._msgs), 0) + + + def test_send_to_ws(self): + ws = self.test_ws + ws.send(u'hello') + ok_(ws.socket.sendall.called_with("\x00hello\xFF")) + ws.send(10) + ok_(ws.socket.sendall.called_with("\x0010\xFF")) + + def test_close_ws(self): + ws = self.test_ws + ws.close() + ok_(ws.socket.shutdown.called_with(True)) + + + From 16dce8a8ae1d5147f7698add4223bf74fbb6cc7f Mon Sep 17 00:00:00 2001 From: Ben Ford Date: Wed, 5 May 2010 07:18:06 +0100 Subject: [PATCH 3/3] Realised that mock is only one file and is extremely useful so included it as tests.mock --- setup.py | 2 +- tests/mock.py | 271 ++++++++++++++++++++++++++++++++++++++++ tests/test_websocket.py | 2 +- 3 files changed, 273 insertions(+), 2 deletions(-) create mode 100755 tests/mock.py diff --git a/setup.py b/setup.py index 7098c23..821ccc3 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ setup( ) ).read(), test_suite = 'nose.collector', - tests_require = 'mock', + tests_require = 'httplib2', classifiers=[ "License :: OSI Approved :: MIT License", "Programming Language :: Python", diff --git a/tests/mock.py b/tests/mock.py new file mode 100755 index 0000000..03871d6 --- /dev/null +++ b/tests/mock.py @@ -0,0 +1,271 @@ +# mock.py +# Test tools for mocking and patching. +# Copyright (C) 2007-2009 Michael Foord +# E-mail: fuzzyman AT voidspace DOT org DOT uk + +# mock 0.6.0 +# http://www.voidspace.org.uk/python/mock/ + +# Released subject to the BSD License +# Please see http://www.voidspace.org.uk/python/license.shtml + +# Scripts maintained at http://www.voidspace.org.uk/python/index.shtml +# Comments, suggestions and bug reports welcome. + + +__all__ = ( + 'Mock', + 'patch', + 'patch_object', + 'sentinel', + 'DEFAULT' +) + +__version__ = '0.6.0' + +class SentinelObject(object): + def __init__(self, name): + self.name = name + + def __repr__(self): + return '' % self.name + + +class Sentinel(object): + def __init__(self): + self._sentinels = {} + + def __getattr__(self, name): + return self._sentinels.setdefault(name, SentinelObject(name)) + + +sentinel = Sentinel() + +DEFAULT = sentinel.DEFAULT + +class OldStyleClass: + pass +ClassType = type(OldStyleClass) + +def _is_magic(name): + return '__%s__' % name[2:-2] == name + +def _copy(value): + if type(value) in (dict, list, tuple, set): + return type(value)(value) + return value + + +class Mock(object): + + def __init__(self, spec=None, side_effect=None, return_value=DEFAULT, + name=None, parent=None, wraps=None): + self._parent = parent + self._name = name + if spec is not None and not isinstance(spec, list): + spec = [member for member in dir(spec) if not _is_magic(member)] + + self._methods = spec + self._children = {} + self._return_value = return_value + self.side_effect = side_effect + self._wraps = wraps + + self.reset_mock() + + + def reset_mock(self): + self.called = False + self.call_args = None + self.call_count = 0 + self.call_args_list = [] + self.method_calls = [] + for child in self._children.itervalues(): + child.reset_mock() + if isinstance(self._return_value, Mock): + self._return_value.reset_mock() + + + def __get_return_value(self): + if self._return_value is DEFAULT: + self._return_value = Mock() + return self._return_value + + def __set_return_value(self, value): + self._return_value = value + + return_value = property(__get_return_value, __set_return_value) + + + def __call__(self, *args, **kwargs): + self.called = True + self.call_count += 1 + self.call_args = (args, kwargs) + self.call_args_list.append((args, kwargs)) + + parent = self._parent + name = self._name + while parent is not None: + parent.method_calls.append((name, args, kwargs)) + if parent._parent is None: + break + name = parent._name + '.' + name + parent = parent._parent + + ret_val = DEFAULT + if self.side_effect is not None: + if (isinstance(self.side_effect, Exception) or + isinstance(self.side_effect, (type, ClassType)) and + issubclass(self.side_effect, Exception)): + raise self.side_effect + + ret_val = self.side_effect(*args, **kwargs) + if ret_val is DEFAULT: + ret_val = self.return_value + + if self._wraps is not None and self._return_value is DEFAULT: + return self._wraps(*args, **kwargs) + if ret_val is DEFAULT: + ret_val = self.return_value + return ret_val + + + def __getattr__(self, name): + if self._methods is not None: + if name not in self._methods: + raise AttributeError("Mock object has no attribute '%s'" % name) + elif _is_magic(name): + raise AttributeError(name) + + if name not in self._children: + wraps = None + if self._wraps is not None: + wraps = getattr(self._wraps, name) + self._children[name] = Mock(parent=self, name=name, wraps=wraps) + + return self._children[name] + + + def assert_called_with(self, *args, **kwargs): + assert self.call_args == (args, kwargs), 'Expected: %s\nCalled with: %s' % ((args, kwargs), self.call_args) + + +def _dot_lookup(thing, comp, import_path): + try: + return getattr(thing, comp) + except AttributeError: + __import__(import_path) + return getattr(thing, comp) + + +def _importer(target): + components = target.split('.') + import_path = components.pop(0) + thing = __import__(import_path) + + for comp in components: + import_path += ".%s" % comp + thing = _dot_lookup(thing, comp, import_path) + return thing + + +class _patch(object): + def __init__(self, target, attribute, new, spec, create): + self.target = target + self.attribute = attribute + self.new = new + self.spec = spec + self.create = create + self.has_local = False + + + def __call__(self, func): + if hasattr(func, 'patchings'): + func.patchings.append(self) + return func + + def patched(*args, **keywargs): + # don't use a with here (backwards compatability with 2.5) + extra_args = [] + for patching in patched.patchings: + arg = patching.__enter__() + if patching.new is DEFAULT: + extra_args.append(arg) + args += tuple(extra_args) + try: + return func(*args, **keywargs) + finally: + for patching in getattr(patched, 'patchings', []): + patching.__exit__() + + patched.patchings = [self] + patched.__name__ = func.__name__ + patched.compat_co_firstlineno = getattr(func, "compat_co_firstlineno", + func.func_code.co_firstlineno) + return patched + + + def get_original(self): + target = self.target + name = self.attribute + create = self.create + + original = DEFAULT + if _has_local_attr(target, name): + try: + original = target.__dict__[name] + except AttributeError: + # for instances of classes with slots, they have no __dict__ + original = getattr(target, name) + elif not create and not hasattr(target, name): + raise AttributeError("%s does not have the attribute %r" % (target, name)) + return original + + + def __enter__(self): + new, spec, = self.new, self.spec + original = self.get_original() + if new is DEFAULT: + # XXXX what if original is DEFAULT - shouldn't use it as a spec + inherit = False + if spec == True: + # set spec to the object we are replacing + spec = original + if isinstance(spec, (type, ClassType)): + inherit = True + new = Mock(spec=spec) + if inherit: + new.return_value = Mock(spec=spec) + self.temp_original = original + setattr(self.target, self.attribute, new) + return new + + + def __exit__(self, *_): + if self.temp_original is not DEFAULT: + setattr(self.target, self.attribute, self.temp_original) + else: + delattr(self.target, self.attribute) + del self.temp_original + + +def patch_object(target, attribute, new=DEFAULT, spec=None, create=False): + return _patch(target, attribute, new, spec, create) + + +def patch(target, new=DEFAULT, spec=None, create=False): + try: + target, attribute = target.rsplit('.', 1) + except (TypeError, ValueError): + raise TypeError("Need a valid target to patch. You supplied: %r" % (target,)) + target = _importer(target) + return _patch(target, attribute, new, spec, create) + + + +def _has_local_attr(obj, name): + try: + return name in vars(obj) + except TypeError: + # objects without a __dict__ + return hasattr(obj, name) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index db4638e..ca53253 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -7,7 +7,7 @@ from StringIO import StringIO from unittest import TestCase from tests.wsgi_test import _TestBase import logging -import mock +from tests import mock import random httplib2 = patcher.import_patched('httplib2')