diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py index 1b29c528..f7acae67 100644 --- a/tests/integration/standard/test_connection.py +++ b/tests/integration/standard/test_connection.py @@ -28,10 +28,10 @@ from cassandra.cluster import NoHostAvailable, Cluster from cassandra.io.asyncorereactor import AsyncoreConnection from cassandra.protocol import QueryMessage from cassandra.connection import Connection -from cassandra.policies import WhiteListRoundRobinPolicy +from cassandra.policies import WhiteListRoundRobinPolicy, HostStateListener from tests import is_monkey_patched -from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration import use_singledc, PROTOCOL_VERSION, get_node try: from cassandra.io.libevreactor import LibevConnection @@ -78,6 +78,91 @@ class ConnectionTimeoutTest(unittest.TestCase): future.result() +class TestHostListener(HostStateListener): + host_down = None + + def on_down(self, host): + host_down = host + + +class HeartbeatTest(unittest.TestCase): + """ + Test to validate failing a heartbeat check doesn't mark a host as down + + @since 3.3 + @jira_ticket PYTHON-286 + @expected_result host should not be marked down when heartbeat fails + + @test_category connection heartbeat + """ + + def setUp(self): + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=1) + self.session = self.cluster.connect() + + def tearDown(self): + self.cluster.shutdown() + + def test_heart_beat_timeout(self): + # Setup a host listener to ensure the nodes don't go down + test_listener = TestHostListener() + host = "127.0.0.1" + node = get_node(1) + initial_connections = self.fetch_connections(host, self.cluster) + self.assertNotEqual(len(initial_connections), 0) + self.cluster.register_listener(test_listener) + # Pause the node + node.pause() + # Wait for connections associated with this host go away + self.wait_for_no_connections(host, self.cluster) + # Resume paused node + node.resume() + # Run a query to ensure connections are re-established + current_host = "" + count = 0 + while current_host != host and count < 100: + rs = self.session.execute_async("SELECT * FROM system.local", trace=False) + rs.result() + current_host = str(rs._current_host) + count += 1 + self.assertLess(count, 100, "Never connected to the first node") + new_connections = self.wait_for_connections(host, self.cluster) + self.assertIsNone(test_listener.host_down) + # Make sure underlying new connections don't match previous ones + for connection in initial_connections: + self.assertFalse(connection in new_connections) + + def fetch_connections(self, host, cluster): + # Given a cluster object and host grab all connection associated with that host + connections = [] + holders = cluster.get_connection_holders() + for conn in holders: + if host == str(getattr(conn, 'host', '')): + if conn._connection is not None: + connections.append(conn._connection) + return connections + + def wait_for_connections(self, host, cluster): + retry = 0 + while(retry < 300): + retry += 1 + connections = self.fetch_connections(host, cluster) + if len(connections) is not 0: + return connections + time.sleep(.1) + self.fail("No new connections found") + + def wait_for_no_connections(self, host, cluster): + retry = 0 + while(retry < 100): + retry += 1 + connections = self.fetch_connections(host, cluster) + if len(connections) is 0: + return + time.sleep(.1) + self.fail("Connections never cleared") + + class ConnectionTests(object): klass = None