Merge remote-tracking branch 'origin/master' into PYTHON-190

Conflicts:
	cassandra/cqltypes.py
	tests/integration/standard/test_types.py
This commit is contained in:
Adam Holmberg
2015-01-30 13:54:08 -06:00
33 changed files with 1501 additions and 177 deletions

View File

@@ -1,14 +1,22 @@
2.1.3-post 2.1.4
========== =====
Features Features
-------- --------
* SaslAuthenticator for Kerberos support (PYTHON-109) * SaslAuthenticator for Kerberos support (PYTHON-109)
* Heartbeat for network device keepalive and detecting failures on idle connections (PYTHON-197)
* Support nested, frozen collections for Cassandra 2.1.3+ (PYTHON-186)
* Schema agreement wait bypass config, new call for synchronous schema refresh (PYTHON-205)
* Add eventlet connection support (PYTHON-194)
Bug Fixes Bug Fixes
--------- ---------
* Schema meta fix for complex thrift tables (PYTHON-191) * Schema meta fix for complex thrift tables (PYTHON-191)
* Support for 'unknown' replica placement strategies in schema meta (PYTHON-192) * Support for 'unknown' replica placement strategies in schema meta (PYTHON-192)
* Resolve stream ID leak on set_keyspace (PYTHON-195)
* Remove implicit timestamp scaling on serialization of numeric timestamps (PYTHON-204)
* Resolve stream id collision when using SASL auth (PYTHON-210)
* Correct unhexlify usage for user defined type meta in Python3 (PYTHON-208)
2.1.3 2.1.3
===== =====

View File

@@ -23,7 +23,7 @@ class NullHandler(logging.Handler):
logging.getLogger('cassandra').addHandler(NullHandler()) logging.getLogger('cassandra').addHandler(NullHandler())
__version_info__ = (2, 1, 3, 'post') __version_info__ = (2, 1, 4, 'post')
__version__ = '.'.join(map(str, __version_info__)) __version__ = '.'.join(map(str, __version_info__))

View File

@@ -130,7 +130,9 @@ class PlainTextAuthenticator(Authenticator):
class SaslAuthProvider(AuthProvider): class SaslAuthProvider(AuthProvider):
""" """
An :class:`~.AuthProvider` for Kerberos authenticators An :class:`~.AuthProvider` supporting general SASL auth mechanisms
Suitable for GSSAPI or other SASL mechanisms
Example usage:: Example usage::
@@ -144,7 +146,7 @@ class SaslAuthProvider(AuthProvider):
auth_provider = SaslAuthProvider(**sasl_kwargs) auth_provider = SaslAuthProvider(**sasl_kwargs)
cluster = Cluster(auth_provider=auth_provider) cluster = Cluster(auth_provider=auth_provider)
.. versionadded:: 2.1.3-post .. versionadded:: 2.1.4
""" """
def __init__(self, **sasl_kwargs): def __init__(self, **sasl_kwargs):
@@ -157,9 +159,10 @@ class SaslAuthProvider(AuthProvider):
class SaslAuthenticator(Authenticator): class SaslAuthenticator(Authenticator):
""" """
An :class:`~.Authenticator` that works with DSE's KerberosAuthenticator. A pass-through :class:`~.Authenticator` using the third party package
'pure-sasl' for authentication
.. versionadded:: 2.1.3-post .. versionadded:: 2.1.4
""" """
def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs): def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs):

View File

