337 lines
11 KiB
Python
337 lines
11 KiB
Python
# Copyright 2013-2015 DataStax, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import sys
|
|
import six
|
|
|
|
try:
|
|
import unittest2 as unittest
|
|
except ImportError:
|
|
import unittest # noqa
|
|
|
|
import errno
|
|
import math
|
|
import time
|
|
from mock import patch, Mock
|
|
import os
|
|
from six import BytesIO
|
|
import socket
|
|
from socket import error as socket_error
|
|
from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT,
|
|
ConnectionException, ProtocolError,Timer)
|
|
from cassandra.io.asyncorereactor import AsyncoreConnection
|
|
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
|
|
SupportedMessage, ReadyMessage, ServerError)
|
|
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
|
|
from tests import is_monkey_patched
|
|
from tests.unit.io.utils import submit_and_wait_for_completion, TimerCallback
|
|
|
|
|
|
class AsyncoreConnectionTest(unittest.TestCase):
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
if is_monkey_patched():
|
|
return
|
|
AsyncoreConnection.initialize_reactor()
|
|
cls.socket_patcher = patch('socket.socket', spec=socket.socket)
|
|
cls.mock_socket = cls.socket_patcher.start()
|
|
cls.mock_socket().connect_ex.return_value = 0
|
|
cls.mock_socket().getsockopt.return_value = 0
|
|
cls.mock_socket().fileno.return_value = 100
|
|
|
|
AsyncoreConnection.add_channel = lambda *args, **kwargs: None
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
if is_monkey_patched():
|
|
return
|
|
cls.socket_patcher.stop()
|
|
|
|
def setUp(self):
|
|
if is_monkey_patched():
|
|
raise unittest.SkipTest("Can't test asyncore with monkey patching")
|
|
|
|
def make_connection(self):
|
|
c = AsyncoreConnection('1.2.3.4', cql_version='3.0.1')
|
|
c.socket = Mock()
|
|
c.socket.send.side_effect = lambda x: len(x)
|
|
return c
|
|
|
|
def make_header_prefix(self, message_class, version=2, stream_id=0):
|
|
return six.binary_type().join(map(uint8_pack, [
|
|
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
|
|
0, # flags (compression)
|
|
stream_id,
|
|
message_class.opcode # opcode
|
|
]))
|
|
|
|
def make_options_body(self):
|
|
options_buf = BytesIO()
|
|
write_stringmultimap(options_buf, {
|
|
'CQL_VERSION': ['3.0.1'],
|
|
'COMPRESSION': []
|
|
})
|
|
return options_buf.getvalue()
|
|
|
|
def make_error_body(self, code, msg):
|
|
buf = BytesIO()
|
|
write_int(buf, code)
|
|
write_string(buf, msg)
|
|
return buf.getvalue()
|
|
|
|
def make_msg(self, header, body=six.binary_type()):
|
|
return header + uint32_pack(len(body)) + body
|
|
|
|
def test_successful_connection(self, *args):
|
|
c = self.make_connection()
|
|
|
|
# let it write the OptionsMessage
|
|
c.handle_write()
|
|
|
|
# read in a SupportedMessage response
|
|
header = self.make_header_prefix(SupportedMessage)
|
|
options = self.make_options_body()
|
|
c.socket.recv.return_value = self.make_msg(header, options)
|
|
c.handle_read()
|
|
|
|
# let it write out a StartupMessage
|
|
c.handle_write()
|
|
|
|
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
|
c.socket.recv.return_value = self.make_msg(header)
|
|
c.handle_read()
|
|
|
|
self.assertTrue(c.connected_event.is_set())
|
|
return c
|
|
|
|
def test_egain_on_buffer_size(self, *args):
|
|
# get a connection that's already fully started
|
|
c = self.test_successful_connection()
|
|
|
|
header = six.b('\x00\x00\x00\x00') + int32_pack(20000)
|
|
responses = [
|
|
header + (six.b('a') * (4096 - len(header))),
|
|
six.b('a') * 4096,
|
|
socket_error(errno.EAGAIN),
|
|
six.b('a') * 100,
|
|
socket_error(errno.EAGAIN)]
|
|
|
|
def side_effect(*args):
|
|
response = responses.pop(0)
|
|
if isinstance(response, socket_error):
|
|
raise response
|
|
else:
|
|
return response
|
|
|
|
c.socket.recv.side_effect = side_effect
|
|
c.handle_read()
|
|
self.assertEqual(c._current_frame.end_pos, 20000 + len(header))
|
|
# the EAGAIN prevents it from reading the last 100 bytes
|
|
c._iobuf.seek(0, os.SEEK_END)
|
|
pos = c._iobuf.tell()
|
|
self.assertEqual(pos, 4096 + 4096)
|
|
|
|
# now tell it to read the last 100 bytes
|
|
c.handle_read()
|
|
c._iobuf.seek(0, os.SEEK_END)
|
|
pos = c._iobuf.tell()
|
|
self.assertEqual(pos, 4096 + 4096 + 100)
|
|
|
|
def test_protocol_error(self, *args):
|
|
c = self.make_connection()
|
|
|
|
# let it write the OptionsMessage
|
|
c.handle_write()
|
|
|
|
# read in a SupportedMessage response
|
|
header = self.make_header_prefix(SupportedMessage, version=0xa4)
|
|
options = self.make_options_body()
|
|
c.socket.recv.return_value = self.make_msg(header, options)
|
|
c.handle_read()
|
|
|
|
# make sure it errored correctly
|
|
self.assertTrue(c.is_defunct)
|
|
self.assertTrue(c.connected_event.is_set())
|
|
self.assertIsInstance(c.last_error, ProtocolError)
|
|
|
|
def test_error_message_on_startup(self, *args):
|
|
c = self.make_connection()
|
|
|
|
# let it write the OptionsMessage
|
|
c.handle_write()
|
|
|
|
# read in a SupportedMessage response
|
|
header = self.make_header_prefix(SupportedMessage)
|
|
options = self.make_options_body()
|
|
c.socket.recv.return_value = self.make_msg(header, options)
|
|
c.handle_read()
|
|
|
|
# let it write out a StartupMessage
|
|
c.handle_write()
|
|
|
|
header = self.make_header_prefix(ServerError, stream_id=1)
|
|
body = self.make_error_body(ServerError.error_code, ServerError.summary)
|
|
c.socket.recv.return_value = self.make_msg(header, body)
|
|
c.handle_read()
|
|
|
|
# make sure it errored correctly
|
|
self.assertTrue(c.is_defunct)
|
|
self.assertIsInstance(c.last_error, ConnectionException)
|
|
self.assertTrue(c.connected_event.is_set())
|
|
|
|
def test_socket_error_on_write(self, *args):
|
|
c = self.make_connection()
|
|
|
|
# make the OptionsMessage write fail
|
|
c.socket.send.side_effect = socket_error(errno.EIO, "bad stuff!")
|
|
c.handle_write()
|
|
|
|
# make sure it errored correctly
|
|
self.assertTrue(c.is_defunct)
|
|
self.assertIsInstance(c.last_error, socket_error)
|
|
self.assertTrue(c.connected_event.is_set())
|
|
|
|
def test_blocking_on_write(self, *args):
|
|
c = self.make_connection()
|
|
|
|
# make the OptionsMessage write block
|
|
c.socket.send.side_effect = socket_error(errno.EAGAIN, "socket busy")
|
|
c.handle_write()
|
|
|
|
self.assertFalse(c.is_defunct)
|
|
|
|
# try again with normal behavior
|
|
c.socket.send.side_effect = lambda x: len(x)
|
|
c.handle_write()
|
|
self.assertFalse(c.is_defunct)
|
|
self.assertTrue(c.socket.send.call_args is not None)
|
|
|
|
def test_partial_send(self, *args):
|
|
c = self.make_connection()
|
|
|
|
# only write the first four bytes of the OptionsMessage
|
|
write_size = 4
|
|
c.socket.send.side_effect = None
|
|
c.socket.send.return_value = write_size
|
|
c.handle_write()
|
|
|
|
msg_size = 9 # v3+ frame header
|
|
expected_writes = int(math.ceil(float(msg_size) / write_size))
|
|
size_mod = msg_size % write_size
|
|
last_write_size = size_mod if size_mod else write_size
|
|
self.assertFalse(c.is_defunct)
|
|
self.assertEqual(expected_writes, c.socket.send.call_count)
|
|
self.assertEqual(last_write_size, len(c.socket.send.call_args[0][0]))
|
|
|
|
def test_socket_error_on_read(self, *args):
|
|
c = self.make_connection()
|
|
|
|
# let it write the OptionsMessage
|
|
c.handle_write()
|
|
|
|
# read in a SupportedMessage response
|
|
c.socket.recv.side_effect = socket_error(errno.EIO, "busy socket")
|
|
c.handle_read()
|
|
|
|
# make sure it errored correctly
|
|
self.assertTrue(c.is_defunct)
|
|
self.assertIsInstance(c.last_error, socket_error)
|
|
self.assertTrue(c.connected_event.is_set())
|
|
|
|
def test_partial_header_read(self, *args):
|
|
c = self.make_connection()
|
|
|
|
header = self.make_header_prefix(SupportedMessage)
|
|
options = self.make_options_body()
|
|
message = self.make_msg(header, options)
|
|
|
|
c.socket.recv.return_value = message[0:1]
|
|
c.handle_read()
|
|
self.assertEqual(c._iobuf.getvalue(), message[0:1])
|
|
|
|
c.socket.recv.return_value = message[1:]
|
|
c.handle_read()
|
|
self.assertEqual(six.binary_type(), c._iobuf.getvalue())
|
|
|
|
# let it write out a StartupMessage
|
|
c.handle_write()
|
|
|
|
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
|
c.socket.recv.return_value = self.make_msg(header)
|
|
c.handle_read()
|
|
|
|
self.assertTrue(c.connected_event.is_set())
|
|
self.assertFalse(c.is_defunct)
|
|
|
|
def test_partial_message_read(self, *args):
|
|
c = self.make_connection()
|
|
|
|
header = self.make_header_prefix(SupportedMessage)
|
|
options = self.make_options_body()
|
|
message = self.make_msg(header, options)
|
|
|
|
# read in the first nine bytes
|
|
c.socket.recv.return_value = message[:9]
|
|
c.handle_read()
|
|
self.assertEqual(c._iobuf.getvalue(), message[:9])
|
|
|
|
# ... then read in the rest
|
|
c.socket.recv.return_value = message[9:]
|
|
c.handle_read()
|
|
self.assertEqual(six.binary_type(), c._iobuf.getvalue())
|
|
|
|
# let it write out a StartupMessage
|
|
c.handle_write()
|
|
|
|
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
|
c.socket.recv.return_value = self.make_msg(header)
|
|
c.handle_read()
|
|
|
|
self.assertTrue(c.connected_event.is_set())
|
|
self.assertFalse(c.is_defunct)
|
|
|
|
def test_multi_timer_validation(self, *args):
|
|
"""
|
|
Verify that timer timeouts are honored appropriately
|
|
"""
|
|
c = self.make_connection()
|
|
# Tests timers submitted in order at various timeouts
|
|
submit_and_wait_for_completion(self, AsyncoreConnection, 0, 100, 1, 100)
|
|
# Tests timers submitted in reverse order at various timeouts
|
|
submit_and_wait_for_completion(self, AsyncoreConnection, 100, 0, -1, 100)
|
|
# Tests timers submitted in varying order at various timeouts
|
|
submit_and_wait_for_completion(self, AsyncoreConnection, 0, 100, 1, 100, True)
|
|
|
|
def test_timer_cancellation(self):
|
|
"""
|
|
Verify that timer cancellation is honored
|
|
"""
|
|
|
|
# Various lists for tracking callback stage
|
|
connection = self.make_connection()
|
|
timeout = .1
|
|
callback = TimerCallback(timeout)
|
|
timer = connection.create_timer(timeout, callback.invoke)
|
|
timer.cancel()
|
|
# Release context allow for timer thread to run.
|
|
time.sleep(.2)
|
|
timer_manager = connection._loop._timers
|
|
# Assert that the cancellation was honored
|
|
self.assertFalse(timer_manager._queue)
|
|
self.assertFalse(timer_manager._new_timers)
|
|
self.assertFalse(callback.was_invoked())
|
|
|
|
|
|
|