Merge branch '2.0' into py3k

Conflicts:
	cassandra/cluster.py
	cassandra/connection.py
	cassandra/io/asyncorereactor.py
	cassandra/io/libevreactor.py
	setup.py
	tests/integration/long/test_large_data.py
This commit is contained in:
Tyler Hobbs
2014-04-03 17:45:53 -05:00
33 changed files with 1158 additions and 432 deletions

View File

@@ -1,9 +1,11 @@
1.0.3
1.1.0
=====
In Progress
Features
--------
* Gevent is now supported through monkey-patching the stdlib (PYTHON-7,
github issue #46)
* Support static columns in schemas, which are available starting in
Cassandra 2.1. (github issue #91)
@@ -15,12 +17,23 @@ Bug Fixes
* Ignore SSL_ERROR_WANT_READ and SSL_ERROR_WANT_WRITE socket errors. Previously,
these resulted in the connection being defuncted, but they can safely be
ignored by the driver.
* Don't reconnect the control connection every time Cluster.connect() is
called
* Avoid race condition that could leave ResponseFuture callbacks uncalled
if the callback was added outside of the event loop thread (github issue #95)
* Properly escape keyspace name in Session.set_keyspace(). Previously, the
keyspace name was quoted, but any quotes in the string were not escaped.
* Avoid adding hosts to the load balancing policy before their datacenter
and rack information has been set, if possible.
* Avoid KeyError when updating metadata after droping a table (github issues
#97, #98)
Other
-----
* Don't ignore column names when parsing typestrings. This is needed for
user-defined type support. (github issue #90)
* Better error message when libevwrapper is not found
* Only try to import scales when metrics are enabled (github issue #92)
1.0.2
=====

View File

@@ -2,6 +2,8 @@
This module houses the main classes you will interact with,
:class:`.Cluster` and :class:`.Session`.
"""
from __future__ import absolute_import
from concurrent.futures import ThreadPoolExecutor
import logging
import socket
@@ -37,8 +39,7 @@ from cassandra.decoder import (QueryMessage, ResultMessage,
BatchMessage, RESULT_KIND_PREPARED,
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
RESULT_KIND_SCHEMA_CHANGE)
from cassandra.metadata import Metadata
# from cassandra.metrics import Metrics
from cassandra.metadata import Metadata, protect_name
from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy,
ExponentialReconnectionPolicy, HostDistance,
RetryPolicy)
@@ -48,11 +49,15 @@ from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
BatchStatement, bind_params, QueryTrace, Statement,
named_tuple_factory, dict_factory)
# libev is all around faster, so we want to try and default to using that when we can
try:
from cassandra.io.libevreactor import LibevConnection as DefaultConnection
except ImportError:
from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA
# default to gevent when we are monkey patched, otherwise if libev is available, use that as the
# default because it's faster than asyncore
if 'gevent.monkey' in sys.modules:
from cassandra.io.geventreactor import GeventConnection as DefaultConnection
else:
try:
from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA
except ImportError:
from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA
# Forces load of utf8 encoding module to avoid deadlock that occurs
# if code that is being imported tries to import the module in a seperate
@@ -147,8 +152,15 @@ class Cluster(object):
server will be automatically used.
"""
# TODO: docs
protocol_version = 2
"""
The version of the native protocol to use. The protocol version 2
add support for lightweight transactions, batch operations, and
automatic query paging, but is only supported by Cassandra 2.0+. When
working with Cassandra 1.2, this must be set to 1. You can also set
this to 1 when working with Cassandra 2.0+, but features that require
the version 2 protocol will not be enabled.
"""
compression = True
"""
@@ -287,13 +299,6 @@ class Cluster(object):
Any of the mutable Cluster attributes may be set as keyword arguments
to the constructor.
"""
if 'gevent.monkey' in sys.modules:
raise Exception(
"gevent monkey-patching detected. This driver does not currently "
"support gevent, and monkey patching will break the driver "
"completely. You can track progress towards adding gevent "
"support here: https://datastax-oss.atlassian.net/browse/PYTHON-7.")
self.contact_points = contact_points
self.port = port
self.compression = compression
@@ -478,20 +483,21 @@ class Cluster(object):
self.load_balancing_policy.populate(
weakref.proxy(self), self.metadata.all_hosts())
if self.control_connection:
try:
self.control_connection.connect()
log.debug("Control connection created")
except Exception:
log.exception("Control connection failed to connect, "
"shutting down Cluster:")
self.shutdown()
raise
self.load_balancing_policy.check_supported()
self._is_setup = True
if self.control_connection:
try:
self.control_connection.connect()
log.debug("Control connection created")
except Exception:
log.exception("Control connection failed to connect, "
"shutting down Cluster:")
self.shutdown()
raise
self.load_balancing_policy.check_supported()
session = self._new_session()
if keyspace:
session.set_keyspace(keyspace)
@@ -772,13 +778,13 @@ class Cluster(object):
self.on_down(host, is_host_addition, force_if_down=True)
return is_down
def add_host(self, address, signal):
def add_host(self, address, datacenter=None, rack=None, signal=True):
"""
Called when adding initial contact points and when the control
connection subsequently discovers a new node. Intended for internal
use only.
"""
new_host = self.metadata.add_host(address)
new_host = self.metadata.add_host(address, datacenter, rack)
if new_host and signal:
log.info("New Cassandra host %s added", address)
self.on_add(new_host)
@@ -947,10 +953,10 @@ class Session(object):
default_fetch_size = 5000
"""
By default, this many rows will be fetched at a time. This can be
specified per-query through :attr:`~Statement.fetch_size`.
specified per-query through :attr:`.Statement.fetch_size`.
This only takes effect when protocol version 2 or higher is used.
See :attr:`~Cluster.protocol_version` for details.
See :attr:`.Cluster.protocol_version` for details.
"""
_lock = None
@@ -1293,7 +1299,7 @@ class Session(object):
Set the default keyspace for all queries made through this Session.
This operation blocks until complete.
"""
self.execute('USE "%s"' % (keyspace,))
self.execute('USE %s' % (protect_name(keyspace),))
def _set_keyspace_for_all_pools(self, keyspace, callback):
"""
@@ -1602,7 +1608,9 @@ class ControlConnection(object):
host = self._cluster.metadata.get_host(connection.host)
if host:
host.set_location_info(local_row["data_center"], local_row["rack"])
datacenter = local_row.get("data_center")
rack = local_row.get("rack")
self._update_location_info(host, datacenter, rack)
partitioner = local_row.get("partitioner")
tokens = local_row.get("tokens")
@@ -1620,10 +1628,13 @@ class ControlConnection(object):
found_hosts.add(addr)
host = self._cluster.metadata.get_host(addr)
datacenter = row.get("data_center")
rack = row.get("rack")
if host is None:
log.debug("[control connection] Found new host to connect to: %s", addr)
host = self._cluster.add_host(addr, signal=True)
host.set_location_info(row.get("data_center"), row.get("rack"))
host = self._cluster.add_host(addr, datacenter, rack, signal=True)
else:
self._update_location_info(host, datacenter, rack)
tokens = row.get("tokens")
if partitioner and tokens:
@@ -1640,11 +1651,22 @@ class ControlConnection(object):
log.debug("[control connection] Fetched ring info, rebuilding metadata")
self._cluster.metadata.rebuild_token_map(partitioner, token_map)
def _update_location_info(self, host, datacenter, rack):
if host.datacenter == datacenter and host.rack == rack:
return
# If the dc/rack information changes, we need to update the load balancing policy.
# For that, we remove and re-add the node against the policy. Not the most elegant, and assumes
# that the policy will update correctly, but in practice this should work.
self._cluster.load_balancing_policy.on_down(host)
host.set_location_info(datacenter, rack)
self._cluster.load_balancing_policy.on_up(host)
def _handle_topology_change(self, event):
change_type = event["change_type"]
addr, port = event["address"]
if change_type == "NEW_NODE":
self._cluster.scheduler.schedule(10, self._cluster.add_host, addr, signal=True)
self._cluster.scheduler.schedule(10, self.refresh_node_list_and_token_map)
elif change_type == "REMOVED_NODE":
host = self._cluster.metadata.get_host(addr)
self._cluster.scheduler.schedule(0, self._cluster.remove_host, host)
@@ -1658,7 +1680,7 @@ class ControlConnection(object):
if change_type == "UP":
if host is None:
# this is the first time we've seen the node
self._cluster.scheduler.schedule(1, self._cluster.add_host, addr, signal=True)
self._cluster.scheduler.schedule(1, self.refresh_node_list_and_token_map)
else:
# this will be run by the scheduler
self._cluster.scheduler.schedule(1, self._cluster.on_up, host)
@@ -1897,23 +1919,19 @@ class ResponseFuture(object):
self.default_timeout = default_timeout
self._metrics = metrics
self.prepared_statement = prepared_statement
self._callback_lock = Lock()
if metrics is not None:
self._start_time = time.time()
# convert the list/generator/etc to an iterator so that subsequent
# calls to send_request (which retries may do) will resume where
# they last left off
self.query_plan = iter(session._load_balancer.make_query_plan(
session.keyspace, query))
self._make_query_plan()
self._event = Event()
self._errors = {}
def __del__(self):
try:
del self.session
except AttributeError:
pass
def _make_query_plan(self):
# convert the list/generator/etc to an iterator so that subsequent
# calls to send_request (which retries may do) will resume where
# they last left off
self.query_plan = iter(self.session._load_balancer.make_query_plan(
self.session.keyspace, self.query))
def send_request(self):
""" Internal """
@@ -1951,6 +1969,8 @@ class ResponseFuture(object):
except Exception as exc:
log.debug("Error querying host %s", host, exc_info=True)
self._errors[host] = exc
if self._metrics is not None:
self._metrics.on_connection_error()
if connection:
pool.return_connection(connection)
return None
@@ -1960,6 +1980,33 @@ class ResponseFuture(object):
self._connection = connection
return request_id
@property
def has_more_pages(self):
"""
Returns :const:`True` if there are more pages left in the
query results, :const:`False` otherwise. This should only
be checked after the first page has been returned.
"""
return self._paging_state is not None
def start_fetching_next_page(self):
"""
If there are more pages left in the query result, this asynchronously
starts fetching the next page. If there are no pages left, :exc:`.QueryExhausted`
is raised. Also see :attr:`.has_more_pages`.
This should only be called after the first page has been returned.
"""
if not self._paging_state:
raise QueryExhausted()
self._make_query_plan()
self.message.paging_state = self._paging_state
self._event.clear()
self._final_result = _NOT_SET
self._final_exception = None
self.send_request()
def _reprepare(self, prepare_message):
cb = partial(self.session.submit, self._execute_after_prepare)
request_id = self._query(self._current_host, prepare_message, cb=cb)
@@ -2163,12 +2210,10 @@ class ResponseFuture(object):
def _set_final_result(self, response):
if self._metrics is not None:
self._metrics.request_timer.addValue(time.time() - self._start_time)
if hasattr(self, 'session'):
try:
del self.session # clear reference cycles
except AttributeError:
pass
self._final_result = response
with self._callback_lock:
self._final_result = response
self._event.set()
if self._callback:
fn, args, kwargs = self._callback
@@ -2177,11 +2222,9 @@ class ResponseFuture(object):
def _set_final_exception(self, response):
if self._metrics is not None:
self._metrics.request_timer.addValue(time.time() - self._start_time)
try:
del self.session # clear reference cycles
except AttributeError:
pass
self._final_exception = response
with self._callback_lock:
self._final_exception = response
self._event.set()
if self._errback:
fn, args, kwargs = self._errback
@@ -2242,13 +2285,19 @@ class ResponseFuture(object):
timeout = self.default_timeout
if self._final_result is not _NOT_SET:
return self._final_result
if self._paging_state is None:
return self._final_result
else:
return PagedResult(self, self._final_result)
elif self._final_exception:
raise self._final_exception
else:
self._event.wait(timeout=timeout)
if self._final_result is not _NOT_SET:
return self._final_result
if self._paging_state is None:
return self._final_result
else:
return PagedResult(self, self._final_result)
elif self._final_exception:
raise self._final_exception
else:
@@ -2301,10 +2350,14 @@ class ResponseFuture(object):
>>> future.add_callback(handle_results, time.time(), should_log=True)
"""
if self._final_result is not _NOT_SET:
run_now = False
with self._callback_lock:
if self._final_result is not _NOT_SET:
run_now = True
else:
self._callback = (fn, args, kwargs)
if run_now:
fn(self._final_result, *args, **kwargs)
else:
self._callback = (fn, args, kwargs)
return self
def add_errback(self, fn, *args, **kwargs):
@@ -2313,10 +2366,14 @@ class ResponseFuture(object):
An Exception instance will be passed as the first positional argument
to `fn`.
"""
if self._final_exception:
run_now = False
with self._callback_lock:
if self._final_exception:
run_now = True
else:
self._errback = (fn, args, kwargs)
if run_now:
fn(self._final_exception, *args, **kwargs)
else:
self._errback = (fn, args, kwargs)
return self
def add_callbacks(self, callback, errback,
@@ -2352,3 +2409,58 @@ class ResponseFuture(object):
return "<ResponseFuture: query='%s' request_id=%s result=%s exception=%s host=%s>" \
% (self.query, self._req_id, result, self._final_exception, self._current_host)
__repr__ = __str__
class QueryExhausted(Exception):
"""
Raised when :meth:`.ResultSet.start_fetching_next_page()` is called and
there are no more pages. You can check :attr:`.ResultSet.has_more_pages`
before calling to avoid this.
"""
pass
class PagedResult(object):
"""
An iterator over the rows from a paged query result. Whenever the number
of result rows for a query exceed the :attr:`~.query.Statement.fetch_size`
(or :attr:`~.Session.default_fetch_size`, if not set) an instance of this
class will be returned.
You can treat this as a normal iterator over rows::
>>> from cassandra.query import SimpleStatement
>>> statement = SimpleStatement("SELECT * FROM users", fetch_size=10)
>>> for user_row in session.execute(statement):
... process_user(user_row)
Whenever there are no more rows in the current page, the next page will
be fetched transparently. However, note that it _is_ possible for
an :class:`Exception` to be raised while fetching the next page, just
like you might see on a normal call to ``session.execute()``.
"""
def __init__(self, response_future, initial_response):
self.response_future = response_future
self.current_response = iter(initial_response)
def __iter__(self):
return self
def next(self):
try:
return next(self.current_response)
except StopIteration:
if self.response_future._paging_state is None:
raise
self.response_future.start_fetching_next_page()
result = self.response_future.result()
if self.response_future.has_more_pages:
self.current_response = result.current_response
else:
self.current_response = iter(result)
return next(self.current_response)
__next__ = next

148
cassandra/concurrent.py Normal file
View File

@@ -0,0 +1,148 @@
import sys
from itertools import count, cycle
from threading import Event
def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True):
"""
Executes a sequence of (statement, parameters) tuples concurrently. Each
``parameters`` item must be a sequence or :const:`None`.
A sequence of ``(success, result_or_exc)`` tuples is returned in the same
order that the statements were passed in. If ``success`` if :const:`False`,
there was an error executing the statement, and ``result_or_exc`` will be
an :class:`Exception`. If ``success`` is :const:`True`, ``result_or_exc``
will be the query result.
If `raise_on_first_error` is left as :const:`True`, execution will stop
after the first failed statement and the corresponding exception will be
raised.
The `concurrency` parameter controls how many statements will be executed
concurrently. It is recommended that this be kept below the number of
core connections per host times the number of connected hosts (see
:meth:`.Cluster.set_core_connections_per_host`). If that amount is exceeded,
the event loop thread may attempt to block on new connection creation,
substantially impacting throughput.
Example usage::
select_statement = session.prepare("SELECT * FROM users WHERE id=?")
statements_and_params = []
for user_id in user_ids:
statatements_and_params.append(
(select_statement, user_id))
results = execute_concurrent(
session, statements_and_params, raise_on_first_error=False)
for (success, result) in results:
if not success:
handle_error(result) # result will be an Exception
else:
process_user(result[0]) # result will be a list of rows
"""
if concurrency <= 0:
raise ValueError("concurrency must be greater than 0")
if not statements_and_parameters:
return []
event = Event()
first_error = [] if raise_on_first_error else None
to_execute = len(statements_and_parameters) # TODO handle iterators/generators
results = [None] * to_execute
num_finished = count(start=1)
statements = enumerate(iter(statements_and_parameters))
for i in xrange(min(concurrency, len(statements_and_parameters))):
_execute_next(_sentinel, i, event, session, statements, results, num_finished, to_execute, first_error)
event.wait()
if first_error:
raise first_error[0]
else:
return results
def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs):
"""
Like :meth:`~.execute_concurrent`, but takes a single statement and a
sequence of parameters. Each item in ``parameters`` should be a sequence
or :const:`None`.
Example usage::
statement = session.prepare("INSERT INTO mytable (a, b) VALUES (1, ?)")
parameters = [(x,) for x in range(1000)]
execute_concurrent_with_args(session, statement, parameters)
"""
return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs)
_sentinel = object()
def _handle_error(error, result_index, event, session, statements, results, num_finished, to_execute, first_error):
if first_error is not None:
first_error.append(error)
event.set()
return
else:
results[result_index] = (False, error)
if num_finished.next() >= to_execute:
event.set()
return
try:
(next_index, (statement, params)) = statements.next()
except StopIteration:
return
args = (next_index, event, session, statements, results, num_finished, to_execute, first_error)
try:
session.execute_async(statement, params).add_callbacks(
callback=_execute_next, callback_args=args,
errback=_handle_error, errback_args=args)
except Exception as exc:
if first_error is not None:
first_error.append(sys.exc_info())
event.set()
return
else:
results[next_index] = (False, exc)
if num_finished.next() >= to_execute:
event.set()
return
def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error):
if result is not _sentinel:
results[result_index] = (True, result)
finished = num_finished.next()
if finished >= to_execute:
event.set()
return
try:
(next_index, (statement, params)) = statements.next()
except StopIteration:
return
args = (next_index, event, session, statements, results, num_finished, to_execute, first_error)
try:
session.execute_async(statement, params).add_callbacks(
callback=_execute_next, callback_args=args,
errback=_handle_error, errback_args=args)
except Exception as exc:
if first_error is not None:
first_error.append(sys.exc_info())
event.set()
return
else:
results[next_index] = (False, exc)
if num_finished.next() >= to_execute:
event.set()
return

View File

@@ -1,12 +1,20 @@
import errno
from functools import wraps, partial
import logging
import sys
from threading import Event, RLock
import time
import traceback
from six.moves.queue import Queue
if 'gevent.monkey' in sys.modules:
from gevent.queue import Queue, Empty
else:
from six.moves.queue import Queue, Empty # noqa
from six.moves import range
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
from cassandra.marshal import int8_unpack, int32_pack, header_unpack
from cassandra.marshal import int32_pack, header_unpack
from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage,
StartupMessage, ErrorMessage, CredentialsMessage,
QueryMessage, ResultMessage, decode_response,
@@ -151,16 +159,99 @@ class Connection(object):
raise NotImplementedError()
def defunct(self, exc):
raise NotImplementedError()
with self.lock:
if self.is_defunct or self.is_closed:
return
self.is_defunct = True
def send_msg(self, msg, cb):
raise NotImplementedError()
trace = traceback.format_exc(exc)
if trace != "None":
log.debug("Defuncting connection (%s) to %s: %s\n%s",
id(self), self.host, exc, traceback.format_exc(exc))
else:
log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc)
def wait_for_response(self, msg, **kwargs):
raise NotImplementedError()
self.last_error = exc
self.close()
self.error_all_callbacks(exc)
self.connected_event.set()
return exc
def error_all_callbacks(self, exc):
with self.lock:
callbacks = self._callbacks
self._callbacks = {}
new_exc = ConnectionShutdown(str(exc))
for cb in callbacks.values():
try:
cb(new_exc)
except Exception:
log.warn("Ignoring unhandled exception while erroring callbacks for a "
"failed connection (%s) to host %s:",
id(self), self.host, exc_info=True)
def handle_pushed(self, response):
log.debug("Message pushed from server: %r", response)
for cb in self._push_watchers.get(response.event_type, []):
try:
cb(response.event_args)
except Exception:
log.exception("Pushed event handler errored, ignoring:")
def send_msg(self, msg, cb, wait_for_id=False):
if self.is_defunct:
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
elif self.is_closed:
raise ConnectionShutdown("Connection to %s is closed" % self.host)
if not wait_for_id:
try:
request_id = self._id_queue.get_nowait()
except Empty:
raise ConnectionBusy(
"Connection to %s is at the max number of requests" % self.host)
else:
request_id = self._id_queue.get()
self._callbacks[request_id] = cb
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
return request_id
def wait_for_response(self, msg, timeout=None):
return self.wait_for_responses(msg, timeout=timeout)[0]
def wait_for_responses(self, *msgs, **kwargs):
raise NotImplementedError()
timeout = kwargs.get('timeout')
waiter = ResponseWaiter(self, len(msgs))
# busy wait for sufficient space on the connection
messages_sent = 0
while True:
needed = len(msgs) - messages_sent
with self.lock:
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
self.in_flight += available
for i in range(messages_sent, messages_sent + available):
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
messages_sent += available
if messages_sent == len(msgs):
break
else:
if timeout is not None:
timeout -= 0.01
if timeout <= 0.0:
raise OperationTimedOut()
time.sleep(0.01)
try:
return waiter.deliver(timeout)
except OperationTimedOut:
raise
except Exception, exc:
self.defunct(exc)
raise
def register_watcher(self, event_type, callback):
raise NotImplementedError()

View File

@@ -1,16 +1,13 @@
from collections import defaultdict, deque
from functools import partial
import logging
import os
import socket
import sys
from threading import Event, Lock, Thread
import time
import traceback
from six import BytesIO
from six.moves import queue as Queue
from six.moves import xrange
from six.moves import range
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode
import asyncore
@@ -21,9 +18,8 @@ except ImportError:
ssl = None # NOQA
from cassandra import OperationTimedOut
from cassandra.connection import (Connection, ResponseWaiter, ConnectionShutdown,
ConnectionBusy, ConnectionException, NONBLOCKING,
MAX_STREAM_PER_CONNECTION)
from cassandra.connection import (Connection, ConnectionShutdown,
ConnectionException, NONBLOCKING)
from cassandra.decoder import RegisterMessage
from cassandra.marshal import int32_unpack
@@ -172,44 +168,11 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
_starting_conns.discard(self)
if not self.is_defunct:
self._error_all_callbacks(
self.error_all_callbacks(
ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging
self.connected_event.set()
def defunct(self, exc):
with self.lock:
if self.is_defunct or self.is_closed:
return
self.is_defunct = True
trace = traceback.format_exc() #exc)
if trace != "None":
log.debug("Defuncting connection (%s) to %s: %s\n%s",
id(self), self.host, exc, traceback.format_exc())
else:
log.debug("Defuncting connection (%s) to %s: %s",
id(self), self.host, exc)
self.last_error = exc
self.close()
self._error_all_callbacks(exc)
self.connected_event.set()
return exc
def _error_all_callbacks(self, exc):
with self.lock:
callbacks = self._callbacks
self._callbacks = {}
new_exc = ConnectionShutdown(str(exc))
for cb in callbacks.values():
try:
cb(new_exc)
except Exception:
log.warn("Ignoring unhandled exception while erroring callbacks for a "
"failed connection (%s) to host %s:",
id(self), self.host, exc_info=True)
def handle_connect(self):
with _starting_conns_lock:
_starting_conns.discard(self)
@@ -300,19 +263,11 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
if not self._callbacks:
self._readable = False
def handle_pushed(self, response):
log.debug("Message pushed from server: %r", response)
for cb in self._push_watchers.get(response.event_type, []):
try:
cb(response.event_args)
except Exception:
log.exception("Pushed event handler errored, ignoring:")
def push(self, data):
sabs = self.out_buffer_size
if len(data) > sabs:
chunks = []
for i in xrange(0, len(data), sabs):
for i in range(0, len(data), sabs):
chunks.append(data[i:i + sabs])
else:
chunks = [data]
@@ -328,61 +283,6 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
def readable(self):
return self._readable or (self._have_listeners and not (self.is_defunct or self.is_closed))
def send_msg(self, msg, cb, wait_for_id=False):
if self.is_defunct:
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
elif self.is_closed:
raise ConnectionShutdown("Connection to %s is closed" % self.host)
if not wait_for_id:
try:
request_id = self._id_queue.get_nowait()
except Queue.Empty:
raise ConnectionBusy(
"Connection to %s is at the max number of requests" % self.host)
else:
request_id = self._id_queue.get()
self._callbacks[request_id] = cb
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
return request_id
def wait_for_response(self, msg, timeout=None):
return self.wait_for_responses(msg, timeout=timeout)[0]
def wait_for_responses(self, *msgs, **kwargs):
timeout = kwargs.get('timeout')
waiter = ResponseWaiter(self, len(msgs))
# busy wait for sufficient space on the connection
messages_sent = 0
while True:
needed = len(msgs) - messages_sent
with self.lock:
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
self.in_flight += available
for i in range(messages_sent, messages_sent + available):
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
messages_sent += available
if messages_sent == len(msgs):
break
else:
if timeout is not None:
timeout -= 0.01
if timeout <= 0.0:
raise OperationTimedOut()
time.sleep(0.01)
try:
return waiter.deliver(timeout)
except OperationTimedOut:
raise
except Exception as exc:
self.defunct(exc)
raise
def register_watcher(self, event_type, callback):
self._push_watchers[event_type].add(callback)
self._have_listeners = True

View File

@@ -0,0 +1,190 @@
import gevent
from gevent import select, socket
from gevent.event import Event
from gevent.queue import Queue
from collections import defaultdict
from functools import partial
import logging
import os
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO # ignore flake8 warning: # NOQA
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL
from cassandra import OperationTimedOut
from cassandra.connection import Connection, ConnectionShutdown
from cassandra.decoder import RegisterMessage
from cassandra.marshal import int32_unpack
log = logging.getLogger(__name__)
def is_timeout(err):
return (
err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or
(err == EINVAL and os.name in ('nt', 'ce'))
)
class GeventConnection(Connection):
"""
An implementation of :class:`.Connection` that utilizes ``gevent``.
"""
_total_reqd_bytes = 0
_read_watcher = None
_write_watcher = None
_socket = None
@classmethod
def factory(cls, *args, **kwargs):
timeout = kwargs.pop('timeout', 5.0)
conn = cls(*args, **kwargs)
conn.connected_event.wait(timeout)
if conn.last_error:
raise conn.last_error
elif not conn.connected_event.is_set():
conn.close()
raise OperationTimedOut("Timed out creating connection")
else:
return conn
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self.connected_event = Event()
self._iobuf = StringIO()
self._write_queue = Queue()
self._callbacks = {}
self._push_watchers = defaultdict(set)
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(1.0)
self._socket.connect((self.host, self.port))
if self.sockopts:
for args in self.sockopts:
self._socket.setsockopt(*args)
self._read_watcher = gevent.spawn(lambda: self.handle_read())
self._write_watcher = gevent.spawn(lambda: self.handle_write())
self._send_options_message()
def close(self):
with self.lock:
if self.is_closed:
return
self.is_closed = True
log.debug("Closing connection (%s) to %s" % (id(self), self.host))
if self._read_watcher:
self._read_watcher.kill()
if self._write_watcher:
self._write_watcher.kill()
if self._socket:
self._socket.close()
log.debug("Closed socket to %s" % (self.host,))
if not self.is_defunct:
self.error_all_callbacks(
ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging
self.connected_event.set()
def handle_close(self):
log.debug("connection closed by server")
self.close()
def handle_write(self):
run_select = partial(select.select, (), (self._socket,), ())
while True:
try:
next_msg = self._write_queue.get()
run_select()
except Exception as exc:
log.debug("Exception during write select() for %s: %s", self, exc)
self.defunct(exc)
return
try:
self._socket.sendall(next_msg)
except socket.error as err:
log.debug("Exception during socket sendall for %s: %s", self, err)
self.defunct(err)
return # Leave the write loop
def handle_read(self):
run_select = partial(select.select, (self._socket,), (), ())
while True:
try:
run_select()
except Exception as exc:
log.debug("Exception during read select() for %s: %s", self, exc)
self.defunct(exc)
return
try:
buf = self._socket.recv(self.in_buffer_size)
self._iobuf.write(buf)
except socket.error as err:
if not is_timeout(err):
log.debug("Exception during socket recv for %s: %s", self, err)
self.defunct(err)
return # leave the read loop
if self._iobuf.tell():
while True:
pos = self._iobuf.tell()
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
# we don't have a complete header yet or we
# already saw a header, but we don't have a
# complete message yet
break
else:
# have enough for header, read body len from header
self._iobuf.seek(4)
body_len = int32_unpack(self._iobuf.read(4))
# seek to end to get length of current buffer
self._iobuf.seek(0, os.SEEK_END)
pos = self._iobuf.tell()
if pos >= body_len + 8:
# read message header and body
self._iobuf.seek(0)
msg = self._iobuf.read(8 + body_len)
# leave leftover in current buffer
leftover = self._iobuf.read()
self._iobuf = StringIO()
self._iobuf.write(leftover)
self._total_reqd_bytes = 0
self.process_msg(msg, body_len)
else:
self._total_reqd_bytes = body_len + 8
break
else:
log.debug("connection closed by server")
self.close()
return
def push(self, data):
chunk_size = self.out_buffer_size
for i in xrange(0, len(data), chunk_size):
self._write_queue.put(data[i:i + chunk_size])
def register_watcher(self, event_type, callback):
self._push_watchers[event_type].add(callback)
self.wait_for_response(RegisterMessage(event_list=[event_type]))
def register_watchers(self, type_callback_dict):
for event_type, callback in type_callback_dict.items():
self._push_watchers[event_type].add(callback)
self.wait_for_response(RegisterMessage(event_list=type_callback_dict.keys()))

View File

@@ -1,20 +1,13 @@
from collections import defaultdict, deque
from functools import partial, wraps
import logging
import os
import socket
from threading import Event, Lock, Thread
import time
import traceback
from six.moves.queue import Queue
from six.moves import cStringIO as StringIO
from six.moves import xrange
from six import BytesIO
from cassandra import OperationTimedOut
from cassandra.connection import (Connection, ResponseWaiter, ConnectionShutdown,
ConnectionBusy, NONBLOCKING,
MAX_STREAM_PER_CONNECTION)
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING
from cassandra.decoder import RegisterMessage
from cassandra.marshal import int32_unpack
try:
@@ -80,18 +73,6 @@ def _start_loop():
return should_start
def defunct_on_error(f):
@wraps(f)
def wrapper(self, *args, **kwargs):
try:
return f(self, *args, **kwargs)
except Exception as exc:
self.defunct(exc)
return wrapper
class LibevConnection(Connection):
"""
An implementation of :class:`.Connection` that uses libev for its event loop.
@@ -192,7 +173,7 @@ class LibevConnection(Connection):
Connection.__init__(self, *args, **kwargs)
self.connected_event = Event()
self._iobuf = StringIO()
self._iobuf = BytesIO()
self._callbacks = {}
self._push_watchers = defaultdict(set)
@@ -237,41 +218,9 @@ class LibevConnection(Connection):
# don't leave in-progress operations hanging
if not self.is_defunct:
self._error_all_callbacks(
self.error_all_callbacks(
ConnectionShutdown("Connection to %s was closed" % self.host))
def defunct(self, exc):
with self.lock:
if self.is_defunct or self.is_closed:
return
self.is_defunct = True
trace = traceback.format_exc(exc)
if trace != "None":
log.debug("Defuncting connection (%s) to %s: %s\n%s",
id(self), self.host, exc, traceback.format_exc(exc))
else:
log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc)
self.last_error = exc
self.close()
self._error_all_callbacks(exc)
self.connected_event.set()
return exc
def _error_all_callbacks(self, exc):
with self.lock:
callbacks = self._callbacks
self._callbacks = {}
new_exc = ConnectionShutdown(str(exc))
for cb in callbacks.values():
try:
cb(new_exc)
except Exception:
log.warn("Ignoring unhandled exception while erroring callbacks for a "
"failed connection (%s) to host %s:",
id(self), self.host, exc_info=True)
def handle_write(self, watcher, revents, errno=None):
if revents & libev.EV_ERROR:
if errno:
@@ -351,7 +300,7 @@ class LibevConnection(Connection):
# leave leftover in current buffer
leftover = self._iobuf.read()
self._iobuf = StringIO()
self._iobuf = BytesIO()
self._iobuf.write(leftover)
self._total_reqd_bytes = 0
@@ -363,14 +312,6 @@ class LibevConnection(Connection):
log.debug("Connection %s closed by server", self)
self.close()
def handle_pushed(self, response):
log.debug("Message pushed from server: %r", response)
for cb in self._push_watchers.get(response.event_type, []):
try:
cb(response.event_args)
except Exception:
log.exception("Pushed event handler errored, ignoring:")
def push(self, data):
sabs = self.out_buffer_size
if len(data) > sabs:
@@ -384,61 +325,6 @@ class LibevConnection(Connection):
self.deque.extend(chunks)
_loop_notifier.send()
def send_msg(self, msg, cb, wait_for_id=False):
if self.is_defunct:
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
elif self.is_closed:
raise ConnectionShutdown("Connection to %s is closed" % self.host)
if not wait_for_id:
try:
request_id = self._id_queue.get_nowait()
except Queue.Empty:
raise ConnectionBusy(
"Connection to %s is at the max number of requests" % self.host)
else:
request_id = self._id_queue.get()
self._callbacks[request_id] = cb
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
return request_id
def wait_for_response(self, msg, timeout=None):
return self.wait_for_responses(msg, timeout=timeout)[0]
def wait_for_responses(self, *msgs, **kwargs):
timeout = kwargs.get('timeout')
waiter = ResponseWaiter(self, len(msgs))
# busy wait for sufficient space on the connection
messages_sent = 0
while True:
needed = len(msgs) - messages_sent
with self.lock:
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
self.in_flight += available
for i in range(messages_sent, messages_sent + available):
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
messages_sent += available
if messages_sent == len(msgs):
break
else:
if timeout is not None:
timeout -= 0.01
if timeout <= 0.0:
raise OperationTimedOut()
time.sleep(0.01)
try:
return waiter.deliver(timeout)
except OperationTimedOut:
raise
except Exception as exc:
self.defunct(exc)
raise
def register_watcher(self, event_type, callback):
self._push_watchers[event_type].add(callback)
self.wait_for_response(RegisterMessage(event_list=[event_type]))

View File

@@ -157,7 +157,7 @@ class Metadata(object):
if not cf_results:
# the table was removed
del keyspace_meta.tables[table]
keyspace_meta.tables.pop(table, None)
else:
assert len(cf_results) == 1
keyspace_meta.tables[table] = self._build_table_metadata(
@@ -346,11 +346,12 @@ class Metadata(object):
else:
return True
def add_host(self, address):
def add_host(self, address, datacenter, rack):
cluster = self.cluster_ref()
with self._hosts_lock:
if address not in self._hosts:
new_host = Host(address, cluster.conviction_policy_factory)
new_host = Host(
address, cluster.conviction_policy_factory, datacenter, rack)
self._hosts[address] = new_host
else:
return None

View File

@@ -137,6 +137,7 @@ class RoundRobinPolicy(LoadBalancingPolicy):
This load balancing policy is used by default.
"""
_live_hosts = frozenset(())
def populate(self, cluster, hosts):
self._live_hosts = frozenset(hosts)

View File

@@ -71,7 +71,7 @@ class Host(object):
_currently_handling_node_up = False
_handle_node_up_condition = None
def __init__(self, inet_address, conviction_policy_factory):
def __init__(self, inet_address, conviction_policy_factory, datacenter=None, rack=None):
if inet_address is None:
raise ValueError("inet_address may not be None")
if conviction_policy_factory is None:
@@ -79,6 +79,7 @@ class Host(object):
self.address = inet_address
self.conviction_policy = conviction_policy_factory(self)
self.set_location_info(datacenter, rack)
self.lock = RLock()
self._handle_node_up_condition = Condition()

View File

@@ -73,6 +73,13 @@ class Statement(object):
"""
fetch_size = None
"""
How many rows will be fetched at a time. This overrides the default
of :attr:`.Session.default_fetch_size`
This only takes effect when protocol version 2 or higher is used.
See :attr:`.Cluster.protocol_version` for details.
"""
_serial_consistency_level = None
_routing_key = None
@@ -561,6 +568,7 @@ class QueryTrace(object):
if max_wait is not None and time_spent >= max_wait:
raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,))
log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id)
session_results = self._execute(
self._SELECT_SESSIONS_FORMAT, (self.trace_id,), time_spent, max_wait)
@@ -568,6 +576,7 @@ class QueryTrace(object):
time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt))
attempt += 1
continue
log.debug("Fetched trace info for trace ID: %s", self.trace_id)
session_row = session_results[0]
self.request_type = session_row.request
@@ -576,9 +585,11 @@ class QueryTrace(object):
self.coordinator = session_row.coordinator
self.parameters = session_row.parameters
log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id)
time_spent = time.time() - start
event_results = self._execute(
self._SELECT_EVENTS_FORMAT, (self.trace_id,), time_spent, max_wait)
log.debug("Fetched trace events for trace ID: %s", self.trace_id)
self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread)
for r in event_results)
break

View File

@@ -7,6 +7,8 @@
.. autoattribute:: cql_version
.. autoattribute:: protocol_version
.. autoattribute:: port
.. autoattribute:: compression
@@ -59,6 +61,8 @@
.. autoattribute:: row_factory
.. autoattribute:: default_fetch_size
.. automethod:: execute(statement[, parameters][, timeout][, trace])
.. automethod:: execute_async(statement[, parameters][, trace])
@@ -77,11 +81,20 @@
.. automethod:: get_query_trace()
.. autoattribute:: has_more_pages
.. automethod:: start_fetching_next_page()
.. automethod:: add_callback(fn, *args, **kwargs)
.. automethod:: add_errback(fn, *args, **kwargs)
.. automethod:: add_callbacks(callback, errback, callback_args=(), callback_kwargs=None, errback_args=(), errback_args=None)
.. autoclass:: PagedResult ()
:members:
.. autoexception:: QueryExhausted ()
.. autoexception:: NoHostAvailable ()
:members:

View File

@@ -8,6 +8,7 @@ Python Cassandra Driver
installation
getting_started
performance
query_paging
Indices and Tables
==================

74
docs/query_paging.rst Normal file
View File

@@ -0,0 +1,74 @@
Paging Large Queries
====================
Cassandra 2.0+ offers support for automatic query paging. Starting with
version 2.0 of the driver, if :attr:`~.Cluster.protocol_version` is set to
:const:`2` (it is by default), queries returning large result sets will be
automatically paged.
Controlling the Page Size
-------------------------
By default, :attr:`.Session.default_fetch_size` controls how many rows will
be fetched per page. This can be overridden per-query by setting
:attr:`~.fetch_size` on a :class:`~.Statement`. By default, each page
will contain at most 5000 rows.
Handling Paged Results
----------------------
Whenever the number of result rows for are query exceed the page size, an
instance of :class:`~.PagedResult` will be returned instead of a normal
list. This class implements the iterator interface, so you can treat
it like a normal iterator over rows::
from cassandra.query import SimpleStatement
query = "SELECT * FROM users" # users contains 100 rows
statement = SimpleStatement(query, fetch_size=10)
for user_row in session.execute(statement):
process_user(user_row)
Whenever there are no more rows in the current page, the next page will
be fetched transparently. However, note that it *is* possible for
an :class:`Exception` to be raised while fetching the next page, just
like you might see on a normal call to ``session.execute()``.
If you use :meth:`.Session.execute_async()` along with,
:meth:`.ResponseFuture.result()`, the first page will be fetched before
:meth:`~.ResponseFuture.result()` returns, but latter pages will be
transparently fetched synchronously while iterating the result.
Handling Paged Results with Callbacks
-------------------------------------
If callbacks are attached to a query that returns a paged result,
the callback will be called once per page with a normal list of rows.
Use :attr:`.ResponseFuture.has_more_pages` and
:meth:`.ResponseFuture.start_fetching_next_page()` to continue fetching
pages. For example::
class PagedResultHandler(object):
def __init__(self, future):
self.error = None
self.finished_event = Event()
self.future = future
self.future.add_callbacks(
callback=self.handle_page,
errback=self.handle_err)
def handle_page(self, rows):
for row in rows:
process_row(row)
if self.future.has_more_pages:
self.future.start_fetching_next_page()
else:
self.finished_event.set()
def handle_error(self, exc):
self.error = exc
self.finished_event.set()
future = session.execute_async("SELECT * FROM users")
handler = PagedResultHandler(future)
handler.finished_event.wait()
if handler.error:
raise handler.error

View File

@@ -1,18 +1,13 @@
from __future__ import print_function
import platform
import os
import sys
import warnings
try:
import subprocess
has_subprocess = True
except ImportError:
has_subprocess = False
import ez_setup
ez_setup.use_setuptools()
if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests":
from gevent.monkey import patch_all
patch_all()
from setuptools import setup
from distutils.command.build_ext import build_ext
from distutils.core import Extension
@@ -20,6 +15,19 @@ from distutils.errors import (CCompilerError, DistutilsPlatformError,
DistutilsExecError)
from distutils.cmd import Command
import platform
import os
import warnings
try:
import subprocess
has_subprocess = True
except ImportError:
has_subprocess = False
from nose.commands import nosetests
from cassandra import __version__
long_description = ""
@@ -27,6 +35,10 @@ with open("README.rst") as f:
long_description = f.read()
class gevent_nosetests(nosetests):
description = "run nosetests with gevent monkey patching"
class DocCommand(Command):
description = "generate or test documentation"
@@ -144,12 +156,12 @@ On OSX, via homebrew:
def run_setup(extensions):
kw = {'cmdclass': {'doc': DocCommand}}
kw = {'cmdclass': {'doc': DocCommand, 'gevent_nosetests': gevent_nosetests}}
if extensions:
kw['cmdclass']['build_ext'] = build_extensions
kw['ext_modules'] = extensions
dependencies = ['futures', 'scales', 'blist', 'six >=1.6']
dependencies = ['futures', 'scales >=1.0.5', 'blist', 'six >=1.6']
if platform.python_implementation() != "CPython":
dependencies.remove('blist')
@@ -164,7 +176,7 @@ def run_setup(extensions):
packages=['cassandra', 'cassandra.io'],
include_package_data=True,
install_requires=dependencies,
tests_require=['nose', 'mock', 'PyYAML'],
tests_require=['nose', 'mock', 'PyYAML', 'pytz'],
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',

View File

@@ -18,7 +18,14 @@ except ImportError as e:
CLUSTER_NAME = 'test_cluster'
CCM_CLUSTER = None
DEFAULT_CASSANDRA_VERSION = '2.0.5'
CASSANDRA_VERSION = os.getenv('CASSANDRA_VERSION', '2.0.6')
if CASSANDRA_VERSION.startswith('1'):
DEFAULT_PROTOCOL_VERSION = 1
else:
DEFAULT_PROTOCOL_VERSION = 2
PROTOCOL_VERSION = int(os.getenv('PROTOCOL_VERSION', DEFAULT_PROTOCOL_VERSION))
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'ccm')
if not os.path.exists(path):
@@ -38,7 +45,7 @@ def get_server_versions():
if cass_version is not None:
return (cass_version, cql_version)
c = Cluster()
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.set_keyspace('system')
row = s.execute('SELECT cql_version, release_version FROM local')[0]
@@ -67,16 +74,16 @@ def get_node(node_id):
def setup_package():
version = os.getenv("CASSANDRA_VERSION", DEFAULT_CASSANDRA_VERSION)
print 'Using Cassandra version: %s' % CASSANDRA_VERSION
try:
try:
cluster = CCMCluster.load(path, CLUSTER_NAME)
log.debug("Found existing ccm test cluster, clearing")
cluster.clear()
cluster.set_cassandra_dir(cassandra_version=version)
cluster.set_cassandra_dir(cassandra_version=CASSANDRA_VERSION)
except Exception:
log.debug("Creating new ccm test cluster with version %s", version)
cluster = CCMCluster(path, CLUSTER_NAME, cassandra_version=version)
log.debug("Creating new ccm test cluster with version %s", CASSANDRA_VERSION)
cluster = CCMCluster(path, CLUSTER_NAME, cassandra_version=CASSANDRA_VERSION)
cluster.set_configuration_options({'start_native_transport': True})
common.switch_cluster(path, CLUSTER_NAME)
cluster.populate(3)
@@ -93,7 +100,7 @@ def setup_package():
def setup_test_keyspace():
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
try:

View File

@@ -7,6 +7,7 @@ from cassandra.cluster import Cluster
from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy, \
DowngradingConsistencyRetryPolicy
from cassandra.query import SimpleStatement
from tests.integration import PROTOCOL_VERSION
from tests.integration.long.utils import force_stop, create_schema, \
wait_for_down, wait_for_up, start, CoordinatorStats
@@ -98,7 +99,8 @@ class ConsistencyTests(unittest.TestCase):
def _test_tokenaware_one_node_down(self, keyspace, rf, accepted):
cluster = Cluster(
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()))
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
wait_for_up(cluster, 1, wait=False)
wait_for_up(cluster, 2)
@@ -147,7 +149,8 @@ class ConsistencyTests(unittest.TestCase):
def test_rfthree_tokenaware_none_down(self):
keyspace = 'test_rfthree_tokenaware_none_down'
cluster = Cluster(
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()))
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
wait_for_up(cluster, 1, wait=False)
wait_for_up(cluster, 2)
@@ -169,7 +172,8 @@ class ConsistencyTests(unittest.TestCase):
def _test_downgrading_cl(self, keyspace, rf, accepted):
cluster = Cluster(
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
default_retry_policy=DowngradingConsistencyRetryPolicy())
default_retry_policy=DowngradingConsistencyRetryPolicy(),
protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
create_schema(session, keyspace, replication_factor=rf)
@@ -210,14 +214,16 @@ class ConsistencyTests(unittest.TestCase):
keyspace = 'test_rfthree_roundrobin_downgradingcl'
cluster = Cluster(
load_balancing_policy=RoundRobinPolicy(),
default_retry_policy=DowngradingConsistencyRetryPolicy())
default_retry_policy=DowngradingConsistencyRetryPolicy(),
protocol_version=PROTOCOL_VERSION)
self.rfthree_downgradingcl(cluster, keyspace, True)
def test_rfthree_tokenaware_downgradingcl(self):
keyspace = 'test_rfthree_tokenaware_downgradingcl'
cluster = Cluster(
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
default_retry_policy=DowngradingConsistencyRetryPolicy())
default_retry_policy=DowngradingConsistencyRetryPolicy(),
protocol_version=PROTOCOL_VERSION)
self.rfthree_downgradingcl(cluster, keyspace, False)
def rfthree_downgradingcl(self, cluster, keyspace, roundrobin):

View File

@@ -1,14 +1,15 @@
try:
from Queue import Queue, Empty
except ImportError:
from queue import Queue, Empty
from queue import Queue, Empty # noqa
from struct import pack
import unittest
from cassandra import ConsistencyLevel
from cassandra.cluster import Cluster
from cassandra.decoder import dict_factory
from cassandra.query import dict_factory
from cassandra.query import SimpleStatement
from tests.integration import PROTOCOL_VERSION
from tests.integration.long.utils import create_schema
@@ -32,7 +33,7 @@ class LargeDataTests(unittest.TestCase):
self.keyspace = 'large_data'
def make_session_and_keyspace(self):
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
session.default_timeout = 20.0 # increase the default timeout
session.row_factory = dict_factory
@@ -41,9 +42,10 @@ class LargeDataTests(unittest.TestCase):
return session
def batch_futures(self, session, statement_generator):
futures = Queue(maxsize=121)
concurrency = 50
futures = Queue.Queue(maxsize=concurrency)
for i, statement in enumerate(statement_generator):
if i > 0 and i % 120 == 0:
if i > 0 and i % (concurrency - 1) == 0:
# clear the existing queue
while True:
try:
@@ -70,7 +72,7 @@ class LargeDataTests(unittest.TestCase):
session,
(SimpleStatement('INSERT INTO %s (k, i) VALUES (0, %s)' % (table, i),
consistency_level=ConsistencyLevel.QUORUM)
for i in range(1000000)))
for i in range(100000)))
# Read
results = session.execute('SELECT i FROM %s WHERE k=%s' % (table, 0))
@@ -112,7 +114,7 @@ class LargeDataTests(unittest.TestCase):
session,
(SimpleStatement('INSERT INTO %s (k, i, v) VALUES (0, %s, %s)' % (table, i, str(bb)),
consistency_level=ConsistencyLevel.QUORUM)
for i in range(1000000)))
for i in range(100000)))
# Read
results = session.execute('SELECT i, v FROM %s WHERE k=%s' % (table, 0))

View File

@@ -3,6 +3,7 @@ import logging
from cassandra import ConsistencyLevel
from cassandra.cluster import Cluster
from cassandra.query import SimpleStatement
from tests.integration import PROTOCOL_VERSION
try:
import unittest2 as unittest
@@ -15,7 +16,7 @@ log = logging.getLogger(__name__)
class SchemaTests(unittest.TestCase):
def test_recreates(self):
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
replication_factor = 3

View File

@@ -4,7 +4,7 @@ import time
from collections import defaultdict
from cassandra.decoder import named_tuple_factory
from cassandra.query import named_tuple_factory
from tests.integration import get_node
@@ -74,7 +74,9 @@ def stop(node):
def force_stop(node):
log.debug("Forcing stop of node %s", node)
get_node(node).stop(wait=False, gently=False)
log.debug("Node %s was stopped", node)
def ring(node):
@@ -85,6 +87,7 @@ def ring(node):
def wait_for_up(cluster, node, wait=True):
while True:
host = cluster.metadata.get_host('127.0.0.%s' % node)
time.sleep(0.1)
if host and host.is_up:
# BUG: shouldn't have to, but we do
if wait:
@@ -93,10 +96,14 @@ def wait_for_up(cluster, node, wait=True):
def wait_for_down(cluster, node, wait=True):
log.debug("Waiting for node %s to be down", node)
while True:
host = cluster.metadata.get_host('127.0.0.%s' % node)
time.sleep(0.1)
if not host or not host.is_up:
# BUG: shouldn't have to, but we do
if wait:
log.debug("Sleeping 5s until host is up")
time.sleep(5)
log.debug("Done waiting for node %s to be down", node)
return

View File

@@ -1,3 +1,5 @@
from tests.integration import PROTOCOL_VERSION
try:
import unittest2 as unittest
except ImportError:
@@ -5,7 +7,6 @@ except ImportError:
import cassandra
from cassandra.query import SimpleStatement, TraceUnavailable
from cassandra.io.asyncorereactor import AsyncoreConnection
from cassandra.policies import RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance
from cassandra.cluster import Cluster, NoHostAvailable
@@ -18,7 +19,7 @@ class ClusterTests(unittest.TestCase):
Test basic connection and usage
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
result = session.execute(
"""
@@ -54,7 +55,7 @@ class ClusterTests(unittest.TestCase):
Ensure clusters that connect on a keyspace, do
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
result = session.execute(
"""
@@ -71,7 +72,7 @@ class ClusterTests(unittest.TestCase):
self.assertEqual(result, result2)
def test_set_keyspace_twice(self):
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
session.execute("USE system")
session.execute("USE system")
@@ -86,7 +87,7 @@ class ClusterTests(unittest.TestCase):
reconnection_policy=ExponentialReconnectionPolicy(1.0, 600.0),
default_retry_policy=RetryPolicy(),
conviction_policy_factory=SimpleConvictionPolicy,
connection_class=AsyncoreConnection
protocol_version=PROTOCOL_VERSION
)
def test_double_shutdown(self):
@@ -94,7 +95,7 @@ class ClusterTests(unittest.TestCase):
Ensure that a cluster can be shutdown twice, without error
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.shutdown()
try:
@@ -108,7 +109,7 @@ class ClusterTests(unittest.TestCase):
Ensure you cannot connect to a cluster that's been shutdown
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.shutdown()
self.assertRaises(Exception, cluster.connect)
@@ -132,7 +133,8 @@ class ClusterTests(unittest.TestCase):
when a cluster cannot connect to given hosts
"""
cluster = Cluster(['127.1.2.9', '127.1.2.10'])
cluster = Cluster(['127.1.2.9', '127.1.2.10'],
protocol_version=PROTOCOL_VERSION)
self.assertRaises(NoHostAvailable, cluster.connect)
def test_cluster_settings(self):
@@ -140,7 +142,7 @@ class ClusterTests(unittest.TestCase):
Test connection setting getters and setters
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
min_requests_per_connection = cluster.get_min_requests_per_connection(HostDistance.LOCAL)
self.assertEqual(cassandra.cluster.DEFAULT_MIN_REQUESTS, min_requests_per_connection)
@@ -167,11 +169,11 @@ class ClusterTests(unittest.TestCase):
Ensure new new schema is refreshed after submit_schema_refresh()
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.connect()
self.assertNotIn("newkeyspace", cluster.metadata.keyspaces)
other_cluster = Cluster()
other_cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = other_cluster.connect()
session.execute(
"""
@@ -189,15 +191,22 @@ class ClusterTests(unittest.TestCase):
Ensure trace can be requested for async and non-async queries
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
self.assertRaises(TypeError, session.execute, "SELECT * FROM system.local", trace=True)
def check_trace(trace):
self.assertIsNot(None, trace.request_type)
self.assertIsNot(None, trace.duration)
self.assertIsNot(None, trace.started_at)
self.assertIsNot(None, trace.coordinator)
self.assertIsNot(None, trace.events)
query = "SELECT * FROM system.local"
statement = SimpleStatement(query)
session.execute(statement, trace=True)
self.assertEqual(query, statement.trace.parameters['query'])
check_trace(statement.trace)
query = "SELECT * FROM system.local"
statement = SimpleStatement(query)
@@ -207,15 +216,20 @@ class ClusterTests(unittest.TestCase):
statement2 = SimpleStatement(query)
future = session.execute_async(statement2, trace=True)
future.result()
self.assertEqual(query, future.get_query_trace().parameters['query'])
check_trace(future.get_query_trace())
statement2 = SimpleStatement(query)
future = session.execute_async(statement2)
future.result()
self.assertEqual(None, future.get_query_trace())
prepared = session.prepare("SELECT * FROM system.local")
future = session.execute_async(prepared, parameters=(), trace=True)
future.result()
check_trace(future.get_query_trace())
def test_trace_timeout(self):
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
query = "SELECT * FROM system.local"
@@ -229,7 +243,7 @@ class ClusterTests(unittest.TestCase):
Ensure str(future) returns without error
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
query = "SELECT * FROM system.local"

View File

@@ -0,0 +1,113 @@
from tests.integration import PROTOCOL_VERSION
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from itertools import cycle
from cassandra import InvalidRequest
from cassandra.cluster import Cluster
from cassandra.concurrent import (execute_concurrent,
execute_concurrent_with_args)
from cassandra.policies import HostDistance
from cassandra.query import tuple_factory
class ClusterTests(unittest.TestCase):
def setUp(self):
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
self.session = self.cluster.connect()
self.session.row_factory = tuple_factory
def test_execute_concurrent(self):
for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
# write
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(num_statements)]
results = execute_concurrent(self.session, zip(statements, parameters))
self.assertEqual(num_statements, len(results))
self.assertEqual([(True, None)] * num_statements, results)
# read
statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s", ))
parameters = [(i, ) for i in range(num_statements)]
results = execute_concurrent(self.session, zip(statements, parameters))
self.assertEqual(num_statements, len(results))
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
def test_execute_concurrent_with_args(self):
for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
statement = "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"
parameters = [(i, i) for i in range(num_statements)]
results = execute_concurrent_with_args(self.session, statement, parameters)
self.assertEqual(num_statements, len(results))
self.assertEqual([(True, None)] * num_statements, results)
# read
statement = "SELECT v FROM test3rf.test WHERE k=%s"
parameters = [(i, ) for i in range(num_statements)]
results = execute_concurrent_with_args(self.session, statement, parameters)
self.assertEqual(num_statements, len(results))
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
def test_first_failure(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(100)]
# we'll get an error back from the server
parameters[57] = ('efefef', 'awefawefawef')
self.assertRaises(
InvalidRequest,
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
def test_first_failure_client_side(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(100)]
# the driver will raise an error when binding the params
parameters[57] = 1
self.assertRaises(
TypeError,
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
def test_no_raise_on_first_failure(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(100)]
# we'll get an error back from the server
parameters[57] = ('efefef', 'awefawefawef')
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
for i, (success, result) in enumerate(results):
if i == 57:
self.assertFalse(success)
self.assertIsInstance(result, InvalidRequest)
else:
self.assertTrue(success)
self.assertEqual(None, result)
def test_no_raise_on_first_failure_client_side(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(100)]
# the driver will raise an error when binding the params
parameters[57] = i
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
for i, (success, result) in enumerate(results):
if i == 57:
self.assertFalse(success)
self.assertIsInstance(result, TypeError)
else:
self.assertTrue(success)
self.assertEqual(None, result)

View File

@@ -1,9 +1,12 @@
from tests.integration import PROTOCOL_VERSION
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from functools import partial
import sys
from threading import Thread, Event
from cassandra import ConsistencyLevel
@@ -24,7 +27,7 @@ class ConnectionTest(object):
"""
Test a single connection with sequential requests.
"""
conn = self.klass.factory()
conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
event = Event()
@@ -47,7 +50,7 @@ class ConnectionTest(object):
"""
Test a single connection with pipelined requests.
"""
conn = self.klass.factory()
conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
responses = [False] * 100
event = Event()
@@ -69,7 +72,7 @@ class ConnectionTest(object):
"""
Test multiple connections with pipelined requests.
"""
conns = [self.klass.factory() for i in range(5)]
conns = [self.klass.factory(protocol_version=PROTOCOL_VERSION) for i in range(5)]
events = [Event() for i in range(5)]
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
@@ -100,7 +103,7 @@ class ConnectionTest(object):
num_threads = 5
event = Event()
conn = self.klass.factory()
conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
def cb(all_responses, thread_responses, request_num, *args, **kwargs):
@@ -157,7 +160,7 @@ class ConnectionTest(object):
threads = []
for i in range(num_conns):
conn = self.klass.factory()
conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
t = Thread(target=send_msgs, args=(conn, events[i]))
threads.append(t)
@@ -172,12 +175,18 @@ class AsyncoreConnectionTest(ConnectionTest, unittest.TestCase):
klass = AsyncoreConnection
def setUp(self):
if 'gevent.monkey' in sys.modules:
raise unittest.SkipTest("Can't test libev with gevent monkey patching")
class LibevConnectionTest(ConnectionTest, unittest.TestCase):
klass = LibevConnection
def setUp(self):
if 'gevent.monkey' in sys.modules:
raise unittest.SkipTest("Can't test libev with gevent monkey patching")
if LibevConnection is None:
raise unittest.SkipTest(
'libev does not appear to be installed properly')

View File

@@ -1,3 +1,5 @@
from tests.integration import PROTOCOL_VERSION
try:
import unittest2 as unittest
except ImportError:
@@ -36,7 +38,7 @@ class TestFactories(unittest.TestCase):
'''
def test_tuple_factory(self):
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
session.row_factory = tuple_factory
@@ -58,7 +60,7 @@ class TestFactories(unittest.TestCase):
self.assertEqual(result[1][0], 2)
def test_named_tuple_factoryy(self):
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
session.row_factory = named_tuple_factory
@@ -79,7 +81,7 @@ class TestFactories(unittest.TestCase):
self.assertEqual(result[1].k, 2)
def test_dict_factory(self):
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
session.row_factory = dict_factory
@@ -101,7 +103,7 @@ class TestFactories(unittest.TestCase):
self.assertEqual(result[1]['k'], 2)
def test_ordered_dict_factory(self):
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
session.row_factory = ordered_dict_factory

View File

@@ -15,7 +15,7 @@ from cassandra.metadata import (Metadata, KeyspaceMetadata, TableMetadata,
from cassandra.policies import SimpleConvictionPolicy
from cassandra.pool import Host
from tests.integration import get_cluster
from tests.integration import get_cluster, PROTOCOL_VERSION
class SchemaMetadataTest(unittest.TestCase):
@@ -28,7 +28,7 @@ class SchemaMetadataTest(unittest.TestCase):
@classmethod
def setup_class(cls):
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
try:
results = session.execute("SELECT keyspace_name FROM system.schema_keyspaces")
@@ -46,7 +46,8 @@ class SchemaMetadataTest(unittest.TestCase):
@classmethod
def teardown_class(cls):
cluster = Cluster(['127.0.0.1'])
cluster = Cluster(['127.0.0.1'],
protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
try:
session.execute("DROP KEYSPACE %s" % cls.ksname)
@@ -54,7 +55,8 @@ class SchemaMetadataTest(unittest.TestCase):
cluster.shutdown()
def setUp(self):
self.cluster = Cluster(['127.0.0.1'])
self.cluster = Cluster(['127.0.0.1'],
protocol_version=PROTOCOL_VERSION)
self.session = self.cluster.connect()
def tearDown(self):
@@ -294,7 +296,7 @@ class TestCodeCoverage(unittest.TestCase):
Test export schema functionality
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.connect()
self.assertIsInstance(cluster.metadata.export_schema_as_string(), six.string_types)
@@ -304,7 +306,7 @@ class TestCodeCoverage(unittest.TestCase):
Test export keyspace schema functionality
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.connect()
for keyspace in cluster.metadata.keyspaces:
@@ -317,7 +319,7 @@ class TestCodeCoverage(unittest.TestCase):
Test that names that need to be escaped in CREATE statements are
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
ksname = 'AnInterestingKeyspace'
@@ -356,7 +358,7 @@ class TestCodeCoverage(unittest.TestCase):
Ensure AlreadyExists exception is thrown when hit
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
ksname = 'test3rf'
@@ -380,7 +382,7 @@ class TestCodeCoverage(unittest.TestCase):
if murmur3 is None:
raise unittest.SkipTest('the murmur3 extension is not available')
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.assertEqual(cluster.metadata.get_replicas('test3rf', 'key'), [])
cluster.connect('test3rf')
@@ -395,7 +397,7 @@ class TestCodeCoverage(unittest.TestCase):
Test token mappings
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.connect('test3rf')
ring = cluster.metadata.token_map.ring
owners = list(cluster.metadata.token_map.token_to_host_owner[token] for token in ring)
@@ -418,7 +420,7 @@ class TokenMetadataTest(unittest.TestCase):
def test_token(self):
expected_node_count = len(get_cluster().nodes)
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.connect()
tmap = cluster.metadata.token_map
self.assertTrue(issubclass(tmap.token_class, Token))

View File

@@ -7,7 +7,7 @@ from cassandra.query import SimpleStatement
from cassandra import ConsistencyLevel, WriteTimeout, Unavailable, ReadTimeout
from cassandra.cluster import Cluster, NoHostAvailable
from tests.integration import get_node, get_cluster
from tests.integration import get_node, get_cluster, PROTOCOL_VERSION
class MetricsTests(unittest.TestCase):
@@ -17,7 +17,8 @@ class MetricsTests(unittest.TestCase):
Trigger and ensure connection_errors are counted
"""
cluster = Cluster(metrics_enabled=True)
cluster = Cluster(metrics_enabled=True,
protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
session.execute("USE test3rf")
@@ -45,7 +46,8 @@ class MetricsTests(unittest.TestCase):
Attempt a write at cl.ALL and receive a WriteTimeout.
"""
cluster = Cluster(metrics_enabled=True)
cluster = Cluster(metrics_enabled=True,
protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
# Test write
@@ -75,7 +77,8 @@ class MetricsTests(unittest.TestCase):
Attempt a read at cl.ALL and receive a ReadTimeout.
"""
cluster = Cluster(metrics_enabled=True)
cluster = Cluster(metrics_enabled=True,
protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
# Test write
@@ -105,7 +108,8 @@ class MetricsTests(unittest.TestCase):
Attempt an insert/read at cl.ALL and receive a Unavailable Exception.
"""
cluster = Cluster(metrics_enabled=True)
cluster = Cluster(metrics_enabled=True,
protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
# Test write

View File

@@ -1,3 +1,5 @@
from tests.integration import PROTOCOL_VERSION
try:
import unittest2 as unittest
except ImportError:
@@ -15,7 +17,7 @@ class PreparedStatementTests(unittest.TestCase):
Test basic PreparedStatement usage
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
session.execute(
"""
@@ -60,7 +62,7 @@ class PreparedStatementTests(unittest.TestCase):
when prepared statements are missing the primary key
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
@@ -77,7 +79,7 @@ class PreparedStatementTests(unittest.TestCase):
Ensure a ValueError is thrown when attempting to bind too many variables
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
@@ -93,7 +95,7 @@ class PreparedStatementTests(unittest.TestCase):
Ensure binding None is handled correctly
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
@@ -120,7 +122,7 @@ class PreparedStatementTests(unittest.TestCase):
Ensure None binding over async queries
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(

View File

@@ -10,13 +10,13 @@ from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
from cassandra.cluster import Cluster
from cassandra.policies import HostDistance
from tests.integration import get_server_versions
from tests.integration import get_server_versions, PROTOCOL_VERSION
class QueryTest(unittest.TestCase):
def test_query(self):
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
@@ -44,7 +44,7 @@ class QueryTest(unittest.TestCase):
Code coverage to ensure trace prints to string without error
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
query = "SELECT * FROM system.local"
@@ -57,7 +57,7 @@ class QueryTest(unittest.TestCase):
str(event)
def test_trace_ignores_row_factory(self):
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
session.row_factory = dict_factory
@@ -78,7 +78,7 @@ class PreparedStatementTests(unittest.TestCase):
Simple code coverage to ensure routing_keys can be accessed
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
@@ -96,7 +96,7 @@ class PreparedStatementTests(unittest.TestCase):
the routing key should be None
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
@@ -115,7 +115,7 @@ class PreparedStatementTests(unittest.TestCase):
overrides the current routing key
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
@@ -133,7 +133,7 @@ class PreparedStatementTests(unittest.TestCase):
Basic test that uses a fake routing_key_index
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
@@ -151,7 +151,7 @@ class PreparedStatementTests(unittest.TestCase):
Ensure that bound.keyspace works as expected
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
@@ -186,7 +186,7 @@ class PrintStatementTests(unittest.TestCase):
Highlight the difference between Prepared and Bound statements
"""
cluster = Cluster()
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)')
@@ -202,13 +202,12 @@ class PrintStatementTests(unittest.TestCase):
class BatchStatementTests(unittest.TestCase):
def setUp(self):
cass_version, _ = get_server_versions()
if cass_version < (2, 0):
if PROTOCOL_VERSION < 2:
raise unittest.SkipTest(
"Cassandra 2.0+ is required for BATCH operations, currently testing against %r"
% (cass_version,))
"Protocol 2.0+ is required for BATCH operations, currently testing against %r"
% (PROTOCOL_VERSION,))
self.cluster = Cluster(protocol_version=2)
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
self.session = self.cluster.connect()
@@ -272,13 +271,12 @@ class BatchStatementTests(unittest.TestCase):
class SerialConsistencyTests(unittest.TestCase):
def setUp(self):
cass_version, _ = get_server_versions()
if cass_version < (2, 0):
if PROTOCOL_VERSION < 2:
raise unittest.SkipTest(
"Cassandra 2.0+ is required for BATCH operations, currently testing against %r"
% (cass_version,))
"Protocol 2.0+ is required for BATCH operations, currently testing against %r"
% (PROTOCOL_VERSION,))
self.cluster = Cluster(protocol_version=2)
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
self.session = self.cluster.connect()

View File

@@ -0,0 +1,112 @@
from tests.integration import PROTOCOL_VERSION
import logging
log = logging.getLogger(__name__)
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from itertools import cycle, count
from threading import Event
from cassandra.cluster import Cluster
from cassandra.concurrent import execute_concurrent
from cassandra.policies import HostDistance
from cassandra.query import SimpleStatement
class QueryPagingTests(unittest.TestCase):
def setUp(self):
if PROTOCOL_VERSION < 2:
raise unittest.SkipTest(
"Protocol 2.0+ is required for BATCH operations, currently testing against %r"
% (PROTOCOL_VERSION,))
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
self.session = self.cluster.connect()
self.session.execute("TRUNCATE test3rf.test")
def test_paging(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params)
prepared = self.session.prepare("SELECT * FROM test3rf.test")
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
self.session.default_fetch_size = fetch_size
self.assertEqual(100, len(list(self.session.execute("SELECT * FROM test3rf.test"))))
statement = SimpleStatement("SELECT * FROM test3rf.test")
self.assertEqual(100, len(list(self.session.execute(statement))))
self.assertEqual(100, len(list(self.session.execute(prepared))))
def test_async_paging(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params)
prepared = self.session.prepare("SELECT * FROM test3rf.test")
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
self.session.default_fetch_size = fetch_size
self.assertEqual(100, len(list(self.session.execute_async("SELECT * FROM test3rf.test").result())))
statement = SimpleStatement("SELECT * FROM test3rf.test")
self.assertEqual(100, len(list(self.session.execute_async(statement).result())))
self.assertEqual(100, len(list(self.session.execute_async(prepared).result())))
def test_paging_callbacks(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params)
prepared = self.session.prepare("SELECT * FROM test3rf.test")
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
self.session.default_fetch_size = fetch_size
future = self.session.execute_async("SELECT * FROM test3rf.test")
event = Event()
counter = count()
def handle_page(rows, future, counter):
for row in rows:
counter.next()
if future.has_more_pages:
future.start_fetching_next_page()
else:
event.set()
def handle_error(err):
event.set()
self.fail(err)
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait()
self.assertEquals(counter.next(), 100)
# simple statement
future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"))
event.clear()
counter = count()
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait()
self.assertEquals(counter.next(), 100)
# prepared statement
future = self.session.execute_async(prepared)
event.clear()
counter = count()
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait()
self.assertEquals(counter.next(), 100)

View File

@@ -18,7 +18,7 @@ from cassandra.cqltypes import Int32Type, EMPTY
from cassandra.query import dict_factory
from cassandra.util import OrderedDict
from tests.integration import get_server_versions
from tests.integration import get_server_versions, PROTOCOL_VERSION
class TypeTests(unittest.TestCase):
@@ -27,7 +27,7 @@ class TypeTests(unittest.TestCase):
self._cass_version, self._cql_version = get_server_versions()
def test_blob_type_as_string(self):
c = Cluster()
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""
@@ -69,7 +69,7 @@ class TypeTests(unittest.TestCase):
self.assertEqual(expected, actual)
def test_blob_type_as_bytearray(self):
c = Cluster()
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""
@@ -129,7 +129,7 @@ class TypeTests(unittest.TestCase):
"""
def test_basic_types(self):
c = Cluster()
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""
CREATE KEYSPACE typetests
@@ -226,7 +226,7 @@ class TypeTests(unittest.TestCase):
self.assertEqual(expected, actual)
def test_empty_strings_and_nones(self):
c = Cluster()
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""
CREATE KEYSPACE test_empty_strings_and_nones
@@ -329,7 +329,7 @@ class TypeTests(unittest.TestCase):
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None])
def test_empty_values(self):
c = Cluster()
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""
CREATE KEYSPACE test_empty_values
@@ -356,7 +356,7 @@ class TypeTests(unittest.TestCase):
eastern_tz = pytz.timezone('US/Eastern')
eastern_tz.localize(dt)
c = Cluster()
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""CREATE KEYSPACE tz_aware_test

View File

@@ -136,13 +136,8 @@ class ConnectionTest(unittest.TestCase):
"""
Ensure the following methods throw NIE's. If not, come back and test them.
"""
c = self.make_connection()
self.assertRaises(NotImplementedError, c.close)
self.assertRaises(NotImplementedError, c.defunct, None)
self.assertRaises(NotImplementedError, c.send_msg, None, None)
self.assertRaises(NotImplementedError, c.wait_for_response, None)
self.assertRaises(NotImplementedError, c.wait_for_responses)
self.assertRaises(NotImplementedError, c.register_watcher, None, None)
self.assertRaises(NotImplementedError, c.register_watchers, None)

View File

@@ -59,8 +59,8 @@ class MockCluster(object):
self.scheduler = Mock(spec=_Scheduler)
self.executor = Mock(spec=ThreadPoolExecutor)
def add_host(self, address, signal=False):
host = Host(address, SimpleConvictionPolicy)
def add_host(self, address, datacenter, rack, signal=False):
host = Host(address, SimpleConvictionPolicy, datacenter, rack)
self.added_hosts.append(host)
return host
@@ -212,6 +212,7 @@ class ControlConnectionTest(unittest.TestCase):
self.connection.peer_results[1].append(
["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", ["3", "103", "203"]]
)
self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs)
self.control_connection.refresh_node_list_and_token_map()
self.assertEqual(1, len(self.cluster.added_hosts))
self.assertEqual(self.cluster.added_hosts[0].address, "192.168.1.3")
@@ -250,7 +251,7 @@ class ControlConnectionTest(unittest.TestCase):
'address': ('1.2.3.4', 9000)
}
self.control_connection._handle_topology_change(event)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.add_host, '1.2.3.4', signal=True)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map)
event = {
'change_type': 'REMOVED_NODE',
@@ -272,7 +273,7 @@ class ControlConnectionTest(unittest.TestCase):
'address': ('1.2.3.4', 9000)
}
self.control_connection._handle_status_change(event)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.add_host, '1.2.3.4', signal=True)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map)
# do the same with a known Host
event = {

View File

@@ -35,6 +35,9 @@ class ResponseFutureTests(unittest.TestCase):
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
return ResponseFuture(session, message, query)
def make_mock_response(self, results):
return Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=results, paging_state=None)
def test_result_message(self):
session = self.make_basic_session()
session._load_balancer.make_query_plan.return_value = ['ip1', 'ip2']
@@ -49,9 +52,7 @@ class ResponseFutureTests(unittest.TestCase):
connection = pool.borrow_connection.return_value
connection.send_msg.assert_called_once_with(rf.message, cb=ANY)
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
rf._set_result(response)
rf._set_result(self.make_mock_response([{'col': 'val'}]))
result = rf.result()
self.assertEqual(result, [{'col': 'val'}])
@@ -259,8 +260,7 @@ class ResponseFutureTests(unittest.TestCase):
rf = self.make_response_future(session)
rf.send_request()
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
rf._set_result(response)
rf._set_result(self.make_mock_response([{'col': 'val'}]))
result = rf.result()
self.assertEqual(result, [{'col': 'val'}])
@@ -280,8 +280,7 @@ class ResponseFutureTests(unittest.TestCase):
rf = self.make_response_future(session)
rf.send_request()
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
rf._set_result(response)
rf._set_result(self.make_mock_response([{'col': 'val'}]))
self.assertEqual(rf.result(), [{'col': 'val'}])
# make sure the exception is recorded correctly
@@ -294,8 +293,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.add_callback(self.assertEqual, [{'col': 'val'}])
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
rf._set_result(response)
rf._set_result(self.make_mock_response([{'col': 'val'}]))
result = rf.result()
self.assertEqual(result, [{'col': 'val'}])
@@ -349,8 +347,7 @@ class ResponseFutureTests(unittest.TestCase):
callback=self.assertEqual, callback_args=([{'col': 'val'}],),
errback=self.assertIsInstance, errback_args=(Exception,))
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
rf._set_result(response)
rf._set_result(self.make_mock_response([{'col': 'val'}]))
self.assertEqual(rf.result(), [{'col': 'val'}])
def test_prepared_query_not_found(self):