Merge remote-tracking branch 'origin/master' into PYTHON-190

Conflicts:
	cassandra/cqltypes.py
	tests/integration/standard/test_types.py
This commit is contained in:
Adam Holmberg
2015-01-30 13:54:08 -06:00
33 changed files with 1501 additions and 177 deletions

View File

@@ -1,14 +1,22 @@
2.1.3-post
==========
2.1.4
=====
Features
--------
* SaslAuthenticator for Kerberos support (PYTHON-109)
* Heartbeat for network device keepalive and detecting failures on idle connections (PYTHON-197)
* Support nested, frozen collections for Cassandra 2.1.3+ (PYTHON-186)
* Schema agreement wait bypass config, new call for synchronous schema refresh (PYTHON-205)
* Add eventlet connection support (PYTHON-194)
Bug Fixes
---------
* Schema meta fix for complex thrift tables (PYTHON-191)
* Support for 'unknown' replica placement strategies in schema meta (PYTHON-192)
* Resolve stream ID leak on set_keyspace (PYTHON-195)
* Remove implicit timestamp scaling on serialization of numeric timestamps (PYTHON-204)
* Resolve stream id collision when using SASL auth (PYTHON-210)
* Correct unhexlify usage for user defined type meta in Python3 (PYTHON-208)
2.1.3
=====

View File

@@ -23,7 +23,7 @@ class NullHandler(logging.Handler):
logging.getLogger('cassandra').addHandler(NullHandler())
__version_info__ = (2, 1, 3, 'post')
__version_info__ = (2, 1, 4, 'post')
__version__ = '.'.join(map(str, __version_info__))

View File

@@ -130,7 +130,9 @@ class PlainTextAuthenticator(Authenticator):
class SaslAuthProvider(AuthProvider):
"""
An :class:`~.AuthProvider` for Kerberos authenticators
An :class:`~.AuthProvider` supporting general SASL auth mechanisms
Suitable for GSSAPI or other SASL mechanisms
Example usage::
@@ -144,7 +146,7 @@ class SaslAuthProvider(AuthProvider):
auth_provider = SaslAuthProvider(**sasl_kwargs)
cluster = Cluster(auth_provider=auth_provider)
.. versionadded:: 2.1.3-post
.. versionadded:: 2.1.4
"""
def __init__(self, **sasl_kwargs):
@@ -157,9 +159,10 @@ class SaslAuthProvider(AuthProvider):
class SaslAuthenticator(Authenticator):
"""
An :class:`~.Authenticator` that works with DSE's KerberosAuthenticator.
A pass-through :class:`~.Authenticator` using the third party package
'pure-sasl' for authentication
.. versionadded:: 2.1.3-post
.. versionadded:: 2.1.4
"""
def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs):

View File