@@ -22,6 +22,7 @@ import atexit
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import logging import logging
from random import random
import socket import socket
import sys import sys
import time import time
@@ -44,7 +45,8 @@ from itertools import groupby
from cassandra import (ConsistencyLevel, AuthenticationFailed, from cassandra import (ConsistencyLevel, AuthenticationFailed,
InvalidRequest, OperationTimedOut, InvalidRequest, OperationTimedOut,
UnsupportedOperation, Unauthorized) UnsupportedOperation, Unauthorized)
from cassandra.connection import ConnectionException, ConnectionShutdown from cassandra.connection import (ConnectionException, ConnectionShutdown,
ConnectionHeartbeat)
from cassandra.encoder import Encoder from cassandra.encoder import Encoder
from cassandra.protocol import (QueryMessage, ResultMessage, from cassandra.protocol import (QueryMessage, ResultMessage,
ErrorMessage, ReadTimeoutErrorMessage, ErrorMessage, ReadTimeoutErrorMessage,
@@ -68,10 +70,20 @@ from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
BatchStatement, bind_params, QueryTrace, Statement, BatchStatement, bind_params, QueryTrace, Statement,
named_tuple_factory, dict_factory, FETCH_SIZE_UNSET) named_tuple_factory, dict_factory, FETCH_SIZE_UNSET)
# default to gevent when we are monkey patched, otherwise if libev is available, use that as the
# default because it's fastest. Otherwise, use asyncore. def _is_eventlet_monkey_patched():
if 'eventlet.patcher' not in sys.modules:
return False
import eventlet.patcher
return eventlet.patcher.is_monkey_patched('socket')
# default to gevent when we are monkey patched with gevent, eventlet when
# monkey patched with eventlet, otherwise if libev is available, use that as
# the default because it's fastest. Otherwise, use asyncore.
if 'gevent.monkey' in sys.modules: if 'gevent.monkey' in sys.modules:
from cassandra.io.geventreactor import GeventConnection as DefaultConnection from cassandra.io.geventreactor import GeventConnection as DefaultConnection
elif _is_eventlet_monkey_patched():
from cassandra.io.eventletreactor import EventletConnection as DefaultConnection
else: else:
try: try:
from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA
@@ -372,6 +384,48 @@ class Cluster(object):
If set to :const:`None`, there will be no timeout for these queries. If set to :const:`None`, there will be no timeout for these queries.
""" """
idle_heartbeat_interval = 30
"""
Interval, in seconds, on which to heartbeat idle connections. This helps
keep connections open through network devices that expire idle connections.
It also helps discover bad connections early in low-traffic scenarios.
Setting to zero disables heartbeats.
"""
schema_event_refresh_window = 2
"""
Window, in seconds, within which a schema component will be refreshed after
receiving a schema_change event.
The driver delays a random amount of time in the range [0.0, window)
before executing the refresh. This serves two purposes:
1.) Spread the refresh for deployments with large fanout from C* to client tier,
preventing a 'thundering herd' problem with many clients refreshing simultaneously.
2.) Remove redundant refreshes. Redundant events arriving within the delay period
are discarded, and only one refresh is executed.
Setting this to zero will execute refreshes immediately.
Setting this negative will disable schema refreshes in response to push events
(refreshes will still occur in response to schema change responses to DDL statements
executed by Sessions of this Cluster).
"""
topology_event_refresh_window = 10
"""
Window, in seconds, within which the node and token list will be refreshed after
receiving a topology_change event.
Setting this to zero will execute refreshes immediately.
Setting this negative will disable node refreshes in response to push events
(refreshes will still occur in response to new nodes observed on "UP" events).
See :attr:`.schema_event_refresh_window` for discussion of rationale
"""
sessions = None sessions = None
control_connection = None control_connection = None
scheduler = None scheduler = None
@@ -380,6 +434,7 @@ class Cluster(object):
_is_setup = False _is_setup = False
_prepared_statements = None _prepared_statements = None
_prepared_statement_lock = None _prepared_statement_lock = None
_idle_heartbeat = None
_user_types = None _user_types = None
""" """
@@ -406,7 +461,10 @@ class Cluster(object):
protocol_version=2, protocol_version=2,
executor_threads=2, executor_threads=2,
max_schema_agreement_wait=10, max_schema_agreement_wait=10,
control_connection_timeout=2.0): control_connection_timeout=2.0,
idle_heartbeat_interval=30,
schema_event_refresh_window=2,
topology_event_refresh_window=10):
""" """
Any of the mutable Cluster attributes may be set as keyword arguments Any of the mutable Cluster attributes may be set as keyword arguments
to the constructor. to the constructor.
@@ -456,6 +514,9 @@ class Cluster(object):
self.cql_version = cql_version self.cql_version = cql_version
self.max_schema_agreement_wait = max_schema_agreement_wait self.max_schema_agreement_wait = max_schema_agreement_wait
self.control_connection_timeout = control_connection_timeout self.control_connection_timeout = control_connection_timeout
self.idle_heartbeat_interval = idle_heartbeat_interval
self.schema_event_refresh_window = schema_event_refresh_window
self.topology_event_refresh_window = topology_event_refresh_window
self._listeners = set() self._listeners = set()
self._listener_lock = Lock() self._listener_lock = Lock()
@@ -500,7 +561,8 @@ class Cluster(object):
self.metrics = Metrics(weakref.proxy(self)) self.metrics = Metrics(weakref.proxy(self))
self.control_connection = ControlConnection( self.control_connection = ControlConnection(
self, self.control_connection_timeout) self, self.control_connection_timeout,
self.schema_event_refresh_window, self.topology_event_refresh_window)
def register_user_type(self, keyspace, user_type, klass): def register_user_type(self, keyspace, user_type, klass):
""" """
@@ -621,7 +683,7 @@ class Cluster(object):
def set_max_connections_per_host(self, host_distance, max_connections): def set_max_connections_per_host(self, host_distance, max_connections):
""" """
Gets the maximum number of connections per Session that will be opened Sets the maximum number of connections per Session that will be opened
for each host with :class:`~.HostDistance` equal to `host_distance`. for each host with :class:`~.HostDistance` equal to `host_distance`.
The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for
:attr:`~HostDistance.REMOTE`. :attr:`~HostDistance.REMOTE`.
@@ -688,7 +750,6 @@ class Cluster(object):
self.load_balancing_policy.populate( self.load_balancing_policy.populate(
weakref.proxy(self), self.metadata.all_hosts()) weakref.proxy(self), self.metadata.all_hosts())
if self.control_connection:
try: try:
self.control_connection.connect() self.control_connection.connect()
log.debug("Control connection created") log.debug("Control connection created")
@@ -700,6 +761,8 @@ class Cluster(object):
self.load_balancing_policy.check_supported() self.load_balancing_policy.check_supported()
if self.idle_heartbeat_interval:
self._idle_heartbeat = ConnectionHeartbeat(self.idle_heartbeat_interval, self.get_connection_holders)
self._is_setup = True self._is_setup = True
session = self._new_session() session = self._new_session()
@@ -707,6 +770,13 @@ class Cluster(object):
session.set_keyspace(keyspace) session.set_keyspace(keyspace)
return session return session
def get_connection_holders(self):
holders = []
for s in self.sessions:
holders.extend(s.get_pools())
holders.append(self.control_connection)
return holders
def shutdown(self): def shutdown(self):
""" """
Closes all sessions and connection associated with this Cluster. Closes all sessions and connection associated with this Cluster.
@@ -721,17 +791,16 @@ class Cluster(object):
else: else:
self.is_shutdown = True self.is_shutdown = True
if self.scheduler: if self._idle_heartbeat:
self._idle_heartbeat.stop()
self.scheduler.shutdown() self.scheduler.shutdown()
if self.control_connection:
self.control_connection.shutdown() self.control_connection.shutdown()
if self.sessions:
for session in self.sessions: for session in self.sessions:
session.shutdown() session.shutdown()
if self.executor:
self.executor.shutdown() self.executor.shutdown()
def _new_session(self): def _new_session(self):
@@ -907,7 +976,7 @@ class Cluster(object):
self._start_reconnector(host, is_host_addition) self._start_reconnector(host, is_host_addition)
def on_add(self, host): def on_add(self, host, refresh_nodes=True):
if self.is_shutdown: if self.is_shutdown:
return return
@@ -919,7 +988,7 @@ class Cluster(object):
log.debug("Done preparing queries for new host %r", host) log.debug("Done preparing queries for new host %r", host)
self.load_balancing_policy.on_add(host) self.load_balancing_policy.on_add(host)
self.control_connection.on_add(host) self.control_connection.on_add(host, refresh_nodes)
if distance == HostDistance.IGNORED: if distance == HostDistance.IGNORED:
log.debug("Not adding connection pool for new host %r because the " log.debug("Not adding connection pool for new host %r because the "
@@ -995,7 +1064,7 @@ class Cluster(object):
self.on_down(host, is_host_addition, expect_host_to_be_down) self.on_down(host, is_host_addition, expect_host_to_be_down)
return is_down return is_down
def add_host(self, address, datacenter=None, rack=None, signal=True): def add_host(self, address, datacenter=None, rack=None, signal=True, refresh_nodes=True):
""" """
Called when adding initial contact points and when the control Called when adding initial contact points and when the control
connection subsequently discovers a new node. Intended for internal connection subsequently discovers a new node. Intended for internal
@@ -1004,7 +1073,7 @@ class Cluster(object):
new_host = self.metadata.add_host(address, datacenter, rack) new_host = self.metadata.add_host(address, datacenter, rack)
if new_host and signal: if new_host and signal:
log.info("New Cassandra host %r discovered", new_host) log.info("New Cassandra host %r discovered", new_host)
self.on_add(new_host) self.on_add(new_host, refresh_nodes)
return new_host return new_host
@@ -1045,16 +1114,20 @@ class Cluster(object):
for pool in session._pools.values(): for pool in session._pools.values():
pool.ensure_core_connections() pool.ensure_core_connections()
def refresh_schema(self, keyspace=None, table=None, usertype=None, schema_agreement_wait=None): def refresh_schema(self, keyspace=None, table=None, usertype=None, max_schema_agreement_wait=None):
""" """
Synchronously refresh the schema metadata. Synchronously refresh the schema metadata.
By default timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait`
By default, the timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait`
and :attr:`~.Cluster.control_connection_timeout`. and :attr:`~.Cluster.control_connection_timeout`.
Passing schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`.
Setting schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately. Passing max_schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`.
Setting max_schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately.
An Exception is raised if schema refresh fails for any reason. An Exception is raised if schema refresh fails for any reason.
""" """
if not self.control_connection.refresh_schema(keyspace, table, usertype, schema_agreement_wait): if not self.control_connection.refresh_schema(keyspace, table, usertype, max_schema_agreement_wait):
raise Exception("Schema was not refreshed. See log for details.") raise Exception("Schema was not refreshed. See log for details.")
def submit_schema_refresh(self, keyspace=None, table=None, usertype=None): def submit_schema_refresh(self, keyspace=None, table=None, usertype=None):
@@ -1066,6 +1139,27 @@ class Cluster(object):
return self.executor.submit( return self.executor.submit(
self.control_connection.refresh_schema, keyspace, table, usertype) self.control_connection.refresh_schema, keyspace, table, usertype)
def refresh_nodes(self):
"""
Synchronously refresh the node list and token metadata
An Exception is raised if node refresh fails for any reason.
"""
if not self.control_connection.refresh_node_list_and_token_map():
raise Exception("Node list was not refreshed. See log for details.")
def set_meta_refresh_enabled(self, enabled):
"""
Sets a flag to enable (True) or disable (False) all metadata refresh queries.
This applies to both schema and node topology.
Disabling this is useful to minimize refreshes during multiple changes.
Meta refresh must be enabled for the driver to become aware of any cluster
topology changes or schema updates.
"""
self.control_connection.set_meta_refresh_enabled(bool(enabled))
def _prepare_all_queries(self, host): def _prepare_all_queries(self, host):
if not self._prepared_statements: if not self._prepared_statements:
return return
@@ -1656,6 +1750,9 @@ class Session(object):
def get_pool_state(self): def get_pool_state(self):
return dict((host, pool.get_state()) for host, pool in self._pools.items()) return dict((host, pool.get_state()) for host, pool in self._pools.items())
def get_pools(self):
return self._pools.values()
class UserTypeDoesNotExist(Exception): class UserTypeDoesNotExist(Exception):
""" """
@@ -1734,16 +1831,26 @@ class ControlConnection(object):
_timeout = None _timeout = None
_protocol_version = None _protocol_version = None
_schema_event_refresh_window = None
_topology_event_refresh_window = None
_meta_refresh_enabled = True
# for testing purposes # for testing purposes
_time = time _time = time
def __init__(self, cluster, timeout): def __init__(self, cluster, timeout,
schema_event_refresh_window,
topology_event_refresh_window):
# use a weak reference to allow the Cluster instance to be GC'ed (and # use a weak reference to allow the Cluster instance to be GC'ed (and
# shutdown) since implementing __del__ disables the cycle detector # shutdown) since implementing __del__ disables the cycle detector
self._cluster = weakref.proxy(cluster) self._cluster = weakref.proxy(cluster)
self._connection = None self._connection = None
self._timeout = timeout self._timeout = timeout
self._schema_event_refresh_window = schema_event_refresh_window
self._topology_event_refresh_window = topology_event_refresh_window
self._lock = RLock() self._lock = RLock()
self._schema_agreement_lock = Lock() self._schema_agreement_lock = Lock()
@@ -1901,6 +2008,10 @@ class ControlConnection(object):
def refresh_schema(self, keyspace=None, table=None, usertype=None, def refresh_schema(self, keyspace=None, table=None, usertype=None,
schema_agreement_wait=None): schema_agreement_wait=None):
if not self._meta_refresh_enabled:
log.debug("[control connection] Skipping schema refresh because meta refresh is disabled")
return False
try: try:
if self._connection: if self._connection:
return self._refresh_schema(self._connection, keyspace, table, usertype, return self._refresh_schema(self._connection, keyspace, table, usertype,
@@ -2028,14 +2139,20 @@ class ControlConnection(object):
return True return True
def refresh_node_list_and_token_map(self, force_token_rebuild=False): def refresh_node_list_and_token_map(self, force_token_rebuild=False):
if not self._meta_refresh_enabled:
log.debug("[control connection] Skipping node list refresh because meta refresh is disabled")
return False
try: try:
if self._connection: if self._connection:
self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild) self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild)
return True
except ReferenceError: except ReferenceError:
pass # our weak reference to the Cluster is no good pass # our weak reference to the Cluster is no good
except Exception: except Exception:
log.debug("[control connection] Error refreshing node list and token map", exc_info=True) log.debug("[control connection] Error refreshing node list and token map", exc_info=True)
self._signal_error() self._signal_error()
return False
def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
force_token_rebuild=False): force_token_rebuild=False):
@@ -2096,7 +2213,7 @@ class ControlConnection(object):
rack = row.get("rack") rack = row.get("rack")
if host is None: if host is None:
log.debug("[control connection] Found new host to connect to: %s", addr) log.debug("[control connection] Found new host to connect to: %s", addr)
host = self._cluster.add_host(addr, datacenter, rack, signal=True) host = self._cluster.add_host(addr, datacenter, rack, signal=True, refresh_nodes=False)
should_rebuild_token_map = True should_rebuild_token_map = True
else: else:
should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) should_rebuild_token_map |= self._update_location_info(host, datacenter, rack)
@@ -2131,25 +2248,25 @@ class ControlConnection(object):
def _handle_topology_change(self, event): def _handle_topology_change(self, event):
change_type = event["change_type"] change_type = event["change_type"]
addr, port = event["address"] addr, port = event["address"]
if change_type == "NEW_NODE": if change_type == "NEW_NODE" or change_type == "MOVED_NODE":
self._cluster.scheduler.schedule(10, self.refresh_node_list_and_token_map) if self._topology_event_refresh_window >= 0:
delay = random() * self._topology_event_refresh_window
self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map)
elif change_type == "REMOVED_NODE": elif change_type == "REMOVED_NODE":
host = self._cluster.metadata.get_host(addr) host = self._cluster.metadata.get_host(addr)
self._cluster.scheduler.schedule(0, self._cluster.remove_host, host) self._cluster.scheduler.schedule_unique(0, self._cluster.remove_host, host)
elif change_type == "MOVED_NODE":
self._cluster.scheduler.schedule(1, self.refresh_node_list_and_token_map)
def _handle_status_change(self, event): def _handle_status_change(self, event):
change_type = event["change_type"] change_type = event["change_type"]
addr, port = event["address"] addr, port = event["address"]
host = self._cluster.metadata.get_host(addr) host = self._cluster.metadata.get_host(addr)
if change_type == "UP": if change_type == "UP":
delay = 1 + random() * 0.5 # randomness to avoid thundering herd problem on events
if host is None: if host is None:
# this is the first time we've seen the node # this is the first time we've seen the node
self._cluster.scheduler.schedule(2, self.refresh_node_list_and_token_map) self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map)
else: else:
# this will be run by the scheduler self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host)
self._cluster.scheduler.schedule(2, self._cluster.on_up, host)
elif change_type == "DOWN": elif change_type == "DOWN":
# Note that there is a slight risk we can receive the event late and thus # Note that there is a slight risk we can receive the event late and thus
# mark the host down even though we already had reconnected successfully. # mark the host down even though we already had reconnected successfully.
@@ -2160,10 +2277,14 @@ class ControlConnection(object):
self._cluster.on_down(host, is_host_addition=False) self._cluster.on_down(host, is_host_addition=False)
def _handle_schema_change(self, event): def _handle_schema_change(self, event):
if self._schema_event_refresh_window < 0:
return
keyspace = event.get('keyspace') keyspace = event.get('keyspace')
table = event.get('table') table = event.get('table')
usertype = event.get('type') usertype = event.get('type')
self._submit(self.refresh_schema, keyspace, table, usertype) delay = random() * self._schema_event_refresh_window
self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, keyspace, table, usertype)
def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None):
@@ -2271,11 +2392,6 @@ class ControlConnection(object):
# manually # manually
self.reconnect() self.reconnect()
@property
def is_open(self):
conn = self._connection
return bool(conn and conn.is_open)
def on_up(self, host): def on_up(self, host):
pass pass
@@ -2289,12 +2405,24 @@ class ControlConnection(object):
# this will result in a task being submitted to the executor to reconnect # this will result in a task being submitted to the executor to reconnect
self.reconnect() self.reconnect()
def on_add(self, host): def on_add(self, host, refresh_nodes=True):
if refresh_nodes:
self.refresh_node_list_and_token_map(force_token_rebuild=True) self.refresh_node_list_and_token_map(force_token_rebuild=True)
def on_remove(self, host): def on_remove(self, host):
self.refresh_node_list_and_token_map(force_token_rebuild=True) self.refresh_node_list_and_token_map(force_token_rebuild=True)
def get_connections(self):
c = getattr(self, '_connection', None)
return [c] if c else []
def return_connection(self, connection):
if connection is self._connection and (connection.is_defunct or connection.is_closed):
self.reconnect()
def set_meta_refresh_enabled(self, enabled):
self._meta_refresh_enabled = enabled
def _stop_scheduler(scheduler, thread): def _stop_scheduler(scheduler, thread):
try: try:
@@ -2308,12 +2436,14 @@ def _stop_scheduler(scheduler, thread):
class _Scheduler(object): class _Scheduler(object):
_scheduled = None _queue = None
_scheduled_tasks = None
_executor = None _executor = None
is_shutdown = False is_shutdown = False
def __init__(self, executor): def __init__(self, executor):
self._scheduled = Queue.PriorityQueue() self._queue = Queue.PriorityQueue()
self._scheduled_tasks = set()
self._executor = executor self._executor = executor
t = Thread(target=self.run, name="Task Scheduler") t = Thread(target=self.run, name="Task Scheduler")
@@ -2331,14 +2461,25 @@ class _Scheduler(object):
# this can happen on interpreter shutdown # this can happen on interpreter shutdown
pass pass
self.is_shutdown = True self.is_shutdown = True
self._scheduled.put_nowait((0, None)) self._queue.put_nowait((0, None))
def schedule(self, delay, fn, *args, **kwargs): def schedule(self, delay, fn, *args):
self._insert_task(delay, (fn, args))
def schedule_unique(self, delay, fn, *args):
task = (fn, args)
if task not in self._scheduled_tasks:
self._insert_task(delay, task)
else:
log.debug("Ignoring schedule_unique for already-scheduled task: %r", task)
def _insert_task(self, delay, task):
if not self.is_shutdown: if not self.is_shutdown:
run_at = time.time() + delay run_at = time.time() + delay
self._scheduled.put_nowait((run_at, (fn, args, kwargs))) self._scheduled_tasks.add(task)
self._queue.put_nowait((run_at, task))
else: else:
log.debug("Ignoring scheduled function after shutdown: %r", fn) log.debug("Ignoring scheduled task after shutdown: %r", task)
def run(self): def run(self):
while True: while True:
@@ -2347,16 +2488,17 @@ class _Scheduler(object):
try: try:
while True: while True:
run_at, task = self._scheduled.get(block=True, timeout=None) run_at, task = self._queue.get(block=True, timeout=None)
if self.is_shutdown: if self.is_shutdown:
log.debug("Not executing scheduled task due to Scheduler shutdown") log.debug("Not executing scheduled task due to Scheduler shutdown")
return return
if run_at <= time.time(): if run_at <= time.time():
fn, args, kwargs = task self._scheduled_tasks.remove(task)
future = self._executor.submit(fn, *args, **kwargs) fn, args = task
future = self._executor.submit(fn, *args)
future.add_done_callback(self._log_if_failed) future.add_done_callback(self._log_if_failed)
else: else:
self._scheduled.put_nowait((run_at, task)) self._queue.put_nowait((run_at, task))
break break
except Queue.Empty: except Queue.Empty:
pass pass
@@ -2373,9 +2515,13 @@ class _Scheduler(object):
def refresh_schema_and_set_result(keyspace, table, usertype, control_conn, response_future): def refresh_schema_and_set_result(keyspace, table, usertype, control_conn, response_future):
try: try:
if control_conn._meta_refresh_enabled:
log.debug("Refreshing schema in response to schema change. Keyspace: %s; Table: %s, Type: %s", log.debug("Refreshing schema in response to schema change. Keyspace: %s; Table: %s, Type: %s",
keyspace, table, usertype) keyspace, table, usertype)
control_conn._refresh_schema(response_future._connection, keyspace, table, usertype) control_conn._refresh_schema(response_future._connection, keyspace, table, usertype)
else:
log.debug("Skipping schema refresh in response to schema change because meta refresh is disabled; "
"Keyspace: %s; Table: %s, Type: %s", keyspace, table, usertype)
except Exception: except Exception:
log.exception("Exception refreshing schema in response to schema change:") log.exception("Exception refreshing schema in response to schema change:")
response_future.session.submit( response_future.session.submit(

View File

@@ -20,7 +20,7 @@ import io
import logging import logging
import os import os
import sys import sys
from threading import Event, RLock from threading import Thread, Event, RLock
import time import time
if 'gevent.monkey' in sys.modules: if 'gevent.monkey' in sys.modules:
@@ -159,7 +159,7 @@ class Connection(object):
in_flight = 0 in_flight = 0
# A set of available request IDs. When using the v3 protocol or higher, # A set of available request IDs. When using the v3 protocol or higher,
# this will no initially include all request IDs in order to save memory, # this will not initially include all request IDs in order to save memory,
# but the set will grow if it is exhausted. # but the set will grow if it is exhausted.
request_ids = None request_ids = None
@@ -172,6 +172,8 @@ class Connection(object):
lock = None lock = None
user_type_map = None user_type_map = None
msg_received = False
is_control_connection = False is_control_connection = False
_iobuf = None _iobuf = None
@@ -401,6 +403,8 @@ class Connection(object):
with self.lock: with self.lock:
self.request_ids.append(stream_id) self.request_ids.append(stream_id)
self.msg_received = True
body = None body = None
try: try:
# check that the protocol version is supported # check that the protocol version is supported
@@ -673,6 +677,13 @@ class Connection(object):
self.send_msg(query, request_id, process_result) self.send_msg(query, request_id, process_result)
@property
def is_idle(self):
return not self.msg_received
def reset_idle(self):
self.msg_received = False
def __str__(self): def __str__(self):
status = "" status = ""
if self.is_defunct: if self.is_defunct:
@@ -732,3 +743,100 @@ class ResponseWaiter(object):
raise OperationTimedOut() raise OperationTimedOut()
else: else:
return self.responses return self.responses
class HeartbeatFuture(object):
def __init__(self, connection, owner):
self._exception = None
self._event = Event()
self.connection = connection
self.owner = owner
log.debug("Sending options message heartbeat on idle connection (%s) %s",
id(connection), connection.host)
with connection.lock:
if connection.in_flight < connection.max_request_id:
connection.in_flight += 1
connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback)
else:
self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold")
self._event.set()
def wait(self, timeout):
self._event.wait(timeout)
if self._event.is_set():
if self._exception:
raise self._exception
else:
raise OperationTimedOut()
def _options_callback(self, response):
if not isinstance(response, SupportedMessage):
if isinstance(response, ConnectionException):
self._exception = response
else:
self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s"
% (response,))
log.debug("Received options response on connection (%s) from %s",
id(self.connection), self.connection.host)
self._event.set()
class ConnectionHeartbeat(Thread):
def __init__(self, interval_sec, get_connection_holders):
Thread.__init__(self, name="Connection heartbeat")
self._interval = interval_sec
self._get_connection_holders = get_connection_holders
self._shutdown_event = Event()
self.daemon = True
self.start()
def run(self):
self._shutdown_event.wait(self._interval)
while not self._shutdown_event.is_set():
start_time = time.time()
futures = []
failed_connections = []
try:
for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]:
for connection in connections:
if not (connection.is_defunct or connection.is_closed):
if connection.is_idle:
try:
futures.append(HeartbeatFuture(connection, owner))
except Exception:
log.warning("Failed sending heartbeat message on connection (%s) to %s",
id(connection), connection.host, exc_info=True)
failed_connections.append((connection, owner))
else:
connection.reset_idle()
else:
# make sure the owner sees this defunt/closed connection
owner.return_connection(connection)
for f in futures:
connection = f.connection
try:
f.wait(self._interval)
# TODO: move this, along with connection locks in pool, down into Connection
with connection.lock:
connection.in_flight -= 1
connection.reset_idle()
except Exception:
log.warning("Heartbeat failed for connection (%s) to %s",
id(connection), connection.host, exc_info=True)
failed_connections.append((f.connection, f.owner))
for connection, owner in failed_connections:
connection.defunct(Exception('Connection heartbeat failure'))
owner.return_connection(connection)
except Exception:
log.error("Failed connection heartbeat", exc_info=True)
elapsed = time.time() - start_time
self._shutdown_event.wait(max(self._interval - elapsed, 0.01))
def stop(self):
self._shutdown_event.set()

