Files
deb-python-cassandra-driver/tests/unit/test_connection.py
Tyler Hobbs 80715767a6 Merge branch 'master' into 2.0
Conflicts:
	cassandra/connection.py
	requirements.txt
	setup.py
	tests/unit/test_connection.py
	tox.ini
2014-05-07 17:06:13 -05:00

291 lines
10 KiB
Python

# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from six import BytesIO
from mock import Mock, ANY
from cassandra.cluster import Cluster
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT,
HEADER_DIRECTION_FROM_CLIENT, ProtocolError,
locally_supported_compressions)
from cassandra.marshal import uint8_pack, uint32_pack
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
SupportedMessage)
class ConnectionTest(unittest.TestCase):
protocol_version = 2
def make_connection(self):
c = Connection('1.2.3.4')
c._socket = Mock()
c._socket.send.side_effect = lambda x: len(x)
return c
def make_header_prefix(self, message_class, version=2, stream_id=0):
return six.binary_type().join(map(uint8_pack, [
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
0, # flags (compression)
stream_id,
message_class.opcode # opcode
]))
def make_options_body(self):
options_buf = BytesIO()
write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.1'],
'COMPRESSION': []
})
return options_buf.getvalue()
def make_error_body(self, code, msg):
buf = BytesIO()
write_int(buf, code)
write_string(buf, msg)
return buf.getvalue()
def make_msg(self, header, body=""):
return header + uint32_pack(len(body)) + body
def test_bad_protocol_version(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = Mock()
c.defunct = Mock()
# read in a SupportedMessage response
header = self.make_header_prefix(SupportedMessage, version=0x04)
options = self.make_options_body()
message = self.make_msg(header, options)
c.process_msg(message, len(message) - 8)
# make sure it errored correctly
c.defunct.assert_called_once_with(ANY)
args, kwargs = c.defunct.call_args
self.assertIsInstance(args[0], ProtocolError)
def test_bad_header_direction(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = Mock()
c.defunct = Mock()
# read in a SupportedMessage response
header = six.binary_type().join(uint8_pack(i) for i in (
0xff & (HEADER_DIRECTION_FROM_CLIENT | self.protocol_version),
0, # flags (compression)
0,
SupportedMessage.opcode # opcode
))
options = self.make_options_body()
message = self.make_msg(header, options)
c.process_msg(message, len(message) - 8)
# make sure it errored correctly
c.defunct.assert_called_once_with(ANY)
args, kwargs = c.defunct.call_args
self.assertIsInstance(args[0], ProtocolError)
def test_negative_body_length(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = Mock()
c.defunct = Mock()
# read in a SupportedMessage response
header = self.make_header_prefix(SupportedMessage)
options = self.make_options_body()
message = self.make_msg(header, options)
c.process_msg(message, -13)
# make sure it errored correctly
c.defunct.assert_called_once_with(ANY)
args, kwargs = c.defunct.call_args
self.assertIsInstance(args[0], ProtocolError)
def test_unsupported_cql_version(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = {0: c._handle_options_response}
c.defunct = Mock()
c.cql_version = "3.0.3"
# read in a SupportedMessage response
header = self.make_header_prefix(SupportedMessage)
options_buf = BytesIO()
write_stringmultimap(options_buf, {
'CQL_VERSION': ['7.8.9'],
'COMPRESSION': []
})
options = options_buf.getvalue()
message = self.make_msg(header, options)
c.process_msg(message, len(message) - 8)
# make sure it errored correctly
c.defunct.assert_called_once_with(ANY)
args, kwargs = c.defunct.call_args
self.assertIsInstance(args[0], ProtocolError)
def test_prefer_lz4_compression(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = {0: c._handle_options_response}
c.defunct = Mock()
c.cql_version = "3.0.3"
locally_supported_compressions.pop('lz4', None)
locally_supported_compressions.pop('snappy', None)
locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress')
locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress')
# read in a SupportedMessage response
header = self.make_header_prefix(SupportedMessage)
options_buf = BytesIO()
write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.3'],
'COMPRESSION': ['snappy', 'lz4']
})
options = options_buf.getvalue()
message = self.make_msg(header, options)
c.process_msg(message, len(message) - 8)
self.assertEqual(c.decompressor, locally_supported_compressions['lz4'][1])
def test_requested_compression_not_available(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = {0: c._handle_options_response}
c.defunct = Mock()
# request lz4 compression
c.compression = "lz4"
locally_supported_compressions.pop('lz4', None)
locally_supported_compressions.pop('snappy', None)
locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress')
locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress')
# read in a SupportedMessage response
header = self.make_header_prefix(SupportedMessage)
# the server only supports snappy
options_buf = BytesIO()
write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.3'],
'COMPRESSION': ['snappy']
})
options = options_buf.getvalue()
message = self.make_msg(header, options)
c.process_msg(message, len(message) - 8)
# make sure it errored correctly
c.defunct.assert_called_once_with(ANY)
args, kwargs = c.defunct.call_args
self.assertIsInstance(args[0], ProtocolError)
def test_use_requested_compression(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = {0: c._handle_options_response}
c.defunct = Mock()
# request snappy compression
c.compression = "snappy"
locally_supported_compressions.pop('lz4', None)
locally_supported_compressions.pop('snappy', None)
locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress')
locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress')
# read in a SupportedMessage response
header = self.make_header_prefix(SupportedMessage)
# the server only supports snappy
options_buf = BytesIO()
write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.3'],
'COMPRESSION': ['snappy', 'lz4']
})
options = options_buf.getvalue()
message = self.make_msg(header, options)
c.process_msg(message, len(message) - 8)
self.assertEqual(c.decompressor, locally_supported_compressions['snappy'][1])
def test_disable_compression(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = {0: c._handle_options_response}
c.defunct = Mock()
# disable compression
c.compression = False
locally_supported_compressions.pop('lz4', None)
locally_supported_compressions.pop('snappy', None)
locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress')
locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress')
# read in a SupportedMessage response
header = self.make_header_prefix(SupportedMessage)
# the server only supports snappy
options_buf = BytesIO()
write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.3'],
'COMPRESSION': ['snappy', 'lz4']
})
options = options_buf.getvalue()
message = self.make_msg(header, options)
c.process_msg(message, len(message) - 8)
self.assertEqual(c.decompressor, None)
def test_not_implemented(self):
"""
Ensure the following methods throw NIE's. If not, come back and test them.
"""
c = self.make_connection()
self.assertRaises(NotImplementedError, c.close)
self.assertRaises(NotImplementedError, c.register_watcher, None, None)
self.assertRaises(NotImplementedError, c.register_watchers, None)
def test_set_keyspace_blocking(self):
c = self.make_connection()
self.assertEqual(c.keyspace, None)
c.set_keyspace_blocking(None)
self.assertEqual(c.keyspace, None)
c.keyspace = 'ks'
c.set_keyspace_blocking('ks')
self.assertEqual(c.keyspace, 'ks')
def test_set_connection_class(self):
cluster = Cluster(connection_class='test')
self.assertEqual('test', cluster.connection_class)