130 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			130 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
try:
 | 
						|
    import unittest2 as unittest
 | 
						|
except ImportError:
 | 
						|
    import unittest # noqa
 | 
						|
 | 
						|
from StringIO import StringIO
 | 
						|
 | 
						|
from mock import Mock, ANY
 | 
						|
 | 
						|
from cassandra.connection import (Connection, PROTOCOL_VERSION,
 | 
						|
                                  HEADER_DIRECTION_TO_CLIENT,
 | 
						|
                                  HEADER_DIRECTION_FROM_CLIENT, ProtocolError)
 | 
						|
from cassandra.decoder import (write_stringmultimap, write_int, write_string,
 | 
						|
                               SupportedMessage)
 | 
						|
from cassandra.marshal import uint8_pack, uint32_pack
 | 
						|
 | 
						|
class ConnectionTest(unittest.TestCase):
 | 
						|
 | 
						|
    def make_connection(self):
 | 
						|
        c = Connection('1.2.3.4')
 | 
						|
        c._socket = Mock()
 | 
						|
        c._socket.send.side_effect = lambda x: len(x)
 | 
						|
        return c
 | 
						|
 | 
						|
    def make_header_prefix(self, message_class, version=PROTOCOL_VERSION, stream_id=0):
 | 
						|
        return ''.join(map(uint8_pack, [
 | 
						|
            0xff & (HEADER_DIRECTION_TO_CLIENT | version),
 | 
						|
            0,  # flags (compression)
 | 
						|
            stream_id,
 | 
						|
            message_class.opcode  # opcode
 | 
						|
        ]))
 | 
						|
 | 
						|
    def make_options_body(self):
 | 
						|
        options_buf = StringIO()
 | 
						|
        write_stringmultimap(options_buf, {
 | 
						|
            'CQL_VERSION': ['3.0.1'],
 | 
						|
            'COMPRESSION': []
 | 
						|
        })
 | 
						|
        return options_buf.getvalue()
 | 
						|
 | 
						|
    def make_error_body(self, code, msg):
 | 
						|
        buf = StringIO()
 | 
						|
        write_int(buf, code)
 | 
						|
        write_string(buf, msg)
 | 
						|
        return buf.getvalue()
 | 
						|
 | 
						|
    def make_msg(self, header, body=""):
 | 
						|
        return header + uint32_pack(len(body)) + body
 | 
						|
 | 
						|
    def test_bad_protocol_version(self, *args):
 | 
						|
        c = self.make_connection()
 | 
						|
        c._id_queue.get_nowait()
 | 
						|
        c._callbacks = Mock()
 | 
						|
        c.defunct = Mock()
 | 
						|
 | 
						|
        # read in a SupportedMessage response
 | 
						|
        header = self.make_header_prefix(SupportedMessage, version=0x04)
 | 
						|
        options = self.make_options_body()
 | 
						|
        message = self.make_msg(header, options)
 | 
						|
        c.process_msg(message, len(message) - 8)
 | 
						|
 | 
						|
        # make sure it errored correctly
 | 
						|
        c.defunct.assert_called_once_with(ANY)
 | 
						|
        args, kwargs = c.defunct.call_args
 | 
						|
        self.assertIsInstance(args[0], ProtocolError)
 | 
						|
 | 
						|
    def test_bad_header_direction(self, *args):
 | 
						|
        c = self.make_connection()
 | 
						|
        c._id_queue.get_nowait()
 | 
						|
        c._callbacks = Mock()
 | 
						|
        c.defunct = Mock()
 | 
						|
 | 
						|
        # read in a SupportedMessage response
 | 
						|
        header = ''.join(map(uint8_pack, [
 | 
						|
            0xff & (HEADER_DIRECTION_FROM_CLIENT | PROTOCOL_VERSION),
 | 
						|
            0,  # flags (compression)
 | 
						|
            0,
 | 
						|
            SupportedMessage.opcode  # opcode
 | 
						|
        ]))
 | 
						|
        options = self.make_options_body()
 | 
						|
        message = self.make_msg(header, options)
 | 
						|
        c.process_msg(message, len(message) - 8)
 | 
						|
 | 
						|
        # make sure it errored correctly
 | 
						|
        c.defunct.assert_called_once_with(ANY)
 | 
						|
        args, kwargs = c.defunct.call_args
 | 
						|
        self.assertIsInstance(args[0], ProtocolError)
 | 
						|
 | 
						|
    def test_negative_body_length(self, *args):
 | 
						|
        c = self.make_connection()
 | 
						|
        c._id_queue.get_nowait()
 | 
						|
        c._callbacks = Mock()
 | 
						|
        c.defunct = Mock()
 | 
						|
 | 
						|
        # read in a SupportedMessage response
 | 
						|
        header = self.make_header_prefix(SupportedMessage, version=0x04)
 | 
						|
        options = self.make_options_body()
 | 
						|
        message = self.make_msg(header, options)
 | 
						|
        c.process_msg(message, -13)
 | 
						|
 | 
						|
        # make sure it errored correctly
 | 
						|
        c.defunct.assert_called_once_with(ANY)
 | 
						|
        args, kwargs = c.defunct.call_args
 | 
						|
        self.assertIsInstance(args[0], ProtocolError)
 | 
						|
 | 
						|
    def test_unsupported_cql_version(self, *args):
 | 
						|
        c = self.make_connection()
 | 
						|
        c._id_queue.get_nowait()
 | 
						|
        c._callbacks = {0: c._handle_options_response}
 | 
						|
        c.defunct = Mock()
 | 
						|
        c.cql_version = "3.0.3"
 | 
						|
 | 
						|
        # read in a SupportedMessage response
 | 
						|
        header = self.make_header_prefix(SupportedMessage)
 | 
						|
 | 
						|
        options_buf = StringIO()
 | 
						|
        write_stringmultimap(options_buf, {
 | 
						|
            'CQL_VERSION': ['7.8.9'],
 | 
						|
            'COMPRESSION': []
 | 
						|
        })
 | 
						|
        options = options_buf.getvalue()
 | 
						|
 | 
						|
        message = self.make_msg(header, options)
 | 
						|
        c.process_msg(message, len(message) - 8)
 | 
						|
 | 
						|
        # make sure it errored correctly
 | 
						|
        c.defunct.assert_called_once_with(ANY)
 | 
						|
        args, kwargs = c.defunct.call_args
 | 
						|
        self.assertIsInstance(args[0], ProtocolError)
 |