401 lines
13 KiB
Python
401 lines
13 KiB
Python
# Copyright 2013-2017 DataStax, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
try:
|
|
import unittest2 as unittest
|
|
except ImportError:
|
|
import unittest # noqa
|
|
|
|
from functools import partial
|
|
from six.moves import range
|
|
import sys
|
|
from threading import Thread, Event
|
|
import time
|
|
|
|
from cassandra import ConsistencyLevel, OperationTimedOut
|
|
from cassandra.cluster import NoHostAvailable, ConnectionShutdown, Cluster
|
|
from cassandra.io.asyncorereactor import AsyncoreConnection
|
|
from cassandra.protocol import QueryMessage
|
|
from cassandra.connection import Connection
|
|
from cassandra.policies import WhiteListRoundRobinPolicy, HostStateListener
|
|
from cassandra.pool import HostConnectionPool
|
|
|
|
from tests import is_monkey_patched
|
|
from tests.integration import use_singledc, PROTOCOL_VERSION, get_node, CASSANDRA_IP, local
|
|
|
|
try:
|
|
from cassandra.io.libevreactor import LibevConnection
|
|
except ImportError:
|
|
LibevConnection = None
|
|
|
|
|
|
def setup_module():
|
|
use_singledc()
|
|
|
|
|
|
class ConnectionTimeoutTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.defaultInFlight = Connection.max_in_flight
|
|
Connection.max_in_flight = 2
|
|
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy=
|
|
WhiteListRoundRobinPolicy([CASSANDRA_IP]))
|
|
self.session = self.cluster.connect()
|
|
|
|
def tearDown(self):
|
|
Connection.max_in_flight = self.defaultInFlight
|
|
self.cluster.shutdown()
|
|
|
|
def test_in_flight_timeout(self):
|
|
"""
|
|
Test to ensure that connection id fetching will block when max_id is reached/
|
|
|
|
In previous versions of the driver this test will cause a
|
|
NoHostAvailable exception to be thrown, when the max_id is restricted
|
|
|
|
@since 3.3
|
|
@jira_ticket PYTHON-514
|
|
@expected_result When many requests are run on a single node connection acquisition should block
|
|
until connection is available or the request times out.
|
|
|
|
@test_category connection timeout
|
|
"""
|
|
futures = []
|
|
query = '''SELECT * FROM system.local'''
|
|
for i in range(100):
|
|
futures.append(self.session.execute_async(query))
|
|
|
|
for future in futures:
|
|
future.result()
|
|
|
|
|
|
class TestHostListener(HostStateListener):
|
|
host_down = None
|
|
|
|
def on_down(self, host):
|
|
host_down = host
|
|
|
|
|
|
class HeartbeatTest(unittest.TestCase):
|
|
"""
|
|
Test to validate failing a heartbeat check doesn't mark a host as down
|
|
|
|
@since 3.3
|
|
@jira_ticket PYTHON-286
|
|
@expected_result host should not be marked down when heartbeat fails
|
|
|
|
@test_category connection heartbeat
|
|
"""
|
|
|
|
def setUp(self):
|
|
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=1)
|
|
self.session = self.cluster.connect(wait_for_all_pools=True)
|
|
|
|
def tearDown(self):
|
|
self.cluster.shutdown()
|
|
|
|
@local
|
|
def test_heart_beat_timeout(self):
|
|
# Setup a host listener to ensure the nodes don't go down
|
|
test_listener = TestHostListener()
|
|
host = "127.0.0.1"
|
|
node = get_node(1)
|
|
initial_connections = self.fetch_connections(host, self.cluster)
|
|
self.assertNotEqual(len(initial_connections), 0)
|
|
self.cluster.register_listener(test_listener)
|
|
# Pause the node
|
|
try:
|
|
node.pause()
|
|
# Wait for connections associated with this host go away
|
|
self.wait_for_no_connections(host, self.cluster)
|
|
# Resume paused node
|
|
finally:
|
|
node.resume()
|
|
# Run a query to ensure connections are re-established
|
|
current_host = ""
|
|
count = 0
|
|
while current_host != host and count < 100:
|
|
rs = self.session.execute_async("SELECT * FROM system.local", trace=False)
|
|
rs.result()
|
|
current_host = str(rs._current_host)
|
|
count += 1
|
|
time.sleep(.1)
|
|
self.assertLess(count, 100, "Never connected to the first node")
|
|
new_connections = self.wait_for_connections(host, self.cluster)
|
|
self.assertIsNone(test_listener.host_down)
|
|
# Make sure underlying new connections don't match previous ones
|
|
for connection in initial_connections:
|
|
self.assertFalse(connection in new_connections)
|
|
|
|
def fetch_connections(self, host, cluster):
|
|
# Given a cluster object and host grab all connection associated with that host
|
|
connections = []
|
|
holders = cluster.get_connection_holders()
|
|
for conn in holders:
|
|
if host == str(getattr(conn, 'host', '')):
|
|
if isinstance(conn, HostConnectionPool):
|
|
if conn._connections is not None and len(conn._connections) > 0:
|
|
connections.append(conn._connections)
|
|
else:
|
|
if conn._connection is not None:
|
|
connections.append(conn._connection)
|
|
return connections
|
|
|
|
def wait_for_connections(self, host, cluster):
|
|
retry = 0
|
|
while(retry < 300):
|
|
retry += 1
|
|
connections = self.fetch_connections(host, cluster)
|
|
if len(connections) is not 0:
|
|
return connections
|
|
time.sleep(.1)
|
|
self.fail("No new connections found")
|
|
|
|
def wait_for_no_connections(self, host, cluster):
|
|
retry = 0
|
|
while(retry < 100):
|
|
retry += 1
|
|
connections = self.fetch_connections(host, cluster)
|
|
if len(connections) is 0:
|
|
return
|
|
time.sleep(.5)
|
|
self.fail("Connections never cleared")
|
|
|
|
|
|
class ConnectionTests(object):
|
|
|
|
klass = None
|
|
|
|
def setUp(self):
|
|
self.klass.initialize_reactor()
|
|
|
|
def get_connection(self, timeout=5):
|
|
"""
|
|
Helper method to solve automated testing issues within Jenkins.
|
|
Officially patched under the 2.0 branch through
|
|
17998ef72a2fe2e67d27dd602b6ced33a58ad8ef, but left as is for the
|
|
1.0 branch due to possible regressions for fixing an
|
|
automated testing edge-case.
|
|
"""
|
|
conn = None
|
|
e = None
|
|
for i in range(5):
|
|
try:
|
|
contact_point = CASSANDRA_IP
|
|
conn = self.klass.factory(host=contact_point, timeout=timeout, protocol_version=PROTOCOL_VERSION)
|
|
break
|
|
except (OperationTimedOut, NoHostAvailable, ConnectionShutdown) as e:
|
|
continue
|
|
|
|
if conn:
|
|
return conn
|
|
else:
|
|
raise e
|
|
|
|
def test_single_connection(self):
|
|
"""
|
|
Test a single connection with sequential requests.
|
|
"""
|
|
conn = self.get_connection()
|
|
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
|
|
event = Event()
|
|
|
|
def cb(count, *args, **kwargs):
|
|
count += 1
|
|
if count >= 10:
|
|
conn.close()
|
|
event.set()
|
|
else:
|
|
conn.send_msg(
|
|
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
|
|
request_id=0,
|
|
cb=partial(cb, count))
|
|
|
|
conn.send_msg(
|
|
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
|
|
request_id=0,
|
|
cb=partial(cb, 0))
|
|
event.wait()
|
|
|
|
def test_single_connection_pipelined_requests(self):
|
|
"""
|
|
Test a single connection with pipelined requests.
|
|
"""
|
|
conn = self.get_connection()
|
|
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
|
|
responses = [False] * 100
|
|
event = Event()
|
|
|
|
def cb(response_list, request_num, *args, **kwargs):
|
|
response_list[request_num] = True
|
|
if all(response_list):
|
|
conn.close()
|
|
event.set()
|
|
|
|
for i in range(100):
|
|
conn.send_msg(
|
|
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
|
|
request_id=i,
|
|
cb=partial(cb, responses, i))
|
|
|
|
event.wait()
|
|
|
|
def test_multiple_connections(self):
|
|
"""
|
|
Test multiple connections with pipelined requests.
|
|
"""
|
|
conns = [self.get_connection() for i in range(5)]
|
|
events = [Event() for i in range(5)]
|
|
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
|
|
|
|
def cb(event, conn, count, *args, **kwargs):
|
|
count += 1
|
|
if count >= 10:
|
|
conn.close()
|
|
event.set()
|
|
else:
|
|
conn.send_msg(
|
|
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
|
|
request_id=count,
|
|
cb=partial(cb, event, conn, count))
|
|
|
|
for event, conn in zip(events, conns):
|
|
conn.send_msg(
|
|
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
|
|
request_id=0,
|
|
cb=partial(cb, event, conn, 0))
|
|
|
|
for event in events:
|
|
event.wait()
|
|
|
|
def test_multiple_threads_shared_connection(self):
|
|
"""
|
|
Test sharing a single connections across multiple threads,
|
|
which will result in pipelined requests.
|
|
"""
|
|
num_requests_per_conn = 25
|
|
num_threads = 5
|
|
event = Event()
|
|
|
|
conn = self.get_connection()
|
|
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
|
|
|
|
def cb(all_responses, thread_responses, request_num, *args, **kwargs):
|
|
thread_responses[request_num] = True
|
|
if all(map(all, all_responses)):
|
|
conn.close()
|
|
event.set()
|
|
|
|
def send_msgs(all_responses, thread_responses):
|
|
for i in range(num_requests_per_conn):
|
|
qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
|
|
with conn.lock:
|
|
request_id = conn.get_request_id()
|
|
conn.send_msg(qmsg, request_id, cb=partial(cb, all_responses, thread_responses, i))
|
|
|
|
all_responses = []
|
|
threads = []
|
|
for i in range(num_threads):
|
|
thread_responses = [False] * num_requests_per_conn
|
|
all_responses.append(thread_responses)
|
|
t = Thread(target=send_msgs, args=(all_responses, thread_responses))
|
|
threads.append(t)
|
|
|
|
for t in threads:
|
|
t.start()
|
|
|
|
for t in threads:
|
|
t.join()
|
|
|
|
event.wait()
|
|
|
|
def test_multiple_threads_multiple_connections(self):
|
|
"""
|
|
Test several threads, each with their own Connection and pipelined
|
|
requests.
|
|
"""
|
|
num_requests_per_conn = 25
|
|
num_conns = 5
|
|
events = [Event() for i in range(5)]
|
|
|
|
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
|
|
|
|
def cb(conn, event, thread_responses, request_num, *args, **kwargs):
|
|
thread_responses[request_num] = True
|
|
if all(thread_responses):
|
|
conn.close()
|
|
event.set()
|
|
|
|
def send_msgs(conn, event):
|
|
thread_responses = [False] * num_requests_per_conn
|
|
for i in range(num_requests_per_conn):
|
|
qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
|
|
with conn.lock:
|
|
request_id = conn.get_request_id()
|
|
conn.send_msg(qmsg, request_id, cb=partial(cb, conn, event, thread_responses, i))
|
|
|
|
event.wait()
|
|
|
|
threads = []
|
|
for i in range(num_conns):
|
|
conn = self.get_connection()
|
|
t = Thread(target=send_msgs, args=(conn, events[i]))
|
|
threads.append(t)
|
|
|
|
for t in threads:
|
|
t.start()
|
|
|
|
for t in threads:
|
|
t.join()
|
|
|
|
def test_connect_timeout(self):
|
|
# Underlying socket implementations don't always throw a socket timeout even with min float
|
|
# This can be timing sensitive, added retry to ensure failure occurs if it can
|
|
max_retry_count = 10
|
|
exception_thrown = False
|
|
for i in range(max_retry_count):
|
|
start = time.time()
|
|
try:
|
|
conn = self.get_connection(timeout=sys.float_info.min)
|
|
conn.close()
|
|
except Exception as e:
|
|
end = time.time()
|
|
self.assertAlmostEqual(start, end, 1)
|
|
exception_thrown = True
|
|
break
|
|
self.assertTrue(exception_thrown)
|
|
|
|
|
|
class AsyncoreConnectionTests(ConnectionTests, unittest.TestCase):
|
|
|
|
klass = AsyncoreConnection
|
|
|
|
def setUp(self):
|
|
if is_monkey_patched():
|
|
raise unittest.SkipTest("Can't test asyncore with monkey patching")
|
|
ConnectionTests.setUp(self)
|
|
|
|
|
|
class LibevConnectionTests(ConnectionTests, unittest.TestCase):
|
|
|
|
klass = LibevConnection
|
|
|
|
def setUp(self):
|
|
if is_monkey_patched():
|
|
raise unittest.SkipTest("Can't test libev with monkey patching")
|
|
if LibevConnection is None:
|
|
raise unittest.SkipTest(
|
|
'libev does not appear to be installed properly')
|
|
ConnectionTests.setUp(self)
|