@@ -22,6 +22,7 @@ import atexit
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import logging
from random import random
import socket
import sys
import time
@@ -44,7 +45,8 @@ from itertools import groupby
from cassandra import (ConsistencyLevel, AuthenticationFailed,
InvalidRequest, OperationTimedOut,
UnsupportedOperation, Unauthorized)
from cassandra.connection import ConnectionException, ConnectionShutdown
from cassandra.connection import (ConnectionException, ConnectionShutdown,
ConnectionHeartbeat)
from cassandra.encoder import Encoder
from cassandra.protocol import (QueryMessage, ResultMessage,
ErrorMessage, ReadTimeoutErrorMessage,
@@ -68,10 +70,20 @@ from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
BatchStatement, bind_params, QueryTrace, Statement,
named_tuple_factory, dict_factory, FETCH_SIZE_UNSET)
# default to gevent when we are monkey patched, otherwise if libev is available, use that as the
# default because it's fastest. Otherwise, use asyncore.
def _is_eventlet_monkey_patched():
if 'eventlet.patcher' not in sys.modules:
return False
import eventlet.patcher
return eventlet.patcher.is_monkey_patched('socket')
# default to gevent when we are monkey patched with gevent, eventlet when
# monkey patched with eventlet, otherwise if libev is available, use that as
# the default because it's fastest. Otherwise, use asyncore.
if 'gevent.monkey' in sys.modules:
from cassandra.io.geventreactor import GeventConnection as DefaultConnection
elif _is_eventlet_monkey_patched():
from cassandra.io.eventletreactor import EventletConnection as DefaultConnection
else:
try:
from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA
@@ -372,6 +384,48 @@ class Cluster(object):
If set to :const:`None`, there will be no timeout for these queries.
"""
idle_heartbeat_interval = 30
"""
Interval, in seconds, on which to heartbeat idle connections. This helps
keep connections open through network devices that expire idle connections.
It also helps discover bad connections early in low-traffic scenarios.
Setting to zero disables heartbeats.
"""
schema_event_refresh_window = 2
"""
Window, in seconds, within which a schema component will be refreshed after
receiving a schema_change event.
The driver delays a random amount of time in the range [0.0, window)
before executing the refresh. This serves two purposes:
1.) Spread the refresh for deployments with large fanout from C* to client tier,
preventing a 'thundering herd' problem with many clients refreshing simultaneously.
2.) Remove redundant refreshes. Redundant events arriving within the delay period
are discarded, and only one refresh is executed.
Setting this to zero will execute refreshes immediately.
Setting this negative will disable schema refreshes in response to push events
(refreshes will still occur in response to schema change responses to DDL statements
executed by Sessions of this Cluster).
"""
topology_event_refresh_window = 10
"""
Window, in seconds, within which the node and token list will be refreshed after
receiving a topology_change event.
Setting this to zero will execute refreshes immediately.
Setting this negative will disable node refreshes in response to push events
(refreshes will still occur in response to new nodes observed on "UP" events).
See :attr:`.schema_event_refresh_window` for discussion of rationale
"""
sessions = None
control_connection = None
scheduler = None
@@ -380,6 +434,7 @@ class Cluster(object):
_is_setup = False
_prepared_statements = None
_prepared_statement_lock = None
_idle_heartbeat = None
_user_types = None
"""
@@ -406,7 +461,10 @@ class Cluster(object):
protocol_version=2,
executor_threads=2,
max_schema_agreement_wait=10,
control_connection_timeout=2.0):
control_connection_timeout=2.0,
idle_heartbeat_interval=30,
schema_event_refresh_window=2,
topology_event_refresh_window=10):
"""
Any of the mutable Cluster attributes may be set as keyword arguments
to the constructor.
@@ -456,6 +514,9 @@ class Cluster(object):
self.cql_version = cql_version
self.max_schema_agreement_wait = max_schema_agreement_wait
self.control_connection_timeout = control_connection_timeout
self.idle_heartbeat_interval = idle_heartbeat_interval
self.schema_event_refresh_window = schema_event_refresh_window
self.topology_event_refresh_window = topology_event_refresh_window
self._listeners = set()
self._listener_lock = Lock()
@@ -500,7 +561,8 @@ class Cluster(object):
self.metrics = Metrics(weakref.proxy(self))
self.control_connection = ControlConnection(
self, self.control_connection_timeout)
self, self.control_connection_timeout,
self.schema_event_refresh_window, self.topology_event_refresh_window)
def register_user_type(self, keyspace, user_type, klass):
"""
@@ -621,7 +683,7 @@ class Cluster(object):
def set_max_connections_per_host(self, host_distance, max_connections):
"""
Gets the maximum number of connections per Session that will be opened
Sets the maximum number of connections per Session that will be opened
for each host with :class:`~.HostDistance` equal to `host_distance`.
The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for
:attr:`~HostDistance.REMOTE`.
@@ -688,18 +750,19 @@ class Cluster(object):
self.load_balancing_policy.populate(
weakref.proxy(self), self.metadata.all_hosts())
if self.control_connection:
try:
self.control_connection.connect()
log.debug("Control connection created")
except Exception:
log.exception("Control connection failed to connect, "
"shutting down Cluster:")
self.shutdown()
raise
try:
self.control_connection.connect()
log.debug("Control connection created")
except Exception:
log.exception("Control connection failed to connect, "
"shutting down Cluster:")
self.shutdown()
raise
self.load_balancing_policy.check_supported()
if self.idle_heartbeat_interval:
self._idle_heartbeat = ConnectionHeartbeat(self.idle_heartbeat_interval, self.get_connection_holders)
self._is_setup = True
session = self._new_session()
@@ -707,6 +770,13 @@ class Cluster(object):
session.set_keyspace(keyspace)
return session
def get_connection_holders(self):
holders = []
for s in self.sessions:
holders.extend(s.get_pools())
holders.append(self.control_connection)
return holders
def shutdown(self):
"""
Closes all sessions and connection associated with this Cluster.
@@ -721,18 +791,17 @@ class Cluster(object):
else:
self.is_shutdown = True
if self.scheduler:
self.scheduler.shutdown()
if self._idle_heartbeat:
self._idle_heartbeat.stop()
if self.control_connection:
self.control_connection.shutdown()
self.scheduler.shutdown()
if self.sessions:
for session in self.sessions:
session.shutdown()
self.control_connection.shutdown()
if self.executor:
self.executor.shutdown()
for session in self.sessions:
session.shutdown()
self.executor.shutdown()
def _new_session(self):
session = Session(self, self.metadata.all_hosts())
@@ -907,7 +976,7 @@ class Cluster(object):
self._start_reconnector(host, is_host_addition)
def on_add(self, host):
def on_add(self, host, refresh_nodes=True):
if self.is_shutdown:
return
@@ -919,7 +988,7 @@ class Cluster(object):
log.debug("Done preparing queries for new host %r", host)
self.load_balancing_policy.on_add(host)
self.control_connection.on_add(host)
self.control_connection.on_add(host, refresh_nodes)
if distance == HostDistance.IGNORED:
log.debug("Not adding connection pool for new host %r because the "
@@ -995,7 +1064,7 @@ class Cluster(object):
self.on_down(host, is_host_addition, expect_host_to_be_down)
return is_down
def add_host(self, address, datacenter=None, rack=None, signal=True):
def add_host(self, address, datacenter=None, rack=None, signal=True, refresh_nodes=True):
"""
Called when adding initial contact points and when the control
connection subsequently discovers a new node. Intended for internal
@@ -1004,7 +1073,7 @@ class Cluster(object):
new_host = self.metadata.add_host(address, datacenter, rack)
if new_host and signal:
log.info("New Cassandra host %r discovered", new_host)
self.on_add(new_host)
self.on_add(new_host, refresh_nodes)
return new_host
@@ -1045,16 +1114,20 @@ class Cluster(object):
for pool in session._pools.values():
pool.ensure_core_connections()
def refresh_schema(self, keyspace=None, table=None, usertype=None, schema_agreement_wait=None):
def refresh_schema(self, keyspace=None, table=None, usertype=None, max_schema_agreement_wait=None):
"""
Synchronously refresh the schema metadata.
By default timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait`
By default, the timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait`
and :attr:`~.Cluster.control_connection_timeout`.
Passing schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`.
Setting schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately.
Passing max_schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`.
Setting max_schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately.
An Exception is raised if schema refresh fails for any reason.
"""
if not self.control_connection.refresh_schema(keyspace, table, usertype, schema_agreement_wait):
if not self.control_connection.refresh_schema(keyspace, table, usertype, max_schema_agreement_wait):
raise Exception("Schema was not refreshed. See log for details.")
def submit_schema_refresh(self, keyspace=None, table=None, usertype=None):
@@ -1066,6 +1139,27 @@ class Cluster(object):
return self.executor.submit(
self.control_connection.refresh_schema, keyspace, table, usertype)
def refresh_nodes(self):
"""
Synchronously refresh the node list and token metadata
An Exception is raised if node refresh fails for any reason.
"""
if not self.control_connection.refresh_node_list_and_token_map():
raise Exception("Node list was not refreshed. See log for details.")
def set_meta_refresh_enabled(self, enabled):
"""
Sets a flag to enable (True) or disable (False) all metadata refresh queries.
This applies to both schema and node topology.
Disabling this is useful to minimize refreshes during multiple changes.
Meta refresh must be enabled for the driver to become aware of any cluster
topology changes or schema updates.
"""
self.control_connection.set_meta_refresh_enabled(bool(enabled))
def _prepare_all_queries(self, host):
if not self._prepared_statements:
return
@@ -1656,6 +1750,9 @@ class Session(object):
def get_pool_state(self):
return dict((host, pool.get_state()) for host, pool in self._pools.items())
def get_pools(self):
return self._pools.values()
class UserTypeDoesNotExist(Exception):
"""
@@ -1734,16 +1831,26 @@ class ControlConnection(object):
_timeout = None
_protocol_version = None
_schema_event_refresh_window = None
_topology_event_refresh_window = None
_meta_refresh_enabled = True
# for testing purposes
_time = time
def __init__(self, cluster, timeout):
def __init__(self, cluster, timeout,
schema_event_refresh_window,
topology_event_refresh_window):
# use a weak reference to allow the Cluster instance to be GC'ed (and
# shutdown) since implementing __del__ disables the cycle detector
self._cluster = weakref.proxy(cluster)
self._connection = None
self._timeout = timeout
self._schema_event_refresh_window = schema_event_refresh_window
self._topology_event_refresh_window = topology_event_refresh_window
self._lock = RLock()
self._schema_agreement_lock = Lock()
@@ -1901,6 +2008,10 @@ class ControlConnection(object):
def refresh_schema(self, keyspace=None, table=None, usertype=None,
schema_agreement_wait=None):
if not self._meta_refresh_enabled:
log.debug("[control connection] Skipping schema refresh because meta refresh is disabled")
return False
try:
if self._connection:
return self._refresh_schema(self._connection, keyspace, table, usertype,
@@ -2028,14 +2139,20 @@ class ControlConnection(object):
return True
def refresh_node_list_and_token_map(self, force_token_rebuild=False):
if not self._meta_refresh_enabled:
log.debug("[control connection] Skipping node list refresh because meta refresh is disabled")
return False
try:
if self._connection:
self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild)
return True
except ReferenceError:
pass # our weak reference to the Cluster is no good
except Exception:
log.debug("[control connection] Error refreshing node list and token map", exc_info=True)
self._signal_error()
return False
def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
force_token_rebuild=False):
@@ -2096,7 +2213,7 @@ class ControlConnection(object):
rack = row.get("rack")
if host is None:
log.debug("[control connection] Found new host to connect to: %s", addr)
host = self._cluster.add_host(addr, datacenter, rack, signal=True)
host = self._cluster.add_host(addr, datacenter, rack, signal=True, refresh_nodes=False)
should_rebuild_token_map = True
else:
should_rebuild_token_map |= self._update_location_info(host, datacenter, rack)
@@ -2131,25 +2248,25 @@ class ControlConnection(object):
def _handle_topology_change(self, event):
change_type = event["change_type"]
addr, port = event["address"]
if change_type == "NEW_NODE":
self._cluster.scheduler.schedule(10, self.refresh_node_list_and_token_map)
if change_type == "NEW_NODE" or change_type == "MOVED_NODE":
if self._topology_event_refresh_window >= 0:
delay = random() * self._topology_event_refresh_window
self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map)
elif change_type == "REMOVED_NODE":
host = self._cluster.metadata.get_host(addr)
self._cluster.scheduler.schedule(0, self._cluster.remove_host, host)
elif change_type == "MOVED_NODE":
self._cluster.scheduler.schedule(1, self.refresh_node_list_and_token_map)
self._cluster.scheduler.schedule_unique(0, self._cluster.remove_host, host)
def _handle_status_change(self, event):
change_type = event["change_type"]
addr, port = event["address"]
host = self._cluster.metadata.get_host(addr)
if change_type == "UP":
delay = 1 + random() * 0.5 # randomness to avoid thundering herd problem on events
if host is None:
# this is the first time we've seen the node
self._cluster.scheduler.schedule(2, self.refresh_node_list_and_token_map)
self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map)
else:
# this will be run by the scheduler
self._cluster.scheduler.schedule(2, self._cluster.on_up, host)
self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host)
elif change_type == "DOWN":
# Note that there is a slight risk we can receive the event late and thus
# mark the host down even though we already had reconnected successfully.
@@ -2160,10 +2277,14 @@ class ControlConnection(object):
self._cluster.on_down(host, is_host_addition=False)
def _handle_schema_change(self, event):
if self._schema_event_refresh_window < 0:
return
keyspace = event.get('keyspace')
table = event.get('table')
usertype = event.get('type')
self._submit(self.refresh_schema, keyspace, table, usertype)
delay = random() * self._schema_event_refresh_window
self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, keyspace, table, usertype)
def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None):
@@ -2271,11 +2392,6 @@ class ControlConnection(object):
# manually
self.reconnect()
@property
def is_open(self):
conn = self._connection
return bool(conn and conn.is_open)
def on_up(self, host):
pass
@@ -2289,12 +2405,24 @@ class ControlConnection(object):
# this will result in a task being submitted to the executor to reconnect
self.reconnect()
def on_add(self, host):
self.refresh_node_list_and_token_map(force_token_rebuild=True)
def on_add(self, host, refresh_nodes=True):
if refresh_nodes:
self.refresh_node_list_and_token_map(force_token_rebuild=True)
def on_remove(self, host):
self.refresh_node_list_and_token_map(force_token_rebuild=True)
def get_connections(self):
c = getattr(self, '_connection', None)
return [c] if c else []
def return_connection(self, connection):
if connection is self._connection and (connection.is_defunct or connection.is_closed):
self.reconnect()
def set_meta_refresh_enabled(self, enabled):
self._meta_refresh_enabled = enabled
def _stop_scheduler(scheduler, thread):
try:
@@ -2308,12 +2436,14 @@ def _stop_scheduler(scheduler, thread):
class _Scheduler(object):
_scheduled = None
_queue = None
_scheduled_tasks = None
_executor = None
is_shutdown = False
def __init__(self, executor):
self._scheduled = Queue.PriorityQueue()
self._queue = Queue.PriorityQueue()
self._scheduled_tasks = set()
self._executor = executor
t = Thread(target=self.run, name="Task Scheduler")
@@ -2331,14 +2461,25 @@ class _Scheduler(object):
# this can happen on interpreter shutdown
pass
self.is_shutdown = True
self._scheduled.put_nowait((0, None))
self._queue.put_nowait((0, None))
def schedule(self, delay, fn, *args, **kwargs):
def schedule(self, delay, fn, *args):
self._insert_task(delay, (fn, args))
def schedule_unique(self, delay, fn, *args):
task = (fn, args)
if task not in self._scheduled_tasks:
self._insert_task(delay, task)
else:
log.debug("Ignoring schedule_unique for already-scheduled task: %r", task)
def _insert_task(self, delay, task):
if not self.is_shutdown:
run_at = time.time() + delay
self._scheduled.put_nowait((run_at, (fn, args, kwargs)))
self._scheduled_tasks.add(task)
self._queue.put_nowait((run_at, task))
else:
log.debug("Ignoring scheduled function after shutdown: %r", fn)
log.debug("Ignoring scheduled task after shutdown: %r", task)
def run(self):
while True:
@@ -2347,16 +2488,17 @@ class _Scheduler(object):
try:
while True:
run_at, task = self._scheduled.get(block=True, timeout=None)
run_at, task = self._queue.get(block=True, timeout=None)
if self.is_shutdown:
log.debug("Not executing scheduled task due to Scheduler shutdown")
return
if run_at <= time.time():
fn, args, kwargs = task
future = self._executor.submit(fn, *args, **kwargs)
self._scheduled_tasks.remove(task)
fn, args = task
future = self._executor.submit(fn, *args)
future.add_done_callback(self._log_if_failed)
else:
self._scheduled.put_nowait((run_at, task))
self._queue.put_nowait((run_at, task))
break
except Queue.Empty:
pass
@@ -2373,9 +2515,13 @@ class _Scheduler(object):
def refresh_schema_and_set_result(keyspace, table, usertype, control_conn, response_future):
try:
log.debug("Refreshing schema in response to schema change. Keyspace: %s; Table: %s, Type: %s",
keyspace, table, usertype)
control_conn._refresh_schema(response_future._connection, keyspace, table, usertype)
if control_conn._meta_refresh_enabled:
log.debug("Refreshing schema in response to schema change. Keyspace: %s; Table: %s, Type: %s",
keyspace, table, usertype)
control_conn._refresh_schema(response_future._connection, keyspace, table, usertype)
else:
log.debug("Skipping schema refresh in response to schema change because meta refresh is disabled; "
"Keyspace: %s; Table: %s, Type: %s", keyspace, table, usertype)
except Exception:
log.exception("Exception refreshing schema in response to schema change:")
response_future.session.submit(

View File

@@ -20,7 +20,7 @@ import io
import logging
import os
import sys
from threading import Event, RLock
from threading import Thread, Event, RLock
import time
if 'gevent.monkey' in sys.modules:
@@ -159,7 +159,7 @@ class Connection(object):
in_flight = 0
# A set of available request IDs. When using the v3 protocol or higher,
# this will no initially include all request IDs in order to save memory,
# this will not initially include all request IDs in order to save memory,
# but the set will grow if it is exhausted.
request_ids = None
@@ -172,6 +172,8 @@ class Connection(object):
lock = None
user_type_map = None
msg_received = False
is_control_connection = False
_iobuf = None
@@ -401,6 +403,8 @@ class Connection(object):
with self.lock:
self.request_ids.append(stream_id)
self.msg_received = True
body = None
try:
# check that the protocol version is supported
@@ -673,6 +677,13 @@ class Connection(object):
self.send_msg(query, request_id, process_result)
@property
def is_idle(self):
return not self.msg_received
def reset_idle(self):
self.msg_received = False
def __str__(self):
status = ""
if self.is_defunct:
@@ -732,3 +743,100 @@ class ResponseWaiter(object):
raise OperationTimedOut()
else:
return self.responses
class HeartbeatFuture(object):
def __init__(self, connection, owner):
self._exception = None
self._event = Event()
self.connection = connection
self.owner = owner
log.debug("Sending options message heartbeat on idle connection (%s) %s",
id(connection), connection.host)
with connection.lock:
if connection.in_flight < connection.max_request_id:
connection.in_flight += 1
connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback)
else:
self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold")
self._event.set()
def wait(self, timeout):
self._event.wait(timeout)
if self._event.is_set():
if self._exception:
raise self._exception
else:
raise OperationTimedOut()
def _options_callback(self, response):
if not isinstance(response, SupportedMessage):
if isinstance(response, ConnectionException):
self._exception = response
else:
self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s"
% (response,))
log.debug("Received options response on connection (%s) from %s",
id(self.connection), self.connection.host)
self._event.set()
class ConnectionHeartbeat(Thread):
def __init__(self, interval_sec, get_connection_holders):
Thread.__init__(self, name="Connection heartbeat")
self._interval = interval_sec
self._get_connection_holders = get_connection_holders
self._shutdown_event = Event()
self.daemon = True
self.start()
def run(self):
self._shutdown_event.wait(self._interval)
while not self._shutdown_event.is_set():
start_time = time.time()
futures = []
failed_connections = []
try:
for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]:
for connection in connections:
if not (connection.is_defunct or connection.is_closed):
if connection.is_idle:
try:
futures.append(HeartbeatFuture(connection, owner))
except Exception:
log.warning("Failed sending heartbeat message on connection (%s) to %s",
id(connection), connection.host, exc_info=True)
failed_connections.append((connection, owner))
else:
connection.reset_idle()
else:
# make sure the owner sees this defunt/closed connection
owner.return_connection(connection)
for f in futures:
connection = f.connection
try:
f.wait(self._interval)
# TODO: move this, along with connection locks in pool, down into Connection
with connection.lock:
connection.in_flight -= 1
connection.reset_idle()
except Exception:
log.warning("Heartbeat failed for connection (%s) to %s",
id(connection), connection.host, exc_info=True)
failed_connections.append((f.connection, f.owner))
for connection, owner in failed_connections:
connection.defunct(Exception('Connection heartbeat failure'))
owner.return_connection(connection)
except Exception:
log.error("Failed connection heartbeat", exc_info=True)
elapsed = time.time() - start_time
self._shutdown_event.wait(max(self._interval - elapsed, 0.01))
def stop(self):
self._shutdown_event.set()

