416 lines
16 KiB
Python
416 lines
16 KiB
Python
# Copyright 2013-2015 DataStax, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
try:
|
|
import unittest2 as unittest
|
|
except ImportError:
|
|
import unittest # noqa
|
|
|
|
from mock import Mock, ANY, call, patch
|
|
import six
|
|
from six import BytesIO
|
|
import time
|
|
from threading import Lock
|
|
|
|
from cassandra.cluster import Cluster
|
|
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError,
|
|
locally_supported_compressions, ConnectionHeartbeat, _Frame)
|
|
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
|
|
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
|
|
SupportedMessage, ProtocolHandler)
|
|
|
|
|
|
class ConnectionTest(unittest.TestCase):
|
|
|
|
def make_connection(self):
|
|
c = Connection('1.2.3.4')
|
|
c._socket = Mock()
|
|
c._socket.send.side_effect = lambda x: len(x)
|
|
return c
|
|
|
|
def make_header_prefix(self, message_class, version=Connection.protocol_version, stream_id=0):
|
|
if Connection.protocol_version < 3:
|
|
return six.binary_type().join(map(uint8_pack, [
|
|
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
|
|
0, # flags (compression)
|
|
stream_id,
|
|
message_class.opcode # opcode
|
|
]))
|
|
else:
|
|
return six.binary_type().join(map(uint8_pack, [
|
|
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
|
|
0, # flags (compression)
|
|
0, # MSB for v3+ stream
|
|
stream_id,
|
|
message_class.opcode # opcode
|
|
]))
|
|
|
|
|
|
def make_options_body(self):
|
|
options_buf = BytesIO()
|
|
write_stringmultimap(options_buf, {
|
|
'CQL_VERSION': ['3.0.1'],
|
|
'COMPRESSION': []
|
|
})
|
|
return options_buf.getvalue()
|
|
|
|
def make_error_body(self, code, msg):
|
|
buf = BytesIO()
|
|
write_int(buf, code)
|
|
write_string(buf, msg)
|
|
return buf.getvalue()
|
|
|
|
def make_msg(self, header, body=""):
|
|
return header + uint32_pack(len(body)) + body
|
|
|
|
def test_bad_protocol_version(self, *args):
|
|
c = self.make_connection()
|
|
c._requests = Mock()
|
|
c.defunct = Mock()
|
|
|
|
# read in a SupportedMessage response
|
|
header = self.make_header_prefix(SupportedMessage, version=0x7f)
|
|
options = self.make_options_body()
|
|
message = self.make_msg(header, options)
|
|
c._iobuf = BytesIO()
|
|
c._iobuf.write(message)
|
|
c.process_io_buffer()
|
|
|
|
# make sure it errored correctly
|
|
c.defunct.assert_called_once_with(ANY)
|
|
args, kwargs = c.defunct.call_args
|
|
self.assertIsInstance(args[0], ProtocolError)
|
|
|
|
def test_negative_body_length(self, *args):
|
|
c = self.make_connection()
|
|
c._requests = Mock()
|
|
c.defunct = Mock()
|
|
|
|
# read in a SupportedMessage response
|
|
header = self.make_header_prefix(SupportedMessage)
|
|
message = header + int32_pack(-13)
|
|
c._iobuf = BytesIO()
|
|
c._iobuf.write(message)
|
|
c.process_io_buffer()
|
|
|
|
# make sure it errored correctly
|
|
c.defunct.assert_called_once_with(ANY)
|
|
args, kwargs = c.defunct.call_args
|
|
self.assertIsInstance(args[0], ProtocolError)
|
|
|
|
def test_unsupported_cql_version(self, *args):
|
|
c = self.make_connection()
|
|
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
|
|
c.defunct = Mock()
|
|
c.cql_version = "3.0.3"
|
|
|
|
# read in a SupportedMessage response
|
|
header = self.make_header_prefix(SupportedMessage)
|
|
|
|
options_buf = BytesIO()
|
|
write_stringmultimap(options_buf, {
|
|
'CQL_VERSION': ['7.8.9'],
|
|
'COMPRESSION': []
|
|
})
|
|
options = options_buf.getvalue()
|
|
|
|
c.process_msg(_Frame(version=4, flags=0, stream=0, opcode=SupportedMessage.opcode, body_offset=9, end_pos=9 + len(options)), options)
|
|
|
|
# make sure it errored correctly
|
|
c.defunct.assert_called_once_with(ANY)
|
|
args, kwargs = c.defunct.call_args
|
|
self.assertIsInstance(args[0], ProtocolError)
|
|
|
|
def test_prefer_lz4_compression(self, *args):
|
|
c = self.make_connection()
|
|
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
|
|
c.defunct = Mock()
|
|
c.cql_version = "3.0.3"
|
|
|
|
locally_supported_compressions.pop('lz4', None)
|
|
locally_supported_compressions.pop('snappy', None)
|
|
locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress')
|
|
locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress')
|
|
|
|
# read in a SupportedMessage response
|
|
options_buf = BytesIO()
|
|
write_stringmultimap(options_buf, {
|
|
'CQL_VERSION': ['3.0.3'],
|
|
'COMPRESSION': ['snappy', 'lz4']
|
|
})
|
|
options = options_buf.getvalue()
|
|
|
|
c.process_msg(_Frame(version=4, flags=0, stream=0, opcode=SupportedMessage.opcode, body_offset=9, end_pos=9 + len(options)), options)
|
|
|
|
self.assertEqual(c.decompressor, locally_supported_compressions['lz4'][1])
|
|
|
|
def test_requested_compression_not_available(self, *args):
|
|
c = self.make_connection()
|
|
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
|
|
c.defunct = Mock()
|
|
# request lz4 compression
|
|
c.compression = "lz4"
|
|
|
|
locally_supported_compressions.pop('lz4', None)
|
|
locally_supported_compressions.pop('snappy', None)
|
|
locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress')
|
|
locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress')
|
|
|
|
# read in a SupportedMessage response
|
|
header = self.make_header_prefix(SupportedMessage)
|
|
|
|
# the server only supports snappy
|
|
options_buf = BytesIO()
|
|
write_stringmultimap(options_buf, {
|
|
'CQL_VERSION': ['3.0.3'],
|
|
'COMPRESSION': ['snappy']
|
|
})
|
|
options = options_buf.getvalue()
|
|
|
|
c.process_msg(_Frame(version=4, flags=0, stream=0, opcode=SupportedMessage.opcode, body_offset=9, end_pos=9 + len(options)), options)
|
|
|
|
# make sure it errored correctly
|
|
c.defunct.assert_called_once_with(ANY)
|
|
args, kwargs = c.defunct.call_args
|
|
self.assertIsInstance(args[0], ProtocolError)
|
|
|
|
def test_use_requested_compression(self, *args):
|
|
c = self.make_connection()
|
|
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
|
|
c.defunct = Mock()
|
|
# request snappy compression
|
|
c.compression = "snappy"
|
|
|
|
locally_supported_compressions.pop('lz4', None)
|
|
locally_supported_compressions.pop('snappy', None)
|
|
locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress')
|
|
locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress')
|
|
|
|
# read in a SupportedMessage response
|
|
header = self.make_header_prefix(SupportedMessage)
|
|
|
|
# the server only supports snappy
|
|
options_buf = BytesIO()
|
|
write_stringmultimap(options_buf, {
|
|
'CQL_VERSION': ['3.0.3'],
|
|
'COMPRESSION': ['snappy', 'lz4']
|
|
})
|
|
options = options_buf.getvalue()
|
|
|
|
c.process_msg(_Frame(version=4, flags=0, stream=0, opcode=SupportedMessage.opcode, body_offset=9, end_pos=9 + len(options)), options)
|
|
|
|
self.assertEqual(c.decompressor, locally_supported_compressions['snappy'][1])
|
|
|
|
def test_disable_compression(self, *args):
|
|
c = self.make_connection()
|
|
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
|
|
c.defunct = Mock()
|
|
# disable compression
|
|
c.compression = False
|
|
|
|
locally_supported_compressions.pop('lz4', None)
|
|
locally_supported_compressions.pop('snappy', None)
|
|
locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress')
|
|
locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress')
|
|
|
|
# read in a SupportedMessage response
|
|
header = self.make_header_prefix(SupportedMessage)
|
|
|
|
# the server only supports snappy
|
|
options_buf = BytesIO()
|
|
write_stringmultimap(options_buf, {
|
|
'CQL_VERSION': ['3.0.3'],
|
|
'COMPRESSION': ['snappy', 'lz4']
|
|
})
|
|
options = options_buf.getvalue()
|
|
|
|
message = self.make_msg(header, options)
|
|
c.process_msg(message, len(message) - 8)
|
|
|
|
self.assertEqual(c.decompressor, None)
|
|
|
|
def test_not_implemented(self):
|
|
"""
|
|
Ensure the following methods throw NIE's. If not, come back and test them.
|
|
"""
|
|
c = self.make_connection()
|
|
self.assertRaises(NotImplementedError, c.close)
|
|
|
|
def test_set_keyspace_blocking(self):
|
|
c = self.make_connection()
|
|
|
|
self.assertEqual(c.keyspace, None)
|
|
c.set_keyspace_blocking(None)
|
|
self.assertEqual(c.keyspace, None)
|
|
|
|
c.keyspace = 'ks'
|
|
c.set_keyspace_blocking('ks')
|
|
self.assertEqual(c.keyspace, 'ks')
|
|
|
|
def test_set_connection_class(self):
|
|
cluster = Cluster(connection_class='test')
|
|
self.assertEqual('test', cluster.connection_class)
|
|
|
|
|
|
@patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped')
|
|
class ConnectionHeartbeatTest(unittest.TestCase):
|
|
|
|
@staticmethod
|
|
def make_get_holders(len):
|
|
holders = []
|
|
for _ in range(len):
|
|
holder = Mock()
|
|
holder.get_connections = Mock(return_value=[])
|
|
holders.append(holder)
|
|
get_holders = Mock(return_value=holders)
|
|
return get_holders
|
|
|
|
def run_heartbeat(self, get_holders_fun, count=2, interval=0.05):
|
|
ch = ConnectionHeartbeat(interval, get_holders_fun)
|
|
time.sleep(interval * count)
|
|
ch.stop()
|
|
self.assertTrue(get_holders_fun.call_count)
|
|
|
|
def test_empty_connections(self, *args):
|
|
count = 3
|
|
get_holders = self.make_get_holders(1)
|
|
|
|
self.run_heartbeat(get_holders, count)
|
|
|
|
self.assertGreaterEqual(get_holders.call_count, count - 1) # lower bound to account for thread spinup time
|
|
self.assertLessEqual(get_holders.call_count, count)
|
|
holder = get_holders.return_value[0]
|
|
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
|
|
|
|
def test_idle_non_idle(self, *args):
|
|
request_id = 999
|
|
|
|
# connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback)
|
|
def send_msg(msg, req_id, msg_callback):
|
|
msg_callback(SupportedMessage([], {}))
|
|
|
|
idle_connection = Mock(spec=Connection, host='localhost',
|
|
max_request_id=127,
|
|
lock=Lock(),
|
|
in_flight=0, is_idle=True,
|
|
is_defunct=False, is_closed=False,
|
|
get_request_id=lambda: request_id,
|
|
send_msg=Mock(side_effect=send_msg))
|
|
non_idle_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=False, is_closed=False)
|
|
|
|
get_holders = self.make_get_holders(1)
|
|
holder = get_holders.return_value[0]
|
|
holder.get_connections.return_value.append(idle_connection)
|
|
holder.get_connections.return_value.append(non_idle_connection)
|
|
|
|
self.run_heartbeat(get_holders)
|
|
|
|
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
|
|
self.assertEqual(idle_connection.in_flight, 0)
|
|
self.assertEqual(non_idle_connection.in_flight, 0)
|
|
|
|
idle_connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count)
|
|
self.assertEqual(non_idle_connection.send_msg.call_count, 0)
|
|
|
|
def test_closed_defunct(self, *args):
|
|
get_holders = self.make_get_holders(1)
|
|
closed_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=False, is_closed=True)
|
|
defunct_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=True, is_closed=False)
|
|
holder = get_holders.return_value[0]
|
|
holder.get_connections.return_value.append(closed_connection)
|
|
holder.get_connections.return_value.append(defunct_connection)
|
|
|
|
self.run_heartbeat(get_holders)
|
|
|
|
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
|
|
self.assertEqual(closed_connection.in_flight, 0)
|
|
self.assertEqual(defunct_connection.in_flight, 0)
|
|
self.assertEqual(closed_connection.send_msg.call_count, 0)
|
|
self.assertEqual(defunct_connection.send_msg.call_count, 0)
|
|
|
|
def test_no_req_ids(self, *args):
|
|
in_flight = 3
|
|
|
|
get_holders = self.make_get_holders(1)
|
|
max_connection = Mock(spec=Connection, host='localhost',
|
|
lock=Lock(),
|
|
max_request_id=in_flight, in_flight=in_flight,
|
|
is_idle=True, is_defunct=False, is_closed=False)
|
|
holder = get_holders.return_value[0]
|
|
holder.get_connections.return_value.append(max_connection)
|
|
|
|
self.run_heartbeat(get_holders)
|
|
|
|
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
|
|
self.assertEqual(max_connection.in_flight, in_flight)
|
|
self.assertEqual(max_connection.send_msg.call_count, 0)
|
|
self.assertEqual(max_connection.send_msg.call_count, 0)
|
|
max_connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count)
|
|
holder.return_connection.assert_has_calls([call(max_connection)] * get_holders.call_count)
|
|
|
|
def test_unexpected_response(self, *args):
|
|
request_id = 999
|
|
|
|
get_holders = self.make_get_holders(1)
|
|
|
|
def send_msg(msg, req_id, msg_callback):
|
|
msg_callback(object())
|
|
|
|
connection = Mock(spec=Connection, host='localhost',
|
|
max_request_id=127,
|
|
lock=Lock(),
|
|
in_flight=0, is_idle=True,
|
|
is_defunct=False, is_closed=False,
|
|
get_request_id=lambda: request_id,
|
|
send_msg=Mock(side_effect=send_msg))
|
|
holder = get_holders.return_value[0]
|
|
holder.get_connections.return_value.append(connection)
|
|
|
|
self.run_heartbeat(get_holders)
|
|
|
|
self.assertEqual(connection.in_flight, get_holders.call_count)
|
|
connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count)
|
|
connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count)
|
|
exc = connection.defunct.call_args_list[0][0][0]
|
|
self.assertIsInstance(exc, Exception)
|
|
self.assertEqual(exc.args, Exception('Connection heartbeat failure').args)
|
|
holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count)
|
|
|
|
def test_timeout(self, *args):
|
|
request_id = 999
|
|
|
|
get_holders = self.make_get_holders(1)
|
|
|
|
def send_msg(msg, req_id, msg_callback):
|
|
pass
|
|
|
|
connection = Mock(spec=Connection, host='localhost',
|
|
max_request_id=127,
|
|
lock=Lock(),
|
|
in_flight=0, is_idle=True,
|
|
is_defunct=False, is_closed=False,
|
|
get_request_id=lambda: request_id,
|
|
send_msg=Mock(side_effect=send_msg))
|
|
holder = get_holders.return_value[0]
|
|
holder.get_connections.return_value.append(connection)
|
|
|
|
self.run_heartbeat(get_holders)
|
|
|
|
self.assertEqual(connection.in_flight, get_holders.call_count)
|
|
connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count)
|
|
connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count)
|
|
exc = connection.defunct.call_args_list[0][0][0]
|
|
self.assertIsInstance(exc, Exception)
|
|
self.assertEqual(exc.args, Exception('Connection heartbeat failure').args)
|
|
holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count)
|