View File

@@ -31,14 +31,14 @@ from __future__ import absolute_import # to enable import io from stdlib
from binascii import unhexlify from binascii import unhexlify
import calendar import calendar
from collections import namedtuple from collections import namedtuple
import datetime
from decimal import Decimal from decimal import Decimal
import io import io
import re import re
import socket import socket
import time import time
import datetime import sys
from uuid import UUID from uuid import UUID
import warnings
import six import six
from six.moves import range from six.moves import range
@@ -48,7 +48,7 @@ from cassandra.marshal import (int8_pack, int8_unpack,
int32_pack, int32_unpack, int64_pack, int64_unpack, int32_pack, int32_unpack, int64_pack, int64_unpack,
float_pack, float_unpack, double_pack, double_unpack, float_pack, float_unpack, double_pack, double_unpack,
varint_pack, varint_unpack) varint_pack, varint_unpack)
from cassandra.util import OrderedDict, sortedset from cassandra.util import OrderedMap, sortedset
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
@@ -58,10 +58,15 @@ if six.PY3:
_time_types = frozenset((int,)) _time_types = frozenset((int,))
_date_types = frozenset((int,)) _date_types = frozenset((int,))
long = int long = int
def _name_from_hex_string(encoded_name):
bin_str = unhexlify(encoded_name)
return bin_str.decode('ascii')
else: else:
_number_types = frozenset((int, long, float)) _number_types = frozenset((int, long, float))
_time_types = frozenset((int, long)) _time_types = frozenset((int, long))
_date_types = frozenset((int, long)) _date_types = frozenset((int, long))
_name_from_hex_string = unhexlify
def trim_if_startswith(s, prefix): def trim_if_startswith(s, prefix):
@@ -569,7 +574,8 @@ class DateType(_CassandraType):
tval = time.strptime(val, tformat) tval = time.strptime(val, tformat)
except ValueError: except ValueError:
continue continue
return calendar.timegm(tval) + offset # scale seconds to millis for the raw value
return (calendar.timegm(tval) + offset) * 1e3
else: else:
raise ValueError("can't interpret %r as a date" % (val,)) raise ValueError("can't interpret %r as a date" % (val,))
@@ -584,31 +590,16 @@ class DateType(_CassandraType):
@staticmethod @staticmethod
def serialize(v, protocol_version): def serialize(v, protocol_version):
try: try:
converted = calendar.timegm(v.utctimetuple()) # v is datetime
converted = converted * 1e3 + getattr(v, 'microsecond', 0) / 1e3 timestamp_seconds = calendar.timegm(v.utctimetuple())
timestamp = timestamp_seconds * 1e3 + getattr(v, 'microsecond', 0) / 1e3
except AttributeError: except AttributeError:
# Ints and floats are valid timestamps too # Ints and floats are valid timestamps too
if type(v) not in _number_types: if type(v) not in _number_types:
raise TypeError('DateType arguments must be a datetime or timestamp') raise TypeError('DateType arguments must be a datetime or timestamp')
timestamp = v
global _have_warned_about_timestamps return int64_pack(long(timestamp))
if not _have_warned_about_timestamps:
_have_warned_about_timestamps = True
warnings.warn(
"timestamp columns in Cassandra hold a number of "
"milliseconds since the unix epoch. Currently, when executing "
"prepared statements, this driver multiplies timestamp "
"values by 1000 so that the result of time.time() "
"can be used directly. However, the driver cannot "
"match this behavior for non-prepared statements, "
"so the 2.0 version of the driver will no longer multiply "
"timestamps by 1000. It is suggested that you simply use "
"datetime.datetime objects for 'timestamp' values to avoid "
"any ambiguity and to guarantee a smooth upgrade of the "
"driver.")
converted = v * 1e3
return int64_pack(long(converted))
class TimestampType(DateType): class TimestampType(DateType):
@@ -838,7 +829,7 @@ class MapType(_ParameterizedType):
length = 2 length = 2
numelements = unpack(byts[:length]) numelements = unpack(byts[:length])
p = length p = length
themap = OrderedDict() themap = OrderedMap()
for _ in range(numelements): for _ in range(numelements):
key_len = unpack(byts[p:p + length]) key_len = unpack(byts[p:p + length])
p += length p += length
@@ -850,7 +841,7 @@ class MapType(_ParameterizedType):
p += val_len p += val_len
key = subkeytype.from_binary(keybytes, protocol_version) key = subkeytype.from_binary(keybytes, protocol_version)
val = subvaltype.from_binary(valbytes, protocol_version) val = subvaltype.from_binary(valbytes, protocol_version)
themap[key] = val themap._insert(key, val)
return themap return themap
@classmethod @classmethod
@@ -929,37 +920,39 @@ class UserType(TupleType):
typename = "'org.apache.cassandra.db.marshal.UserType'" typename = "'org.apache.cassandra.db.marshal.UserType'"
_cache = {} _cache = {}
_module = sys.modules[__name__]
@classmethod @classmethod
def make_udt_class(cls, keyspace, udt_name, names_and_types, mapped_class): def make_udt_class(cls, keyspace, udt_name, names_and_types, mapped_class):
if six.PY2 and isinstance(udt_name, unicode): if six.PY2 and isinstance(udt_name, unicode):
udt_name = udt_name.encode('utf-8') udt_name = udt_name.encode('utf-8')
try: try:
return cls._cache[(keyspace, udt_name)] return cls._cache[(keyspace, udt_name)]
except KeyError: except KeyError:
fieldnames, types = zip(*names_and_types) field_names, types = zip(*names_and_types)
instance = type(udt_name, (cls,), {'subtypes': types, instance = type(udt_name, (cls,), {'subtypes': types,
'cassname': cls.cassname, 'cassname': cls.cassname,
'typename': udt_name, 'typename': udt_name,
'fieldnames': fieldnames, 'fieldnames': field_names,
'keyspace': keyspace, 'keyspace': keyspace,
'mapped_class': mapped_class}) 'mapped_class': mapped_class,
'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)})
cls._cache[(keyspace, udt_name)] = instance cls._cache[(keyspace, udt_name)] = instance
return instance return instance
@classmethod @classmethod
def apply_parameters(cls, subtypes, names): def apply_parameters(cls, subtypes, names):
keyspace = subtypes[0] keyspace = subtypes[0]
udt_name = unhexlify(subtypes[1].cassname) udt_name = _name_from_hex_string(subtypes[1].cassname)
field_names = [unhexlify(encoded_name) for encoded_name in names[2:]] field_names = [_name_from_hex_string(encoded_name) for encoded_name in names[2:]]
assert len(field_names) == len(subtypes[2:]) assert len(field_names) == len(subtypes[2:])
return type(udt_name, (cls,), {'subtypes': subtypes[2:], return type(udt_name, (cls,), {'subtypes': subtypes[2:],
'cassname': cls.cassname, 'cassname': cls.cassname,
'typename': udt_name, 'typename': udt_name,
'fieldnames': field_names, 'fieldnames': field_names,
'keyspace': keyspace, 'keyspace': keyspace,
'mapped_class': None}) 'mapped_class': None,
'tuple_type': namedtuple(udt_name, field_names)})
@classmethod @classmethod
def cql_parameterized_type(cls): def cql_parameterized_type(cls):
@@ -991,8 +984,7 @@ class UserType(TupleType):
if cls.mapped_class: if cls.mapped_class:
return cls.mapped_class(**dict(zip(cls.fieldnames, values))) return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
else: else:
Result = namedtuple(cls.typename, cls.fieldnames) return cls.tuple_type(*values)
return Result(*values)
@classmethod @classmethod
def serialize_safe(cls, val, protocol_version): def serialize_safe(cls, val, protocol_version):
@@ -1008,6 +1000,18 @@ class UserType(TupleType):
buf.write(int32_pack(-1)) buf.write(int32_pack(-1))
return buf.getvalue() return buf.getvalue()
@classmethod
def _make_registered_udt_namedtuple(cls, keyspace, name, field_names):
# this is required to make the type resolvable via this module...
# required when unregistered udts are pickled for use as keys in
# util.OrderedMap
qualified_name = "%s_%s" % (keyspace, name)
nt = getattr(cls._module, qualified_name, None)
if not nt:
nt = namedtuple(qualified_name, field_names)
setattr(cls._module, qualified_name, nt)
return nt
class CompositeType(_ParameterizedType): class CompositeType(_ParameterizedType):
typename = "'org.apache.cassandra.db.marshal.CompositeType'" typename = "'org.apache.cassandra.db.marshal.CompositeType'"

View File

@@ -28,7 +28,7 @@ import types
from uuid import UUID from uuid import UUID
import six import six
from cassandra.util import OrderedDict, sortedset from cassandra.util import OrderedDict, OrderedMap, sortedset
if six.PY3: if six.PY3:
long = int long = int
@@ -77,6 +77,7 @@ class Encoder(object):
datetime.time: self.cql_encode_time, datetime.time: self.cql_encode_time,
dict: self.cql_encode_map_collection, dict: self.cql_encode_map_collection,
OrderedDict: self.cql_encode_map_collection, OrderedDict: self.cql_encode_map_collection,
OrderedMap: self.cql_encode_map_collection,
list: self.cql_encode_list_collection, list: self.cql_encode_list_collection,
tuple: self.cql_encode_list_collection, tuple: self.cql_encode_list_collection,
set: self.cql_encode_set_collection, set: self.cql_encode_set_collection,

View File

@@ -0,0 +1,193 @@
# Copyright 2014 Symantec Corporation
# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Originally derived from MagnetoDB source:
# https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py
from collections import defaultdict
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL
import eventlet
from eventlet.green import select, socket
from eventlet.queue import Queue
from functools import partial
import logging
import os
from threading import Event
from six.moves import xrange
from cassandra import OperationTimedOut
from cassandra.connection import Connection, ConnectionShutdown
from cassandra.protocol import RegisterMessage
log = logging.getLogger(__name__)
def is_timeout(err):
return (
err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or
(err == EINVAL and os.name in ('nt', 'ce'))
)
class EventletConnection(Connection):
"""
An implementation of :class:`.Connection` that utilizes ``eventlet``.
"""
_total_reqd_bytes = 0
_read_watcher = None
_write_watcher = None
_socket = None
@classmethod
def initialize_reactor(cls):
eventlet.monkey_patch()
@classmethod
def factory(cls, *args, **kwargs):
timeout = kwargs.pop('timeout', 5.0)
conn = cls(*args, **kwargs)
conn.connected_event.wait(timeout)
if conn.last_error:
raise conn.last_error
elif not conn.connected_event.is_set():
conn.close()
raise OperationTimedOut("Timed out creating connection")
else:
return conn
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self.connected_event = Event()
self._write_queue = Queue()
self._callbacks = {}
self._push_watchers = defaultdict(set)
sockerr = None
addresses = socket.getaddrinfo(
self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM
)
for (af, socktype, proto, canonname, sockaddr) in addresses:
try:
self._socket = socket.socket(af, socktype, proto)
self._socket.settimeout(1.0)
self._socket.connect(sockaddr)
sockerr = None
break
except socket.error as err:
sockerr = err
if sockerr:
raise socket.error(
sockerr.errno,
"Tried connecting to %s. Last error: %s" % (
[a[4] for a in addresses], sockerr.strerror)
)
if self.sockopts:
for args in self.sockopts:
self._socket.setsockopt(*args)
self._read_watcher = eventlet.spawn(lambda: self.handle_read())
self._write_watcher = eventlet.spawn(lambda: self.handle_write())
self._send_options_message()
def close(self):
with self.lock:
if self.is_closed:
return
self.is_closed = True
log.debug("Closing connection (%s) to %s" % (id(self), self.host))
cur_gthread = eventlet.getcurrent()
if self._read_watcher and self._read_watcher != cur_gthread:
self._read_watcher.kill()
if self._write_watcher and self._write_watcher != cur_gthread:
self._write_watcher.kill()
if self._socket:
self._socket.close()
log.debug("Closed socket to %s" % (self.host,))
if not self.is_defunct:
self.error_all_callbacks(
ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging
self.connected_event.set()
def handle_close(self):
log.debug("connection closed by server")
self.close()
def handle_write(self):
while True:
try:
next_msg = self._write_queue.get()
self._socket.sendall(next_msg)
except socket.error as err:
log.debug("Exception during socket send for %s: %s", self, err)
self.defunct(err)
return # Leave the write loop
def handle_read(self):
run_select = partial(select.select, (self._socket,), (), ())
while True:
try:
run_select()
except Exception as exc:
if not self.is_closed:
log.debug("Exception during read select() for %s: %s",
self, exc)
self.defunct(exc)
return
try:
buf = self._socket.recv(self.in_buffer_size)
self._iobuf.write(buf)
except socket.error as err:
if not is_timeout(err):
log.debug("Exception during socket recv for %s: %s",
self, err)
self.defunct(err)
return # leave the read loop
if self._iobuf.tell():
self.process_io_buffer()
else:
log.debug("Connection %s closed by server", self)
self.close()
return
def push(self, data):
chunk_size = self.out_buffer_size
for i in xrange(0, len(data), chunk_size):
self._write_queue.put(data[i:i + chunk_size])
def register_watcher(self, event_type, callback, register_timeout=None):
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=[event_type]),
timeout=register_timeout)
def register_watchers(self, type_callback_dict, register_timeout=None):
for event_type, callback in type_callback_dict.items():
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=type_callback_dict.keys()),
timeout=register_timeout)