View File

@@ -31,14 +31,14 @@ from __future__ import absolute_import # to enable import io from stdlib
from binascii import unhexlify
import calendar
from collections import namedtuple
import datetime
from decimal import Decimal
import io
import re
import socket
import time
import datetime
import sys
from uuid import UUID
import warnings
import six
from six.moves import range
@@ -48,7 +48,7 @@ from cassandra.marshal import (int8_pack, int8_unpack,
int32_pack, int32_unpack, int64_pack, int64_unpack,
float_pack, float_unpack, double_pack, double_unpack,
varint_pack, varint_unpack)
from cassandra.util import OrderedDict, sortedset
from cassandra.util import OrderedMap, sortedset
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
@@ -58,10 +58,15 @@ if six.PY3:
_time_types = frozenset((int,))
_date_types = frozenset((int,))
long = int
def _name_from_hex_string(encoded_name):
bin_str = unhexlify(encoded_name)
return bin_str.decode('ascii')
else:
_number_types = frozenset((int, long, float))
_time_types = frozenset((int, long))
_date_types = frozenset((int, long))
_name_from_hex_string = unhexlify
def trim_if_startswith(s, prefix):
@@ -569,7 +574,8 @@ class DateType(_CassandraType):
tval = time.strptime(val, tformat)
except ValueError:
continue
return calendar.timegm(tval) + offset
# scale seconds to millis for the raw value
return (calendar.timegm(tval) + offset) * 1e3
else:
raise ValueError("can't interpret %r as a date" % (val,))
@@ -584,31 +590,16 @@ class DateType(_CassandraType):
@staticmethod
def serialize(v, protocol_version):
try:
converted = calendar.timegm(v.utctimetuple())
converted = converted * 1e3 + getattr(v, 'microsecond', 0) / 1e3
# v is datetime
timestamp_seconds = calendar.timegm(v.utctimetuple())
timestamp = timestamp_seconds * 1e3 + getattr(v, 'microsecond', 0) / 1e3
except AttributeError:
# Ints and floats are valid timestamps too
if type(v) not in _number_types:
raise TypeError('DateType arguments must be a datetime or timestamp')
timestamp = v
global _have_warned_about_timestamps
if not _have_warned_about_timestamps:
_have_warned_about_timestamps = True
warnings.warn(
"timestamp columns in Cassandra hold a number of "
"milliseconds since the unix epoch. Currently, when executing "
"prepared statements, this driver multiplies timestamp "
"values by 1000 so that the result of time.time() "
"can be used directly. However, the driver cannot "
"match this behavior for non-prepared statements, "
"so the 2.0 version of the driver will no longer multiply "
"timestamps by 1000. It is suggested that you simply use "
"datetime.datetime objects for 'timestamp' values to avoid "
"any ambiguity and to guarantee a smooth upgrade of the "
"driver.")
converted = v * 1e3
return int64_pack(long(converted))
return int64_pack(long(timestamp))
class TimestampType(DateType):
@@ -838,7 +829,7 @@ class MapType(_ParameterizedType):
length = 2
numelements = unpack(byts[:length])
p = length
themap = OrderedDict()
themap = OrderedMap()
for _ in range(numelements):
key_len = unpack(byts[p:p + length])
p += length
@@ -850,7 +841,7 @@ class MapType(_ParameterizedType):
p += val_len
key = subkeytype.from_binary(keybytes, protocol_version)
val = subvaltype.from_binary(valbytes, protocol_version)
themap[key] = val
themap._insert(key, val)
return themap
@classmethod
@@ -929,37 +920,39 @@ class UserType(TupleType):
typename = "'org.apache.cassandra.db.marshal.UserType'"
_cache = {}
_module = sys.modules[__name__]
@classmethod
def make_udt_class(cls, keyspace, udt_name, names_and_types, mapped_class):
if six.PY2 and isinstance(udt_name, unicode):
udt_name = udt_name.encode('utf-8')
try:
return cls._cache[(keyspace, udt_name)]
except KeyError:
fieldnames, types = zip(*names_and_types)
field_names, types = zip(*names_and_types)
instance = type(udt_name, (cls,), {'subtypes': types,
'cassname': cls.cassname,
'typename': udt_name,
'fieldnames': fieldnames,
'fieldnames': field_names,
'keyspace': keyspace,
'mapped_class': mapped_class})
'mapped_class': mapped_class,
'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)})
cls._cache[(keyspace, udt_name)] = instance
return instance
@classmethod
def apply_parameters(cls, subtypes, names):
keyspace = subtypes[0]
udt_name = unhexlify(subtypes[1].cassname)
field_names = [unhexlify(encoded_name) for encoded_name in names[2:]]
udt_name = _name_from_hex_string(subtypes[1].cassname)
field_names = [_name_from_hex_string(encoded_name) for encoded_name in names[2:]]
assert len(field_names) == len(subtypes[2:])
return type(udt_name, (cls,), {'subtypes': subtypes[2:],
'cassname': cls.cassname,
'typename': udt_name,
'fieldnames': field_names,
'keyspace': keyspace,
'mapped_class': None})
'mapped_class': None,
'tuple_type': namedtuple(udt_name, field_names)})
@classmethod
def cql_parameterized_type(cls):
@@ -991,8 +984,7 @@ class UserType(TupleType):
if cls.mapped_class:
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
else:
Result = namedtuple(cls.typename, cls.fieldnames)
return Result(*values)
return cls.tuple_type(*values)
@classmethod
def serialize_safe(cls, val, protocol_version):
@@ -1008,6 +1000,18 @@ class UserType(TupleType):
buf.write(int32_pack(-1))
return buf.getvalue()
@classmethod
def _make_registered_udt_namedtuple(cls, keyspace, name, field_names):
# this is required to make the type resolvable via this module...
# required when unregistered udts are pickled for use as keys in
# util.OrderedMap
qualified_name = "%s_%s" % (keyspace, name)
nt = getattr(cls._module, qualified_name, None)
if not nt:
nt = namedtuple(qualified_name, field_names)
setattr(cls._module, qualified_name, nt)
return nt
class CompositeType(_ParameterizedType):
typename = "'org.apache.cassandra.db.marshal.CompositeType'"

View File

