Properly handle onConnect if it is an async method
Includes a simple unit-test, and should close #752
This commit is contained in:
parent
57c2f87d77
commit
d069a4c1e4
@ -32,3 +32,41 @@ class Test(TestCase):
|
|||||||
transport = Mock()
|
transport = Mock()
|
||||||
|
|
||||||
server.connection_made(transport)
|
server.connection_made(transport)
|
||||||
|
|
||||||
|
def test_async_on_connect_server(self):
|
||||||
|
# see also issue 757
|
||||||
|
|
||||||
|
async def foo(x):
|
||||||
|
return x * x
|
||||||
|
|
||||||
|
values = []
|
||||||
|
async def on_connect(req):
|
||||||
|
x = await foo(42)
|
||||||
|
values.append(x)
|
||||||
|
|
||||||
|
factory = WebSocketServerFactory()
|
||||||
|
server = factory()
|
||||||
|
server.onConnect = on_connect
|
||||||
|
transport = Mock()
|
||||||
|
|
||||||
|
server.connection_made(transport)
|
||||||
|
# need/want to insert real-fake handshake data?
|
||||||
|
server.data = b"\r\n".join([
|
||||||
|
b'GET /ws HTTP/1.1',
|
||||||
|
b'Host: www.example.com',
|
||||||
|
b'Sec-WebSocket-Version: 13',
|
||||||
|
b'Origin: http://www.example.com.malicious.com',
|
||||||
|
b'Sec-WebSocket-Extensions: permessage-deflate',
|
||||||
|
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
|
||||||
|
b'Connection: keep-alive, Upgrade',
|
||||||
|
b'Upgrade: websocket',
|
||||||
|
b'\r\n', # last string doesn't get a \r\n from join()
|
||||||
|
])
|
||||||
|
server.processHandshake()
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from asyncio.test_utils import run_once
|
||||||
|
run_once(asyncio.get_event_loop())
|
||||||
|
|
||||||
|
self.assertEqual(1, len(values))
|
||||||
|
self.assertEqual(42 * 42, values[0])
|
||||||
|
@ -202,12 +202,17 @@ class WebSocketServerProtocol(WebSocketAdapterProtocol, protocol.WebSocketServer
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
res = self.onConnect(request)
|
res = self.onConnect(request)
|
||||||
if yields(res):
|
|
||||||
asyncio.async(res)
|
|
||||||
except ConnectionDeny as e:
|
except ConnectionDeny as e:
|
||||||
self.failHandshake(e.reason, e.code)
|
self.failHandshake(e.reason, e.code)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.failHandshake("Internal server error: {}".format(e), ConnectionDeny.INTERNAL_SERVER_ERROR)
|
self.failHandshake("Internal server error: {}".format(e), ConnectionDeny.INTERNAL_SERVER_ERROR)
|
||||||
|
else:
|
||||||
|
if yields(res):
|
||||||
|
# if onConnect was an async method, we need to await
|
||||||
|
# the actual result before calling succeedHandshake
|
||||||
|
asyncio.async(res).add_done_callback(
|
||||||
|
lambda res: self.succeedHandshake(res.result())
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.succeedHandshake(res)
|
self.succeedHandshake(res)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user