Merge branch '2.0' into py3k

Conflicts:
	cassandra/cluster.py
	cassandra/connection.py
	cassandra/io/asyncorereactor.py
	cassandra/io/libevreactor.py
	setup.py
	tests/integration/long/test_large_data.py
This commit is contained in:
Tyler Hobbs
2014-04-03 17:45:53 -05:00
33 changed files with 1158 additions and 432 deletions

View File

@@ -1,9 +1,11 @@
1.0.3 1.1.0
===== =====
In Progress In Progress
Features Features
-------- --------
* Gevent is now supported through monkey-patching the stdlib (PYTHON-7,
github issue #46)
* Support static columns in schemas, which are available starting in * Support static columns in schemas, which are available starting in
Cassandra 2.1. (github issue #91) Cassandra 2.1. (github issue #91)
@@ -15,12 +17,23 @@ Bug Fixes
* Ignore SSL_ERROR_WANT_READ and SSL_ERROR_WANT_WRITE socket errors. Previously, * Ignore SSL_ERROR_WANT_READ and SSL_ERROR_WANT_WRITE socket errors. Previously,
these resulted in the connection being defuncted, but they can safely be these resulted in the connection being defuncted, but they can safely be
ignored by the driver. ignored by the driver.
* Don't reconnect the control connection every time Cluster.connect() is
called
* Avoid race condition that could leave ResponseFuture callbacks uncalled
if the callback was added outside of the event loop thread (github issue #95)
* Properly escape keyspace name in Session.set_keyspace(). Previously, the
keyspace name was quoted, but any quotes in the string were not escaped.
* Avoid adding hosts to the load balancing policy before their datacenter
and rack information has been set, if possible.
* Avoid KeyError when updating metadata after droping a table (github issues
#97, #98)
Other Other
----- -----
* Don't ignore column names when parsing typestrings. This is needed for * Don't ignore column names when parsing typestrings. This is needed for
user-defined type support. (github issue #90) user-defined type support. (github issue #90)
* Better error message when libevwrapper is not found * Better error message when libevwrapper is not found
* Only try to import scales when metrics are enabled (github issue #92)
1.0.2 1.0.2
===== =====

View File

@@ -2,6 +2,8 @@
This module houses the main classes you will interact with, This module houses the main classes you will interact with,
:class:`.Cluster` and :class:`.Session`. :class:`.Cluster` and :class:`.Session`.
""" """
from __future__ import absolute_import
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import logging import logging
import socket import socket
@@ -37,8 +39,7 @@ from cassandra.decoder import (QueryMessage, ResultMessage,
BatchMessage, RESULT_KIND_PREPARED, BatchMessage, RESULT_KIND_PREPARED,
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
RESULT_KIND_SCHEMA_CHANGE) RESULT_KIND_SCHEMA_CHANGE)
from cassandra.metadata import Metadata from cassandra.metadata import Metadata, protect_name
# from cassandra.metrics import Metrics
from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy, from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy,
ExponentialReconnectionPolicy, HostDistance, ExponentialReconnectionPolicy, HostDistance,
RetryPolicy) RetryPolicy)
@@ -48,11 +49,15 @@ from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
BatchStatement, bind_params, QueryTrace, Statement, BatchStatement, bind_params, QueryTrace, Statement,
named_tuple_factory, dict_factory) named_tuple_factory, dict_factory)
# libev is all around faster, so we want to try and default to using that when we can # default to gevent when we are monkey patched, otherwise if libev is available, use that as the
try: # default because it's faster than asyncore
from cassandra.io.libevreactor import LibevConnection as DefaultConnection if 'gevent.monkey' in sys.modules:
except ImportError: from cassandra.io.geventreactor import GeventConnection as DefaultConnection
from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA else:
try:
from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA
except ImportError:
from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA
# Forces load of utf8 encoding module to avoid deadlock that occurs # Forces load of utf8 encoding module to avoid deadlock that occurs
# if code that is being imported tries to import the module in a seperate # if code that is being imported tries to import the module in a seperate
@@ -147,8 +152,15 @@ class Cluster(object):
server will be automatically used. server will be automatically used.
""" """
# TODO: docs
protocol_version = 2 protocol_version = 2
"""
The version of the native protocol to use. The protocol version 2
add support for lightweight transactions, batch operations, and
automatic query paging, but is only supported by Cassandra 2.0+. When
working with Cassandra 1.2, this must be set to 1. You can also set
this to 1 when working with Cassandra 2.0+, but features that require
the version 2 protocol will not be enabled.
"""
compression = True compression = True
""" """
@@ -287,13 +299,6 @@ class Cluster(object):
Any of the mutable Cluster attributes may be set as keyword arguments Any of the mutable Cluster attributes may be set as keyword arguments
to the constructor. to the constructor.
""" """
if 'gevent.monkey' in sys.modules:
raise Exception(
"gevent monkey-patching detected. This driver does not currently "
"support gevent, and monkey patching will break the driver "
"completely. You can track progress towards adding gevent "
"support here: https://datastax-oss.atlassian.net/browse/PYTHON-7.")
self.contact_points = contact_points self.contact_points = contact_points
self.port = port self.port = port
self.compression = compression self.compression = compression
@@ -478,20 +483,21 @@ class Cluster(object):
self.load_balancing_policy.populate( self.load_balancing_policy.populate(
weakref.proxy(self), self.metadata.all_hosts()) weakref.proxy(self), self.metadata.all_hosts())
if self.control_connection:
try:
self.control_connection.connect()
log.debug("Control connection created")
except Exception:
log.exception("Control connection failed to connect, "
"shutting down Cluster:")
self.shutdown()
raise
self.load_balancing_policy.check_supported()
self._is_setup = True self._is_setup = True
if self.control_connection:
try:
self.control_connection.connect()
log.debug("Control connection created")
except Exception:
log.exception("Control connection failed to connect, "
"shutting down Cluster:")
self.shutdown()
raise
self.load_balancing_policy.check_supported()
session = self._new_session() session = self._new_session()
if keyspace: if keyspace:
session.set_keyspace(keyspace) session.set_keyspace(keyspace)
@@ -772,13 +778,13 @@ class Cluster(object):
self.on_down(host, is_host_addition, force_if_down=True) self.on_down(host, is_host_addition, force_if_down=True)
return is_down return is_down
def add_host(self, address, signal): def add_host(self, address, datacenter=None, rack=None, signal=True):
""" """
Called when adding initial contact points and when the control Called when adding initial contact points and when the control
connection subsequently discovers a new node. Intended for internal connection subsequently discovers a new node. Intended for internal
use only. use only.
""" """
new_host = self.metadata.add_host(address) new_host = self.metadata.add_host(address, datacenter, rack)
if new_host and signal: if new_host and signal:
log.info("New Cassandra host %s added", address) log.info("New Cassandra host %s added", address)
self.on_add(new_host) self.on_add(new_host)
@@ -947,10 +953,10 @@ class Session(object):
default_fetch_size = 5000 default_fetch_size = 5000
""" """
By default, this many rows will be fetched at a time. This can be By default, this many rows will be fetched at a time. This can be
specified per-query through :attr:`~Statement.fetch_size`. specified per-query through :attr:`.Statement.fetch_size`.
This only takes effect when protocol version 2 or higher is used. This only takes effect when protocol version 2 or higher is used.
See :attr:`~Cluster.protocol_version` for details. See :attr:`.Cluster.protocol_version` for details.
""" """
_lock = None _lock = None
@@ -1293,7 +1299,7 @@ class Session(object):
Set the default keyspace for all queries made through this Session. Set the default keyspace for all queries made through this Session.
This operation blocks until complete. This operation blocks until complete.
""" """
self.execute('USE "%s"' % (keyspace,)) self.execute('USE %s' % (protect_name(keyspace),))
def _set_keyspace_for_all_pools(self, keyspace, callback): def _set_keyspace_for_all_pools(self, keyspace, callback):
""" """
@@ -1602,7 +1608,9 @@ class ControlConnection(object):
host = self._cluster.metadata.get_host(connection.host) host = self._cluster.metadata.get_host(connection.host)
if host: if host:
host.set_location_info(local_row["data_center"], local_row["rack"]) datacenter = local_row.get("data_center")
rack = local_row.get("rack")
self._update_location_info(host, datacenter, rack)
partitioner = local_row.get("partitioner") partitioner = local_row.get("partitioner")
tokens = local_row.get("tokens") tokens = local_row.get("tokens")
@@ -1620,10 +1628,13 @@ class ControlConnection(object):
found_hosts.add(addr) found_hosts.add(addr)
host = self._cluster.metadata.get_host(addr) host = self._cluster.metadata.get_host(addr)
datacenter = row.get("data_center")
rack = row.get("rack")
if host is None: if host is None:
log.debug("[control connection] Found new host to connect to: %s", addr) log.debug("[control connection] Found new host to connect to: %s", addr)
host = self._cluster.add_host(addr, signal=True) host = self._cluster.add_host(addr, datacenter, rack, signal=True)
host.set_location_info(row.get("data_center"), row.get("rack")) else:
self._update_location_info(host, datacenter, rack)
tokens = row.get("tokens") tokens = row.get("tokens")
if partitioner and tokens: if partitioner and tokens:
@@ -1640,11 +1651,22 @@ class ControlConnection(object):
log.debug("[control connection] Fetched ring info, rebuilding metadata") log.debug("[control connection] Fetched ring info, rebuilding metadata")
self._cluster.metadata.rebuild_token_map(partitioner, token_map) self._cluster.metadata.rebuild_token_map(partitioner, token_map)
def _update_location_info(self, host, datacenter, rack):
if host.datacenter == datacenter and host.rack == rack:
return
# If the dc/rack information changes, we need to update the load balancing policy.
# For that, we remove and re-add the node against the policy. Not the most elegant, and assumes
# that the policy will update correctly, but in practice this should work.
self._cluster.load_balancing_policy.on_down(host)
host.set_location_info(datacenter, rack)
self._cluster.load_balancing_policy.on_up(host)
def _handle_topology_change(self, event): def _handle_topology_change(self, event):
change_type = event["change_type"] change_type = event["change_type"]
addr, port = event["address"] addr, port = event["address"]
if change_type == "NEW_NODE": if change_type == "NEW_NODE":
self._cluster.scheduler.schedule(10, self._cluster.add_host, addr, signal=True) self._cluster.scheduler.schedule(10, self.refresh_node_list_and_token_map)
elif change_type == "REMOVED_NODE": elif change_type == "REMOVED_NODE":
host = self._cluster.metadata.get_host(addr) host = self._cluster.metadata.get_host(addr)
self._cluster.scheduler.schedule(0, self._cluster.remove_host, host) self._cluster.scheduler.schedule(0, self._cluster.remove_host, host)
@@ -1658,7 +1680,7 @@ class ControlConnection(object):
if change_type == "UP": if change_type == "UP":
if host is None: if host is None:
# this is the first time we've seen the node # this is the first time we've seen the node
self._cluster.scheduler.schedule(1, self._cluster.add_host, addr, signal=True) self._cluster.scheduler.schedule(1, self.refresh_node_list_and_token_map)
else: else:
# this will be run by the scheduler # this will be run by the scheduler
self._cluster.scheduler.schedule(1, self._cluster.on_up, host) self._cluster.scheduler.schedule(1, self._cluster.on_up, host)
@@ -1897,23 +1919,19 @@ class ResponseFuture(object):
self.default_timeout = default_timeout self.default_timeout = default_timeout
self._metrics = metrics self._metrics = metrics
self.prepared_statement = prepared_statement self.prepared_statement = prepared_statement
self._callback_lock = Lock()
if metrics is not None: if metrics is not None:
self._start_time = time.time() self._start_time = time.time()
self._make_query_plan()
# convert the list/generator/etc to an iterator so that subsequent
# calls to send_request (which retries may do) will resume where
# they last left off
self.query_plan = iter(session._load_balancer.make_query_plan(
session.keyspace, query))
self._event = Event() self._event = Event()
self._errors = {} self._errors = {}
def __del__(self): def _make_query_plan(self):
try: # convert the list/generator/etc to an iterator so that subsequent
del self.session # calls to send_request (which retries may do) will resume where
except AttributeError: # they last left off
pass self.query_plan = iter(self.session._load_balancer.make_query_plan(
self.session.keyspace, self.query))
def send_request(self): def send_request(self):
""" Internal """ """ Internal """
@@ -1951,6 +1969,8 @@ class ResponseFuture(object):
except Exception as exc: except Exception as exc:
log.debug("Error querying host %s", host, exc_info=True) log.debug("Error querying host %s", host, exc_info=True)
self._errors[host] = exc self._errors[host] = exc
if self._metrics is not None:
self._metrics.on_connection_error()
if connection: if connection:
pool.return_connection(connection) pool.return_connection(connection)
return None return None
@@ -1960,6 +1980,33 @@ class ResponseFuture(object):
self._connection = connection self._connection = connection
return request_id return request_id
@property
def has_more_pages(self):
"""
Returns :const:`True` if there are more pages left in the
query results, :const:`False` otherwise. This should only
be checked after the first page has been returned.
"""
return self._paging_state is not None
def start_fetching_next_page(self):
"""
If there are more pages left in the query result, this asynchronously
starts fetching the next page. If there are no pages left, :exc:`.QueryExhausted`
is raised. Also see :attr:`.has_more_pages`.
This should only be called after the first page has been returned.
"""
if not self._paging_state:
raise QueryExhausted()
self._make_query_plan()
self.message.paging_state = self._paging_state
self._event.clear()
self._final_result = _NOT_SET
self._final_exception = None
self.send_request()
def _reprepare(self, prepare_message): def _reprepare(self, prepare_message):
cb = partial(self.session.submit, self._execute_after_prepare) cb = partial(self.session.submit, self._execute_after_prepare)
request_id = self._query(self._current_host, prepare_message, cb=cb) request_id = self._query(self._current_host, prepare_message, cb=cb)
@@ -2163,12 +2210,10 @@ class ResponseFuture(object):
def _set_final_result(self, response): def _set_final_result(self, response):
if self._metrics is not None: if self._metrics is not None:
self._metrics.request_timer.addValue(time.time() - self._start_time) self._metrics.request_timer.addValue(time.time() - self._start_time)
if hasattr(self, 'session'):
try: with self._callback_lock:
del self.session # clear reference cycles self._final_result = response
except AttributeError:
pass
self._final_result = response
self._event.set() self._event.set()
if self._callback: if self._callback:
fn, args, kwargs = self._callback fn, args, kwargs = self._callback
@@ -2177,11 +2222,9 @@ class ResponseFuture(object):
def _set_final_exception(self, response): def _set_final_exception(self, response):
if self._metrics is not None: if self._metrics is not None:
self._metrics.request_timer.addValue(time.time() - self._start_time) self._metrics.request_timer.addValue(time.time() - self._start_time)
try:
del self.session # clear reference cycles with self._callback_lock:
except AttributeError: self._final_exception = response
pass
self._final_exception = response
self._event.set() self._event.set()
if self._errback: if self._errback:
fn, args, kwargs = self._errback fn, args, kwargs = self._errback
@@ -2242,13 +2285,19 @@ class ResponseFuture(object):
timeout = self.default_timeout timeout = self.default_timeout
if self._final_result is not _NOT_SET: if self._final_result is not _NOT_SET:
return self._final_result if self._paging_state is None:
return self._final_result
else:
return PagedResult(self, self._final_result)
elif self._final_exception: elif self._final_exception:
raise self._final_exception raise self._final_exception
else: else:
self._event.wait(timeout=timeout) self._event.wait(timeout=timeout)
if self._final_result is not _NOT_SET: if self._final_result is not _NOT_SET:
return self._final_result if self._paging_state is None:
return self._final_result
else:
return PagedResult(self, self._final_result)
elif self._final_exception: elif self._final_exception:
raise self._final_exception raise self._final_exception
else: else:
@@ -2301,10 +2350,14 @@ class ResponseFuture(object):
>>> future.add_callback(handle_results, time.time(), should_log=True) >>> future.add_callback(handle_results, time.time(), should_log=True)
""" """
if self._final_result is not _NOT_SET: run_now = False
with self._callback_lock:
if self._final_result is not _NOT_SET:
run_now = True
else:
self._callback = (fn, args, kwargs)
if run_now:
fn(self._final_result, *args, **kwargs) fn(self._final_result, *args, **kwargs)
else:
self._callback = (fn, args, kwargs)
return self return self
def add_errback(self, fn, *args, **kwargs): def add_errback(self, fn, *args, **kwargs):
@@ -2313,10 +2366,14 @@ class ResponseFuture(object):
An Exception instance will be passed as the first positional argument An Exception instance will be passed as the first positional argument
to `fn`. to `fn`.
""" """
if self._final_exception: run_now = False
with self._callback_lock:
if self._final_exception:
run_now = True
else:
self._errback = (fn, args, kwargs)
if run_now:
fn(self._final_exception, *args, **kwargs) fn(self._final_exception, *args, **kwargs)
else:
self._errback = (fn, args, kwargs)
return self return self
def add_callbacks(self, callback, errback, def add_callbacks(self, callback, errback,
@@ -2352,3 +2409,58 @@ class ResponseFuture(object):
return "<ResponseFuture: query='%s' request_id=%s result=%s exception=%s host=%s>" \ return "<ResponseFuture: query='%s' request_id=%s result=%s exception=%s host=%s>" \
% (self.query, self._req_id, result, self._final_exception, self._current_host) % (self.query, self._req_id, result, self._final_exception, self._current_host)
__repr__ = __str__ __repr__ = __str__
class QueryExhausted(Exception):
"""
Raised when :meth:`.ResultSet.start_fetching_next_page()` is called and
there are no more pages. You can check :attr:`.ResultSet.has_more_pages`
before calling to avoid this.
"""
pass
class PagedResult(object):
"""
An iterator over the rows from a paged query result. Whenever the number
of result rows for a query exceed the :attr:`~.query.Statement.fetch_size`
(or :attr:`~.Session.default_fetch_size`, if not set) an instance of this
class will be returned.
You can treat this as a normal iterator over rows::
>>> from cassandra.query import SimpleStatement
>>> statement = SimpleStatement("SELECT * FROM users", fetch_size=10)
>>> for user_row in session.execute(statement):
... process_user(user_row)
Whenever there are no more rows in the current page, the next page will
be fetched transparently. However, note that it _is_ possible for
an :class:`Exception` to be raised while fetching the next page, just
like you might see on a normal call to ``session.execute()``.
"""
def __init__(self, response_future, initial_response):
self.response_future = response_future
self.current_response = iter(initial_response)
def __iter__(self):
return self
def next(self):
try:
return next(self.current_response)
except StopIteration:
if self.response_future._paging_state is None:
raise
self.response_future.start_fetching_next_page()
result = self.response_future.result()
if self.response_future.has_more_pages:
self.current_response = result.current_response
else:
self.current_response = iter(result)
return next(self.current_response)
__next__ = next

148
cassandra/concurrent.py Normal file
View File

@@ -0,0 +1,148 @@
import sys
from itertools import count, cycle
from threading import Event
def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True):
"""
Executes a sequence of (statement, parameters) tuples concurrently. Each
``parameters`` item must be a sequence or :const:`None`.
A sequence of ``(success, result_or_exc)`` tuples is returned in the same
order that the statements were passed in. If ``success`` if :const:`False`,
there was an error executing the statement, and ``result_or_exc`` will be
an :class:`Exception`. If ``success`` is :const:`True`, ``result_or_exc``
will be the query result.
If `raise_on_first_error` is left as :const:`True`, execution will stop
after the first failed statement and the corresponding exception will be
raised.
The `concurrency` parameter controls how many statements will be executed
concurrently. It is recommended that this be kept below the number of
core connections per host times the number of connected hosts (see
:meth:`.Cluster.set_core_connections_per_host`). If that amount is exceeded,
the event loop thread may attempt to block on new connection creation,
substantially impacting throughput.
Example usage::
select_statement = session.prepare("SELECT * FROM users WHERE id=?")
statements_and_params = []
for user_id in user_ids:
statatements_and_params.append(
(select_statement, user_id))
results = execute_concurrent(
session, statements_and_params, raise_on_first_error=False)
for (success, result) in results:
if not success:
handle_error(result) # result will be an Exception
else:
process_user(result[0]) # result will be a list of rows
"""
if concurrency <= 0:
raise ValueError("concurrency must be greater than 0")
if not statements_and_parameters:
return []
event = Event()
first_error = [] if raise_on_first_error else None
to_execute = len(statements_and_parameters) # TODO handle iterators/generators
results = [None] * to_execute
num_finished = count(start=1)
statements = enumerate(iter(statements_and_parameters))
for i in xrange(min(concurrency, len(statements_and_parameters))):
_execute_next(_sentinel, i, event, session, statements, results, num_finished, to_execute, first_error)
event.wait()
if first_error:
raise first_error[0]
else:
return results
def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs):
"""
Like :meth:`~.execute_concurrent`, but takes a single statement and a
sequence of parameters. Each item in ``parameters`` should be a sequence
or :const:`None`.
Example usage::
statement = session.prepare("INSERT INTO mytable (a, b) VALUES (1, ?)")
parameters = [(x,) for x in range(1000)]
execute_concurrent_with_args(session, statement, parameters)
"""
return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs)
_sentinel = object()
def _handle_error(error, result_index, event, session, statements, results, num_finished, to_execute, first_error):
if first_error is not None:
first_error.append(error)
event.set()
return
else:
results[result_index] = (False, error)
if num_finished.next() >= to_execute:
event.set()
return
try:
(next_index, (statement, params)) = statements.next()
except StopIteration:
return
args = (next_index, event, session, statements, results, num_finished, to_execute, first_error)
try:
session.execute_async(statement, params).add_callbacks(
callback=_execute_next, callback_args=args,
errback=_handle_error, errback_args=args)
except Exception as exc:
if first_error is not None:
first_error.append(sys.exc_info())
event.set()
return
else:
results[next_index] = (False, exc)
if num_finished.next() >= to_execute:
event.set()
return
def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error):
if result is not _sentinel:
results[result_index] = (True, result)
finished = num_finished.next()
if finished >= to_execute:
event.set()
return
try:
(next_index, (statement, params)) = statements.next()
except StopIteration:
return
args = (next_index, event, session, statements, results, num_finished, to_execute, first_error)
try:
session.execute_async(statement, params).add_callbacks(
callback=_execute_next, callback_args=args,
errback=_handle_error, errback_args=args)
except Exception as exc:
if first_error is not None:
first_error.append(sys.exc_info())
event.set()
return
else:
results[next_index] = (False, exc)
if num_finished.next() >= to_execute:
event.set()
return