@@ -28,7 +28,7 @@ import types
from uuid import UUID
import six
from cassandra.util import OrderedDict, sortedset
from cassandra.util import OrderedDict, OrderedMap, sortedset
if six.PY3:
long = int
@@ -77,6 +77,7 @@ class Encoder(object):
datetime.time: self.cql_encode_time,
dict: self.cql_encode_map_collection,
OrderedDict: self.cql_encode_map_collection,
OrderedMap: self.cql_encode_map_collection,
list: self.cql_encode_list_collection,
tuple: self.cql_encode_list_collection,
set: self.cql_encode_set_collection,

View File

@@ -0,0 +1,193 @@
# Copyright 2014 Symantec Corporation
# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Originally derived from MagnetoDB source:
# https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py
from collections import defaultdict
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL
import eventlet
from eventlet.green import select, socket
from eventlet.queue import Queue
from functools import partial
import logging
import os
from threading import Event
from six.moves import xrange
from cassandra import OperationTimedOut
from cassandra.connection import Connection, ConnectionShutdown
from cassandra.protocol import RegisterMessage
log = logging.getLogger(__name__)
def is_timeout(err):
return (
err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or
(err == EINVAL and os.name in ('nt', 'ce'))
)
class EventletConnection(Connection):
"""
An implementation of :class:`.Connection` that utilizes ``eventlet``.
"""
_total_reqd_bytes = 0
_read_watcher = None
_write_watcher = None
_socket = None
@classmethod
def initialize_reactor(cls):
eventlet.monkey_patch()
@classmethod
def factory(cls, *args, **kwargs):
timeout = kwargs.pop('timeout', 5.0)
conn = cls(*args, **kwargs)
conn.connected_event.wait(timeout)
if conn.last_error:
raise conn.last_error
elif not conn.connected_event.is_set():
conn.close()
raise OperationTimedOut("Timed out creating connection")
else:
return conn
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self.connected_event = Event()
self._write_queue = Queue()
self._callbacks = {}
self._push_watchers = defaultdict(set)
sockerr = None
addresses = socket.getaddrinfo(
self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM
)
for (af, socktype, proto, canonname, sockaddr) in addresses:
try:
self._socket = socket.socket(af, socktype, proto)
self._socket.settimeout(1.0)
self._socket.connect(sockaddr)
sockerr = None
break
except socket.error as err:
sockerr = err
if sockerr:
raise socket.error(
sockerr.errno,
"Tried connecting to %s. Last error: %s" % (
[a[4] for a in addresses], sockerr.strerror)
)
if self.sockopts:
for args in self.sockopts:
self._socket.setsockopt(*args)
self._read_watcher = eventlet.spawn(lambda: self.handle_read())
self._write_watcher = eventlet.spawn(lambda: self.handle_write())
self._send_options_message()
def close(self):
with self.lock:
if self.is_closed:
return
self.is_closed = True
log.debug("Closing connection (%s) to %s" % (id(self), self.host))
cur_gthread = eventlet.getcurrent()
if self._read_watcher and self._read_watcher != cur_gthread:
self._read_watcher.kill()
if self._write_watcher and self._write_watcher != cur_gthread:
self._write_watcher.kill()
if self._socket:
self._socket.close()
log.debug("Closed socket to %s" % (self.host,))
if not self.is_defunct:
self.error_all_callbacks(
ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging
self.connected_event.set()
def handle_close(self):
log.debug("connection closed by server")
self.close()
def handle_write(self):
while True:
try:
next_msg = self._write_queue.get()
self._socket.sendall(next_msg)
except socket.error as err:
log.debug("Exception during socket send for %s: %s", self, err)
self.defunct(err)
return # Leave the write loop
def handle_read(self):
run_select = partial(select.select, (self._socket,), (), ())
while True:
try:
run_select()
except Exception as exc:
if not self.is_closed:
log.debug("Exception during read select() for %s: %s",
self, exc)
self.defunct(exc)
return
try:
buf = self._socket.recv(self.in_buffer_size)
self._iobuf.write(buf)
except socket.error as err:
if not is_timeout(err):
log.debug("Exception during socket recv for %s: %s",
self, err)
self.defunct(err)
return # leave the read loop
if self._iobuf.tell():
self.process_io_buffer()
else:
log.debug("Connection %s closed by server", self)
self.close()
return
def push(self, data):
chunk_size = self.out_buffer_size
for i in xrange(0, len(data), chunk_size):
self._write_queue.put(data[i:i + chunk_size])
def register_watcher(self, event_type, callback, register_timeout=None):
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=[event_type]),
timeout=register_timeout)
def register_watchers(self, type_callback_dict, register_timeout=None):
for event_type, callback in type_callback_dict.items():
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=type_callback_dict.keys()),
timeout=register_timeout)

View File

@@ -211,7 +211,7 @@ class Metadata(object):
return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options)
def _build_usertype(self, keyspace, usertype_row):
type_classes = map(types.lookup_casstype, usertype_row['field_types'])
type_classes = list(map(types.lookup_casstype, usertype_row['field_types']))
return UserType(usertype_row['keyspace_name'], usertype_row['type_name'],
usertype_row['field_names'], type_classes)

View File

@@ -355,15 +355,21 @@ class HostConnection(object):
return
def connection_finished_setting_keyspace(conn, error):
self.return_connection(conn)
errors = [] if not error else [error]
callback(self, errors)
self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
def get_connections(self):
c = self._connection
return [c] if c else []
def get_state(self):
have_conn = self._connection is not None
in_flight = self._connection.in_flight if have_conn else 0
return "shutdown: %s, open: %s, in_flights: %s" % (self.is_shutdown, have_conn, in_flight)
connection = self._connection
open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0
in_flights = [connection.in_flight] if connection else []
return {'shutdown': self.is_shutdown, 'open_count': open_count, 'in_flights': in_flights}
_MAX_SIMULTANEOUS_CREATION = 1
@@ -683,6 +689,7 @@ class HostConnectionPool(object):
return
def connection_finished_setting_keyspace(conn, error):
self.return_connection(conn)
remaining_callbacks.remove(conn)
if error:
errors.append(error)
@@ -693,6 +700,9 @@ class HostConnectionPool(object):
for conn in self._connections:
conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
def get_connections(self):
return self._connections
def get_state(self):
in_flights = ", ".join([str(c.in_flight) for c in self._connections])
return "shutdown: %s, open_count: %d, in_flights: %s" % (self.is_shutdown, self.open_count, in_flights)
in_flights = [c.in_flight for c in self._connections]
return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights}

View File

@@ -632,14 +632,14 @@ class ResultMessage(_MessageType):
typeclass = typeclass.apply_parameters((keysubtype, valsubtype))
elif typeclass == TupleType:
num_items = read_short(f)
types = tuple(cls.read_type(f, user_type_map) for _ in xrange(num_items))
types = tuple(cls.read_type(f, user_type_map) for _ in range(num_items))
typeclass = typeclass.apply_parameters(types)
elif typeclass == UserType:
ks = read_string(f)
udt_name = read_string(f)
num_fields = read_short(f)
names_and_types = tuple((read_string(f), cls.read_type(f, user_type_map))
for _ in xrange(num_fields))
for _ in range(num_fields))
mapped_class = user_type_map.get(ks, {}).get(udt_name)
typeclass = typeclass.make_udt_class(
ks, udt_name, names_and_types, mapped_class)

View File

@@ -500,13 +500,13 @@ class BoundStatement(Statement):
try:
self.values.append(col_type.serialize(value, proto_version))
except (TypeError, struct.error):
except (TypeError, struct.error) as exc:
col_name = col_spec[2]
expected_type = col_type
actual_type = type(value)
message = ('Received an argument of invalid type for column "%s". '
'Expected: %s, Got: %s' % (col_name, expected_type, actual_type))
'Expected: %s, Got: %s; (%s)' % (col_name, expected_type, actual_type, exc))
raise TypeError(message)
return self

View File

@@ -555,3 +555,101 @@ except ImportError:
if item in other:
isect.add(item)
return isect
from collections import Mapping
import six
from six.moves import cPickle
class OrderedMap(Mapping):
'''
An ordered map that accepts non-hashable types for keys. It also maintains the
insertion order of items, behaving as OrderedDict in that regard. These maps
are constructed and read just as normal mapping types, exept that they may
contain arbitrary collections and other non-hashable items as keys::
>>> od = OrderedMap([({'one': 1, 'two': 2}, 'value'),
... ({'three': 3, 'four': 4}, 'value2')])
>>> list(od.keys())
[{'two': 2, 'one': 1}, {'three': 3, 'four': 4}]
>>> list(od.values())
['value', 'value2']
These constructs are needed to support nested collections in Cassandra 2.1.3+,
where frozen collections can be specified as parameters to others\*::
CREATE TABLE example (
...
value map<frozen<map<int, int>>, double>
...
)
This class dervies from the (immutable) Mapping API. Objects in these maps
are not intended be modified.
\* Note: Because of the way Cassandra encodes nested types, when using the
driver with nested collections, :attr:`~.Cluster.protocol_version` must be 3
or higher.
'''
def __init__(self, *args, **kwargs):
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
self._items = []
self._index = {}
if args:
e = args[0]
if callable(getattr(e, 'keys', None)):
for k in e.keys():
self._items.append((k, e[k]))
else:
for k, v in e:
self._insert(k, v)
for k, v in six.iteritems(kwargs):
self._insert(k, v)
def _insert(self, key, value):
flat_key = self._serialize_key(key)
i = self._index.get(flat_key, -1)
if i >= 0:
self._items[i] = (key, value)
else:
self._items.append((key, value))
self._index[flat_key] = len(self._items) - 1
def __getitem__(self, key):
index = self._index[self._serialize_key(key)]
return self._items[index][1]
def __iter__(self):
for i in self._items:
yield i[0]
def __len__(self):
return len(self._items)
def __eq__(self, other):
if isinstance(other, OrderedMap):
return self._items == other._items
try:
d = dict(other)
return len(d) == len(self._items) and all(i[1] == d[i[0]] for i in self._items)
except KeyError:
return False
except TypeError:
pass
return NotImplemented
def __repr__(self):
return '%s([%s])' % (
self.__class__.__name__,
', '.join("(%r, %r)" % (k, v) for k, v in self._items))
def __str__(self):
return '{%s}' % ', '.join("%s: %s" % (k, v) for k, v in self._items)
@staticmethod
def _serialize_key(key):
return cPickle.dumps(key)

View File