View File

@@ -211,7 +211,7 @@ class Metadata(object):
return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options)
def _build_usertype(self, keyspace, usertype_row): def _build_usertype(self, keyspace, usertype_row):
type_classes = map(types.lookup_casstype, usertype_row['field_types']) type_classes = list(map(types.lookup_casstype, usertype_row['field_types']))
return UserType(usertype_row['keyspace_name'], usertype_row['type_name'], return UserType(usertype_row['keyspace_name'], usertype_row['type_name'],
usertype_row['field_names'], type_classes) usertype_row['field_names'], type_classes)

View File

@@ -355,15 +355,21 @@ class HostConnection(object):
return return
def connection_finished_setting_keyspace(conn, error): def connection_finished_setting_keyspace(conn, error):
self.return_connection(conn)
errors = [] if not error else [error] errors = [] if not error else [error]
callback(self, errors) callback(self, errors)
self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace) self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
def get_connections(self):
c = self._connection
return [c] if c else []
def get_state(self): def get_state(self):
have_conn = self._connection is not None connection = self._connection
in_flight = self._connection.in_flight if have_conn else 0 open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0
return "shutdown: %s, open: %s, in_flights: %s" % (self.is_shutdown, have_conn, in_flight) in_flights = [connection.in_flight] if connection else []
return {'shutdown': self.is_shutdown, 'open_count': open_count, 'in_flights': in_flights}
_MAX_SIMULTANEOUS_CREATION = 1 _MAX_SIMULTANEOUS_CREATION = 1
@@ -683,6 +689,7 @@ class HostConnectionPool(object):
return return
def connection_finished_setting_keyspace(conn, error): def connection_finished_setting_keyspace(conn, error):
self.return_connection(conn)
remaining_callbacks.remove(conn) remaining_callbacks.remove(conn)
if error: if error:
errors.append(error) errors.append(error)
@@ -693,6 +700,9 @@ class HostConnectionPool(object):
for conn in self._connections: for conn in self._connections:
conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace) conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
def get_connections(self):
return self._connections
def get_state(self): def get_state(self):
in_flights = ", ".join([str(c.in_flight) for c in self._connections]) in_flights = [c.in_flight for c in self._connections]
return "shutdown: %s, open_count: %d, in_flights: %s" % (self.is_shutdown, self.open_count, in_flights) return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights}

View File

@@ -632,14 +632,14 @@ class ResultMessage(_MessageType):
typeclass = typeclass.apply_parameters((keysubtype, valsubtype)) typeclass = typeclass.apply_parameters((keysubtype, valsubtype))
elif typeclass == TupleType: elif typeclass == TupleType:
num_items = read_short(f) num_items = read_short(f)
types = tuple(cls.read_type(f, user_type_map) for _ in xrange(num_items)) types = tuple(cls.read_type(f, user_type_map) for _ in range(num_items))
typeclass = typeclass.apply_parameters(types) typeclass = typeclass.apply_parameters(types)
elif typeclass == UserType: elif typeclass == UserType:
ks = read_string(f) ks = read_string(f)
udt_name = read_string(f) udt_name = read_string(f)
num_fields = read_short(f) num_fields = read_short(f)
names_and_types = tuple((read_string(f), cls.read_type(f, user_type_map)) names_and_types = tuple((read_string(f), cls.read_type(f, user_type_map))
for _ in xrange(num_fields)) for _ in range(num_fields))
mapped_class = user_type_map.get(ks, {}).get(udt_name) mapped_class = user_type_map.get(ks, {}).get(udt_name)
typeclass = typeclass.make_udt_class( typeclass = typeclass.make_udt_class(
ks, udt_name, names_and_types, mapped_class) ks, udt_name, names_and_types, mapped_class)

View File

@@ -500,13 +500,13 @@ class BoundStatement(Statement):
try: try:
self.values.append(col_type.serialize(value, proto_version)) self.values.append(col_type.serialize(value, proto_version))
except (TypeError, struct.error): except (TypeError, struct.error) as exc:
col_name = col_spec[2] col_name = col_spec[2]
expected_type = col_type expected_type = col_type
actual_type = type(value) actual_type = type(value)
message = ('Received an argument of invalid type for column "%s". ' message = ('Received an argument of invalid type for column "%s". '
'Expected: %s, Got: %s' % (col_name, expected_type, actual_type)) 'Expected: %s, Got: %s; (%s)' % (col_name, expected_type, actual_type, exc))
raise TypeError(message) raise TypeError(message)
return self return self

View File

@@ -555,3 +555,101 @@ except ImportError:
if item in other: if item in other:
isect.add(item) isect.add(item)
return isect return isect
from collections import Mapping
import six
from six.moves import cPickle
class OrderedMap(Mapping):
'''
An ordered map that accepts non-hashable types for keys. It also maintains the
insertion order of items, behaving as OrderedDict in that regard. These maps
are constructed and read just as normal mapping types, exept that they may
contain arbitrary collections and other non-hashable items as keys::
>>> od = OrderedMap([({'one': 1, 'two': 2}, 'value'),
... ({'three': 3, 'four': 4}, 'value2')])
>>> list(od.keys())
[{'two': 2, 'one': 1}, {'three': 3, 'four': 4}]
>>> list(od.values())
['value', 'value2']
These constructs are needed to support nested collections in Cassandra 2.1.3+,
where frozen collections can be specified as parameters to others\*::
CREATE TABLE example (
...
value map<frozen<map<int, int>>, double>
...
)
This class dervies from the (immutable) Mapping API. Objects in these maps
are not intended be modified.
\* Note: Because of the way Cassandra encodes nested types, when using the
driver with nested collections, :attr:`~.Cluster.protocol_version` must be 3
or higher.
'''
def __init__(self, *args, **kwargs):
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
self._items = []
self._index = {}
if args:
e = args[0]
if callable(getattr(e, 'keys', None)):
for k in e.keys():
self._items.append((k, e[k]))
else:
for k, v in e:
self._insert(k, v)
for k, v in six.iteritems(kwargs):
self._insert(k, v)
def _insert(self, key, value):
flat_key = self._serialize_key(key)
i = self._index.get(flat_key, -1)
if i >= 0:
self._items[i] = (key, value)
else:
self._items.append((key, value))
self._index[flat_key] = len(self._items) - 1
def __getitem__(self, key):
index = self._index[self._serialize_key(key)]
return self._items[index][1]
def __iter__(self):
for i in self._items:
yield i[0]
def __len__(self):
return len(self._items)
def __eq__(self, other):
if isinstance(other, OrderedMap):
return self._items == other._items
try:
d = dict(other)
return len(d) == len(self._items) and all(i[1] == d[i[0]] for i in self._items)
except KeyError:
return False
except TypeError:
pass
return NotImplemented
def __repr__(self):
return '%s([%s])' % (
self.__class__.__name__,
', '.join("(%r, %r)" % (k, v) for k, v in self._items))
def __str__(self):
return '{%s}' % ', '.join("%s: %s" % (k, v) for k, v in self._items)
@staticmethod
def _serialize_key(key):
return cPickle.dumps(key)