View File

@@ -1,12 +1,20 @@
import errno import errno
from functools import wraps, partial from functools import wraps, partial
import logging import logging
import sys
from threading import Event, RLock from threading import Event, RLock
import time
import traceback
from six.moves.queue import Queue if 'gevent.monkey' in sys.modules:
from gevent.queue import Queue, Empty
else:
from six.moves.queue import Queue, Empty # noqa
from six.moves import range
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
from cassandra.marshal import int8_unpack, int32_pack, header_unpack from cassandra.marshal import int32_pack, header_unpack
from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage, from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage,
StartupMessage, ErrorMessage, CredentialsMessage, StartupMessage, ErrorMessage, CredentialsMessage,
QueryMessage, ResultMessage, decode_response, QueryMessage, ResultMessage, decode_response,
@@ -151,16 +159,99 @@ class Connection(object):
raise NotImplementedError() raise NotImplementedError()
def defunct(self, exc): def defunct(self, exc):
raise NotImplementedError() with self.lock:
if self.is_defunct or self.is_closed:
return
self.is_defunct = True
def send_msg(self, msg, cb): trace = traceback.format_exc(exc)
raise NotImplementedError() if trace != "None":
log.debug("Defuncting connection (%s) to %s: %s\n%s",
id(self), self.host, exc, traceback.format_exc(exc))
else:
log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc)
def wait_for_response(self, msg, **kwargs): self.last_error = exc
raise NotImplementedError() self.close()
self.error_all_callbacks(exc)
self.connected_event.set()
return exc
def error_all_callbacks(self, exc):
with self.lock:
callbacks = self._callbacks
self._callbacks = {}
new_exc = ConnectionShutdown(str(exc))
for cb in callbacks.values():
try:
cb(new_exc)
except Exception:
log.warn("Ignoring unhandled exception while erroring callbacks for a "
"failed connection (%s) to host %s:",
id(self), self.host, exc_info=True)
def handle_pushed(self, response):
log.debug("Message pushed from server: %r", response)
for cb in self._push_watchers.get(response.event_type, []):
try:
cb(response.event_args)
except Exception:
log.exception("Pushed event handler errored, ignoring:")
def send_msg(self, msg, cb, wait_for_id=False):
if self.is_defunct:
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
elif self.is_closed:
raise ConnectionShutdown("Connection to %s is closed" % self.host)
if not wait_for_id:
try:
request_id = self._id_queue.get_nowait()
except Empty:
raise ConnectionBusy(
"Connection to %s is at the max number of requests" % self.host)
else:
request_id = self._id_queue.get()
self._callbacks[request_id] = cb
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
return request_id
def wait_for_response(self, msg, timeout=None):
return self.wait_for_responses(msg, timeout=timeout)[0]
def wait_for_responses(self, *msgs, **kwargs): def wait_for_responses(self, *msgs, **kwargs):
raise NotImplementedError() timeout = kwargs.get('timeout')
waiter = ResponseWaiter(self, len(msgs))
# busy wait for sufficient space on the connection
messages_sent = 0
while True:
needed = len(msgs) - messages_sent
with self.lock:
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
self.in_flight += available
for i in range(messages_sent, messages_sent + available):
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
messages_sent += available
if messages_sent == len(msgs):
break
else:
if timeout is not None:
timeout -= 0.01
if timeout <= 0.0:
raise OperationTimedOut()
time.sleep(0.01)
try:
return waiter.deliver(timeout)
except OperationTimedOut:
raise
except Exception, exc:
self.defunct(exc)
raise
def register_watcher(self, event_type, callback): def register_watcher(self, event_type, callback):
raise NotImplementedError() raise NotImplementedError()