@@ -14,3 +14,9 @@
.. autoclass:: PlainTextAuthenticator
:members:
.. autoclass:: SaslAuthProvider
:members:
.. autoclass:: SaslAuthenticator
:members:

View File

@@ -39,6 +39,12 @@
.. autoattribute:: control_connection_timeout
.. autoattribute:: idle_heartbeat_interval
.. autoattribute:: schema_event_refresh_window
.. autoattribute:: topology_event_refresh_window
.. automethod:: connect
.. automethod:: shutdown
@@ -57,6 +63,13 @@
.. automethod:: set_max_connections_per_host
.. automethod:: refresh_schema
.. automethod:: refresh_nodes
.. automethod:: set_meta_refresh_enabled
.. autoclass:: Session ()
.. autoattribute:: default_timeout

View File

@@ -0,0 +1,7 @@
``cassandra.io.eventletreactor`` - ``eventlet``-compatible Connection
=====================================================================
.. module:: cassandra.io.eventletreactor
.. autoclass:: EventletConnection
:members:

View File

@@ -0,0 +1,7 @@
``cassandra.util`` - Utilities
===================================
.. module:: cassandra.util
.. autoclass:: OrderedMap
:members:

View File

@@ -16,7 +16,9 @@ API Documentation
cassandra/decoder
cassandra/concurrent
cassandra/connection
cassandra/util
cassandra/io/asyncorereactor
cassandra/io/eventletreactor
cassandra/io/libevreactor
cassandra/io/geventreactor
cassandra/io/twistedreactor

View File

@@ -34,9 +34,11 @@ to be explicit.
Custom Authenticators
^^^^^^^^^^^^^^^^^^^^^
If you're using something other than Cassandra's ``PasswordAuthenticator``,
you may need to create your own subclasses of :class:`~.AuthProvider` and
:class:`~.Authenticator`. You can use :class:`~.PlainTextAuthProvider`
and :class:`~.PlainTextAuthenticator` as example implementations.
:class:`~.SaslAuthProvider` is provided for generic SASL authentication mechanisms,
utilizing the ``pure-sasl`` package.
If these do not suit your needs, you may need to create your own subclasses of
:class:`~.AuthProvider` and :class:`~.Authenticator`. You can use the Sasl classes
as example implementations.
Protocol v1 Authentication
^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@@ -20,6 +20,11 @@ if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests":
from gevent.monkey import patch_all
patch_all()
if __name__ == '__main__' and sys.argv[1] == "eventlet_nosetests":
print("Running eventlet tests")
from eventlet import monkey_patch
monkey_patch()
import ez_setup
ez_setup.use_setuptools()
@@ -51,10 +56,14 @@ try:
from nose.commands import nosetests
except ImportError:
gevent_nosetests = None
eventlet_nosetests = None
else:
class gevent_nosetests(nosetests):
description = "run nosetests with gevent monkey patching"
class eventlet_nosetests(nosetests):
description = "run nosetests with eventlet monkey patching"
class DocCommand(Command):
@@ -174,10 +183,14 @@ On OSX, via homebrew:
def run_setup(extensions):
kw = {'cmdclass': {'doc': DocCommand}}
if gevent_nosetests is not None:
kw['cmdclass']['gevent_nosetests'] = gevent_nosetests
if eventlet_nosetests is not None:
kw['cmdclass']['eventlet_nosetests'] = eventlet_nosetests
if extensions:
kw['cmdclass']['build_ext'] = build_extensions
kw['ext_modules'] = extensions

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import logging
import sys
log = logging.getLogger()
log.setLevel('DEBUG')
@@ -21,3 +22,18 @@ if not log.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s [%(module)s:%(lineno)s]: %(message)s'))
log.addHandler(handler)
def is_gevent_monkey_patched():
return 'gevent.monkey' in sys.modules
def is_eventlet_monkey_patched():
if 'eventlet.patcher' in sys.modules:
import eventlet
return eventlet.patcher.is_monkey_patched('socket')
return False
def is_monkey_patched():
return is_gevent_monkey_patched() or is_eventlet_monkey_patched()

View File

