diff --git a/tempest/common/compute.py b/tempest/common/compute.py index 9110c4add9..cb9525bbf2 100644 --- a/tempest/common/compute.py +++ b/tempest/common/compute.py @@ -252,16 +252,34 @@ class _WebSocket(object): def __init__(self, client_socket, url): """Contructor for the WebSocket wrapper to the socket.""" self._socket = client_socket + # cached stream for early frames. + self.cached_stream = b'' # Upgrade the HTTP connection to a WebSocket self._upgrade(url) + def _recv(self, recv_size): + """Wrapper to receive data from the cached stream or socket.""" + if recv_size <= 0: + return None + + data_from_cached = b'' + data_from_socket = b'' + if len(self.cached_stream) > 0: + read_from_cached = min(len(self.cached_stream), recv_size) + data_from_cached += self.cached_stream[:read_from_cached] + self.cached_stream = self.cached_stream[read_from_cached:] + recv_size -= read_from_cached + if recv_size > 0: + data_from_socket = self._socket.recv(recv_size) + return data_from_cached + data_from_socket + def receive_frame(self): """Wrapper for receiving data to parse the WebSocket frame format""" # We need to loop until we either get some bytes back in the frame # or no data was received (meaning the socket was closed). This is # done to handle the case where we get back some empty frames while True: - header = self._socket.recv(2) + header = self._recv(2) # If we didn't receive any data, just return None if not header: return None @@ -270,7 +288,7 @@ class _WebSocket(object): # that only the 2nd byte contains the length, and since the # server doesn't do masking, we can just read the data length if ord_func(header[1]) & 127 > 0: - return self._socket.recv(ord_func(header[1]) & 127) + return self._recv(ord_func(header[1]) & 127) def send_frame(self, data): """Wrapper for sending data to add in the WebSocket frame format.""" @@ -318,6 +336,15 @@ class _WebSocket(object): self._socket.sendall(reqdata.encode('utf8')) self.response = data = self._socket.recv(4096) # Loop through & concatenate all of the data in the response body - while data and self.response.find(b'\r\n\r\n') < 0: + end_loc = self.response.find(b'\r\n\r\n') + while data and end_loc < 0: data = self._socket.recv(4096) self.response += data + end_loc = self.response.find(b'\r\n\r\n') + + if len(self.response) > end_loc + 4: + # In case some frames (e.g. the first RFP negotiation) have + # arrived, cache it for next reading. + self.cached_stream = self.response[end_loc + 4:] + # ensure response ends with '\r\n\r\n'. + self.response = self.response[:end_loc + 4] diff --git a/tempest/tests/common/test_compute.py b/tempest/tests/common/test_compute.py new file mode 100644 index 0000000000..c108be981e --- /dev/null +++ b/tempest/tests/common/test_compute.py @@ -0,0 +1,106 @@ +# Copyright 2017 Citrix Systems +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from six.moves.urllib import parse as urlparse + +import mock + +from tempest.common import compute +from tempest.tests import base + + +class TestCompute(base.TestCase): + def setUp(self): + super(TestCompute, self).setUp() + self.client_sock = mock.Mock() + self.url = urlparse.urlparse("http://www.fake.com:80") + + def test_rfp_frame_not_cached(self): + # rfp negotiation frame arrived separately after upgrade + # response, so it's not cached. + RFP_VERSION = b'RFB.003.003\x0a' + rfp_frame_header = b'\x82\x0c' + + self.client_sock.recv.side_effect = [ + b'fake response start\r\n', + b'fake response end\r\n\r\n', + rfp_frame_header, + RFP_VERSION] + expect_response = b'fake response start\r\nfake response end\r\n\r\n' + + webSocket = compute._WebSocket(self.client_sock, self.url) + + self.assertEqual(webSocket.response, expect_response) + # no cache + self.assertEqual(webSocket.cached_stream, b'') + self.client_sock.recv.assert_has_calls([mock.call(4096), + mock.call(4096)]) + + self.client_sock.recv.reset_mock() + recv_version = webSocket.receive_frame() + + self.assertEqual(recv_version, RFP_VERSION) + self.client_sock.recv.assert_has_calls([mock.call(2), + mock.call(12)]) + + def test_rfp_frame_fully_cached(self): + RFP_VERSION = b'RFB.003.003\x0a' + rfp_version_frame = b'\x82\x0c%s' % RFP_VERSION + + self.client_sock.recv.side_effect = [ + b'fake response start\r\n', + b'fake response end\r\n\r\n%s' % rfp_version_frame] + expect_response = b'fake response start\r\nfake response end\r\n\r\n' + webSocket = compute._WebSocket(self.client_sock, self.url) + + self.client_sock.recv.assert_has_calls([mock.call(4096), + mock.call(4096)]) + self.assertEqual(webSocket.response, expect_response) + self.assertEqual(webSocket.cached_stream, rfp_version_frame) + + self.client_sock.recv.reset_mock() + recv_version = webSocket.receive_frame() + + self.client_sock.recv.assert_not_called() + self.assertEqual(recv_version, RFP_VERSION) + # cached_stream should be empty in the end. + self.assertEqual(webSocket.cached_stream, b'') + + def test_rfp_frame_partially_cached(self): + RFP_VERSION = b'RFB.003.003\x0a' + rfp_version_frame = b'\x82\x0c%s' % RFP_VERSION + frame_part1 = rfp_version_frame[:6] + frame_part2 = rfp_version_frame[6:] + + self.client_sock.recv.side_effect = [ + b'fake response start\r\n', + b'fake response end\r\n\r\n%s' % frame_part1, + frame_part2] + expect_response = b'fake response start\r\nfake response end\r\n\r\n' + webSocket = compute._WebSocket(self.client_sock, self.url) + + self.client_sock.recv.assert_has_calls([mock.call(4096), + mock.call(4096)]) + self.assertEqual(webSocket.response, expect_response) + self.assertEqual(webSocket.cached_stream, frame_part1) + + self.client_sock.recv.reset_mock() + + recv_version = webSocket.receive_frame() + + self.client_sock.recv.assert_called_once_with(len(frame_part2)) + self.assertEqual(recv_version, RFP_VERSION) + # cached_stream should be empty in the end. + self.assertEqual(webSocket.cached_stream, b'')