View File

@@ -1,16 +1,13 @@
from collections import defaultdict, deque from collections import defaultdict, deque
from functools import partial
import logging import logging
import os import os
import socket import socket
import sys import sys
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
import time
import traceback
from six import BytesIO from six import BytesIO
from six.moves import queue as Queue from six.moves import range
from six.moves import xrange
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode
import asyncore import asyncore
@@ -21,9 +18,8 @@ except ImportError:
ssl = None # NOQA ssl = None # NOQA
from cassandra import OperationTimedOut from cassandra import OperationTimedOut
from cassandra.connection import (Connection, ResponseWaiter, ConnectionShutdown, from cassandra.connection import (Connection, ConnectionShutdown,
ConnectionBusy, ConnectionException, NONBLOCKING, ConnectionException, NONBLOCKING)
MAX_STREAM_PER_CONNECTION)
from cassandra.decoder import RegisterMessage from cassandra.decoder import RegisterMessage
from cassandra.marshal import int32_unpack from cassandra.marshal import int32_unpack
@@ -172,44 +168,11 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
_starting_conns.discard(self) _starting_conns.discard(self)
if not self.is_defunct: if not self.is_defunct:
self._error_all_callbacks( self.error_all_callbacks(
ConnectionShutdown("Connection to %s was closed" % self.host)) ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging # don't leave in-progress operations hanging
self.connected_event.set() self.connected_event.set()
def defunct(self, exc):
with self.lock:
if self.is_defunct or self.is_closed:
return
self.is_defunct = True
trace = traceback.format_exc() #exc)
if trace != "None":
log.debug("Defuncting connection (%s) to %s: %s\n%s",
id(self), self.host, exc, traceback.format_exc())
else:
log.debug("Defuncting connection (%s) to %s: %s",
id(self), self.host, exc)
self.last_error = exc
self.close()
self._error_all_callbacks(exc)
self.connected_event.set()
return exc
def _error_all_callbacks(self, exc):
with self.lock:
callbacks = self._callbacks
self._callbacks = {}
new_exc = ConnectionShutdown(str(exc))
for cb in callbacks.values():
try:
cb(new_exc)
except Exception:
log.warn("Ignoring unhandled exception while erroring callbacks for a "
"failed connection (%s) to host %s:",
id(self), self.host, exc_info=True)
def handle_connect(self): def handle_connect(self):
with _starting_conns_lock: with _starting_conns_lock:
_starting_conns.discard(self) _starting_conns.discard(self)
@@ -300,19 +263,11 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
if not self._callbacks: if not self._callbacks:
self._readable = False self._readable = False
def handle_pushed(self, response):
log.debug("Message pushed from server: %r", response)
for cb in self._push_watchers.get(response.event_type, []):
try:
cb(response.event_args)
except Exception:
log.exception("Pushed event handler errored, ignoring:")
def push(self, data): def push(self, data):
sabs = self.out_buffer_size sabs = self.out_buffer_size
if len(data) > sabs: if len(data) > sabs:
chunks = [] chunks = []
for i in xrange(0, len(data), sabs): for i in range(0, len(data), sabs):
chunks.append(data[i:i + sabs]) chunks.append(data[i:i + sabs])
else: else:
chunks = [data] chunks = [data]
@@ -328,61 +283,6 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
def readable(self): def readable(self):
return self._readable or (self._have_listeners and not (self.is_defunct or self.is_closed)) return self._readable or (self._have_listeners and not (self.is_defunct or self.is_closed))
def send_msg(self, msg, cb, wait_for_id=False):
if self.is_defunct:
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
elif self.is_closed:
raise ConnectionShutdown("Connection to %s is closed" % self.host)
if not wait_for_id:
try:
request_id = self._id_queue.get_nowait()
except Queue.Empty:
raise ConnectionBusy(
"Connection to %s is at the max number of requests" % self.host)
else:
request_id = self._id_queue.get()
self._callbacks[request_id] = cb
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
return request_id
def wait_for_response(self, msg, timeout=None):
return self.wait_for_responses(msg, timeout=timeout)[0]
def wait_for_responses(self, *msgs, **kwargs):
timeout = kwargs.get('timeout')
waiter = ResponseWaiter(self, len(msgs))
# busy wait for sufficient space on the connection
messages_sent = 0
while True:
needed = len(msgs) - messages_sent
with self.lock:
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
self.in_flight += available
for i in range(messages_sent, messages_sent + available):
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
messages_sent += available
if messages_sent == len(msgs):
break
else:
if timeout is not None:
timeout -= 0.01
if timeout <= 0.0:
raise OperationTimedOut()
time.sleep(0.01)
try:
return waiter.deliver(timeout)
except OperationTimedOut:
raise
except Exception as exc:
self.defunct(exc)
raise
def register_watcher(self, event_type, callback): def register_watcher(self, event_type, callback):
self._push_watchers[event_type].add(callback) self._push_watchers[event_type].add(callback)
self._have_listeners = True self._have_listeners = True

View File

@@ -0,0 +1,190 @@
import gevent
from gevent import select, socket
from gevent.event import Event
from gevent.queue import Queue
from collections import defaultdict
from functools import partial
import logging
import os
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO # ignore flake8 warning: # NOQA
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL
from cassandra import OperationTimedOut
from cassandra.connection import Connection, ConnectionShutdown
from cassandra.decoder import RegisterMessage
from cassandra.marshal import int32_unpack
log = logging.getLogger(__name__)
def is_timeout(err):
return (
err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or
(err == EINVAL and os.name in ('nt', 'ce'))
)
class GeventConnection(Connection):
"""
An implementation of :class:`.Connection` that utilizes ``gevent``.
"""
_total_reqd_bytes = 0
_read_watcher = None
_write_watcher = None
_socket = None
@classmethod
def factory(cls, *args, **kwargs):
timeout = kwargs.pop('timeout', 5.0)
conn = cls(*args, **kwargs)
conn.connected_event.wait(timeout)
if conn.last_error:
raise conn.last_error
elif not conn.connected_event.is_set():
conn.close()
raise OperationTimedOut("Timed out creating connection")
else:
return conn
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self.connected_event = Event()
self._iobuf = StringIO()
self._write_queue = Queue()
self._callbacks = {}
self._push_watchers = defaultdict(set)
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(1.0)
self._socket.connect((self.host, self.port))
if self.sockopts:
for args in self.sockopts:
self._socket.setsockopt(*args)
self._read_watcher = gevent.spawn(lambda: self.handle_read())
self._write_watcher = gevent.spawn(lambda: self.handle_write())
self._send_options_message()
def close(self):
with self.lock:
if self.is_closed:
return
self.is_closed = True
log.debug("Closing connection (%s) to %s" % (id(self), self.host))
if self._read_watcher:
self._read_watcher.kill()
if self._write_watcher:
self._write_watcher.kill()
if self._socket:
self._socket.close()
log.debug("Closed socket to %s" % (self.host,))
if not self.is_defunct:
self.error_all_callbacks(
ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging
self.connected_event.set()
def handle_close(self):
log.debug("connection closed by server")
self.close()
def handle_write(self):
run_select = partial(select.select, (), (self._socket,), ())
while True:
try:
next_msg = self._write_queue.get()
run_select()
except Exception as exc:
log.debug("Exception during write select() for %s: %s", self, exc)
self.defunct(exc)
return
try:
self._socket.sendall(next_msg)
except socket.error as err:
log.debug("Exception during socket sendall for %s: %s", self, err)
self.defunct(err)
return # Leave the write loop
def handle_read(self):
run_select = partial(select.select, (self._socket,), (), ())
while True:
try:
run_select()
except Exception as exc:
log.debug("Exception during read select() for %s: %s", self, exc)
self.defunct(exc)
return
try:
buf = self._socket.recv(self.in_buffer_size)
self._iobuf.write(buf)
except socket.error as err:
if not is_timeout(err):
log.debug("Exception during socket recv for %s: %s", self, err)
self.defunct(err)
return # leave the read loop
if self._iobuf.tell():
while True:
pos = self._iobuf.tell()
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
# we don't have a complete header yet or we
# already saw a header, but we don't have a
# complete message yet
break
else:
# have enough for header, read body len from header
self._iobuf.seek(4)
body_len = int32_unpack(self._iobuf.read(4))
# seek to end to get length of current buffer
self._iobuf.seek(0, os.SEEK_END)
pos = self._iobuf.tell()
if pos >= body_len + 8:
# read message header and body
self._iobuf.seek(0)
msg = self._iobuf.read(8 + body_len)
# leave leftover in current buffer
leftover = self._iobuf.read()
self._iobuf = StringIO()
self._iobuf.write(leftover)
self._total_reqd_bytes = 0
self.process_msg(msg, body_len)
else:
self._total_reqd_bytes = body_len + 8
break
else:
log.debug("connection closed by server")
self.close()
return
def push(self, data):
chunk_size = self.out_buffer_size
for i in xrange(0, len(data), chunk_size):
self._write_queue.put(data[i:i + chunk_size])
def register_watcher(self, event_type, callback):
self._push_watchers[event_type].add(callback)
self.wait_for_response(RegisterMessage(event_list=[event_type]))
def register_watchers(self, type_callback_dict):
for event_type, callback in type_callback_dict.items():
self._push_watchers[event_type].add(callback)
self.wait_for_response(RegisterMessage(event_list=type_callback_dict.keys()))

View File

@@ -1,20 +1,13 @@
from collections import defaultdict, deque from collections import defaultdict, deque
from functools import partial, wraps
import logging import logging
import os import os
import socket import socket
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
import time
import traceback
from six.moves.queue import Queue from six import BytesIO
from six.moves import cStringIO as StringIO
from six.moves import xrange
from cassandra import OperationTimedOut from cassandra import OperationTimedOut
from cassandra.connection import (Connection, ResponseWaiter, ConnectionShutdown, from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING
ConnectionBusy, NONBLOCKING,
MAX_STREAM_PER_CONNECTION)
from cassandra.decoder import RegisterMessage from cassandra.decoder import RegisterMessage
from cassandra.marshal import int32_unpack from cassandra.marshal import int32_unpack
try: try:
@@ -80,18 +73,6 @@ def _start_loop():
return should_start return should_start
def defunct_on_error(f):
@wraps(f)
def wrapper(self, *args, **kwargs):
try:
return f(self, *args, **kwargs)
except Exception as exc:
self.defunct(exc)
return wrapper
class LibevConnection(Connection): class LibevConnection(Connection):
""" """
An implementation of :class:`.Connection` that uses libev for its event loop. An implementation of :class:`.Connection` that uses libev for its event loop.
@@ -192,7 +173,7 @@ class LibevConnection(Connection):
Connection.__init__(self, *args, **kwargs) Connection.__init__(self, *args, **kwargs)
self.connected_event = Event() self.connected_event = Event()
self._iobuf = StringIO() self._iobuf = BytesIO()
self._callbacks = {} self._callbacks = {}
self._push_watchers = defaultdict(set) self._push_watchers = defaultdict(set)
@@ -237,41 +218,9 @@ class LibevConnection(Connection):
# don't leave in-progress operations hanging # don't leave in-progress operations hanging
if not self.is_defunct: if not self.is_defunct:
self._error_all_callbacks( self.error_all_callbacks(
ConnectionShutdown("Connection to %s was closed" % self.host)) ConnectionShutdown("Connection to %s was closed" % self.host))
def defunct(self, exc):
with self.lock:
if self.is_defunct or self.is_closed:
return
self.is_defunct = True
trace = traceback.format_exc(exc)
if trace != "None":
log.debug("Defuncting connection (%s) to %s: %s\n%s",
id(self), self.host, exc, traceback.format_exc(exc))
else:
log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc)
self.last_error = exc
self.close()
self._error_all_callbacks(exc)
self.connected_event.set()
return exc
def _error_all_callbacks(self, exc):
with self.lock:
callbacks = self._callbacks
self._callbacks = {}
new_exc = ConnectionShutdown(str(exc))
for cb in callbacks.values():
try:
cb(new_exc)
except Exception:
log.warn("Ignoring unhandled exception while erroring callbacks for a "
"failed connection (%s) to host %s:",
id(self), self.host, exc_info=True)
def handle_write(self, watcher, revents, errno=None): def handle_write(self, watcher, revents, errno=None):
if revents & libev.EV_ERROR: if revents & libev.EV_ERROR:
if errno: if errno:
@@ -351,7 +300,7 @@ class LibevConnection(Connection):
# leave leftover in current buffer # leave leftover in current buffer
leftover = self._iobuf.read() leftover = self._iobuf.read()
self._iobuf = StringIO() self._iobuf = BytesIO()
self._iobuf.write(leftover) self._iobuf.write(leftover)
self._total_reqd_bytes = 0 self._total_reqd_bytes = 0
@@ -363,14 +312,6 @@ class LibevConnection(Connection):
log.debug("Connection %s closed by server", self) log.debug("Connection %s closed by server", self)
self.close() self.close()
def handle_pushed(self, response):
log.debug("Message pushed from server: %r", response)
for cb in self._push_watchers.get(response.event_type, []):
try:
cb(response.event_args)
except Exception:
log.exception("Pushed event handler errored, ignoring:")
def push(self, data): def push(self, data):
sabs = self.out_buffer_size sabs = self.out_buffer_size
if len(data) > sabs: if len(data) > sabs:
@@ -384,61 +325,6 @@ class LibevConnection(Connection):
self.deque.extend(chunks) self.deque.extend(chunks)
_loop_notifier.send() _loop_notifier.send()
def send_msg(self, msg, cb, wait_for_id=False):
if self.is_defunct:
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
elif self.is_closed:
raise ConnectionShutdown("Connection to %s is closed" % self.host)
if not wait_for_id:
try:
request_id = self._id_queue.get_nowait()
except Queue.Empty:
raise ConnectionBusy(
"Connection to %s is at the max number of requests" % self.host)
else:
request_id = self._id_queue.get()
self._callbacks[request_id] = cb
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
return request_id
def wait_for_response(self, msg, timeout=None):
return self.wait_for_responses(msg, timeout=timeout)[0]
def wait_for_responses(self, *msgs, **kwargs):
timeout = kwargs.get('timeout')
waiter = ResponseWaiter(self, len(msgs))
# busy wait for sufficient space on the connection
messages_sent = 0
while True:
needed = len(msgs) - messages_sent
with self.lock:
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
self.in_flight += available
for i in range(messages_sent, messages_sent + available):
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
messages_sent += available
if messages_sent == len(msgs):
break
else:
if timeout is not None:
timeout -= 0.01
if timeout <= 0.0:
raise OperationTimedOut()
time.sleep(0.01)
try:
return waiter.deliver(timeout)
except OperationTimedOut:
raise
except Exception as exc:
self.defunct(exc)
raise
def register_watcher(self, event_type, callback): def register_watcher(self, event_type, callback):
self._push_watchers[event_type].add(callback) self._push_watchers[event_type].add(callback)
self.wait_for_response(RegisterMessage(event_list=[event_type])) self.wait_for_response(RegisterMessage(event_list=[event_type]))

View File

@@ -157,7 +157,7 @@ class Metadata(object):
if not cf_results: if not cf_results:
# the table was removed # the table was removed
del keyspace_meta.tables[table] keyspace_meta.tables.pop(table, None)
else: else:
assert len(cf_results) == 1 assert len(cf_results) == 1
keyspace_meta.tables[table] = self._build_table_metadata( keyspace_meta.tables[table] = self._build_table_metadata(
@@ -346,11 +346,12 @@ class Metadata(object):
else: else:
return True return True
def add_host(self, address): def add_host(self, address, datacenter, rack):
cluster = self.cluster_ref() cluster = self.cluster_ref()
with self._hosts_lock: with self._hosts_lock:
if address not in self._hosts: if address not in self._hosts:
new_host = Host(address, cluster.conviction_policy_factory) new_host = Host(
address, cluster.conviction_policy_factory, datacenter, rack)
self._hosts[address] = new_host self._hosts[address] = new_host
else: else:
return None return None

View File

@@ -137,6 +137,7 @@ class RoundRobinPolicy(LoadBalancingPolicy):
This load balancing policy is used by default. This load balancing policy is used by default.
""" """
_live_hosts = frozenset(())
def populate(self, cluster, hosts): def populate(self, cluster, hosts):
self._live_hosts = frozenset(hosts) self._live_hosts = frozenset(hosts)

View File

@@ -71,7 +71,7 @@ class Host(object):
_currently_handling_node_up = False _currently_handling_node_up = False
_handle_node_up_condition = None _handle_node_up_condition = None
def __init__(self, inet_address, conviction_policy_factory): def __init__(self, inet_address, conviction_policy_factory, datacenter=None, rack=None):
if inet_address is None: if inet_address is None:
raise ValueError("inet_address may not be None") raise ValueError("inet_address may not be None")
if conviction_policy_factory is None: if conviction_policy_factory is None:
@@ -79,6 +79,7 @@ class Host(object):
self.address = inet_address self.address = inet_address
self.conviction_policy = conviction_policy_factory(self) self.conviction_policy = conviction_policy_factory(self)
self.set_location_info(datacenter, rack)
self.lock = RLock() self.lock = RLock()
self._handle_node_up_condition = Condition() self._handle_node_up_condition = Condition()

View File

@@ -73,6 +73,13 @@ class Statement(object):
""" """
fetch_size = None fetch_size = None
"""
How many rows will be fetched at a time. This overrides the default
of :attr:`.Session.default_fetch_size`
This only takes effect when protocol version 2 or higher is used.
See :attr:`.Cluster.protocol_version` for details.
"""
_serial_consistency_level = None _serial_consistency_level = None
_routing_key = None _routing_key = None
@@ -561,6 +568,7 @@ class QueryTrace(object):
if max_wait is not None and time_spent >= max_wait: if max_wait is not None and time_spent >= max_wait:
raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,))
log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id)
session_results = self._execute( session_results = self._execute(
self._SELECT_SESSIONS_FORMAT, (self.trace_id,), time_spent, max_wait) self._SELECT_SESSIONS_FORMAT, (self.trace_id,), time_spent, max_wait)
@@ -568,6 +576,7 @@ class QueryTrace(object):
time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt))
attempt += 1 attempt += 1
continue continue
log.debug("Fetched trace info for trace ID: %s", self.trace_id)
session_row = session_results[0] session_row = session_results[0]
self.request_type = session_row.request self.request_type = session_row.request
@@ -576,9 +585,11 @@ class QueryTrace(object):
self.coordinator = session_row.coordinator self.coordinator = session_row.coordinator
self.parameters = session_row.parameters self.parameters = session_row.parameters
log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id)
time_spent = time.time() - start time_spent = time.time() - start
event_results = self._execute( event_results = self._execute(
self._SELECT_EVENTS_FORMAT, (self.trace_id,), time_spent, max_wait) self._SELECT_EVENTS_FORMAT, (self.trace_id,), time_spent, max_wait)
log.debug("Fetched trace events for trace ID: %s", self.trace_id)
self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread)
for r in event_results) for r in event_results)
break break