@@ -13,11 +13,13 @@
# limitations under the License.
import logging
import time
from tests.integration import use_singledc, get_cluster, remove_cluster, PROTOCOL_VERSION
from cassandra.cluster import Cluster, NoHostAvailable
from cassandra.auth import PlainTextAuthProvider, SASLClient, SaslAuthProvider
from tests.integration import use_singledc, get_cluster, remove_cluster, PROTOCOL_VERSION
from tests.integration.util import assert_quiescent_pool_state
try:
import unittest2 as unittest
@@ -35,7 +37,10 @@ def setup_module():
'authorizer': 'CassandraAuthorizer'}
ccm_cluster.set_configuration_options(config_options)
log.debug("Starting ccm test cluster with %s", config_options)
ccm_cluster.start(wait_for_binary_proto=True)
ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True)
# there seems to be some race, with some versions of C* taking longer to
# get the auth (and default user) setup. Sleep here to give it a chance
time.sleep(2)
def teardown_module():
@@ -59,12 +64,13 @@ class AuthenticationTests(unittest.TestCase):
:return: authentication object suitable for Cluster.connect()
"""
if PROTOCOL_VERSION < 2:
return lambda(hostname): dict(username=username, password=password)
return lambda hostname: dict(username=username, password=password)
else:
return PlainTextAuthProvider(username=username, password=password)
def cluster_as(self, usr, pwd):
return Cluster(protocol_version=PROTOCOL_VERSION,
idle_heartbeat_interval=0,
auth_provider=self.get_authentication_provider(username=usr, password=pwd))
def test_auth_connect(self):
@@ -77,9 +83,11 @@ class AuthenticationTests(unittest.TestCase):
cluster = self.cluster_as(user, passwd)
session = cluster.connect()
self.assertTrue(session.execute('SELECT release_version FROM system.local'))
assert_quiescent_pool_state(self, cluster)
cluster.shutdown()
root_session.execute('DROP USER %s', user)
assert_quiescent_pool_state(self, root_session.cluster)
root_session.cluster.shutdown()
def test_connect_wrong_pwd(self):
@@ -88,6 +96,8 @@ class AuthenticationTests(unittest.TestCase):
'.*AuthenticationFailed.*Bad credentials.*Username and/or '
'password are incorrect.*',
cluster.connect)
assert_quiescent_pool_state(self, cluster)
cluster.shutdown()
def test_connect_wrong_username(self):
cluster = self.cluster_as('wrong_user', 'cassandra')
@@ -95,6 +105,8 @@ class AuthenticationTests(unittest.TestCase):
'.*AuthenticationFailed.*Bad credentials.*Username and/or '
'password are incorrect.*',
cluster.connect)
assert_quiescent_pool_state(self, cluster)
cluster.shutdown()
def test_connect_empty_pwd(self):
cluster = self.cluster_as('Cassandra', '')
@@ -102,12 +114,16 @@ class AuthenticationTests(unittest.TestCase):
'.*AuthenticationFailed.*Bad credentials.*Username and/or '
'password are incorrect.*',
cluster.connect)
assert_quiescent_pool_state(self, cluster)
cluster.shutdown()
def test_connect_no_auth_provider(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.assertRaisesRegexp(NoHostAvailable,
'.*AuthenticationFailed.*Remote end requires authentication.*',
cluster.connect)
assert_quiescent_pool_state(self, cluster)
cluster.shutdown()
class SaslAuthenticatorTests(AuthenticationTests):

View File

@@ -12,18 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.integration import use_singledc, PROTOCOL_VERSION
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
import cassandra
from cassandra.query import SimpleStatement, TraceUnavailable
from cassandra.policies import RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance
from collections import deque
from mock import patch
import time
from uuid import uuid4
import cassandra
from cassandra.cluster import Cluster, NoHostAvailable
from cassandra.concurrent import execute_concurrent
from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy,
RetryPolicy, SimpleConvictionPolicy, HostDistance,
WhiteListRoundRobinPolicy)
from cassandra.query import SimpleStatement, TraceUnavailable
from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions
from tests.integration.util import assert_quiescent_pool_state
def setup_module():
@@ -66,6 +74,8 @@ class ClusterTests(unittest.TestCase):
result = session.execute("SELECT * FROM clustertests.cf0")
self.assertEqual([('a', 'b', 'c')], result)
session.execute("DROP KEYSPACE clustertests")
cluster.shutdown()
def test_connect_on_keyspace(self):
@@ -202,6 +212,159 @@ class ClusterTests(unittest.TestCase):
self.assertIn("newkeyspace", cluster.metadata.keyspaces)
session.execute("DROP KEYSPACE newkeyspace")
def test_refresh_schema(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
original_meta = cluster.metadata.keyspaces
# full schema refresh, with wait
cluster.refresh_schema()
self.assertIsNot(original_meta, cluster.metadata.keyspaces)
self.assertEqual(original_meta, cluster.metadata.keyspaces)
session.shutdown()
def test_refresh_schema_keyspace(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
original_meta = cluster.metadata.keyspaces
original_system_meta = original_meta['system']
# only refresh one keyspace
cluster.refresh_schema(keyspace='system')
current_meta = cluster.metadata.keyspaces
self.assertIs(original_meta, current_meta)
current_system_meta = current_meta['system']
self.assertIsNot(original_system_meta, current_system_meta)
self.assertEqual(original_system_meta.as_cql_query(), current_system_meta.as_cql_query())
session.shutdown()
def test_refresh_schema_table(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
original_meta = cluster.metadata.keyspaces
original_system_meta = original_meta['system']
original_system_schema_meta = original_system_meta.tables['schema_columnfamilies']
# only refresh one table
cluster.refresh_schema(keyspace='system', table='schema_columnfamilies')
current_meta = cluster.metadata.keyspaces
current_system_meta = current_meta['system']
current_system_schema_meta = current_system_meta.tables['schema_columnfamilies']
self.assertIs(original_meta, current_meta)
self.assertIs(original_system_meta, current_system_meta)
self.assertIsNot(original_system_schema_meta, current_system_schema_meta)
self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query())
session.shutdown()
def test_refresh_schema_type(self):
if get_server_versions()[0] < (2, 1, 0):
raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1')
if PROTOCOL_VERSION < 3:
raise unittest.SkipTest('UDTs are not specified in change events for protocol v2')
# We may want to refresh types on keyspace change events in that case(?)
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
keyspace_name = 'test1rf'
type_name = self._testMethodName
session.execute('CREATE TYPE IF NOT EXISTS %s.%s (one int, two text)' % (keyspace_name, type_name))
original_meta = cluster.metadata.keyspaces
original_test1rf_meta = original_meta[keyspace_name]
original_type_meta = original_test1rf_meta.user_types[type_name]
# only refresh one type
cluster.refresh_schema(keyspace='test1rf', usertype=type_name)
current_meta = cluster.metadata.keyspaces
current_test1rf_meta = current_meta[keyspace_name]
current_type_meta = current_test1rf_meta.user_types[type_name]
self.assertIs(original_meta, current_meta)
self.assertIs(original_test1rf_meta, current_test1rf_meta)
self.assertIsNot(original_type_meta, current_type_meta)
self.assertEqual(original_type_meta.as_cql_query(), current_type_meta.as_cql_query())
session.shutdown()
def test_refresh_schema_no_wait(self):
contact_points = ['127.0.0.1']
cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=10,
contact_points=contact_points, load_balancing_policy=WhiteListRoundRobinPolicy(contact_points))
session = cluster.connect()
schema_ver = session.execute("SELECT schema_version FROM system.local WHERE key='local'")[0][0]
# create a schema disagreement
session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (uuid4(),))
try:
agreement_timeout = 1
# cluster agreement wait exceeded
c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=agreement_timeout)
start_time = time.time()
s = c.connect()
end_time = time.time()
self.assertGreaterEqual(end_time - start_time, agreement_timeout)
self.assertTrue(c.metadata.keyspaces)
# cluster agreement wait used for refresh
original_meta = c.metadata.keyspaces
start_time = time.time()
self.assertRaisesRegexp(Exception, r"Schema was not refreshed.*", c.refresh_schema)
end_time = time.time()
self.assertGreaterEqual(end_time - start_time, agreement_timeout)
self.assertIs(original_meta, c.metadata.keyspaces)
# refresh wait overrides cluster value
original_meta = c.metadata.keyspaces
start_time = time.time()
c.refresh_schema(max_schema_agreement_wait=0)
end_time = time.time()
self.assertLess(end_time - start_time, agreement_timeout)
self.assertIsNot(original_meta, c.metadata.keyspaces)
self.assertEqual(original_meta, c.metadata.keyspaces)
s.shutdown()
refresh_threshold = 0.5
# cluster agreement bypass
c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=0)
start_time = time.time()
s = c.connect()
end_time = time.time()
self.assertLess(end_time - start_time, refresh_threshold)
self.assertTrue(c.metadata.keyspaces)
# cluster agreement wait used for refresh
original_meta = c.metadata.keyspaces
start_time = time.time()
c.refresh_schema()
end_time = time.time()
self.assertLess(end_time - start_time, refresh_threshold)
self.assertIsNot(original_meta, c.metadata.keyspaces)
self.assertEqual(original_meta, c.metadata.keyspaces)
# refresh wait overrides cluster value
original_meta = c.metadata.keyspaces
start_time = time.time()
self.assertRaisesRegexp(Exception, r"Schema was not refreshed.*", c.refresh_schema, max_schema_agreement_wait=agreement_timeout)
end_time = time.time()
self.assertGreaterEqual(end_time - start_time, agreement_timeout)
self.assertIs(original_meta, c.metadata.keyspaces)
s.shutdown()
finally:
session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (schema_ver,))
session.shutdown()
def test_trace(self):
"""
Ensure trace can be requested for async and non-async queries
@@ -271,3 +434,115 @@ class ClusterTests(unittest.TestCase):
self.assertIn(query, str(future))
self.assertIn('result', str(future))
def test_idle_heartbeat(self):
interval = 1
cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=interval)
if PROTOCOL_VERSION < 3:
cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
session = cluster.connect()
# This test relies on impl details of connection req id management to see if heartbeats
# are being sent. May need update if impl is changed
connection_request_ids = {}
for h in cluster.get_connection_holders():
for c in h.get_connections():
# make sure none are idle (should have startup messages)
self.assertFalse(c.is_idle)
with c.lock:
connection_request_ids[id(c)] = deque(c.request_ids) # copy of request ids
# let two heatbeat intervals pass (first one had startup messages in it)
time.sleep(2 * interval + interval/10.)
connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()]
# make sure requests were sent on all connections
for c in connections:
expected_ids = connection_request_ids[id(c)]
expected_ids.rotate(-1)
with c.lock:
self.assertListEqual(list(c.request_ids), list(expected_ids))
# assert idle status
self.assertTrue(all(c.is_idle for c in connections))
# send messages on all connections
statements_and_params = [("SELECT release_version FROM system.local", ())] * len(cluster.metadata.all_hosts())
results = execute_concurrent(session, statements_and_params)
for success, result in results:
self.assertTrue(success)
# assert not idle status
self.assertFalse(any(c.is_idle if not c.is_control_connection else False for c in connections))
# holders include session pools and cc
holders = cluster.get_connection_holders()
self.assertIn(cluster.control_connection, holders)
self.assertEqual(len(holders), len(cluster.metadata.all_hosts()) + 1) # hosts pools, 1 for cc
# include additional sessions
session2 = cluster.connect()
holders = cluster.get_connection_holders()
self.assertIn(cluster.control_connection, holders)
self.assertEqual(len(holders), 2 * len(cluster.metadata.all_hosts()) + 1) # 2 sessions' hosts pools, 1 for cc
cluster._idle_heartbeat.stop()
cluster._idle_heartbeat.join()
assert_quiescent_pool_state(self, cluster)
session.shutdown()
@patch('cassandra.cluster.Cluster.idle_heartbeat_interval', new=0.1)
def test_idle_heartbeat_disabled(self):
self.assertTrue(Cluster.idle_heartbeat_interval)
# heartbeat disabled with '0'
cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0)
self.assertEqual(cluster.idle_heartbeat_interval, 0)
session = cluster.connect()
# let two heatbeat intervals pass (first one had startup messages in it)
time.sleep(2 * Cluster.idle_heartbeat_interval)
connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()]
# assert not idle status (should never get reset because there is not heartbeat)
self.assertFalse(any(c.is_idle for c in connections))
session.shutdown()
def test_pool_management(self):
# Ensure that in_flight and request_ids quiesce after cluster operations
cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0) # no idle heartbeat here, pool management is tested in test_idle_heartbeat
session = cluster.connect()
session2 = cluster.connect()
# prepare
p = session.prepare("SELECT * FROM system.local WHERE key=?")
self.assertTrue(session.execute(p, ('local',)))
# simple
self.assertTrue(session.execute("SELECT * FROM system.local WHERE key='local'"))
# set keyspace
session.set_keyspace('system')
session.set_keyspace('system_traces')
# use keyspace
session.execute('USE system')
session.execute('USE system_traces')
# refresh schema
cluster.refresh_schema()
cluster.refresh_schema(max_schema_agreement_wait=0)
# submit schema refresh
future = cluster.submit_schema_refresh()
future.result()
assert_quiescent_pool_state(self, cluster)
session2.shutdown()
session.shutdown()

View File

@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.integration import use_singledc, PROTOCOL_VERSION
try:
import unittest2 as unittest
except ImportError:
@@ -21,13 +19,15 @@ except ImportError:
from functools import partial
from six.moves import range
import sys
from threading import Thread, Event
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cluster import NoHostAvailable
from cassandra.protocol import QueryMessage
from cassandra.io.asyncorereactor import AsyncoreConnection
from cassandra.protocol import QueryMessage
from tests import is_monkey_patched
from tests.integration import use_singledc, PROTOCOL_VERSION
try:
from cassandra.io.libevreactor import LibevConnection
@@ -230,8 +230,8 @@ class AsyncoreConnectionTests(ConnectionTests, unittest.TestCase):
klass = AsyncoreConnection
def setUp(self):
if 'gevent.monkey' in sys.modules:
raise unittest.SkipTest("Can't test asyncore with gevent monkey patching")
if is_monkey_patched():
raise unittest.SkipTest("Can't test asyncore with monkey patching")
ConnectionTests.setUp(self)
@@ -240,8 +240,8 @@ class LibevConnectionTests(ConnectionTests, unittest.TestCase):
klass = LibevConnection
def setUp(self):
if 'gevent.monkey' in sys.modules:
raise unittest.SkipTest("Can't test libev with gevent monkey patching")
if is_monkey_patched():
raise unittest.SkipTest("Can't test libev with monkey patching")
if LibevConnection is None:
raise unittest.SkipTest(
'libev does not appear to be installed properly')

View File

@@ -11,15 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import difflib
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
import difflib
from mock import Mock
import six
import sys
from cassandra import AlreadyExists
@@ -361,6 +362,9 @@ class TestCodeCoverage(unittest.TestCase):
"Protocol 3.0+ is required for UDT change events, currently testing against %r"
% (PROTOCOL_VERSION,))
if sys.version_info[2:] != (2, 7):
raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.')
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
@@ -549,6 +553,9 @@ CREATE TABLE export_udts.users (
if get_server_versions()[0] < (2, 1, 0):
raise unittest.SkipTest('Test schema output assumes 2.1.0+ options')
if sys.version_info[2:] != (2, 7):
raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.')
cli_script = """CREATE KEYSPACE legacy
WITH placement_strategy = 'SimpleStrategy'
AND strategy_options = {replication_factor:1};
@@ -633,7 +640,7 @@ create column family composite_comp_with_col
index_type : 0}]
and compression_options = {'sstable_compression' : 'org.apache.cassandra.io.compress.LZ4Compressor'};"""
# note: the innerlkey type for legacy.nested_composite_key
# note: the inner key type for legacy.nested_composite_key
# (org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UUIDType, org.apache.cassandra.db.marshal.UTF8Type))
# is a bit strange, but it replays in CQL with desired results
expected_string = """CREATE KEYSPACE legacy WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND durable_writes = true;
@@ -807,6 +814,7 @@ CREATE TABLE legacy.composite_comp_no_col (
cluster.shutdown()
class TokenMetadataTest(unittest.TestCase):
"""
Test of TokenMap creation and other behavior.
@@ -882,7 +890,7 @@ class KeyspaceAlterMetadata(unittest.TestCase):
self.assertEqual(original_keyspace_meta.durable_writes, True)
self.assertEqual(len(original_keyspace_meta.tables), 1)
self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' %name)
self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % name)
new_keyspace_meta = self.cluster.metadata.keyspaces[name]
self.assertNotEqual(original_keyspace_meta, new_keyspace_meta)
self.assertEqual(new_keyspace_meta.durable_writes, False)