View File

@@ -14,3 +14,9 @@
.. autoclass:: PlainTextAuthenticator .. autoclass:: PlainTextAuthenticator
:members: :members:
.. autoclass:: SaslAuthProvider
:members:
.. autoclass:: SaslAuthenticator
:members:

View File

@@ -39,6 +39,12 @@
.. autoattribute:: control_connection_timeout .. autoattribute:: control_connection_timeout
.. autoattribute:: idle_heartbeat_interval
.. autoattribute:: schema_event_refresh_window
.. autoattribute:: topology_event_refresh_window
.. automethod:: connect .. automethod:: connect
.. automethod:: shutdown .. automethod:: shutdown
@@ -57,6 +63,13 @@
.. automethod:: set_max_connections_per_host .. automethod:: set_max_connections_per_host
.. automethod:: refresh_schema
.. automethod:: refresh_nodes
.. automethod:: set_meta_refresh_enabled
.. autoclass:: Session () .. autoclass:: Session ()
.. autoattribute:: default_timeout .. autoattribute:: default_timeout

View File

@@ -0,0 +1,7 @@
``cassandra.io.eventletreactor`` - ``eventlet``-compatible Connection
=====================================================================
.. module:: cassandra.io.eventletreactor
.. autoclass:: EventletConnection
:members:

View File

@@ -0,0 +1,7 @@
``cassandra.util`` - Utilities
===================================
.. module:: cassandra.util
.. autoclass:: OrderedMap
:members:

View File

@@ -16,7 +16,9 @@ API Documentation
cassandra/decoder cassandra/decoder
cassandra/concurrent cassandra/concurrent
cassandra/connection cassandra/connection
cassandra/util
cassandra/io/asyncorereactor cassandra/io/asyncorereactor
cassandra/io/eventletreactor
cassandra/io/libevreactor cassandra/io/libevreactor
cassandra/io/geventreactor cassandra/io/geventreactor
cassandra/io/twistedreactor cassandra/io/twistedreactor

View File

@@ -34,9 +34,11 @@ to be explicit.
Custom Authenticators Custom Authenticators
^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^
If you're using something other than Cassandra's ``PasswordAuthenticator``, If you're using something other than Cassandra's ``PasswordAuthenticator``,
you may need to create your own subclasses of :class:`~.AuthProvider` and :class:`~.SaslAuthProvider` is provided for generic SASL authentication mechanisms,
:class:`~.Authenticator`. You can use :class:`~.PlainTextAuthProvider` utilizing the ``pure-sasl`` package.
and :class:`~.PlainTextAuthenticator` as example implementations. If these do not suit your needs, you may need to create your own subclasses of
:class:`~.AuthProvider` and :class:`~.Authenticator`. You can use the Sasl classes
as example implementations.
Protocol v1 Authentication Protocol v1 Authentication
^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@@ -20,6 +20,11 @@ if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests":
from gevent.monkey import patch_all from gevent.monkey import patch_all
patch_all() patch_all()
if __name__ == '__main__' and sys.argv[1] == "eventlet_nosetests":
print("Running eventlet tests")
from eventlet import monkey_patch
monkey_patch()
import ez_setup import ez_setup
ez_setup.use_setuptools() ez_setup.use_setuptools()
@@ -51,10 +56,14 @@ try:
from nose.commands import nosetests from nose.commands import nosetests
except ImportError: except ImportError:
gevent_nosetests = None gevent_nosetests = None
eventlet_nosetests = None
else: else:
class gevent_nosetests(nosetests): class gevent_nosetests(nosetests):
description = "run nosetests with gevent monkey patching" description = "run nosetests with gevent monkey patching"
class eventlet_nosetests(nosetests):
description = "run nosetests with eventlet monkey patching"
class DocCommand(Command): class DocCommand(Command):
@@ -174,10 +183,14 @@ On OSX, via homebrew:
def run_setup(extensions): def run_setup(extensions):
kw = {'cmdclass': {'doc': DocCommand}} kw = {'cmdclass': {'doc': DocCommand}}
if gevent_nosetests is not None: if gevent_nosetests is not None:
kw['cmdclass']['gevent_nosetests'] = gevent_nosetests kw['cmdclass']['gevent_nosetests'] = gevent_nosetests
if eventlet_nosetests is not None:
kw['cmdclass']['eventlet_nosetests'] = eventlet_nosetests
if extensions: if extensions:
kw['cmdclass']['build_ext'] = build_extensions kw['cmdclass']['build_ext'] = build_extensions
kw['ext_modules'] = extensions kw['ext_modules'] = extensions

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import sys
log = logging.getLogger() log = logging.getLogger()
log.setLevel('DEBUG') log.setLevel('DEBUG')
@@ -21,3 +22,18 @@ if not log.handlers:
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s [%(module)s:%(lineno)s]: %(message)s')) handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s [%(module)s:%(lineno)s]: %(message)s'))
log.addHandler(handler) log.addHandler(handler)
def is_gevent_monkey_patched():
return 'gevent.monkey' in sys.modules
def is_eventlet_monkey_patched():
if 'eventlet.patcher' in sys.modules:
import eventlet
return eventlet.patcher.is_monkey_patched('socket')
return False
def is_monkey_patched():
return is_gevent_monkey_patched() or is_eventlet_monkey_patched()

View File

@@ -13,11 +13,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
import time
from tests.integration import use_singledc, get_cluster, remove_cluster, PROTOCOL_VERSION
from cassandra.cluster import Cluster, NoHostAvailable from cassandra.cluster import Cluster, NoHostAvailable
from cassandra.auth import PlainTextAuthProvider, SASLClient, SaslAuthProvider from cassandra.auth import PlainTextAuthProvider, SASLClient, SaslAuthProvider
from tests.integration import use_singledc, get_cluster, remove_cluster, PROTOCOL_VERSION
from tests.integration.util import assert_quiescent_pool_state
try: try:
import unittest2 as unittest import unittest2 as unittest
@@ -35,7 +37,10 @@ def setup_module():
'authorizer': 'CassandraAuthorizer'} 'authorizer': 'CassandraAuthorizer'}
ccm_cluster.set_configuration_options(config_options) ccm_cluster.set_configuration_options(config_options)
log.debug("Starting ccm test cluster with %s", config_options) log.debug("Starting ccm test cluster with %s", config_options)
ccm_cluster.start(wait_for_binary_proto=True) ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True)
# there seems to be some race, with some versions of C* taking longer to
# get the auth (and default user) setup. Sleep here to give it a chance
time.sleep(2)
def teardown_module(): def teardown_module():
@@ -59,12 +64,13 @@ class AuthenticationTests(unittest.TestCase):
:return: authentication object suitable for Cluster.connect() :return: authentication object suitable for Cluster.connect()
""" """
if PROTOCOL_VERSION < 2: if PROTOCOL_VERSION < 2:
return lambda(hostname): dict(username=username, password=password) return lambda hostname: dict(username=username, password=password)
else: else:
return PlainTextAuthProvider(username=username, password=password) return PlainTextAuthProvider(username=username, password=password)
def cluster_as(self, usr, pwd): def cluster_as(self, usr, pwd):
return Cluster(protocol_version=PROTOCOL_VERSION, return Cluster(protocol_version=PROTOCOL_VERSION,
idle_heartbeat_interval=0,
auth_provider=self.get_authentication_provider(username=usr, password=pwd)) auth_provider=self.get_authentication_provider(username=usr, password=pwd))
def test_auth_connect(self): def test_auth_connect(self):
@@ -77,9 +83,11 @@ class AuthenticationTests(unittest.TestCase):
cluster = self.cluster_as(user, passwd) cluster = self.cluster_as(user, passwd)
session = cluster.connect() session = cluster.connect()
self.assertTrue(session.execute('SELECT release_version FROM system.local')) self.assertTrue(session.execute('SELECT release_version FROM system.local'))
assert_quiescent_pool_state(self, cluster)
cluster.shutdown() cluster.shutdown()
root_session.execute('DROP USER %s', user) root_session.execute('DROP USER %s', user)
assert_quiescent_pool_state(self, root_session.cluster)
root_session.cluster.shutdown() root_session.cluster.shutdown()
def test_connect_wrong_pwd(self): def test_connect_wrong_pwd(self):
@@ -88,6 +96,8 @@ class AuthenticationTests(unittest.TestCase):
'.*AuthenticationFailed.*Bad credentials.*Username and/or ' '.*AuthenticationFailed.*Bad credentials.*Username and/or '
'password are incorrect.*', 'password are incorrect.*',
cluster.connect) cluster.connect)
assert_quiescent_pool_state(self, cluster)
cluster.shutdown()
def test_connect_wrong_username(self): def test_connect_wrong_username(self):
cluster = self.cluster_as('wrong_user', 'cassandra') cluster = self.cluster_as('wrong_user', 'cassandra')
@@ -95,6 +105,8 @@ class AuthenticationTests(unittest.TestCase):
'.*AuthenticationFailed.*Bad credentials.*Username and/or ' '.*AuthenticationFailed.*Bad credentials.*Username and/or '
'password are incorrect.*', 'password are incorrect.*',
cluster.connect) cluster.connect)
assert_quiescent_pool_state(self, cluster)
cluster.shutdown()
def test_connect_empty_pwd(self): def test_connect_empty_pwd(self):
cluster = self.cluster_as('Cassandra', '') cluster = self.cluster_as('Cassandra', '')
@@ -102,12 +114,16 @@ class AuthenticationTests(unittest.TestCase):
'.*AuthenticationFailed.*Bad credentials.*Username and/or ' '.*AuthenticationFailed.*Bad credentials.*Username and/or '
'password are incorrect.*', 'password are incorrect.*',
cluster.connect) cluster.connect)
assert_quiescent_pool_state(self, cluster)
cluster.shutdown()
def test_connect_no_auth_provider(self): def test_connect_no_auth_provider(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION) cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.assertRaisesRegexp(NoHostAvailable, self.assertRaisesRegexp(NoHostAvailable,
'.*AuthenticationFailed.*Remote end requires authentication.*', '.*AuthenticationFailed.*Remote end requires authentication.*',
cluster.connect) cluster.connect)
assert_quiescent_pool_state(self, cluster)
cluster.shutdown()
class SaslAuthenticatorTests(AuthenticationTests): class SaslAuthenticatorTests(AuthenticationTests):

View File

