Properly handle onConnect if it is an async method
Includes a simple unit-test, and should close #752
This commit is contained in:
parent
57c2f87d77
commit
d069a4c1e4
@ -32,3 +32,41 @@ class Test(TestCase):
|
||||
transport = Mock()
|
||||
|
||||
server.connection_made(transport)
|
||||
|
||||
def test_async_on_connect_server(self):
|
||||
# see also issue 757
|
||||
|
||||
async def foo(x):
|
||||
return x * x
|
||||
|
||||
values = []
|
||||
async def on_connect(req):
|
||||
x = await foo(42)
|
||||
values.append(x)
|
||||
|
||||
factory = WebSocketServerFactory()
|
||||
server = factory()
|
||||
server.onConnect = on_connect
|
||||
transport = Mock()
|
||||
|
||||
server.connection_made(transport)
|
||||
# need/want to insert real-fake handshake data?
|
||||
server.data = b"\r\n".join([
|
||||
b'GET /ws HTTP/1.1',
|
||||
b'Host: www.example.com',
|
||||
b'Sec-WebSocket-Version: 13',
|
||||
b'Origin: http://www.example.com.malicious.com',
|
||||
b'Sec-WebSocket-Extensions: permessage-deflate',
|
||||
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
|
||||
b'Connection: keep-alive, Upgrade',
|
||||
b'Upgrade: websocket',
|
||||
b'\r\n', # last string doesn't get a \r\n from join()
|
||||
])
|
||||
server.processHandshake()
|
||||
|
||||
import asyncio
|
||||
from asyncio.test_utils import run_once
|
||||
run_once(asyncio.get_event_loop())
|
||||
|
||||
self.assertEqual(1, len(values))
|
||||
self.assertEqual(42 * 42, values[0])
|
||||
|
@ -202,14 +202,19 @@ class WebSocketServerProtocol(WebSocketAdapterProtocol, protocol.WebSocketServer
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
res = self.onConnect(request)
|
||||
if yields(res):
|
||||
asyncio.async(res)
|
||||
except ConnectionDeny as e:
|
||||
self.failHandshake(e.reason, e.code)
|
||||
except Exception as e:
|
||||
self.failHandshake("Internal server error: {}".format(e), ConnectionDeny.INTERNAL_SERVER_ERROR)
|
||||
else:
|
||||
self.succeedHandshake(res)
|
||||
if yields(res):
|
||||
# if onConnect was an async method, we need to await
|
||||
# the actual result before calling succeedHandshake
|
||||
asyncio.async(res).add_done_callback(
|
||||
lambda res: self.succeedHandshake(res.result())
|
||||
)
|
||||
else:
|
||||
self.succeedHandshake(res)
|
||||
|
||||
|
||||
class WebSocketClientProtocol(WebSocketAdapterProtocol, protocol.WebSocketClientProtocol):
|
||||
|
Loading…
Reference in New Issue
Block a user