View File

@@ -21,8 +21,10 @@ except ImportError:
import logging
log = logging.getLogger(__name__)
from collections import namedtuple
from decimal import Decimal
from datetime import datetime, date, time
from functools import partial
import six
from uuid import uuid1, uuid4
@@ -30,10 +32,14 @@ from cassandra import InvalidRequest
from cassandra.cluster import Cluster
from cassandra.cqltypes import Int32Type, EMPTY
from cassandra.query import dict_factory
from cassandra.util import OrderedDict, sortedset
from cassandra.util import OrderedMap, sortedset
from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION
# defined in module scope for pickling in OrderedMap
nested_collection_udt = namedtuple('nested_collection_udt', ['m', 't', 'l', 's'])
nested_collection_udt_nested = namedtuple('nested_collection_udt_nested', ['m', 't', 'l', 's', 'u'])
def setup_module():
use_singledc()
@@ -287,11 +293,11 @@ class TypeTests(unittest.TestCase):
s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)",
('', '', '', [''], {'': 3}))
self.assertEqual(
{'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedDict({'': 3})},
{'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedMap({'': 3})},
s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0])
self.assertEqual(
{'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedDict({'': 3})},
{'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedMap({'': 3})},
s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0])
# non-string types shouldn't accept empty strings
@@ -708,10 +714,10 @@ class TypeTests(unittest.TestCase):
self.assertEquals((None, None, None, None), s.execute(read)[0].t)
# also test empty strings where compatible
s.execute(insert, [('', None, None, '')])
s.execute(insert, [('', None, None, b'')])
result = s.execute("SELECT * FROM mytable WHERE k=0")
self.assertEquals(('', None, None, ''), result[0].t)
self.assertEquals(('', None, None, ''), s.execute(read)[0].t)
self.assertEquals(('', None, None, b''), result[0].t)
self.assertEquals(('', None, None, b''), s.execute(read)[0].t)
c.shutdown()
@@ -721,3 +727,61 @@ class TypeTests(unittest.TestCase):
query = u"SELECT * FROM system.schema_columnfamilies WHERE keyspace_name = 'ef\u2052ef' AND columnfamily_name = %s"
s.execute(query, (u"fe\u2051fe",))
def insert_select_column(self, session, table_name, column_name, value):
insert = session.prepare("INSERT INTO %s (k, %s) VALUES (?, ?)" % (table_name, column_name))
session.execute(insert, (0, value))
result = session.execute("SELECT %s FROM %s WHERE k=%%s" % (column_name, table_name), (0,))[0][0]
self.assertEqual(result, value)
def test_nested_collections(self):
if self._cass_version < (2, 1, 3):
raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3")
name = self._testMethodName
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect('test1rf')
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
s.execute("""
CREATE TYPE %s (
m frozen<map<int,text>>,
t tuple<int,text>,
l frozen<list<int>>,
s frozen<set<int>>
)""" % name)
s.execute("""
CREATE TYPE %s_nested (
m frozen<map<int,text>>,
t tuple<int,text>,
l frozen<list<int>>,
s frozen<set<int>>,
u frozen<%s>
)""" % (name, name))
s.execute("""
CREATE TABLE %s (
k int PRIMARY KEY,
map_map map<frozen<map<int,int>>, frozen<map<int,int>>>,
map_set map<frozen<set<int>>, frozen<set<int>>>,
map_list map<frozen<list<int>>, frozen<list<int>>>,
map_tuple map<frozen<tuple<int, int>>, frozen<tuple<int>>>,
map_udt map<frozen<%s_nested>, frozen<%s>>,
)""" % (name, name, name))
validate = partial(self.insert_select_column, s, name)
validate('map_map', OrderedMap([({1: 1, 2: 2}, {3: 3, 4: 4}), ({5: 5, 6: 6}, {7: 7, 8: 8})]))
validate('map_set', OrderedMap([(set((1, 2)), set((3, 4))), (set((5, 6)), set((7, 8)))]))
validate('map_list', OrderedMap([([1, 2], [3, 4]), ([5, 6], [7, 8])]))
validate('map_tuple', OrderedMap([((1, 2), (3,)), ((4, 5), (6,))]))
value = nested_collection_udt({1: 'v1', 2: 'v2'}, (3, 'v3'), [4, 5, 6, 7], set((8, 9, 10)))
key = nested_collection_udt_nested(value.m, value.t, value.l, value.s, value)
key2 = nested_collection_udt_nested({3: 'v3'}, value.t, value.l, value.s, value)
validate('map_udt', OrderedMap([(key, value), (key2, value)]))
s.execute("DROP TABLE %s" % (name))
s.execute("DROP TYPE %s_nested" % (name))
s.execute("DROP TYPE %s" % (name))
s.shutdown()

37
tests/integration/util.py Normal file
View File

@@ -0,0 +1,37 @@
# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.integration import PROTOCOL_VERSION
def assert_quiescent_pool_state(test_case, cluster):
for session in cluster.sessions:
pool_states = session.get_pool_state().values()
test_case.assertTrue(pool_states)
for state in pool_states:
test_case.assertFalse(state['shutdown'])
test_case.assertGreater(state['open_count'], 0)
test_case.assertTrue(all((i == 0 for i in state['in_flights'])))
for holder in cluster.get_connection_holders():
for connection in holder.get_connections():
# all ids are unique
req_ids = connection.request_ids
test_case.assertEqual(len(req_ids), len(set(req_ids)))
test_case.assertEqual(connection.highest_request_id, len(req_ids) - 1)
test_case.assertEqual(connection.highest_request_id, max(req_ids))
if PROTOCOL_VERSION < 3:
test_case.assertEqual(connection.highest_request_id, connection.max_request_id)

View File

@@ -31,20 +31,20 @@ from mock import patch, Mock
from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT,
ConnectionException)
from cassandra.io.asyncorereactor import AsyncoreConnection
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
SupportedMessage, ReadyMessage, ServerError)
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
from cassandra.io.asyncorereactor import AsyncoreConnection
from tests import is_monkey_patched
class AsyncoreConnectionTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
if 'gevent.monkey' in sys.modules:
raise unittest.SkipTest("gevent monkey-patching detected")
if is_monkey_patched():
raise unittest.SkipTest("monkey-patching detected")
AsyncoreConnection.initialize_reactor()
cls.socket_patcher = patch('socket.socket', spec=socket.socket)
cls.mock_socket = cls.socket_patcher.start()

View File

@@ -11,21 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
import unittest # noqa
from mock import Mock, ANY, call
import six
from six import BytesIO
from mock import Mock, ANY
import time
from threading import Lock
from cassandra.cluster import Cluster
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT,
HEADER_DIRECTION_FROM_CLIENT, ProtocolError,
locally_supported_compressions)
locally_supported_compressions, ConnectionHeartbeat)
from cassandra.marshal import uint8_pack, uint32_pack
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
SupportedMessage)
@@ -280,3 +280,155 @@ class ConnectionTest(unittest.TestCase):
def test_set_connection_class(self):
cluster = Cluster(connection_class='test')
self.assertEqual('test', cluster.connection_class)
class ConnectionHeartbeatTest(unittest.TestCase):
@staticmethod
def make_get_holders(len):
holders = []
for _ in range(len):
holder = Mock()
holder.get_connections = Mock(return_value=[])
holders.append(holder)
get_holders = Mock(return_value=holders)
return get_holders
def run_heartbeat(self, get_holders_fun, count=2, interval=0.05):
ch = ConnectionHeartbeat(interval, get_holders_fun)
time.sleep(interval * count)
ch.stop()
ch.join()
self.assertTrue(get_holders_fun.call_count)
def test_empty_connections(self):
count = 3
get_holders = self.make_get_holders(1)
self.run_heartbeat(get_holders, count)
self.assertGreaterEqual(get_holders.call_count, count - 1) # lower bound to account for thread spinup time
self.assertLessEqual(get_holders.call_count, count)
holder = get_holders.return_value[0]
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
def test_idle_non_idle(self):
request_id = 999
# connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback)
def send_msg(msg, req_id, msg_callback):
msg_callback(SupportedMessage([], {}))
idle_connection = Mock(spec=Connection, host='localhost',
max_request_id=127,
lock=Lock(),
in_flight=0, is_idle=True,
is_defunct=False, is_closed=False,
get_request_id=lambda: request_id,
send_msg=Mock(side_effect=send_msg))
non_idle_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=False, is_closed=False)
get_holders = self.make_get_holders(1)
holder = get_holders.return_value[0]
holder.get_connections.return_value.append(idle_connection)
holder.get_connections.return_value.append(non_idle_connection)
self.run_heartbeat(get_holders)
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
self.assertEqual(idle_connection.in_flight, 0)
self.assertEqual(non_idle_connection.in_flight, 0)
idle_connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count)
self.assertEqual(non_idle_connection.send_msg.call_count, 0)
def test_closed_defunct(self):
get_holders = self.make_get_holders(1)
closed_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=False, is_closed=True)
defunct_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=True, is_closed=False)
holder = get_holders.return_value[0]
holder.get_connections.return_value.append(closed_connection)
holder.get_connections.return_value.append(defunct_connection)
self.run_heartbeat(get_holders)
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
self.assertEqual(closed_connection.in_flight, 0)
self.assertEqual(defunct_connection.in_flight, 0)
self.assertEqual(closed_connection.send_msg.call_count, 0)
self.assertEqual(defunct_connection.send_msg.call_count, 0)
def test_no_req_ids(self):
in_flight = 3
get_holders = self.make_get_holders(1)
max_connection = Mock(spec=Connection, host='localhost',
max_request_id=in_flight, in_flight=in_flight,
is_idle=True, is_defunct=False, is_closed=False)
holder = get_holders.return_value[0]
holder.get_connections.return_value.append(max_connection)
self.run_heartbeat(get_holders)
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
self.assertEqual(max_connection.in_flight, in_flight)
self.assertEqual(max_connection.send_msg.call_count, 0)
self.assertEqual(max_connection.send_msg.call_count, 0)
max_connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count)
holder.return_connection.assert_has_calls([call(max_connection)] * get_holders.call_count)
def test_unexpected_response(self):
request_id = 999
get_holders = self.make_get_holders(1)
def send_msg(msg, req_id, msg_callback):
msg_callback(object())
connection = Mock(spec=Connection, host='localhost',
max_request_id=127,
lock=Lock(),
in_flight=0, is_idle=True,
is_defunct=False, is_closed=False,
get_request_id=lambda: request_id,
send_msg=Mock(side_effect=send_msg))
holder = get_holders.return_value[0]
holder.get_connections.return_value.append(connection)
self.run_heartbeat(get_holders)
self.assertEqual(connection.in_flight, get_holders.call_count)
connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count)
connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count)
exc = connection.defunct.call_args_list[0][0][0]
self.assertIsInstance(exc, Exception)
self.assertEqual(exc.args, Exception('Connection heartbeat failure').args)
holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count)
def test_timeout(self):
request_id = 999
get_holders = self.make_get_holders(1)
def send_msg(msg, req_id, msg_callback):
pass
connection = Mock(spec=Connection, host='localhost',
max_request_id=127,
lock=Lock(),
in_flight=0, is_idle=True,
is_defunct=False, is_closed=False,
get_request_id=lambda: request_id,
send_msg=Mock(side_effect=send_msg))
holder = get_holders.return_value[0]
holder.get_connections.return_value.append(connection)
self.run_heartbeat(get_holders)
self.assertEqual(connection.in_flight, get_holders.call_count)
connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count)
connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count)
exc = connection.defunct.call_args_list[0][0][0]
self.assertIsInstance(exc, Exception)
self.assertEqual(exc.args, Exception('Connection heartbeat failure').args)
holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count)