View File

@@ -7,6 +7,8 @@
.. autoattribute:: cql_version .. autoattribute:: cql_version
.. autoattribute:: protocol_version
.. autoattribute:: port .. autoattribute:: port
.. autoattribute:: compression .. autoattribute:: compression
@@ -59,6 +61,8 @@
.. autoattribute:: row_factory .. autoattribute:: row_factory
.. autoattribute:: default_fetch_size
.. automethod:: execute(statement[, parameters][, timeout][, trace]) .. automethod:: execute(statement[, parameters][, timeout][, trace])
.. automethod:: execute_async(statement[, parameters][, trace]) .. automethod:: execute_async(statement[, parameters][, trace])
@@ -77,11 +81,20 @@
.. automethod:: get_query_trace() .. automethod:: get_query_trace()
.. autoattribute:: has_more_pages
.. automethod:: start_fetching_next_page()
.. automethod:: add_callback(fn, *args, **kwargs) .. automethod:: add_callback(fn, *args, **kwargs)
.. automethod:: add_errback(fn, *args, **kwargs) .. automethod:: add_errback(fn, *args, **kwargs)
.. automethod:: add_callbacks(callback, errback, callback_args=(), callback_kwargs=None, errback_args=(), errback_args=None) .. automethod:: add_callbacks(callback, errback, callback_args=(), callback_kwargs=None, errback_args=(), errback_args=None)
.. autoclass:: PagedResult ()
:members:
.. autoexception:: QueryExhausted ()
.. autoexception:: NoHostAvailable () .. autoexception:: NoHostAvailable ()
:members: :members:

View File

@@ -8,6 +8,7 @@ Python Cassandra Driver
installation installation
getting_started getting_started
performance performance
query_paging
Indices and Tables Indices and Tables
================== ==================

74
docs/query_paging.rst Normal file
View File

@@ -0,0 +1,74 @@
Paging Large Queries
====================
Cassandra 2.0+ offers support for automatic query paging. Starting with
version 2.0 of the driver, if :attr:`~.Cluster.protocol_version` is set to
:const:`2` (it is by default), queries returning large result sets will be
automatically paged.
Controlling the Page Size
-------------------------
By default, :attr:`.Session.default_fetch_size` controls how many rows will
be fetched per page. This can be overridden per-query by setting
:attr:`~.fetch_size` on a :class:`~.Statement`. By default, each page
will contain at most 5000 rows.
Handling Paged Results
----------------------
Whenever the number of result rows for are query exceed the page size, an
instance of :class:`~.PagedResult` will be returned instead of a normal
list. This class implements the iterator interface, so you can treat
it like a normal iterator over rows::
from cassandra.query import SimpleStatement
query = "SELECT * FROM users" # users contains 100 rows
statement = SimpleStatement(query, fetch_size=10)
for user_row in session.execute(statement):
process_user(user_row)
Whenever there are no more rows in the current page, the next page will
be fetched transparently. However, note that it *is* possible for
an :class:`Exception` to be raised while fetching the next page, just
like you might see on a normal call to ``session.execute()``.
If you use :meth:`.Session.execute_async()` along with,
:meth:`.ResponseFuture.result()`, the first page will be fetched before
:meth:`~.ResponseFuture.result()` returns, but latter pages will be
transparently fetched synchronously while iterating the result.
Handling Paged Results with Callbacks
-------------------------------------
If callbacks are attached to a query that returns a paged result,
the callback will be called once per page with a normal list of rows.
Use :attr:`.ResponseFuture.has_more_pages` and
:meth:`.ResponseFuture.start_fetching_next_page()` to continue fetching
pages. For example::
class PagedResultHandler(object):
def __init__(self, future):
self.error = None
self.finished_event = Event()
self.future = future
self.future.add_callbacks(
callback=self.handle_page,
errback=self.handle_err)
def handle_page(self, rows):
for row in rows:
process_row(row)
if self.future.has_more_pages:
self.future.start_fetching_next_page()
else:
self.finished_event.set()
def handle_error(self, exc):
self.error = exc
self.finished_event.set()
future = session.execute_async("SELECT * FROM users")
handler = PagedResultHandler(future)
handler.finished_event.wait()
if handler.error:
raise handler.error