@@ -12,18 +12,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from tests.integration import use_singledc, PROTOCOL_VERSION
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
import cassandra from collections import deque
from cassandra.query import SimpleStatement, TraceUnavailable from mock import patch
from cassandra.policies import RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance import time
from uuid import uuid4
import cassandra
from cassandra.cluster import Cluster, NoHostAvailable from cassandra.cluster import Cluster, NoHostAvailable
from cassandra.concurrent import execute_concurrent
from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy,
RetryPolicy, SimpleConvictionPolicy, HostDistance,
WhiteListRoundRobinPolicy)
from cassandra.query import SimpleStatement, TraceUnavailable
from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions
from tests.integration.util import assert_quiescent_pool_state
def setup_module(): def setup_module():
@@ -66,6 +74,8 @@ class ClusterTests(unittest.TestCase):
result = session.execute("SELECT * FROM clustertests.cf0") result = session.execute("SELECT * FROM clustertests.cf0")
self.assertEqual([('a', 'b', 'c')], result) self.assertEqual([('a', 'b', 'c')], result)
session.execute("DROP KEYSPACE clustertests")
cluster.shutdown() cluster.shutdown()
def test_connect_on_keyspace(self): def test_connect_on_keyspace(self):
@@ -202,6 +212,159 @@ class ClusterTests(unittest.TestCase):
self.assertIn("newkeyspace", cluster.metadata.keyspaces) self.assertIn("newkeyspace", cluster.metadata.keyspaces)
session.execute("DROP KEYSPACE newkeyspace")
def test_refresh_schema(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
original_meta = cluster.metadata.keyspaces
# full schema refresh, with wait
cluster.refresh_schema()
self.assertIsNot(original_meta, cluster.metadata.keyspaces)
self.assertEqual(original_meta, cluster.metadata.keyspaces)
session.shutdown()
def test_refresh_schema_keyspace(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
original_meta = cluster.metadata.keyspaces
original_system_meta = original_meta['system']
# only refresh one keyspace
cluster.refresh_schema(keyspace='system')
current_meta = cluster.metadata.keyspaces
self.assertIs(original_meta, current_meta)
current_system_meta = current_meta['system']
self.assertIsNot(original_system_meta, current_system_meta)
self.assertEqual(original_system_meta.as_cql_query(), current_system_meta.as_cql_query())
session.shutdown()
def test_refresh_schema_table(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
original_meta = cluster.metadata.keyspaces
original_system_meta = original_meta['system']
original_system_schema_meta = original_system_meta.tables['schema_columnfamilies']
# only refresh one table
cluster.refresh_schema(keyspace='system', table='schema_columnfamilies')
current_meta = cluster.metadata.keyspaces
current_system_meta = current_meta['system']
current_system_schema_meta = current_system_meta.tables['schema_columnfamilies']
self.assertIs(original_meta, current_meta)
self.assertIs(original_system_meta, current_system_meta)
self.assertIsNot(original_system_schema_meta, current_system_schema_meta)
self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query())
session.shutdown()
def test_refresh_schema_type(self):
if get_server_versions()[0] < (2, 1, 0):
raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1')
if PROTOCOL_VERSION < 3:
raise unittest.SkipTest('UDTs are not specified in change events for protocol v2')
# We may want to refresh types on keyspace change events in that case(?)
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
keyspace_name = 'test1rf'
type_name = self._testMethodName
session.execute('CREATE TYPE IF NOT EXISTS %s.%s (one int, two text)' % (keyspace_name, type_name))
original_meta = cluster.metadata.keyspaces
original_test1rf_meta = original_meta[keyspace_name]
original_type_meta = original_test1rf_meta.user_types[type_name]
# only refresh one type
cluster.refresh_schema(keyspace='test1rf', usertype=type_name)
current_meta = cluster.metadata.keyspaces
current_test1rf_meta = current_meta[keyspace_name]
current_type_meta = current_test1rf_meta.user_types[type_name]
self.assertIs(original_meta, current_meta)
self.assertIs(original_test1rf_meta, current_test1rf_meta)
self.assertIsNot(original_type_meta, current_type_meta)
self.assertEqual(original_type_meta.as_cql_query(), current_type_meta.as_cql_query())
session.shutdown()
def test_refresh_schema_no_wait(self):
contact_points = ['127.0.0.1']
cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=10,
contact_points=contact_points, load_balancing_policy=WhiteListRoundRobinPolicy(contact_points))
session = cluster.connect()
schema_ver = session.execute("SELECT schema_version FROM system.local WHERE key='local'")[0][0]
# create a schema disagreement
session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (uuid4(),))
try:
agreement_timeout = 1
# cluster agreement wait exceeded
c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=agreement_timeout)
start_time = time.time()
s = c.connect()
end_time = time.time()
self.assertGreaterEqual(end_time - start_time, agreement_timeout)
self.assertTrue(c.metadata.keyspaces)
# cluster agreement wait used for refresh
original_meta = c.metadata.keyspaces
start_time = time.time()
self.assertRaisesRegexp(Exception, r"Schema was not refreshed.*", c.refresh_schema)
end_time = time.time()
self.assertGreaterEqual(end_time - start_time, agreement_timeout)
self.assertIs(original_meta, c.metadata.keyspaces)
# refresh wait overrides cluster value
original_meta = c.metadata.keyspaces
start_time = time.time()
c.refresh_schema(max_schema_agreement_wait=0)
end_time = time.time()
self.assertLess(end_time - start_time, agreement_timeout)
self.assertIsNot(original_meta, c.metadata.keyspaces)
self.assertEqual(original_meta, c.metadata.keyspaces)
s.shutdown()
refresh_threshold = 0.5
# cluster agreement bypass
c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=0)
start_time = time.time()
s = c.connect()
end_time = time.time()
self.assertLess(end_time - start_time, refresh_threshold)
self.assertTrue(c.metadata.keyspaces)
# cluster agreement wait used for refresh
original_meta = c.metadata.keyspaces
start_time = time.time()
c.refresh_schema()
end_time = time.time()
self.assertLess(end_time - start_time, refresh_threshold)
self.assertIsNot(original_meta, c.metadata.keyspaces)
self.assertEqual(original_meta, c.metadata.keyspaces)
# refresh wait overrides cluster value
original_meta = c.metadata.keyspaces
start_time = time.time()
self.assertRaisesRegexp(Exception, r"Schema was not refreshed.*", c.refresh_schema, max_schema_agreement_wait=agreement_timeout)
end_time = time.time()
self.assertGreaterEqual(end_time - start_time, agreement_timeout)
self.assertIs(original_meta, c.metadata.keyspaces)
s.shutdown()
finally:
session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (schema_ver,))
session.shutdown()
def test_trace(self): def test_trace(self):
""" """
Ensure trace can be requested for async and non-async queries Ensure trace can be requested for async and non-async queries
@@ -271,3 +434,115 @@ class ClusterTests(unittest.TestCase):
self.assertIn(query, str(future)) self.assertIn(query, str(future))
self.assertIn('result', str(future)) self.assertIn('result', str(future))
def test_idle_heartbeat(self):
interval = 1
cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=interval)
if PROTOCOL_VERSION < 3:
cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
session = cluster.connect()
# This test relies on impl details of connection req id management to see if heartbeats
# are being sent. May need update if impl is changed
connection_request_ids = {}
for h in cluster.get_connection_holders():
for c in h.get_connections():
# make sure none are idle (should have startup messages)
self.assertFalse(c.is_idle)
with c.lock:
connection_request_ids[id(c)] = deque(c.request_ids) # copy of request ids
# let two heatbeat intervals pass (first one had startup messages in it)
time.sleep(2 * interval + interval/10.)
connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()]
# make sure requests were sent on all connections
for c in connections:
expected_ids = connection_request_ids[id(c)]
expected_ids.rotate(-1)
with c.lock:
self.assertListEqual(list(c.request_ids), list(expected_ids))
# assert idle status
self.assertTrue(all(c.is_idle for c in connections))
# send messages on all connections
statements_and_params = [("SELECT release_version FROM system.local", ())] * len(cluster.metadata.all_hosts())
results = execute_concurrent(session, statements_and_params)
for success, result in results:
self.assertTrue(success)
# assert not idle status
self.assertFalse(any(c.is_idle if not c.is_control_connection else False for c in connections))
# holders include session pools and cc
holders = cluster.get_connection_holders()
self.assertIn(cluster.control_connection, holders)
self.assertEqual(len(holders), len(cluster.metadata.all_hosts()) + 1) # hosts pools, 1 for cc
# include additional sessions
session2 = cluster.connect()
holders = cluster.get_connection_holders()
self.assertIn(cluster.control_connection, holders)
self.assertEqual(len(holders), 2 * len(cluster.metadata.all_hosts()) + 1) # 2 sessions' hosts pools, 1 for cc
cluster._idle_heartbeat.stop()
cluster._idle_heartbeat.join()
assert_quiescent_pool_state(self, cluster)
session.shutdown()
@patch('cassandra.cluster.Cluster.idle_heartbeat_interval', new=0.1)
def test_idle_heartbeat_disabled(self):
self.assertTrue(Cluster.idle_heartbeat_interval)
# heartbeat disabled with '0'
cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0)
self.assertEqual(cluster.idle_heartbeat_interval, 0)
session = cluster.connect()
# let two heatbeat intervals pass (first one had startup messages in it)
time.sleep(2 * Cluster.idle_heartbeat_interval)
connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()]
# assert not idle status (should never get reset because there is not heartbeat)
self.assertFalse(any(c.is_idle for c in connections))
session.shutdown()
def test_pool_management(self):
# Ensure that in_flight and request_ids quiesce after cluster operations
cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0) # no idle heartbeat here, pool management is tested in test_idle_heartbeat
session = cluster.connect()
session2 = cluster.connect()
# prepare
p = session.prepare("SELECT * FROM system.local WHERE key=?")
self.assertTrue(session.execute(p, ('local',)))
# simple
self.assertTrue(session.execute("SELECT * FROM system.local WHERE key='local'"))
# set keyspace
session.set_keyspace('system')
session.set_keyspace('system_traces')
# use keyspace
session.execute('USE system')
session.execute('USE system_traces')
# refresh schema
cluster.refresh_schema()
cluster.refresh_schema(max_schema_agreement_wait=0)
# submit schema refresh
future = cluster.submit_schema_refresh()
future.result()
assert_quiescent_pool_state(self, cluster)
session2.shutdown()
session.shutdown()

View File

@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from tests.integration import use_singledc, PROTOCOL_VERSION
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
@@ -21,13 +19,15 @@ except ImportError:
from functools import partial from functools import partial
from six.moves import range from six.moves import range
import sys
from threading import Thread, Event from threading import Thread, Event
from cassandra import ConsistencyLevel, OperationTimedOut from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cluster import NoHostAvailable from cassandra.cluster import NoHostAvailable
from cassandra.protocol import QueryMessage
from cassandra.io.asyncorereactor import AsyncoreConnection from cassandra.io.asyncorereactor import AsyncoreConnection
from cassandra.protocol import QueryMessage
from tests import is_monkey_patched
from tests.integration import use_singledc, PROTOCOL_VERSION
try: try:
from cassandra.io.libevreactor import LibevConnection from cassandra.io.libevreactor import LibevConnection
@@ -230,8 +230,8 @@ class AsyncoreConnectionTests(ConnectionTests, unittest.TestCase):
klass = AsyncoreConnection klass = AsyncoreConnection
def setUp(self): def setUp(self):
if 'gevent.monkey' in sys.modules: if is_monkey_patched():
raise unittest.SkipTest("Can't test asyncore with gevent monkey patching") raise unittest.SkipTest("Can't test asyncore with monkey patching")
ConnectionTests.setUp(self) ConnectionTests.setUp(self)
@@ -240,8 +240,8 @@ class LibevConnectionTests(ConnectionTests, unittest.TestCase):
klass = LibevConnection klass = LibevConnection
def setUp(self): def setUp(self):
if 'gevent.monkey' in sys.modules: if is_monkey_patched():
raise unittest.SkipTest("Can't test libev with gevent monkey patching") raise unittest.SkipTest("Can't test libev with monkey patching")
if LibevConnection is None: if LibevConnection is None:
raise unittest.SkipTest( raise unittest.SkipTest(
'libev does not appear to be installed properly') 'libev does not appear to be installed properly')

View File

