diff --git a/eventlet/websocket.py b/eventlet/websocket.py index 2976d7e..5cd2447 100644 --- a/eventlet/websocket.py +++ b/eventlet/websocket.py @@ -1,15 +1,64 @@ import collections import errno -from eventlet import wsgi -from eventlet import pools + import eventlet -from eventlet.support import get_errno +from eventlet import semaphore +from eventlet import wsgi from eventlet.green import socket -#from pprint import pformat +from eventlet.support import get_errno + +ACCEPTABLE_CLIENT_ERRORS = set((errno.ECONNRESET, errno.EPIPE)) + +class WebSocketWSGI(object): + """This is a WSGI application that serves up websocket connections. + """ + def __init__(self, handler): + self.handler = handler + + def __call__(self, environ, start_response): + if not (environ.get('HTTP_CONNECTION') == 'Upgrade' and + environ.get('HTTP_UPGRADE') == 'WebSocket'): + # need to check a few more things here for true compliance + start_response('400 Bad Request', [('Connection','close')]) + return [] + + sock = environ['eventlet.input'].get_socket() + ws = WebSocket(sock, environ) + handshake_reply = ("HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: %s\r\n" + "WebSocket-Location: ws://%s%s\r\n\r\n" % ( + environ.get('HTTP_ORIGIN'), + environ.get('HTTP_HOST'), + environ.get('PATH_INFO'))) + sock.sendall(handshake_reply) + try: + self.handler(ws) + except socket.error, e: + if get_errno(e) not in ACCEPTABLE_CLIENT_ERRORS: + raise + # use this undocumented feature of eventlet.wsgi to ensure that it + # doesn't barf on the fact that we didn't call start_response + return wsgi.ALREADY_HANDLED + class WebSocket(object): - """Handles access to the actual socket""" - + """The object representing the server side of a websocket. + + The primary way to interact with a WebSocket object is to call + :meth:`send` and :meth:`wait` in order to pass messages back and + forth with the client. Also available are the following properties: + + path + The path value of the request. This is the same as the WSGI PATH_INFO variable. + protocol + The value of the Websocket-Protocol header. + origin + The value of the 'Origin' header. + environ + The full WSGI environment for this request. + """ def __init__(self, sock, environ): """ :param socket: The eventlet socket @@ -23,7 +72,7 @@ class WebSocket(object): self.environ = environ self._buf = "" self._msgs = collections.deque() - self._sendlock = pools.TokenPool(1) + self._sendlock = semaphore.Semaphore() @staticmethod def pack_message(message): @@ -44,8 +93,8 @@ class WebSocket(object): may contain only part of the rest of the message. NOTE: only understands lengthless messages for now. - Returns an array of messages, and the buffer remainder that didn't contain - any full messages.""" + Returns an array of messages, and the buffer remainder that + didn't contain any full messages.""" msgs = [] end_idx = 0 buf = self._buf @@ -60,22 +109,24 @@ class WebSocket(object): return msgs def send(self, message): - """Send a message to the client""" + """Send a message to the client. *message* should be + convertable to a string; unicode objects should be encodable + as utf-8.""" packed = self.pack_message(message) # if two greenthreads are trying to send at the same time # on the same socket, sendlock prevents interleaving and corruption - t = self._sendlock.get() + self._sendlock.acquire() try: self.socket.sendall(packed) finally: - self._sendlock.put(t) - - def wait(self): - """Waits for an deserializes messages""" + self._sendlock.release() + def wait(self): + """Waits for and deserializes messages. Returns a single + message; the oldest not yet processed.""" while not self._msgs: # no parsed messages, must mean buf needs more data - delta = self.socket.recv(1024) + delta = self.socket.recv(8096) if delta == '': return None self._buf += delta @@ -84,5 +135,8 @@ class WebSocket(object): return self._msgs.popleft() def close(self): + """Forcibly close the websocket; generally it is preferable to + return from the handler method.""" self.socket.shutdown(True) + self.socket.close() diff --git a/examples/websocket.py b/examples/websocket.py index 56f6d1e..9bac0e2 100644 --- a/examples/websocket.py +++ b/examples/websocket.py @@ -1,109 +1,6 @@ -import collections -import errno import eventlet from eventlet import wsgi -from eventlet import pools -from eventlet.support import get_errno - -class WebSocketWSGI(object): - def __init__(self, handler, origin): - self.handler = handler - self.origin = origin - - def verify_client(self, ws): - pass - - def __call__(self, environ, start_response): - if not (environ['HTTP_CONNECTION'] == 'Upgrade' and - environ['HTTP_UPGRADE'] == 'WebSocket'): - # need to check a few more things here for true compliance - start_response('400 Bad Request', [('Connection','close')]) - return [] - - sock = environ['eventlet.input'].get_socket() - ws = WebSocket(sock, - environ.get('HTTP_ORIGIN'), - environ.get('HTTP_WEBSOCKET_PROTOCOL'), - environ.get('PATH_INFO')) - self.verify_client(ws) - handshake_reply = ("HTTP/1.1 101 Web Socket Protocol Handshake\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "WebSocket-Origin: %s\r\n" - "WebSocket-Location: ws://%s%s\r\n\r\n" % ( - self.origin, - environ.get('HTTP_HOST'), - ws.path)) - sock.sendall(handshake_reply) - try: - self.handler(ws) - except socket.error, e: - if get_errno(e) != errno.EPIPE: - raise - # use this undocumented feature of eventlet.wsgi to ensure that it - # doesn't barf on the fact that we didn't call start_response - return wsgi.ALREADY_HANDLED - -def parse_messages(buf): - """ Parses for messages in the buffer *buf*. It is assumed that - the buffer contains the start character for a message, but that it - may contain only part of the rest of the message. NOTE: only understands - lengthless messages for now. - - Returns an array of messages, and the buffer remainder that didn't contain - any full messages.""" - msgs = [] - end_idx = 0 - while buf: - assert ord(buf[0]) == 0, "Don't understand how to parse this type of message: %r" % buf - end_idx = buf.find("\xFF") - if end_idx == -1: - break - msgs.append(buf[1:end_idx].decode('utf-8', 'replace')) - buf = buf[end_idx+1:] - return msgs, buf - -def format_message(message): - # TODO support iterable messages - if isinstance(message, unicode): - message = message.encode('utf-8') - elif not isinstance(message, str): - message = str(message) - packed = "\x00%s\xFF" % message - return packed - - -class WebSocket(object): - def __init__(self, sock, origin, protocol, path): - self.sock = sock - self.origin = origin - self.protocol = protocol - self.path = path - self._buf = "" - self._msgs = collections.deque() - self._sendlock = pools.TokenPool(1) - - def send(self, message): - packed = format_message(message) - # if two greenthreads are trying to send at the same time - # on the same socket, sendlock prevents interleaving and corruption - t = self._sendlock.get() - try: - self.sock.sendall(packed) - finally: - self._sendlock.put(t) - - def wait(self): - while not self._msgs: - # no parsed messages, must mean buf needs more data - delta = self.sock.recv(1024) - if delta == '': - return None - self._buf += delta - msgs, self._buf = parse_messages(self._buf) - self._msgs.extend(msgs) - return self._msgs.popleft() - +from eventlet import websocket # demo app import os @@ -122,21 +19,21 @@ def handle(ws): for i in xrange(10000): ws.send("0 %s %s\n" % (i, random.random())) eventlet.sleep(0.1) - -wsapp = WebSocketWSGI(handle, 'http://localhost:7000') + +wsapp = websocket.WebSocketWSGI(handle) def dispatch(environ, start_response): """ This resolves to the web page or the websocket depending on the path.""" - if environ['PATH_INFO'] == '/': + if environ['PATH_INFO'] == '/data': + return wsapp(environ, start_response) + else: start_response('200 OK', [('content-type', 'text/html')]) return [open(os.path.join( os.path.dirname(__file__), 'websocket.html')).read()] - else: - return wsapp(environ, start_response) - if __name__ == "__main__": # run an example app from the command line listener = eventlet.listen(('localhost', 7000)) + print "\nVisit http://localhost:7000/ in your websocket-capable browser.\n" wsgi.server(listener, dispatch) diff --git a/tests/websocket_test.py b/tests/websocket_test.py index 3e9ea59..3fff51d 100644 --- a/tests/websocket_test.py +++ b/tests/websocket_test.py @@ -1,65 +1,37 @@ +import socket +import errno + import eventlet from eventlet.green import urllib2 from eventlet.green import httplib -from eventlet.websocket import WebSocket +from eventlet.websocket import WebSocket, WebSocketWSGI from eventlet import wsgi +from eventlet import event from tests import mock, LimitedTestCase from tests.wsgi_test import _TestBase -class WebSocketWSGI(object): - def __init__(self, handler): - self.handler = handler - - def __call__(self, environ, start_response): - if not (environ.get('HTTP_CONNECTION') == 'Upgrade' and - environ.get('HTTP_UPGRADE') == 'WebSocket'): - # need to check a few more things here for true compliance - start_response('400 Bad Request', [('Connection','close')]) - return [] - - sock = environ['eventlet.input'].get_socket() - ws = WebSocket(sock, environ) - handshake_reply = ("HTTP/1.1 101 Web Socket Protocol Handshake\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "WebSocket-Origin: %s\r\n" - "WebSocket-Location: ws://%s%s\r\n\r\n" % ( - environ.get('HTTP_ORIGIN'), - environ.get('HTTP_HOST'), - environ.get('PATH_INFO'))) - sock.sendall(handshake_reply) - try: - self.handler(ws) - except socket.error, e: - if get_errno(e) != errno.EPIPE: - raise - # use this undocumented feature of eventlet.wsgi to ensure that it - # doesn't barf on the fact that we didn't call start_response - return wsgi.ALREADY_HANDLED # demo app def handle(ws): - """ This is the websocket handler function. Note that we - can dispatch based on path in here, too.""" if ws.path == '/echo': while True: m = ws.wait() if m is None: break ws.send(m) - elif ws.path == '/range': for i in xrange(10): ws.send("msg %d" % i) - eventlet.sleep(0.1) - + eventlet.sleep(0.01) + elif ws.path == '/error': + # some random socket error that we shouldn't normally get + raise socket.error(errno.ENOTSOCK) else: ws.close() wsapp = WebSocketWSGI(handle) - class TestWebSocket(_TestBase): TEST_TIMEOUT = 5 @@ -135,7 +107,9 @@ class TestWebSocket(_TestBase): sock.sendall(' end\xff') result = sock.recv(1024) self.assertEqual(result, '\x00start end\xff') + sock.shutdown(socket.SHUT_RDWR) sock.close() + eventlet.sleep(0.01) def test_getting_messages_from_websocket(self): connect = [ @@ -160,6 +134,67 @@ class TestWebSocket(_TestBase): # Last item in msgs is an empty string self.assertEqual(msgs[:-1], ['msg %d' % i for i in range(10)]) + def test_breaking_the_connection(self): + error_detected = [False] + done_with_request = event.Event() + site = self.site + def error_detector(environ, start_response): + try: + try: + return site(environ, start_response) + except: + error_detected[0] = True + raise + finally: + done_with_request.send(True) + self.site = error_detector + self.spawn_server() + connect = [ + "GET /range HTTP/1.1", + "Upgrade: WebSocket", + "Connection: Upgrade", + "Host: localhost:%s" % self.port, + "Origin: http://localhost:%s" % self.port, + "WebSocket-Protocol: ws", + ] + sock = eventlet.connect( + ('localhost', self.port)) + sock.sendall('\r\n'.join(connect) + '\r\n\r\n') + resp = sock.recv(1024) # get the headers + sock.close() # close while the app is running + done_with_request.wait() + self.assert_(not error_detected[0]) + + def test_app_socket_errors(self): + error_detected = [False] + done_with_request = event.Event() + site = self.site + def error_detector(environ, start_response): + try: + try: + return site(environ, start_response) + except: + error_detected[0] = True + raise + finally: + done_with_request.send(True) + self.site = error_detector + self.spawn_server() + connect = [ + "GET /error HTTP/1.1", + "Upgrade: WebSocket", + "Connection: Upgrade", + "Host: localhost:%s" % self.port, + "Origin: http://localhost:%s" % self.port, + "WebSocket-Protocol: ws", + ] + sock = eventlet.connect( + ('localhost', self.port)) + sock.sendall('\r\n'.join(connect) + '\r\n\r\n') + resp = sock.recv(1024) + done_with_request.wait() + self.assert_(error_detected[0]) + class TestWebSocketObject(LimitedTestCase):