diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 15fa6e72..1872bbe0 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -22,9 +22,11 @@ from six import BytesIO import time from threading import Lock -from cassandra.cluster import Cluster, Session +from cassandra import OperationTimedOut +from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, - locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager) + locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, + ConnectionException) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage, ProtocolHandler) @@ -382,8 +384,8 @@ class ConnectionHeartbeatTest(unittest.TestCase): connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) exc = connection.defunct.call_args_list[0][0][0] - self.assertIsInstance(exc, Exception) - self.assertEqual(exc.args, Exception('Connection heartbeat failure').args) + self.assertIsInstance(exc, ConnectionException) + self.assertRegex(exc.args[0], r'^Received unexpected response to OptionsMessage.*') holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count) def test_timeout(self, *args): @@ -410,8 +412,9 @@ class ConnectionHeartbeatTest(unittest.TestCase): connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) exc = connection.defunct.call_args_list[0][0][0] - self.assertIsInstance(exc, Exception) - self.assertEqual(exc.args, Exception('Connection heartbeat failure').args) + self.assertIsInstance(exc, OperationTimedOut) + self.assertEqual(exc.errors, 'Connection heartbeat timeout after 0.05 seconds') + self.assertEqual(exc.last_host, 'localhost') holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count)