diff --git a/autobahn/asyncio/test/test_asyncio_websocket.py b/autobahn/asyncio/test/test_asyncio_websocket.py index a752d3bb..d5525f2b 100644 --- a/autobahn/asyncio/test/test_asyncio_websocket.py +++ b/autobahn/asyncio/test/test_asyncio_websocket.py @@ -32,3 +32,41 @@ class Test(TestCase): transport = Mock() server.connection_made(transport) + + def test_async_on_connect_server(self): + # see also issue 757 + + async def foo(x): + return x * x + + values = [] + async def on_connect(req): + x = await foo(42) + values.append(x) + + factory = WebSocketServerFactory() + server = factory() + server.onConnect = on_connect + transport = Mock() + + server.connection_made(transport) + # need/want to insert real-fake handshake data? + server.data = b"\r\n".join([ + b'GET /ws HTTP/1.1', + b'Host: www.example.com', + b'Sec-WebSocket-Version: 13', + b'Origin: http://www.example.com.malicious.com', + b'Sec-WebSocket-Extensions: permessage-deflate', + b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==', + b'Connection: keep-alive, Upgrade', + b'Upgrade: websocket', + b'\r\n', # last string doesn't get a \r\n from join() + ]) + server.processHandshake() + + import asyncio + from asyncio.test_utils import run_once + run_once(asyncio.get_event_loop()) + + self.assertEqual(1, len(values)) + self.assertEqual(42 * 42, values[0]) diff --git a/autobahn/asyncio/websocket.py b/autobahn/asyncio/websocket.py index f08ac6d9..944f4fe4 100644 --- a/autobahn/asyncio/websocket.py +++ b/autobahn/asyncio/websocket.py @@ -202,14 +202,19 @@ class WebSocketServerProtocol(WebSocketAdapterProtocol, protocol.WebSocketServer # noinspection PyBroadException try: res = self.onConnect(request) - if yields(res): - asyncio.async(res) except ConnectionDeny as e: self.failHandshake(e.reason, e.code) except Exception as e: self.failHandshake("Internal server error: {}".format(e), ConnectionDeny.INTERNAL_SERVER_ERROR) else: - self.succeedHandshake(res) + if yields(res): + # if onConnect was an async method, we need to await + # the actual result before calling succeedHandshake + asyncio.async(res).add_done_callback( + lambda res: self.succeedHandshake(res.result()) + ) + else: + self.succeedHandshake(res) class WebSocketClientProtocol(WebSocketAdapterProtocol, protocol.WebSocketClientProtocol):