@@ -11,15 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import six
import difflib
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
import difflib
from mock import Mock from mock import Mock
import six
import sys
from cassandra import AlreadyExists from cassandra import AlreadyExists
@@ -361,6 +362,9 @@ class TestCodeCoverage(unittest.TestCase):
"Protocol 3.0+ is required for UDT change events, currently testing against %r" "Protocol 3.0+ is required for UDT change events, currently testing against %r"
% (PROTOCOL_VERSION,)) % (PROTOCOL_VERSION,))
if sys.version_info[2:] != (2, 7):
raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.')
cluster = Cluster(protocol_version=PROTOCOL_VERSION) cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
@@ -549,6 +553,9 @@ CREATE TABLE export_udts.users (
if get_server_versions()[0] < (2, 1, 0): if get_server_versions()[0] < (2, 1, 0):
raise unittest.SkipTest('Test schema output assumes 2.1.0+ options') raise unittest.SkipTest('Test schema output assumes 2.1.0+ options')
if sys.version_info[2:] != (2, 7):
raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.')
cli_script = """CREATE KEYSPACE legacy cli_script = """CREATE KEYSPACE legacy
WITH placement_strategy = 'SimpleStrategy' WITH placement_strategy = 'SimpleStrategy'
AND strategy_options = {replication_factor:1}; AND strategy_options = {replication_factor:1};
@@ -633,7 +640,7 @@ create column family composite_comp_with_col
index_type : 0}] index_type : 0}]
and compression_options = {'sstable_compression' : 'org.apache.cassandra.io.compress.LZ4Compressor'};""" and compression_options = {'sstable_compression' : 'org.apache.cassandra.io.compress.LZ4Compressor'};"""
# note: the innerlkey type for legacy.nested_composite_key # note: the inner key type for legacy.nested_composite_key
# (org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UUIDType, org.apache.cassandra.db.marshal.UTF8Type)) # (org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UUIDType, org.apache.cassandra.db.marshal.UTF8Type))
# is a bit strange, but it replays in CQL with desired results # is a bit strange, but it replays in CQL with desired results
expected_string = """CREATE KEYSPACE legacy WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND durable_writes = true; expected_string = """CREATE KEYSPACE legacy WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND durable_writes = true;
@@ -807,6 +814,7 @@ CREATE TABLE legacy.composite_comp_no_col (
cluster.shutdown() cluster.shutdown()
class TokenMetadataTest(unittest.TestCase): class TokenMetadataTest(unittest.TestCase):
""" """
Test of TokenMap creation and other behavior. Test of TokenMap creation and other behavior.

View File

@@ -21,8 +21,10 @@ except ImportError:
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
from collections import namedtuple
from decimal import Decimal from decimal import Decimal
from datetime import datetime, date, time from datetime import datetime, date, time
from functools import partial
import six import six
from uuid import uuid1, uuid4 from uuid import uuid1, uuid4
@@ -30,10 +32,14 @@ from cassandra import InvalidRequest
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.cqltypes import Int32Type, EMPTY from cassandra.cqltypes import Int32Type, EMPTY
from cassandra.query import dict_factory from cassandra.query import dict_factory
from cassandra.util import OrderedDict, sortedset from cassandra.util import OrderedMap, sortedset
from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION
# defined in module scope for pickling in OrderedMap
nested_collection_udt = namedtuple('nested_collection_udt', ['m', 't', 'l', 's'])
nested_collection_udt_nested = namedtuple('nested_collection_udt_nested', ['m', 't', 'l', 's', 'u'])
def setup_module(): def setup_module():
use_singledc() use_singledc()
@@ -287,11 +293,11 @@ class TypeTests(unittest.TestCase):
s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)", s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)",
('', '', '', [''], {'': 3})) ('', '', '', [''], {'': 3}))
self.assertEqual( self.assertEqual(
{'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedDict({'': 3})}, {'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedMap({'': 3})},
s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0]) s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0])
self.assertEqual( self.assertEqual(
{'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedDict({'': 3})}, {'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedMap({'': 3})},
s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0]) s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0])
# non-string types shouldn't accept empty strings # non-string types shouldn't accept empty strings
@@ -708,10 +714,10 @@ class TypeTests(unittest.TestCase):
self.assertEquals((None, None, None, None), s.execute(read)[0].t) self.assertEquals((None, None, None, None), s.execute(read)[0].t)
# also test empty strings where compatible # also test empty strings where compatible
s.execute(insert, [('', None, None, '')]) s.execute(insert, [('', None, None, b'')])
result = s.execute("SELECT * FROM mytable WHERE k=0") result = s.execute("SELECT * FROM mytable WHERE k=0")
self.assertEquals(('', None, None, ''), result[0].t) self.assertEquals(('', None, None, b''), result[0].t)
self.assertEquals(('', None, None, ''), s.execute(read)[0].t) self.assertEquals(('', None, None, b''), s.execute(read)[0].t)
c.shutdown() c.shutdown()
@@ -721,3 +727,61 @@ class TypeTests(unittest.TestCase):
query = u"SELECT * FROM system.schema_columnfamilies WHERE keyspace_name = 'ef\u2052ef' AND columnfamily_name = %s" query = u"SELECT * FROM system.schema_columnfamilies WHERE keyspace_name = 'ef\u2052ef' AND columnfamily_name = %s"
s.execute(query, (u"fe\u2051fe",)) s.execute(query, (u"fe\u2051fe",))
def insert_select_column(self, session, table_name, column_name, value):
insert = session.prepare("INSERT INTO %s (k, %s) VALUES (?, ?)" % (table_name, column_name))
session.execute(insert, (0, value))
result = session.execute("SELECT %s FROM %s WHERE k=%%s" % (column_name, table_name), (0,))[0][0]
self.assertEqual(result, value)
def test_nested_collections(self):
if self._cass_version < (2, 1, 3):
raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3")
name = self._testMethodName
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect('test1rf')
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""
CREATE TYPE %s (
m frozen<map<int,text>>,
t tuple<int,text>,
l frozen<list<int>>,
s frozen<set<int>>
)""" % name)
s.execute("""
CREATE TYPE %s_nested (
m frozen<map<int,text>>,
t tuple<int,text>,
l frozen<list<int>>,
s frozen<set<int>>,
u frozen<%s>
)""" % (name, name))
s.execute("""
CREATE TABLE %s (
k int PRIMARY KEY,
map_map map<frozen<map<int,int>>, frozen<map<int,int>>>,
map_set map<frozen<set<int>>, frozen<set<int>>>,
map_list map<frozen<list<int>>, frozen<list<int>>>,
map_tuple map<frozen<tuple<int, int>>, frozen<tuple<int>>>,
map_udt map<frozen<%s_nested>, frozen<%s>>,
)""" % (name, name, name))
validate = partial(self.insert_select_column, s, name)
validate('map_map', OrderedMap([({1: 1, 2: 2}, {3: 3, 4: 4}), ({5: 5, 6: 6}, {7: 7, 8: 8})]))
validate('map_set', OrderedMap([(set((1, 2)), set((3, 4))), (set((5, 6)), set((7, 8)))]))
validate('map_list', OrderedMap([([1, 2], [3, 4]), ([5, 6], [7, 8])]))
validate('map_tuple', OrderedMap([((1, 2), (3,)), ((4, 5), (6,))]))
value = nested_collection_udt({1: 'v1', 2: 'v2'}, (3, 'v3'), [4, 5, 6, 7], set((8, 9, 10)))
key = nested_collection_udt_nested(value.m, value.t, value.l, value.s, value)
key2 = nested_collection_udt_nested({3: 'v3'}, value.t, value.l, value.s, value)
validate('map_udt', OrderedMap([(key, value), (key2, value)]))
s.execute("DROP TABLE %s" % (name))
s.execute("DROP TYPE %s_nested" % (name))
s.execute("DROP TYPE %s" % (name))
s.shutdown()

37
tests/integration/util.py Normal file
View File

@@ -0,0 +1,37 @@
# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.integration import PROTOCOL_VERSION
def assert_quiescent_pool_state(test_case, cluster):
for session in cluster.sessions:
pool_states = session.get_pool_state().values()
test_case.assertTrue(pool_states)
for state in pool_states:
test_case.assertFalse(state['shutdown'])
test_case.assertGreater(state['open_count'], 0)
test_case.assertTrue(all((i == 0 for i in state['in_flights'])))
for holder in cluster.get_connection_holders():
for connection in holder.get_connections():
# all ids are unique
req_ids = connection.request_ids
test_case.assertEqual(len(req_ids), len(set(req_ids)))
test_case.assertEqual(connection.highest_request_id, len(req_ids) - 1)
test_case.assertEqual(connection.highest_request_id, max(req_ids))
if PROTOCOL_VERSION < 3:
test_case.assertEqual(connection.highest_request_id, connection.max_request_id)

View File

@@ -31,20 +31,20 @@ from mock import patch, Mock
from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT, from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT,
ConnectionException) ConnectionException)
from cassandra.io.asyncorereactor import AsyncoreConnection
from cassandra.protocol import (write_stringmultimap, write_int, write_string, from cassandra.protocol import (write_stringmultimap, write_int, write_string,
SupportedMessage, ReadyMessage, ServerError) SupportedMessage, ReadyMessage, ServerError)
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
from cassandra.io.asyncorereactor import AsyncoreConnection from tests import is_monkey_patched
class AsyncoreConnectionTest(unittest.TestCase): class AsyncoreConnectionTest(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
if 'gevent.monkey' in sys.modules: if is_monkey_patched():
raise unittest.SkipTest("gevent monkey-patching detected") raise unittest.SkipTest("monkey-patching detected")
AsyncoreConnection.initialize_reactor() AsyncoreConnection.initialize_reactor()
cls.socket_patcher = patch('socket.socket', spec=socket.socket) cls.socket_patcher = patch('socket.socket', spec=socket.socket)
cls.mock_socket = cls.socket_patcher.start() cls.mock_socket = cls.socket_patcher.start()

View File

@@ -11,21 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import six
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
from mock import Mock, ANY, call
import six
from six import BytesIO from six import BytesIO
import time
from mock import Mock, ANY from threading import Lock
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT,
HEADER_DIRECTION_FROM_CLIENT, ProtocolError, HEADER_DIRECTION_FROM_CLIENT, ProtocolError,
locally_supported_compressions) locally_supported_compressions, ConnectionHeartbeat)
from cassandra.marshal import uint8_pack, uint32_pack from cassandra.marshal import uint8_pack, uint32_pack
from cassandra.protocol import (write_stringmultimap, write_int, write_string, from cassandra.protocol import (write_stringmultimap, write_int, write_string,
SupportedMessage) SupportedMessage)
@@ -280,3 +280,155 @@ class ConnectionTest(unittest.TestCase):
def test_set_connection_class(self): def test_set_connection_class(self):
cluster = Cluster(connection_class='test') cluster = Cluster(connection_class='test')
self.assertEqual('test', cluster.connection_class) self.assertEqual('test', cluster.connection_class)
class ConnectionHeartbeatTest(unittest.TestCase):
@staticmethod
def make_get_holders(len):
holders = []
for _ in range(len):
holder = Mock()
holder.get_connections = Mock(return_value=[])
holders.append(holder)
get_holders = Mock(return_value=holders)
return get_holders
def run_heartbeat(self, get_holders_fun, count=2, interval=0.05):
ch = ConnectionHeartbeat(interval, get_holders_fun)
time.sleep(interval * count)
ch.stop()
ch.join()
self.assertTrue(get_holders_fun.call_count)
def test_empty_connections(self):
count = 3
get_holders = self.make_get_holders(1)
self.run_heartbeat(get_holders, count)
self.assertGreaterEqual(get_holders.call_count, count - 1) # lower bound to account for thread spinup time
self.assertLessEqual(get_holders.call_count, count)
holder = get_holders.return_value[0]
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
def test_idle_non_idle(self):
request_id = 999
# connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback)
def send_msg(msg, req_id, msg_callback):
msg_callback(SupportedMessage([], {}))
idle_connection = Mock(spec=Connection, host='localhost',
max_request_id=127,
lock=Lock(),
in_flight=0, is_idle=True,
is_defunct=False, is_closed=False,
get_request_id=lambda: request_id,
send_msg=Mock(side_effect=send_msg))
non_idle_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=False, is_closed=False)
get_holders = self.make_get_holders(1)
holder = get_holders.return_value[0]
holder.get_connections.return_value.append(idle_connection)
holder.get_connections.return_value.append(non_idle_connection)
self.run_heartbeat(get_holders)
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
self.assertEqual(idle_connection.in_flight, 0)
self.assertEqual(non_idle_connection.in_flight, 0)
idle_connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count)
self.assertEqual(non_idle_connection.send_msg.call_count, 0)
def test_closed_defunct(self):
get_holders = self.make_get_holders(1)
closed_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=False, is_closed=True)
defunct_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=True, is_closed=False)
holder = get_holders.return_value[0]
holder.get_connections.return_value.append(closed_connection)
holder.get_connections.return_value.append(defunct_connection)
self.run_heartbeat(get_holders)
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
self.assertEqual(closed_connection.in_flight, 0)
self.assertEqual(defunct_connection.in_flight, 0)
self.assertEqual(closed_connection.send_msg.call_count, 0)
self.assertEqual(defunct_connection.send_msg.call_count, 0)
def test_no_req_ids(self):
in_flight = 3
get_holders = self.make_get_holders(1)
max_connection = Mock(spec=Connection, host='localhost',
max_request_id=in_flight, in_flight=in_flight,
is_idle=True, is_defunct=False, is_closed=False)
holder = get_holders.return_value[0]
holder.get_connections.return_value.append(max_connection)
self.run_heartbeat(get_holders)
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
self.assertEqual(max_connection.in_flight, in_flight)
self.assertEqual(max_connection.send_msg.call_count, 0)
self.assertEqual(max_connection.send_msg.call_count, 0)
max_connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count)
holder.return_connection.assert_has_calls([call(max_connection)] * get_holders.call_count)
def test_unexpected_response(self):
request_id = 999
get_holders = self.make_get_holders(1)
def send_msg(msg, req_id, msg_callback):
msg_callback(object())
connection = Mock(spec=Connection, host='localhost',
max_request_id=127,
lock=Lock(),
in_flight=0, is_idle=True,
is_defunct=False, is_closed=False,
get_request_id=lambda: request_id,
send_msg=Mock(side_effect=send_msg))
holder = get_holders.return_value[0]
holder.get_connections.return_value.append(connection)
self.run_heartbeat(get_holders)
self.assertEqual(connection.in_flight, get_holders.call_count)
connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count)
connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count)
exc = connection.defunct.call_args_list[0][0][0]
self.assertIsInstance(exc, Exception)
self.assertEqual(exc.args, Exception('Connection heartbeat failure').args)
holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count)
def test_timeout(self):
request_id = 999
get_holders = self.make_get_holders(1)
def send_msg(msg, req_id, msg_callback):
pass
connection = Mock(spec=Connection, host='localhost',
max_request_id=127,
lock=Lock(),
in_flight=0, is_idle=True,
is_defunct=False, is_closed=False,
get_request_id=lambda: request_id,
send_msg=Mock(side_effect=send_msg))
holder = get_holders.return_value[0]
holder.get_connections.return_value.append(connection)
self.run_heartbeat(get_holders)
self.assertEqual(connection.in_flight, get_holders.call_count)
connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count)
connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count)
exc = connection.defunct.call_args_list[0][0][0]
self.assertIsInstance(exc, Exception)
self.assertEqual(exc.args, Exception('Connection heartbeat failure').args)
holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count)

View File

@@ -73,7 +73,7 @@ class MockCluster(object):
self.scheduler = Mock(spec=_Scheduler) self.scheduler = Mock(spec=_Scheduler)
self.executor = Mock(spec=ThreadPoolExecutor) self.executor = Mock(spec=ThreadPoolExecutor)
def add_host(self, address, datacenter, rack, signal=False): def add_host(self, address, datacenter, rack, signal=False, refresh_nodes=True):
host = Host(address, SimpleConvictionPolicy, datacenter, rack) host = Host(address, SimpleConvictionPolicy, datacenter, rack)
self.added_hosts.append(host) self.added_hosts.append(host)
return host return host
@@ -131,7 +131,7 @@ class ControlConnectionTest(unittest.TestCase):
self.connection = MockConnection() self.connection = MockConnection()
self.time = FakeTime() self.time = FakeTime()
self.control_connection = ControlConnection(self.cluster, timeout=1) self.control_connection = ControlConnection(self.cluster, 1, 0, 0)
self.control_connection._connection = self.connection self.control_connection._connection = self.connection
self.control_connection._time = self.time self.control_connection._time = self.time
@@ -345,39 +345,44 @@ class ControlConnectionTest(unittest.TestCase):
'change_type': 'NEW_NODE', 'change_type': 'NEW_NODE',
'address': ('1.2.3.4', 9000) 'address': ('1.2.3.4', 9000)
} }
self.cluster.scheduler.reset_mock()
self.control_connection._handle_topology_change(event) self.control_connection._handle_topology_change(event)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map) self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map)
event = { event = {
'change_type': 'REMOVED_NODE', 'change_type': 'REMOVED_NODE',
'address': ('1.2.3.4', 9000) 'address': ('1.2.3.4', 9000)
} }
self.cluster.scheduler.reset_mock()
self.control_connection._handle_topology_change(event) self.control_connection._handle_topology_change(event)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.remove_host, None) self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.remove_host, None)
event = { event = {
'change_type': 'MOVED_NODE', 'change_type': 'MOVED_NODE',
'address': ('1.2.3.4', 9000) 'address': ('1.2.3.4', 9000)
} }
self.cluster.scheduler.reset_mock()
self.control_connection._handle_topology_change(event) self.control_connection._handle_topology_change(event)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map) self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map)
def test_handle_status_change(self): def test_handle_status_change(self):
event = { event = {
'change_type': 'UP', 'change_type': 'UP',
'address': ('1.2.3.4', 9000) 'address': ('1.2.3.4', 9000)
} }
self.cluster.scheduler.reset_mock()
self.control_connection._handle_status_change(event) self.control_connection._handle_status_change(event)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map) self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map)
# do the same with a known Host # do the same with a known Host
event = { event = {
'change_type': 'UP', 'change_type': 'UP',
'address': ('192.168.1.0', 9000) 'address': ('192.168.1.0', 9000)
} }
self.cluster.scheduler.reset_mock()
self.control_connection._handle_status_change(event) self.control_connection._handle_status_change(event)
host = self.cluster.metadata.hosts['192.168.1.0'] host = self.cluster.metadata.hosts['192.168.1.0']
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.on_up, host) self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.on_up, host)
self.cluster.scheduler.schedule.reset_mock() self.cluster.scheduler.schedule.reset_mock()
event = { event = {
@@ -404,9 +409,11 @@ class ControlConnectionTest(unittest.TestCase):
'keyspace': 'ks1', 'keyspace': 'ks1',
'table': 'table1' 'table': 'table1'
} }
self.cluster.scheduler.reset_mock()
self.control_connection._handle_schema_change(event) self.control_connection._handle_schema_change(event)
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', 'table1', None) self.cluster.scheduler.schedule_unique.assert_called_once_with(0.0, self.control_connection.refresh_schema, 'ks1', 'table1', None)
self.cluster.scheduler.reset_mock()
event['table'] = None event['table'] = None
self.control_connection._handle_schema_change(event) self.control_connection._handle_schema_change(event)
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', None, None) self.cluster.scheduler.schedule_unique.assert_called_once_with(0.0, self.control_connection.refresh_schema, 'ks1', None, None)

View File

@@ -24,7 +24,7 @@ from decimal import Decimal
from uuid import UUID from uuid import UUID
from cassandra.cqltypes import lookup_casstype from cassandra.cqltypes import lookup_casstype
from cassandra.util import OrderedDict, sortedset from cassandra.util import OrderedMap, sortedset
marshalled_value_pairs = ( marshalled_value_pairs = (
# binary form, type, python native type # binary form, type, python native type
@@ -75,7 +75,7 @@ marshalled_value_pairs = (
(b'', 'MapType(AsciiType, BooleanType)', None), (b'', 'MapType(AsciiType, BooleanType)', None),
(b'', 'ListType(FloatType)', None), (b'', 'ListType(FloatType)', None),
(b'', 'SetType(LongType)', None), (b'', 'SetType(LongType)', None),
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()), (b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMap()),
(b'\x00\x00', 'ListType(FloatType)', []), (b'\x00\x00', 'ListType(FloatType)', []),
(b'\x00\x00', 'SetType(IntegerType)', sortedset()), (b'\x00\x00', 'SetType(IntegerType)', sortedset()),
(b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]), (b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]),
@@ -84,15 +84,14 @@ marshalled_value_pairs = (
(b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', 1) (b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', 1)
) )
ordered_dict_value = OrderedDict() ordered_map_value = OrderedMap([(u'\u307fbob', 199),
ordered_dict_value[u'\u307fbob'] = 199 (u'', -1),
ordered_dict_value[u''] = -1 (u'\\', 0)])
ordered_dict_value[u'\\'] = 0
# these following entries work for me right now, but they're dependent on # these following entries work for me right now, but they're dependent on
# vagaries of internal python ordering for unordered types # vagaries of internal python ordering for unordered types
marshalled_value_pairs_unsafe = ( marshalled_value_pairs_unsafe = (
(b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value), (b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_map_value),
(b'\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])), (b'\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])),
(b'\x00', 'IntegerType', 0), (b'\x00', 'IntegerType', 0),
) )

View File

@@ -0,0 +1,127 @@
# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from cassandra.util import OrderedMap
from cassandra.cqltypes import EMPTY
import six
class OrderedMapTest(unittest.TestCase):
def test_init(self):
a = OrderedMap(zip(['one', 'three', 'two'], [1, 3, 2]))
b = OrderedMap([('one', 1), ('three', 3), ('two', 2)])
c = OrderedMap(a)
builtin = {'one': 1, 'two': 2, 'three': 3}
self.assertEqual(a, b)
self.assertEqual(a, c)
self.assertEqual(a, builtin)
self.assertEqual(OrderedMap([(1, 1), (1, 2)]), {1: 2})
def test_contains(self):
keys = ['first', 'middle', 'last']
od = OrderedMap()
od = OrderedMap(zip(keys, range(len(keys))))
for k in keys:
self.assertTrue(k in od)
self.assertFalse(k not in od)
self.assertTrue('notthere' not in od)
self.assertFalse('notthere' in od)
def test_keys(self):
keys = ['first', 'middle', 'last']
od = OrderedMap(zip(keys, range(len(keys))))
self.assertListEqual(list(od.keys()), keys)
def test_values(self):
keys = ['first', 'middle', 'last']
values = list(range(len(keys)))
od = OrderedMap(zip(keys, values))
self.assertListEqual(list(od.values()), values)
def test_items(self):
keys = ['first', 'middle', 'last']
items = list(zip(keys, range(len(keys))))
od = OrderedMap(items)
self.assertListEqual(list(od.items()), items)
def test_get(self):
keys = ['first', 'middle', 'last']
od = OrderedMap(zip(keys, range(len(keys))))
for v, k in enumerate(keys):
self.assertEqual(od.get(k), v)
self.assertEqual(od.get('notthere', 'default'), 'default')
self.assertIsNone(od.get('notthere'))
def test_equal(self):
d1 = {'one': 1}
d12 = {'one': 1, 'two': 2}
od1 = OrderedMap({'one': 1})
od12 = OrderedMap([('one', 1), ('two', 2)])
od21 = OrderedMap([('two', 2), ('one', 1)])
self.assertEqual(od1, d1)
self.assertEqual(od12, d12)
self.assertEqual(od21, d12)
self.assertNotEqual(od1, od12)
self.assertNotEqual(od12, od1)
self.assertNotEqual(od12, od21)
self.assertNotEqual(od1, d12)
self.assertNotEqual(od12, d1)
self.assertNotEqual(od1, EMPTY)
def test_getitem(self):
keys = ['first', 'middle', 'last']
od = OrderedMap(zip(keys, range(len(keys))))
for v, k in enumerate(keys):
self.assertEqual(od[k], v)
with self.assertRaises(KeyError):
od['notthere']
def test_iter(self):
keys = ['first', 'middle', 'last']
values = list(range(len(keys)))
items = list(zip(keys, values))
od = OrderedMap(items)
itr = iter(od)
self.assertEqual(sum([1 for _ in itr]), len(keys))
self.assertRaises(StopIteration, six.next, itr)
self.assertEqual(list(iter(od)), keys)
self.assertEqual(list(six.iteritems(od)), items)
self.assertEqual(list(six.itervalues(od)), values)
def test_len(self):
self.assertEqual(len(OrderedMap()), 0)
self.assertEqual(len(OrderedMap([(1, 1)])), 1)
def test_mutable_keys(self):
d = {'1': 1}
s = set([1, 2, 3])
od = OrderedMap([(d, 'dict'), (s, 'set')])

View File

@@ -262,11 +262,13 @@ class TypeTests(unittest.TestCase):
self.assertEqual(cassandra_type.val, 'randomvaluetocheck') self.assertEqual(cassandra_type.val, 'randomvaluetocheck')
def test_datetype(self): def test_datetype(self):
now_timestamp = time.time() now_time_seconds = time.time()
now_datetime = datetime.datetime.utcfromtimestamp(now_timestamp) now_datetime = datetime.datetime.utcfromtimestamp(now_time_seconds)
# Cassandra timestamps in millis
now_timestamp = now_time_seconds * 1e3
# same results serialized # same results serialized
# (this could change if we follow up on the timestamp multiplication warning in DateType.serialize)
self.assertEqual(DateType.serialize(now_datetime, 0), DateType.serialize(now_timestamp, 0)) self.assertEqual(DateType.serialize(now_datetime, 0), DateType.serialize(now_timestamp, 0))
# from timestamp # from timestamp