Merge branch 'master' into 2.1-support
This commit is contained in:
@@ -1,10 +1,13 @@
|
|||||||
2.0.3
|
2.1.0
|
||||||
=====
|
=====
|
||||||
In Progress
|
In Progress
|
||||||
|
|
||||||
Features
|
Features
|
||||||
--------
|
--------
|
||||||
* Use io.BytesIO for reduced CPU consumption (github #143)
|
* Use io.BytesIO for reduced CPU consumption (github #143)
|
||||||
|
* Support Twisted as a reactor. Note that a Twisted-compatible
|
||||||
|
API is not exposed (so no Deferreds), this is just a reactor
|
||||||
|
implementation. (github #135, PYTHON-8)
|
||||||
|
|
||||||
Bug Fixes
|
Bug Fixes
|
||||||
---------
|
---------
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ dirname = os.path.dirname(os.path.abspath(__file__))
|
|||||||
sys.path.append(dirname)
|
sys.path.append(dirname)
|
||||||
sys.path.append(os.path.join(dirname, '..'))
|
sys.path.append(os.path.join(dirname, '..'))
|
||||||
|
|
||||||
|
import cassandra
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.io.asyncorereactor import AsyncoreConnection
|
from cassandra.io.asyncorereactor import AsyncoreConnection
|
||||||
from cassandra.policies import HostDistance
|
from cassandra.policies import HostDistance
|
||||||
@@ -44,39 +45,47 @@ try:
|
|||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
KEYSPACE = "testkeyspace"
|
have_twisted = False
|
||||||
|
try:
|
||||||
|
from cassandra.io.twistedreactor import TwistedConnection
|
||||||
|
have_twisted = True
|
||||||
|
supported_reactors.append(TwistedConnection)
|
||||||
|
except ImportError as exc:
|
||||||
|
log.exception("Error importing twisted")
|
||||||
|
pass
|
||||||
|
|
||||||
|
KEYSPACE = "testkeyspace" + str(int(time.time()))
|
||||||
TABLE = "testtable"
|
TABLE = "testtable"
|
||||||
|
|
||||||
|
|
||||||
def setup(hosts):
|
def setup(hosts):
|
||||||
|
log.info("Using 'cassandra' package from %s", cassandra.__path__)
|
||||||
|
|
||||||
cluster = Cluster(hosts)
|
cluster = Cluster(hosts)
|
||||||
cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
|
cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
|
||||||
session = cluster.connect()
|
try:
|
||||||
|
session = cluster.connect()
|
||||||
|
|
||||||
rows = session.execute("SELECT keyspace_name FROM system.schema_keyspaces")
|
log.debug("Creating keyspace...")
|
||||||
if KEYSPACE in [row[0] for row in rows]:
|
session.execute("""
|
||||||
log.debug("dropping existing keyspace...")
|
CREATE KEYSPACE %s
|
||||||
session.execute("DROP KEYSPACE " + KEYSPACE)
|
WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
|
||||||
|
""" % KEYSPACE)
|
||||||
|
|
||||||
log.debug("Creating keyspace...")
|
log.debug("Setting keyspace...")
|
||||||
session.execute("""
|
session.set_keyspace(KEYSPACE)
|
||||||
CREATE KEYSPACE %s
|
|
||||||
WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
|
|
||||||
""" % KEYSPACE)
|
|
||||||
|
|
||||||
log.debug("Setting keyspace...")
|
log.debug("Creating table...")
|
||||||
session.set_keyspace(KEYSPACE)
|
session.execute("""
|
||||||
|
CREATE TABLE %s (
|
||||||
log.debug("Creating table...")
|
thekey text,
|
||||||
session.execute("""
|
col1 text,
|
||||||
CREATE TABLE %s (
|
col2 text,
|
||||||
thekey text,
|
PRIMARY KEY (thekey, col1)
|
||||||
col1 text,
|
)
|
||||||
col2 text,
|
""" % TABLE)
|
||||||
PRIMARY KEY (thekey, col1)
|
finally:
|
||||||
)
|
cluster.shutdown()
|
||||||
""" % TABLE)
|
|
||||||
|
|
||||||
|
|
||||||
def teardown(hosts):
|
def teardown(hosts):
|
||||||
@@ -84,6 +93,7 @@ def teardown(hosts):
|
|||||||
cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
|
cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
|
||||||
session = cluster.connect()
|
session = cluster.connect()
|
||||||
session.execute("DROP KEYSPACE " + KEYSPACE)
|
session.execute("DROP KEYSPACE " + KEYSPACE)
|
||||||
|
cluster.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def benchmark(thread_class):
|
def benchmark(thread_class):
|
||||||
@@ -124,6 +134,7 @@ def benchmark(thread_class):
|
|||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
finally:
|
finally:
|
||||||
|
cluster.shutdown()
|
||||||
teardown(options.hosts)
|
teardown(options.hosts)
|
||||||
|
|
||||||
total = end - start
|
total = end - start
|
||||||
@@ -164,6 +175,8 @@ def parse_options():
|
|||||||
help='only benchmark with asyncore connections')
|
help='only benchmark with asyncore connections')
|
||||||
parser.add_option('--libev-only', action='store_true', dest='libev_only',
|
parser.add_option('--libev-only', action='store_true', dest='libev_only',
|
||||||
help='only benchmark with libev connections')
|
help='only benchmark with libev connections')
|
||||||
|
parser.add_option('--twisted-only', action='store_true', dest='twisted_only',
|
||||||
|
help='only benchmark with Twisted connections')
|
||||||
parser.add_option('-m', '--metrics', action='store_true', dest='enable_metrics',
|
parser.add_option('-m', '--metrics', action='store_true', dest='enable_metrics',
|
||||||
help='enable and print metrics for operations')
|
help='enable and print metrics for operations')
|
||||||
parser.add_option('-l', '--log-level', default='info',
|
parser.add_option('-l', '--log-level', default='info',
|
||||||
@@ -184,6 +197,11 @@ def parse_options():
|
|||||||
log.error("libev is not available")
|
log.error("libev is not available")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
options.supported_reactors = [LibevConnection]
|
options.supported_reactors = [LibevConnection]
|
||||||
|
elif options.twisted_only:
|
||||||
|
if not have_twisted:
|
||||||
|
log.error("Twisted is not available")
|
||||||
|
sys.exit(1)
|
||||||
|
options.supported_reactors = [TwistedConnection]
|
||||||
else:
|
else:
|
||||||
options.supported_reactors = supported_reactors
|
options.supported_reactors = supported_reactors
|
||||||
if not have_libev:
|
if not have_libev:
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
|
|||||||
named_tuple_factory, dict_factory)
|
named_tuple_factory, dict_factory)
|
||||||
|
|
||||||
# default to gevent when we are monkey patched, otherwise if libev is available, use that as the
|
# default to gevent when we are monkey patched, otherwise if libev is available, use that as the
|
||||||
# default because it's faster than asyncore
|
# default because it's fastest. Otherwise, use asyncore.
|
||||||
if 'gevent.monkey' in sys.modules:
|
if 'gevent.monkey' in sys.modules:
|
||||||
from cassandra.io.geventreactor import GeventConnection as DefaultConnection
|
from cassandra.io.geventreactor import GeventConnection as DefaultConnection
|
||||||
else:
|
else:
|
||||||
@@ -310,6 +310,8 @@ class Cluster(object):
|
|||||||
|
|
||||||
* :class:`cassandra.io.asyncorereactor.AsyncoreConnection`
|
* :class:`cassandra.io.asyncorereactor.AsyncoreConnection`
|
||||||
* :class:`cassandra.io.libevreactor.LibevConnection`
|
* :class:`cassandra.io.libevreactor.LibevConnection`
|
||||||
|
* :class:`cassandra.io.libevreactor.GeventConnection` (requires monkey-patching)
|
||||||
|
* :class:`cassandra.io.libevreactor.TwistedConnection`
|
||||||
|
|
||||||
By default, ``AsyncoreConnection`` will be used, which uses
|
By default, ``AsyncoreConnection`` will be used, which uses
|
||||||
the ``asyncore`` module in the Python standard library. The
|
the ``asyncore`` module in the Python standard library. The
|
||||||
@@ -317,6 +319,9 @@ class Cluster(object):
|
|||||||
supported on a wider range of systems.
|
supported on a wider range of systems.
|
||||||
|
|
||||||
If ``libev`` is installed, ``LibevConnection`` will be used instead.
|
If ``libev`` is installed, ``LibevConnection`` will be used instead.
|
||||||
|
|
||||||
|
If gevent monkey-patching of the standard library is detected,
|
||||||
|
GeventConnection will be used automatically.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
control_connection_timeout = 2.0
|
control_connection_timeout = 2.0
|
||||||
|
|||||||
295
cassandra/io/twistedreactor.py
Normal file
295
cassandra/io/twistedreactor.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
# Copyright 2013-2014 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Module that implements an event loop based on twisted
|
||||||
|
( https://twistedmatrix.com ).
|
||||||
|
"""
|
||||||
|
from twisted.internet import reactor, protocol
|
||||||
|
from threading import Event, Thread, Lock
|
||||||
|
from functools import partial
|
||||||
|
import logging
|
||||||
|
import weakref
|
||||||
|
import atexit
|
||||||
|
import os
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from cassandra import OperationTimedOut
|
||||||
|
from cassandra.connection import Connection, ConnectionShutdown
|
||||||
|
from cassandra.protocol import RegisterMessage
|
||||||
|
from cassandra.marshal import int32_unpack
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup(cleanup_weakref):
|
||||||
|
try:
|
||||||
|
cleanup_weakref()._cleanup()
|
||||||
|
except ReferenceError:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class TwistedConnectionProtocol(protocol.Protocol):
|
||||||
|
"""
|
||||||
|
Twisted Protocol class for handling data received and connection
|
||||||
|
made events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def dataReceived(self, data):
|
||||||
|
"""
|
||||||
|
Callback function that is called when data has been received
|
||||||
|
on the connection.
|
||||||
|
|
||||||
|
Reaches back to the Connection object and queues the data for
|
||||||
|
processing.
|
||||||
|
"""
|
||||||
|
self.transport.connector.factory.conn._iobuf.write(data)
|
||||||
|
self.transport.connector.factory.conn.handle_read()
|
||||||
|
|
||||||
|
def connectionMade(self):
|
||||||
|
"""
|
||||||
|
Callback function that is called when a connection has succeeded.
|
||||||
|
|
||||||
|
Reaches back to the Connection object and confirms that the connection
|
||||||
|
is ready.
|
||||||
|
"""
|
||||||
|
self.transport.connector.factory.conn.client_connection_made()
|
||||||
|
|
||||||
|
def connectionLost(self, reason):
|
||||||
|
# reason is a Failure instance
|
||||||
|
self.transport.connector.factory.conn.defunct(reason.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TwistedConnectionClientFactory(protocol.ClientFactory):
|
||||||
|
|
||||||
|
def __init__(self, connection):
|
||||||
|
# ClientFactory does not define __init__() in parent classes
|
||||||
|
# and does not inherit from object.
|
||||||
|
self.conn = connection
|
||||||
|
|
||||||
|
def buildProtocol(self, addr):
|
||||||
|
"""
|
||||||
|
Twisted function that defines which kind of protocol to use
|
||||||
|
in the ClientFactory.
|
||||||
|
"""
|
||||||
|
return TwistedConnectionProtocol()
|
||||||
|
|
||||||
|
def clientConnectionFailed(self, connector, reason):
|
||||||
|
"""
|
||||||
|
Overridden twisted callback which is called when the
|
||||||
|
connection attempt fails.
|
||||||
|
"""
|
||||||
|
log.debug("Connect failed: %s", reason)
|
||||||
|
self.conn.defunct(reason.value)
|
||||||
|
|
||||||
|
def clientConnectionLost(self, connector, reason):
|
||||||
|
"""
|
||||||
|
Overridden twisted callback which is called when the
|
||||||
|
connection goes away (cleanly or otherwise).
|
||||||
|
|
||||||
|
It should be safe to call defunct() here instead of just close, because
|
||||||
|
we can assume that if the connection was closed cleanly, there are no
|
||||||
|
callbacks to error out. If this assumption turns out to be false, we
|
||||||
|
can call close() instead of defunct() when "reason" is an appropriate
|
||||||
|
type.
|
||||||
|
"""
|
||||||
|
log.debug("Connect lost: %s", reason)
|
||||||
|
self.conn.defunct(reason.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TwistedLoop(object):
|
||||||
|
|
||||||
|
_lock = None
|
||||||
|
_thread = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def maybe_start(self):
|
||||||
|
with self._lock:
|
||||||
|
if not reactor.running:
|
||||||
|
self._thread = Thread(target=reactor.run,
|
||||||
|
name="cassandra_driver_event_loop",
|
||||||
|
kwargs={'installSignalHandlers': False})
|
||||||
|
self._thread.daemon = True
|
||||||
|
self._thread.start()
|
||||||
|
atexit.register(partial(_cleanup, weakref.ref(self)))
|
||||||
|
|
||||||
|
def _cleanup(self):
|
||||||
|
if self._thread:
|
||||||
|
reactor.callFromThread(reactor.stop)
|
||||||
|
self._thread.join(timeout=1.0)
|
||||||
|
if self._thread.is_alive():
|
||||||
|
log.warning("Event loop thread could not be joined, so "
|
||||||
|
"shutdown may not be clean. Please call "
|
||||||
|
"Cluster.shutdown() to avoid this.")
|
||||||
|
log.debug("Event loop thread was joined")
|
||||||
|
|
||||||
|
|
||||||
|
class TwistedConnection(Connection):
|
||||||
|
"""
|
||||||
|
An implementation of :class:`.Connection` that utilizes the
|
||||||
|
Twisted event loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_loop = None
|
||||||
|
_total_reqd_bytes = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initialize_reactor(cls):
|
||||||
|
if not cls._loop:
|
||||||
|
cls._loop = TwistedLoop()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def factory(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
A factory function which returns connections which have
|
||||||
|
succeeded in connecting and are ready for service (or
|
||||||
|
raises an exception otherwise).
|
||||||
|
"""
|
||||||
|
timeout = kwargs.pop('timeout', 5.0)
|
||||||
|
conn = cls(*args, **kwargs)
|
||||||
|
conn.connected_event.wait(timeout)
|
||||||
|
if conn.last_error:
|
||||||
|
raise conn.last_error
|
||||||
|
elif not conn.connected_event.is_set():
|
||||||
|
conn.close()
|
||||||
|
raise OperationTimedOut("Timed out creating connection")
|
||||||
|
else:
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialization method.
|
||||||
|
|
||||||
|
Note that we can't call reactor methods directly here because
|
||||||
|
it's not thread-safe, so we schedule the reactor/connection
|
||||||
|
stuff to be run from the event loop thread when it gets the
|
||||||
|
chance.
|
||||||
|
"""
|
||||||
|
Connection.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
|
self.connected_event = Event()
|
||||||
|
self._iobuf = BytesIO()
|
||||||
|
self.is_closed = True
|
||||||
|
self.connector = None
|
||||||
|
|
||||||
|
self._callbacks = {}
|
||||||
|
reactor.callFromThread(self.add_connection)
|
||||||
|
self._loop.maybe_start()
|
||||||
|
|
||||||
|
def add_connection(self):
|
||||||
|
"""
|
||||||
|
Convenience function to connect and store the resulting
|
||||||
|
connector.
|
||||||
|
"""
|
||||||
|
self.connector = reactor.connectTCP(
|
||||||
|
host=self.host, port=self.port,
|
||||||
|
factory=TwistedConnectionClientFactory(self))
|
||||||
|
|
||||||
|
def client_connection_made(self):
|
||||||
|
"""
|
||||||
|
Called by twisted protocol when a connection attempt has
|
||||||
|
succeeded.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
self.is_closed = False
|
||||||
|
self._send_options_message()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Disconnect and error-out all callbacks.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
if self.is_closed:
|
||||||
|
return
|
||||||
|
self.is_closed = True
|
||||||
|
|
||||||
|
log.debug("Closing connection (%s) to %s", id(self), self.host)
|
||||||
|
self.connector.disconnect()
|
||||||
|
log.debug("Closed socket to %s", self.host)
|
||||||
|
|
||||||
|
if not self.is_defunct:
|
||||||
|
self.error_all_callbacks(
|
||||||
|
ConnectionShutdown("Connection to %s was closed" % self.host))
|
||||||
|
# don't leave in-progress operations hanging
|
||||||
|
self.connected_event.set()
|
||||||
|
|
||||||
|
def handle_read(self):
|
||||||
|
"""
|
||||||
|
Process the incoming data buffer.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
pos = self._iobuf.tell()
|
||||||
|
if pos < 8 or (self._total_reqd_bytes > 0 and
|
||||||
|
pos < self._total_reqd_bytes):
|
||||||
|
# we don't have a complete header yet or we
|
||||||
|
# already saw a header, but we don't have a
|
||||||
|
# complete message yet
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# have enough for header, read body len from header
|
||||||
|
self._iobuf.seek(4)
|
||||||
|
body_len = int32_unpack(self._iobuf.read(4))
|
||||||
|
|
||||||
|
# seek to end to get length of current buffer
|
||||||
|
self._iobuf.seek(0, os.SEEK_END)
|
||||||
|
pos = self._iobuf.tell()
|
||||||
|
|
||||||
|
if pos >= body_len + 8:
|
||||||
|
# read message header and body
|
||||||
|
self._iobuf.seek(0)
|
||||||
|
msg = self._iobuf.read(8 + body_len)
|
||||||
|
|
||||||
|
# leave leftover in current buffer
|
||||||
|
leftover = self._iobuf.read()
|
||||||
|
self._iobuf = BytesIO()
|
||||||
|
self._iobuf.write(leftover)
|
||||||
|
|
||||||
|
self._total_reqd_bytes = 0
|
||||||
|
self.process_msg(msg, body_len)
|
||||||
|
else:
|
||||||
|
self._total_reqd_bytes = body_len + 8
|
||||||
|
return
|
||||||
|
|
||||||
|
def push(self, data):
|
||||||
|
"""
|
||||||
|
This function is called when outgoing data should be queued
|
||||||
|
for sending.
|
||||||
|
|
||||||
|
Note that we can't call transport.write() directly because
|
||||||
|
it is not thread-safe, so we schedule it to run from within
|
||||||
|
the event loop when it gets the chance.
|
||||||
|
"""
|
||||||
|
reactor.callFromThread(self.connector.transport.write, data)
|
||||||
|
|
||||||
|
def register_watcher(self, event_type, callback, register_timeout=None):
|
||||||
|
"""
|
||||||
|
Register a callback for a given event type.
|
||||||
|
"""
|
||||||
|
self._push_watchers[event_type].add(callback)
|
||||||
|
self.wait_for_response(
|
||||||
|
RegisterMessage(event_list=[event_type]),
|
||||||
|
timeout=register_timeout)
|
||||||
|
|
||||||
|
def register_watchers(self, type_callback_dict, register_timeout=None):
|
||||||
|
"""
|
||||||
|
Register multiple callback/event type pairs, expressed as a dict.
|
||||||
|
"""
|
||||||
|
for event_type, callback in type_callback_dict.items():
|
||||||
|
self._push_watchers[event_type].add(callback)
|
||||||
|
self.wait_for_response(
|
||||||
|
RegisterMessage(event_list=type_callback_dict.keys()),
|
||||||
|
timeout=register_timeout)
|
||||||
194
tests/unit/io/test_twistedreactor.py
Normal file
194
tests/unit/io/test_twistedreactor.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
# Copyright 2013-2014 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
try:
|
||||||
|
import unittest2 as unittest
|
||||||
|
except ImportError:
|
||||||
|
import unittest
|
||||||
|
from mock import Mock, patch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from twisted.test import proto_helpers
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
from cassandra.io import twistedreactor
|
||||||
|
except ImportError:
|
||||||
|
twistedreactor = None # NOQA
|
||||||
|
|
||||||
|
|
||||||
|
class TestTwistedProtocol(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
if twistedreactor is None:
|
||||||
|
raise unittest.SkipTest("Twisted libraries not available")
|
||||||
|
twistedreactor.TwistedConnection.initialize_reactor()
|
||||||
|
self.tr = proto_helpers.StringTransportWithDisconnection()
|
||||||
|
self.tr.connector = Mock()
|
||||||
|
self.mock_connection = Mock()
|
||||||
|
self.tr.connector.factory = twistedreactor.TwistedConnectionClientFactory(
|
||||||
|
self.mock_connection)
|
||||||
|
self.obj_ut = twistedreactor.TwistedConnectionProtocol()
|
||||||
|
self.tr.protocol = self.obj_ut
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_makeConnection(self):
|
||||||
|
"""
|
||||||
|
Verify that the protocol class notifies the connection
|
||||||
|
object that a successful connection was made.
|
||||||
|
"""
|
||||||
|
self.obj_ut.makeConnection(self.tr)
|
||||||
|
self.assertTrue(self.mock_connection.client_connection_made.called)
|
||||||
|
|
||||||
|
def test_receiving_data(self):
|
||||||
|
"""
|
||||||
|
Verify that the dataReceived() callback writes the data to
|
||||||
|
the connection object's buffer and calls handle_read().
|
||||||
|
"""
|
||||||
|
self.obj_ut.makeConnection(self.tr)
|
||||||
|
self.obj_ut.dataReceived('foobar')
|
||||||
|
self.assertTrue(self.mock_connection.handle_read.called)
|
||||||
|
self.mock_connection._iobuf.write.assert_called_with("foobar")
|
||||||
|
|
||||||
|
|
||||||
|
class TestTwistedClientFactory(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
if twistedreactor is None:
|
||||||
|
raise unittest.SkipTest("Twisted libraries not available")
|
||||||
|
twistedreactor.TwistedConnection.initialize_reactor()
|
||||||
|
self.mock_connection = Mock()
|
||||||
|
self.obj_ut = twistedreactor.TwistedConnectionClientFactory(
|
||||||
|
self.mock_connection)
|
||||||
|
|
||||||
|
def test_client_connection_failed(self):
|
||||||
|
"""
|
||||||
|
Verify that connection failed causes the connection object to close.
|
||||||
|
"""
|
||||||
|
exc = Exception('a test')
|
||||||
|
self.obj_ut.clientConnectionFailed(None, Failure(exc))
|
||||||
|
self.mock_connection.defunct.assert_called_with(exc)
|
||||||
|
|
||||||
|
def test_client_connection_lost(self):
|
||||||
|
"""
|
||||||
|
Verify that connection lost causes the connection object to close.
|
||||||
|
"""
|
||||||
|
exc = Exception('a test')
|
||||||
|
self.obj_ut.clientConnectionLost(None, Failure(exc))
|
||||||
|
self.mock_connection.defunct.assert_called_with(exc)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTwistedConnection(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
if twistedreactor is None:
|
||||||
|
raise unittest.SkipTest("Twisted libraries not available")
|
||||||
|
twistedreactor.TwistedConnection.initialize_reactor()
|
||||||
|
self.reactor_cft_patcher = patch(
|
||||||
|
'twisted.internet.reactor.callFromThread')
|
||||||
|
self.reactor_running_patcher = patch(
|
||||||
|
'twisted.internet.reactor.running', False)
|
||||||
|
self.reactor_run_patcher = patch('twisted.internet.reactor.run')
|
||||||
|
self.mock_reactor_cft = self.reactor_cft_patcher.start()
|
||||||
|
self.mock_reactor_run = self.reactor_run_patcher.start()
|
||||||
|
self.obj_ut = twistedreactor.TwistedConnection('1.2.3.4',
|
||||||
|
cql_version='3.0.1')
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.reactor_cft_patcher.stop()
|
||||||
|
self.reactor_run_patcher.stop()
|
||||||
|
self.obj_ut._loop._cleanup()
|
||||||
|
|
||||||
|
def test_connection_initialization(self):
|
||||||
|
"""
|
||||||
|
Verify that __init__() works correctly.
|
||||||
|
"""
|
||||||
|
self.mock_reactor_cft.assert_called_with(self.obj_ut.add_connection)
|
||||||
|
self.obj_ut._loop._cleanup()
|
||||||
|
self.mock_reactor_run.assert_called_with(installSignalHandlers=False)
|
||||||
|
|
||||||
|
@patch('twisted.internet.reactor.connectTCP')
|
||||||
|
def test_add_connection(self, mock_connectTCP):
|
||||||
|
"""
|
||||||
|
Verify that add_connection() gives us a valid twisted connector.
|
||||||
|
"""
|
||||||
|
self.obj_ut.add_connection()
|
||||||
|
self.assertTrue(self.obj_ut.connector is not None)
|
||||||
|
self.assertTrue(mock_connectTCP.called)
|
||||||
|
|
||||||
|
def test_client_connection_made(self):
|
||||||
|
"""
|
||||||
|
Verifiy that _send_options_message() is called in
|
||||||
|
client_connection_made()
|
||||||
|
"""
|
||||||
|
self.obj_ut._send_options_message = Mock()
|
||||||
|
self.obj_ut.client_connection_made()
|
||||||
|
self.obj_ut._send_options_message.assert_called_with()
|
||||||
|
|
||||||
|
@patch('twisted.internet.reactor.connectTCP')
|
||||||
|
def test_close(self, mock_connectTCP):
|
||||||
|
"""
|
||||||
|
Verify that close() disconnects the connector and errors callbacks.
|
||||||
|
"""
|
||||||
|
self.obj_ut.error_all_callbacks = Mock()
|
||||||
|
self.obj_ut.add_connection()
|
||||||
|
self.obj_ut.is_closed = False
|
||||||
|
self.obj_ut.close()
|
||||||
|
self.obj_ut.connector.disconnect.assert_called_with()
|
||||||
|
self.assertTrue(self.obj_ut.connected_event.is_set())
|
||||||
|
self.assertTrue(self.obj_ut.error_all_callbacks.called)
|
||||||
|
|
||||||
|
def test_handle_read__incomplete(self):
|
||||||
|
"""
|
||||||
|
Verify that handle_read() processes incomplete messages properly.
|
||||||
|
"""
|
||||||
|
self.obj_ut.process_msg = Mock()
|
||||||
|
self.assertEqual(self.obj_ut._iobuf.getvalue(), '') # buf starts empty
|
||||||
|
# incomplete header
|
||||||
|
self.obj_ut._iobuf.write('\xff\x00\x00\x00')
|
||||||
|
self.obj_ut.handle_read()
|
||||||
|
self.assertEqual(self.obj_ut._iobuf.getvalue(), '\xff\x00\x00\x00')
|
||||||
|
|
||||||
|
# full header, but incomplete body
|
||||||
|
self.obj_ut._iobuf.write('\x00\x00\x00\x15')
|
||||||
|
self.obj_ut.handle_read()
|
||||||
|
self.assertEqual(self.obj_ut._iobuf.getvalue(),
|
||||||
|
'\xff\x00\x00\x00\x00\x00\x00\x15')
|
||||||
|
self.assertEqual(self.obj_ut._total_reqd_bytes, 29)
|
||||||
|
|
||||||
|
# verify we never attempted to process the incomplete message
|
||||||
|
self.assertFalse(self.obj_ut.process_msg.called)
|
||||||
|
|
||||||
|
def test_handle_read__fullmessage(self):
|
||||||
|
"""
|
||||||
|
Verify that handle_read() processes complete messages properly.
|
||||||
|
"""
|
||||||
|
self.obj_ut.process_msg = Mock()
|
||||||
|
self.assertEqual(self.obj_ut._iobuf.getvalue(), '') # buf starts empty
|
||||||
|
|
||||||
|
# write a complete message, plus 'NEXT' (to simulate next message)
|
||||||
|
self.obj_ut._iobuf.write(
|
||||||
|
'\xff\x00\x00\x00\x00\x00\x00\x15this is the drum rollNEXT')
|
||||||
|
self.obj_ut.handle_read()
|
||||||
|
self.assertEqual(self.obj_ut._iobuf.getvalue(), 'NEXT')
|
||||||
|
self.obj_ut.process_msg.assert_called_with(
|
||||||
|
'\xff\x00\x00\x00\x00\x00\x00\x15this is the drum roll', 21)
|
||||||
|
|
||||||
|
@patch('twisted.internet.reactor.connectTCP')
|
||||||
|
def test_push(self, mock_connectTCP):
|
||||||
|
"""
|
||||||
|
Verifiy that push() calls transport.write(data).
|
||||||
|
"""
|
||||||
|
self.obj_ut.add_connection()
|
||||||
|
self.obj_ut.push('123 pickup')
|
||||||
|
self.mock_reactor_cft.assert_called_with(
|
||||||
|
self.obj_ut.connector.transport.write, '123 pickup')
|
||||||
Reference in New Issue
Block a user