View File

@@ -1,18 +1,13 @@
from __future__ import print_function from __future__ import print_function
import platform
import os
import sys import sys
import warnings
try:
import subprocess
has_subprocess = True
except ImportError:
has_subprocess = False
import ez_setup import ez_setup
ez_setup.use_setuptools() ez_setup.use_setuptools()
if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests":
from gevent.monkey import patch_all
patch_all()
from setuptools import setup from setuptools import setup
from distutils.command.build_ext import build_ext from distutils.command.build_ext import build_ext
from distutils.core import Extension from distutils.core import Extension
@@ -20,6 +15,19 @@ from distutils.errors import (CCompilerError, DistutilsPlatformError,
DistutilsExecError) DistutilsExecError)
from distutils.cmd import Command from distutils.cmd import Command
import platform
import os
import warnings
try:
import subprocess
has_subprocess = True
except ImportError:
has_subprocess = False
from nose.commands import nosetests
from cassandra import __version__ from cassandra import __version__
long_description = "" long_description = ""
@@ -27,6 +35,10 @@ with open("README.rst") as f:
long_description = f.read() long_description = f.read()
class gevent_nosetests(nosetests):
description = "run nosetests with gevent monkey patching"
class DocCommand(Command): class DocCommand(Command):
description = "generate or test documentation" description = "generate or test documentation"
@@ -144,12 +156,12 @@ On OSX, via homebrew:
def run_setup(extensions): def run_setup(extensions):
kw = {'cmdclass': {'doc': DocCommand}} kw = {'cmdclass': {'doc': DocCommand, 'gevent_nosetests': gevent_nosetests}}
if extensions: if extensions:
kw['cmdclass']['build_ext'] = build_extensions kw['cmdclass']['build_ext'] = build_extensions
kw['ext_modules'] = extensions kw['ext_modules'] = extensions
dependencies = ['futures', 'scales', 'blist', 'six >=1.6'] dependencies = ['futures', 'scales >=1.0.5', 'blist', 'six >=1.6']
if platform.python_implementation() != "CPython": if platform.python_implementation() != "CPython":
dependencies.remove('blist') dependencies.remove('blist')
@@ -164,7 +176,7 @@ def run_setup(extensions):
packages=['cassandra', 'cassandra.io'], packages=['cassandra', 'cassandra.io'],
include_package_data=True, include_package_data=True,
install_requires=dependencies, install_requires=dependencies,
tests_require=['nose', 'mock', 'PyYAML'], tests_require=['nose', 'mock', 'PyYAML', 'pytz'],
classifiers=[ classifiers=[
'Development Status :: 5 - Production/Stable', 'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers', 'Intended Audience :: Developers',

View File

@@ -18,7 +18,14 @@ except ImportError as e:
CLUSTER_NAME = 'test_cluster' CLUSTER_NAME = 'test_cluster'
CCM_CLUSTER = None CCM_CLUSTER = None
DEFAULT_CASSANDRA_VERSION = '2.0.5'
CASSANDRA_VERSION = os.getenv('CASSANDRA_VERSION', '2.0.6')
if CASSANDRA_VERSION.startswith('1'):
DEFAULT_PROTOCOL_VERSION = 1
else:
DEFAULT_PROTOCOL_VERSION = 2
PROTOCOL_VERSION = int(os.getenv('PROTOCOL_VERSION', DEFAULT_PROTOCOL_VERSION))
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'ccm') path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'ccm')
if not os.path.exists(path): if not os.path.exists(path):
@@ -38,7 +45,7 @@ def get_server_versions():
if cass_version is not None: if cass_version is not None:
return (cass_version, cql_version) return (cass_version, cql_version)
c = Cluster() c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect() s = c.connect()
s.set_keyspace('system') s.set_keyspace('system')
row = s.execute('SELECT cql_version, release_version FROM local')[0] row = s.execute('SELECT cql_version, release_version FROM local')[0]
@@ -67,16 +74,16 @@ def get_node(node_id):
def setup_package(): def setup_package():
version = os.getenv("CASSANDRA_VERSION", DEFAULT_CASSANDRA_VERSION) print 'Using Cassandra version: %s' % CASSANDRA_VERSION
try: try:
try: try:
cluster = CCMCluster.load(path, CLUSTER_NAME) cluster = CCMCluster.load(path, CLUSTER_NAME)
log.debug("Found existing ccm test cluster, clearing") log.debug("Found existing ccm test cluster, clearing")
cluster.clear() cluster.clear()
cluster.set_cassandra_dir(cassandra_version=version) cluster.set_cassandra_dir(cassandra_version=CASSANDRA_VERSION)
except Exception: except Exception:
log.debug("Creating new ccm test cluster with version %s", version) log.debug("Creating new ccm test cluster with version %s", CASSANDRA_VERSION)
cluster = CCMCluster(path, CLUSTER_NAME, cassandra_version=version) cluster = CCMCluster(path, CLUSTER_NAME, cassandra_version=CASSANDRA_VERSION)
cluster.set_configuration_options({'start_native_transport': True}) cluster.set_configuration_options({'start_native_transport': True})
common.switch_cluster(path, CLUSTER_NAME) common.switch_cluster(path, CLUSTER_NAME)
cluster.populate(3) cluster.populate(3)
@@ -93,7 +100,7 @@ def setup_package():
def setup_test_keyspace(): def setup_test_keyspace():
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
try: try:

View File

@@ -7,6 +7,7 @@ from cassandra.cluster import Cluster
from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy, \ from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy, \
DowngradingConsistencyRetryPolicy DowngradingConsistencyRetryPolicy
from cassandra.query import SimpleStatement from cassandra.query import SimpleStatement
from tests.integration import PROTOCOL_VERSION
from tests.integration.long.utils import force_stop, create_schema, \ from tests.integration.long.utils import force_stop, create_schema, \
wait_for_down, wait_for_up, start, CoordinatorStats wait_for_down, wait_for_up, start, CoordinatorStats
@@ -98,7 +99,8 @@ class ConsistencyTests(unittest.TestCase):
def _test_tokenaware_one_node_down(self, keyspace, rf, accepted): def _test_tokenaware_one_node_down(self, keyspace, rf, accepted):
cluster = Cluster( cluster = Cluster(
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy())) load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
wait_for_up(cluster, 1, wait=False) wait_for_up(cluster, 1, wait=False)
wait_for_up(cluster, 2) wait_for_up(cluster, 2)
@@ -147,7 +149,8 @@ class ConsistencyTests(unittest.TestCase):
def test_rfthree_tokenaware_none_down(self): def test_rfthree_tokenaware_none_down(self):
keyspace = 'test_rfthree_tokenaware_none_down' keyspace = 'test_rfthree_tokenaware_none_down'
cluster = Cluster( cluster = Cluster(
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy())) load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
wait_for_up(cluster, 1, wait=False) wait_for_up(cluster, 1, wait=False)
wait_for_up(cluster, 2) wait_for_up(cluster, 2)
@@ -169,7 +172,8 @@ class ConsistencyTests(unittest.TestCase):
def _test_downgrading_cl(self, keyspace, rf, accepted): def _test_downgrading_cl(self, keyspace, rf, accepted):
cluster = Cluster( cluster = Cluster(
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
default_retry_policy=DowngradingConsistencyRetryPolicy()) default_retry_policy=DowngradingConsistencyRetryPolicy(),
protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
create_schema(session, keyspace, replication_factor=rf) create_schema(session, keyspace, replication_factor=rf)
@@ -210,14 +214,16 @@ class ConsistencyTests(unittest.TestCase):
keyspace = 'test_rfthree_roundrobin_downgradingcl' keyspace = 'test_rfthree_roundrobin_downgradingcl'
cluster = Cluster( cluster = Cluster(
load_balancing_policy=RoundRobinPolicy(), load_balancing_policy=RoundRobinPolicy(),
default_retry_policy=DowngradingConsistencyRetryPolicy()) default_retry_policy=DowngradingConsistencyRetryPolicy(),
protocol_version=PROTOCOL_VERSION)
self.rfthree_downgradingcl(cluster, keyspace, True) self.rfthree_downgradingcl(cluster, keyspace, True)
def test_rfthree_tokenaware_downgradingcl(self): def test_rfthree_tokenaware_downgradingcl(self):
keyspace = 'test_rfthree_tokenaware_downgradingcl' keyspace = 'test_rfthree_tokenaware_downgradingcl'
cluster = Cluster( cluster = Cluster(
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
default_retry_policy=DowngradingConsistencyRetryPolicy()) default_retry_policy=DowngradingConsistencyRetryPolicy(),
protocol_version=PROTOCOL_VERSION)
self.rfthree_downgradingcl(cluster, keyspace, False) self.rfthree_downgradingcl(cluster, keyspace, False)
def rfthree_downgradingcl(self, cluster, keyspace, roundrobin): def rfthree_downgradingcl(self, cluster, keyspace, roundrobin):

View File

@@ -1,14 +1,15 @@
try: try:
from Queue import Queue, Empty from Queue import Queue, Empty
except ImportError: except ImportError:
from queue import Queue, Empty from queue import Queue, Empty # noqa
from struct import pack from struct import pack
import unittest import unittest
from cassandra import ConsistencyLevel from cassandra import ConsistencyLevel
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.decoder import dict_factory from cassandra.query import dict_factory
from cassandra.query import SimpleStatement from cassandra.query import SimpleStatement
from tests.integration import PROTOCOL_VERSION
from tests.integration.long.utils import create_schema from tests.integration.long.utils import create_schema
@@ -32,7 +33,7 @@ class LargeDataTests(unittest.TestCase):
self.keyspace = 'large_data' self.keyspace = 'large_data'
def make_session_and_keyspace(self): def make_session_and_keyspace(self):
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
session.default_timeout = 20.0 # increase the default timeout session.default_timeout = 20.0 # increase the default timeout
session.row_factory = dict_factory session.row_factory = dict_factory
@@ -41,9 +42,10 @@ class LargeDataTests(unittest.TestCase):
return session return session
def batch_futures(self, session, statement_generator): def batch_futures(self, session, statement_generator):
futures = Queue(maxsize=121) concurrency = 50
futures = Queue.Queue(maxsize=concurrency)
for i, statement in enumerate(statement_generator): for i, statement in enumerate(statement_generator):
if i > 0 and i % 120 == 0: if i > 0 and i % (concurrency - 1) == 0:
# clear the existing queue # clear the existing queue
while True: while True:
try: try:
@@ -70,7 +72,7 @@ class LargeDataTests(unittest.TestCase):
session, session,
(SimpleStatement('INSERT INTO %s (k, i) VALUES (0, %s)' % (table, i), (SimpleStatement('INSERT INTO %s (k, i) VALUES (0, %s)' % (table, i),
consistency_level=ConsistencyLevel.QUORUM) consistency_level=ConsistencyLevel.QUORUM)
for i in range(1000000))) for i in range(100000)))
# Read # Read
results = session.execute('SELECT i FROM %s WHERE k=%s' % (table, 0)) results = session.execute('SELECT i FROM %s WHERE k=%s' % (table, 0))
@@ -112,7 +114,7 @@ class LargeDataTests(unittest.TestCase):
session, session,
(SimpleStatement('INSERT INTO %s (k, i, v) VALUES (0, %s, %s)' % (table, i, str(bb)), (SimpleStatement('INSERT INTO %s (k, i, v) VALUES (0, %s, %s)' % (table, i, str(bb)),
consistency_level=ConsistencyLevel.QUORUM) consistency_level=ConsistencyLevel.QUORUM)
for i in range(1000000))) for i in range(100000)))
# Read # Read
results = session.execute('SELECT i, v FROM %s WHERE k=%s' % (table, 0)) results = session.execute('SELECT i, v FROM %s WHERE k=%s' % (table, 0))

View File

