Handle the case that RFP negotiation message arrived early.

The RFP server will send RFP negotiation message to client once
connection is setup. In some case this message may arrive before
it retrieve the response for WebSocket upgrade. Then it will cause
issues for the following RFP verification as no more RFP negotiation
initial message will arrive.
This commit will cache the data if the above case happened. Defined
a new function _recv() which will read from the cached buff if it has
data, otherwise read from the socket.

Change-Id: Icc3f312156b8d2cee6e0500218edf5d8b991ade7
Closes-Bug: #1691185
This commit is contained in:
jianghua wang 2017-05-08 08:05:04 +01:00 committed by Jianghua Wang
parent 9ef109d999
commit d22514a522
2 changed files with 136 additions and 3 deletions

View File

@ -252,16 +252,34 @@ class _WebSocket(object):
def __init__(self, client_socket, url):
"""Contructor for the WebSocket wrapper to the socket."""
self._socket = client_socket
# cached stream for early frames.
self.cached_stream = b''
# Upgrade the HTTP connection to a WebSocket
self._upgrade(url)
def _recv(self, recv_size):
"""Wrapper to receive data from the cached stream or socket."""
if recv_size <= 0:
return None
data_from_cached = b''
data_from_socket = b''
if len(self.cached_stream) > 0:
read_from_cached = min(len(self.cached_stream), recv_size)
data_from_cached += self.cached_stream[:read_from_cached]
self.cached_stream = self.cached_stream[read_from_cached:]
recv_size -= read_from_cached
if recv_size > 0:
data_from_socket = self._socket.recv(recv_size)
return data_from_cached + data_from_socket
def receive_frame(self):
"""Wrapper for receiving data to parse the WebSocket frame format"""
# We need to loop until we either get some bytes back in the frame
# or no data was received (meaning the socket was closed). This is
# done to handle the case where we get back some empty frames
while True:
header = self._socket.recv(2)
header = self._recv(2)
# If we didn't receive any data, just return None
if not header:
return None
@ -270,7 +288,7 @@ class _WebSocket(object):
# that only the 2nd byte contains the length, and since the
# server doesn't do masking, we can just read the data length
if ord_func(header[1]) & 127 > 0:
return self._socket.recv(ord_func(header[1]) & 127)
return self._recv(ord_func(header[1]) & 127)
def send_frame(self, data):
"""Wrapper for sending data to add in the WebSocket frame format."""
@ -318,6 +336,15 @@ class _WebSocket(object):
self._socket.sendall(reqdata.encode('utf8'))
self.response = data = self._socket.recv(4096)
# Loop through & concatenate all of the data in the response body
while data and self.response.find(b'\r\n\r\n') < 0:
end_loc = self.response.find(b'\r\n\r\n')
while data and end_loc < 0:
data = self._socket.recv(4096)
self.response += data
end_loc = self.response.find(b'\r\n\r\n')
if len(self.response) > end_loc + 4:
# In case some frames (e.g. the first RFP negotiation) have
# arrived, cache it for next reading.
self.cached_stream = self.response[end_loc + 4:]
# ensure response ends with '\r\n\r\n'.
self.response = self.response[:end_loc + 4]

View File

@ -0,0 +1,106 @@
# Copyright 2017 Citrix Systems
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from six.moves.urllib import parse as urlparse
import mock
from tempest.common import compute
from tempest.tests import base
class TestCompute(base.TestCase):
def setUp(self):
super(TestCompute, self).setUp()
self.client_sock = mock.Mock()
self.url = urlparse.urlparse("http://www.fake.com:80")
def test_rfp_frame_not_cached(self):
# rfp negotiation frame arrived separately after upgrade
# response, so it's not cached.
RFP_VERSION = b'RFB.003.003\x0a'
rfp_frame_header = b'\x82\x0c'
self.client_sock.recv.side_effect = [
b'fake response start\r\n',
b'fake response end\r\n\r\n',
rfp_frame_header,
RFP_VERSION]
expect_response = b'fake response start\r\nfake response end\r\n\r\n'
webSocket = compute._WebSocket(self.client_sock, self.url)
self.assertEqual(webSocket.response, expect_response)
# no cache
self.assertEqual(webSocket.cached_stream, b'')
self.client_sock.recv.assert_has_calls([mock.call(4096),
mock.call(4096)])
self.client_sock.recv.reset_mock()
recv_version = webSocket.receive_frame()
self.assertEqual(recv_version, RFP_VERSION)
self.client_sock.recv.assert_has_calls([mock.call(2),
mock.call(12)])
def test_rfp_frame_fully_cached(self):
RFP_VERSION = b'RFB.003.003\x0a'
rfp_version_frame = b'\x82\x0c%s' % RFP_VERSION
self.client_sock.recv.side_effect = [
b'fake response start\r\n',
b'fake response end\r\n\r\n%s' % rfp_version_frame]
expect_response = b'fake response start\r\nfake response end\r\n\r\n'
webSocket = compute._WebSocket(self.client_sock, self.url)
self.client_sock.recv.assert_has_calls([mock.call(4096),
mock.call(4096)])
self.assertEqual(webSocket.response, expect_response)
self.assertEqual(webSocket.cached_stream, rfp_version_frame)
self.client_sock.recv.reset_mock()
recv_version = webSocket.receive_frame()
self.client_sock.recv.assert_not_called()
self.assertEqual(recv_version, RFP_VERSION)
# cached_stream should be empty in the end.
self.assertEqual(webSocket.cached_stream, b'')
def test_rfp_frame_partially_cached(self):
RFP_VERSION = b'RFB.003.003\x0a'
rfp_version_frame = b'\x82\x0c%s' % RFP_VERSION
frame_part1 = rfp_version_frame[:6]
frame_part2 = rfp_version_frame[6:]
self.client_sock.recv.side_effect = [
b'fake response start\r\n',
b'fake response end\r\n\r\n%s' % frame_part1,
frame_part2]
expect_response = b'fake response start\r\nfake response end\r\n\r\n'
webSocket = compute._WebSocket(self.client_sock, self.url)
self.client_sock.recv.assert_has_calls([mock.call(4096),
mock.call(4096)])
self.assertEqual(webSocket.response, expect_response)
self.assertEqual(webSocket.cached_stream, frame_part1)
self.client_sock.recv.reset_mock()
recv_version = webSocket.receive_frame()
self.client_sock.recv.assert_called_once_with(len(frame_part2))
self.assertEqual(recv_version, RFP_VERSION)
# cached_stream should be empty in the end.
self.assertEqual(webSocket.cached_stream, b'')