Properly handle onConnect if it is an async method

Includes a simple unit-test, and should close #752
This commit is contained in:
meejah 2016-11-25 16:16:11 -07:00
parent 57c2f87d77
commit d069a4c1e4
2 changed files with 46 additions and 3 deletions

View File

@ -32,3 +32,41 @@ class Test(TestCase):
transport = Mock() transport = Mock()
server.connection_made(transport) server.connection_made(transport)
def test_async_on_connect_server(self):
# see also issue 757
async def foo(x):
return x * x
values = []
async def on_connect(req):
x = await foo(42)
values.append(x)
factory = WebSocketServerFactory()
server = factory()
server.onConnect = on_connect
transport = Mock()
server.connection_made(transport)
# need/want to insert real-fake handshake data?
server.data = b"\r\n".join([
b'GET /ws HTTP/1.1',
b'Host: www.example.com',
b'Sec-WebSocket-Version: 13',
b'Origin: http://www.example.com.malicious.com',
b'Sec-WebSocket-Extensions: permessage-deflate',
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
b'Connection: keep-alive, Upgrade',
b'Upgrade: websocket',
b'\r\n', # last string doesn't get a \r\n from join()
])
server.processHandshake()
import asyncio
from asyncio.test_utils import run_once
run_once(asyncio.get_event_loop())
self.assertEqual(1, len(values))
self.assertEqual(42 * 42, values[0])

View File

@ -202,12 +202,17 @@ class WebSocketServerProtocol(WebSocketAdapterProtocol, protocol.WebSocketServer
# noinspection PyBroadException # noinspection PyBroadException
try: try:
res = self.onConnect(request) res = self.onConnect(request)
if yields(res):
asyncio.async(res)
except ConnectionDeny as e: except ConnectionDeny as e:
self.failHandshake(e.reason, e.code) self.failHandshake(e.reason, e.code)
except Exception as e: except Exception as e:
self.failHandshake("Internal server error: {}".format(e), ConnectionDeny.INTERNAL_SERVER_ERROR) self.failHandshake("Internal server error: {}".format(e), ConnectionDeny.INTERNAL_SERVER_ERROR)
else:
if yields(res):
# if onConnect was an async method, we need to await
# the actual result before calling succeedHandshake
asyncio.async(res).add_done_callback(
lambda res: self.succeedHandshake(res.result())
)
else: else:
self.succeedHandshake(res) self.succeedHandshake(res)