@@ -3,6 +3,7 @@ import logging
from cassandra import ConsistencyLevel from cassandra import ConsistencyLevel
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.query import SimpleStatement from cassandra.query import SimpleStatement
from tests.integration import PROTOCOL_VERSION
try: try:
import unittest2 as unittest import unittest2 as unittest
@@ -15,7 +16,7 @@ log = logging.getLogger(__name__)
class SchemaTests(unittest.TestCase): class SchemaTests(unittest.TestCase):
def test_recreates(self): def test_recreates(self):
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
replication_factor = 3 replication_factor = 3

View File

@@ -4,7 +4,7 @@ import time
from collections import defaultdict from collections import defaultdict
from cassandra.decoder import named_tuple_factory from cassandra.query import named_tuple_factory
from tests.integration import get_node from tests.integration import get_node
@@ -74,7 +74,9 @@ def stop(node):
def force_stop(node): def force_stop(node):
log.debug("Forcing stop of node %s", node)
get_node(node).stop(wait=False, gently=False) get_node(node).stop(wait=False, gently=False)
log.debug("Node %s was stopped", node)
def ring(node): def ring(node):
@@ -85,6 +87,7 @@ def ring(node):
def wait_for_up(cluster, node, wait=True): def wait_for_up(cluster, node, wait=True):
while True: while True:
host = cluster.metadata.get_host('127.0.0.%s' % node) host = cluster.metadata.get_host('127.0.0.%s' % node)
time.sleep(0.1)
if host and host.is_up: if host and host.is_up:
# BUG: shouldn't have to, but we do # BUG: shouldn't have to, but we do
if wait: if wait:
@@ -93,10 +96,14 @@ def wait_for_up(cluster, node, wait=True):
def wait_for_down(cluster, node, wait=True): def wait_for_down(cluster, node, wait=True):
log.debug("Waiting for node %s to be down", node)
while True: while True:
host = cluster.metadata.get_host('127.0.0.%s' % node) host = cluster.metadata.get_host('127.0.0.%s' % node)
time.sleep(0.1)
if not host or not host.is_up: if not host or not host.is_up:
# BUG: shouldn't have to, but we do # BUG: shouldn't have to, but we do
if wait: if wait:
log.debug("Sleeping 5s until host is up")
time.sleep(5) time.sleep(5)
log.debug("Done waiting for node %s to be down", node)
return return

View File

