195 lines
7.4 KiB
Python
195 lines
7.4 KiB
Python
# Copyright 2013-2014 DataStax, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
try:
|
|
import unittest2 as unittest
|
|
except ImportError:
|
|
import unittest
|
|
from mock import Mock, patch
|
|
|
|
try:
|
|
from twisted.test import proto_helpers
|
|
from twisted.python.failure import Failure
|
|
from cassandra.io import twistedreactor
|
|
except ImportError:
|
|
twistedreactor = None # NOQA
|
|
|
|
|
|
class TestTwistedProtocol(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
if twistedreactor is None:
|
|
raise unittest.SkipTest("Twisted libraries not available")
|
|
twistedreactor.TwistedConnection.initialize_reactor()
|
|
self.tr = proto_helpers.StringTransportWithDisconnection()
|
|
self.tr.connector = Mock()
|
|
self.mock_connection = Mock()
|
|
self.tr.connector.factory = twistedreactor.TwistedConnectionClientFactory(
|
|
self.mock_connection)
|
|
self.obj_ut = twistedreactor.TwistedConnectionProtocol()
|
|
self.tr.protocol = self.obj_ut
|
|
|
|
def tearDown(self):
|
|
pass
|
|
|
|
def test_makeConnection(self):
|
|
"""
|
|
Verify that the protocol class notifies the connection
|
|
object that a successful connection was made.
|
|
"""
|
|
self.obj_ut.makeConnection(self.tr)
|
|
self.assertTrue(self.mock_connection.client_connection_made.called)
|
|
|
|
def test_receiving_data(self):
|
|
"""
|
|
Verify that the dataReceived() callback writes the data to
|
|
the connection object's buffer and calls handle_read().
|
|
"""
|
|
self.obj_ut.makeConnection(self.tr)
|
|
self.obj_ut.dataReceived('foobar')
|
|
self.assertTrue(self.mock_connection.handle_read.called)
|
|
self.mock_connection._iobuf.write.assert_called_with("foobar")
|
|
|
|
|
|
class TestTwistedClientFactory(unittest.TestCase):
|
|
def setUp(self):
|
|
if twistedreactor is None:
|
|
raise unittest.SkipTest("Twisted libraries not available")
|
|
twistedreactor.TwistedConnection.initialize_reactor()
|
|
self.mock_connection = Mock()
|
|
self.obj_ut = twistedreactor.TwistedConnectionClientFactory(
|
|
self.mock_connection)
|
|
|
|
def test_client_connection_failed(self):
|
|
"""
|
|
Verify that connection failed causes the connection object to close.
|
|
"""
|
|
exc = Exception('a test')
|
|
self.obj_ut.clientConnectionFailed(None, Failure(exc))
|
|
self.mock_connection.defunct.assert_called_with(exc)
|
|
|
|
def test_client_connection_lost(self):
|
|
"""
|
|
Verify that connection lost causes the connection object to close.
|
|
"""
|
|
exc = Exception('a test')
|
|
self.obj_ut.clientConnectionLost(None, Failure(exc))
|
|
self.mock_connection.defunct.assert_called_with(exc)
|
|
|
|
|
|
class TestTwistedConnection(unittest.TestCase):
|
|
def setUp(self):
|
|
if twistedreactor is None:
|
|
raise unittest.SkipTest("Twisted libraries not available")
|
|
twistedreactor.TwistedConnection.initialize_reactor()
|
|
self.reactor_cft_patcher = patch(
|
|
'twisted.internet.reactor.callFromThread')
|
|
self.reactor_running_patcher = patch(
|
|
'twisted.internet.reactor.running', False)
|
|
self.reactor_run_patcher = patch('twisted.internet.reactor.run')
|
|
self.mock_reactor_cft = self.reactor_cft_patcher.start()
|
|
self.mock_reactor_run = self.reactor_run_patcher.start()
|
|
self.obj_ut = twistedreactor.TwistedConnection('1.2.3.4',
|
|
cql_version='3.0.1')
|
|
|
|
def tearDown(self):
|
|
self.reactor_cft_patcher.stop()
|
|
self.reactor_run_patcher.stop()
|
|
self.obj_ut._loop._cleanup()
|
|
|
|
def test_connection_initialization(self):
|
|
"""
|
|
Verify that __init__() works correctly.
|
|
"""
|
|
self.mock_reactor_cft.assert_called_with(self.obj_ut.add_connection)
|
|
self.obj_ut._loop._cleanup()
|
|
self.mock_reactor_run.assert_called_with(installSignalHandlers=False)
|
|
|
|
@patch('twisted.internet.reactor.connectTCP')
|
|
def test_add_connection(self, mock_connectTCP):
|
|
"""
|
|
Verify that add_connection() gives us a valid twisted connector.
|
|
"""
|
|
self.obj_ut.add_connection()
|
|
self.assertTrue(self.obj_ut.connector is not None)
|
|
self.assertTrue(mock_connectTCP.called)
|
|
|
|
def test_client_connection_made(self):
|
|
"""
|
|
Verifiy that _send_options_message() is called in
|
|
client_connection_made()
|
|
"""
|
|
self.obj_ut._send_options_message = Mock()
|
|
self.obj_ut.client_connection_made()
|
|
self.obj_ut._send_options_message.assert_called_with()
|
|
|
|
@patch('twisted.internet.reactor.connectTCP')
|
|
def test_close(self, mock_connectTCP):
|
|
"""
|
|
Verify that close() disconnects the connector and errors callbacks.
|
|
"""
|
|
self.obj_ut.error_all_callbacks = Mock()
|
|
self.obj_ut.add_connection()
|
|
self.obj_ut.is_closed = False
|
|
self.obj_ut.close()
|
|
self.obj_ut.connector.disconnect.assert_called_with()
|
|
self.assertTrue(self.obj_ut.connected_event.is_set())
|
|
self.assertTrue(self.obj_ut.error_all_callbacks.called)
|
|
|
|
def test_handle_read__incomplete(self):
|
|
"""
|
|
Verify that handle_read() processes incomplete messages properly.
|
|
"""
|
|
self.obj_ut.process_msg = Mock()
|
|
self.assertEqual(self.obj_ut._iobuf.getvalue(), '') # buf starts empty
|
|
# incomplete header
|
|
self.obj_ut._iobuf.write('\xff\x00\x00\x00')
|
|
self.obj_ut.handle_read()
|
|
self.assertEqual(self.obj_ut._iobuf.getvalue(), '\xff\x00\x00\x00')
|
|
|
|
# full header, but incomplete body
|
|
self.obj_ut._iobuf.write('\x00\x00\x00\x15')
|
|
self.obj_ut.handle_read()
|
|
self.assertEqual(self.obj_ut._iobuf.getvalue(),
|
|
'\xff\x00\x00\x00\x00\x00\x00\x15')
|
|
self.assertEqual(self.obj_ut._total_reqd_bytes, 29)
|
|
|
|
# verify we never attempted to process the incomplete message
|
|
self.assertFalse(self.obj_ut.process_msg.called)
|
|
|
|
def test_handle_read__fullmessage(self):
|
|
"""
|
|
Verify that handle_read() processes complete messages properly.
|
|
"""
|
|
self.obj_ut.process_msg = Mock()
|
|
self.assertEqual(self.obj_ut._iobuf.getvalue(), '') # buf starts empty
|
|
|
|
# write a complete message, plus 'NEXT' (to simulate next message)
|
|
self.obj_ut._iobuf.write(
|
|
'\xff\x00\x00\x00\x00\x00\x00\x15this is the drum rollNEXT')
|
|
self.obj_ut.handle_read()
|
|
self.assertEqual(self.obj_ut._iobuf.getvalue(), 'NEXT')
|
|
self.obj_ut.process_msg.assert_called_with(
|
|
'\xff\x00\x00\x00\x00\x00\x00\x15this is the drum roll', 21)
|
|
|
|
@patch('twisted.internet.reactor.connectTCP')
|
|
def test_push(self, mock_connectTCP):
|
|
"""
|
|
Verifiy that push() calls transport.write(data).
|
|
"""
|
|
self.obj_ut.add_connection()
|
|
self.obj_ut.push('123 pickup')
|
|
self.mock_reactor_cft.assert_called_with(
|
|
self.obj_ut.connector.transport.write, '123 pickup')
|