Files
deb-python-cassandra-driver/tests/integration/standard/test_connection.py
Tyler Hobbs 567655c6b0 Merge branch 'master' into 2.0
Conflicts:
	cassandra/cluster.py
	cassandra/io/asyncorereactor.py
	cassandra/io/libevreactor.py
	tests/integration/__init__.py
	tests/integration/standard/test_cluster.py
	tests/integration/standard/test_connection.py
	tests/unit/test_connection.py
2014-04-02 16:17:36 -05:00

193 lines
6.1 KiB
Python

from tests.integration import PROTOCOL_VERSION
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from functools import partial
import sys
from threading import Thread, Event
from cassandra import ConsistencyLevel
from cassandra.decoder import QueryMessage
from cassandra.io.asyncorereactor import AsyncoreConnection
try:
from cassandra.io.libevreactor import LibevConnection
except ImportError:
LibevConnection = None
class ConnectionTest(object):
klass = None
def test_single_connection(self):
"""
Test a single connection with sequential requests.
"""
conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
event = Event()
def cb(count, *args, **kwargs):
count += 1
if count >= 10:
conn.close()
event.set()
else:
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
cb=partial(cb, count))
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
cb=partial(cb, 0))
event.wait()
def test_single_connection_pipelined_requests(self):
"""
Test a single connection with pipelined requests.
"""
conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
responses = [False] * 100
event = Event()
def cb(response_list, request_num, *args, **kwargs):
response_list[request_num] = True
if all(response_list):
conn.close()
event.set()
for i in range(100):
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
cb=partial(cb, responses, i))
event.wait()
def test_multiple_connections(self):
"""
Test multiple connections with pipelined requests.
"""
conns = [self.klass.factory(protocol_version=PROTOCOL_VERSION) for i in range(5)]
events = [Event() for i in range(5)]
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
def cb(event, conn, count, *args, **kwargs):
count += 1
if count >= 10:
conn.close()
event.set()
else:
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
cb=partial(cb, event, conn, count))
for event, conn in zip(events, conns):
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
cb=partial(cb, event, conn, 0))
for event in events:
event.wait()
def test_multiple_threads_shared_connection(self):
"""
Test sharing a single connections across multiple threads,
which will result in pipelined requests.
"""
num_requests_per_conn = 25
num_threads = 5
event = Event()
conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
def cb(all_responses, thread_responses, request_num, *args, **kwargs):
thread_responses[request_num] = True
if all(map(all, all_responses)):
conn.close()
event.set()
def send_msgs(all_responses, thread_responses):
for i in range(num_requests_per_conn):
qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
conn.send_msg(qmsg, cb=partial(cb, all_responses, thread_responses, i))
all_responses = []
threads = []
for i in range(num_threads):
thread_responses = [False] * num_requests_per_conn
all_responses.append(thread_responses)
t = Thread(target=send_msgs, args=(all_responses, thread_responses))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
event.wait()
def test_multiple_threads_multiple_connections(self):
"""
Test several threads, each with their own Connection and pipelined
requests.
"""
num_requests_per_conn = 25
num_conns = 5
events = [Event() for i in range(5)]
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
def cb(conn, event, thread_responses, request_num, *args, **kwargs):
thread_responses[request_num] = True
if all(thread_responses):
conn.close()
event.set()
def send_msgs(conn, event):
thread_responses = [False] * num_requests_per_conn
for i in range(num_requests_per_conn):
qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
conn.send_msg(qmsg, cb=partial(cb, conn, event, thread_responses, i))
event.wait()
threads = []
for i in range(num_conns):
conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
t = Thread(target=send_msgs, args=(conn, events[i]))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
class AsyncoreConnectionTest(ConnectionTest, unittest.TestCase):
klass = AsyncoreConnection
def setUp(self):
if 'gevent.monkey' in sys.modules:
raise unittest.SkipTest("Can't test libev with gevent monkey patching")
class LibevConnectionTest(ConnectionTest, unittest.TestCase):
klass = LibevConnection
def setUp(self):
if 'gevent.monkey' in sys.modules:
raise unittest.SkipTest("Can't test libev with gevent monkey patching")
if LibevConnection is None:
raise unittest.SkipTest(
'libev does not appear to be installed properly')