Merge remote-tracking branch 'origin/master' into PYTHON-190
Conflicts: cassandra/cqltypes.py tests/integration/standard/test_types.py
This commit is contained in:
@@ -1,14 +1,22 @@
|
||||
2.1.3-post
|
||||
==========
|
||||
2.1.4
|
||||
=====
|
||||
|
||||
Features
|
||||
--------
|
||||
* SaslAuthenticator for Kerberos support (PYTHON-109)
|
||||
* Heartbeat for network device keepalive and detecting failures on idle connections (PYTHON-197)
|
||||
* Support nested, frozen collections for Cassandra 2.1.3+ (PYTHON-186)
|
||||
* Schema agreement wait bypass config, new call for synchronous schema refresh (PYTHON-205)
|
||||
* Add eventlet connection support (PYTHON-194)
|
||||
|
||||
Bug Fixes
|
||||
---------
|
||||
* Schema meta fix for complex thrift tables (PYTHON-191)
|
||||
* Support for 'unknown' replica placement strategies in schema meta (PYTHON-192)
|
||||
* Resolve stream ID leak on set_keyspace (PYTHON-195)
|
||||
* Remove implicit timestamp scaling on serialization of numeric timestamps (PYTHON-204)
|
||||
* Resolve stream id collision when using SASL auth (PYTHON-210)
|
||||
* Correct unhexlify usage for user defined type meta in Python3 (PYTHON-208)
|
||||
|
||||
2.1.3
|
||||
=====
|
||||
|
||||
@@ -23,7 +23,7 @@ class NullHandler(logging.Handler):
|
||||
logging.getLogger('cassandra').addHandler(NullHandler())
|
||||
|
||||
|
||||
__version_info__ = (2, 1, 3, 'post')
|
||||
__version_info__ = (2, 1, 4, 'post')
|
||||
__version__ = '.'.join(map(str, __version_info__))
|
||||
|
||||
|
||||
|
||||
@@ -130,7 +130,9 @@ class PlainTextAuthenticator(Authenticator):
|
||||
|
||||
class SaslAuthProvider(AuthProvider):
|
||||
"""
|
||||
An :class:`~.AuthProvider` for Kerberos authenticators
|
||||
An :class:`~.AuthProvider` supporting general SASL auth mechanisms
|
||||
|
||||
Suitable for GSSAPI or other SASL mechanisms
|
||||
|
||||
Example usage::
|
||||
|
||||
@@ -144,7 +146,7 @@ class SaslAuthProvider(AuthProvider):
|
||||
auth_provider = SaslAuthProvider(**sasl_kwargs)
|
||||
cluster = Cluster(auth_provider=auth_provider)
|
||||
|
||||
.. versionadded:: 2.1.3-post
|
||||
.. versionadded:: 2.1.4
|
||||
"""
|
||||
|
||||
def __init__(self, **sasl_kwargs):
|
||||
@@ -157,9 +159,10 @@ class SaslAuthProvider(AuthProvider):
|
||||
|
||||
class SaslAuthenticator(Authenticator):
|
||||
"""
|
||||
An :class:`~.Authenticator` that works with DSE's KerberosAuthenticator.
|
||||
A pass-through :class:`~.Authenticator` using the third party package
|
||||
'pure-sasl' for authentication
|
||||
|
||||
.. versionadded:: 2.1.3-post
|
||||
.. versionadded:: 2.1.4
|
||||
"""
|
||||
|
||||
def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs):
|
||||
|
||||
@@ -22,6 +22,7 @@ import atexit
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
from random import random
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
@@ -44,7 +45,8 @@ from itertools import groupby
|
||||
from cassandra import (ConsistencyLevel, AuthenticationFailed,
|
||||
InvalidRequest, OperationTimedOut,
|
||||
UnsupportedOperation, Unauthorized)
|
||||
from cassandra.connection import ConnectionException, ConnectionShutdown
|
||||
from cassandra.connection import (ConnectionException, ConnectionShutdown,
|
||||
ConnectionHeartbeat)
|
||||
from cassandra.encoder import Encoder
|
||||
from cassandra.protocol import (QueryMessage, ResultMessage,
|
||||
ErrorMessage, ReadTimeoutErrorMessage,
|
||||
@@ -68,10 +70,20 @@ from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
|
||||
BatchStatement, bind_params, QueryTrace, Statement,
|
||||
named_tuple_factory, dict_factory, FETCH_SIZE_UNSET)
|
||||
|
||||
# default to gevent when we are monkey patched, otherwise if libev is available, use that as the
|
||||
# default because it's fastest. Otherwise, use asyncore.
|
||||
|
||||
def _is_eventlet_monkey_patched():
|
||||
if 'eventlet.patcher' not in sys.modules:
|
||||
return False
|
||||
import eventlet.patcher
|
||||
return eventlet.patcher.is_monkey_patched('socket')
|
||||
|
||||
# default to gevent when we are monkey patched with gevent, eventlet when
|
||||
# monkey patched with eventlet, otherwise if libev is available, use that as
|
||||
# the default because it's fastest. Otherwise, use asyncore.
|
||||
if 'gevent.monkey' in sys.modules:
|
||||
from cassandra.io.geventreactor import GeventConnection as DefaultConnection
|
||||
elif _is_eventlet_monkey_patched():
|
||||
from cassandra.io.eventletreactor import EventletConnection as DefaultConnection
|
||||
else:
|
||||
try:
|
||||
from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA
|
||||
@@ -372,6 +384,48 @@ class Cluster(object):
|
||||
If set to :const:`None`, there will be no timeout for these queries.
|
||||
"""
|
||||
|
||||
idle_heartbeat_interval = 30
|
||||
"""
|
||||
Interval, in seconds, on which to heartbeat idle connections. This helps
|
||||
keep connections open through network devices that expire idle connections.
|
||||
It also helps discover bad connections early in low-traffic scenarios.
|
||||
Setting to zero disables heartbeats.
|
||||
"""
|
||||
|
||||
schema_event_refresh_window = 2
|
||||
"""
|
||||
Window, in seconds, within which a schema component will be refreshed after
|
||||
receiving a schema_change event.
|
||||
|
||||
The driver delays a random amount of time in the range [0.0, window)
|
||||
before executing the refresh. This serves two purposes:
|
||||
|
||||
1.) Spread the refresh for deployments with large fanout from C* to client tier,
|
||||
preventing a 'thundering herd' problem with many clients refreshing simultaneously.
|
||||
|
||||
2.) Remove redundant refreshes. Redundant events arriving within the delay period
|
||||
are discarded, and only one refresh is executed.
|
||||
|
||||
Setting this to zero will execute refreshes immediately.
|
||||
|
||||
Setting this negative will disable schema refreshes in response to push events
|
||||
(refreshes will still occur in response to schema change responses to DDL statements
|
||||
executed by Sessions of this Cluster).
|
||||
"""
|
||||
|
||||
topology_event_refresh_window = 10
|
||||
"""
|
||||
Window, in seconds, within which the node and token list will be refreshed after
|
||||
receiving a topology_change event.
|
||||
|
||||
Setting this to zero will execute refreshes immediately.
|
||||
|
||||
Setting this negative will disable node refreshes in response to push events
|
||||
(refreshes will still occur in response to new nodes observed on "UP" events).
|
||||
|
||||
See :attr:`.schema_event_refresh_window` for discussion of rationale
|
||||
"""
|
||||
|
||||
sessions = None
|
||||
control_connection = None
|
||||
scheduler = None
|
||||
@@ -380,6 +434,7 @@ class Cluster(object):
|
||||
_is_setup = False
|
||||
_prepared_statements = None
|
||||
_prepared_statement_lock = None
|
||||
_idle_heartbeat = None
|
||||
|
||||
_user_types = None
|
||||
"""
|
||||
@@ -406,7 +461,10 @@ class Cluster(object):
|
||||
protocol_version=2,
|
||||
executor_threads=2,
|
||||
max_schema_agreement_wait=10,
|
||||
control_connection_timeout=2.0):
|
||||
control_connection_timeout=2.0,
|
||||
idle_heartbeat_interval=30,
|
||||
schema_event_refresh_window=2,
|
||||
topology_event_refresh_window=10):
|
||||
"""
|
||||
Any of the mutable Cluster attributes may be set as keyword arguments
|
||||
to the constructor.
|
||||
@@ -456,6 +514,9 @@ class Cluster(object):
|
||||
self.cql_version = cql_version
|
||||
self.max_schema_agreement_wait = max_schema_agreement_wait
|
||||
self.control_connection_timeout = control_connection_timeout
|
||||
self.idle_heartbeat_interval = idle_heartbeat_interval
|
||||
self.schema_event_refresh_window = schema_event_refresh_window
|
||||
self.topology_event_refresh_window = topology_event_refresh_window
|
||||
|
||||
self._listeners = set()
|
||||
self._listener_lock = Lock()
|
||||
@@ -500,7 +561,8 @@ class Cluster(object):
|
||||
self.metrics = Metrics(weakref.proxy(self))
|
||||
|
||||
self.control_connection = ControlConnection(
|
||||
self, self.control_connection_timeout)
|
||||
self, self.control_connection_timeout,
|
||||
self.schema_event_refresh_window, self.topology_event_refresh_window)
|
||||
|
||||
def register_user_type(self, keyspace, user_type, klass):
|
||||
"""
|
||||
@@ -621,7 +683,7 @@ class Cluster(object):
|
||||
|
||||
def set_max_connections_per_host(self, host_distance, max_connections):
|
||||
"""
|
||||
Gets the maximum number of connections per Session that will be opened
|
||||
Sets the maximum number of connections per Session that will be opened
|
||||
for each host with :class:`~.HostDistance` equal to `host_distance`.
|
||||
The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for
|
||||
:attr:`~HostDistance.REMOTE`.
|
||||
@@ -688,18 +750,19 @@ class Cluster(object):
|
||||
self.load_balancing_policy.populate(
|
||||
weakref.proxy(self), self.metadata.all_hosts())
|
||||
|
||||
if self.control_connection:
|
||||
try:
|
||||
self.control_connection.connect()
|
||||
log.debug("Control connection created")
|
||||
except Exception:
|
||||
log.exception("Control connection failed to connect, "
|
||||
"shutting down Cluster:")
|
||||
self.shutdown()
|
||||
raise
|
||||
try:
|
||||
self.control_connection.connect()
|
||||
log.debug("Control connection created")
|
||||
except Exception:
|
||||
log.exception("Control connection failed to connect, "
|
||||
"shutting down Cluster:")
|
||||
self.shutdown()
|
||||
raise
|
||||
|
||||
self.load_balancing_policy.check_supported()
|
||||
|
||||
if self.idle_heartbeat_interval:
|
||||
self._idle_heartbeat = ConnectionHeartbeat(self.idle_heartbeat_interval, self.get_connection_holders)
|
||||
self._is_setup = True
|
||||
|
||||
session = self._new_session()
|
||||
@@ -707,6 +770,13 @@ class Cluster(object):
|
||||
session.set_keyspace(keyspace)
|
||||
return session
|
||||
|
||||
def get_connection_holders(self):
|
||||
holders = []
|
||||
for s in self.sessions:
|
||||
holders.extend(s.get_pools())
|
||||
holders.append(self.control_connection)
|
||||
return holders
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Closes all sessions and connection associated with this Cluster.
|
||||
@@ -721,18 +791,17 @@ class Cluster(object):
|
||||
else:
|
||||
self.is_shutdown = True
|
||||
|
||||
if self.scheduler:
|
||||
self.scheduler.shutdown()
|
||||
if self._idle_heartbeat:
|
||||
self._idle_heartbeat.stop()
|
||||
|
||||
if self.control_connection:
|
||||
self.control_connection.shutdown()
|
||||
self.scheduler.shutdown()
|
||||
|
||||
if self.sessions:
|
||||
for session in self.sessions:
|
||||
session.shutdown()
|
||||
self.control_connection.shutdown()
|
||||
|
||||
if self.executor:
|
||||
self.executor.shutdown()
|
||||
for session in self.sessions:
|
||||
session.shutdown()
|
||||
|
||||
self.executor.shutdown()
|
||||
|
||||
def _new_session(self):
|
||||
session = Session(self, self.metadata.all_hosts())
|
||||
@@ -907,7 +976,7 @@ class Cluster(object):
|
||||
|
||||
self._start_reconnector(host, is_host_addition)
|
||||
|
||||
def on_add(self, host):
|
||||
def on_add(self, host, refresh_nodes=True):
|
||||
if self.is_shutdown:
|
||||
return
|
||||
|
||||
@@ -919,7 +988,7 @@ class Cluster(object):
|
||||
log.debug("Done preparing queries for new host %r", host)
|
||||
|
||||
self.load_balancing_policy.on_add(host)
|
||||
self.control_connection.on_add(host)
|
||||
self.control_connection.on_add(host, refresh_nodes)
|
||||
|
||||
if distance == HostDistance.IGNORED:
|
||||
log.debug("Not adding connection pool for new host %r because the "
|
||||
@@ -995,7 +1064,7 @@ class Cluster(object):
|
||||
self.on_down(host, is_host_addition, expect_host_to_be_down)
|
||||
return is_down
|
||||
|
||||
def add_host(self, address, datacenter=None, rack=None, signal=True):
|
||||
def add_host(self, address, datacenter=None, rack=None, signal=True, refresh_nodes=True):
|
||||
"""
|
||||
Called when adding initial contact points and when the control
|
||||
connection subsequently discovers a new node. Intended for internal
|
||||
@@ -1004,7 +1073,7 @@ class Cluster(object):
|
||||
new_host = self.metadata.add_host(address, datacenter, rack)
|
||||
if new_host and signal:
|
||||
log.info("New Cassandra host %r discovered", new_host)
|
||||
self.on_add(new_host)
|
||||
self.on_add(new_host, refresh_nodes)
|
||||
|
||||
return new_host
|
||||
|
||||
@@ -1045,16 +1114,20 @@ class Cluster(object):
|
||||
for pool in session._pools.values():
|
||||
pool.ensure_core_connections()
|
||||
|
||||
def refresh_schema(self, keyspace=None, table=None, usertype=None, schema_agreement_wait=None):
|
||||
def refresh_schema(self, keyspace=None, table=None, usertype=None, max_schema_agreement_wait=None):
|
||||
"""
|
||||
Synchronously refresh the schema metadata.
|
||||
By default timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait`
|
||||
|
||||
By default, the timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait`
|
||||
and :attr:`~.Cluster.control_connection_timeout`.
|
||||
Passing schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`.
|
||||
Setting schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately.
|
||||
|
||||
Passing max_schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`.
|
||||
|
||||
Setting max_schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately.
|
||||
|
||||
An Exception is raised if schema refresh fails for any reason.
|
||||
"""
|
||||
if not self.control_connection.refresh_schema(keyspace, table, usertype, schema_agreement_wait):
|
||||
if not self.control_connection.refresh_schema(keyspace, table, usertype, max_schema_agreement_wait):
|
||||
raise Exception("Schema was not refreshed. See log for details.")
|
||||
|
||||
def submit_schema_refresh(self, keyspace=None, table=None, usertype=None):
|
||||
@@ -1066,6 +1139,27 @@ class Cluster(object):
|
||||
return self.executor.submit(
|
||||
self.control_connection.refresh_schema, keyspace, table, usertype)
|
||||
|
||||
def refresh_nodes(self):
|
||||
"""
|
||||
Synchronously refresh the node list and token metadata
|
||||
|
||||
An Exception is raised if node refresh fails for any reason.
|
||||
"""
|
||||
if not self.control_connection.refresh_node_list_and_token_map():
|
||||
raise Exception("Node list was not refreshed. See log for details.")
|
||||
|
||||
def set_meta_refresh_enabled(self, enabled):
|
||||
"""
|
||||
Sets a flag to enable (True) or disable (False) all metadata refresh queries.
|
||||
This applies to both schema and node topology.
|
||||
|
||||
Disabling this is useful to minimize refreshes during multiple changes.
|
||||
|
||||
Meta refresh must be enabled for the driver to become aware of any cluster
|
||||
topology changes or schema updates.
|
||||
"""
|
||||
self.control_connection.set_meta_refresh_enabled(bool(enabled))
|
||||
|
||||
def _prepare_all_queries(self, host):
|
||||
if not self._prepared_statements:
|
||||
return
|
||||
@@ -1656,6 +1750,9 @@ class Session(object):
|
||||
def get_pool_state(self):
|
||||
return dict((host, pool.get_state()) for host, pool in self._pools.items())
|
||||
|
||||
def get_pools(self):
|
||||
return self._pools.values()
|
||||
|
||||
|
||||
class UserTypeDoesNotExist(Exception):
|
||||
"""
|
||||
@@ -1734,16 +1831,26 @@ class ControlConnection(object):
|
||||
_timeout = None
|
||||
_protocol_version = None
|
||||
|
||||
_schema_event_refresh_window = None
|
||||
_topology_event_refresh_window = None
|
||||
|
||||
_meta_refresh_enabled = True
|
||||
|
||||
# for testing purposes
|
||||
_time = time
|
||||
|
||||
def __init__(self, cluster, timeout):
|
||||
def __init__(self, cluster, timeout,
|
||||
schema_event_refresh_window,
|
||||
topology_event_refresh_window):
|
||||
# use a weak reference to allow the Cluster instance to be GC'ed (and
|
||||
# shutdown) since implementing __del__ disables the cycle detector
|
||||
self._cluster = weakref.proxy(cluster)
|
||||
self._connection = None
|
||||
self._timeout = timeout
|
||||
|
||||
self._schema_event_refresh_window = schema_event_refresh_window
|
||||
self._topology_event_refresh_window = topology_event_refresh_window
|
||||
|
||||
self._lock = RLock()
|
||||
self._schema_agreement_lock = Lock()
|
||||
|
||||
@@ -1901,6 +2008,10 @@ class ControlConnection(object):
|
||||
|
||||
def refresh_schema(self, keyspace=None, table=None, usertype=None,
|
||||
schema_agreement_wait=None):
|
||||
if not self._meta_refresh_enabled:
|
||||
log.debug("[control connection] Skipping schema refresh because meta refresh is disabled")
|
||||
return False
|
||||
|
||||
try:
|
||||
if self._connection:
|
||||
return self._refresh_schema(self._connection, keyspace, table, usertype,
|
||||
@@ -2028,14 +2139,20 @@ class ControlConnection(object):
|
||||
return True
|
||||
|
||||
def refresh_node_list_and_token_map(self, force_token_rebuild=False):
|
||||
if not self._meta_refresh_enabled:
|
||||
log.debug("[control connection] Skipping node list refresh because meta refresh is disabled")
|
||||
return False
|
||||
|
||||
try:
|
||||
if self._connection:
|
||||
self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild)
|
||||
return True
|
||||
except ReferenceError:
|
||||
pass # our weak reference to the Cluster is no good
|
||||
except Exception:
|
||||
log.debug("[control connection] Error refreshing node list and token map", exc_info=True)
|
||||
self._signal_error()
|
||||
return False
|
||||
|
||||
def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
|
||||
force_token_rebuild=False):
|
||||
@@ -2096,7 +2213,7 @@ class ControlConnection(object):
|
||||
rack = row.get("rack")
|
||||
if host is None:
|
||||
log.debug("[control connection] Found new host to connect to: %s", addr)
|
||||
host = self._cluster.add_host(addr, datacenter, rack, signal=True)
|
||||
host = self._cluster.add_host(addr, datacenter, rack, signal=True, refresh_nodes=False)
|
||||
should_rebuild_token_map = True
|
||||
else:
|
||||
should_rebuild_token_map |= self._update_location_info(host, datacenter, rack)
|
||||
@@ -2131,25 +2248,25 @@ class ControlConnection(object):
|
||||
def _handle_topology_change(self, event):
|
||||
change_type = event["change_type"]
|
||||
addr, port = event["address"]
|
||||
if change_type == "NEW_NODE":
|
||||
self._cluster.scheduler.schedule(10, self.refresh_node_list_and_token_map)
|
||||
if change_type == "NEW_NODE" or change_type == "MOVED_NODE":
|
||||
if self._topology_event_refresh_window >= 0:
|
||||
delay = random() * self._topology_event_refresh_window
|
||||
self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map)
|
||||
elif change_type == "REMOVED_NODE":
|
||||
host = self._cluster.metadata.get_host(addr)
|
||||
self._cluster.scheduler.schedule(0, self._cluster.remove_host, host)
|
||||
elif change_type == "MOVED_NODE":
|
||||
self._cluster.scheduler.schedule(1, self.refresh_node_list_and_token_map)
|
||||
self._cluster.scheduler.schedule_unique(0, self._cluster.remove_host, host)
|
||||
|
||||
def _handle_status_change(self, event):
|
||||
change_type = event["change_type"]
|
||||
addr, port = event["address"]
|
||||
host = self._cluster.metadata.get_host(addr)
|
||||
if change_type == "UP":
|
||||
delay = 1 + random() * 0.5 # randomness to avoid thundering herd problem on events
|
||||
if host is None:
|
||||
# this is the first time we've seen the node
|
||||
self._cluster.scheduler.schedule(2, self.refresh_node_list_and_token_map)
|
||||
self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map)
|
||||
else:
|
||||
# this will be run by the scheduler
|
||||
self._cluster.scheduler.schedule(2, self._cluster.on_up, host)
|
||||
self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host)
|
||||
elif change_type == "DOWN":
|
||||
# Note that there is a slight risk we can receive the event late and thus
|
||||
# mark the host down even though we already had reconnected successfully.
|
||||
@@ -2160,10 +2277,14 @@ class ControlConnection(object):
|
||||
self._cluster.on_down(host, is_host_addition=False)
|
||||
|
||||
def _handle_schema_change(self, event):
|
||||
if self._schema_event_refresh_window < 0:
|
||||
return
|
||||
|
||||
keyspace = event.get('keyspace')
|
||||
table = event.get('table')
|
||||
usertype = event.get('type')
|
||||
self._submit(self.refresh_schema, keyspace, table, usertype)
|
||||
delay = random() * self._schema_event_refresh_window
|
||||
self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, keyspace, table, usertype)
|
||||
|
||||
def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None):
|
||||
|
||||
@@ -2271,11 +2392,6 @@ class ControlConnection(object):
|
||||
# manually
|
||||
self.reconnect()
|
||||
|
||||
@property
|
||||
def is_open(self):
|
||||
conn = self._connection
|
||||
return bool(conn and conn.is_open)
|
||||
|
||||
def on_up(self, host):
|
||||
pass
|
||||
|
||||
@@ -2289,12 +2405,24 @@ class ControlConnection(object):
|
||||
# this will result in a task being submitted to the executor to reconnect
|
||||
self.reconnect()
|
||||
|
||||
def on_add(self, host):
|
||||
self.refresh_node_list_and_token_map(force_token_rebuild=True)
|
||||
def on_add(self, host, refresh_nodes=True):
|
||||
if refresh_nodes:
|
||||
self.refresh_node_list_and_token_map(force_token_rebuild=True)
|
||||
|
||||
def on_remove(self, host):
|
||||
self.refresh_node_list_and_token_map(force_token_rebuild=True)
|
||||
|
||||
def get_connections(self):
|
||||
c = getattr(self, '_connection', None)
|
||||
return [c] if c else []
|
||||
|
||||
def return_connection(self, connection):
|
||||
if connection is self._connection and (connection.is_defunct or connection.is_closed):
|
||||
self.reconnect()
|
||||
|
||||
def set_meta_refresh_enabled(self, enabled):
|
||||
self._meta_refresh_enabled = enabled
|
||||
|
||||
|
||||
def _stop_scheduler(scheduler, thread):
|
||||
try:
|
||||
@@ -2308,12 +2436,14 @@ def _stop_scheduler(scheduler, thread):
|
||||
|
||||
class _Scheduler(object):
|
||||
|
||||
_scheduled = None
|
||||
_queue = None
|
||||
_scheduled_tasks = None
|
||||
_executor = None
|
||||
is_shutdown = False
|
||||
|
||||
def __init__(self, executor):
|
||||
self._scheduled = Queue.PriorityQueue()
|
||||
self._queue = Queue.PriorityQueue()
|
||||
self._scheduled_tasks = set()
|
||||
self._executor = executor
|
||||
|
||||
t = Thread(target=self.run, name="Task Scheduler")
|
||||
@@ -2331,14 +2461,25 @@ class _Scheduler(object):
|
||||
# this can happen on interpreter shutdown
|
||||
pass
|
||||
self.is_shutdown = True
|
||||
self._scheduled.put_nowait((0, None))
|
||||
self._queue.put_nowait((0, None))
|
||||
|
||||
def schedule(self, delay, fn, *args, **kwargs):
|
||||
def schedule(self, delay, fn, *args):
|
||||
self._insert_task(delay, (fn, args))
|
||||
|
||||
def schedule_unique(self, delay, fn, *args):
|
||||
task = (fn, args)
|
||||
if task not in self._scheduled_tasks:
|
||||
self._insert_task(delay, task)
|
||||
else:
|
||||
log.debug("Ignoring schedule_unique for already-scheduled task: %r", task)
|
||||
|
||||
def _insert_task(self, delay, task):
|
||||
if not self.is_shutdown:
|
||||
run_at = time.time() + delay
|
||||
self._scheduled.put_nowait((run_at, (fn, args, kwargs)))
|
||||
self._scheduled_tasks.add(task)
|
||||
self._queue.put_nowait((run_at, task))
|
||||
else:
|
||||
log.debug("Ignoring scheduled function after shutdown: %r", fn)
|
||||
log.debug("Ignoring scheduled task after shutdown: %r", task)
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
@@ -2347,16 +2488,17 @@ class _Scheduler(object):
|
||||
|
||||
try:
|
||||
while True:
|
||||
run_at, task = self._scheduled.get(block=True, timeout=None)
|
||||
run_at, task = self._queue.get(block=True, timeout=None)
|
||||
if self.is_shutdown:
|
||||
log.debug("Not executing scheduled task due to Scheduler shutdown")
|
||||
return
|
||||
if run_at <= time.time():
|
||||
fn, args, kwargs = task
|
||||
future = self._executor.submit(fn, *args, **kwargs)
|
||||
self._scheduled_tasks.remove(task)
|
||||
fn, args = task
|
||||
future = self._executor.submit(fn, *args)
|
||||
future.add_done_callback(self._log_if_failed)
|
||||
else:
|
||||
self._scheduled.put_nowait((run_at, task))
|
||||
self._queue.put_nowait((run_at, task))
|
||||
break
|
||||
except Queue.Empty:
|
||||
pass
|
||||
@@ -2373,9 +2515,13 @@ class _Scheduler(object):
|
||||
|
||||
def refresh_schema_and_set_result(keyspace, table, usertype, control_conn, response_future):
|
||||
try:
|
||||
log.debug("Refreshing schema in response to schema change. Keyspace: %s; Table: %s, Type: %s",
|
||||
keyspace, table, usertype)
|
||||
control_conn._refresh_schema(response_future._connection, keyspace, table, usertype)
|
||||
if control_conn._meta_refresh_enabled:
|
||||
log.debug("Refreshing schema in response to schema change. Keyspace: %s; Table: %s, Type: %s",
|
||||
keyspace, table, usertype)
|
||||
control_conn._refresh_schema(response_future._connection, keyspace, table, usertype)
|
||||
else:
|
||||
log.debug("Skipping schema refresh in response to schema change because meta refresh is disabled; "
|
||||
"Keyspace: %s; Table: %s, Type: %s", keyspace, table, usertype)
|
||||
except Exception:
|
||||
log.exception("Exception refreshing schema in response to schema change:")
|
||||
response_future.session.submit(
|
||||
|
||||
@@ -20,7 +20,7 @@ import io
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from threading import Event, RLock
|
||||
from threading import Thread, Event, RLock
|
||||
import time
|
||||
|
||||
if 'gevent.monkey' in sys.modules:
|
||||
@@ -159,7 +159,7 @@ class Connection(object):
|
||||
in_flight = 0
|
||||
|
||||
# A set of available request IDs. When using the v3 protocol or higher,
|
||||
# this will no initially include all request IDs in order to save memory,
|
||||
# this will not initially include all request IDs in order to save memory,
|
||||
# but the set will grow if it is exhausted.
|
||||
request_ids = None
|
||||
|
||||
@@ -172,6 +172,8 @@ class Connection(object):
|
||||
lock = None
|
||||
user_type_map = None
|
||||
|
||||
msg_received = False
|
||||
|
||||
is_control_connection = False
|
||||
_iobuf = None
|
||||
|
||||
@@ -401,6 +403,8 @@ class Connection(object):
|
||||
with self.lock:
|
||||
self.request_ids.append(stream_id)
|
||||
|
||||
self.msg_received = True
|
||||
|
||||
body = None
|
||||
try:
|
||||
# check that the protocol version is supported
|
||||
@@ -673,6 +677,13 @@ class Connection(object):
|
||||
|
||||
self.send_msg(query, request_id, process_result)
|
||||
|
||||
@property
|
||||
def is_idle(self):
|
||||
return not self.msg_received
|
||||
|
||||
def reset_idle(self):
|
||||
self.msg_received = False
|
||||
|
||||
def __str__(self):
|
||||
status = ""
|
||||
if self.is_defunct:
|
||||
@@ -732,3 +743,100 @@ class ResponseWaiter(object):
|
||||
raise OperationTimedOut()
|
||||
else:
|
||||
return self.responses
|
||||
|
||||
|
||||
class HeartbeatFuture(object):
|
||||
def __init__(self, connection, owner):
|
||||
self._exception = None
|
||||
self._event = Event()
|
||||
self.connection = connection
|
||||
self.owner = owner
|
||||
log.debug("Sending options message heartbeat on idle connection (%s) %s",
|
||||
id(connection), connection.host)
|
||||
with connection.lock:
|
||||
if connection.in_flight < connection.max_request_id:
|
||||
connection.in_flight += 1
|
||||
connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback)
|
||||
else:
|
||||
self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold")
|
||||
self._event.set()
|
||||
|
||||
def wait(self, timeout):
|
||||
self._event.wait(timeout)
|
||||
if self._event.is_set():
|
||||
if self._exception:
|
||||
raise self._exception
|
||||
else:
|
||||
raise OperationTimedOut()
|
||||
|
||||
def _options_callback(self, response):
|
||||
if not isinstance(response, SupportedMessage):
|
||||
if isinstance(response, ConnectionException):
|
||||
self._exception = response
|
||||
else:
|
||||
self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s"
|
||||
% (response,))
|
||||
|
||||
log.debug("Received options response on connection (%s) from %s",
|
||||
id(self.connection), self.connection.host)
|
||||
self._event.set()
|
||||
|
||||
|
||||
class ConnectionHeartbeat(Thread):
|
||||
|
||||
def __init__(self, interval_sec, get_connection_holders):
|
||||
Thread.__init__(self, name="Connection heartbeat")
|
||||
self._interval = interval_sec
|
||||
self._get_connection_holders = get_connection_holders
|
||||
self._shutdown_event = Event()
|
||||
self.daemon = True
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
self._shutdown_event.wait(self._interval)
|
||||
while not self._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
|
||||
futures = []
|
||||
failed_connections = []
|
||||
try:
|
||||
for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]:
|
||||
for connection in connections:
|
||||
if not (connection.is_defunct or connection.is_closed):
|
||||
if connection.is_idle:
|
||||
try:
|
||||
futures.append(HeartbeatFuture(connection, owner))
|
||||
except Exception:
|
||||
log.warning("Failed sending heartbeat message on connection (%s) to %s",
|
||||
id(connection), connection.host, exc_info=True)
|
||||
failed_connections.append((connection, owner))
|
||||
else:
|
||||
connection.reset_idle()
|
||||
else:
|
||||
# make sure the owner sees this defunt/closed connection
|
||||
owner.return_connection(connection)
|
||||
|
||||
for f in futures:
|
||||
connection = f.connection
|
||||
try:
|
||||
f.wait(self._interval)
|
||||
# TODO: move this, along with connection locks in pool, down into Connection
|
||||
with connection.lock:
|
||||
connection.in_flight -= 1
|
||||
connection.reset_idle()
|
||||
except Exception:
|
||||
log.warning("Heartbeat failed for connection (%s) to %s",
|
||||
id(connection), connection.host, exc_info=True)
|
||||
failed_connections.append((f.connection, f.owner))
|
||||
|
||||
for connection, owner in failed_connections:
|
||||
connection.defunct(Exception('Connection heartbeat failure'))
|
||||
owner.return_connection(connection)
|
||||
except Exception:
|
||||
log.error("Failed connection heartbeat", exc_info=True)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
self._shutdown_event.wait(max(self._interval - elapsed, 0.01))
|
||||
|
||||
def stop(self):
|
||||
self._shutdown_event.set()
|
||||
|
||||
@@ -31,14 +31,14 @@ from __future__ import absolute_import # to enable import io from stdlib
|
||||
from binascii import unhexlify
|
||||
import calendar
|
||||
from collections import namedtuple
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
import io
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
import datetime
|
||||
import sys
|
||||
from uuid import UUID
|
||||
import warnings
|
||||
|
||||
import six
|
||||
from six.moves import range
|
||||
@@ -48,7 +48,7 @@ from cassandra.marshal import (int8_pack, int8_unpack,
|
||||
int32_pack, int32_unpack, int64_pack, int64_unpack,
|
||||
float_pack, float_unpack, double_pack, double_unpack,
|
||||
varint_pack, varint_unpack)
|
||||
from cassandra.util import OrderedDict, sortedset
|
||||
from cassandra.util import OrderedMap, sortedset
|
||||
|
||||
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
|
||||
|
||||
@@ -58,10 +58,15 @@ if six.PY3:
|
||||
_time_types = frozenset((int,))
|
||||
_date_types = frozenset((int,))
|
||||
long = int
|
||||
|
||||
def _name_from_hex_string(encoded_name):
|
||||
bin_str = unhexlify(encoded_name)
|
||||
return bin_str.decode('ascii')
|
||||
else:
|
||||
_number_types = frozenset((int, long, float))
|
||||
_time_types = frozenset((int, long))
|
||||
_date_types = frozenset((int, long))
|
||||
_name_from_hex_string = unhexlify
|
||||
|
||||
|
||||
def trim_if_startswith(s, prefix):
|
||||
@@ -569,7 +574,8 @@ class DateType(_CassandraType):
|
||||
tval = time.strptime(val, tformat)
|
||||
except ValueError:
|
||||
continue
|
||||
return calendar.timegm(tval) + offset
|
||||
# scale seconds to millis for the raw value
|
||||
return (calendar.timegm(tval) + offset) * 1e3
|
||||
else:
|
||||
raise ValueError("can't interpret %r as a date" % (val,))
|
||||
|
||||
@@ -584,31 +590,16 @@ class DateType(_CassandraType):
|
||||
@staticmethod
|
||||
def serialize(v, protocol_version):
|
||||
try:
|
||||
converted = calendar.timegm(v.utctimetuple())
|
||||
converted = converted * 1e3 + getattr(v, 'microsecond', 0) / 1e3
|
||||
# v is datetime
|
||||
timestamp_seconds = calendar.timegm(v.utctimetuple())
|
||||
timestamp = timestamp_seconds * 1e3 + getattr(v, 'microsecond', 0) / 1e3
|
||||
except AttributeError:
|
||||
# Ints and floats are valid timestamps too
|
||||
if type(v) not in _number_types:
|
||||
raise TypeError('DateType arguments must be a datetime or timestamp')
|
||||
timestamp = v
|
||||
|
||||
global _have_warned_about_timestamps
|
||||
if not _have_warned_about_timestamps:
|
||||
_have_warned_about_timestamps = True
|
||||
warnings.warn(
|
||||
"timestamp columns in Cassandra hold a number of "
|
||||
"milliseconds since the unix epoch. Currently, when executing "
|
||||
"prepared statements, this driver multiplies timestamp "
|
||||
"values by 1000 so that the result of time.time() "
|
||||
"can be used directly. However, the driver cannot "
|
||||
"match this behavior for non-prepared statements, "
|
||||
"so the 2.0 version of the driver will no longer multiply "
|
||||
"timestamps by 1000. It is suggested that you simply use "
|
||||
"datetime.datetime objects for 'timestamp' values to avoid "
|
||||
"any ambiguity and to guarantee a smooth upgrade of the "
|
||||
"driver.")
|
||||
converted = v * 1e3
|
||||
|
||||
return int64_pack(long(converted))
|
||||
return int64_pack(long(timestamp))
|
||||
|
||||
|
||||
class TimestampType(DateType):
|
||||
@@ -838,7 +829,7 @@ class MapType(_ParameterizedType):
|
||||
length = 2
|
||||
numelements = unpack(byts[:length])
|
||||
p = length
|
||||
themap = OrderedDict()
|
||||
themap = OrderedMap()
|
||||
for _ in range(numelements):
|
||||
key_len = unpack(byts[p:p + length])
|
||||
p += length
|
||||
@@ -850,7 +841,7 @@ class MapType(_ParameterizedType):
|
||||
p += val_len
|
||||
key = subkeytype.from_binary(keybytes, protocol_version)
|
||||
val = subvaltype.from_binary(valbytes, protocol_version)
|
||||
themap[key] = val
|
||||
themap._insert(key, val)
|
||||
return themap
|
||||
|
||||
@classmethod
|
||||
@@ -929,37 +920,39 @@ class UserType(TupleType):
|
||||
typename = "'org.apache.cassandra.db.marshal.UserType'"
|
||||
|
||||
_cache = {}
|
||||
_module = sys.modules[__name__]
|
||||
|
||||
@classmethod
|
||||
def make_udt_class(cls, keyspace, udt_name, names_and_types, mapped_class):
|
||||
if six.PY2 and isinstance(udt_name, unicode):
|
||||
udt_name = udt_name.encode('utf-8')
|
||||
|
||||
try:
|
||||
return cls._cache[(keyspace, udt_name)]
|
||||
except KeyError:
|
||||
fieldnames, types = zip(*names_and_types)
|
||||
field_names, types = zip(*names_and_types)
|
||||
instance = type(udt_name, (cls,), {'subtypes': types,
|
||||
'cassname': cls.cassname,
|
||||
'typename': udt_name,
|
||||
'fieldnames': fieldnames,
|
||||
'fieldnames': field_names,
|
||||
'keyspace': keyspace,
|
||||
'mapped_class': mapped_class})
|
||||
'mapped_class': mapped_class,
|
||||
'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)})
|
||||
cls._cache[(keyspace, udt_name)] = instance
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def apply_parameters(cls, subtypes, names):
|
||||
keyspace = subtypes[0]
|
||||
udt_name = unhexlify(subtypes[1].cassname)
|
||||
field_names = [unhexlify(encoded_name) for encoded_name in names[2:]]
|
||||
udt_name = _name_from_hex_string(subtypes[1].cassname)
|
||||
field_names = [_name_from_hex_string(encoded_name) for encoded_name in names[2:]]
|
||||
assert len(field_names) == len(subtypes[2:])
|
||||
return type(udt_name, (cls,), {'subtypes': subtypes[2:],
|
||||
'cassname': cls.cassname,
|
||||
'typename': udt_name,
|
||||
'fieldnames': field_names,
|
||||
'keyspace': keyspace,
|
||||
'mapped_class': None})
|
||||
'mapped_class': None,
|
||||
'tuple_type': namedtuple(udt_name, field_names)})
|
||||
|
||||
@classmethod
|
||||
def cql_parameterized_type(cls):
|
||||
@@ -991,8 +984,7 @@ class UserType(TupleType):
|
||||
if cls.mapped_class:
|
||||
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
|
||||
else:
|
||||
Result = namedtuple(cls.typename, cls.fieldnames)
|
||||
return Result(*values)
|
||||
return cls.tuple_type(*values)
|
||||
|
||||
@classmethod
|
||||
def serialize_safe(cls, val, protocol_version):
|
||||
@@ -1008,6 +1000,18 @@ class UserType(TupleType):
|
||||
buf.write(int32_pack(-1))
|
||||
return buf.getvalue()
|
||||
|
||||
@classmethod
|
||||
def _make_registered_udt_namedtuple(cls, keyspace, name, field_names):
|
||||
# this is required to make the type resolvable via this module...
|
||||
# required when unregistered udts are pickled for use as keys in
|
||||
# util.OrderedMap
|
||||
qualified_name = "%s_%s" % (keyspace, name)
|
||||
nt = getattr(cls._module, qualified_name, None)
|
||||
if not nt:
|
||||
nt = namedtuple(qualified_name, field_names)
|
||||
setattr(cls._module, qualified_name, nt)
|
||||
return nt
|
||||
|
||||
|
||||
class CompositeType(_ParameterizedType):
|
||||
typename = "'org.apache.cassandra.db.marshal.CompositeType'"
|
||||
|
||||
@@ -28,7 +28,7 @@ import types
|
||||
from uuid import UUID
|
||||
import six
|
||||
|
||||
from cassandra.util import OrderedDict, sortedset
|
||||
from cassandra.util import OrderedDict, OrderedMap, sortedset
|
||||
|
||||
if six.PY3:
|
||||
long = int
|
||||
@@ -77,6 +77,7 @@ class Encoder(object):
|
||||
datetime.time: self.cql_encode_time,
|
||||
dict: self.cql_encode_map_collection,
|
||||
OrderedDict: self.cql_encode_map_collection,
|
||||
OrderedMap: self.cql_encode_map_collection,
|
||||
list: self.cql_encode_list_collection,
|
||||
tuple: self.cql_encode_list_collection,
|
||||
set: self.cql_encode_set_collection,
|
||||
|
||||
193
cassandra/io/eventletreactor.py
Normal file
193
cassandra/io/eventletreactor.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# Copyright 2014 Symantec Corporation
|
||||
# Copyright 2013-2014 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Originally derived from MagnetoDB source:
|
||||
# https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py
|
||||
|
||||
from collections import defaultdict
|
||||
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL
|
||||
import eventlet
|
||||
from eventlet.green import select, socket
|
||||
from eventlet.queue import Queue
|
||||
from functools import partial
|
||||
import logging
|
||||
import os
|
||||
from threading import Event
|
||||
|
||||
from six.moves import xrange
|
||||
|
||||
from cassandra import OperationTimedOut
|
||||
from cassandra.connection import Connection, ConnectionShutdown
|
||||
from cassandra.protocol import RegisterMessage
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_timeout(err):
|
||||
return (
|
||||
err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or
|
||||
(err == EINVAL and os.name in ('nt', 'ce'))
|
||||
)
|
||||
|
||||
|
||||
class EventletConnection(Connection):
|
||||
"""
|
||||
An implementation of :class:`.Connection` that utilizes ``eventlet``.
|
||||
"""
|
||||
|
||||
_total_reqd_bytes = 0
|
||||
_read_watcher = None
|
||||
_write_watcher = None
|
||||
_socket = None
|
||||
|
||||
@classmethod
|
||||
def initialize_reactor(cls):
|
||||
eventlet.monkey_patch()
|
||||
|
||||
@classmethod
|
||||
def factory(cls, *args, **kwargs):
|
||||
timeout = kwargs.pop('timeout', 5.0)
|
||||
conn = cls(*args, **kwargs)
|
||||
conn.connected_event.wait(timeout)
|
||||
if conn.last_error:
|
||||
raise conn.last_error
|
||||
elif not conn.connected_event.is_set():
|
||||
conn.close()
|
||||
raise OperationTimedOut("Timed out creating connection")
|
||||
else:
|
||||
return conn
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
Connection.__init__(self, *args, **kwargs)
|
||||
|
||||
self.connected_event = Event()
|
||||
self._write_queue = Queue()
|
||||
|
||||
self._callbacks = {}
|
||||
self._push_watchers = defaultdict(set)
|
||||
|
||||
sockerr = None
|
||||
addresses = socket.getaddrinfo(
|
||||
self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM
|
||||
)
|
||||
for (af, socktype, proto, canonname, sockaddr) in addresses:
|
||||
try:
|
||||
self._socket = socket.socket(af, socktype, proto)
|
||||
self._socket.settimeout(1.0)
|
||||
self._socket.connect(sockaddr)
|
||||
sockerr = None
|
||||
break
|
||||
except socket.error as err:
|
||||
sockerr = err
|
||||
if sockerr:
|
||||
raise socket.error(
|
||||
sockerr.errno,
|
||||
"Tried connecting to %s. Last error: %s" % (
|
||||
[a[4] for a in addresses], sockerr.strerror)
|
||||
)
|
||||
|
||||
if self.sockopts:
|
||||
for args in self.sockopts:
|
||||
self._socket.setsockopt(*args)
|
||||
|
||||
self._read_watcher = eventlet.spawn(lambda: self.handle_read())
|
||||
self._write_watcher = eventlet.spawn(lambda: self.handle_write())
|
||||
self._send_options_message()
|
||||
|
||||
def close(self):
|
||||
with self.lock:
|
||||
if self.is_closed:
|
||||
return
|
||||
self.is_closed = True
|
||||
|
||||
log.debug("Closing connection (%s) to %s" % (id(self), self.host))
|
||||
|
||||
cur_gthread = eventlet.getcurrent()
|
||||
|
||||
if self._read_watcher and self._read_watcher != cur_gthread:
|
||||
self._read_watcher.kill()
|
||||
if self._write_watcher and self._write_watcher != cur_gthread:
|
||||
self._write_watcher.kill()
|
||||
if self._socket:
|
||||
self._socket.close()
|
||||
log.debug("Closed socket to %s" % (self.host,))
|
||||
|
||||
if not self.is_defunct:
|
||||
self.error_all_callbacks(
|
||||
ConnectionShutdown("Connection to %s was closed" % self.host))
|
||||
# don't leave in-progress operations hanging
|
||||
self.connected_event.set()
|
||||
|
||||
def handle_close(self):
|
||||
log.debug("connection closed by server")
|
||||
self.close()
|
||||
|
||||
def handle_write(self):
|
||||
while True:
|
||||
try:
|
||||
next_msg = self._write_queue.get()
|
||||
self._socket.sendall(next_msg)
|
||||
except socket.error as err:
|
||||
log.debug("Exception during socket send for %s: %s", self, err)
|
||||
self.defunct(err)
|
||||
return # Leave the write loop
|
||||
|
||||
def handle_read(self):
|
||||
run_select = partial(select.select, (self._socket,), (), ())
|
||||
while True:
|
||||
try:
|
||||
run_select()
|
||||
except Exception as exc:
|
||||
if not self.is_closed:
|
||||
log.debug("Exception during read select() for %s: %s",
|
||||
self, exc)
|
||||
self.defunct(exc)
|
||||
return
|
||||
|
||||
try:
|
||||
buf = self._socket.recv(self.in_buffer_size)
|
||||
self._iobuf.write(buf)
|
||||
except socket.error as err:
|
||||
if not is_timeout(err):
|
||||
log.debug("Exception during socket recv for %s: %s",
|
||||
self, err)
|
||||
self.defunct(err)
|
||||
return # leave the read loop
|
||||
|
||||
if self._iobuf.tell():
|
||||
self.process_io_buffer()
|
||||
else:
|
||||
log.debug("Connection %s closed by server", self)
|
||||
self.close()
|
||||
return
|
||||
|
||||
def push(self, data):
|
||||
chunk_size = self.out_buffer_size
|
||||
for i in xrange(0, len(data), chunk_size):
|
||||
self._write_queue.put(data[i:i + chunk_size])
|
||||
|
||||
def register_watcher(self, event_type, callback, register_timeout=None):
|
||||
self._push_watchers[event_type].add(callback)
|
||||
self.wait_for_response(
|
||||
RegisterMessage(event_list=[event_type]),
|
||||
timeout=register_timeout)
|
||||
|
||||
def register_watchers(self, type_callback_dict, register_timeout=None):
|
||||
for event_type, callback in type_callback_dict.items():
|
||||
self._push_watchers[event_type].add(callback)
|
||||
self.wait_for_response(
|
||||
RegisterMessage(event_list=type_callback_dict.keys()),
|
||||
timeout=register_timeout)
|
||||
@@ -211,7 +211,7 @@ class Metadata(object):
|
||||
return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options)
|
||||
|
||||
def _build_usertype(self, keyspace, usertype_row):
|
||||
type_classes = map(types.lookup_casstype, usertype_row['field_types'])
|
||||
type_classes = list(map(types.lookup_casstype, usertype_row['field_types']))
|
||||
return UserType(usertype_row['keyspace_name'], usertype_row['type_name'],
|
||||
usertype_row['field_names'], type_classes)
|
||||
|
||||
|
||||
@@ -355,15 +355,21 @@ class HostConnection(object):
|
||||
return
|
||||
|
||||
def connection_finished_setting_keyspace(conn, error):
|
||||
self.return_connection(conn)
|
||||
errors = [] if not error else [error]
|
||||
callback(self, errors)
|
||||
|
||||
self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
|
||||
|
||||
def get_connections(self):
|
||||
c = self._connection
|
||||
return [c] if c else []
|
||||
|
||||
def get_state(self):
|
||||
have_conn = self._connection is not None
|
||||
in_flight = self._connection.in_flight if have_conn else 0
|
||||
return "shutdown: %s, open: %s, in_flights: %s" % (self.is_shutdown, have_conn, in_flight)
|
||||
connection = self._connection
|
||||
open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0
|
||||
in_flights = [connection.in_flight] if connection else []
|
||||
return {'shutdown': self.is_shutdown, 'open_count': open_count, 'in_flights': in_flights}
|
||||
|
||||
|
||||
_MAX_SIMULTANEOUS_CREATION = 1
|
||||
@@ -683,6 +689,7 @@ class HostConnectionPool(object):
|
||||
return
|
||||
|
||||
def connection_finished_setting_keyspace(conn, error):
|
||||
self.return_connection(conn)
|
||||
remaining_callbacks.remove(conn)
|
||||
if error:
|
||||
errors.append(error)
|
||||
@@ -693,6 +700,9 @@ class HostConnectionPool(object):
|
||||
for conn in self._connections:
|
||||
conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
|
||||
|
||||
def get_connections(self):
|
||||
return self._connections
|
||||
|
||||
def get_state(self):
|
||||
in_flights = ", ".join([str(c.in_flight) for c in self._connections])
|
||||
return "shutdown: %s, open_count: %d, in_flights: %s" % (self.is_shutdown, self.open_count, in_flights)
|
||||
in_flights = [c.in_flight for c in self._connections]
|
||||
return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights}
|
||||
|
||||
@@ -632,14 +632,14 @@ class ResultMessage(_MessageType):
|
||||
typeclass = typeclass.apply_parameters((keysubtype, valsubtype))
|
||||
elif typeclass == TupleType:
|
||||
num_items = read_short(f)
|
||||
types = tuple(cls.read_type(f, user_type_map) for _ in xrange(num_items))
|
||||
types = tuple(cls.read_type(f, user_type_map) for _ in range(num_items))
|
||||
typeclass = typeclass.apply_parameters(types)
|
||||
elif typeclass == UserType:
|
||||
ks = read_string(f)
|
||||
udt_name = read_string(f)
|
||||
num_fields = read_short(f)
|
||||
names_and_types = tuple((read_string(f), cls.read_type(f, user_type_map))
|
||||
for _ in xrange(num_fields))
|
||||
for _ in range(num_fields))
|
||||
mapped_class = user_type_map.get(ks, {}).get(udt_name)
|
||||
typeclass = typeclass.make_udt_class(
|
||||
ks, udt_name, names_and_types, mapped_class)
|
||||
|
||||
@@ -500,13 +500,13 @@ class BoundStatement(Statement):
|
||||
|
||||
try:
|
||||
self.values.append(col_type.serialize(value, proto_version))
|
||||
except (TypeError, struct.error):
|
||||
except (TypeError, struct.error) as exc:
|
||||
col_name = col_spec[2]
|
||||
expected_type = col_type
|
||||
actual_type = type(value)
|
||||
|
||||
message = ('Received an argument of invalid type for column "%s". '
|
||||
'Expected: %s, Got: %s' % (col_name, expected_type, actual_type))
|
||||
'Expected: %s, Got: %s; (%s)' % (col_name, expected_type, actual_type, exc))
|
||||
raise TypeError(message)
|
||||
|
||||
return self
|
||||
|
||||
@@ -555,3 +555,101 @@ except ImportError:
|
||||
if item in other:
|
||||
isect.add(item)
|
||||
return isect
|
||||
|
||||
from collections import Mapping
|
||||
import six
|
||||
from six.moves import cPickle
|
||||
|
||||
|
||||
class OrderedMap(Mapping):
|
||||
'''
|
||||
An ordered map that accepts non-hashable types for keys. It also maintains the
|
||||
insertion order of items, behaving as OrderedDict in that regard. These maps
|
||||
are constructed and read just as normal mapping types, exept that they may
|
||||
contain arbitrary collections and other non-hashable items as keys::
|
||||
|
||||
>>> od = OrderedMap([({'one': 1, 'two': 2}, 'value'),
|
||||
... ({'three': 3, 'four': 4}, 'value2')])
|
||||
>>> list(od.keys())
|
||||
[{'two': 2, 'one': 1}, {'three': 3, 'four': 4}]
|
||||
>>> list(od.values())
|
||||
['value', 'value2']
|
||||
|
||||
These constructs are needed to support nested collections in Cassandra 2.1.3+,
|
||||
where frozen collections can be specified as parameters to others\*::
|
||||
|
||||
CREATE TABLE example (
|
||||
...
|
||||
value map<frozen<map<int, int>>, double>
|
||||
...
|
||||
)
|
||||
|
||||
This class dervies from the (immutable) Mapping API. Objects in these maps
|
||||
are not intended be modified.
|
||||
|
||||
\* Note: Because of the way Cassandra encodes nested types, when using the
|
||||
driver with nested collections, :attr:`~.Cluster.protocol_version` must be 3
|
||||
or higher.
|
||||
|
||||
'''
|
||||
def __init__(self, *args, **kwargs):
|
||||
if len(args) > 1:
|
||||
raise TypeError('expected at most 1 arguments, got %d' % len(args))
|
||||
|
||||
self._items = []
|
||||
self._index = {}
|
||||
if args:
|
||||
e = args[0]
|
||||
if callable(getattr(e, 'keys', None)):
|
||||
for k in e.keys():
|
||||
self._items.append((k, e[k]))
|
||||
else:
|
||||
for k, v in e:
|
||||
self._insert(k, v)
|
||||
|
||||
for k, v in six.iteritems(kwargs):
|
||||
self._insert(k, v)
|
||||
|
||||
def _insert(self, key, value):
|
||||
flat_key = self._serialize_key(key)
|
||||
i = self._index.get(flat_key, -1)
|
||||
if i >= 0:
|
||||
self._items[i] = (key, value)
|
||||
else:
|
||||
self._items.append((key, value))
|
||||
self._index[flat_key] = len(self._items) - 1
|
||||
|
||||
def __getitem__(self, key):
|
||||
index = self._index[self._serialize_key(key)]
|
||||
return self._items[index][1]
|
||||
|
||||
def __iter__(self):
|
||||
for i in self._items:
|
||||
yield i[0]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._items)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, OrderedMap):
|
||||
return self._items == other._items
|
||||
try:
|
||||
d = dict(other)
|
||||
return len(d) == len(self._items) and all(i[1] == d[i[0]] for i in self._items)
|
||||
except KeyError:
|
||||
return False
|
||||
except TypeError:
|
||||
pass
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self):
|
||||
return '%s([%s])' % (
|
||||
self.__class__.__name__,
|
||||
', '.join("(%r, %r)" % (k, v) for k, v in self._items))
|
||||
|
||||
def __str__(self):
|
||||
return '{%s}' % ', '.join("%s: %s" % (k, v) for k, v in self._items)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_key(key):
|
||||
return cPickle.dumps(key)
|
||||
|
||||
@@ -14,3 +14,9 @@
|
||||
|
||||
.. autoclass:: PlainTextAuthenticator
|
||||
:members:
|
||||
|
||||
.. autoclass:: SaslAuthProvider
|
||||
:members:
|
||||
|
||||
.. autoclass:: SaslAuthenticator
|
||||
:members:
|
||||
|
||||
@@ -39,6 +39,12 @@
|
||||
|
||||
.. autoattribute:: control_connection_timeout
|
||||
|
||||
.. autoattribute:: idle_heartbeat_interval
|
||||
|
||||
.. autoattribute:: schema_event_refresh_window
|
||||
|
||||
.. autoattribute:: topology_event_refresh_window
|
||||
|
||||
.. automethod:: connect
|
||||
|
||||
.. automethod:: shutdown
|
||||
@@ -57,6 +63,13 @@
|
||||
|
||||
.. automethod:: set_max_connections_per_host
|
||||
|
||||
.. automethod:: refresh_schema
|
||||
|
||||
.. automethod:: refresh_nodes
|
||||
|
||||
.. automethod:: set_meta_refresh_enabled
|
||||
|
||||
|
||||
.. autoclass:: Session ()
|
||||
|
||||
.. autoattribute:: default_timeout
|
||||
|
||||
7
docs/api/cassandra/io/eventletreactor.rst
Normal file
7
docs/api/cassandra/io/eventletreactor.rst
Normal file
@@ -0,0 +1,7 @@
|
||||
``cassandra.io.eventletreactor`` - ``eventlet``-compatible Connection
|
||||
=====================================================================
|
||||
|
||||
.. module:: cassandra.io.eventletreactor
|
||||
|
||||
.. autoclass:: EventletConnection
|
||||
:members:
|
||||
7
docs/api/cassandra/util.rst
Normal file
7
docs/api/cassandra/util.rst
Normal file
@@ -0,0 +1,7 @@
|
||||
``cassandra.util`` - Utilities
|
||||
===================================
|
||||
|
||||
.. module:: cassandra.util
|
||||
|
||||
.. autoclass:: OrderedMap
|
||||
:members:
|
||||
@@ -16,7 +16,9 @@ API Documentation
|
||||
cassandra/decoder
|
||||
cassandra/concurrent
|
||||
cassandra/connection
|
||||
cassandra/util
|
||||
cassandra/io/asyncorereactor
|
||||
cassandra/io/eventletreactor
|
||||
cassandra/io/libevreactor
|
||||
cassandra/io/geventreactor
|
||||
cassandra/io/twistedreactor
|
||||
|
||||
@@ -34,9 +34,11 @@ to be explicit.
|
||||
Custom Authenticators
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
If you're using something other than Cassandra's ``PasswordAuthenticator``,
|
||||
you may need to create your own subclasses of :class:`~.AuthProvider` and
|
||||
:class:`~.Authenticator`. You can use :class:`~.PlainTextAuthProvider`
|
||||
and :class:`~.PlainTextAuthenticator` as example implementations.
|
||||
:class:`~.SaslAuthProvider` is provided for generic SASL authentication mechanisms,
|
||||
utilizing the ``pure-sasl`` package.
|
||||
If these do not suit your needs, you may need to create your own subclasses of
|
||||
:class:`~.AuthProvider` and :class:`~.Authenticator`. You can use the Sasl classes
|
||||
as example implementations.
|
||||
|
||||
Protocol v1 Authentication
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
13
setup.py
13
setup.py
@@ -20,6 +20,11 @@ if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests":
|
||||
from gevent.monkey import patch_all
|
||||
patch_all()
|
||||
|
||||
if __name__ == '__main__' and sys.argv[1] == "eventlet_nosetests":
|
||||
print("Running eventlet tests")
|
||||
from eventlet import monkey_patch
|
||||
monkey_patch()
|
||||
|
||||
import ez_setup
|
||||
ez_setup.use_setuptools()
|
||||
|
||||
@@ -51,10 +56,14 @@ try:
|
||||
from nose.commands import nosetests
|
||||
except ImportError:
|
||||
gevent_nosetests = None
|
||||
eventlet_nosetests = None
|
||||
else:
|
||||
class gevent_nosetests(nosetests):
|
||||
description = "run nosetests with gevent monkey patching"
|
||||
|
||||
class eventlet_nosetests(nosetests):
|
||||
description = "run nosetests with eventlet monkey patching"
|
||||
|
||||
|
||||
class DocCommand(Command):
|
||||
|
||||
@@ -174,10 +183,14 @@ On OSX, via homebrew:
|
||||
|
||||
|
||||
def run_setup(extensions):
|
||||
|
||||
kw = {'cmdclass': {'doc': DocCommand}}
|
||||
if gevent_nosetests is not None:
|
||||
kw['cmdclass']['gevent_nosetests'] = gevent_nosetests
|
||||
|
||||
if eventlet_nosetests is not None:
|
||||
kw['cmdclass']['eventlet_nosetests'] = eventlet_nosetests
|
||||
|
||||
if extensions:
|
||||
kw['cmdclass']['build_ext'] = build_extensions
|
||||
kw['ext_modules'] = extensions
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
log = logging.getLogger()
|
||||
log.setLevel('DEBUG')
|
||||
@@ -21,3 +22,18 @@ if not log.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s [%(module)s:%(lineno)s]: %(message)s'))
|
||||
log.addHandler(handler)
|
||||
|
||||
|
||||
def is_gevent_monkey_patched():
|
||||
return 'gevent.monkey' in sys.modules
|
||||
|
||||
|
||||
def is_eventlet_monkey_patched():
|
||||
if 'eventlet.patcher' in sys.modules:
|
||||
import eventlet
|
||||
return eventlet.patcher.is_monkey_patched('socket')
|
||||
return False
|
||||
|
||||
|
||||
def is_monkey_patched():
|
||||
return is_gevent_monkey_patched() or is_eventlet_monkey_patched()
|
||||
|
||||
@@ -13,11 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from tests.integration import use_singledc, get_cluster, remove_cluster, PROTOCOL_VERSION
|
||||
from cassandra.cluster import Cluster, NoHostAvailable
|
||||
from cassandra.auth import PlainTextAuthProvider, SASLClient, SaslAuthProvider
|
||||
|
||||
from tests.integration import use_singledc, get_cluster, remove_cluster, PROTOCOL_VERSION
|
||||
from tests.integration.util import assert_quiescent_pool_state
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
@@ -35,7 +37,10 @@ def setup_module():
|
||||
'authorizer': 'CassandraAuthorizer'}
|
||||
ccm_cluster.set_configuration_options(config_options)
|
||||
log.debug("Starting ccm test cluster with %s", config_options)
|
||||
ccm_cluster.start(wait_for_binary_proto=True)
|
||||
ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True)
|
||||
# there seems to be some race, with some versions of C* taking longer to
|
||||
# get the auth (and default user) setup. Sleep here to give it a chance
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def teardown_module():
|
||||
@@ -59,12 +64,13 @@ class AuthenticationTests(unittest.TestCase):
|
||||
:return: authentication object suitable for Cluster.connect()
|
||||
"""
|
||||
if PROTOCOL_VERSION < 2:
|
||||
return lambda(hostname): dict(username=username, password=password)
|
||||
return lambda hostname: dict(username=username, password=password)
|
||||
else:
|
||||
return PlainTextAuthProvider(username=username, password=password)
|
||||
|
||||
def cluster_as(self, usr, pwd):
|
||||
return Cluster(protocol_version=PROTOCOL_VERSION,
|
||||
idle_heartbeat_interval=0,
|
||||
auth_provider=self.get_authentication_provider(username=usr, password=pwd))
|
||||
|
||||
def test_auth_connect(self):
|
||||
@@ -77,9 +83,11 @@ class AuthenticationTests(unittest.TestCase):
|
||||
cluster = self.cluster_as(user, passwd)
|
||||
session = cluster.connect()
|
||||
self.assertTrue(session.execute('SELECT release_version FROM system.local'))
|
||||
assert_quiescent_pool_state(self, cluster)
|
||||
cluster.shutdown()
|
||||
|
||||
root_session.execute('DROP USER %s', user)
|
||||
assert_quiescent_pool_state(self, root_session.cluster)
|
||||
root_session.cluster.shutdown()
|
||||
|
||||
def test_connect_wrong_pwd(self):
|
||||
@@ -88,6 +96,8 @@ class AuthenticationTests(unittest.TestCase):
|
||||
'.*AuthenticationFailed.*Bad credentials.*Username and/or '
|
||||
'password are incorrect.*',
|
||||
cluster.connect)
|
||||
assert_quiescent_pool_state(self, cluster)
|
||||
cluster.shutdown()
|
||||
|
||||
def test_connect_wrong_username(self):
|
||||
cluster = self.cluster_as('wrong_user', 'cassandra')
|
||||
@@ -95,6 +105,8 @@ class AuthenticationTests(unittest.TestCase):
|
||||
'.*AuthenticationFailed.*Bad credentials.*Username and/or '
|
||||
'password are incorrect.*',
|
||||
cluster.connect)
|
||||
assert_quiescent_pool_state(self, cluster)
|
||||
cluster.shutdown()
|
||||
|
||||
def test_connect_empty_pwd(self):
|
||||
cluster = self.cluster_as('Cassandra', '')
|
||||
@@ -102,12 +114,16 @@ class AuthenticationTests(unittest.TestCase):
|
||||
'.*AuthenticationFailed.*Bad credentials.*Username and/or '
|
||||
'password are incorrect.*',
|
||||
cluster.connect)
|
||||
assert_quiescent_pool_state(self, cluster)
|
||||
cluster.shutdown()
|
||||
|
||||
def test_connect_no_auth_provider(self):
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
self.assertRaisesRegexp(NoHostAvailable,
|
||||
'.*AuthenticationFailed.*Remote end requires authentication.*',
|
||||
cluster.connect)
|
||||
assert_quiescent_pool_state(self, cluster)
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
class SaslAuthenticatorTests(AuthenticationTests):
|
||||
|
||||
@@ -12,18 +12,26 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tests.integration import use_singledc, PROTOCOL_VERSION
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
import cassandra
|
||||
from cassandra.query import SimpleStatement, TraceUnavailable
|
||||
from cassandra.policies import RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance
|
||||
from collections import deque
|
||||
from mock import patch
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import cassandra
|
||||
from cassandra.cluster import Cluster, NoHostAvailable
|
||||
from cassandra.concurrent import execute_concurrent
|
||||
from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy,
|
||||
RetryPolicy, SimpleConvictionPolicy, HostDistance,
|
||||
WhiteListRoundRobinPolicy)
|
||||
from cassandra.query import SimpleStatement, TraceUnavailable
|
||||
|
||||
from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions
|
||||
from tests.integration.util import assert_quiescent_pool_state
|
||||
|
||||
|
||||
def setup_module():
|
||||
@@ -66,6 +74,8 @@ class ClusterTests(unittest.TestCase):
|
||||
result = session.execute("SELECT * FROM clustertests.cf0")
|
||||
self.assertEqual([('a', 'b', 'c')], result)
|
||||
|
||||
session.execute("DROP KEYSPACE clustertests")
|
||||
|
||||
cluster.shutdown()
|
||||
|
||||
def test_connect_on_keyspace(self):
|
||||
@@ -202,6 +212,159 @@ class ClusterTests(unittest.TestCase):
|
||||
|
||||
self.assertIn("newkeyspace", cluster.metadata.keyspaces)
|
||||
|
||||
session.execute("DROP KEYSPACE newkeyspace")
|
||||
|
||||
def test_refresh_schema(self):
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
original_meta = cluster.metadata.keyspaces
|
||||
# full schema refresh, with wait
|
||||
cluster.refresh_schema()
|
||||
self.assertIsNot(original_meta, cluster.metadata.keyspaces)
|
||||
self.assertEqual(original_meta, cluster.metadata.keyspaces)
|
||||
|
||||
session.shutdown()
|
||||
|
||||
def test_refresh_schema_keyspace(self):
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
original_meta = cluster.metadata.keyspaces
|
||||
original_system_meta = original_meta['system']
|
||||
|
||||
# only refresh one keyspace
|
||||
cluster.refresh_schema(keyspace='system')
|
||||
current_meta = cluster.metadata.keyspaces
|
||||
self.assertIs(original_meta, current_meta)
|
||||
current_system_meta = current_meta['system']
|
||||
self.assertIsNot(original_system_meta, current_system_meta)
|
||||
self.assertEqual(original_system_meta.as_cql_query(), current_system_meta.as_cql_query())
|
||||
session.shutdown()
|
||||
|
||||
def test_refresh_schema_table(self):
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
original_meta = cluster.metadata.keyspaces
|
||||
original_system_meta = original_meta['system']
|
||||
original_system_schema_meta = original_system_meta.tables['schema_columnfamilies']
|
||||
|
||||
# only refresh one table
|
||||
cluster.refresh_schema(keyspace='system', table='schema_columnfamilies')
|
||||
current_meta = cluster.metadata.keyspaces
|
||||
current_system_meta = current_meta['system']
|
||||
current_system_schema_meta = current_system_meta.tables['schema_columnfamilies']
|
||||
self.assertIs(original_meta, current_meta)
|
||||
self.assertIs(original_system_meta, current_system_meta)
|
||||
self.assertIsNot(original_system_schema_meta, current_system_schema_meta)
|
||||
self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query())
|
||||
session.shutdown()
|
||||
|
||||
def test_refresh_schema_type(self):
|
||||
if get_server_versions()[0] < (2, 1, 0):
|
||||
raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1')
|
||||
|
||||
if PROTOCOL_VERSION < 3:
|
||||
raise unittest.SkipTest('UDTs are not specified in change events for protocol v2')
|
||||
# We may want to refresh types on keyspace change events in that case(?)
|
||||
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
keyspace_name = 'test1rf'
|
||||
type_name = self._testMethodName
|
||||
|
||||
session.execute('CREATE TYPE IF NOT EXISTS %s.%s (one int, two text)' % (keyspace_name, type_name))
|
||||
original_meta = cluster.metadata.keyspaces
|
||||
original_test1rf_meta = original_meta[keyspace_name]
|
||||
original_type_meta = original_test1rf_meta.user_types[type_name]
|
||||
|
||||
# only refresh one type
|
||||
cluster.refresh_schema(keyspace='test1rf', usertype=type_name)
|
||||
current_meta = cluster.metadata.keyspaces
|
||||
current_test1rf_meta = current_meta[keyspace_name]
|
||||
current_type_meta = current_test1rf_meta.user_types[type_name]
|
||||
self.assertIs(original_meta, current_meta)
|
||||
self.assertIs(original_test1rf_meta, current_test1rf_meta)
|
||||
self.assertIsNot(original_type_meta, current_type_meta)
|
||||
self.assertEqual(original_type_meta.as_cql_query(), current_type_meta.as_cql_query())
|
||||
session.shutdown()
|
||||
|
||||
def test_refresh_schema_no_wait(self):
|
||||
|
||||
contact_points = ['127.0.0.1']
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=10,
|
||||
contact_points=contact_points, load_balancing_policy=WhiteListRoundRobinPolicy(contact_points))
|
||||
session = cluster.connect()
|
||||
|
||||
schema_ver = session.execute("SELECT schema_version FROM system.local WHERE key='local'")[0][0]
|
||||
|
||||
# create a schema disagreement
|
||||
session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (uuid4(),))
|
||||
|
||||
try:
|
||||
agreement_timeout = 1
|
||||
|
||||
# cluster agreement wait exceeded
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=agreement_timeout)
|
||||
start_time = time.time()
|
||||
s = c.connect()
|
||||
end_time = time.time()
|
||||
self.assertGreaterEqual(end_time - start_time, agreement_timeout)
|
||||
self.assertTrue(c.metadata.keyspaces)
|
||||
|
||||
# cluster agreement wait used for refresh
|
||||
original_meta = c.metadata.keyspaces
|
||||
start_time = time.time()
|
||||
self.assertRaisesRegexp(Exception, r"Schema was not refreshed.*", c.refresh_schema)
|
||||
end_time = time.time()
|
||||
self.assertGreaterEqual(end_time - start_time, agreement_timeout)
|
||||
self.assertIs(original_meta, c.metadata.keyspaces)
|
||||
|
||||
# refresh wait overrides cluster value
|
||||
original_meta = c.metadata.keyspaces
|
||||
start_time = time.time()
|
||||
c.refresh_schema(max_schema_agreement_wait=0)
|
||||
end_time = time.time()
|
||||
self.assertLess(end_time - start_time, agreement_timeout)
|
||||
self.assertIsNot(original_meta, c.metadata.keyspaces)
|
||||
self.assertEqual(original_meta, c.metadata.keyspaces)
|
||||
|
||||
s.shutdown()
|
||||
|
||||
refresh_threshold = 0.5
|
||||
# cluster agreement bypass
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=0)
|
||||
start_time = time.time()
|
||||
s = c.connect()
|
||||
end_time = time.time()
|
||||
self.assertLess(end_time - start_time, refresh_threshold)
|
||||
self.assertTrue(c.metadata.keyspaces)
|
||||
|
||||
# cluster agreement wait used for refresh
|
||||
original_meta = c.metadata.keyspaces
|
||||
start_time = time.time()
|
||||
c.refresh_schema()
|
||||
end_time = time.time()
|
||||
self.assertLess(end_time - start_time, refresh_threshold)
|
||||
self.assertIsNot(original_meta, c.metadata.keyspaces)
|
||||
self.assertEqual(original_meta, c.metadata.keyspaces)
|
||||
|
||||
# refresh wait overrides cluster value
|
||||
original_meta = c.metadata.keyspaces
|
||||
start_time = time.time()
|
||||
self.assertRaisesRegexp(Exception, r"Schema was not refreshed.*", c.refresh_schema, max_schema_agreement_wait=agreement_timeout)
|
||||
end_time = time.time()
|
||||
self.assertGreaterEqual(end_time - start_time, agreement_timeout)
|
||||
self.assertIs(original_meta, c.metadata.keyspaces)
|
||||
|
||||
s.shutdown()
|
||||
finally:
|
||||
session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (schema_ver,))
|
||||
|
||||
session.shutdown()
|
||||
|
||||
def test_trace(self):
|
||||
"""
|
||||
Ensure trace can be requested for async and non-async queries
|
||||
@@ -271,3 +434,115 @@ class ClusterTests(unittest.TestCase):
|
||||
|
||||
self.assertIn(query, str(future))
|
||||
self.assertIn('result', str(future))
|
||||
|
||||
def test_idle_heartbeat(self):
|
||||
interval = 1
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=interval)
|
||||
if PROTOCOL_VERSION < 3:
|
||||
cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
|
||||
session = cluster.connect()
|
||||
|
||||
# This test relies on impl details of connection req id management to see if heartbeats
|
||||
# are being sent. May need update if impl is changed
|
||||
connection_request_ids = {}
|
||||
for h in cluster.get_connection_holders():
|
||||
for c in h.get_connections():
|
||||
# make sure none are idle (should have startup messages)
|
||||
self.assertFalse(c.is_idle)
|
||||
with c.lock:
|
||||
connection_request_ids[id(c)] = deque(c.request_ids) # copy of request ids
|
||||
|
||||
# let two heatbeat intervals pass (first one had startup messages in it)
|
||||
time.sleep(2 * interval + interval/10.)
|
||||
|
||||
connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()]
|
||||
|
||||
# make sure requests were sent on all connections
|
||||
for c in connections:
|
||||
expected_ids = connection_request_ids[id(c)]
|
||||
expected_ids.rotate(-1)
|
||||
with c.lock:
|
||||
self.assertListEqual(list(c.request_ids), list(expected_ids))
|
||||
|
||||
# assert idle status
|
||||
self.assertTrue(all(c.is_idle for c in connections))
|
||||
|
||||
# send messages on all connections
|
||||
statements_and_params = [("SELECT release_version FROM system.local", ())] * len(cluster.metadata.all_hosts())
|
||||
results = execute_concurrent(session, statements_and_params)
|
||||
for success, result in results:
|
||||
self.assertTrue(success)
|
||||
|
||||
# assert not idle status
|
||||
self.assertFalse(any(c.is_idle if not c.is_control_connection else False for c in connections))
|
||||
|
||||
# holders include session pools and cc
|
||||
holders = cluster.get_connection_holders()
|
||||
self.assertIn(cluster.control_connection, holders)
|
||||
self.assertEqual(len(holders), len(cluster.metadata.all_hosts()) + 1) # hosts pools, 1 for cc
|
||||
|
||||
# include additional sessions
|
||||
session2 = cluster.connect()
|
||||
|
||||
holders = cluster.get_connection_holders()
|
||||
self.assertIn(cluster.control_connection, holders)
|
||||
self.assertEqual(len(holders), 2 * len(cluster.metadata.all_hosts()) + 1) # 2 sessions' hosts pools, 1 for cc
|
||||
|
||||
cluster._idle_heartbeat.stop()
|
||||
cluster._idle_heartbeat.join()
|
||||
assert_quiescent_pool_state(self, cluster)
|
||||
|
||||
session.shutdown()
|
||||
|
||||
@patch('cassandra.cluster.Cluster.idle_heartbeat_interval', new=0.1)
|
||||
def test_idle_heartbeat_disabled(self):
|
||||
self.assertTrue(Cluster.idle_heartbeat_interval)
|
||||
|
||||
# heartbeat disabled with '0'
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0)
|
||||
self.assertEqual(cluster.idle_heartbeat_interval, 0)
|
||||
session = cluster.connect()
|
||||
|
||||
# let two heatbeat intervals pass (first one had startup messages in it)
|
||||
time.sleep(2 * Cluster.idle_heartbeat_interval)
|
||||
|
||||
connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()]
|
||||
|
||||
# assert not idle status (should never get reset because there is not heartbeat)
|
||||
self.assertFalse(any(c.is_idle for c in connections))
|
||||
|
||||
session.shutdown()
|
||||
|
||||
def test_pool_management(self):
|
||||
# Ensure that in_flight and request_ids quiesce after cluster operations
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0) # no idle heartbeat here, pool management is tested in test_idle_heartbeat
|
||||
session = cluster.connect()
|
||||
session2 = cluster.connect()
|
||||
|
||||
# prepare
|
||||
p = session.prepare("SELECT * FROM system.local WHERE key=?")
|
||||
self.assertTrue(session.execute(p, ('local',)))
|
||||
|
||||
# simple
|
||||
self.assertTrue(session.execute("SELECT * FROM system.local WHERE key='local'"))
|
||||
|
||||
# set keyspace
|
||||
session.set_keyspace('system')
|
||||
session.set_keyspace('system_traces')
|
||||
|
||||
# use keyspace
|
||||
session.execute('USE system')
|
||||
session.execute('USE system_traces')
|
||||
|
||||
# refresh schema
|
||||
cluster.refresh_schema()
|
||||
cluster.refresh_schema(max_schema_agreement_wait=0)
|
||||
|
||||
# submit schema refresh
|
||||
future = cluster.submit_schema_refresh()
|
||||
future.result()
|
||||
|
||||
assert_quiescent_pool_state(self, cluster)
|
||||
|
||||
session2.shutdown()
|
||||
session.shutdown()
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tests.integration import use_singledc, PROTOCOL_VERSION
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
@@ -21,13 +19,15 @@ except ImportError:
|
||||
|
||||
from functools import partial
|
||||
from six.moves import range
|
||||
import sys
|
||||
from threading import Thread, Event
|
||||
|
||||
from cassandra import ConsistencyLevel, OperationTimedOut
|
||||
from cassandra.cluster import NoHostAvailable
|
||||
from cassandra.protocol import QueryMessage
|
||||
from cassandra.io.asyncorereactor import AsyncoreConnection
|
||||
from cassandra.protocol import QueryMessage
|
||||
|
||||
from tests import is_monkey_patched
|
||||
from tests.integration import use_singledc, PROTOCOL_VERSION
|
||||
|
||||
try:
|
||||
from cassandra.io.libevreactor import LibevConnection
|
||||
@@ -230,8 +230,8 @@ class AsyncoreConnectionTests(ConnectionTests, unittest.TestCase):
|
||||
klass = AsyncoreConnection
|
||||
|
||||
def setUp(self):
|
||||
if 'gevent.monkey' in sys.modules:
|
||||
raise unittest.SkipTest("Can't test asyncore with gevent monkey patching")
|
||||
if is_monkey_patched():
|
||||
raise unittest.SkipTest("Can't test asyncore with monkey patching")
|
||||
ConnectionTests.setUp(self)
|
||||
|
||||
|
||||
@@ -240,8 +240,8 @@ class LibevConnectionTests(ConnectionTests, unittest.TestCase):
|
||||
klass = LibevConnection
|
||||
|
||||
def setUp(self):
|
||||
if 'gevent.monkey' in sys.modules:
|
||||
raise unittest.SkipTest("Can't test libev with gevent monkey patching")
|
||||
if is_monkey_patched():
|
||||
raise unittest.SkipTest("Can't test libev with monkey patching")
|
||||
if LibevConnection is None:
|
||||
raise unittest.SkipTest(
|
||||
'libev does not appear to be installed properly')
|
||||
|
||||
@@ -11,15 +11,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import six
|
||||
import difflib
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
import difflib
|
||||
from mock import Mock
|
||||
import six
|
||||
import sys
|
||||
|
||||
from cassandra import AlreadyExists
|
||||
|
||||
@@ -361,6 +362,9 @@ class TestCodeCoverage(unittest.TestCase):
|
||||
"Protocol 3.0+ is required for UDT change events, currently testing against %r"
|
||||
% (PROTOCOL_VERSION,))
|
||||
|
||||
if sys.version_info[2:] != (2, 7):
|
||||
raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.')
|
||||
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
@@ -549,6 +553,9 @@ CREATE TABLE export_udts.users (
|
||||
if get_server_versions()[0] < (2, 1, 0):
|
||||
raise unittest.SkipTest('Test schema output assumes 2.1.0+ options')
|
||||
|
||||
if sys.version_info[2:] != (2, 7):
|
||||
raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.')
|
||||
|
||||
cli_script = """CREATE KEYSPACE legacy
|
||||
WITH placement_strategy = 'SimpleStrategy'
|
||||
AND strategy_options = {replication_factor:1};
|
||||
@@ -633,7 +640,7 @@ create column family composite_comp_with_col
|
||||
index_type : 0}]
|
||||
and compression_options = {'sstable_compression' : 'org.apache.cassandra.io.compress.LZ4Compressor'};"""
|
||||
|
||||
# note: the innerlkey type for legacy.nested_composite_key
|
||||
# note: the inner key type for legacy.nested_composite_key
|
||||
# (org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UUIDType, org.apache.cassandra.db.marshal.UTF8Type))
|
||||
# is a bit strange, but it replays in CQL with desired results
|
||||
expected_string = """CREATE KEYSPACE legacy WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND durable_writes = true;
|
||||
@@ -807,6 +814,7 @@ CREATE TABLE legacy.composite_comp_no_col (
|
||||
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
class TokenMetadataTest(unittest.TestCase):
|
||||
"""
|
||||
Test of TokenMap creation and other behavior.
|
||||
@@ -882,7 +890,7 @@ class KeyspaceAlterMetadata(unittest.TestCase):
|
||||
self.assertEqual(original_keyspace_meta.durable_writes, True)
|
||||
self.assertEqual(len(original_keyspace_meta.tables), 1)
|
||||
|
||||
self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' %name)
|
||||
self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % name)
|
||||
new_keyspace_meta = self.cluster.metadata.keyspaces[name]
|
||||
self.assertNotEqual(original_keyspace_meta, new_keyspace_meta)
|
||||
self.assertEqual(new_keyspace_meta.durable_writes, False)
|
||||
|
||||
@@ -21,8 +21,10 @@ except ImportError:
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from collections import namedtuple
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, date, time
|
||||
from functools import partial
|
||||
import six
|
||||
from uuid import uuid1, uuid4
|
||||
|
||||
@@ -30,10 +32,14 @@ from cassandra import InvalidRequest
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.cqltypes import Int32Type, EMPTY
|
||||
from cassandra.query import dict_factory
|
||||
from cassandra.util import OrderedDict, sortedset
|
||||
from cassandra.util import OrderedMap, sortedset
|
||||
|
||||
from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION
|
||||
|
||||
# defined in module scope for pickling in OrderedMap
|
||||
nested_collection_udt = namedtuple('nested_collection_udt', ['m', 't', 'l', 's'])
|
||||
nested_collection_udt_nested = namedtuple('nested_collection_udt_nested', ['m', 't', 'l', 's', 'u'])
|
||||
|
||||
|
||||
def setup_module():
|
||||
use_singledc()
|
||||
@@ -287,11 +293,11 @@ class TypeTests(unittest.TestCase):
|
||||
s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)",
|
||||
('', '', '', [''], {'': 3}))
|
||||
self.assertEqual(
|
||||
{'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedDict({'': 3})},
|
||||
{'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedMap({'': 3})},
|
||||
s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0])
|
||||
|
||||
self.assertEqual(
|
||||
{'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedDict({'': 3})},
|
||||
{'c': '', 'o': '', 's': '', 'l': [''], 'n': OrderedMap({'': 3})},
|
||||
s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0])
|
||||
|
||||
# non-string types shouldn't accept empty strings
|
||||
@@ -708,10 +714,10 @@ class TypeTests(unittest.TestCase):
|
||||
self.assertEquals((None, None, None, None), s.execute(read)[0].t)
|
||||
|
||||
# also test empty strings where compatible
|
||||
s.execute(insert, [('', None, None, '')])
|
||||
s.execute(insert, [('', None, None, b'')])
|
||||
result = s.execute("SELECT * FROM mytable WHERE k=0")
|
||||
self.assertEquals(('', None, None, ''), result[0].t)
|
||||
self.assertEquals(('', None, None, ''), s.execute(read)[0].t)
|
||||
self.assertEquals(('', None, None, b''), result[0].t)
|
||||
self.assertEquals(('', None, None, b''), s.execute(read)[0].t)
|
||||
|
||||
c.shutdown()
|
||||
|
||||
@@ -721,3 +727,61 @@ class TypeTests(unittest.TestCase):
|
||||
|
||||
query = u"SELECT * FROM system.schema_columnfamilies WHERE keyspace_name = 'ef\u2052ef' AND columnfamily_name = %s"
|
||||
s.execute(query, (u"fe\u2051fe",))
|
||||
|
||||
def insert_select_column(self, session, table_name, column_name, value):
|
||||
insert = session.prepare("INSERT INTO %s (k, %s) VALUES (?, ?)" % (table_name, column_name))
|
||||
session.execute(insert, (0, value))
|
||||
result = session.execute("SELECT %s FROM %s WHERE k=%%s" % (column_name, table_name), (0,))[0][0]
|
||||
self.assertEqual(result, value)
|
||||
|
||||
def test_nested_collections(self):
|
||||
|
||||
if self._cass_version < (2, 1, 3):
|
||||
raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3")
|
||||
|
||||
name = self._testMethodName
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect('test1rf')
|
||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||
|
||||
s.execute("""
|
||||
CREATE TYPE %s (
|
||||
m frozen<map<int,text>>,
|
||||
t tuple<int,text>,
|
||||
l frozen<list<int>>,
|
||||
s frozen<set<int>>
|
||||
)""" % name)
|
||||
s.execute("""
|
||||
CREATE TYPE %s_nested (
|
||||
m frozen<map<int,text>>,
|
||||
t tuple<int,text>,
|
||||
l frozen<list<int>>,
|
||||
s frozen<set<int>>,
|
||||
u frozen<%s>
|
||||
)""" % (name, name))
|
||||
s.execute("""
|
||||
CREATE TABLE %s (
|
||||
k int PRIMARY KEY,
|
||||
map_map map<frozen<map<int,int>>, frozen<map<int,int>>>,
|
||||
map_set map<frozen<set<int>>, frozen<set<int>>>,
|
||||
map_list map<frozen<list<int>>, frozen<list<int>>>,
|
||||
map_tuple map<frozen<tuple<int, int>>, frozen<tuple<int>>>,
|
||||
map_udt map<frozen<%s_nested>, frozen<%s>>,
|
||||
)""" % (name, name, name))
|
||||
|
||||
validate = partial(self.insert_select_column, s, name)
|
||||
validate('map_map', OrderedMap([({1: 1, 2: 2}, {3: 3, 4: 4}), ({5: 5, 6: 6}, {7: 7, 8: 8})]))
|
||||
validate('map_set', OrderedMap([(set((1, 2)), set((3, 4))), (set((5, 6)), set((7, 8)))]))
|
||||
validate('map_list', OrderedMap([([1, 2], [3, 4]), ([5, 6], [7, 8])]))
|
||||
validate('map_tuple', OrderedMap([((1, 2), (3,)), ((4, 5), (6,))]))
|
||||
|
||||
value = nested_collection_udt({1: 'v1', 2: 'v2'}, (3, 'v3'), [4, 5, 6, 7], set((8, 9, 10)))
|
||||
key = nested_collection_udt_nested(value.m, value.t, value.l, value.s, value)
|
||||
key2 = nested_collection_udt_nested({3: 'v3'}, value.t, value.l, value.s, value)
|
||||
validate('map_udt', OrderedMap([(key, value), (key2, value)]))
|
||||
|
||||
s.execute("DROP TABLE %s" % (name))
|
||||
s.execute("DROP TYPE %s_nested" % (name))
|
||||
s.execute("DROP TYPE %s" % (name))
|
||||
s.shutdown()
|
||||
|
||||
37
tests/integration/util.py
Normal file
37
tests/integration/util.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright 2013-2014 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tests.integration import PROTOCOL_VERSION
|
||||
|
||||
def assert_quiescent_pool_state(test_case, cluster):
|
||||
|
||||
for session in cluster.sessions:
|
||||
pool_states = session.get_pool_state().values()
|
||||
test_case.assertTrue(pool_states)
|
||||
|
||||
for state in pool_states:
|
||||
test_case.assertFalse(state['shutdown'])
|
||||
test_case.assertGreater(state['open_count'], 0)
|
||||
test_case.assertTrue(all((i == 0 for i in state['in_flights'])))
|
||||
|
||||
for holder in cluster.get_connection_holders():
|
||||
for connection in holder.get_connections():
|
||||
# all ids are unique
|
||||
req_ids = connection.request_ids
|
||||
test_case.assertEqual(len(req_ids), len(set(req_ids)))
|
||||
test_case.assertEqual(connection.highest_request_id, len(req_ids) - 1)
|
||||
test_case.assertEqual(connection.highest_request_id, max(req_ids))
|
||||
if PROTOCOL_VERSION < 3:
|
||||
test_case.assertEqual(connection.highest_request_id, connection.max_request_id)
|
||||
|
||||
@@ -31,20 +31,20 @@ from mock import patch, Mock
|
||||
|
||||
from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT,
|
||||
ConnectionException)
|
||||
|
||||
from cassandra.io.asyncorereactor import AsyncoreConnection
|
||||
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
|
||||
SupportedMessage, ReadyMessage, ServerError)
|
||||
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
|
||||
|
||||
from cassandra.io.asyncorereactor import AsyncoreConnection
|
||||
from tests import is_monkey_patched
|
||||
|
||||
|
||||
class AsyncoreConnectionTest(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if 'gevent.monkey' in sys.modules:
|
||||
raise unittest.SkipTest("gevent monkey-patching detected")
|
||||
if is_monkey_patched():
|
||||
raise unittest.SkipTest("monkey-patching detected")
|
||||
AsyncoreConnection.initialize_reactor()
|
||||
cls.socket_patcher = patch('socket.socket', spec=socket.socket)
|
||||
cls.mock_socket = cls.socket_patcher.start()
|
||||
|
||||
@@ -11,21 +11,21 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import six
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
import unittest # noqa
|
||||
|
||||
from mock import Mock, ANY, call
|
||||
import six
|
||||
from six import BytesIO
|
||||
|
||||
from mock import Mock, ANY
|
||||
import time
|
||||
from threading import Lock
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT,
|
||||
HEADER_DIRECTION_FROM_CLIENT, ProtocolError,
|
||||
locally_supported_compressions)
|
||||
locally_supported_compressions, ConnectionHeartbeat)
|
||||
from cassandra.marshal import uint8_pack, uint32_pack
|
||||
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
|
||||
SupportedMessage)
|
||||
@@ -280,3 +280,155 @@ class ConnectionTest(unittest.TestCase):
|
||||
def test_set_connection_class(self):
|
||||
cluster = Cluster(connection_class='test')
|
||||
self.assertEqual('test', cluster.connection_class)
|
||||
|
||||
|
||||
class ConnectionHeartbeatTest(unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def make_get_holders(len):
|
||||
holders = []
|
||||
for _ in range(len):
|
||||
holder = Mock()
|
||||
holder.get_connections = Mock(return_value=[])
|
||||
holders.append(holder)
|
||||
get_holders = Mock(return_value=holders)
|
||||
return get_holders
|
||||
|
||||
def run_heartbeat(self, get_holders_fun, count=2, interval=0.05):
|
||||
ch = ConnectionHeartbeat(interval, get_holders_fun)
|
||||
time.sleep(interval * count)
|
||||
ch.stop()
|
||||
ch.join()
|
||||
self.assertTrue(get_holders_fun.call_count)
|
||||
|
||||
def test_empty_connections(self):
|
||||
count = 3
|
||||
get_holders = self.make_get_holders(1)
|
||||
|
||||
self.run_heartbeat(get_holders, count)
|
||||
|
||||
self.assertGreaterEqual(get_holders.call_count, count - 1) # lower bound to account for thread spinup time
|
||||
self.assertLessEqual(get_holders.call_count, count)
|
||||
holder = get_holders.return_value[0]
|
||||
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
|
||||
|
||||
def test_idle_non_idle(self):
|
||||
request_id = 999
|
||||
|
||||
# connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback)
|
||||
def send_msg(msg, req_id, msg_callback):
|
||||
msg_callback(SupportedMessage([], {}))
|
||||
|
||||
idle_connection = Mock(spec=Connection, host='localhost',
|
||||
max_request_id=127,
|
||||
lock=Lock(),
|
||||
in_flight=0, is_idle=True,
|
||||
is_defunct=False, is_closed=False,
|
||||
get_request_id=lambda: request_id,
|
||||
send_msg=Mock(side_effect=send_msg))
|
||||
non_idle_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=False, is_closed=False)
|
||||
|
||||
get_holders = self.make_get_holders(1)
|
||||
holder = get_holders.return_value[0]
|
||||
holder.get_connections.return_value.append(idle_connection)
|
||||
holder.get_connections.return_value.append(non_idle_connection)
|
||||
|
||||
self.run_heartbeat(get_holders)
|
||||
|
||||
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
|
||||
self.assertEqual(idle_connection.in_flight, 0)
|
||||
self.assertEqual(non_idle_connection.in_flight, 0)
|
||||
|
||||
idle_connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count)
|
||||
self.assertEqual(non_idle_connection.send_msg.call_count, 0)
|
||||
|
||||
def test_closed_defunct(self):
|
||||
get_holders = self.make_get_holders(1)
|
||||
closed_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=False, is_closed=True)
|
||||
defunct_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=True, is_closed=False)
|
||||
holder = get_holders.return_value[0]
|
||||
holder.get_connections.return_value.append(closed_connection)
|
||||
holder.get_connections.return_value.append(defunct_connection)
|
||||
|
||||
self.run_heartbeat(get_holders)
|
||||
|
||||
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
|
||||
self.assertEqual(closed_connection.in_flight, 0)
|
||||
self.assertEqual(defunct_connection.in_flight, 0)
|
||||
self.assertEqual(closed_connection.send_msg.call_count, 0)
|
||||
self.assertEqual(defunct_connection.send_msg.call_count, 0)
|
||||
|
||||
def test_no_req_ids(self):
|
||||
in_flight = 3
|
||||
|
||||
get_holders = self.make_get_holders(1)
|
||||
max_connection = Mock(spec=Connection, host='localhost',
|
||||
max_request_id=in_flight, in_flight=in_flight,
|
||||
is_idle=True, is_defunct=False, is_closed=False)
|
||||
holder = get_holders.return_value[0]
|
||||
holder.get_connections.return_value.append(max_connection)
|
||||
|
||||
self.run_heartbeat(get_holders)
|
||||
|
||||
holder.get_connections.assert_has_calls([call()] * get_holders.call_count)
|
||||
self.assertEqual(max_connection.in_flight, in_flight)
|
||||
self.assertEqual(max_connection.send_msg.call_count, 0)
|
||||
self.assertEqual(max_connection.send_msg.call_count, 0)
|
||||
max_connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count)
|
||||
holder.return_connection.assert_has_calls([call(max_connection)] * get_holders.call_count)
|
||||
|
||||
def test_unexpected_response(self):
|
||||
request_id = 999
|
||||
|
||||
get_holders = self.make_get_holders(1)
|
||||
|
||||
def send_msg(msg, req_id, msg_callback):
|
||||
msg_callback(object())
|
||||
|
||||
connection = Mock(spec=Connection, host='localhost',
|
||||
max_request_id=127,
|
||||
lock=Lock(),
|
||||
in_flight=0, is_idle=True,
|
||||
is_defunct=False, is_closed=False,
|
||||
get_request_id=lambda: request_id,
|
||||
send_msg=Mock(side_effect=send_msg))
|
||||
holder = get_holders.return_value[0]
|
||||
holder.get_connections.return_value.append(connection)
|
||||
|
||||
self.run_heartbeat(get_holders)
|
||||
|
||||
self.assertEqual(connection.in_flight, get_holders.call_count)
|
||||
connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count)
|
||||
connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count)
|
||||
exc = connection.defunct.call_args_list[0][0][0]
|
||||
self.assertIsInstance(exc, Exception)
|
||||
self.assertEqual(exc.args, Exception('Connection heartbeat failure').args)
|
||||
holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count)
|
||||
|
||||
def test_timeout(self):
|
||||
request_id = 999
|
||||
|
||||
get_holders = self.make_get_holders(1)
|
||||
|
||||
def send_msg(msg, req_id, msg_callback):
|
||||
pass
|
||||
|
||||
connection = Mock(spec=Connection, host='localhost',
|
||||
max_request_id=127,
|
||||
lock=Lock(),
|
||||
in_flight=0, is_idle=True,
|
||||
is_defunct=False, is_closed=False,
|
||||
get_request_id=lambda: request_id,
|
||||
send_msg=Mock(side_effect=send_msg))
|
||||
holder = get_holders.return_value[0]
|
||||
holder.get_connections.return_value.append(connection)
|
||||
|
||||
self.run_heartbeat(get_holders)
|
||||
|
||||
self.assertEqual(connection.in_flight, get_holders.call_count)
|
||||
connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count)
|
||||
connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count)
|
||||
exc = connection.defunct.call_args_list[0][0][0]
|
||||
self.assertIsInstance(exc, Exception)
|
||||
self.assertEqual(exc.args, Exception('Connection heartbeat failure').args)
|
||||
holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count)
|
||||
|
||||
@@ -73,7 +73,7 @@ class MockCluster(object):
|
||||
self.scheduler = Mock(spec=_Scheduler)
|
||||
self.executor = Mock(spec=ThreadPoolExecutor)
|
||||
|
||||
def add_host(self, address, datacenter, rack, signal=False):
|
||||
def add_host(self, address, datacenter, rack, signal=False, refresh_nodes=True):
|
||||
host = Host(address, SimpleConvictionPolicy, datacenter, rack)
|
||||
self.added_hosts.append(host)
|
||||
return host
|
||||
@@ -131,7 +131,7 @@ class ControlConnectionTest(unittest.TestCase):
|
||||
self.connection = MockConnection()
|
||||
self.time = FakeTime()
|
||||
|
||||
self.control_connection = ControlConnection(self.cluster, timeout=1)
|
||||
self.control_connection = ControlConnection(self.cluster, 1, 0, 0)
|
||||
self.control_connection._connection = self.connection
|
||||
self.control_connection._time = self.time
|
||||
|
||||
@@ -345,39 +345,44 @@ class ControlConnectionTest(unittest.TestCase):
|
||||
'change_type': 'NEW_NODE',
|
||||
'address': ('1.2.3.4', 9000)
|
||||
}
|
||||
self.cluster.scheduler.reset_mock()
|
||||
self.control_connection._handle_topology_change(event)
|
||||
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map)
|
||||
self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map)
|
||||
|
||||
event = {
|
||||
'change_type': 'REMOVED_NODE',
|
||||
'address': ('1.2.3.4', 9000)
|
||||
}
|
||||
self.cluster.scheduler.reset_mock()
|
||||
self.control_connection._handle_topology_change(event)
|
||||
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.remove_host, None)
|
||||
self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.remove_host, None)
|
||||
|
||||
event = {
|
||||
'change_type': 'MOVED_NODE',
|
||||
'address': ('1.2.3.4', 9000)
|
||||
}
|
||||
self.cluster.scheduler.reset_mock()
|
||||
self.control_connection._handle_topology_change(event)
|
||||
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map)
|
||||
self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map)
|
||||
|
||||
def test_handle_status_change(self):
|
||||
event = {
|
||||
'change_type': 'UP',
|
||||
'address': ('1.2.3.4', 9000)
|
||||
}
|
||||
self.cluster.scheduler.reset_mock()
|
||||
self.control_connection._handle_status_change(event)
|
||||
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map)
|
||||
self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map)
|
||||
|
||||
# do the same with a known Host
|
||||
event = {
|
||||
'change_type': 'UP',
|
||||
'address': ('192.168.1.0', 9000)
|
||||
}
|
||||
self.cluster.scheduler.reset_mock()
|
||||
self.control_connection._handle_status_change(event)
|
||||
host = self.cluster.metadata.hosts['192.168.1.0']
|
||||
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.on_up, host)
|
||||
self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.on_up, host)
|
||||
|
||||
self.cluster.scheduler.schedule.reset_mock()
|
||||
event = {
|
||||
@@ -404,9 +409,11 @@ class ControlConnectionTest(unittest.TestCase):
|
||||
'keyspace': 'ks1',
|
||||
'table': 'table1'
|
||||
}
|
||||
self.cluster.scheduler.reset_mock()
|
||||
self.control_connection._handle_schema_change(event)
|
||||
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', 'table1', None)
|
||||
self.cluster.scheduler.schedule_unique.assert_called_once_with(0.0, self.control_connection.refresh_schema, 'ks1', 'table1', None)
|
||||
|
||||
self.cluster.scheduler.reset_mock()
|
||||
event['table'] = None
|
||||
self.control_connection._handle_schema_change(event)
|
||||
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', None, None)
|
||||
self.cluster.scheduler.schedule_unique.assert_called_once_with(0.0, self.control_connection.refresh_schema, 'ks1', None, None)
|
||||
|
||||
@@ -24,7 +24,7 @@ from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from cassandra.cqltypes import lookup_casstype
|
||||
from cassandra.util import OrderedDict, sortedset
|
||||
from cassandra.util import OrderedMap, sortedset
|
||||
|
||||
marshalled_value_pairs = (
|
||||
# binary form, type, python native type
|
||||
@@ -75,7 +75,7 @@ marshalled_value_pairs = (
|
||||
(b'', 'MapType(AsciiType, BooleanType)', None),
|
||||
(b'', 'ListType(FloatType)', None),
|
||||
(b'', 'SetType(LongType)', None),
|
||||
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()),
|
||||
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMap()),
|
||||
(b'\x00\x00', 'ListType(FloatType)', []),
|
||||
(b'\x00\x00', 'SetType(IntegerType)', sortedset()),
|
||||
(b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]),
|
||||
@@ -84,15 +84,14 @@ marshalled_value_pairs = (
|
||||
(b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', 1)
|
||||
)
|
||||
|
||||
ordered_dict_value = OrderedDict()
|
||||
ordered_dict_value[u'\u307fbob'] = 199
|
||||
ordered_dict_value[u''] = -1
|
||||
ordered_dict_value[u'\\'] = 0
|
||||
ordered_map_value = OrderedMap([(u'\u307fbob', 199),
|
||||
(u'', -1),
|
||||
(u'\\', 0)])
|
||||
|
||||
# these following entries work for me right now, but they're dependent on
|
||||
# vagaries of internal python ordering for unordered types
|
||||
marshalled_value_pairs_unsafe = (
|
||||
(b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value),
|
||||
(b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_map_value),
|
||||
(b'\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])),
|
||||
(b'\x00', 'IntegerType', 0),
|
||||
)
|
||||
|
||||
127
tests/unit/test_orderedmap.py
Normal file
127
tests/unit/test_orderedmap.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright 2013-2014 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.util import OrderedMap
|
||||
from cassandra.cqltypes import EMPTY
|
||||
import six
|
||||
|
||||
class OrderedMapTest(unittest.TestCase):
|
||||
def test_init(self):
|
||||
a = OrderedMap(zip(['one', 'three', 'two'], [1, 3, 2]))
|
||||
b = OrderedMap([('one', 1), ('three', 3), ('two', 2)])
|
||||
c = OrderedMap(a)
|
||||
builtin = {'one': 1, 'two': 2, 'three': 3}
|
||||
self.assertEqual(a, b)
|
||||
self.assertEqual(a, c)
|
||||
self.assertEqual(a, builtin)
|
||||
self.assertEqual(OrderedMap([(1, 1), (1, 2)]), {1: 2})
|
||||
|
||||
def test_contains(self):
|
||||
keys = ['first', 'middle', 'last']
|
||||
|
||||
od = OrderedMap()
|
||||
|
||||
od = OrderedMap(zip(keys, range(len(keys))))
|
||||
|
||||
for k in keys:
|
||||
self.assertTrue(k in od)
|
||||
self.assertFalse(k not in od)
|
||||
|
||||
self.assertTrue('notthere' not in od)
|
||||
self.assertFalse('notthere' in od)
|
||||
|
||||
def test_keys(self):
|
||||
keys = ['first', 'middle', 'last']
|
||||
od = OrderedMap(zip(keys, range(len(keys))))
|
||||
|
||||
self.assertListEqual(list(od.keys()), keys)
|
||||
|
||||
def test_values(self):
|
||||
keys = ['first', 'middle', 'last']
|
||||
values = list(range(len(keys)))
|
||||
od = OrderedMap(zip(keys, values))
|
||||
|
||||
self.assertListEqual(list(od.values()), values)
|
||||
|
||||
def test_items(self):
|
||||
keys = ['first', 'middle', 'last']
|
||||
items = list(zip(keys, range(len(keys))))
|
||||
od = OrderedMap(items)
|
||||
|
||||
self.assertListEqual(list(od.items()), items)
|
||||
|
||||
def test_get(self):
|
||||
keys = ['first', 'middle', 'last']
|
||||
od = OrderedMap(zip(keys, range(len(keys))))
|
||||
|
||||
for v, k in enumerate(keys):
|
||||
self.assertEqual(od.get(k), v)
|
||||
|
||||
self.assertEqual(od.get('notthere', 'default'), 'default')
|
||||
self.assertIsNone(od.get('notthere'))
|
||||
|
||||
def test_equal(self):
|
||||
d1 = {'one': 1}
|
||||
d12 = {'one': 1, 'two': 2}
|
||||
od1 = OrderedMap({'one': 1})
|
||||
od12 = OrderedMap([('one', 1), ('two', 2)])
|
||||
od21 = OrderedMap([('two', 2), ('one', 1)])
|
||||
|
||||
self.assertEqual(od1, d1)
|
||||
self.assertEqual(od12, d12)
|
||||
self.assertEqual(od21, d12)
|
||||
self.assertNotEqual(od1, od12)
|
||||
self.assertNotEqual(od12, od1)
|
||||
self.assertNotEqual(od12, od21)
|
||||
self.assertNotEqual(od1, d12)
|
||||
self.assertNotEqual(od12, d1)
|
||||
self.assertNotEqual(od1, EMPTY)
|
||||
|
||||
def test_getitem(self):
|
||||
keys = ['first', 'middle', 'last']
|
||||
od = OrderedMap(zip(keys, range(len(keys))))
|
||||
|
||||
for v, k in enumerate(keys):
|
||||
self.assertEqual(od[k], v)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
od['notthere']
|
||||
|
||||
def test_iter(self):
|
||||
keys = ['first', 'middle', 'last']
|
||||
values = list(range(len(keys)))
|
||||
items = list(zip(keys, values))
|
||||
od = OrderedMap(items)
|
||||
|
||||
itr = iter(od)
|
||||
self.assertEqual(sum([1 for _ in itr]), len(keys))
|
||||
self.assertRaises(StopIteration, six.next, itr)
|
||||
|
||||
self.assertEqual(list(iter(od)), keys)
|
||||
self.assertEqual(list(six.iteritems(od)), items)
|
||||
self.assertEqual(list(six.itervalues(od)), values)
|
||||
|
||||
def test_len(self):
|
||||
self.assertEqual(len(OrderedMap()), 0)
|
||||
self.assertEqual(len(OrderedMap([(1, 1)])), 1)
|
||||
|
||||
def test_mutable_keys(self):
|
||||
d = {'1': 1}
|
||||
s = set([1, 2, 3])
|
||||
od = OrderedMap([(d, 'dict'), (s, 'set')])
|
||||
@@ -262,11 +262,13 @@ class TypeTests(unittest.TestCase):
|
||||
self.assertEqual(cassandra_type.val, 'randomvaluetocheck')
|
||||
|
||||
def test_datetype(self):
|
||||
now_timestamp = time.time()
|
||||
now_datetime = datetime.datetime.utcfromtimestamp(now_timestamp)
|
||||
now_time_seconds = time.time()
|
||||
now_datetime = datetime.datetime.utcfromtimestamp(now_time_seconds)
|
||||
|
||||
# Cassandra timestamps in millis
|
||||
now_timestamp = now_time_seconds * 1e3
|
||||
|
||||
# same results serialized
|
||||
# (this could change if we follow up on the timestamp multiplication warning in DateType.serialize)
|
||||
self.assertEqual(DateType.serialize(now_datetime, 0), DateType.serialize(now_timestamp, 0))
|
||||
|
||||
# from timestamp
|
||||
|
||||
Reference in New Issue
Block a user