View File

@@ -73,7 +73,7 @@ class MockCluster(object):
self.scheduler = Mock(spec=_Scheduler)
self.executor = Mock(spec=ThreadPoolExecutor)
def add_host(self, address, datacenter, rack, signal=False):
def add_host(self, address, datacenter, rack, signal=False, refresh_nodes=True):
host = Host(address, SimpleConvictionPolicy, datacenter, rack)
self.added_hosts.append(host)
return host
@@ -131,7 +131,7 @@ class ControlConnectionTest(unittest.TestCase):
self.connection = MockConnection()
self.time = FakeTime()
self.control_connection = ControlConnection(self.cluster, timeout=1)
self.control_connection = ControlConnection(self.cluster, 1, 0, 0)
self.control_connection._connection = self.connection
self.control_connection._time = self.time
@@ -345,39 +345,44 @@ class ControlConnectionTest(unittest.TestCase):
'change_type': 'NEW_NODE',
'address': ('1.2.3.4', 9000)
}
self.cluster.scheduler.reset_mock()
self.control_connection._handle_topology_change(event)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map)
self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map)
event = {
'change_type': 'REMOVED_NODE',
'address': ('1.2.3.4', 9000)
}
self.cluster.scheduler.reset_mock()
self.control_connection._handle_topology_change(event)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.remove_host, None)
self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.remove_host, None)
event = {
'change_type': 'MOVED_NODE',
'address': ('1.2.3.4', 9000)
}
self.cluster.scheduler.reset_mock()
self.control_connection._handle_topology_change(event)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map)
self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map)
def test_handle_status_change(self):
event = {
'change_type': 'UP',
'address': ('1.2.3.4', 9000)
}
self.cluster.scheduler.reset_mock()
self.control_connection._handle_status_change(event)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map)
self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map)
# do the same with a known Host
event = {
'change_type': 'UP',
'address': ('192.168.1.0', 9000)
}
self.cluster.scheduler.reset_mock()
self.control_connection._handle_status_change(event)
host = self.cluster.metadata.hosts['192.168.1.0']
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.on_up, host)
self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.on_up, host)
self.cluster.scheduler.schedule.reset_mock()
event = {
@@ -404,9 +409,11 @@ class ControlConnectionTest(unittest.TestCase):
'keyspace': 'ks1',
'table': 'table1'
}
self.cluster.scheduler.reset_mock()
self.control_connection._handle_schema_change(event)
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', 'table1', None)
self.cluster.scheduler.schedule_unique.assert_called_once_with(0.0, self.control_connection.refresh_schema, 'ks1', 'table1', None)
self.cluster.scheduler.reset_mock()
event['table'] = None
self.control_connection._handle_schema_change(event)
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', None, None)
self.cluster.scheduler.schedule_unique.assert_called_once_with(0.0, self.control_connection.refresh_schema, 'ks1', None, None)

View File

@@ -24,7 +24,7 @@ from decimal import Decimal
from uuid import UUID
from cassandra.cqltypes import lookup_casstype
from cassandra.util import OrderedDict, sortedset
from cassandra.util import OrderedMap, sortedset
marshalled_value_pairs = (
# binary form, type, python native type
@@ -75,7 +75,7 @@ marshalled_value_pairs = (
(b'', 'MapType(AsciiType, BooleanType)', None),
(b'', 'ListType(FloatType)', None),
(b'', 'SetType(LongType)', None),
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()),
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMap()),
(b'\x00\x00', 'ListType(FloatType)', []),
(b'\x00\x00', 'SetType(IntegerType)', sortedset()),
(b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]),
@@ -84,15 +84,14 @@ marshalled_value_pairs = (
(b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', 1)
)
ordered_dict_value = OrderedDict()
ordered_dict_value[u'\u307fbob'] = 199
ordered_dict_value[u''] = -1
ordered_dict_value[u'\\'] = 0
ordered_map_value = OrderedMap([(u'\u307fbob', 199),
(u'', -1),
(u'\\', 0)])
# these following entries work for me right now, but they're dependent on
# vagaries of internal python ordering for unordered types
marshalled_value_pairs_unsafe = (
(b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value),
(b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_map_value),
(b'\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])),
(b'\x00', 'IntegerType', 0),
)

View File

@@ -0,0 +1,127 @@
# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from cassandra.util import OrderedMap
from cassandra.cqltypes import EMPTY
import six
class OrderedMapTest(unittest.TestCase):
def test_init(self):
a = OrderedMap(zip(['one', 'three', 'two'], [1, 3, 2]))
b = OrderedMap([('one', 1), ('three', 3), ('two', 2)])
c = OrderedMap(a)
builtin = {'one': 1, 'two': 2, 'three': 3}
self.assertEqual(a, b)
self.assertEqual(a, c)
self.assertEqual(a, builtin)
self.assertEqual(OrderedMap([(1, 1), (1, 2)]), {1: 2})
def test_contains(self):
keys = ['first', 'middle', 'last']
od = OrderedMap()
od = OrderedMap(zip(keys, range(len(keys))))
for k in keys:
self.assertTrue(k in od)
self.assertFalse(k not in od)
self.assertTrue('notthere' not in od)
self.assertFalse('notthere' in od)
def test_keys(self):
keys = ['first', 'middle', 'last']
od = OrderedMap(zip(keys, range(len(keys))))
self.assertListEqual(list(od.keys()), keys)
def test_values(self):
keys = ['first', 'middle', 'last']
values = list(range(len(keys)))
od = OrderedMap(zip(keys, values))
self.assertListEqual(list(od.values()), values)
def test_items(self):
keys = ['first', 'middle', 'last']
items = list(zip(keys, range(len(keys))))
od = OrderedMap(items)
self.assertListEqual(list(od.items()), items)
def test_get(self):
keys = ['first', 'middle', 'last']
od = OrderedMap(zip(keys, range(len(keys))))
for v, k in enumerate(keys):
self.assertEqual(od.get(k), v)
self.assertEqual(od.get('notthere', 'default'), 'default')
self.assertIsNone(od.get('notthere'))
def test_equal(self):
d1 = {'one': 1}
d12 = {'one': 1, 'two': 2}
od1 = OrderedMap({'one': 1})
od12 = OrderedMap([('one', 1), ('two', 2)])
od21 = OrderedMap([('two', 2), ('one', 1)])
self.assertEqual(od1, d1)
self.assertEqual(od12, d12)
self.assertEqual(od21, d12)
self.assertNotEqual(od1, od12)
self.assertNotEqual(od12, od1)
self.assertNotEqual(od12, od21)
self.assertNotEqual(od1, d12)
self.assertNotEqual(od12, d1)
self.assertNotEqual(od1, EMPTY)
def test_getitem(self):
keys = ['first', 'middle', 'last']
od = OrderedMap(zip(keys, range(len(keys))))
for v, k in enumerate(keys):
self.assertEqual(od[k], v)
with self.assertRaises(KeyError):
od['notthere']
def test_iter(self):
keys = ['first', 'middle', 'last']
values = list(range(len(keys)))
items = list(zip(keys, values))
od = OrderedMap(items)
itr = iter(od)
self.assertEqual(sum([1 for _ in itr]), len(keys))
self.assertRaises(StopIteration, six.next, itr)
self.assertEqual(list(iter(od)), keys)
self.assertEqual(list(six.iteritems(od)), items)
self.assertEqual(list(six.itervalues(od)), values)
def test_len(self):
self.assertEqual(len(OrderedMap()), 0)
self.assertEqual(len(OrderedMap([(1, 1)])), 1)
def test_mutable_keys(self):
d = {'1': 1}
s = set([1, 2, 3])
od = OrderedMap([(d, 'dict'), (s, 'set')])

View File

@@ -262,11 +262,13 @@ class TypeTests(unittest.TestCase):
self.assertEqual(cassandra_type.val, 'randomvaluetocheck')
def test_datetype(self):
now_timestamp = time.time()
now_datetime = datetime.datetime.utcfromtimestamp(now_timestamp)
now_time_seconds = time.time()
now_datetime = datetime.datetime.utcfromtimestamp(now_time_seconds)
# Cassandra timestamps in millis
now_timestamp = now_time_seconds * 1e3
# same results serialized
# (this could change if we follow up on the timestamp multiplication warning in DateType.serialize)
self.assertEqual(DateType.serialize(now_datetime, 0), DateType.serialize(now_timestamp, 0))
# from timestamp