Files
deb-python-cassandra-driver/tests/integration/standard/test_connection.py
2015-08-17 11:55:49 -05:00

257 lines
8.1 KiB
Python

# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from functools import partial
from six.moves import range
import sys
from threading import Thread, Event
import time
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cluster import NoHostAvailable
from cassandra.io.asyncorereactor import AsyncoreConnection
from cassandra.protocol import QueryMessage
from tests import is_monkey_patched
from tests.integration import use_singledc, PROTOCOL_VERSION
try:
from cassandra.io.libevreactor import LibevConnection
except ImportError:
LibevConnection = None
def setup_module():
use_singledc()
class ConnectionTests(object):
klass = None
def setUp(self):
self.klass.initialize_reactor()
def get_connection(self, timeout=5):
"""
Helper method to solve automated testing issues within Jenkins.
Officially patched under the 2.0 branch through
17998ef72a2fe2e67d27dd602b6ced33a58ad8ef, but left as is for the
1.0 branch due to possible regressions for fixing an
automated testing edge-case.
"""
conn = None
e = None
for i in range(5):
try:
conn = self.klass.factory(host='127.0.0.1', timeout=timeout, protocol_version=PROTOCOL_VERSION)
break
except (OperationTimedOut, NoHostAvailable) as e:
continue
if conn:
return conn
else:
raise e
def test_single_connection(self):
"""
Test a single connection with sequential requests.
"""
conn = self.get_connection()
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
event = Event()
def cb(count, *args, **kwargs):
count += 1
if count >= 10:
conn.close()
event.set()
else:
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
request_id=0,
cb=partial(cb, count))
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
request_id=0,
cb=partial(cb, 0))
event.wait()
def test_single_connection_pipelined_requests(self):
"""
Test a single connection with pipelined requests.
"""
conn = self.get_connection()
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
responses = [False] * 100
event = Event()
def cb(response_list, request_num, *args, **kwargs):
response_list[request_num] = True
if all(response_list):
conn.close()
event.set()
for i in range(100):
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
request_id=i,
cb=partial(cb, responses, i))
event.wait()
def test_multiple_connections(self):
"""
Test multiple connections with pipelined requests.
"""
conns = [self.get_connection() for i in range(5)]
events = [Event() for i in range(5)]
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
def cb(event, conn, count, *args, **kwargs):
count += 1
if count >= 10:
conn.close()
event.set()
else:
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
request_id=count,
cb=partial(cb, event, conn, count))
for event, conn in zip(events, conns):
conn.send_msg(
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE),
request_id=0,
cb=partial(cb, event, conn, 0))
for event in events:
event.wait()
def test_multiple_threads_shared_connection(self):
"""
Test sharing a single connections across multiple threads,
which will result in pipelined requests.
"""
num_requests_per_conn = 25
num_threads = 5
event = Event()
conn = self.get_connection()
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
def cb(all_responses, thread_responses, request_num, *args, **kwargs):
thread_responses[request_num] = True
if all(map(all, all_responses)):
conn.close()
event.set()
def send_msgs(all_responses, thread_responses):
for i in range(num_requests_per_conn):
qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
with conn.lock:
request_id = conn.get_request_id()
conn.send_msg(qmsg, request_id, cb=partial(cb, all_responses, thread_responses, i))
all_responses = []
threads = []
for i in range(num_threads):
thread_responses = [False] * num_requests_per_conn
all_responses.append(thread_responses)
t = Thread(target=send_msgs, args=(all_responses, thread_responses))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
event.wait()
def test_multiple_threads_multiple_connections(self):
"""
Test several threads, each with their own Connection and pipelined
requests.
"""
num_requests_per_conn = 25
num_conns = 5
events = [Event() for i in range(5)]
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
def cb(conn, event, thread_responses, request_num, *args, **kwargs):
thread_responses[request_num] = True
if all(thread_responses):
conn.close()
event.set()
def send_msgs(conn, event):
thread_responses = [False] * num_requests_per_conn
for i in range(num_requests_per_conn):
qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
with conn.lock:
request_id = conn.get_request_id()
conn.send_msg(qmsg, request_id, cb=partial(cb, conn, event, thread_responses, i))
event.wait()
threads = []
for i in range(num_conns):
conn = self.get_connection()
t = Thread(target=send_msgs, args=(conn, events[i]))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
def test_connect_timeout(self):
start = time.time()
self.assertRaises(Exception, self.get_connection, timeout=sys.float_info.min)
end = time.time()
self.assertAlmostEqual(start, end, 1)
class AsyncoreConnectionTests(ConnectionTests, unittest.TestCase):
klass = AsyncoreConnection
def setUp(self):
if is_monkey_patched():
raise unittest.SkipTest("Can't test asyncore with monkey patching")
ConnectionTests.setUp(self)
class LibevConnectionTests(ConnectionTests, unittest.TestCase):
klass = LibevConnection
def setUp(self):
if is_monkey_patched():
raise unittest.SkipTest("Can't test libev with monkey patching")
if LibevConnection is None:
raise unittest.SkipTest(
'libev does not appear to be installed properly')
ConnectionTests.setUp(self)