From 5f7e4a9016d5215063b9521b8dfd45b368670c62 Mon Sep 17 00:00:00 2001 From: Tyler Hobbs Date: Tue, 14 Jan 2014 17:56:37 -0600 Subject: [PATCH] Add asyncore unit test --- tests/unit/io/test_asyncorereactor.py | 234 ++++++++++++++++++++++++++ 1 file changed, 234 insertions(+) create mode 100644 tests/unit/io/test_asyncorereactor.py diff --git a/tests/unit/io/test_asyncorereactor.py b/tests/unit/io/test_asyncorereactor.py new file mode 100644 index 00000000..0ed26eda --- /dev/null +++ b/tests/unit/io/test_asyncorereactor.py @@ -0,0 +1,234 @@ +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import errno +from StringIO import StringIO +import socket +from socket import error as socket_error + +from mock import patch, Mock + +from cassandra.connection import (PROTOCOL_VERSION, + HEADER_DIRECTION_TO_CLIENT, + ConnectionException) + +from cassandra.decoder import (write_stringmultimap, write_int, write_string, + SupportedMessage, ReadyMessage, ServerError) +from cassandra.marshal import uint8_pack, uint32_pack + +from cassandra.io.asyncorereactor import AsyncoreConnection + + +class LibevConnectionTest(unittest.TestCase): + + def setUp(self): + self.socket_patcher = patch('socket.socket', spec=socket.socket) + self.mock_socket = self.socket_patcher.start() + self.mock_socket().connect_ex.return_value = 0 + self.mock_socket().getsockopt.return_value = 0 + + def tearDown(self): + self.socket_patcher.stop() + + def make_connection(self): + c = AsyncoreConnection('1.2.3.4') + c.socket = Mock() + c.socket.send.side_effect = lambda x: len(x) + return c + + def make_header_prefix(self, message_class, version=PROTOCOL_VERSION, stream_id=0): + return ''.join(map(uint8_pack, [ + 0xff & (HEADER_DIRECTION_TO_CLIENT | version), + 0, # flags (compression) + stream_id, + message_class.opcode # opcode + ])) + + def make_options_body(self): + options_buf = StringIO() + write_stringmultimap(options_buf, { + 'CQL_VERSION': ['3.0.1'], + 'COMPRESSION': [] + }) + return options_buf.getvalue() + + def make_error_body(self, code, msg): + buf = StringIO() + write_int(buf, code) + write_string(buf, msg) + return buf.getvalue() + + def make_msg(self, header, body=""): + return header + uint32_pack(len(body)) + body + + def test_successful_connection(self, *args): + c = self.make_connection() + + # let it write the OptionsMessage + c.handle_write() + + # read in a SupportedMessage response + header = self.make_header_prefix(SupportedMessage) + options = self.make_options_body() + c.socket.recv.return_value = self.make_msg(header, options) + c.handle_read() + + # let it write out a StartupMessage + c.handle_write() + + header = self.make_header_prefix(ReadyMessage, stream_id=1) + c.socket.recv.return_value = self.make_msg(header) + c.handle_read() + + self.assertTrue(c.connected_event.is_set()) + + def test_protocol_error(self, *args): + c = self.make_connection() + + # let it write the OptionsMessage + c.handle_write() + + # read in a SupportedMessage response + header = self.make_header_prefix(SupportedMessage, version=0xa4) + options = self.make_options_body() + c.socket.recv.return_value = self.make_msg(header, options) + c.handle_read() + + # make sure it errored correctly + self.assertTrue(c.is_defunct) + self.assertTrue(c.connected_event.is_set()) + self.assertIsInstance(c.last_error, ConnectionException) + + def test_error_message_on_startup(self, *args): + c = self.make_connection() + + # let it write the OptionsMessage + c.handle_write() + + # read in a SupportedMessage response + header = self.make_header_prefix(SupportedMessage) + options = self.make_options_body() + c.socket.recv.return_value = self.make_msg(header, options) + c.handle_read() + + # let it write out a StartupMessage + c.handle_write() + + header = self.make_header_prefix(ServerError, stream_id=1) + body = self.make_error_body(ServerError.error_code, ServerError.summary) + c.socket.recv.return_value = self.make_msg(header, body) + c.handle_read() + + # make sure it errored correctly + self.assertTrue(c.is_defunct) + self.assertIsInstance(c.last_error, ConnectionException) + self.assertTrue(c.connected_event.is_set()) + + def test_socket_error_on_write(self, *args): + c = self.make_connection() + + # make the OptionsMessage write fail + c.socket.send.side_effect = socket_error(errno.EIO, "bad stuff!") + c.handle_write() + + # make sure it errored correctly + self.assertTrue(c.is_defunct) + self.assertIsInstance(c.last_error, socket_error) + self.assertTrue(c.connected_event.is_set()) + + def test_blocking_on_write(self, *args): + c = self.make_connection() + + # make the OptionsMessage write block + c.socket.send.side_effect = socket_error(errno.EAGAIN, "socket busy") + c.handle_write() + + self.assertFalse(c.is_defunct) + + # try again with normal behavior + c.socket.send.side_effect = lambda x: len(x) + c.handle_write() + self.assertFalse(c.is_defunct) + self.assertTrue(c.socket.send.call_args is not None) + + def test_partial_send(self, *args): + c = self.make_connection() + + # only write the first four bytes of the OptionsMessage + c.socket.send.side_effect = None + c.socket.send.return_value = 4 + c.handle_write() + + self.assertFalse(c.is_defunct) + self.assertEqual(2, c.socket.send.call_count) + self.assertEqual(4, len(c.socket.send.call_args[0][0])) + + def test_socket_error_on_read(self, *args): + c = self.make_connection() + + # let it write the OptionsMessage + c.handle_write() + + # read in a SupportedMessage response + c.socket.recv.side_effect = socket_error(errno.EIO, "busy socket") + c.handle_read() + + # make sure it errored correctly + self.assertTrue(c.is_defunct) + self.assertIsInstance(c.last_error, socket_error) + self.assertTrue(c.connected_event.is_set()) + + def test_partial_header_read(self, *args): + c = self.make_connection() + + header = self.make_header_prefix(SupportedMessage) + options = self.make_options_body() + message = self.make_msg(header, options) + + # read in the first byte + c.socket.recv.return_value = message[0] + c.handle_read() + self.assertEquals(c._iobuf.getvalue(), message[0]) + + c.socket.recv.return_value = message[1:] + c.handle_read() + self.assertEquals("", c._iobuf.getvalue()) + + # let it write out a StartupMessage + c.handle_write() + + header = self.make_header_prefix(ReadyMessage, stream_id=1) + c.socket.recv.return_value = self.make_msg(header) + c.handle_read() + + self.assertTrue(c.connected_event.is_set()) + self.assertFalse(c.is_defunct) + + def test_partial_message_read(self, *args): + c = self.make_connection() + + header = self.make_header_prefix(SupportedMessage) + options = self.make_options_body() + message = self.make_msg(header, options) + + # read in the first nine bytes + c.socket.recv.return_value = message[:9] + c.handle_read() + self.assertEquals(c._iobuf.getvalue(), message[:9]) + + # ... then read in the rest + c.socket.recv.return_value = message[9:] + c.handle_read() + self.assertEquals("", c._iobuf.getvalue()) + + # let it write out a StartupMessage + c.handle_write() + + header = self.make_header_prefix(ReadyMessage, stream_id=1) + c.socket.recv.return_value = self.make_msg(header) + c.handle_read() + + self.assertTrue(c.connected_event.is_set()) + self.assertFalse(c.is_defunct)