@@ -1,3 +1,5 @@
from tests.integration import PROTOCOL_VERSION
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
@@ -5,7 +7,6 @@ except ImportError:
import cassandra import cassandra
from cassandra.query import SimpleStatement, TraceUnavailable from cassandra.query import SimpleStatement, TraceUnavailable
from cassandra.io.asyncorereactor import AsyncoreConnection
from cassandra.policies import RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance from cassandra.policies import RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance
from cassandra.cluster import Cluster, NoHostAvailable from cassandra.cluster import Cluster, NoHostAvailable
@@ -18,7 +19,7 @@ class ClusterTests(unittest.TestCase):
Test basic connection and usage Test basic connection and usage
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
result = session.execute( result = session.execute(
""" """
@@ -54,7 +55,7 @@ class ClusterTests(unittest.TestCase):
Ensure clusters that connect on a keyspace, do Ensure clusters that connect on a keyspace, do
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
result = session.execute( result = session.execute(
""" """
@@ -71,7 +72,7 @@ class ClusterTests(unittest.TestCase):
self.assertEqual(result, result2) self.assertEqual(result, result2)
def test_set_keyspace_twice(self): def test_set_keyspace_twice(self):
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
session.execute("USE system") session.execute("USE system")
session.execute("USE system") session.execute("USE system")
@@ -86,7 +87,7 @@ class ClusterTests(unittest.TestCase):
reconnection_policy=ExponentialReconnectionPolicy(1.0, 600.0), reconnection_policy=ExponentialReconnectionPolicy(1.0, 600.0),
default_retry_policy=RetryPolicy(), default_retry_policy=RetryPolicy(),
conviction_policy_factory=SimpleConvictionPolicy, conviction_policy_factory=SimpleConvictionPolicy,
connection_class=AsyncoreConnection protocol_version=PROTOCOL_VERSION
) )
def test_double_shutdown(self): def test_double_shutdown(self):
@@ -94,7 +95,7 @@ class ClusterTests(unittest.TestCase):
Ensure that a cluster can be shutdown twice, without error Ensure that a cluster can be shutdown twice, without error
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.shutdown() cluster.shutdown()
try: try:
@@ -108,7 +109,7 @@ class ClusterTests(unittest.TestCase):
Ensure you cannot connect to a cluster that's been shutdown Ensure you cannot connect to a cluster that's been shutdown
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.shutdown() cluster.shutdown()
self.assertRaises(Exception, cluster.connect) self.assertRaises(Exception, cluster.connect)
@@ -132,7 +133,8 @@ class ClusterTests(unittest.TestCase):
when a cluster cannot connect to given hosts when a cluster cannot connect to given hosts
""" """
cluster = Cluster(['127.1.2.9', '127.1.2.10']) cluster = Cluster(['127.1.2.9', '127.1.2.10'],
protocol_version=PROTOCOL_VERSION)
self.assertRaises(NoHostAvailable, cluster.connect) self.assertRaises(NoHostAvailable, cluster.connect)
def test_cluster_settings(self): def test_cluster_settings(self):
@@ -140,7 +142,7 @@ class ClusterTests(unittest.TestCase):
Test connection setting getters and setters Test connection setting getters and setters
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
min_requests_per_connection = cluster.get_min_requests_per_connection(HostDistance.LOCAL) min_requests_per_connection = cluster.get_min_requests_per_connection(HostDistance.LOCAL)
self.assertEqual(cassandra.cluster.DEFAULT_MIN_REQUESTS, min_requests_per_connection) self.assertEqual(cassandra.cluster.DEFAULT_MIN_REQUESTS, min_requests_per_connection)
@@ -167,11 +169,11 @@ class ClusterTests(unittest.TestCase):
Ensure new new schema is refreshed after submit_schema_refresh() Ensure new new schema is refreshed after submit_schema_refresh()
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.connect() cluster.connect()
self.assertNotIn("newkeyspace", cluster.metadata.keyspaces) self.assertNotIn("newkeyspace", cluster.metadata.keyspaces)
other_cluster = Cluster() other_cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = other_cluster.connect() session = other_cluster.connect()
session.execute( session.execute(
""" """
@@ -189,15 +191,22 @@ class ClusterTests(unittest.TestCase):
Ensure trace can be requested for async and non-async queries Ensure trace can be requested for async and non-async queries
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
self.assertRaises(TypeError, session.execute, "SELECT * FROM system.local", trace=True) self.assertRaises(TypeError, session.execute, "SELECT * FROM system.local", trace=True)
def check_trace(trace):
self.assertIsNot(None, trace.request_type)
self.assertIsNot(None, trace.duration)
self.assertIsNot(None, trace.started_at)
self.assertIsNot(None, trace.coordinator)
self.assertIsNot(None, trace.events)
query = "SELECT * FROM system.local" query = "SELECT * FROM system.local"
statement = SimpleStatement(query) statement = SimpleStatement(query)
session.execute(statement, trace=True) session.execute(statement, trace=True)
self.assertEqual(query, statement.trace.parameters['query']) check_trace(statement.trace)
query = "SELECT * FROM system.local" query = "SELECT * FROM system.local"
statement = SimpleStatement(query) statement = SimpleStatement(query)
@@ -207,15 +216,20 @@ class ClusterTests(unittest.TestCase):
statement2 = SimpleStatement(query) statement2 = SimpleStatement(query)
future = session.execute_async(statement2, trace=True) future = session.execute_async(statement2, trace=True)
future.result() future.result()
self.assertEqual(query, future.get_query_trace().parameters['query']) check_trace(future.get_query_trace())
statement2 = SimpleStatement(query) statement2 = SimpleStatement(query)
future = session.execute_async(statement2) future = session.execute_async(statement2)
future.result() future.result()
self.assertEqual(None, future.get_query_trace()) self.assertEqual(None, future.get_query_trace())
prepared = session.prepare("SELECT * FROM system.local")
future = session.execute_async(prepared, parameters=(), trace=True)
future.result()
check_trace(future.get_query_trace())
def test_trace_timeout(self): def test_trace_timeout(self):
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
query = "SELECT * FROM system.local" query = "SELECT * FROM system.local"
@@ -229,7 +243,7 @@ class ClusterTests(unittest.TestCase):
Ensure str(future) returns without error Ensure str(future) returns without error
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
query = "SELECT * FROM system.local" query = "SELECT * FROM system.local"

View File

@@ -0,0 +1,113 @@
from tests.integration import PROTOCOL_VERSION
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from itertools import cycle
from cassandra import InvalidRequest
from cassandra.cluster import Cluster
from cassandra.concurrent import (execute_concurrent,
execute_concurrent_with_args)
from cassandra.policies import HostDistance
from cassandra.query import tuple_factory
class ClusterTests(unittest.TestCase):
def setUp(self):
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
self.session = self.cluster.connect()
self.session.row_factory = tuple_factory
def test_execute_concurrent(self):
for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
# write
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(num_statements)]
results = execute_concurrent(self.session, zip(statements, parameters))
self.assertEqual(num_statements, len(results))
self.assertEqual([(True, None)] * num_statements, results)
# read
statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s", ))
parameters = [(i, ) for i in range(num_statements)]
results = execute_concurrent(self.session, zip(statements, parameters))
self.assertEqual(num_statements, len(results))
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
def test_execute_concurrent_with_args(self):
for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
statement = "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"
parameters = [(i, i) for i in range(num_statements)]
results = execute_concurrent_with_args(self.session, statement, parameters)
self.assertEqual(num_statements, len(results))
self.assertEqual([(True, None)] * num_statements, results)
# read
statement = "SELECT v FROM test3rf.test WHERE k=%s"
parameters = [(i, ) for i in range(num_statements)]
results = execute_concurrent_with_args(self.session, statement, parameters)
self.assertEqual(num_statements, len(results))
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
def test_first_failure(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(100)]
# we'll get an error back from the server
parameters[57] = ('efefef', 'awefawefawef')
self.assertRaises(
InvalidRequest,
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
def test_first_failure_client_side(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(100)]
# the driver will raise an error when binding the params
parameters[57] = 1
self.assertRaises(
TypeError,
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
def test_no_raise_on_first_failure(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(100)]
# we'll get an error back from the server
parameters[57] = ('efefef', 'awefawefawef')
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
for i, (success, result) in enumerate(results):
if i == 57:
self.assertFalse(success)
self.assertIsInstance(result, InvalidRequest)
else:
self.assertTrue(success)
self.assertEqual(None, result)
def test_no_raise_on_first_failure_client_side(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(100)]
# the driver will raise an error when binding the params
parameters[57] = i
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
for i, (success, result) in enumerate(results):
if i == 57:
self.assertFalse(success)
self.assertIsInstance(result, TypeError)
else:
self.assertTrue(success)
self.assertEqual(None, result)

View File

@@ -1,9 +1,12 @@
from tests.integration import PROTOCOL_VERSION
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
from functools import partial from functools import partial
import sys
from threading import Thread, Event from threading import Thread, Event
from cassandra import ConsistencyLevel from cassandra import ConsistencyLevel
@@ -24,7 +27,7 @@ class ConnectionTest(object):
""" """
Test a single connection with sequential requests. Test a single connection with sequential requests.
""" """
conn = self.klass.factory() conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1" query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
event = Event() event = Event()
@@ -47,7 +50,7 @@ class ConnectionTest(object):
""" """
Test a single connection with pipelined requests. Test a single connection with pipelined requests.
""" """
conn = self.klass.factory() conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1" query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
responses = [False] * 100 responses = [False] * 100
event = Event() event = Event()
@@ -69,7 +72,7 @@ class ConnectionTest(object):
""" """
Test multiple connections with pipelined requests. Test multiple connections with pipelined requests.
""" """
conns = [self.klass.factory() for i in range(5)] conns = [self.klass.factory(protocol_version=PROTOCOL_VERSION) for i in range(5)]
events = [Event() for i in range(5)] events = [Event() for i in range(5)]
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1" query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
@@ -100,7 +103,7 @@ class ConnectionTest(object):
num_threads = 5 num_threads = 5
event = Event() event = Event()
conn = self.klass.factory() conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1" query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
def cb(all_responses, thread_responses, request_num, *args, **kwargs): def cb(all_responses, thread_responses, request_num, *args, **kwargs):
@@ -157,7 +160,7 @@ class ConnectionTest(object):
threads = [] threads = []
for i in range(num_conns): for i in range(num_conns):
conn = self.klass.factory() conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
t = Thread(target=send_msgs, args=(conn, events[i])) t = Thread(target=send_msgs, args=(conn, events[i]))
threads.append(t) threads.append(t)
@@ -172,12 +175,18 @@ class AsyncoreConnectionTest(ConnectionTest, unittest.TestCase):
klass = AsyncoreConnection klass = AsyncoreConnection
def setUp(self):
if 'gevent.monkey' in sys.modules:
raise unittest.SkipTest("Can't test libev with gevent monkey patching")
class LibevConnectionTest(ConnectionTest, unittest.TestCase): class LibevConnectionTest(ConnectionTest, unittest.TestCase):
klass = LibevConnection klass = LibevConnection
def setUp(self): def setUp(self):
if 'gevent.monkey' in sys.modules:
raise unittest.SkipTest("Can't test libev with gevent monkey patching")
if LibevConnection is None: if LibevConnection is None:
raise unittest.SkipTest( raise unittest.SkipTest(
'libev does not appear to be installed properly') 'libev does not appear to be installed properly')

View File

@@ -1,3 +1,5 @@
from tests.integration import PROTOCOL_VERSION
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
@@ -36,7 +38,7 @@ class TestFactories(unittest.TestCase):
''' '''
def test_tuple_factory(self): def test_tuple_factory(self):
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
session.row_factory = tuple_factory session.row_factory = tuple_factory
@@ -58,7 +60,7 @@ class TestFactories(unittest.TestCase):
self.assertEqual(result[1][0], 2) self.assertEqual(result[1][0], 2)
def test_named_tuple_factoryy(self): def test_named_tuple_factoryy(self):
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
session.row_factory = named_tuple_factory session.row_factory = named_tuple_factory
@@ -79,7 +81,7 @@ class TestFactories(unittest.TestCase):
self.assertEqual(result[1].k, 2) self.assertEqual(result[1].k, 2)
def test_dict_factory(self): def test_dict_factory(self):
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
session.row_factory = dict_factory session.row_factory = dict_factory
@@ -101,7 +103,7 @@ class TestFactories(unittest.TestCase):
self.assertEqual(result[1]['k'], 2) self.assertEqual(result[1]['k'], 2)
def test_ordered_dict_factory(self): def test_ordered_dict_factory(self):
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
session.row_factory = ordered_dict_factory session.row_factory = ordered_dict_factory

View File

@@ -15,7 +15,7 @@ from cassandra.metadata import (Metadata, KeyspaceMetadata, TableMetadata,
from cassandra.policies import SimpleConvictionPolicy from cassandra.policies import SimpleConvictionPolicy
from cassandra.pool import Host from cassandra.pool import Host
from tests.integration import get_cluster from tests.integration import get_cluster, PROTOCOL_VERSION
class SchemaMetadataTest(unittest.TestCase): class SchemaMetadataTest(unittest.TestCase):
@@ -28,7 +28,7 @@ class SchemaMetadataTest(unittest.TestCase):
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
try: try:
results = session.execute("SELECT keyspace_name FROM system.schema_keyspaces") results = session.execute("SELECT keyspace_name FROM system.schema_keyspaces")
@@ -46,7 +46,8 @@ class SchemaMetadataTest(unittest.TestCase):
@classmethod @classmethod
def teardown_class(cls): def teardown_class(cls):
cluster = Cluster(['127.0.0.1']) cluster = Cluster(['127.0.0.1'],
protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
try: try:
session.execute("DROP KEYSPACE %s" % cls.ksname) session.execute("DROP KEYSPACE %s" % cls.ksname)
@@ -54,7 +55,8 @@ class SchemaMetadataTest(unittest.TestCase):
cluster.shutdown() cluster.shutdown()
def setUp(self): def setUp(self):
self.cluster = Cluster(['127.0.0.1']) self.cluster = Cluster(['127.0.0.1'],
protocol_version=PROTOCOL_VERSION)
self.session = self.cluster.connect() self.session = self.cluster.connect()
def tearDown(self): def tearDown(self):
@@ -294,7 +296,7 @@ class TestCodeCoverage(unittest.TestCase):
Test export schema functionality Test export schema functionality
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.connect() cluster.connect()
self.assertIsInstance(cluster.metadata.export_schema_as_string(), six.string_types) self.assertIsInstance(cluster.metadata.export_schema_as_string(), six.string_types)
@@ -304,7 +306,7 @@ class TestCodeCoverage(unittest.TestCase):
Test export keyspace schema functionality Test export keyspace schema functionality
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.connect() cluster.connect()
for keyspace in cluster.metadata.keyspaces: for keyspace in cluster.metadata.keyspaces:
@@ -317,7 +319,7 @@ class TestCodeCoverage(unittest.TestCase):
Test that names that need to be escaped in CREATE statements are Test that names that need to be escaped in CREATE statements are
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
ksname = 'AnInterestingKeyspace' ksname = 'AnInterestingKeyspace'
@@ -356,7 +358,7 @@ class TestCodeCoverage(unittest.TestCase):
Ensure AlreadyExists exception is thrown when hit Ensure AlreadyExists exception is thrown when hit
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
ksname = 'test3rf' ksname = 'test3rf'
@@ -380,7 +382,7 @@ class TestCodeCoverage(unittest.TestCase):
if murmur3 is None: if murmur3 is None:
raise unittest.SkipTest('the murmur3 extension is not available') raise unittest.SkipTest('the murmur3 extension is not available')
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.assertEqual(cluster.metadata.get_replicas('test3rf', 'key'), []) self.assertEqual(cluster.metadata.get_replicas('test3rf', 'key'), [])
cluster.connect('test3rf') cluster.connect('test3rf')
@@ -395,7 +397,7 @@ class TestCodeCoverage(unittest.TestCase):
Test token mappings Test token mappings
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.connect('test3rf') cluster.connect('test3rf')
ring = cluster.metadata.token_map.ring ring = cluster.metadata.token_map.ring
owners = list(cluster.metadata.token_map.token_to_host_owner[token] for token in ring) owners = list(cluster.metadata.token_map.token_to_host_owner[token] for token in ring)
@@ -418,7 +420,7 @@ class TokenMetadataTest(unittest.TestCase):
def test_token(self): def test_token(self):
expected_node_count = len(get_cluster().nodes) expected_node_count = len(get_cluster().nodes)
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.connect() cluster.connect()
tmap = cluster.metadata.token_map tmap = cluster.metadata.token_map
self.assertTrue(issubclass(tmap.token_class, Token)) self.assertTrue(issubclass(tmap.token_class, Token))

View File

@@ -7,7 +7,7 @@ from cassandra.query import SimpleStatement
from cassandra import ConsistencyLevel, WriteTimeout, Unavailable, ReadTimeout from cassandra import ConsistencyLevel, WriteTimeout, Unavailable, ReadTimeout
from cassandra.cluster import Cluster, NoHostAvailable from cassandra.cluster import Cluster, NoHostAvailable
from tests.integration import get_node, get_cluster from tests.integration import get_node, get_cluster, PROTOCOL_VERSION
class MetricsTests(unittest.TestCase): class MetricsTests(unittest.TestCase):
@@ -17,7 +17,8 @@ class MetricsTests(unittest.TestCase):
Trigger and ensure connection_errors are counted Trigger and ensure connection_errors are counted
""" """
cluster = Cluster(metrics_enabled=True) cluster = Cluster(metrics_enabled=True,
protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
session.execute("USE test3rf") session.execute("USE test3rf")
@@ -45,7 +46,8 @@ class MetricsTests(unittest.TestCase):
Attempt a write at cl.ALL and receive a WriteTimeout. Attempt a write at cl.ALL and receive a WriteTimeout.
""" """
cluster = Cluster(metrics_enabled=True) cluster = Cluster(metrics_enabled=True,
protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
# Test write # Test write
@@ -75,7 +77,8 @@ class MetricsTests(unittest.TestCase):
Attempt a read at cl.ALL and receive a ReadTimeout. Attempt a read at cl.ALL and receive a ReadTimeout.
""" """
cluster = Cluster(metrics_enabled=True) cluster = Cluster(metrics_enabled=True,
protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
# Test write # Test write
@@ -105,7 +108,8 @@ class MetricsTests(unittest.TestCase):
Attempt an insert/read at cl.ALL and receive a Unavailable Exception. Attempt an insert/read at cl.ALL and receive a Unavailable Exception.
""" """
cluster = Cluster(metrics_enabled=True) cluster = Cluster(metrics_enabled=True,
protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
# Test write # Test write

View File

@@ -1,3 +1,5 @@
from tests.integration import PROTOCOL_VERSION
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
@@ -15,7 +17,7 @@ class PreparedStatementTests(unittest.TestCase):
Test basic PreparedStatement usage Test basic PreparedStatement usage
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
session.execute( session.execute(
""" """
@@ -60,7 +62,7 @@ class PreparedStatementTests(unittest.TestCase):
when prepared statements are missing the primary key when prepared statements are missing the primary key
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
prepared = session.prepare( prepared = session.prepare(
@@ -77,7 +79,7 @@ class PreparedStatementTests(unittest.TestCase):
Ensure a ValueError is thrown when attempting to bind too many variables Ensure a ValueError is thrown when attempting to bind too many variables
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
prepared = session.prepare( prepared = session.prepare(
@@ -93,7 +95,7 @@ class PreparedStatementTests(unittest.TestCase):
Ensure binding None is handled correctly Ensure binding None is handled correctly
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
prepared = session.prepare( prepared = session.prepare(
@@ -120,7 +122,7 @@ class PreparedStatementTests(unittest.TestCase):
Ensure None binding over async queries Ensure None binding over async queries
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
prepared = session.prepare( prepared = session.prepare(

View File

@@ -10,13 +10,13 @@ from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.policies import HostDistance from cassandra.policies import HostDistance
from tests.integration import get_server_versions from tests.integration import get_server_versions, PROTOCOL_VERSION
class QueryTest(unittest.TestCase): class QueryTest(unittest.TestCase):
def test_query(self): def test_query(self):
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
prepared = session.prepare( prepared = session.prepare(
@@ -44,7 +44,7 @@ class QueryTest(unittest.TestCase):
Code coverage to ensure trace prints to string without error Code coverage to ensure trace prints to string without error
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
query = "SELECT * FROM system.local" query = "SELECT * FROM system.local"
@@ -57,7 +57,7 @@ class QueryTest(unittest.TestCase):
str(event) str(event)
def test_trace_ignores_row_factory(self): def test_trace_ignores_row_factory(self):
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
session.row_factory = dict_factory session.row_factory = dict_factory
@@ -78,7 +78,7 @@ class PreparedStatementTests(unittest.TestCase):
Simple code coverage to ensure routing_keys can be accessed Simple code coverage to ensure routing_keys can be accessed
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
prepared = session.prepare( prepared = session.prepare(
@@ -96,7 +96,7 @@ class PreparedStatementTests(unittest.TestCase):
the routing key should be None the routing key should be None
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
prepared = session.prepare( prepared = session.prepare(
@@ -115,7 +115,7 @@ class PreparedStatementTests(unittest.TestCase):
overrides the current routing key overrides the current routing key
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
prepared = session.prepare( prepared = session.prepare(
@@ -133,7 +133,7 @@ class PreparedStatementTests(unittest.TestCase):
Basic test that uses a fake routing_key_index Basic test that uses a fake routing_key_index
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
prepared = session.prepare( prepared = session.prepare(
@@ -151,7 +151,7 @@ class PreparedStatementTests(unittest.TestCase):
Ensure that bound.keyspace works as expected Ensure that bound.keyspace works as expected
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
prepared = session.prepare( prepared = session.prepare(
@@ -186,7 +186,7 @@ class PrintStatementTests(unittest.TestCase):
Highlight the difference between Prepared and Bound statements Highlight the difference between Prepared and Bound statements
""" """
cluster = Cluster() cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)') prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)')
@@ -202,13 +202,12 @@ class PrintStatementTests(unittest.TestCase):
class BatchStatementTests(unittest.TestCase): class BatchStatementTests(unittest.TestCase):
def setUp(self): def setUp(self):
cass_version, _ = get_server_versions() if PROTOCOL_VERSION < 2:
if cass_version < (2, 0):
raise unittest.SkipTest( raise unittest.SkipTest(
"Cassandra 2.0+ is required for BATCH operations, currently testing against %r" "Protocol 2.0+ is required for BATCH operations, currently testing against %r"
% (cass_version,)) % (PROTOCOL_VERSION,))
self.cluster = Cluster(protocol_version=2) self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
self.session = self.cluster.connect() self.session = self.cluster.connect()
@@ -272,13 +271,12 @@ class BatchStatementTests(unittest.TestCase):
class SerialConsistencyTests(unittest.TestCase): class SerialConsistencyTests(unittest.TestCase):
def setUp(self): def setUp(self):
cass_version, _ = get_server_versions() if PROTOCOL_VERSION < 2:
if cass_version < (2, 0):
raise unittest.SkipTest( raise unittest.SkipTest(
"Cassandra 2.0+ is required for BATCH operations, currently testing against %r" "Protocol 2.0+ is required for BATCH operations, currently testing against %r"
% (cass_version,)) % (PROTOCOL_VERSION,))
self.cluster = Cluster(protocol_version=2) self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
self.session = self.cluster.connect() self.session = self.cluster.connect()

View File

@@ -0,0 +1,112 @@
from tests.integration import PROTOCOL_VERSION
import logging
log = logging.getLogger(__name__)
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from itertools import cycle, count
from threading import Event
from cassandra.cluster import Cluster
from cassandra.concurrent import execute_concurrent
from cassandra.policies import HostDistance
from cassandra.query import SimpleStatement
class QueryPagingTests(unittest.TestCase):
def setUp(self):
if PROTOCOL_VERSION < 2:
raise unittest.SkipTest(
"Protocol 2.0+ is required for BATCH operations, currently testing against %r"
% (PROTOCOL_VERSION,))
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
self.session = self.cluster.connect()
self.session.execute("TRUNCATE test3rf.test")
def test_paging(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params)
prepared = self.session.prepare("SELECT * FROM test3rf.test")
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
self.session.default_fetch_size = fetch_size
self.assertEqual(100, len(list(self.session.execute("SELECT * FROM test3rf.test"))))
statement = SimpleStatement("SELECT * FROM test3rf.test")
self.assertEqual(100, len(list(self.session.execute(statement))))
self.assertEqual(100, len(list(self.session.execute(prepared))))
def test_async_paging(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params)
prepared = self.session.prepare("SELECT * FROM test3rf.test")
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
self.session.default_fetch_size = fetch_size
self.assertEqual(100, len(list(self.session.execute_async("SELECT * FROM test3rf.test").result())))
statement = SimpleStatement("SELECT * FROM test3rf.test")
self.assertEqual(100, len(list(self.session.execute_async(statement).result())))
self.assertEqual(100, len(list(self.session.execute_async(prepared).result())))
def test_paging_callbacks(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params)
prepared = self.session.prepare("SELECT * FROM test3rf.test")
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
self.session.default_fetch_size = fetch_size
future = self.session.execute_async("SELECT * FROM test3rf.test")
event = Event()
counter = count()
def handle_page(rows, future, counter):
for row in rows:
counter.next()
if future.has_more_pages:
future.start_fetching_next_page()
else:
event.set()
def handle_error(err):
event.set()
self.fail(err)
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait()
self.assertEquals(counter.next(), 100)
# simple statement
future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"))
event.clear()
counter = count()
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait()
self.assertEquals(counter.next(), 100)
# prepared statement
future = self.session.execute_async(prepared)
event.clear()
counter = count()
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait()
self.assertEquals(counter.next(), 100)

View File

@@ -18,7 +18,7 @@ from cassandra.cqltypes import Int32Type, EMPTY
from cassandra.query import dict_factory from cassandra.query import dict_factory
from cassandra.util import OrderedDict from cassandra.util import OrderedDict
from tests.integration import get_server_versions from tests.integration import get_server_versions, PROTOCOL_VERSION
class TypeTests(unittest.TestCase): class TypeTests(unittest.TestCase):
@@ -27,7 +27,7 @@ class TypeTests(unittest.TestCase):
self._cass_version, self._cql_version = get_server_versions() self._cass_version, self._cql_version = get_server_versions()
def test_blob_type_as_string(self): def test_blob_type_as_string(self):
c = Cluster() c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect() s = c.connect()
s.execute(""" s.execute("""
@@ -69,7 +69,7 @@ class TypeTests(unittest.TestCase):
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
def test_blob_type_as_bytearray(self): def test_blob_type_as_bytearray(self):
c = Cluster() c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect() s = c.connect()
s.execute(""" s.execute("""
@@ -129,7 +129,7 @@ class TypeTests(unittest.TestCase):
""" """
def test_basic_types(self): def test_basic_types(self):
c = Cluster() c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect() s = c.connect()
s.execute(""" s.execute("""
CREATE KEYSPACE typetests CREATE KEYSPACE typetests
@@ -226,7 +226,7 @@ class TypeTests(unittest.TestCase):
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
def test_empty_strings_and_nones(self): def test_empty_strings_and_nones(self):
c = Cluster() c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect() s = c.connect()
s.execute(""" s.execute("""
CREATE KEYSPACE test_empty_strings_and_nones CREATE KEYSPACE test_empty_strings_and_nones
@@ -329,7 +329,7 @@ class TypeTests(unittest.TestCase):
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None]) self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None])
def test_empty_values(self): def test_empty_values(self):
c = Cluster() c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect() s = c.connect()
s.execute(""" s.execute("""
CREATE KEYSPACE test_empty_values CREATE KEYSPACE test_empty_values
@@ -356,7 +356,7 @@ class TypeTests(unittest.TestCase):
eastern_tz = pytz.timezone('US/Eastern') eastern_tz = pytz.timezone('US/Eastern')
eastern_tz.localize(dt) eastern_tz.localize(dt)
c = Cluster() c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect() s = c.connect()
s.execute("""CREATE KEYSPACE tz_aware_test s.execute("""CREATE KEYSPACE tz_aware_test

View File

@@ -136,13 +136,8 @@ class ConnectionTest(unittest.TestCase):
""" """
Ensure the following methods throw NIE's. If not, come back and test them. Ensure the following methods throw NIE's. If not, come back and test them.
""" """
c = self.make_connection() c = self.make_connection()
self.assertRaises(NotImplementedError, c.close) self.assertRaises(NotImplementedError, c.close)
self.assertRaises(NotImplementedError, c.defunct, None)
self.assertRaises(NotImplementedError, c.send_msg, None, None)
self.assertRaises(NotImplementedError, c.wait_for_response, None)
self.assertRaises(NotImplementedError, c.wait_for_responses)
self.assertRaises(NotImplementedError, c.register_watcher, None, None) self.assertRaises(NotImplementedError, c.register_watcher, None, None)
self.assertRaises(NotImplementedError, c.register_watchers, None) self.assertRaises(NotImplementedError, c.register_watchers, None)

View File

@@ -59,8 +59,8 @@ class MockCluster(object):
self.scheduler = Mock(spec=_Scheduler) self.scheduler = Mock(spec=_Scheduler)
self.executor = Mock(spec=ThreadPoolExecutor) self.executor = Mock(spec=ThreadPoolExecutor)
def add_host(self, address, signal=False): def add_host(self, address, datacenter, rack, signal=False):
host = Host(address, SimpleConvictionPolicy) host = Host(address, SimpleConvictionPolicy, datacenter, rack)
self.added_hosts.append(host) self.added_hosts.append(host)
return host return host
@@ -212,6 +212,7 @@ class ControlConnectionTest(unittest.TestCase):
self.connection.peer_results[1].append( self.connection.peer_results[1].append(
["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", ["3", "103", "203"]] ["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", ["3", "103", "203"]]
) )
self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs)
self.control_connection.refresh_node_list_and_token_map() self.control_connection.refresh_node_list_and_token_map()
self.assertEqual(1, len(self.cluster.added_hosts)) self.assertEqual(1, len(self.cluster.added_hosts))
self.assertEqual(self.cluster.added_hosts[0].address, "192.168.1.3") self.assertEqual(self.cluster.added_hosts[0].address, "192.168.1.3")
@@ -250,7 +251,7 @@ class ControlConnectionTest(unittest.TestCase):
'address': ('1.2.3.4', 9000) 'address': ('1.2.3.4', 9000)
} }
self.control_connection._handle_topology_change(event) self.control_connection._handle_topology_change(event)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.add_host, '1.2.3.4', signal=True) self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map)
event = { event = {
'change_type': 'REMOVED_NODE', 'change_type': 'REMOVED_NODE',
@@ -272,7 +273,7 @@ class ControlConnectionTest(unittest.TestCase):
'address': ('1.2.3.4', 9000) 'address': ('1.2.3.4', 9000)
} }
self.control_connection._handle_status_change(event) self.control_connection._handle_status_change(event)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.add_host, '1.2.3.4', signal=True) self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map)
# do the same with a known Host # do the same with a known Host
event = { event = {

View File

@@ -35,6 +35,9 @@ class ResponseFutureTests(unittest.TestCase):
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
return ResponseFuture(session, message, query) return ResponseFuture(session, message, query)
def make_mock_response(self, results):
return Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=results, paging_state=None)
def test_result_message(self): def test_result_message(self):
session = self.make_basic_session() session = self.make_basic_session()
session._load_balancer.make_query_plan.return_value = ['ip1', 'ip2'] session._load_balancer.make_query_plan.return_value = ['ip1', 'ip2']
@@ -49,9 +52,7 @@ class ResponseFutureTests(unittest.TestCase):
connection = pool.borrow_connection.return_value connection = pool.borrow_connection.return_value
connection.send_msg.assert_called_once_with(rf.message, cb=ANY) connection.send_msg.assert_called_once_with(rf.message, cb=ANY)
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}]) rf._set_result(self.make_mock_response([{'col': 'val'}]))
rf._set_result(response)
result = rf.result() result = rf.result()
self.assertEqual(result, [{'col': 'val'}]) self.assertEqual(result, [{'col': 'val'}])
@@ -259,8 +260,7 @@ class ResponseFutureTests(unittest.TestCase):
rf = self.make_response_future(session) rf = self.make_response_future(session)
rf.send_request() rf.send_request()
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}]) rf._set_result(self.make_mock_response([{'col': 'val'}]))
rf._set_result(response)
result = rf.result() result = rf.result()
self.assertEqual(result, [{'col': 'val'}]) self.assertEqual(result, [{'col': 'val'}])
@@ -280,8 +280,7 @@ class ResponseFutureTests(unittest.TestCase):
rf = self.make_response_future(session) rf = self.make_response_future(session)
rf.send_request() rf.send_request()
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}]) rf._set_result(self.make_mock_response([{'col': 'val'}]))
rf._set_result(response)
self.assertEqual(rf.result(), [{'col': 'val'}]) self.assertEqual(rf.result(), [{'col': 'val'}])
# make sure the exception is recorded correctly # make sure the exception is recorded correctly
@@ -294,8 +293,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.add_callback(self.assertEqual, [{'col': 'val'}]) rf.add_callback(self.assertEqual, [{'col': 'val'}])
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}]) rf._set_result(self.make_mock_response([{'col': 'val'}]))
rf._set_result(response)
result = rf.result() result = rf.result()
self.assertEqual(result, [{'col': 'val'}]) self.assertEqual(result, [{'col': 'val'}])
@@ -349,8 +347,7 @@ class ResponseFutureTests(unittest.TestCase):
callback=self.assertEqual, callback_args=([{'col': 'val'}],), callback=self.assertEqual, callback_args=([{'col': 'val'}],),
errback=self.assertIsInstance, errback_args=(Exception,)) errback=self.assertIsInstance, errback_args=(Exception,))
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}]) rf._set_result(self.make_mock_response([{'col': 'val'}]))
rf._set_result(response)
self.assertEqual(rf.result(), [{'col': 'val'}]) self.assertEqual(rf.result(), [{'col': 'val'}])
def test_prepared_query_not_found(self): def test_prepared_query_not_found(self):