diff --git a/CHANGELOG.rst b/CHANGELOG.rst index aaa67c8e..47b9d29c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,14 +1,22 @@ -2.1.3-post -========== +2.1.4 +===== Features -------- * SaslAuthenticator for Kerberos support (PYTHON-109) +* Heartbeat for network device keepalive and detecting failures on idle connections (PYTHON-197) +* Support nested, frozen collections for Cassandra 2.1.3+ (PYTHON-186) +* Schema agreement wait bypass config, new call for synchronous schema refresh (PYTHON-205) +* Add eventlet connection support (PYTHON-194) Bug Fixes --------- * Schema meta fix for complex thrift tables (PYTHON-191) * Support for 'unknown' replica placement strategies in schema meta (PYTHON-192) +* Resolve stream ID leak on set_keyspace (PYTHON-195) +* Remove implicit timestamp scaling on serialization of numeric timestamps (PYTHON-204) +* Resolve stream id collision when using SASL auth (PYTHON-210) +* Correct unhexlify usage for user defined type meta in Python3 (PYTHON-208) 2.1.3 ===== diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 93c53dcc..96f08576 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -23,7 +23,7 @@ class NullHandler(logging.Handler): logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (2, 1, 3, 'post') +__version_info__ = (2, 1, 4, 'post') __version__ = '.'.join(map(str, __version_info__)) diff --git a/cassandra/auth.py b/cassandra/auth.py index 77c5bcf9..fc13a821 100644 --- a/cassandra/auth.py +++ b/cassandra/auth.py @@ -130,7 +130,9 @@ class PlainTextAuthenticator(Authenticator): class SaslAuthProvider(AuthProvider): """ - An :class:`~.AuthProvider` for Kerberos authenticators + An :class:`~.AuthProvider` supporting general SASL auth mechanisms + + Suitable for GSSAPI or other SASL mechanisms Example usage:: @@ -144,7 +146,7 @@ class SaslAuthProvider(AuthProvider): auth_provider = SaslAuthProvider(**sasl_kwargs) cluster = Cluster(auth_provider=auth_provider) - .. versionadded:: 2.1.3-post + .. versionadded:: 2.1.4 """ def __init__(self, **sasl_kwargs): @@ -157,9 +159,10 @@ class SaslAuthProvider(AuthProvider): class SaslAuthenticator(Authenticator): """ - An :class:`~.Authenticator` that works with DSE's KerberosAuthenticator. + A pass-through :class:`~.Authenticator` using the third party package + 'pure-sasl' for authentication - .. versionadded:: 2.1.3-post + .. versionadded:: 2.1.4 """ def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs): diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 0822e2a2..b639af69 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -22,6 +22,7 @@ import atexit from collections import defaultdict from concurrent.futures import ThreadPoolExecutor import logging +from random import random import socket import sys import time @@ -44,7 +45,8 @@ from itertools import groupby from cassandra import (ConsistencyLevel, AuthenticationFailed, InvalidRequest, OperationTimedOut, UnsupportedOperation, Unauthorized) -from cassandra.connection import ConnectionException, ConnectionShutdown +from cassandra.connection import (ConnectionException, ConnectionShutdown, + ConnectionHeartbeat) from cassandra.encoder import Encoder from cassandra.protocol import (QueryMessage, ResultMessage, ErrorMessage, ReadTimeoutErrorMessage, @@ -68,10 +70,20 @@ from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement, BatchStatement, bind_params, QueryTrace, Statement, named_tuple_factory, dict_factory, FETCH_SIZE_UNSET) -# default to gevent when we are monkey patched, otherwise if libev is available, use that as the -# default because it's fastest. Otherwise, use asyncore. + +def _is_eventlet_monkey_patched(): + if 'eventlet.patcher' not in sys.modules: + return False + import eventlet.patcher + return eventlet.patcher.is_monkey_patched('socket') + +# default to gevent when we are monkey patched with gevent, eventlet when +# monkey patched with eventlet, otherwise if libev is available, use that as +# the default because it's fastest. Otherwise, use asyncore. if 'gevent.monkey' in sys.modules: from cassandra.io.geventreactor import GeventConnection as DefaultConnection +elif _is_eventlet_monkey_patched(): + from cassandra.io.eventletreactor import EventletConnection as DefaultConnection else: try: from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA @@ -372,6 +384,48 @@ class Cluster(object): If set to :const:`None`, there will be no timeout for these queries. """ + idle_heartbeat_interval = 30 + """ + Interval, in seconds, on which to heartbeat idle connections. This helps + keep connections open through network devices that expire idle connections. + It also helps discover bad connections early in low-traffic scenarios. + Setting to zero disables heartbeats. + """ + + schema_event_refresh_window = 2 + """ + Window, in seconds, within which a schema component will be refreshed after + receiving a schema_change event. + + The driver delays a random amount of time in the range [0.0, window) + before executing the refresh. This serves two purposes: + + 1.) Spread the refresh for deployments with large fanout from C* to client tier, + preventing a 'thundering herd' problem with many clients refreshing simultaneously. + + 2.) Remove redundant refreshes. Redundant events arriving within the delay period + are discarded, and only one refresh is executed. + + Setting this to zero will execute refreshes immediately. + + Setting this negative will disable schema refreshes in response to push events + (refreshes will still occur in response to schema change responses to DDL statements + executed by Sessions of this Cluster). + """ + + topology_event_refresh_window = 10 + """ + Window, in seconds, within which the node and token list will be refreshed after + receiving a topology_change event. + + Setting this to zero will execute refreshes immediately. + + Setting this negative will disable node refreshes in response to push events + (refreshes will still occur in response to new nodes observed on "UP" events). + + See :attr:`.schema_event_refresh_window` for discussion of rationale + """ + sessions = None control_connection = None scheduler = None @@ -380,6 +434,7 @@ class Cluster(object): _is_setup = False _prepared_statements = None _prepared_statement_lock = None + _idle_heartbeat = None _user_types = None """ @@ -406,7 +461,10 @@ class Cluster(object): protocol_version=2, executor_threads=2, max_schema_agreement_wait=10, - control_connection_timeout=2.0): + control_connection_timeout=2.0, + idle_heartbeat_interval=30, + schema_event_refresh_window=2, + topology_event_refresh_window=10): """ Any of the mutable Cluster attributes may be set as keyword arguments to the constructor. @@ -456,6 +514,9 @@ class Cluster(object): self.cql_version = cql_version self.max_schema_agreement_wait = max_schema_agreement_wait self.control_connection_timeout = control_connection_timeout + self.idle_heartbeat_interval = idle_heartbeat_interval + self.schema_event_refresh_window = schema_event_refresh_window + self.topology_event_refresh_window = topology_event_refresh_window self._listeners = set() self._listener_lock = Lock() @@ -500,7 +561,8 @@ class Cluster(object): self.metrics = Metrics(weakref.proxy(self)) self.control_connection = ControlConnection( - self, self.control_connection_timeout) + self, self.control_connection_timeout, + self.schema_event_refresh_window, self.topology_event_refresh_window) def register_user_type(self, keyspace, user_type, klass): """ @@ -621,7 +683,7 @@ class Cluster(object): def set_max_connections_per_host(self, host_distance, max_connections): """ - Gets the maximum number of connections per Session that will be opened + Sets the maximum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for :attr:`~HostDistance.REMOTE`. @@ -688,18 +750,19 @@ class Cluster(object): self.load_balancing_policy.populate( weakref.proxy(self), self.metadata.all_hosts()) - if self.control_connection: - try: - self.control_connection.connect() - log.debug("Control connection created") - except Exception: - log.exception("Control connection failed to connect, " - "shutting down Cluster:") - self.shutdown() - raise + try: + self.control_connection.connect() + log.debug("Control connection created") + except Exception: + log.exception("Control connection failed to connect, " + "shutting down Cluster:") + self.shutdown() + raise self.load_balancing_policy.check_supported() + if self.idle_heartbeat_interval: + self._idle_heartbeat = ConnectionHeartbeat(self.idle_heartbeat_interval, self.get_connection_holders) self._is_setup = True session = self._new_session() @@ -707,6 +770,13 @@ class Cluster(object): session.set_keyspace(keyspace) return session + def get_connection_holders(self): + holders = [] + for s in self.sessions: + holders.extend(s.get_pools()) + holders.append(self.control_connection) + return holders + def shutdown(self): """ Closes all sessions and connection associated with this Cluster. @@ -721,18 +791,17 @@ class Cluster(object): else: self.is_shutdown = True - if self.scheduler: - self.scheduler.shutdown() + if self._idle_heartbeat: + self._idle_heartbeat.stop() - if self.control_connection: - self.control_connection.shutdown() + self.scheduler.shutdown() - if self.sessions: - for session in self.sessions: - session.shutdown() + self.control_connection.shutdown() - if self.executor: - self.executor.shutdown() + for session in self.sessions: + session.shutdown() + + self.executor.shutdown() def _new_session(self): session = Session(self, self.metadata.all_hosts()) @@ -907,7 +976,7 @@ class Cluster(object): self._start_reconnector(host, is_host_addition) - def on_add(self, host): + def on_add(self, host, refresh_nodes=True): if self.is_shutdown: return @@ -919,7 +988,7 @@ class Cluster(object): log.debug("Done preparing queries for new host %r", host) self.load_balancing_policy.on_add(host) - self.control_connection.on_add(host) + self.control_connection.on_add(host, refresh_nodes) if distance == HostDistance.IGNORED: log.debug("Not adding connection pool for new host %r because the " @@ -995,7 +1064,7 @@ class Cluster(object): self.on_down(host, is_host_addition, expect_host_to_be_down) return is_down - def add_host(self, address, datacenter=None, rack=None, signal=True): + def add_host(self, address, datacenter=None, rack=None, signal=True, refresh_nodes=True): """ Called when adding initial contact points and when the control connection subsequently discovers a new node. Intended for internal @@ -1004,7 +1073,7 @@ class Cluster(object): new_host = self.metadata.add_host(address, datacenter, rack) if new_host and signal: log.info("New Cassandra host %r discovered", new_host) - self.on_add(new_host) + self.on_add(new_host, refresh_nodes) return new_host @@ -1045,16 +1114,20 @@ class Cluster(object): for pool in session._pools.values(): pool.ensure_core_connections() - def refresh_schema(self, keyspace=None, table=None, usertype=None, schema_agreement_wait=None): + def refresh_schema(self, keyspace=None, table=None, usertype=None, max_schema_agreement_wait=None): """ Synchronously refresh the schema metadata. - By default timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait` + + By default, the timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait` and :attr:`~.Cluster.control_connection_timeout`. - Passing schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`. - Setting schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately. + + Passing max_schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`. + + Setting max_schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately. + An Exception is raised if schema refresh fails for any reason. """ - if not self.control_connection.refresh_schema(keyspace, table, usertype, schema_agreement_wait): + if not self.control_connection.refresh_schema(keyspace, table, usertype, max_schema_agreement_wait): raise Exception("Schema was not refreshed. See log for details.") def submit_schema_refresh(self, keyspace=None, table=None, usertype=None): @@ -1066,6 +1139,27 @@ class Cluster(object): return self.executor.submit( self.control_connection.refresh_schema, keyspace, table, usertype) + def refresh_nodes(self): + """ + Synchronously refresh the node list and token metadata + + An Exception is raised if node refresh fails for any reason. + """ + if not self.control_connection.refresh_node_list_and_token_map(): + raise Exception("Node list was not refreshed. See log for details.") + + def set_meta_refresh_enabled(self, enabled): + """ + Sets a flag to enable (True) or disable (False) all metadata refresh queries. + This applies to both schema and node topology. + + Disabling this is useful to minimize refreshes during multiple changes. + + Meta refresh must be enabled for the driver to become aware of any cluster + topology changes or schema updates. + """ + self.control_connection.set_meta_refresh_enabled(bool(enabled)) + def _prepare_all_queries(self, host): if not self._prepared_statements: return @@ -1656,6 +1750,9 @@ class Session(object): def get_pool_state(self): return dict((host, pool.get_state()) for host, pool in self._pools.items()) + def get_pools(self): + return self._pools.values() + class UserTypeDoesNotExist(Exception): """ @@ -1734,16 +1831,26 @@ class ControlConnection(object): _timeout = None _protocol_version = None + _schema_event_refresh_window = None + _topology_event_refresh_window = None + + _meta_refresh_enabled = True + # for testing purposes _time = time - def __init__(self, cluster, timeout): + def __init__(self, cluster, timeout, + schema_event_refresh_window, + topology_event_refresh_window): # use a weak reference to allow the Cluster instance to be GC'ed (and # shutdown) since implementing __del__ disables the cycle detector self._cluster = weakref.proxy(cluster) self._connection = None self._timeout = timeout + self._schema_event_refresh_window = schema_event_refresh_window + self._topology_event_refresh_window = topology_event_refresh_window + self._lock = RLock() self._schema_agreement_lock = Lock() @@ -1901,6 +2008,10 @@ class ControlConnection(object): def refresh_schema(self, keyspace=None, table=None, usertype=None, schema_agreement_wait=None): + if not self._meta_refresh_enabled: + log.debug("[control connection] Skipping schema refresh because meta refresh is disabled") + return False + try: if self._connection: return self._refresh_schema(self._connection, keyspace, table, usertype, @@ -2028,14 +2139,20 @@ class ControlConnection(object): return True def refresh_node_list_and_token_map(self, force_token_rebuild=False): + if not self._meta_refresh_enabled: + log.debug("[control connection] Skipping node list refresh because meta refresh is disabled") + return False + try: if self._connection: self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild) + return True except ReferenceError: pass # our weak reference to the Cluster is no good except Exception: log.debug("[control connection] Error refreshing node list and token map", exc_info=True) self._signal_error() + return False def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, force_token_rebuild=False): @@ -2096,7 +2213,7 @@ class ControlConnection(object): rack = row.get("rack") if host is None: log.debug("[control connection] Found new host to connect to: %s", addr) - host = self._cluster.add_host(addr, datacenter, rack, signal=True) + host = self._cluster.add_host(addr, datacenter, rack, signal=True, refresh_nodes=False) should_rebuild_token_map = True else: should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) @@ -2131,25 +2248,25 @@ class ControlConnection(object): def _handle_topology_change(self, event): change_type = event["change_type"] addr, port = event["address"] - if change_type == "NEW_NODE": - self._cluster.scheduler.schedule(10, self.refresh_node_list_and_token_map) + if change_type == "NEW_NODE" or change_type == "MOVED_NODE": + if self._topology_event_refresh_window >= 0: + delay = random() * self._topology_event_refresh_window + self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) elif change_type == "REMOVED_NODE": host = self._cluster.metadata.get_host(addr) - self._cluster.scheduler.schedule(0, self._cluster.remove_host, host) - elif change_type == "MOVED_NODE": - self._cluster.scheduler.schedule(1, self.refresh_node_list_and_token_map) + self._cluster.scheduler.schedule_unique(0, self._cluster.remove_host, host) def _handle_status_change(self, event): change_type = event["change_type"] addr, port = event["address"] host = self._cluster.metadata.get_host(addr) if change_type == "UP": + delay = 1 + random() * 0.5 # randomness to avoid thundering herd problem on events if host is None: # this is the first time we've seen the node - self._cluster.scheduler.schedule(2, self.refresh_node_list_and_token_map) + self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) else: - # this will be run by the scheduler - self._cluster.scheduler.schedule(2, self._cluster.on_up, host) + self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host) elif change_type == "DOWN": # Note that there is a slight risk we can receive the event late and thus # mark the host down even though we already had reconnected successfully. @@ -2160,10 +2277,14 @@ class ControlConnection(object): self._cluster.on_down(host, is_host_addition=False) def _handle_schema_change(self, event): + if self._schema_event_refresh_window < 0: + return + keyspace = event.get('keyspace') table = event.get('table') usertype = event.get('type') - self._submit(self.refresh_schema, keyspace, table, usertype) + delay = random() * self._schema_event_refresh_window + self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, keyspace, table, usertype) def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): @@ -2271,11 +2392,6 @@ class ControlConnection(object): # manually self.reconnect() - @property - def is_open(self): - conn = self._connection - return bool(conn and conn.is_open) - def on_up(self, host): pass @@ -2289,12 +2405,24 @@ class ControlConnection(object): # this will result in a task being submitted to the executor to reconnect self.reconnect() - def on_add(self, host): - self.refresh_node_list_and_token_map(force_token_rebuild=True) + def on_add(self, host, refresh_nodes=True): + if refresh_nodes: + self.refresh_node_list_and_token_map(force_token_rebuild=True) def on_remove(self, host): self.refresh_node_list_and_token_map(force_token_rebuild=True) + def get_connections(self): + c = getattr(self, '_connection', None) + return [c] if c else [] + + def return_connection(self, connection): + if connection is self._connection and (connection.is_defunct or connection.is_closed): + self.reconnect() + + def set_meta_refresh_enabled(self, enabled): + self._meta_refresh_enabled = enabled + def _stop_scheduler(scheduler, thread): try: @@ -2308,12 +2436,14 @@ def _stop_scheduler(scheduler, thread): class _Scheduler(object): - _scheduled = None + _queue = None + _scheduled_tasks = None _executor = None is_shutdown = False def __init__(self, executor): - self._scheduled = Queue.PriorityQueue() + self._queue = Queue.PriorityQueue() + self._scheduled_tasks = set() self._executor = executor t = Thread(target=self.run, name="Task Scheduler") @@ -2331,14 +2461,25 @@ class _Scheduler(object): # this can happen on interpreter shutdown pass self.is_shutdown = True - self._scheduled.put_nowait((0, None)) + self._queue.put_nowait((0, None)) - def schedule(self, delay, fn, *args, **kwargs): + def schedule(self, delay, fn, *args): + self._insert_task(delay, (fn, args)) + + def schedule_unique(self, delay, fn, *args): + task = (fn, args) + if task not in self._scheduled_tasks: + self._insert_task(delay, task) + else: + log.debug("Ignoring schedule_unique for already-scheduled task: %r", task) + + def _insert_task(self, delay, task): if not self.is_shutdown: run_at = time.time() + delay - self._scheduled.put_nowait((run_at, (fn, args, kwargs))) + self._scheduled_tasks.add(task) + self._queue.put_nowait((run_at, task)) else: - log.debug("Ignoring scheduled function after shutdown: %r", fn) + log.debug("Ignoring scheduled task after shutdown: %r", task) def run(self): while True: @@ -2347,16 +2488,17 @@ class _Scheduler(object): try: while True: - run_at, task = self._scheduled.get(block=True, timeout=None) + run_at, task = self._queue.get(block=True, timeout=None) if self.is_shutdown: log.debug("Not executing scheduled task due to Scheduler shutdown") return if run_at <= time.time(): - fn, args, kwargs = task - future = self._executor.submit(fn, *args, **kwargs) + self._scheduled_tasks.remove(task) + fn, args = task + future = self._executor.submit(fn, *args) future.add_done_callback(self._log_if_failed) else: - self._scheduled.put_nowait((run_at, task)) + self._queue.put_nowait((run_at, task)) break except Queue.Empty: pass @@ -2373,9 +2515,13 @@ class _Scheduler(object): def refresh_schema_and_set_result(keyspace, table, usertype, control_conn, response_future): try: - log.debug("Refreshing schema in response to schema change. Keyspace: %s; Table: %s, Type: %s", - keyspace, table, usertype) - control_conn._refresh_schema(response_future._connection, keyspace, table, usertype) + if control_conn._meta_refresh_enabled: + log.debug("Refreshing schema in response to schema change. Keyspace: %s; Table: %s, Type: %s", + keyspace, table, usertype) + control_conn._refresh_schema(response_future._connection, keyspace, table, usertype) + else: + log.debug("Skipping schema refresh in response to schema change because meta refresh is disabled; " + "Keyspace: %s; Table: %s, Type: %s", keyspace, table, usertype) except Exception: log.exception("Exception refreshing schema in response to schema change:") response_future.session.submit( diff --git a/cassandra/connection.py b/cassandra/connection.py index 5a58793c..9a41a0d6 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -20,7 +20,7 @@ import io import logging import os import sys -from threading import Event, RLock +from threading import Thread, Event, RLock import time if 'gevent.monkey' in sys.modules: @@ -159,7 +159,7 @@ class Connection(object): in_flight = 0 # A set of available request IDs. When using the v3 protocol or higher, - # this will no initially include all request IDs in order to save memory, + # this will not initially include all request IDs in order to save memory, # but the set will grow if it is exhausted. request_ids = None @@ -172,6 +172,8 @@ class Connection(object): lock = None user_type_map = None + msg_received = False + is_control_connection = False _iobuf = None @@ -401,6 +403,8 @@ class Connection(object): with self.lock: self.request_ids.append(stream_id) + self.msg_received = True + body = None try: # check that the protocol version is supported @@ -673,6 +677,13 @@ class Connection(object): self.send_msg(query, request_id, process_result) + @property + def is_idle(self): + return not self.msg_received + + def reset_idle(self): + self.msg_received = False + def __str__(self): status = "" if self.is_defunct: @@ -732,3 +743,100 @@ class ResponseWaiter(object): raise OperationTimedOut() else: return self.responses + + +class HeartbeatFuture(object): + def __init__(self, connection, owner): + self._exception = None + self._event = Event() + self.connection = connection + self.owner = owner + log.debug("Sending options message heartbeat on idle connection (%s) %s", + id(connection), connection.host) + with connection.lock: + if connection.in_flight < connection.max_request_id: + connection.in_flight += 1 + connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback) + else: + self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold") + self._event.set() + + def wait(self, timeout): + self._event.wait(timeout) + if self._event.is_set(): + if self._exception: + raise self._exception + else: + raise OperationTimedOut() + + def _options_callback(self, response): + if not isinstance(response, SupportedMessage): + if isinstance(response, ConnectionException): + self._exception = response + else: + self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s" + % (response,)) + + log.debug("Received options response on connection (%s) from %s", + id(self.connection), self.connection.host) + self._event.set() + + +class ConnectionHeartbeat(Thread): + + def __init__(self, interval_sec, get_connection_holders): + Thread.__init__(self, name="Connection heartbeat") + self._interval = interval_sec + self._get_connection_holders = get_connection_holders + self._shutdown_event = Event() + self.daemon = True + self.start() + + def run(self): + self._shutdown_event.wait(self._interval) + while not self._shutdown_event.is_set(): + start_time = time.time() + + futures = [] + failed_connections = [] + try: + for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]: + for connection in connections: + if not (connection.is_defunct or connection.is_closed): + if connection.is_idle: + try: + futures.append(HeartbeatFuture(connection, owner)) + except Exception: + log.warning("Failed sending heartbeat message on connection (%s) to %s", + id(connection), connection.host, exc_info=True) + failed_connections.append((connection, owner)) + else: + connection.reset_idle() + else: + # make sure the owner sees this defunt/closed connection + owner.return_connection(connection) + + for f in futures: + connection = f.connection + try: + f.wait(self._interval) + # TODO: move this, along with connection locks in pool, down into Connection + with connection.lock: + connection.in_flight -= 1 + connection.reset_idle() + except Exception: + log.warning("Heartbeat failed for connection (%s) to %s", + id(connection), connection.host, exc_info=True) + failed_connections.append((f.connection, f.owner)) + + for connection, owner in failed_connections: + connection.defunct(Exception('Connection heartbeat failure')) + owner.return_connection(connection) + except Exception: + log.error("Failed connection heartbeat", exc_info=True) + + elapsed = time.time() - start_time + self._shutdown_event.wait(max(self._interval - elapsed, 0.01)) + + def stop(self): + self._shutdown_event.set() diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index abcb4a10..a9ba4aac 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -31,14 +31,14 @@ from __future__ import absolute_import # to enable import io from stdlib from binascii import unhexlify import calendar from collections import namedtuple +import datetime from decimal import Decimal import io import re import socket import time -import datetime +import sys from uuid import UUID -import warnings import six from six.moves import range @@ -48,7 +48,7 @@ from cassandra.marshal import (int8_pack, int8_unpack, int32_pack, int32_unpack, int64_pack, int64_unpack, float_pack, float_unpack, double_pack, double_unpack, varint_pack, varint_unpack) -from cassandra.util import OrderedDict, sortedset +from cassandra.util import OrderedMap, sortedset apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' @@ -58,10 +58,15 @@ if six.PY3: _time_types = frozenset((int,)) _date_types = frozenset((int,)) long = int + + def _name_from_hex_string(encoded_name): + bin_str = unhexlify(encoded_name) + return bin_str.decode('ascii') else: _number_types = frozenset((int, long, float)) _time_types = frozenset((int, long)) _date_types = frozenset((int, long)) + _name_from_hex_string = unhexlify def trim_if_startswith(s, prefix): @@ -569,7 +574,8 @@ class DateType(_CassandraType): tval = time.strptime(val, tformat) except ValueError: continue - return calendar.timegm(tval) + offset + # scale seconds to millis for the raw value + return (calendar.timegm(tval) + offset) * 1e3 else: raise ValueError("can't interpret %r as a date" % (val,)) @@ -584,31 +590,16 @@ class DateType(_CassandraType): @staticmethod def serialize(v, protocol_version): try: - converted = calendar.timegm(v.utctimetuple()) - converted = converted * 1e3 + getattr(v, 'microsecond', 0) / 1e3 + # v is datetime + timestamp_seconds = calendar.timegm(v.utctimetuple()) + timestamp = timestamp_seconds * 1e3 + getattr(v, 'microsecond', 0) / 1e3 except AttributeError: # Ints and floats are valid timestamps too if type(v) not in _number_types: raise TypeError('DateType arguments must be a datetime or timestamp') + timestamp = v - global _have_warned_about_timestamps - if not _have_warned_about_timestamps: - _have_warned_about_timestamps = True - warnings.warn( - "timestamp columns in Cassandra hold a number of " - "milliseconds since the unix epoch. Currently, when executing " - "prepared statements, this driver multiplies timestamp " - "values by 1000 so that the result of time.time() " - "can be used directly. However, the driver cannot " - "match this behavior for non-prepared statements, " - "so the 2.0 version of the driver will no longer multiply " - "timestamps by 1000. It is suggested that you simply use " - "datetime.datetime objects for 'timestamp' values to avoid " - "any ambiguity and to guarantee a smooth upgrade of the " - "driver.") - converted = v * 1e3 - - return int64_pack(long(converted)) + return int64_pack(long(timestamp)) class TimestampType(DateType): @@ -838,7 +829,7 @@ class MapType(_ParameterizedType): length = 2 numelements = unpack(byts[:length]) p = length - themap = OrderedDict() + themap = OrderedMap() for _ in range(numelements): key_len = unpack(byts[p:p + length]) p += length @@ -850,7 +841,7 @@ class MapType(_ParameterizedType): p += val_len key = subkeytype.from_binary(keybytes, protocol_version) val = subvaltype.from_binary(valbytes, protocol_version) - themap[key] = val + themap._insert(key, val) return themap @classmethod @@ -929,37 +920,39 @@ class UserType(TupleType): typename = "'org.apache.cassandra.db.marshal.UserType'" _cache = {} + _module = sys.modules[__name__] @classmethod def make_udt_class(cls, keyspace, udt_name, names_and_types, mapped_class): if six.PY2 and isinstance(udt_name, unicode): udt_name = udt_name.encode('utf-8') - try: return cls._cache[(keyspace, udt_name)] except KeyError: - fieldnames, types = zip(*names_and_types) + field_names, types = zip(*names_and_types) instance = type(udt_name, (cls,), {'subtypes': types, 'cassname': cls.cassname, 'typename': udt_name, - 'fieldnames': fieldnames, + 'fieldnames': field_names, 'keyspace': keyspace, - 'mapped_class': mapped_class}) + 'mapped_class': mapped_class, + 'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)}) cls._cache[(keyspace, udt_name)] = instance return instance @classmethod def apply_parameters(cls, subtypes, names): keyspace = subtypes[0] - udt_name = unhexlify(subtypes[1].cassname) - field_names = [unhexlify(encoded_name) for encoded_name in names[2:]] + udt_name = _name_from_hex_string(subtypes[1].cassname) + field_names = [_name_from_hex_string(encoded_name) for encoded_name in names[2:]] assert len(field_names) == len(subtypes[2:]) return type(udt_name, (cls,), {'subtypes': subtypes[2:], 'cassname': cls.cassname, 'typename': udt_name, 'fieldnames': field_names, 'keyspace': keyspace, - 'mapped_class': None}) + 'mapped_class': None, + 'tuple_type': namedtuple(udt_name, field_names)}) @classmethod def cql_parameterized_type(cls): @@ -991,8 +984,7 @@ class UserType(TupleType): if cls.mapped_class: return cls.mapped_class(**dict(zip(cls.fieldnames, values))) else: - Result = namedtuple(cls.typename, cls.fieldnames) - return Result(*values) + return cls.tuple_type(*values) @classmethod def serialize_safe(cls, val, protocol_version): @@ -1008,6 +1000,18 @@ class UserType(TupleType): buf.write(int32_pack(-1)) return buf.getvalue() + @classmethod + def _make_registered_udt_namedtuple(cls, keyspace, name, field_names): + # this is required to make the type resolvable via this module... + # required when unregistered udts are pickled for use as keys in + # util.OrderedMap + qualified_name = "%s_%s" % (keyspace, name) + nt = getattr(cls._module, qualified_name, None) + if not nt: + nt = namedtuple(qualified_name, field_names) + setattr(cls._module, qualified_name, nt) + return nt + class CompositeType(_ParameterizedType): typename = "'org.apache.cassandra.db.marshal.CompositeType'" diff --git a/cassandra/encoder.py b/cassandra/encoder.py index 2f3cfa92..516945b3 100644 --- a/cassandra/encoder.py +++ b/cassandra/encoder.py @@ -28,7 +28,7 @@ import types from uuid import UUID import six -from cassandra.util import OrderedDict, sortedset +from cassandra.util import OrderedDict, OrderedMap, sortedset if six.PY3: long = int @@ -77,6 +77,7 @@ class Encoder(object): datetime.time: self.cql_encode_time, dict: self.cql_encode_map_collection, OrderedDict: self.cql_encode_map_collection, + OrderedMap: self.cql_encode_map_collection, list: self.cql_encode_list_collection, tuple: self.cql_encode_list_collection, set: self.cql_encode_set_collection, diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py new file mode 100644 index 00000000..357d8634 --- /dev/null +++ b/cassandra/io/eventletreactor.py @@ -0,0 +1,193 @@ +# Copyright 2014 Symantec Corporation +# Copyright 2013-2014 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Originally derived from MagnetoDB source: +# https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py + +from collections import defaultdict +from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL +import eventlet +from eventlet.green import select, socket +from eventlet.queue import Queue +from functools import partial +import logging +import os +from threading import Event + +from six.moves import xrange + +from cassandra import OperationTimedOut +from cassandra.connection import Connection, ConnectionShutdown +from cassandra.protocol import RegisterMessage + + +log = logging.getLogger(__name__) + + +def is_timeout(err): + return ( + err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or + (err == EINVAL and os.name in ('nt', 'ce')) + ) + + +class EventletConnection(Connection): + """ + An implementation of :class:`.Connection` that utilizes ``eventlet``. + """ + + _total_reqd_bytes = 0 + _read_watcher = None + _write_watcher = None + _socket = None + + @classmethod + def initialize_reactor(cls): + eventlet.monkey_patch() + + @classmethod + def factory(cls, *args, **kwargs): + timeout = kwargs.pop('timeout', 5.0) + conn = cls(*args, **kwargs) + conn.connected_event.wait(timeout) + if conn.last_error: + raise conn.last_error + elif not conn.connected_event.is_set(): + conn.close() + raise OperationTimedOut("Timed out creating connection") + else: + return conn + + def __init__(self, *args, **kwargs): + Connection.__init__(self, *args, **kwargs) + + self.connected_event = Event() + self._write_queue = Queue() + + self._callbacks = {} + self._push_watchers = defaultdict(set) + + sockerr = None + addresses = socket.getaddrinfo( + self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) + for (af, socktype, proto, canonname, sockaddr) in addresses: + try: + self._socket = socket.socket(af, socktype, proto) + self._socket.settimeout(1.0) + self._socket.connect(sockaddr) + sockerr = None + break + except socket.error as err: + sockerr = err + if sockerr: + raise socket.error( + sockerr.errno, + "Tried connecting to %s. Last error: %s" % ( + [a[4] for a in addresses], sockerr.strerror) + ) + + if self.sockopts: + for args in self.sockopts: + self._socket.setsockopt(*args) + + self._read_watcher = eventlet.spawn(lambda: self.handle_read()) + self._write_watcher = eventlet.spawn(lambda: self.handle_write()) + self._send_options_message() + + def close(self): + with self.lock: + if self.is_closed: + return + self.is_closed = True + + log.debug("Closing connection (%s) to %s" % (id(self), self.host)) + + cur_gthread = eventlet.getcurrent() + + if self._read_watcher and self._read_watcher != cur_gthread: + self._read_watcher.kill() + if self._write_watcher and self._write_watcher != cur_gthread: + self._write_watcher.kill() + if self._socket: + self._socket.close() + log.debug("Closed socket to %s" % (self.host,)) + + if not self.is_defunct: + self.error_all_callbacks( + ConnectionShutdown("Connection to %s was closed" % self.host)) + # don't leave in-progress operations hanging + self.connected_event.set() + + def handle_close(self): + log.debug("connection closed by server") + self.close() + + def handle_write(self): + while True: + try: + next_msg = self._write_queue.get() + self._socket.sendall(next_msg) + except socket.error as err: + log.debug("Exception during socket send for %s: %s", self, err) + self.defunct(err) + return # Leave the write loop + + def handle_read(self): + run_select = partial(select.select, (self._socket,), (), ()) + while True: + try: + run_select() + except Exception as exc: + if not self.is_closed: + log.debug("Exception during read select() for %s: %s", + self, exc) + self.defunct(exc) + return + + try: + buf = self._socket.recv(self.in_buffer_size) + self._iobuf.write(buf) + except socket.error as err: + if not is_timeout(err): + log.debug("Exception during socket recv for %s: %s", + self, err) + self.defunct(err) + return # leave the read loop + + if self._iobuf.tell(): + self.process_io_buffer() + else: + log.debug("Connection %s closed by server", self) + self.close() + return + + def push(self, data): + chunk_size = self.out_buffer_size + for i in xrange(0, len(data), chunk_size): + self._write_queue.put(data[i:i + chunk_size]) + + def register_watcher(self, event_type, callback, register_timeout=None): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), + timeout=register_timeout) + + def register_watchers(self, type_callback_dict, register_timeout=None): + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), + timeout=register_timeout) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index c80e1c33..f51b9c93 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -211,7 +211,7 @@ class Metadata(object): return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) def _build_usertype(self, keyspace, usertype_row): - type_classes = map(types.lookup_casstype, usertype_row['field_types']) + type_classes = list(map(types.lookup_casstype, usertype_row['field_types'])) return UserType(usertype_row['keyspace_name'], usertype_row['type_name'], usertype_row['field_names'], type_classes) diff --git a/cassandra/pool.py b/cassandra/pool.py index 587fa277..a99724b0 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -355,15 +355,21 @@ class HostConnection(object): return def connection_finished_setting_keyspace(conn, error): + self.return_connection(conn) errors = [] if not error else [error] callback(self, errors) self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace) + def get_connections(self): + c = self._connection + return [c] if c else [] + def get_state(self): - have_conn = self._connection is not None - in_flight = self._connection.in_flight if have_conn else 0 - return "shutdown: %s, open: %s, in_flights: %s" % (self.is_shutdown, have_conn, in_flight) + connection = self._connection + open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0 + in_flights = [connection.in_flight] if connection else [] + return {'shutdown': self.is_shutdown, 'open_count': open_count, 'in_flights': in_flights} _MAX_SIMULTANEOUS_CREATION = 1 @@ -683,6 +689,7 @@ class HostConnectionPool(object): return def connection_finished_setting_keyspace(conn, error): + self.return_connection(conn) remaining_callbacks.remove(conn) if error: errors.append(error) @@ -693,6 +700,9 @@ class HostConnectionPool(object): for conn in self._connections: conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace) + def get_connections(self): + return self._connections + def get_state(self): - in_flights = ", ".join([str(c.in_flight) for c in self._connections]) - return "shutdown: %s, open_count: %d, in_flights: %s" % (self.is_shutdown, self.open_count, in_flights) + in_flights = [c.in_flight for c in self._connections] + return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights} diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 00239f1c..ca628a43 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -632,14 +632,14 @@ class ResultMessage(_MessageType): typeclass = typeclass.apply_parameters((keysubtype, valsubtype)) elif typeclass == TupleType: num_items = read_short(f) - types = tuple(cls.read_type(f, user_type_map) for _ in xrange(num_items)) + types = tuple(cls.read_type(f, user_type_map) for _ in range(num_items)) typeclass = typeclass.apply_parameters(types) elif typeclass == UserType: ks = read_string(f) udt_name = read_string(f) num_fields = read_short(f) names_and_types = tuple((read_string(f), cls.read_type(f, user_type_map)) - for _ in xrange(num_fields)) + for _ in range(num_fields)) mapped_class = user_type_map.get(ks, {}).get(udt_name) typeclass = typeclass.make_udt_class( ks, udt_name, names_and_types, mapped_class) diff --git a/cassandra/query.py b/cassandra/query.py index 1b0a56d4..c159e787 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -500,13 +500,13 @@ class BoundStatement(Statement): try: self.values.append(col_type.serialize(value, proto_version)) - except (TypeError, struct.error): + except (TypeError, struct.error) as exc: col_name = col_spec[2] expected_type = col_type actual_type = type(value) message = ('Received an argument of invalid type for column "%s". ' - 'Expected: %s, Got: %s' % (col_name, expected_type, actual_type)) + 'Expected: %s, Got: %s; (%s)' % (col_name, expected_type, actual_type, exc)) raise TypeError(message) return self diff --git a/cassandra/util.py b/cassandra/util.py index d02219ff..cf5e0f4a 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -555,3 +555,101 @@ except ImportError: if item in other: isect.add(item) return isect + +from collections import Mapping +import six +from six.moves import cPickle + + +class OrderedMap(Mapping): + ''' + An ordered map that accepts non-hashable types for keys. It also maintains the + insertion order of items, behaving as OrderedDict in that regard. These maps + are constructed and read just as normal mapping types, exept that they may + contain arbitrary collections and other non-hashable items as keys:: + + >>> od = OrderedMap([({'one': 1, 'two': 2}, 'value'), + ... ({'three': 3, 'four': 4}, 'value2')]) + >>> list(od.keys()) + [{'two': 2, 'one': 1}, {'three': 3, 'four': 4}] + >>> list(od.values()) + ['value', 'value2'] + + These constructs are needed to support nested collections in Cassandra 2.1.3+, + where frozen collections can be specified as parameters to others\*:: + + CREATE TABLE example ( + ... + value map>, double> + ... + ) + + This class dervies from the (immutable) Mapping API. Objects in these maps + are not intended be modified. + + \* Note: Because of the way Cassandra encodes nested types, when using the + driver with nested collections, :attr:`~.Cluster.protocol_version` must be 3 + or higher. + + ''' + def __init__(self, *args, **kwargs): + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + + self._items = [] + self._index = {} + if args: + e = args[0] + if callable(getattr(e, 'keys', None)): + for k in e.keys(): + self._items.append((k, e[k])) + else: + for k, v in e: + self._insert(k, v) + + for k, v in six.iteritems(kwargs): + self._insert(k, v) + + def _insert(self, key, value): + flat_key = self._serialize_key(key) + i = self._index.get(flat_key, -1) + if i >= 0: + self._items[i] = (key, value) + else: + self._items.append((key, value)) + self._index[flat_key] = len(self._items) - 1 + + def __getitem__(self, key): + index = self._index[self._serialize_key(key)] + return self._items[index][1] + + def __iter__(self): + for i in self._items: + yield i[0] + + def __len__(self): + return len(self._items) + + def __eq__(self, other): + if isinstance(other, OrderedMap): + return self._items == other._items + try: + d = dict(other) + return len(d) == len(self._items) and all(i[1] == d[i[0]] for i in self._items) + except KeyError: + return False + except TypeError: + pass + return NotImplemented + + def __repr__(self): + return '%s([%s])' % ( + self.__class__.__name__, + ', '.join("(%r, %r)" % (k, v) for k, v in self._items)) + + def __str__(self): + return '{%s}' % ', '.join("%s: %s" % (k, v) for k, v in self._items) + + @staticmethod + def _serialize_key(key): + return cPickle.dumps(key) diff --git a/docs/api/cassandra/auth.rst b/docs/api/cassandra/auth.rst index 0ee6e539..58c964cf 100644 --- a/docs/api/cassandra/auth.rst +++ b/docs/api/cassandra/auth.rst @@ -14,3 +14,9 @@ .. autoclass:: PlainTextAuthenticator :members: + +.. autoclass:: SaslAuthProvider + :members: + +.. autoclass:: SaslAuthenticator + :members: diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst index 2d6a5342..b1b1aebb 100644 --- a/docs/api/cassandra/cluster.rst +++ b/docs/api/cassandra/cluster.rst @@ -39,6 +39,12 @@ .. autoattribute:: control_connection_timeout + .. autoattribute:: idle_heartbeat_interval + + .. autoattribute:: schema_event_refresh_window + + .. autoattribute:: topology_event_refresh_window + .. automethod:: connect .. automethod:: shutdown @@ -57,6 +63,13 @@ .. automethod:: set_max_connections_per_host + .. automethod:: refresh_schema + + .. automethod:: refresh_nodes + + .. automethod:: set_meta_refresh_enabled + + .. autoclass:: Session () .. autoattribute:: default_timeout diff --git a/docs/api/cassandra/io/eventletreactor.rst b/docs/api/cassandra/io/eventletreactor.rst new file mode 100644 index 00000000..1ba742c7 --- /dev/null +++ b/docs/api/cassandra/io/eventletreactor.rst @@ -0,0 +1,7 @@ +``cassandra.io.eventletreactor`` - ``eventlet``-compatible Connection +===================================================================== + +.. module:: cassandra.io.eventletreactor + +.. autoclass:: EventletConnection + :members: diff --git a/docs/api/cassandra/util.rst b/docs/api/cassandra/util.rst new file mode 100644 index 00000000..2e79758d --- /dev/null +++ b/docs/api/cassandra/util.rst @@ -0,0 +1,7 @@ +``cassandra.util`` - Utilities +=================================== + +.. module:: cassandra.util + +.. autoclass:: OrderedMap + :members: diff --git a/docs/api/index.rst b/docs/api/index.rst index 7db7c7ee..27aebf0f 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -16,7 +16,9 @@ API Documentation cassandra/decoder cassandra/concurrent cassandra/connection + cassandra/util cassandra/io/asyncorereactor + cassandra/io/eventletreactor cassandra/io/libevreactor cassandra/io/geventreactor cassandra/io/twistedreactor diff --git a/docs/security.rst b/docs/security.rst index c87c5de8..9f7af68b 100644 --- a/docs/security.rst +++ b/docs/security.rst @@ -34,9 +34,11 @@ to be explicit. Custom Authenticators ^^^^^^^^^^^^^^^^^^^^^ If you're using something other than Cassandra's ``PasswordAuthenticator``, -you may need to create your own subclasses of :class:`~.AuthProvider` and -:class:`~.Authenticator`. You can use :class:`~.PlainTextAuthProvider` -and :class:`~.PlainTextAuthenticator` as example implementations. +:class:`~.SaslAuthProvider` is provided for generic SASL authentication mechanisms, +utilizing the ``pure-sasl`` package. +If these do not suit your needs, you may need to create your own subclasses of +:class:`~.AuthProvider` and :class:`~.Authenticator`. You can use the Sasl classes +as example implementations. Protocol v1 Authentication ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 11cfd4c5..8536755e 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,11 @@ if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests": from gevent.monkey import patch_all patch_all() +if __name__ == '__main__' and sys.argv[1] == "eventlet_nosetests": + print("Running eventlet tests") + from eventlet import monkey_patch + monkey_patch() + import ez_setup ez_setup.use_setuptools() @@ -51,10 +56,14 @@ try: from nose.commands import nosetests except ImportError: gevent_nosetests = None + eventlet_nosetests = None else: class gevent_nosetests(nosetests): description = "run nosetests with gevent monkey patching" + class eventlet_nosetests(nosetests): + description = "run nosetests with eventlet monkey patching" + class DocCommand(Command): @@ -174,10 +183,14 @@ On OSX, via homebrew: def run_setup(extensions): + kw = {'cmdclass': {'doc': DocCommand}} if gevent_nosetests is not None: kw['cmdclass']['gevent_nosetests'] = gevent_nosetests + if eventlet_nosetests is not None: + kw['cmdclass']['eventlet_nosetests'] = eventlet_nosetests + if extensions: kw['cmdclass']['build_ext'] = build_extensions kw['ext_modules'] = extensions diff --git a/tests/__init__.py b/tests/__init__.py index 64d56ea8..5de2c765 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import sys log = logging.getLogger() log.setLevel('DEBUG') @@ -21,3 +22,18 @@ if not log.handlers: handler = logging.StreamHandler() handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s [%(module)s:%(lineno)s]: %(message)s')) log.addHandler(handler) + + +def is_gevent_monkey_patched(): + return 'gevent.monkey' in sys.modules + + +def is_eventlet_monkey_patched(): + if 'eventlet.patcher' in sys.modules: + import eventlet + return eventlet.patcher.is_monkey_patched('socket') + return False + + +def is_monkey_patched(): + return is_gevent_monkey_patched() or is_eventlet_monkey_patched() diff --git a/tests/integration/standard/test_authentication.py b/tests/integration/standard/test_authentication.py index 08832ecd..1534947a 100644 --- a/tests/integration/standard/test_authentication.py +++ b/tests/integration/standard/test_authentication.py @@ -13,11 +13,13 @@ # limitations under the License. import logging +import time -from tests.integration import use_singledc, get_cluster, remove_cluster, PROTOCOL_VERSION from cassandra.cluster import Cluster, NoHostAvailable from cassandra.auth import PlainTextAuthProvider, SASLClient, SaslAuthProvider +from tests.integration import use_singledc, get_cluster, remove_cluster, PROTOCOL_VERSION +from tests.integration.util import assert_quiescent_pool_state try: import unittest2 as unittest @@ -35,7 +37,10 @@ def setup_module(): 'authorizer': 'CassandraAuthorizer'} ccm_cluster.set_configuration_options(config_options) log.debug("Starting ccm test cluster with %s", config_options) - ccm_cluster.start(wait_for_binary_proto=True) + ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) + # there seems to be some race, with some versions of C* taking longer to + # get the auth (and default user) setup. Sleep here to give it a chance + time.sleep(2) def teardown_module(): @@ -59,12 +64,13 @@ class AuthenticationTests(unittest.TestCase): :return: authentication object suitable for Cluster.connect() """ if PROTOCOL_VERSION < 2: - return lambda(hostname): dict(username=username, password=password) + return lambda hostname: dict(username=username, password=password) else: return PlainTextAuthProvider(username=username, password=password) def cluster_as(self, usr, pwd): return Cluster(protocol_version=PROTOCOL_VERSION, + idle_heartbeat_interval=0, auth_provider=self.get_authentication_provider(username=usr, password=pwd)) def test_auth_connect(self): @@ -77,9 +83,11 @@ class AuthenticationTests(unittest.TestCase): cluster = self.cluster_as(user, passwd) session = cluster.connect() self.assertTrue(session.execute('SELECT release_version FROM system.local')) + assert_quiescent_pool_state(self, cluster) cluster.shutdown() root_session.execute('DROP USER %s', user) + assert_quiescent_pool_state(self, root_session.cluster) root_session.cluster.shutdown() def test_connect_wrong_pwd(self): @@ -88,6 +96,8 @@ class AuthenticationTests(unittest.TestCase): '.*AuthenticationFailed.*Bad credentials.*Username and/or ' 'password are incorrect.*', cluster.connect) + assert_quiescent_pool_state(self, cluster) + cluster.shutdown() def test_connect_wrong_username(self): cluster = self.cluster_as('wrong_user', 'cassandra') @@ -95,6 +105,8 @@ class AuthenticationTests(unittest.TestCase): '.*AuthenticationFailed.*Bad credentials.*Username and/or ' 'password are incorrect.*', cluster.connect) + assert_quiescent_pool_state(self, cluster) + cluster.shutdown() def test_connect_empty_pwd(self): cluster = self.cluster_as('Cassandra', '') @@ -102,12 +114,16 @@ class AuthenticationTests(unittest.TestCase): '.*AuthenticationFailed.*Bad credentials.*Username and/or ' 'password are incorrect.*', cluster.connect) + assert_quiescent_pool_state(self, cluster) + cluster.shutdown() def test_connect_no_auth_provider(self): cluster = Cluster(protocol_version=PROTOCOL_VERSION) self.assertRaisesRegexp(NoHostAvailable, '.*AuthenticationFailed.*Remote end requires authentication.*', cluster.connect) + assert_quiescent_pool_state(self, cluster) + cluster.shutdown() class SaslAuthenticatorTests(AuthenticationTests): diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 957a4214..95e6082b 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -12,18 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.integration import use_singledc, PROTOCOL_VERSION - try: import unittest2 as unittest except ImportError: import unittest # noqa -import cassandra -from cassandra.query import SimpleStatement, TraceUnavailable -from cassandra.policies import RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance +from collections import deque +from mock import patch +import time +from uuid import uuid4 +import cassandra from cassandra.cluster import Cluster, NoHostAvailable +from cassandra.concurrent import execute_concurrent +from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy, + RetryPolicy, SimpleConvictionPolicy, HostDistance, + WhiteListRoundRobinPolicy) +from cassandra.query import SimpleStatement, TraceUnavailable + +from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions +from tests.integration.util import assert_quiescent_pool_state def setup_module(): @@ -66,6 +74,8 @@ class ClusterTests(unittest.TestCase): result = session.execute("SELECT * FROM clustertests.cf0") self.assertEqual([('a', 'b', 'c')], result) + session.execute("DROP KEYSPACE clustertests") + cluster.shutdown() def test_connect_on_keyspace(self): @@ -202,6 +212,159 @@ class ClusterTests(unittest.TestCase): self.assertIn("newkeyspace", cluster.metadata.keyspaces) + session.execute("DROP KEYSPACE newkeyspace") + + def test_refresh_schema(self): + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + original_meta = cluster.metadata.keyspaces + # full schema refresh, with wait + cluster.refresh_schema() + self.assertIsNot(original_meta, cluster.metadata.keyspaces) + self.assertEqual(original_meta, cluster.metadata.keyspaces) + + session.shutdown() + + def test_refresh_schema_keyspace(self): + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + original_meta = cluster.metadata.keyspaces + original_system_meta = original_meta['system'] + + # only refresh one keyspace + cluster.refresh_schema(keyspace='system') + current_meta = cluster.metadata.keyspaces + self.assertIs(original_meta, current_meta) + current_system_meta = current_meta['system'] + self.assertIsNot(original_system_meta, current_system_meta) + self.assertEqual(original_system_meta.as_cql_query(), current_system_meta.as_cql_query()) + session.shutdown() + + def test_refresh_schema_table(self): + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + original_meta = cluster.metadata.keyspaces + original_system_meta = original_meta['system'] + original_system_schema_meta = original_system_meta.tables['schema_columnfamilies'] + + # only refresh one table + cluster.refresh_schema(keyspace='system', table='schema_columnfamilies') + current_meta = cluster.metadata.keyspaces + current_system_meta = current_meta['system'] + current_system_schema_meta = current_system_meta.tables['schema_columnfamilies'] + self.assertIs(original_meta, current_meta) + self.assertIs(original_system_meta, current_system_meta) + self.assertIsNot(original_system_schema_meta, current_system_schema_meta) + self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query()) + session.shutdown() + + def test_refresh_schema_type(self): + if get_server_versions()[0] < (2, 1, 0): + raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1') + + if PROTOCOL_VERSION < 3: + raise unittest.SkipTest('UDTs are not specified in change events for protocol v2') + # We may want to refresh types on keyspace change events in that case(?) + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + keyspace_name = 'test1rf' + type_name = self._testMethodName + + session.execute('CREATE TYPE IF NOT EXISTS %s.%s (one int, two text)' % (keyspace_name, type_name)) + original_meta = cluster.metadata.keyspaces + original_test1rf_meta = original_meta[keyspace_name] + original_type_meta = original_test1rf_meta.user_types[type_name] + + # only refresh one type + cluster.refresh_schema(keyspace='test1rf', usertype=type_name) + current_meta = cluster.metadata.keyspaces + current_test1rf_meta = current_meta[keyspace_name] + current_type_meta = current_test1rf_meta.user_types[type_name] + self.assertIs(original_meta, current_meta) + self.assertIs(original_test1rf_meta, current_test1rf_meta) + self.assertIsNot(original_type_meta, current_type_meta) + self.assertEqual(original_type_meta.as_cql_query(), current_type_meta.as_cql_query()) + session.shutdown() + + def test_refresh_schema_no_wait(self): + + contact_points = ['127.0.0.1'] + cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=10, + contact_points=contact_points, load_balancing_policy=WhiteListRoundRobinPolicy(contact_points)) + session = cluster.connect() + + schema_ver = session.execute("SELECT schema_version FROM system.local WHERE key='local'")[0][0] + + # create a schema disagreement + session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (uuid4(),)) + + try: + agreement_timeout = 1 + + # cluster agreement wait exceeded + c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=agreement_timeout) + start_time = time.time() + s = c.connect() + end_time = time.time() + self.assertGreaterEqual(end_time - start_time, agreement_timeout) + self.assertTrue(c.metadata.keyspaces) + + # cluster agreement wait used for refresh + original_meta = c.metadata.keyspaces + start_time = time.time() + self.assertRaisesRegexp(Exception, r"Schema was not refreshed.*", c.refresh_schema) + end_time = time.time() + self.assertGreaterEqual(end_time - start_time, agreement_timeout) + self.assertIs(original_meta, c.metadata.keyspaces) + + # refresh wait overrides cluster value + original_meta = c.metadata.keyspaces + start_time = time.time() + c.refresh_schema(max_schema_agreement_wait=0) + end_time = time.time() + self.assertLess(end_time - start_time, agreement_timeout) + self.assertIsNot(original_meta, c.metadata.keyspaces) + self.assertEqual(original_meta, c.metadata.keyspaces) + + s.shutdown() + + refresh_threshold = 0.5 + # cluster agreement bypass + c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=0) + start_time = time.time() + s = c.connect() + end_time = time.time() + self.assertLess(end_time - start_time, refresh_threshold) + self.assertTrue(c.metadata.keyspaces) + + # cluster agreement wait used for refresh + original_meta = c.metadata.keyspaces + start_time = time.time() + c.refresh_schema() + end_time = time.time() + self.assertLess(end_time - start_time, refresh_threshold) + self.assertIsNot(original_meta, c.metadata.keyspaces) + self.assertEqual(original_meta, c.metadata.keyspaces) + + # refresh wait overrides cluster value + original_meta = c.metadata.keyspaces + start_time = time.time() + self.assertRaisesRegexp(Exception, r"Schema was not refreshed.*", c.refresh_schema, max_schema_agreement_wait=agreement_timeout) + end_time = time.time() + self.assertGreaterEqual(end_time - start_time, agreement_timeout) + self.assertIs(original_meta, c.metadata.keyspaces) + + s.shutdown() + finally: + session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (schema_ver,)) + + session.shutdown() + def test_trace(self): """ Ensure trace can be requested for async and non-async queries @@ -271,3 +434,115 @@ class ClusterTests(unittest.TestCase): self.assertIn(query, str(future)) self.assertIn('result', str(future)) + + def test_idle_heartbeat(self): + interval = 1 + cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=interval) + if PROTOCOL_VERSION < 3: + cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) + session = cluster.connect() + + # This test relies on impl details of connection req id management to see if heartbeats + # are being sent. May need update if impl is changed + connection_request_ids = {} + for h in cluster.get_connection_holders(): + for c in h.get_connections(): + # make sure none are idle (should have startup messages) + self.assertFalse(c.is_idle) + with c.lock: + connection_request_ids[id(c)] = deque(c.request_ids) # copy of request ids + + # let two heatbeat intervals pass (first one had startup messages in it) + time.sleep(2 * interval + interval/10.) + + connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()] + + # make sure requests were sent on all connections + for c in connections: + expected_ids = connection_request_ids[id(c)] + expected_ids.rotate(-1) + with c.lock: + self.assertListEqual(list(c.request_ids), list(expected_ids)) + + # assert idle status + self.assertTrue(all(c.is_idle for c in connections)) + + # send messages on all connections + statements_and_params = [("SELECT release_version FROM system.local", ())] * len(cluster.metadata.all_hosts()) + results = execute_concurrent(session, statements_and_params) + for success, result in results: + self.assertTrue(success) + + # assert not idle status + self.assertFalse(any(c.is_idle if not c.is_control_connection else False for c in connections)) + + # holders include session pools and cc + holders = cluster.get_connection_holders() + self.assertIn(cluster.control_connection, holders) + self.assertEqual(len(holders), len(cluster.metadata.all_hosts()) + 1) # hosts pools, 1 for cc + + # include additional sessions + session2 = cluster.connect() + + holders = cluster.get_connection_holders() + self.assertIn(cluster.control_connection, holders) + self.assertEqual(len(holders), 2 * len(cluster.metadata.all_hosts()) + 1) # 2 sessions' hosts pools, 1 for cc + + cluster._idle_heartbeat.stop() + cluster._idle_heartbeat.join() + assert_quiescent_pool_state(self, cluster) + + session.shutdown() + + @patch('cassandra.cluster.Cluster.idle_heartbeat_interval', new=0.1) + def test_idle_heartbeat_disabled(self): + self.assertTrue(Cluster.idle_heartbeat_interval) + + # heartbeat disabled with '0' + cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0) + self.assertEqual(cluster.idle_heartbeat_interval, 0) + session = cluster.connect() + + # let two heatbeat intervals pass (first one had startup messages in it) + time.sleep(2 * Cluster.idle_heartbeat_interval) + + connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()] + + # assert not idle status (should never get reset because there is not heartbeat) + self.assertFalse(any(c.is_idle for c in connections)) + + session.shutdown() + + def test_pool_management(self): + # Ensure that in_flight and request_ids quiesce after cluster operations + cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0) # no idle heartbeat here, pool management is tested in test_idle_heartbeat + session = cluster.connect() + session2 = cluster.connect() + + # prepare + p = session.prepare("SELECT * FROM system.local WHERE key=?") + self.assertTrue(session.execute(p, ('local',))) + + # simple + self.assertTrue(session.execute("SELECT * FROM system.local WHERE key='local'")) + + # set keyspace + session.set_keyspace('system') + session.set_keyspace('system_traces') + + # use keyspace + session.execute('USE system') + session.execute('USE system_traces') + + # refresh schema + cluster.refresh_schema() + cluster.refresh_schema(max_schema_agreement_wait=0) + + # submit schema refresh + future = cluster.submit_schema_refresh() + future.result() + + assert_quiescent_pool_state(self, cluster) + + session2.shutdown() + session.shutdown() diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py index d17f692a..15e2b198 100644 --- a/tests/integration/standard/test_connection.py +++ b/tests/integration/standard/test_connection.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.integration import use_singledc, PROTOCOL_VERSION - try: import unittest2 as unittest except ImportError: @@ -21,13 +19,15 @@ except ImportError: from functools import partial from six.moves import range -import sys from threading import Thread, Event from cassandra import ConsistencyLevel, OperationTimedOut from cassandra.cluster import NoHostAvailable -from cassandra.protocol import QueryMessage from cassandra.io.asyncorereactor import AsyncoreConnection +from cassandra.protocol import QueryMessage + +from tests import is_monkey_patched +from tests.integration import use_singledc, PROTOCOL_VERSION try: from cassandra.io.libevreactor import LibevConnection @@ -230,8 +230,8 @@ class AsyncoreConnectionTests(ConnectionTests, unittest.TestCase): klass = AsyncoreConnection def setUp(self): - if 'gevent.monkey' in sys.modules: - raise unittest.SkipTest("Can't test asyncore with gevent monkey patching") + if is_monkey_patched(): + raise unittest.SkipTest("Can't test asyncore with monkey patching") ConnectionTests.setUp(self) @@ -240,8 +240,8 @@ class LibevConnectionTests(ConnectionTests, unittest.TestCase): klass = LibevConnection def setUp(self): - if 'gevent.monkey' in sys.modules: - raise unittest.SkipTest("Can't test libev with gevent monkey patching") + if is_monkey_patched(): + raise unittest.SkipTest("Can't test libev with monkey patching") if LibevConnection is None: raise unittest.SkipTest( 'libev does not appear to be installed properly') diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py index 344e5491..1e46d3c5 100644 --- a/tests/integration/standard/test_metadata.py +++ b/tests/integration/standard/test_metadata.py @@ -11,15 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import six -import difflib try: import unittest2 as unittest except ImportError: import unittest # noqa +import difflib from mock import Mock +import six +import sys from cassandra import AlreadyExists @@ -361,6 +362,9 @@ class TestCodeCoverage(unittest.TestCase): "Protocol 3.0+ is required for UDT change events, currently testing against %r" % (PROTOCOL_VERSION,)) + if sys.version_info[2:] != (2, 7): + raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.') + cluster = Cluster(protocol_version=PROTOCOL_VERSION) session = cluster.connect() @@ -549,6 +553,9 @@ CREATE TABLE export_udts.users ( if get_server_versions()[0] < (2, 1, 0): raise unittest.SkipTest('Test schema output assumes 2.1.0+ options') + if sys.version_info[2:] != (2, 7): + raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.') + cli_script = """CREATE KEYSPACE legacy WITH placement_strategy = 'SimpleStrategy' AND strategy_options = {replication_factor:1}; @@ -633,7 +640,7 @@ create column family composite_comp_with_col index_type : 0}] and compression_options = {'sstable_compression' : 'org.apache.cassandra.io.compress.LZ4Compressor'};""" - # note: the innerlkey type for legacy.nested_composite_key + # note: the inner key type for legacy.nested_composite_key # (org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UUIDType, org.apache.cassandra.db.marshal.UTF8Type)) # is a bit strange, but it replays in CQL with desired results expected_string = """CREATE KEYSPACE legacy WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND durable_writes = true; @@ -807,6 +814,7 @@ CREATE TABLE legacy.composite_comp_no_col ( cluster.shutdown() + class TokenMetadataTest(unittest.TestCase): """ Test of TokenMap creation and other behavior. @@ -882,7 +890,7 @@ class KeyspaceAlterMetadata(unittest.TestCase): self.assertEqual(original_keyspace_meta.durable_writes, True) self.assertEqual(len(original_keyspace_meta.tables), 1) - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' %name) + self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % name) new_keyspace_meta = self.cluster.metadata.keyspaces[name] self.assertNotEqual(original_keyspace_meta, new_keyspace_meta) self.assertEqual(new_keyspace_meta.durable_writes, False) diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index 4944725e..01da0dad 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -21,8 +21,10 @@ except ImportError: import logging log = logging.getLogger(__name__) +from collections import namedtuple from decimal import Decimal from datetime import datetime, date, time +from functools import partial import six from uuid import uuid1, uuid4 @@ -30,10 +32,14 @@ from cassandra import InvalidRequest from cassandra.cluster import Cluster from cassandra.cqltypes import Int32Type, EMPTY from cassandra.query import dict_factory -from cassandra.util import OrderedDict, sortedset +from cassandra.util import OrderedMap, sortedset from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION +# defined in module scope for pickling in OrderedMap +nested_collection_udt = namedtuple('nested_collection_udt', ['m', 't', 'l', 's']) +nested_collection_udt_nested = namedtuple('nested_collection_udt_nested', ['m', 't', 'l', 's', 'u']) + def setup_module(): use_singledc() @@ -287,11 +293,11 @@ class TypeTests(unittest.TestCase): s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)", ('', '', '', [''], {'': 3})) self.assertEqual( - {'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedDict({'': 3})}, + {'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedMap({'': 3})}, s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0]) self.assertEqual( - {'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedDict({'': 3})}, + {'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedMap({'': 3})}, s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0]) # non-string types shouldn't accept empty strings @@ -708,10 +714,10 @@ class TypeTests(unittest.TestCase): self.assertEquals((None, None, None, None), s.execute(read)[0].t) # also test empty strings where compatible - s.execute(insert, [('', None, None, '')]) + s.execute(insert, [('', None, None, b'')]) result = s.execute("SELECT * FROM mytable WHERE k=0") - self.assertEquals(('', None, None, ''), result[0].t) - self.assertEquals(('', None, None, ''), s.execute(read)[0].t) + self.assertEquals(('', None, None, b''), result[0].t) + self.assertEquals(('', None, None, b''), s.execute(read)[0].t) c.shutdown() @@ -721,3 +727,61 @@ class TypeTests(unittest.TestCase): query = u"SELECT * FROM system.schema_columnfamilies WHERE keyspace_name = 'ef\u2052ef' AND columnfamily_name = %s" s.execute(query, (u"fe\u2051fe",)) + + def insert_select_column(self, session, table_name, column_name, value): + insert = session.prepare("INSERT INTO %s (k, %s) VALUES (?, ?)" % (table_name, column_name)) + session.execute(insert, (0, value)) + result = session.execute("SELECT %s FROM %s WHERE k=%%s" % (column_name, table_name), (0,))[0][0] + self.assertEqual(result, value) + + def test_nested_collections(self): + + if self._cass_version < (2, 1, 3): + raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3") + + name = self._testMethodName + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect('test1rf') + s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple + + s.execute(""" + CREATE TYPE %s ( + m frozen>, + t tuple, + l frozen>, + s frozen> + )""" % name) + s.execute(""" + CREATE TYPE %s_nested ( + m frozen>, + t tuple, + l frozen>, + s frozen>, + u frozen<%s> + )""" % (name, name)) + s.execute(""" + CREATE TABLE %s ( + k int PRIMARY KEY, + map_map map>, frozen>>, + map_set map>, frozen>>, + map_list map>, frozen>>, + map_tuple map>, frozen>>, + map_udt map, frozen<%s>>, + )""" % (name, name, name)) + + validate = partial(self.insert_select_column, s, name) + validate('map_map', OrderedMap([({1: 1, 2: 2}, {3: 3, 4: 4}), ({5: 5, 6: 6}, {7: 7, 8: 8})])) + validate('map_set', OrderedMap([(set((1, 2)), set((3, 4))), (set((5, 6)), set((7, 8)))])) + validate('map_list', OrderedMap([([1, 2], [3, 4]), ([5, 6], [7, 8])])) + validate('map_tuple', OrderedMap([((1, 2), (3,)), ((4, 5), (6,))])) + + value = nested_collection_udt({1: 'v1', 2: 'v2'}, (3, 'v3'), [4, 5, 6, 7], set((8, 9, 10))) + key = nested_collection_udt_nested(value.m, value.t, value.l, value.s, value) + key2 = nested_collection_udt_nested({3: 'v3'}, value.t, value.l, value.s, value) + validate('map_udt', OrderedMap([(key, value), (key2, value)])) + + s.execute("DROP TABLE %s" % (name)) + s.execute("DROP TYPE %s_nested" % (name)) + s.execute("DROP TYPE %s" % (name)) + s.shutdown() diff --git a/tests/integration/util.py b/tests/integration/util.py new file mode 100644 index 00000000..08ebac06 --- /dev/null +++ b/tests/integration/util.py @@ -0,0 +1,37 @@ +# Copyright 2013-2014 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.integration import PROTOCOL_VERSION + +def assert_quiescent_pool_state(test_case, cluster): + + for session in cluster.sessions: + pool_states = session.get_pool_state().values() + test_case.assertTrue(pool_states) + + for state in pool_states: + test_case.assertFalse(state['shutdown']) + test_case.assertGreater(state['open_count'], 0) + test_case.assertTrue(all((i == 0 for i in state['in_flights']))) + + for holder in cluster.get_connection_holders(): + for connection in holder.get_connections(): + # all ids are unique + req_ids = connection.request_ids + test_case.assertEqual(len(req_ids), len(set(req_ids))) + test_case.assertEqual(connection.highest_request_id, len(req_ids) - 1) + test_case.assertEqual(connection.highest_request_id, max(req_ids)) + if PROTOCOL_VERSION < 3: + test_case.assertEqual(connection.highest_request_id, connection.max_request_id) + diff --git a/tests/unit/io/test_asyncorereactor.py b/tests/unit/io/test_asyncorereactor.py index 6acb1ef8..c2138aa5 100644 --- a/tests/unit/io/test_asyncorereactor.py +++ b/tests/unit/io/test_asyncorereactor.py @@ -31,20 +31,20 @@ from mock import patch, Mock from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT, ConnectionException) - +from cassandra.io.asyncorereactor import AsyncoreConnection from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage, ReadyMessage, ServerError) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack -from cassandra.io.asyncorereactor import AsyncoreConnection +from tests import is_monkey_patched class AsyncoreConnectionTest(unittest.TestCase): @classmethod def setUpClass(cls): - if 'gevent.monkey' in sys.modules: - raise unittest.SkipTest("gevent monkey-patching detected") + if is_monkey_patched(): + raise unittest.SkipTest("monkey-patching detected") AsyncoreConnection.initialize_reactor() cls.socket_patcher = patch('socket.socket', spec=socket.socket) cls.mock_socket = cls.socket_patcher.start() diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 6bc50c22..39ef8ce5 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -11,21 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import six - try: import unittest2 as unittest except ImportError: - import unittest # noqa + import unittest # noqa +from mock import Mock, ANY, call +import six from six import BytesIO - -from mock import Mock, ANY +import time +from threading import Lock from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, HEADER_DIRECTION_FROM_CLIENT, ProtocolError, - locally_supported_compressions) + locally_supported_compressions, ConnectionHeartbeat) from cassandra.marshal import uint8_pack, uint32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage) @@ -280,3 +280,155 @@ class ConnectionTest(unittest.TestCase): def test_set_connection_class(self): cluster = Cluster(connection_class='test') self.assertEqual('test', cluster.connection_class) + + +class ConnectionHeartbeatTest(unittest.TestCase): + + @staticmethod + def make_get_holders(len): + holders = [] + for _ in range(len): + holder = Mock() + holder.get_connections = Mock(return_value=[]) + holders.append(holder) + get_holders = Mock(return_value=holders) + return get_holders + + def run_heartbeat(self, get_holders_fun, count=2, interval=0.05): + ch = ConnectionHeartbeat(interval, get_holders_fun) + time.sleep(interval * count) + ch.stop() + ch.join() + self.assertTrue(get_holders_fun.call_count) + + def test_empty_connections(self): + count = 3 + get_holders = self.make_get_holders(1) + + self.run_heartbeat(get_holders, count) + + self.assertGreaterEqual(get_holders.call_count, count - 1) # lower bound to account for thread spinup time + self.assertLessEqual(get_holders.call_count, count) + holder = get_holders.return_value[0] + holder.get_connections.assert_has_calls([call()] * get_holders.call_count) + + def test_idle_non_idle(self): + request_id = 999 + + # connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback) + def send_msg(msg, req_id, msg_callback): + msg_callback(SupportedMessage([], {})) + + idle_connection = Mock(spec=Connection, host='localhost', + max_request_id=127, + lock=Lock(), + in_flight=0, is_idle=True, + is_defunct=False, is_closed=False, + get_request_id=lambda: request_id, + send_msg=Mock(side_effect=send_msg)) + non_idle_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=False, is_closed=False) + + get_holders = self.make_get_holders(1) + holder = get_holders.return_value[0] + holder.get_connections.return_value.append(idle_connection) + holder.get_connections.return_value.append(non_idle_connection) + + self.run_heartbeat(get_holders) + + holder.get_connections.assert_has_calls([call()] * get_holders.call_count) + self.assertEqual(idle_connection.in_flight, 0) + self.assertEqual(non_idle_connection.in_flight, 0) + + idle_connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) + self.assertEqual(non_idle_connection.send_msg.call_count, 0) + + def test_closed_defunct(self): + get_holders = self.make_get_holders(1) + closed_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=False, is_closed=True) + defunct_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=True, is_closed=False) + holder = get_holders.return_value[0] + holder.get_connections.return_value.append(closed_connection) + holder.get_connections.return_value.append(defunct_connection) + + self.run_heartbeat(get_holders) + + holder.get_connections.assert_has_calls([call()] * get_holders.call_count) + self.assertEqual(closed_connection.in_flight, 0) + self.assertEqual(defunct_connection.in_flight, 0) + self.assertEqual(closed_connection.send_msg.call_count, 0) + self.assertEqual(defunct_connection.send_msg.call_count, 0) + + def test_no_req_ids(self): + in_flight = 3 + + get_holders = self.make_get_holders(1) + max_connection = Mock(spec=Connection, host='localhost', + max_request_id=in_flight, in_flight=in_flight, + is_idle=True, is_defunct=False, is_closed=False) + holder = get_holders.return_value[0] + holder.get_connections.return_value.append(max_connection) + + self.run_heartbeat(get_holders) + + holder.get_connections.assert_has_calls([call()] * get_holders.call_count) + self.assertEqual(max_connection.in_flight, in_flight) + self.assertEqual(max_connection.send_msg.call_count, 0) + self.assertEqual(max_connection.send_msg.call_count, 0) + max_connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) + holder.return_connection.assert_has_calls([call(max_connection)] * get_holders.call_count) + + def test_unexpected_response(self): + request_id = 999 + + get_holders = self.make_get_holders(1) + + def send_msg(msg, req_id, msg_callback): + msg_callback(object()) + + connection = Mock(spec=Connection, host='localhost', + max_request_id=127, + lock=Lock(), + in_flight=0, is_idle=True, + is_defunct=False, is_closed=False, + get_request_id=lambda: request_id, + send_msg=Mock(side_effect=send_msg)) + holder = get_holders.return_value[0] + holder.get_connections.return_value.append(connection) + + self.run_heartbeat(get_holders) + + self.assertEqual(connection.in_flight, get_holders.call_count) + connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) + connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) + exc = connection.defunct.call_args_list[0][0][0] + self.assertIsInstance(exc, Exception) + self.assertEqual(exc.args, Exception('Connection heartbeat failure').args) + holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count) + + def test_timeout(self): + request_id = 999 + + get_holders = self.make_get_holders(1) + + def send_msg(msg, req_id, msg_callback): + pass + + connection = Mock(spec=Connection, host='localhost', + max_request_id=127, + lock=Lock(), + in_flight=0, is_idle=True, + is_defunct=False, is_closed=False, + get_request_id=lambda: request_id, + send_msg=Mock(side_effect=send_msg)) + holder = get_holders.return_value[0] + holder.get_connections.return_value.append(connection) + + self.run_heartbeat(get_holders) + + self.assertEqual(connection.in_flight, get_holders.call_count) + connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) + connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) + exc = connection.defunct.call_args_list[0][0][0] + self.assertIsInstance(exc, Exception) + self.assertEqual(exc.args, Exception('Connection heartbeat failure').args) + holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count) diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 12d09781..de95dbf4 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -73,7 +73,7 @@ class MockCluster(object): self.scheduler = Mock(spec=_Scheduler) self.executor = Mock(spec=ThreadPoolExecutor) - def add_host(self, address, datacenter, rack, signal=False): + def add_host(self, address, datacenter, rack, signal=False, refresh_nodes=True): host = Host(address, SimpleConvictionPolicy, datacenter, rack) self.added_hosts.append(host) return host @@ -131,7 +131,7 @@ class ControlConnectionTest(unittest.TestCase): self.connection = MockConnection() self.time = FakeTime() - self.control_connection = ControlConnection(self.cluster, timeout=1) + self.control_connection = ControlConnection(self.cluster, 1, 0, 0) self.control_connection._connection = self.connection self.control_connection._time = self.time @@ -345,39 +345,44 @@ class ControlConnectionTest(unittest.TestCase): 'change_type': 'NEW_NODE', 'address': ('1.2.3.4', 9000) } + self.cluster.scheduler.reset_mock() self.control_connection._handle_topology_change(event) - self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map) + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map) event = { 'change_type': 'REMOVED_NODE', 'address': ('1.2.3.4', 9000) } + self.cluster.scheduler.reset_mock() self.control_connection._handle_topology_change(event) - self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.remove_host, None) + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.remove_host, None) event = { 'change_type': 'MOVED_NODE', 'address': ('1.2.3.4', 9000) } + self.cluster.scheduler.reset_mock() self.control_connection._handle_topology_change(event) - self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map) + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map) def test_handle_status_change(self): event = { 'change_type': 'UP', 'address': ('1.2.3.4', 9000) } + self.cluster.scheduler.reset_mock() self.control_connection._handle_status_change(event) - self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map) + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map) # do the same with a known Host event = { 'change_type': 'UP', 'address': ('192.168.1.0', 9000) } + self.cluster.scheduler.reset_mock() self.control_connection._handle_status_change(event) host = self.cluster.metadata.hosts['192.168.1.0'] - self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.on_up, host) + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.on_up, host) self.cluster.scheduler.schedule.reset_mock() event = { @@ -404,9 +409,11 @@ class ControlConnectionTest(unittest.TestCase): 'keyspace': 'ks1', 'table': 'table1' } + self.cluster.scheduler.reset_mock() self.control_connection._handle_schema_change(event) - self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', 'table1', None) + self.cluster.scheduler.schedule_unique.assert_called_once_with(0.0, self.control_connection.refresh_schema, 'ks1', 'table1', None) + self.cluster.scheduler.reset_mock() event['table'] = None self.control_connection._handle_schema_change(event) - self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', None, None) + self.cluster.scheduler.schedule_unique.assert_called_once_with(0.0, self.control_connection.refresh_schema, 'ks1', None, None) diff --git a/tests/unit/test_marshalling.py b/tests/unit/test_marshalling.py index a329b4f5..4adba748 100644 --- a/tests/unit/test_marshalling.py +++ b/tests/unit/test_marshalling.py @@ -24,7 +24,7 @@ from decimal import Decimal from uuid import UUID from cassandra.cqltypes import lookup_casstype -from cassandra.util import OrderedDict, sortedset +from cassandra.util import OrderedMap, sortedset marshalled_value_pairs = ( # binary form, type, python native type @@ -75,7 +75,7 @@ marshalled_value_pairs = ( (b'', 'MapType(AsciiType, BooleanType)', None), (b'', 'ListType(FloatType)', None), (b'', 'SetType(LongType)', None), - (b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()), + (b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMap()), (b'\x00\x00', 'ListType(FloatType)', []), (b'\x00\x00', 'SetType(IntegerType)', sortedset()), (b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]), @@ -84,15 +84,14 @@ marshalled_value_pairs = ( (b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', 1) ) -ordered_dict_value = OrderedDict() -ordered_dict_value[u'\u307fbob'] = 199 -ordered_dict_value[u''] = -1 -ordered_dict_value[u'\\'] = 0 +ordered_map_value = OrderedMap([(u'\u307fbob', 199), + (u'', -1), + (u'\\', 0)]) # these following entries work for me right now, but they're dependent on # vagaries of internal python ordering for unordered types marshalled_value_pairs_unsafe = ( - (b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value), + (b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_map_value), (b'\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])), (b'\x00', 'IntegerType', 0), ) diff --git a/tests/unit/test_orderedmap.py b/tests/unit/test_orderedmap.py new file mode 100644 index 00000000..3cee3a11 --- /dev/null +++ b/tests/unit/test_orderedmap.py @@ -0,0 +1,127 @@ +# Copyright 2013-2014 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.util import OrderedMap +from cassandra.cqltypes import EMPTY +import six + +class OrderedMapTest(unittest.TestCase): + def test_init(self): + a = OrderedMap(zip(['one', 'three', 'two'], [1, 3, 2])) + b = OrderedMap([('one', 1), ('three', 3), ('two', 2)]) + c = OrderedMap(a) + builtin = {'one': 1, 'two': 2, 'three': 3} + self.assertEqual(a, b) + self.assertEqual(a, c) + self.assertEqual(a, builtin) + self.assertEqual(OrderedMap([(1, 1), (1, 2)]), {1: 2}) + + def test_contains(self): + keys = ['first', 'middle', 'last'] + + od = OrderedMap() + + od = OrderedMap(zip(keys, range(len(keys)))) + + for k in keys: + self.assertTrue(k in od) + self.assertFalse(k not in od) + + self.assertTrue('notthere' not in od) + self.assertFalse('notthere' in od) + + def test_keys(self): + keys = ['first', 'middle', 'last'] + od = OrderedMap(zip(keys, range(len(keys)))) + + self.assertListEqual(list(od.keys()), keys) + + def test_values(self): + keys = ['first', 'middle', 'last'] + values = list(range(len(keys))) + od = OrderedMap(zip(keys, values)) + + self.assertListEqual(list(od.values()), values) + + def test_items(self): + keys = ['first', 'middle', 'last'] + items = list(zip(keys, range(len(keys)))) + od = OrderedMap(items) + + self.assertListEqual(list(od.items()), items) + + def test_get(self): + keys = ['first', 'middle', 'last'] + od = OrderedMap(zip(keys, range(len(keys)))) + + for v, k in enumerate(keys): + self.assertEqual(od.get(k), v) + + self.assertEqual(od.get('notthere', 'default'), 'default') + self.assertIsNone(od.get('notthere')) + + def test_equal(self): + d1 = {'one': 1} + d12 = {'one': 1, 'two': 2} + od1 = OrderedMap({'one': 1}) + od12 = OrderedMap([('one', 1), ('two', 2)]) + od21 = OrderedMap([('two', 2), ('one', 1)]) + + self.assertEqual(od1, d1) + self.assertEqual(od12, d12) + self.assertEqual(od21, d12) + self.assertNotEqual(od1, od12) + self.assertNotEqual(od12, od1) + self.assertNotEqual(od12, od21) + self.assertNotEqual(od1, d12) + self.assertNotEqual(od12, d1) + self.assertNotEqual(od1, EMPTY) + + def test_getitem(self): + keys = ['first', 'middle', 'last'] + od = OrderedMap(zip(keys, range(len(keys)))) + + for v, k in enumerate(keys): + self.assertEqual(od[k], v) + + with self.assertRaises(KeyError): + od['notthere'] + + def test_iter(self): + keys = ['first', 'middle', 'last'] + values = list(range(len(keys))) + items = list(zip(keys, values)) + od = OrderedMap(items) + + itr = iter(od) + self.assertEqual(sum([1 for _ in itr]), len(keys)) + self.assertRaises(StopIteration, six.next, itr) + + self.assertEqual(list(iter(od)), keys) + self.assertEqual(list(six.iteritems(od)), items) + self.assertEqual(list(six.itervalues(od)), values) + + def test_len(self): + self.assertEqual(len(OrderedMap()), 0) + self.assertEqual(len(OrderedMap([(1, 1)])), 1) + + def test_mutable_keys(self): + d = {'1': 1} + s = set([1, 2, 3]) + od = OrderedMap([(d, 'dict'), (s, 'set')]) diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 42e2c992..e5fb7575 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -262,11 +262,13 @@ class TypeTests(unittest.TestCase): self.assertEqual(cassandra_type.val, 'randomvaluetocheck') def test_datetype(self): - now_timestamp = time.time() - now_datetime = datetime.datetime.utcfromtimestamp(now_timestamp) + now_time_seconds = time.time() + now_datetime = datetime.datetime.utcfromtimestamp(now_time_seconds) + + # Cassandra timestamps in millis + now_timestamp = now_time_seconds * 1e3 # same results serialized - # (this could change if we follow up on the timestamp multiplication warning in DateType.serialize) self.assertEqual(DateType.serialize(now_datetime, 0), DateType.serialize(now_timestamp, 0)) # from timestamp