Merge branch '2.0' into py3k
Conflicts: cassandra/cluster.py cassandra/connection.py cassandra/io/asyncorereactor.py cassandra/io/libevreactor.py setup.py tests/integration/long/test_large_data.py
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
1.0.3
|
||||
1.1.0
|
||||
=====
|
||||
In Progress
|
||||
|
||||
Features
|
||||
--------
|
||||
* Gevent is now supported through monkey-patching the stdlib (PYTHON-7,
|
||||
github issue #46)
|
||||
* Support static columns in schemas, which are available starting in
|
||||
Cassandra 2.1. (github issue #91)
|
||||
|
||||
@@ -15,12 +17,23 @@ Bug Fixes
|
||||
* Ignore SSL_ERROR_WANT_READ and SSL_ERROR_WANT_WRITE socket errors. Previously,
|
||||
these resulted in the connection being defuncted, but they can safely be
|
||||
ignored by the driver.
|
||||
* Don't reconnect the control connection every time Cluster.connect() is
|
||||
called
|
||||
* Avoid race condition that could leave ResponseFuture callbacks uncalled
|
||||
if the callback was added outside of the event loop thread (github issue #95)
|
||||
* Properly escape keyspace name in Session.set_keyspace(). Previously, the
|
||||
keyspace name was quoted, but any quotes in the string were not escaped.
|
||||
* Avoid adding hosts to the load balancing policy before their datacenter
|
||||
and rack information has been set, if possible.
|
||||
* Avoid KeyError when updating metadata after droping a table (github issues
|
||||
#97, #98)
|
||||
|
||||
Other
|
||||
-----
|
||||
* Don't ignore column names when parsing typestrings. This is needed for
|
||||
user-defined type support. (github issue #90)
|
||||
* Better error message when libevwrapper is not found
|
||||
* Only try to import scales when metrics are enabled (github issue #92)
|
||||
|
||||
1.0.2
|
||||
=====
|
||||
|
@@ -2,6 +2,8 @@
|
||||
This module houses the main classes you will interact with,
|
||||
:class:`.Cluster` and :class:`.Session`.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
import socket
|
||||
@@ -37,8 +39,7 @@ from cassandra.decoder import (QueryMessage, ResultMessage,
|
||||
BatchMessage, RESULT_KIND_PREPARED,
|
||||
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
|
||||
RESULT_KIND_SCHEMA_CHANGE)
|
||||
from cassandra.metadata import Metadata
|
||||
# from cassandra.metrics import Metrics
|
||||
from cassandra.metadata import Metadata, protect_name
|
||||
from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy,
|
||||
ExponentialReconnectionPolicy, HostDistance,
|
||||
RetryPolicy)
|
||||
@@ -48,11 +49,15 @@ from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
|
||||
BatchStatement, bind_params, QueryTrace, Statement,
|
||||
named_tuple_factory, dict_factory)
|
||||
|
||||
# libev is all around faster, so we want to try and default to using that when we can
|
||||
try:
|
||||
from cassandra.io.libevreactor import LibevConnection as DefaultConnection
|
||||
except ImportError:
|
||||
from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA
|
||||
# default to gevent when we are monkey patched, otherwise if libev is available, use that as the
|
||||
# default because it's faster than asyncore
|
||||
if 'gevent.monkey' in sys.modules:
|
||||
from cassandra.io.geventreactor import GeventConnection as DefaultConnection
|
||||
else:
|
||||
try:
|
||||
from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA
|
||||
except ImportError:
|
||||
from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA
|
||||
|
||||
# Forces load of utf8 encoding module to avoid deadlock that occurs
|
||||
# if code that is being imported tries to import the module in a seperate
|
||||
@@ -147,8 +152,15 @@ class Cluster(object):
|
||||
server will be automatically used.
|
||||
"""
|
||||
|
||||
# TODO: docs
|
||||
protocol_version = 2
|
||||
"""
|
||||
The version of the native protocol to use. The protocol version 2
|
||||
add support for lightweight transactions, batch operations, and
|
||||
automatic query paging, but is only supported by Cassandra 2.0+. When
|
||||
working with Cassandra 1.2, this must be set to 1. You can also set
|
||||
this to 1 when working with Cassandra 2.0+, but features that require
|
||||
the version 2 protocol will not be enabled.
|
||||
"""
|
||||
|
||||
compression = True
|
||||
"""
|
||||
@@ -287,13 +299,6 @@ class Cluster(object):
|
||||
Any of the mutable Cluster attributes may be set as keyword arguments
|
||||
to the constructor.
|
||||
"""
|
||||
if 'gevent.monkey' in sys.modules:
|
||||
raise Exception(
|
||||
"gevent monkey-patching detected. This driver does not currently "
|
||||
"support gevent, and monkey patching will break the driver "
|
||||
"completely. You can track progress towards adding gevent "
|
||||
"support here: https://datastax-oss.atlassian.net/browse/PYTHON-7.")
|
||||
|
||||
self.contact_points = contact_points
|
||||
self.port = port
|
||||
self.compression = compression
|
||||
@@ -478,20 +483,21 @@ class Cluster(object):
|
||||
|
||||
self.load_balancing_policy.populate(
|
||||
weakref.proxy(self), self.metadata.all_hosts())
|
||||
|
||||
if self.control_connection:
|
||||
try:
|
||||
self.control_connection.connect()
|
||||
log.debug("Control connection created")
|
||||
except Exception:
|
||||
log.exception("Control connection failed to connect, "
|
||||
"shutting down Cluster:")
|
||||
self.shutdown()
|
||||
raise
|
||||
|
||||
self.load_balancing_policy.check_supported()
|
||||
|
||||
self._is_setup = True
|
||||
|
||||
if self.control_connection:
|
||||
try:
|
||||
self.control_connection.connect()
|
||||
log.debug("Control connection created")
|
||||
except Exception:
|
||||
log.exception("Control connection failed to connect, "
|
||||
"shutting down Cluster:")
|
||||
self.shutdown()
|
||||
raise
|
||||
|
||||
self.load_balancing_policy.check_supported()
|
||||
|
||||
session = self._new_session()
|
||||
if keyspace:
|
||||
session.set_keyspace(keyspace)
|
||||
@@ -772,13 +778,13 @@ class Cluster(object):
|
||||
self.on_down(host, is_host_addition, force_if_down=True)
|
||||
return is_down
|
||||
|
||||
def add_host(self, address, signal):
|
||||
def add_host(self, address, datacenter=None, rack=None, signal=True):
|
||||
"""
|
||||
Called when adding initial contact points and when the control
|
||||
connection subsequently discovers a new node. Intended for internal
|
||||
use only.
|
||||
"""
|
||||
new_host = self.metadata.add_host(address)
|
||||
new_host = self.metadata.add_host(address, datacenter, rack)
|
||||
if new_host and signal:
|
||||
log.info("New Cassandra host %s added", address)
|
||||
self.on_add(new_host)
|
||||
@@ -947,10 +953,10 @@ class Session(object):
|
||||
default_fetch_size = 5000
|
||||
"""
|
||||
By default, this many rows will be fetched at a time. This can be
|
||||
specified per-query through :attr:`~Statement.fetch_size`.
|
||||
specified per-query through :attr:`.Statement.fetch_size`.
|
||||
|
||||
This only takes effect when protocol version 2 or higher is used.
|
||||
See :attr:`~Cluster.protocol_version` for details.
|
||||
See :attr:`.Cluster.protocol_version` for details.
|
||||
"""
|
||||
|
||||
_lock = None
|
||||
@@ -1293,7 +1299,7 @@ class Session(object):
|
||||
Set the default keyspace for all queries made through this Session.
|
||||
This operation blocks until complete.
|
||||
"""
|
||||
self.execute('USE "%s"' % (keyspace,))
|
||||
self.execute('USE %s' % (protect_name(keyspace),))
|
||||
|
||||
def _set_keyspace_for_all_pools(self, keyspace, callback):
|
||||
"""
|
||||
@@ -1602,7 +1608,9 @@ class ControlConnection(object):
|
||||
|
||||
host = self._cluster.metadata.get_host(connection.host)
|
||||
if host:
|
||||
host.set_location_info(local_row["data_center"], local_row["rack"])
|
||||
datacenter = local_row.get("data_center")
|
||||
rack = local_row.get("rack")
|
||||
self._update_location_info(host, datacenter, rack)
|
||||
|
||||
partitioner = local_row.get("partitioner")
|
||||
tokens = local_row.get("tokens")
|
||||
@@ -1620,10 +1628,13 @@ class ControlConnection(object):
|
||||
found_hosts.add(addr)
|
||||
|
||||
host = self._cluster.metadata.get_host(addr)
|
||||
datacenter = row.get("data_center")
|
||||
rack = row.get("rack")
|
||||
if host is None:
|
||||
log.debug("[control connection] Found new host to connect to: %s", addr)
|
||||
host = self._cluster.add_host(addr, signal=True)
|
||||
host.set_location_info(row.get("data_center"), row.get("rack"))
|
||||
host = self._cluster.add_host(addr, datacenter, rack, signal=True)
|
||||
else:
|
||||
self._update_location_info(host, datacenter, rack)
|
||||
|
||||
tokens = row.get("tokens")
|
||||
if partitioner and tokens:
|
||||
@@ -1640,11 +1651,22 @@ class ControlConnection(object):
|
||||
log.debug("[control connection] Fetched ring info, rebuilding metadata")
|
||||
self._cluster.metadata.rebuild_token_map(partitioner, token_map)
|
||||
|
||||
def _update_location_info(self, host, datacenter, rack):
|
||||
if host.datacenter == datacenter and host.rack == rack:
|
||||
return
|
||||
|
||||
# If the dc/rack information changes, we need to update the load balancing policy.
|
||||
# For that, we remove and re-add the node against the policy. Not the most elegant, and assumes
|
||||
# that the policy will update correctly, but in practice this should work.
|
||||
self._cluster.load_balancing_policy.on_down(host)
|
||||
host.set_location_info(datacenter, rack)
|
||||
self._cluster.load_balancing_policy.on_up(host)
|
||||
|
||||
def _handle_topology_change(self, event):
|
||||
change_type = event["change_type"]
|
||||
addr, port = event["address"]
|
||||
if change_type == "NEW_NODE":
|
||||
self._cluster.scheduler.schedule(10, self._cluster.add_host, addr, signal=True)
|
||||
self._cluster.scheduler.schedule(10, self.refresh_node_list_and_token_map)
|
||||
elif change_type == "REMOVED_NODE":
|
||||
host = self._cluster.metadata.get_host(addr)
|
||||
self._cluster.scheduler.schedule(0, self._cluster.remove_host, host)
|
||||
@@ -1658,7 +1680,7 @@ class ControlConnection(object):
|
||||
if change_type == "UP":
|
||||
if host is None:
|
||||
# this is the first time we've seen the node
|
||||
self._cluster.scheduler.schedule(1, self._cluster.add_host, addr, signal=True)
|
||||
self._cluster.scheduler.schedule(1, self.refresh_node_list_and_token_map)
|
||||
else:
|
||||
# this will be run by the scheduler
|
||||
self._cluster.scheduler.schedule(1, self._cluster.on_up, host)
|
||||
@@ -1897,23 +1919,19 @@ class ResponseFuture(object):
|
||||
self.default_timeout = default_timeout
|
||||
self._metrics = metrics
|
||||
self.prepared_statement = prepared_statement
|
||||
self._callback_lock = Lock()
|
||||
if metrics is not None:
|
||||
self._start_time = time.time()
|
||||
|
||||
# convert the list/generator/etc to an iterator so that subsequent
|
||||
# calls to send_request (which retries may do) will resume where
|
||||
# they last left off
|
||||
self.query_plan = iter(session._load_balancer.make_query_plan(
|
||||
session.keyspace, query))
|
||||
|
||||
self._make_query_plan()
|
||||
self._event = Event()
|
||||
self._errors = {}
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
del self.session
|
||||
except AttributeError:
|
||||
pass
|
||||
def _make_query_plan(self):
|
||||
# convert the list/generator/etc to an iterator so that subsequent
|
||||
# calls to send_request (which retries may do) will resume where
|
||||
# they last left off
|
||||
self.query_plan = iter(self.session._load_balancer.make_query_plan(
|
||||
self.session.keyspace, self.query))
|
||||
|
||||
def send_request(self):
|
||||
""" Internal """
|
||||
@@ -1951,6 +1969,8 @@ class ResponseFuture(object):
|
||||
except Exception as exc:
|
||||
log.debug("Error querying host %s", host, exc_info=True)
|
||||
self._errors[host] = exc
|
||||
if self._metrics is not None:
|
||||
self._metrics.on_connection_error()
|
||||
if connection:
|
||||
pool.return_connection(connection)
|
||||
return None
|
||||
@@ -1960,6 +1980,33 @@ class ResponseFuture(object):
|
||||
self._connection = connection
|
||||
return request_id
|
||||
|
||||
@property
|
||||
def has_more_pages(self):
|
||||
"""
|
||||
Returns :const:`True` if there are more pages left in the
|
||||
query results, :const:`False` otherwise. This should only
|
||||
be checked after the first page has been returned.
|
||||
"""
|
||||
return self._paging_state is not None
|
||||
|
||||
def start_fetching_next_page(self):
|
||||
"""
|
||||
If there are more pages left in the query result, this asynchronously
|
||||
starts fetching the next page. If there are no pages left, :exc:`.QueryExhausted`
|
||||
is raised. Also see :attr:`.has_more_pages`.
|
||||
|
||||
This should only be called after the first page has been returned.
|
||||
"""
|
||||
if not self._paging_state:
|
||||
raise QueryExhausted()
|
||||
|
||||
self._make_query_plan()
|
||||
self.message.paging_state = self._paging_state
|
||||
self._event.clear()
|
||||
self._final_result = _NOT_SET
|
||||
self._final_exception = None
|
||||
self.send_request()
|
||||
|
||||
def _reprepare(self, prepare_message):
|
||||
cb = partial(self.session.submit, self._execute_after_prepare)
|
||||
request_id = self._query(self._current_host, prepare_message, cb=cb)
|
||||
@@ -2163,12 +2210,10 @@ class ResponseFuture(object):
|
||||
def _set_final_result(self, response):
|
||||
if self._metrics is not None:
|
||||
self._metrics.request_timer.addValue(time.time() - self._start_time)
|
||||
if hasattr(self, 'session'):
|
||||
try:
|
||||
del self.session # clear reference cycles
|
||||
except AttributeError:
|
||||
pass
|
||||
self._final_result = response
|
||||
|
||||
with self._callback_lock:
|
||||
self._final_result = response
|
||||
|
||||
self._event.set()
|
||||
if self._callback:
|
||||
fn, args, kwargs = self._callback
|
||||
@@ -2177,11 +2222,9 @@ class ResponseFuture(object):
|
||||
def _set_final_exception(self, response):
|
||||
if self._metrics is not None:
|
||||
self._metrics.request_timer.addValue(time.time() - self._start_time)
|
||||
try:
|
||||
del self.session # clear reference cycles
|
||||
except AttributeError:
|
||||
pass
|
||||
self._final_exception = response
|
||||
|
||||
with self._callback_lock:
|
||||
self._final_exception = response
|
||||
self._event.set()
|
||||
if self._errback:
|
||||
fn, args, kwargs = self._errback
|
||||
@@ -2242,13 +2285,19 @@ class ResponseFuture(object):
|
||||
timeout = self.default_timeout
|
||||
|
||||
if self._final_result is not _NOT_SET:
|
||||
return self._final_result
|
||||
if self._paging_state is None:
|
||||
return self._final_result
|
||||
else:
|
||||
return PagedResult(self, self._final_result)
|
||||
elif self._final_exception:
|
||||
raise self._final_exception
|
||||
else:
|
||||
self._event.wait(timeout=timeout)
|
||||
if self._final_result is not _NOT_SET:
|
||||
return self._final_result
|
||||
if self._paging_state is None:
|
||||
return self._final_result
|
||||
else:
|
||||
return PagedResult(self, self._final_result)
|
||||
elif self._final_exception:
|
||||
raise self._final_exception
|
||||
else:
|
||||
@@ -2301,10 +2350,14 @@ class ResponseFuture(object):
|
||||
>>> future.add_callback(handle_results, time.time(), should_log=True)
|
||||
|
||||
"""
|
||||
if self._final_result is not _NOT_SET:
|
||||
run_now = False
|
||||
with self._callback_lock:
|
||||
if self._final_result is not _NOT_SET:
|
||||
run_now = True
|
||||
else:
|
||||
self._callback = (fn, args, kwargs)
|
||||
if run_now:
|
||||
fn(self._final_result, *args, **kwargs)
|
||||
else:
|
||||
self._callback = (fn, args, kwargs)
|
||||
return self
|
||||
|
||||
def add_errback(self, fn, *args, **kwargs):
|
||||
@@ -2313,10 +2366,14 @@ class ResponseFuture(object):
|
||||
An Exception instance will be passed as the first positional argument
|
||||
to `fn`.
|
||||
"""
|
||||
if self._final_exception:
|
||||
run_now = False
|
||||
with self._callback_lock:
|
||||
if self._final_exception:
|
||||
run_now = True
|
||||
else:
|
||||
self._errback = (fn, args, kwargs)
|
||||
if run_now:
|
||||
fn(self._final_exception, *args, **kwargs)
|
||||
else:
|
||||
self._errback = (fn, args, kwargs)
|
||||
return self
|
||||
|
||||
def add_callbacks(self, callback, errback,
|
||||
@@ -2352,3 +2409,58 @@ class ResponseFuture(object):
|
||||
return "<ResponseFuture: query='%s' request_id=%s result=%s exception=%s host=%s>" \
|
||||
% (self.query, self._req_id, result, self._final_exception, self._current_host)
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
class QueryExhausted(Exception):
|
||||
"""
|
||||
Raised when :meth:`.ResultSet.start_fetching_next_page()` is called and
|
||||
there are no more pages. You can check :attr:`.ResultSet.has_more_pages`
|
||||
before calling to avoid this.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PagedResult(object):
|
||||
"""
|
||||
An iterator over the rows from a paged query result. Whenever the number
|
||||
of result rows for a query exceed the :attr:`~.query.Statement.fetch_size`
|
||||
(or :attr:`~.Session.default_fetch_size`, if not set) an instance of this
|
||||
class will be returned.
|
||||
|
||||
You can treat this as a normal iterator over rows::
|
||||
|
||||
>>> from cassandra.query import SimpleStatement
|
||||
>>> statement = SimpleStatement("SELECT * FROM users", fetch_size=10)
|
||||
>>> for user_row in session.execute(statement):
|
||||
... process_user(user_row)
|
||||
|
||||
Whenever there are no more rows in the current page, the next page will
|
||||
be fetched transparently. However, note that it _is_ possible for
|
||||
an :class:`Exception` to be raised while fetching the next page, just
|
||||
like you might see on a normal call to ``session.execute()``.
|
||||
"""
|
||||
|
||||
def __init__(self, response_future, initial_response):
|
||||
self.response_future = response_future
|
||||
self.current_response = iter(initial_response)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
try:
|
||||
return next(self.current_response)
|
||||
except StopIteration:
|
||||
if self.response_future._paging_state is None:
|
||||
raise
|
||||
|
||||
self.response_future.start_fetching_next_page()
|
||||
result = self.response_future.result()
|
||||
if self.response_future.has_more_pages:
|
||||
self.current_response = result.current_response
|
||||
else:
|
||||
self.current_response = iter(result)
|
||||
|
||||
return next(self.current_response)
|
||||
|
||||
__next__ = next
|
||||
|
148
cassandra/concurrent.py
Normal file
148
cassandra/concurrent.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import sys
|
||||
|
||||
from itertools import count, cycle
|
||||
from threading import Event
|
||||
|
||||
|
||||
def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True):
|
||||
"""
|
||||
Executes a sequence of (statement, parameters) tuples concurrently. Each
|
||||
``parameters`` item must be a sequence or :const:`None`.
|
||||
|
||||
A sequence of ``(success, result_or_exc)`` tuples is returned in the same
|
||||
order that the statements were passed in. If ``success`` if :const:`False`,
|
||||
there was an error executing the statement, and ``result_or_exc`` will be
|
||||
an :class:`Exception`. If ``success`` is :const:`True`, ``result_or_exc``
|
||||
will be the query result.
|
||||
|
||||
If `raise_on_first_error` is left as :const:`True`, execution will stop
|
||||
after the first failed statement and the corresponding exception will be
|
||||
raised.
|
||||
|
||||
The `concurrency` parameter controls how many statements will be executed
|
||||
concurrently. It is recommended that this be kept below the number of
|
||||
core connections per host times the number of connected hosts (see
|
||||
:meth:`.Cluster.set_core_connections_per_host`). If that amount is exceeded,
|
||||
the event loop thread may attempt to block on new connection creation,
|
||||
substantially impacting throughput.
|
||||
|
||||
Example usage::
|
||||
|
||||
select_statement = session.prepare("SELECT * FROM users WHERE id=?")
|
||||
|
||||
statements_and_params = []
|
||||
for user_id in user_ids:
|
||||
statatements_and_params.append(
|
||||
(select_statement, user_id))
|
||||
|
||||
results = execute_concurrent(
|
||||
session, statements_and_params, raise_on_first_error=False)
|
||||
|
||||
for (success, result) in results:
|
||||
if not success:
|
||||
handle_error(result) # result will be an Exception
|
||||
else:
|
||||
process_user(result[0]) # result will be a list of rows
|
||||
|
||||
"""
|
||||
if concurrency <= 0:
|
||||
raise ValueError("concurrency must be greater than 0")
|
||||
|
||||
if not statements_and_parameters:
|
||||
return []
|
||||
|
||||
event = Event()
|
||||
first_error = [] if raise_on_first_error else None
|
||||
to_execute = len(statements_and_parameters) # TODO handle iterators/generators
|
||||
results = [None] * to_execute
|
||||
num_finished = count(start=1)
|
||||
statements = enumerate(iter(statements_and_parameters))
|
||||
for i in xrange(min(concurrency, len(statements_and_parameters))):
|
||||
_execute_next(_sentinel, i, event, session, statements, results, num_finished, to_execute, first_error)
|
||||
|
||||
event.wait()
|
||||
if first_error:
|
||||
raise first_error[0]
|
||||
else:
|
||||
return results
|
||||
|
||||
|
||||
def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs):
|
||||
"""
|
||||
Like :meth:`~.execute_concurrent`, but takes a single statement and a
|
||||
sequence of parameters. Each item in ``parameters`` should be a sequence
|
||||
or :const:`None`.
|
||||
|
||||
Example usage::
|
||||
|
||||
statement = session.prepare("INSERT INTO mytable (a, b) VALUES (1, ?)")
|
||||
parameters = [(x,) for x in range(1000)]
|
||||
execute_concurrent_with_args(session, statement, parameters)
|
||||
"""
|
||||
return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs)
|
||||
|
||||
|
||||
_sentinel = object()
|
||||
|
||||
|
||||
def _handle_error(error, result_index, event, session, statements, results, num_finished, to_execute, first_error):
|
||||
if first_error is not None:
|
||||
first_error.append(error)
|
||||
event.set()
|
||||
return
|
||||
else:
|
||||
results[result_index] = (False, error)
|
||||
if num_finished.next() >= to_execute:
|
||||
event.set()
|
||||
return
|
||||
|
||||
try:
|
||||
(next_index, (statement, params)) = statements.next()
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
args = (next_index, event, session, statements, results, num_finished, to_execute, first_error)
|
||||
try:
|
||||
session.execute_async(statement, params).add_callbacks(
|
||||
callback=_execute_next, callback_args=args,
|
||||
errback=_handle_error, errback_args=args)
|
||||
except Exception as exc:
|
||||
if first_error is not None:
|
||||
first_error.append(sys.exc_info())
|
||||
event.set()
|
||||
return
|
||||
else:
|
||||
results[next_index] = (False, exc)
|
||||
if num_finished.next() >= to_execute:
|
||||
event.set()
|
||||
return
|
||||
|
||||
|
||||
def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error):
|
||||
if result is not _sentinel:
|
||||
results[result_index] = (True, result)
|
||||
finished = num_finished.next()
|
||||
if finished >= to_execute:
|
||||
event.set()
|
||||
return
|
||||
|
||||
try:
|
||||
(next_index, (statement, params)) = statements.next()
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
args = (next_index, event, session, statements, results, num_finished, to_execute, first_error)
|
||||
try:
|
||||
session.execute_async(statement, params).add_callbacks(
|
||||
callback=_execute_next, callback_args=args,
|
||||
errback=_handle_error, errback_args=args)
|
||||
except Exception as exc:
|
||||
if first_error is not None:
|
||||
first_error.append(sys.exc_info())
|
||||
event.set()
|
||||
return
|
||||
else:
|
||||
results[next_index] = (False, exc)
|
||||
if num_finished.next() >= to_execute:
|
||||
event.set()
|
||||
return
|
@@ -1,12 +1,20 @@
|
||||
import errno
|
||||
from functools import wraps, partial
|
||||
import logging
|
||||
import sys
|
||||
from threading import Event, RLock
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from six.moves.queue import Queue
|
||||
if 'gevent.monkey' in sys.modules:
|
||||
from gevent.queue import Queue, Empty
|
||||
else:
|
||||
from six.moves.queue import Queue, Empty # noqa
|
||||
|
||||
from six.moves import range
|
||||
|
||||
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
|
||||
from cassandra.marshal import int8_unpack, int32_pack, header_unpack
|
||||
from cassandra.marshal import int32_pack, header_unpack
|
||||
from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage,
|
||||
StartupMessage, ErrorMessage, CredentialsMessage,
|
||||
QueryMessage, ResultMessage, decode_response,
|
||||
@@ -151,16 +159,99 @@ class Connection(object):
|
||||
raise NotImplementedError()
|
||||
|
||||
def defunct(self, exc):
|
||||
raise NotImplementedError()
|
||||
with self.lock:
|
||||
if self.is_defunct or self.is_closed:
|
||||
return
|
||||
self.is_defunct = True
|
||||
|
||||
def send_msg(self, msg, cb):
|
||||
raise NotImplementedError()
|
||||
trace = traceback.format_exc(exc)
|
||||
if trace != "None":
|
||||
log.debug("Defuncting connection (%s) to %s: %s\n%s",
|
||||
id(self), self.host, exc, traceback.format_exc(exc))
|
||||
else:
|
||||
log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc)
|
||||
|
||||
def wait_for_response(self, msg, **kwargs):
|
||||
raise NotImplementedError()
|
||||
self.last_error = exc
|
||||
self.close()
|
||||
self.error_all_callbacks(exc)
|
||||
self.connected_event.set()
|
||||
return exc
|
||||
|
||||
def error_all_callbacks(self, exc):
|
||||
with self.lock:
|
||||
callbacks = self._callbacks
|
||||
self._callbacks = {}
|
||||
new_exc = ConnectionShutdown(str(exc))
|
||||
for cb in callbacks.values():
|
||||
try:
|
||||
cb(new_exc)
|
||||
except Exception:
|
||||
log.warn("Ignoring unhandled exception while erroring callbacks for a "
|
||||
"failed connection (%s) to host %s:",
|
||||
id(self), self.host, exc_info=True)
|
||||
|
||||
def handle_pushed(self, response):
|
||||
log.debug("Message pushed from server: %r", response)
|
||||
for cb in self._push_watchers.get(response.event_type, []):
|
||||
try:
|
||||
cb(response.event_args)
|
||||
except Exception:
|
||||
log.exception("Pushed event handler errored, ignoring:")
|
||||
|
||||
def send_msg(self, msg, cb, wait_for_id=False):
|
||||
if self.is_defunct:
|
||||
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
|
||||
elif self.is_closed:
|
||||
raise ConnectionShutdown("Connection to %s is closed" % self.host)
|
||||
|
||||
if not wait_for_id:
|
||||
try:
|
||||
request_id = self._id_queue.get_nowait()
|
||||
except Empty:
|
||||
raise ConnectionBusy(
|
||||
"Connection to %s is at the max number of requests" % self.host)
|
||||
else:
|
||||
request_id = self._id_queue.get()
|
||||
|
||||
self._callbacks[request_id] = cb
|
||||
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
|
||||
return request_id
|
||||
|
||||
def wait_for_response(self, msg, timeout=None):
|
||||
return self.wait_for_responses(msg, timeout=timeout)[0]
|
||||
|
||||
def wait_for_responses(self, *msgs, **kwargs):
|
||||
raise NotImplementedError()
|
||||
timeout = kwargs.get('timeout')
|
||||
waiter = ResponseWaiter(self, len(msgs))
|
||||
|
||||
# busy wait for sufficient space on the connection
|
||||
messages_sent = 0
|
||||
while True:
|
||||
needed = len(msgs) - messages_sent
|
||||
with self.lock:
|
||||
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
|
||||
self.in_flight += available
|
||||
|
||||
for i in range(messages_sent, messages_sent + available):
|
||||
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
|
||||
messages_sent += available
|
||||
|
||||
if messages_sent == len(msgs):
|
||||
break
|
||||
else:
|
||||
if timeout is not None:
|
||||
timeout -= 0.01
|
||||
if timeout <= 0.0:
|
||||
raise OperationTimedOut()
|
||||
time.sleep(0.01)
|
||||
|
||||
try:
|
||||
return waiter.deliver(timeout)
|
||||
except OperationTimedOut:
|
||||
raise
|
||||
except Exception, exc:
|
||||
self.defunct(exc)
|
||||
raise
|
||||
|
||||
def register_watcher(self, event_type, callback):
|
||||
raise NotImplementedError()
|
||||
|
@@ -1,16 +1,13 @@
|
||||
from collections import defaultdict, deque
|
||||
from functools import partial
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from threading import Event, Lock, Thread
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from six import BytesIO
|
||||
from six.moves import queue as Queue
|
||||
from six.moves import xrange
|
||||
from six.moves import range
|
||||
|
||||
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode
|
||||
|
||||
import asyncore
|
||||
@@ -21,9 +18,8 @@ except ImportError:
|
||||
ssl = None # NOQA
|
||||
|
||||
from cassandra import OperationTimedOut
|
||||
from cassandra.connection import (Connection, ResponseWaiter, ConnectionShutdown,
|
||||
ConnectionBusy, ConnectionException, NONBLOCKING,
|
||||
MAX_STREAM_PER_CONNECTION)
|
||||
from cassandra.connection import (Connection, ConnectionShutdown,
|
||||
ConnectionException, NONBLOCKING)
|
||||
from cassandra.decoder import RegisterMessage
|
||||
from cassandra.marshal import int32_unpack
|
||||
|
||||
@@ -172,44 +168,11 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
_starting_conns.discard(self)
|
||||
|
||||
if not self.is_defunct:
|
||||
self._error_all_callbacks(
|
||||
self.error_all_callbacks(
|
||||
ConnectionShutdown("Connection to %s was closed" % self.host))
|
||||
# don't leave in-progress operations hanging
|
||||
self.connected_event.set()
|
||||
|
||||
def defunct(self, exc):
|
||||
with self.lock:
|
||||
if self.is_defunct or self.is_closed:
|
||||
return
|
||||
self.is_defunct = True
|
||||
|
||||
trace = traceback.format_exc() #exc)
|
||||
if trace != "None":
|
||||
log.debug("Defuncting connection (%s) to %s: %s\n%s",
|
||||
id(self), self.host, exc, traceback.format_exc())
|
||||
else:
|
||||
log.debug("Defuncting connection (%s) to %s: %s",
|
||||
id(self), self.host, exc)
|
||||
|
||||
self.last_error = exc
|
||||
self.close()
|
||||
self._error_all_callbacks(exc)
|
||||
self.connected_event.set()
|
||||
return exc
|
||||
|
||||
def _error_all_callbacks(self, exc):
|
||||
with self.lock:
|
||||
callbacks = self._callbacks
|
||||
self._callbacks = {}
|
||||
new_exc = ConnectionShutdown(str(exc))
|
||||
for cb in callbacks.values():
|
||||
try:
|
||||
cb(new_exc)
|
||||
except Exception:
|
||||
log.warn("Ignoring unhandled exception while erroring callbacks for a "
|
||||
"failed connection (%s) to host %s:",
|
||||
id(self), self.host, exc_info=True)
|
||||
|
||||
def handle_connect(self):
|
||||
with _starting_conns_lock:
|
||||
_starting_conns.discard(self)
|
||||
@@ -300,19 +263,11 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
if not self._callbacks:
|
||||
self._readable = False
|
||||
|
||||
def handle_pushed(self, response):
|
||||
log.debug("Message pushed from server: %r", response)
|
||||
for cb in self._push_watchers.get(response.event_type, []):
|
||||
try:
|
||||
cb(response.event_args)
|
||||
except Exception:
|
||||
log.exception("Pushed event handler errored, ignoring:")
|
||||
|
||||
def push(self, data):
|
||||
sabs = self.out_buffer_size
|
||||
if len(data) > sabs:
|
||||
chunks = []
|
||||
for i in xrange(0, len(data), sabs):
|
||||
for i in range(0, len(data), sabs):
|
||||
chunks.append(data[i:i + sabs])
|
||||
else:
|
||||
chunks = [data]
|
||||
@@ -328,61 +283,6 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
def readable(self):
|
||||
return self._readable or (self._have_listeners and not (self.is_defunct or self.is_closed))
|
||||
|
||||
def send_msg(self, msg, cb, wait_for_id=False):
|
||||
if self.is_defunct:
|
||||
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
|
||||
elif self.is_closed:
|
||||
raise ConnectionShutdown("Connection to %s is closed" % self.host)
|
||||
|
||||
if not wait_for_id:
|
||||
try:
|
||||
request_id = self._id_queue.get_nowait()
|
||||
except Queue.Empty:
|
||||
raise ConnectionBusy(
|
||||
"Connection to %s is at the max number of requests" % self.host)
|
||||
else:
|
||||
request_id = self._id_queue.get()
|
||||
|
||||
self._callbacks[request_id] = cb
|
||||
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
|
||||
return request_id
|
||||
|
||||
def wait_for_response(self, msg, timeout=None):
|
||||
return self.wait_for_responses(msg, timeout=timeout)[0]
|
||||
|
||||
def wait_for_responses(self, *msgs, **kwargs):
|
||||
timeout = kwargs.get('timeout')
|
||||
waiter = ResponseWaiter(self, len(msgs))
|
||||
|
||||
# busy wait for sufficient space on the connection
|
||||
messages_sent = 0
|
||||
while True:
|
||||
needed = len(msgs) - messages_sent
|
||||
with self.lock:
|
||||
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
|
||||
self.in_flight += available
|
||||
|
||||
for i in range(messages_sent, messages_sent + available):
|
||||
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
|
||||
messages_sent += available
|
||||
|
||||
if messages_sent == len(msgs):
|
||||
break
|
||||
else:
|
||||
if timeout is not None:
|
||||
timeout -= 0.01
|
||||
if timeout <= 0.0:
|
||||
raise OperationTimedOut()
|
||||
time.sleep(0.01)
|
||||
|
||||
try:
|
||||
return waiter.deliver(timeout)
|
||||
except OperationTimedOut:
|
||||
raise
|
||||
except Exception as exc:
|
||||
self.defunct(exc)
|
||||
raise
|
||||
|
||||
def register_watcher(self, event_type, callback):
|
||||
self._push_watchers[event_type].add(callback)
|
||||
self._have_listeners = True
|
||||
|
190
cassandra/io/geventreactor.py
Normal file
190
cassandra/io/geventreactor.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import gevent
|
||||
from gevent import select, socket
|
||||
from gevent.event import Event
|
||||
from gevent.queue import Queue
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
import logging
|
||||
import os
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
from StringIO import StringIO # ignore flake8 warning: # NOQA
|
||||
|
||||
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL
|
||||
|
||||
from cassandra import OperationTimedOut
|
||||
from cassandra.connection import Connection, ConnectionShutdown
|
||||
from cassandra.decoder import RegisterMessage
|
||||
from cassandra.marshal import int32_unpack
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_timeout(err):
|
||||
return (
|
||||
err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or
|
||||
(err == EINVAL and os.name in ('nt', 'ce'))
|
||||
)
|
||||
|
||||
|
||||
class GeventConnection(Connection):
|
||||
"""
|
||||
An implementation of :class:`.Connection` that utilizes ``gevent``.
|
||||
"""
|
||||
|
||||
_total_reqd_bytes = 0
|
||||
_read_watcher = None
|
||||
_write_watcher = None
|
||||
_socket = None
|
||||
|
||||
@classmethod
|
||||
def factory(cls, *args, **kwargs):
|
||||
timeout = kwargs.pop('timeout', 5.0)
|
||||
conn = cls(*args, **kwargs)
|
||||
conn.connected_event.wait(timeout)
|
||||
if conn.last_error:
|
||||
raise conn.last_error
|
||||
elif not conn.connected_event.is_set():
|
||||
conn.close()
|
||||
raise OperationTimedOut("Timed out creating connection")
|
||||
else:
|
||||
return conn
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
Connection.__init__(self, *args, **kwargs)
|
||||
|
||||
self.connected_event = Event()
|
||||
self._iobuf = StringIO()
|
||||
self._write_queue = Queue()
|
||||
|
||||
self._callbacks = {}
|
||||
self._push_watchers = defaultdict(set)
|
||||
|
||||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._socket.settimeout(1.0)
|
||||
self._socket.connect((self.host, self.port))
|
||||
|
||||
if self.sockopts:
|
||||
for args in self.sockopts:
|
||||
self._socket.setsockopt(*args)
|
||||
|
||||
self._read_watcher = gevent.spawn(lambda: self.handle_read())
|
||||
self._write_watcher = gevent.spawn(lambda: self.handle_write())
|
||||
self._send_options_message()
|
||||
|
||||
def close(self):
|
||||
with self.lock:
|
||||
if self.is_closed:
|
||||
return
|
||||
self.is_closed = True
|
||||
|
||||
log.debug("Closing connection (%s) to %s" % (id(self), self.host))
|
||||
if self._read_watcher:
|
||||
self._read_watcher.kill()
|
||||
if self._write_watcher:
|
||||
self._write_watcher.kill()
|
||||
if self._socket:
|
||||
self._socket.close()
|
||||
log.debug("Closed socket to %s" % (self.host,))
|
||||
|
||||
if not self.is_defunct:
|
||||
self.error_all_callbacks(
|
||||
ConnectionShutdown("Connection to %s was closed" % self.host))
|
||||
# don't leave in-progress operations hanging
|
||||
self.connected_event.set()
|
||||
|
||||
def handle_close(self):
|
||||
log.debug("connection closed by server")
|
||||
self.close()
|
||||
|
||||
def handle_write(self):
|
||||
run_select = partial(select.select, (), (self._socket,), ())
|
||||
while True:
|
||||
try:
|
||||
next_msg = self._write_queue.get()
|
||||
run_select()
|
||||
except Exception as exc:
|
||||
log.debug("Exception during write select() for %s: %s", self, exc)
|
||||
self.defunct(exc)
|
||||
return
|
||||
|
||||
try:
|
||||
self._socket.sendall(next_msg)
|
||||
except socket.error as err:
|
||||
log.debug("Exception during socket sendall for %s: %s", self, err)
|
||||
self.defunct(err)
|
||||
return # Leave the write loop
|
||||
|
||||
def handle_read(self):
|
||||
run_select = partial(select.select, (self._socket,), (), ())
|
||||
while True:
|
||||
try:
|
||||
run_select()
|
||||
except Exception as exc:
|
||||
log.debug("Exception during read select() for %s: %s", self, exc)
|
||||
self.defunct(exc)
|
||||
return
|
||||
|
||||
try:
|
||||
buf = self._socket.recv(self.in_buffer_size)
|
||||
self._iobuf.write(buf)
|
||||
except socket.error as err:
|
||||
if not is_timeout(err):
|
||||
log.debug("Exception during socket recv for %s: %s", self, err)
|
||||
self.defunct(err)
|
||||
return # leave the read loop
|
||||
|
||||
if self._iobuf.tell():
|
||||
while True:
|
||||
pos = self._iobuf.tell()
|
||||
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
||||
# we don't have a complete header yet or we
|
||||
# already saw a header, but we don't have a
|
||||
# complete message yet
|
||||
break
|
||||
else:
|
||||
# have enough for header, read body len from header
|
||||
self._iobuf.seek(4)
|
||||
body_len = int32_unpack(self._iobuf.read(4))
|
||||
|
||||
# seek to end to get length of current buffer
|
||||
self._iobuf.seek(0, os.SEEK_END)
|
||||
pos = self._iobuf.tell()
|
||||
|
||||
if pos >= body_len + 8:
|
||||
# read message header and body
|
||||
self._iobuf.seek(0)
|
||||
msg = self._iobuf.read(8 + body_len)
|
||||
|
||||
# leave leftover in current buffer
|
||||
leftover = self._iobuf.read()
|
||||
self._iobuf = StringIO()
|
||||
self._iobuf.write(leftover)
|
||||
|
||||
self._total_reqd_bytes = 0
|
||||
self.process_msg(msg, body_len)
|
||||
else:
|
||||
self._total_reqd_bytes = body_len + 8
|
||||
break
|
||||
else:
|
||||
log.debug("connection closed by server")
|
||||
self.close()
|
||||
return
|
||||
|
||||
def push(self, data):
|
||||
chunk_size = self.out_buffer_size
|
||||
for i in xrange(0, len(data), chunk_size):
|
||||
self._write_queue.put(data[i:i + chunk_size])
|
||||
|
||||
def register_watcher(self, event_type, callback):
|
||||
self._push_watchers[event_type].add(callback)
|
||||
self.wait_for_response(RegisterMessage(event_list=[event_type]))
|
||||
|
||||
def register_watchers(self, type_callback_dict):
|
||||
for event_type, callback in type_callback_dict.items():
|
||||
self._push_watchers[event_type].add(callback)
|
||||
self.wait_for_response(RegisterMessage(event_list=type_callback_dict.keys()))
|
@@ -1,20 +1,13 @@
|
||||
from collections import defaultdict, deque
|
||||
from functools import partial, wraps
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from threading import Event, Lock, Thread
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from six.moves.queue import Queue
|
||||
from six.moves import cStringIO as StringIO
|
||||
from six.moves import xrange
|
||||
from six import BytesIO
|
||||
|
||||
from cassandra import OperationTimedOut
|
||||
from cassandra.connection import (Connection, ResponseWaiter, ConnectionShutdown,
|
||||
ConnectionBusy, NONBLOCKING,
|
||||
MAX_STREAM_PER_CONNECTION)
|
||||
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING
|
||||
from cassandra.decoder import RegisterMessage
|
||||
from cassandra.marshal import int32_unpack
|
||||
try:
|
||||
@@ -80,18 +73,6 @@ def _start_loop():
|
||||
return should_start
|
||||
|
||||
|
||||
def defunct_on_error(f):
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return f(self, *args, **kwargs)
|
||||
except Exception as exc:
|
||||
self.defunct(exc)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class LibevConnection(Connection):
|
||||
"""
|
||||
An implementation of :class:`.Connection` that uses libev for its event loop.
|
||||
@@ -192,7 +173,7 @@ class LibevConnection(Connection):
|
||||
Connection.__init__(self, *args, **kwargs)
|
||||
|
||||
self.connected_event = Event()
|
||||
self._iobuf = StringIO()
|
||||
self._iobuf = BytesIO()
|
||||
|
||||
self._callbacks = {}
|
||||
self._push_watchers = defaultdict(set)
|
||||
@@ -237,41 +218,9 @@ class LibevConnection(Connection):
|
||||
|
||||
# don't leave in-progress operations hanging
|
||||
if not self.is_defunct:
|
||||
self._error_all_callbacks(
|
||||
self.error_all_callbacks(
|
||||
ConnectionShutdown("Connection to %s was closed" % self.host))
|
||||
|
||||
def defunct(self, exc):
|
||||
with self.lock:
|
||||
if self.is_defunct or self.is_closed:
|
||||
return
|
||||
self.is_defunct = True
|
||||
|
||||
trace = traceback.format_exc(exc)
|
||||
if trace != "None":
|
||||
log.debug("Defuncting connection (%s) to %s: %s\n%s",
|
||||
id(self), self.host, exc, traceback.format_exc(exc))
|
||||
else:
|
||||
log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc)
|
||||
|
||||
self.last_error = exc
|
||||
self.close()
|
||||
self._error_all_callbacks(exc)
|
||||
self.connected_event.set()
|
||||
return exc
|
||||
|
||||
def _error_all_callbacks(self, exc):
|
||||
with self.lock:
|
||||
callbacks = self._callbacks
|
||||
self._callbacks = {}
|
||||
new_exc = ConnectionShutdown(str(exc))
|
||||
for cb in callbacks.values():
|
||||
try:
|
||||
cb(new_exc)
|
||||
except Exception:
|
||||
log.warn("Ignoring unhandled exception while erroring callbacks for a "
|
||||
"failed connection (%s) to host %s:",
|
||||
id(self), self.host, exc_info=True)
|
||||
|
||||
def handle_write(self, watcher, revents, errno=None):
|
||||
if revents & libev.EV_ERROR:
|
||||
if errno:
|
||||
@@ -351,7 +300,7 @@ class LibevConnection(Connection):
|
||||
|
||||
# leave leftover in current buffer
|
||||
leftover = self._iobuf.read()
|
||||
self._iobuf = StringIO()
|
||||
self._iobuf = BytesIO()
|
||||
self._iobuf.write(leftover)
|
||||
|
||||
self._total_reqd_bytes = 0
|
||||
@@ -363,14 +312,6 @@ class LibevConnection(Connection):
|
||||
log.debug("Connection %s closed by server", self)
|
||||
self.close()
|
||||
|
||||
def handle_pushed(self, response):
|
||||
log.debug("Message pushed from server: %r", response)
|
||||
for cb in self._push_watchers.get(response.event_type, []):
|
||||
try:
|
||||
cb(response.event_args)
|
||||
except Exception:
|
||||
log.exception("Pushed event handler errored, ignoring:")
|
||||
|
||||
def push(self, data):
|
||||
sabs = self.out_buffer_size
|
||||
if len(data) > sabs:
|
||||
@@ -384,61 +325,6 @@ class LibevConnection(Connection):
|
||||
self.deque.extend(chunks)
|
||||
_loop_notifier.send()
|
||||
|
||||
def send_msg(self, msg, cb, wait_for_id=False):
|
||||
if self.is_defunct:
|
||||
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
|
||||
elif self.is_closed:
|
||||
raise ConnectionShutdown("Connection to %s is closed" % self.host)
|
||||
|
||||
if not wait_for_id:
|
||||
try:
|
||||
request_id = self._id_queue.get_nowait()
|
||||
except Queue.Empty:
|
||||
raise ConnectionBusy(
|
||||
"Connection to %s is at the max number of requests" % self.host)
|
||||
else:
|
||||
request_id = self._id_queue.get()
|
||||
|
||||
self._callbacks[request_id] = cb
|
||||
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
|
||||
return request_id
|
||||
|
||||
def wait_for_response(self, msg, timeout=None):
|
||||
return self.wait_for_responses(msg, timeout=timeout)[0]
|
||||
|
||||
def wait_for_responses(self, *msgs, **kwargs):
|
||||
timeout = kwargs.get('timeout')
|
||||
waiter = ResponseWaiter(self, len(msgs))
|
||||
|
||||
# busy wait for sufficient space on the connection
|
||||
messages_sent = 0
|
||||
while True:
|
||||
needed = len(msgs) - messages_sent
|
||||
with self.lock:
|
||||
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
|
||||
self.in_flight += available
|
||||
|
||||
for i in range(messages_sent, messages_sent + available):
|
||||
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
|
||||
messages_sent += available
|
||||
|
||||
if messages_sent == len(msgs):
|
||||
break
|
||||
else:
|
||||
if timeout is not None:
|
||||
timeout -= 0.01
|
||||
if timeout <= 0.0:
|
||||
raise OperationTimedOut()
|
||||
time.sleep(0.01)
|
||||
|
||||
try:
|
||||
return waiter.deliver(timeout)
|
||||
except OperationTimedOut:
|
||||
raise
|
||||
except Exception as exc:
|
||||
self.defunct(exc)
|
||||
raise
|
||||
|
||||
def register_watcher(self, event_type, callback):
|
||||
self._push_watchers[event_type].add(callback)
|
||||
self.wait_for_response(RegisterMessage(event_list=[event_type]))
|
||||
|
@@ -157,7 +157,7 @@ class Metadata(object):
|
||||
|
||||
if not cf_results:
|
||||
# the table was removed
|
||||
del keyspace_meta.tables[table]
|
||||
keyspace_meta.tables.pop(table, None)
|
||||
else:
|
||||
assert len(cf_results) == 1
|
||||
keyspace_meta.tables[table] = self._build_table_metadata(
|
||||
@@ -346,11 +346,12 @@ class Metadata(object):
|
||||
else:
|
||||
return True
|
||||
|
||||
def add_host(self, address):
|
||||
def add_host(self, address, datacenter, rack):
|
||||
cluster = self.cluster_ref()
|
||||
with self._hosts_lock:
|
||||
if address not in self._hosts:
|
||||
new_host = Host(address, cluster.conviction_policy_factory)
|
||||
new_host = Host(
|
||||
address, cluster.conviction_policy_factory, datacenter, rack)
|
||||
self._hosts[address] = new_host
|
||||
else:
|
||||
return None
|
||||
|
@@ -137,6 +137,7 @@ class RoundRobinPolicy(LoadBalancingPolicy):
|
||||
|
||||
This load balancing policy is used by default.
|
||||
"""
|
||||
_live_hosts = frozenset(())
|
||||
|
||||
def populate(self, cluster, hosts):
|
||||
self._live_hosts = frozenset(hosts)
|
||||
|
@@ -71,7 +71,7 @@ class Host(object):
|
||||
_currently_handling_node_up = False
|
||||
_handle_node_up_condition = None
|
||||
|
||||
def __init__(self, inet_address, conviction_policy_factory):
|
||||
def __init__(self, inet_address, conviction_policy_factory, datacenter=None, rack=None):
|
||||
if inet_address is None:
|
||||
raise ValueError("inet_address may not be None")
|
||||
if conviction_policy_factory is None:
|
||||
@@ -79,6 +79,7 @@ class Host(object):
|
||||
|
||||
self.address = inet_address
|
||||
self.conviction_policy = conviction_policy_factory(self)
|
||||
self.set_location_info(datacenter, rack)
|
||||
self.lock = RLock()
|
||||
self._handle_node_up_condition = Condition()
|
||||
|
||||
|
@@ -73,6 +73,13 @@ class Statement(object):
|
||||
"""
|
||||
|
||||
fetch_size = None
|
||||
"""
|
||||
How many rows will be fetched at a time. This overrides the default
|
||||
of :attr:`.Session.default_fetch_size`
|
||||
|
||||
This only takes effect when protocol version 2 or higher is used.
|
||||
See :attr:`.Cluster.protocol_version` for details.
|
||||
"""
|
||||
|
||||
_serial_consistency_level = None
|
||||
_routing_key = None
|
||||
@@ -561,6 +568,7 @@ class QueryTrace(object):
|
||||
if max_wait is not None and time_spent >= max_wait:
|
||||
raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,))
|
||||
|
||||
log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id)
|
||||
session_results = self._execute(
|
||||
self._SELECT_SESSIONS_FORMAT, (self.trace_id,), time_spent, max_wait)
|
||||
|
||||
@@ -568,6 +576,7 @@ class QueryTrace(object):
|
||||
time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt))
|
||||
attempt += 1
|
||||
continue
|
||||
log.debug("Fetched trace info for trace ID: %s", self.trace_id)
|
||||
|
||||
session_row = session_results[0]
|
||||
self.request_type = session_row.request
|
||||
@@ -576,9 +585,11 @@ class QueryTrace(object):
|
||||
self.coordinator = session_row.coordinator
|
||||
self.parameters = session_row.parameters
|
||||
|
||||
log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id)
|
||||
time_spent = time.time() - start
|
||||
event_results = self._execute(
|
||||
self._SELECT_EVENTS_FORMAT, (self.trace_id,), time_spent, max_wait)
|
||||
log.debug("Fetched trace events for trace ID: %s", self.trace_id)
|
||||
self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread)
|
||||
for r in event_results)
|
||||
break
|
||||
|
@@ -7,6 +7,8 @@
|
||||
|
||||
.. autoattribute:: cql_version
|
||||
|
||||
.. autoattribute:: protocol_version
|
||||
|
||||
.. autoattribute:: port
|
||||
|
||||
.. autoattribute:: compression
|
||||
@@ -59,6 +61,8 @@
|
||||
|
||||
.. autoattribute:: row_factory
|
||||
|
||||
.. autoattribute:: default_fetch_size
|
||||
|
||||
.. automethod:: execute(statement[, parameters][, timeout][, trace])
|
||||
|
||||
.. automethod:: execute_async(statement[, parameters][, trace])
|
||||
@@ -77,11 +81,20 @@
|
||||
|
||||
.. automethod:: get_query_trace()
|
||||
|
||||
.. autoattribute:: has_more_pages
|
||||
|
||||
.. automethod:: start_fetching_next_page()
|
||||
|
||||
.. automethod:: add_callback(fn, *args, **kwargs)
|
||||
|
||||
.. automethod:: add_errback(fn, *args, **kwargs)
|
||||
|
||||
.. automethod:: add_callbacks(callback, errback, callback_args=(), callback_kwargs=None, errback_args=(), errback_args=None)
|
||||
|
||||
.. autoclass:: PagedResult ()
|
||||
:members:
|
||||
|
||||
.. autoexception:: QueryExhausted ()
|
||||
|
||||
.. autoexception:: NoHostAvailable ()
|
||||
:members:
|
||||
|
@@ -8,6 +8,7 @@ Python Cassandra Driver
|
||||
installation
|
||||
getting_started
|
||||
performance
|
||||
query_paging
|
||||
|
||||
Indices and Tables
|
||||
==================
|
||||
|
74
docs/query_paging.rst
Normal file
74
docs/query_paging.rst
Normal file
@@ -0,0 +1,74 @@
|
||||
Paging Large Queries
|
||||
====================
|
||||
Cassandra 2.0+ offers support for automatic query paging. Starting with
|
||||
version 2.0 of the driver, if :attr:`~.Cluster.protocol_version` is set to
|
||||
:const:`2` (it is by default), queries returning large result sets will be
|
||||
automatically paged.
|
||||
|
||||
Controlling the Page Size
|
||||
-------------------------
|
||||
By default, :attr:`.Session.default_fetch_size` controls how many rows will
|
||||
be fetched per page. This can be overridden per-query by setting
|
||||
:attr:`~.fetch_size` on a :class:`~.Statement`. By default, each page
|
||||
will contain at most 5000 rows.
|
||||
|
||||
Handling Paged Results
|
||||
----------------------
|
||||
Whenever the number of result rows for are query exceed the page size, an
|
||||
instance of :class:`~.PagedResult` will be returned instead of a normal
|
||||
list. This class implements the iterator interface, so you can treat
|
||||
it like a normal iterator over rows::
|
||||
|
||||
from cassandra.query import SimpleStatement
|
||||
query = "SELECT * FROM users" # users contains 100 rows
|
||||
statement = SimpleStatement(query, fetch_size=10)
|
||||
for user_row in session.execute(statement):
|
||||
process_user(user_row)
|
||||
|
||||
Whenever there are no more rows in the current page, the next page will
|
||||
be fetched transparently. However, note that it *is* possible for
|
||||
an :class:`Exception` to be raised while fetching the next page, just
|
||||
like you might see on a normal call to ``session.execute()``.
|
||||
|
||||
If you use :meth:`.Session.execute_async()` along with,
|
||||
:meth:`.ResponseFuture.result()`, the first page will be fetched before
|
||||
:meth:`~.ResponseFuture.result()` returns, but latter pages will be
|
||||
transparently fetched synchronously while iterating the result.
|
||||
|
||||
Handling Paged Results with Callbacks
|
||||
-------------------------------------
|
||||
If callbacks are attached to a query that returns a paged result,
|
||||
the callback will be called once per page with a normal list of rows.
|
||||
|
||||
Use :attr:`.ResponseFuture.has_more_pages` and
|
||||
:meth:`.ResponseFuture.start_fetching_next_page()` to continue fetching
|
||||
pages. For example::
|
||||
|
||||
class PagedResultHandler(object):
|
||||
|
||||
def __init__(self, future):
|
||||
self.error = None
|
||||
self.finished_event = Event()
|
||||
self.future = future
|
||||
self.future.add_callbacks(
|
||||
callback=self.handle_page,
|
||||
errback=self.handle_err)
|
||||
|
||||
def handle_page(self, rows):
|
||||
for row in rows:
|
||||
process_row(row)
|
||||
|
||||
if self.future.has_more_pages:
|
||||
self.future.start_fetching_next_page()
|
||||
else:
|
||||
self.finished_event.set()
|
||||
|
||||
def handle_error(self, exc):
|
||||
self.error = exc
|
||||
self.finished_event.set()
|
||||
|
||||
future = session.execute_async("SELECT * FROM users")
|
||||
handler = PagedResultHandler(future)
|
||||
handler.finished_event.wait()
|
||||
if handler.error:
|
||||
raise handler.error
|
36
setup.py
36
setup.py
@@ -1,18 +1,13 @@
|
||||
from __future__ import print_function
|
||||
import platform
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
has_subprocess = True
|
||||
except ImportError:
|
||||
has_subprocess = False
|
||||
|
||||
import ez_setup
|
||||
ez_setup.use_setuptools()
|
||||
|
||||
if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests":
|
||||
from gevent.monkey import patch_all
|
||||
patch_all()
|
||||
|
||||
from setuptools import setup
|
||||
from distutils.command.build_ext import build_ext
|
||||
from distutils.core import Extension
|
||||
@@ -20,6 +15,19 @@ from distutils.errors import (CCompilerError, DistutilsPlatformError,
|
||||
DistutilsExecError)
|
||||
from distutils.cmd import Command
|
||||
|
||||
|
||||
import platform
|
||||
import os
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
has_subprocess = True
|
||||
except ImportError:
|
||||
has_subprocess = False
|
||||
|
||||
from nose.commands import nosetests
|
||||
|
||||
from cassandra import __version__
|
||||
|
||||
long_description = ""
|
||||
@@ -27,6 +35,10 @@ with open("README.rst") as f:
|
||||
long_description = f.read()
|
||||
|
||||
|
||||
class gevent_nosetests(nosetests):
|
||||
description = "run nosetests with gevent monkey patching"
|
||||
|
||||
|
||||
class DocCommand(Command):
|
||||
|
||||
description = "generate or test documentation"
|
||||
@@ -144,12 +156,12 @@ On OSX, via homebrew:
|
||||
|
||||
|
||||
def run_setup(extensions):
|
||||
kw = {'cmdclass': {'doc': DocCommand}}
|
||||
kw = {'cmdclass': {'doc': DocCommand, 'gevent_nosetests': gevent_nosetests}}
|
||||
if extensions:
|
||||
kw['cmdclass']['build_ext'] = build_extensions
|
||||
kw['ext_modules'] = extensions
|
||||
|
||||
dependencies = ['futures', 'scales', 'blist', 'six >=1.6']
|
||||
dependencies = ['futures', 'scales >=1.0.5', 'blist', 'six >=1.6']
|
||||
if platform.python_implementation() != "CPython":
|
||||
dependencies.remove('blist')
|
||||
|
||||
@@ -164,7 +176,7 @@ def run_setup(extensions):
|
||||
packages=['cassandra', 'cassandra.io'],
|
||||
include_package_data=True,
|
||||
install_requires=dependencies,
|
||||
tests_require=['nose', 'mock', 'PyYAML'],
|
||||
tests_require=['nose', 'mock', 'PyYAML', 'pytz'],
|
||||
classifiers=[
|
||||
'Development Status :: 5 - Production/Stable',
|
||||
'Intended Audience :: Developers',
|
||||
|
@@ -18,7 +18,14 @@ except ImportError as e:
|
||||
|
||||
CLUSTER_NAME = 'test_cluster'
|
||||
CCM_CLUSTER = None
|
||||
DEFAULT_CASSANDRA_VERSION = '2.0.5'
|
||||
|
||||
CASSANDRA_VERSION = os.getenv('CASSANDRA_VERSION', '2.0.6')
|
||||
|
||||
if CASSANDRA_VERSION.startswith('1'):
|
||||
DEFAULT_PROTOCOL_VERSION = 1
|
||||
else:
|
||||
DEFAULT_PROTOCOL_VERSION = 2
|
||||
PROTOCOL_VERSION = int(os.getenv('PROTOCOL_VERSION', DEFAULT_PROTOCOL_VERSION))
|
||||
|
||||
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'ccm')
|
||||
if not os.path.exists(path):
|
||||
@@ -38,7 +45,7 @@ def get_server_versions():
|
||||
if cass_version is not None:
|
||||
return (cass_version, cql_version)
|
||||
|
||||
c = Cluster()
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
s.set_keyspace('system')
|
||||
row = s.execute('SELECT cql_version, release_version FROM local')[0]
|
||||
@@ -67,16 +74,16 @@ def get_node(node_id):
|
||||
|
||||
|
||||
def setup_package():
|
||||
version = os.getenv("CASSANDRA_VERSION", DEFAULT_CASSANDRA_VERSION)
|
||||
print 'Using Cassandra version: %s' % CASSANDRA_VERSION
|
||||
try:
|
||||
try:
|
||||
cluster = CCMCluster.load(path, CLUSTER_NAME)
|
||||
log.debug("Found existing ccm test cluster, clearing")
|
||||
cluster.clear()
|
||||
cluster.set_cassandra_dir(cassandra_version=version)
|
||||
cluster.set_cassandra_dir(cassandra_version=CASSANDRA_VERSION)
|
||||
except Exception:
|
||||
log.debug("Creating new ccm test cluster with version %s", version)
|
||||
cluster = CCMCluster(path, CLUSTER_NAME, cassandra_version=version)
|
||||
log.debug("Creating new ccm test cluster with version %s", CASSANDRA_VERSION)
|
||||
cluster = CCMCluster(path, CLUSTER_NAME, cassandra_version=CASSANDRA_VERSION)
|
||||
cluster.set_configuration_options({'start_native_transport': True})
|
||||
common.switch_cluster(path, CLUSTER_NAME)
|
||||
cluster.populate(3)
|
||||
@@ -93,7 +100,7 @@ def setup_package():
|
||||
|
||||
|
||||
def setup_test_keyspace():
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
try:
|
||||
|
@@ -7,6 +7,7 @@ from cassandra.cluster import Cluster
|
||||
from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy, \
|
||||
DowngradingConsistencyRetryPolicy
|
||||
from cassandra.query import SimpleStatement
|
||||
from tests.integration import PROTOCOL_VERSION
|
||||
|
||||
from tests.integration.long.utils import force_stop, create_schema, \
|
||||
wait_for_down, wait_for_up, start, CoordinatorStats
|
||||
@@ -98,7 +99,8 @@ class ConsistencyTests(unittest.TestCase):
|
||||
|
||||
def _test_tokenaware_one_node_down(self, keyspace, rf, accepted):
|
||||
cluster = Cluster(
|
||||
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()))
|
||||
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
|
||||
protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
wait_for_up(cluster, 1, wait=False)
|
||||
wait_for_up(cluster, 2)
|
||||
@@ -147,7 +149,8 @@ class ConsistencyTests(unittest.TestCase):
|
||||
def test_rfthree_tokenaware_none_down(self):
|
||||
keyspace = 'test_rfthree_tokenaware_none_down'
|
||||
cluster = Cluster(
|
||||
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()))
|
||||
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
|
||||
protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
wait_for_up(cluster, 1, wait=False)
|
||||
wait_for_up(cluster, 2)
|
||||
@@ -169,7 +172,8 @@ class ConsistencyTests(unittest.TestCase):
|
||||
def _test_downgrading_cl(self, keyspace, rf, accepted):
|
||||
cluster = Cluster(
|
||||
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
|
||||
default_retry_policy=DowngradingConsistencyRetryPolicy())
|
||||
default_retry_policy=DowngradingConsistencyRetryPolicy(),
|
||||
protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
create_schema(session, keyspace, replication_factor=rf)
|
||||
@@ -210,14 +214,16 @@ class ConsistencyTests(unittest.TestCase):
|
||||
keyspace = 'test_rfthree_roundrobin_downgradingcl'
|
||||
cluster = Cluster(
|
||||
load_balancing_policy=RoundRobinPolicy(),
|
||||
default_retry_policy=DowngradingConsistencyRetryPolicy())
|
||||
default_retry_policy=DowngradingConsistencyRetryPolicy(),
|
||||
protocol_version=PROTOCOL_VERSION)
|
||||
self.rfthree_downgradingcl(cluster, keyspace, True)
|
||||
|
||||
def test_rfthree_tokenaware_downgradingcl(self):
|
||||
keyspace = 'test_rfthree_tokenaware_downgradingcl'
|
||||
cluster = Cluster(
|
||||
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
|
||||
default_retry_policy=DowngradingConsistencyRetryPolicy())
|
||||
default_retry_policy=DowngradingConsistencyRetryPolicy(),
|
||||
protocol_version=PROTOCOL_VERSION)
|
||||
self.rfthree_downgradingcl(cluster, keyspace, False)
|
||||
|
||||
def rfthree_downgradingcl(self, cluster, keyspace, roundrobin):
|
||||
|
@@ -1,14 +1,15 @@
|
||||
try:
|
||||
from Queue import Queue, Empty
|
||||
except ImportError:
|
||||
from queue import Queue, Empty
|
||||
from queue import Queue, Empty # noqa
|
||||
from struct import pack
|
||||
import unittest
|
||||
|
||||
from cassandra import ConsistencyLevel
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.decoder import dict_factory
|
||||
from cassandra.query import dict_factory
|
||||
from cassandra.query import SimpleStatement
|
||||
from tests.integration import PROTOCOL_VERSION
|
||||
from tests.integration.long.utils import create_schema
|
||||
|
||||
|
||||
@@ -32,7 +33,7 @@ class LargeDataTests(unittest.TestCase):
|
||||
self.keyspace = 'large_data'
|
||||
|
||||
def make_session_and_keyspace(self):
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
session.default_timeout = 20.0 # increase the default timeout
|
||||
session.row_factory = dict_factory
|
||||
@@ -41,9 +42,10 @@ class LargeDataTests(unittest.TestCase):
|
||||
return session
|
||||
|
||||
def batch_futures(self, session, statement_generator):
|
||||
futures = Queue(maxsize=121)
|
||||
concurrency = 50
|
||||
futures = Queue.Queue(maxsize=concurrency)
|
||||
for i, statement in enumerate(statement_generator):
|
||||
if i > 0 and i % 120 == 0:
|
||||
if i > 0 and i % (concurrency - 1) == 0:
|
||||
# clear the existing queue
|
||||
while True:
|
||||
try:
|
||||
@@ -70,7 +72,7 @@ class LargeDataTests(unittest.TestCase):
|
||||
session,
|
||||
(SimpleStatement('INSERT INTO %s (k, i) VALUES (0, %s)' % (table, i),
|
||||
consistency_level=ConsistencyLevel.QUORUM)
|
||||
for i in range(1000000)))
|
||||
for i in range(100000)))
|
||||
|
||||
# Read
|
||||
results = session.execute('SELECT i FROM %s WHERE k=%s' % (table, 0))
|
||||
@@ -112,7 +114,7 @@ class LargeDataTests(unittest.TestCase):
|
||||
session,
|
||||
(SimpleStatement('INSERT INTO %s (k, i, v) VALUES (0, %s, %s)' % (table, i, str(bb)),
|
||||
consistency_level=ConsistencyLevel.QUORUM)
|
||||
for i in range(1000000)))
|
||||
for i in range(100000)))
|
||||
|
||||
# Read
|
||||
results = session.execute('SELECT i, v FROM %s WHERE k=%s' % (table, 0))
|
||||
|
@@ -3,6 +3,7 @@ import logging
|
||||
from cassandra import ConsistencyLevel
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.query import SimpleStatement
|
||||
from tests.integration import PROTOCOL_VERSION
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
@@ -15,7 +16,7 @@ log = logging.getLogger(__name__)
|
||||
class SchemaTests(unittest.TestCase):
|
||||
|
||||
def test_recreates(self):
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
replication_factor = 3
|
||||
|
||||
|
@@ -4,7 +4,7 @@ import time
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from cassandra.decoder import named_tuple_factory
|
||||
from cassandra.query import named_tuple_factory
|
||||
|
||||
from tests.integration import get_node
|
||||
|
||||
@@ -74,7 +74,9 @@ def stop(node):
|
||||
|
||||
|
||||
def force_stop(node):
|
||||
log.debug("Forcing stop of node %s", node)
|
||||
get_node(node).stop(wait=False, gently=False)
|
||||
log.debug("Node %s was stopped", node)
|
||||
|
||||
|
||||
def ring(node):
|
||||
@@ -85,6 +87,7 @@ def ring(node):
|
||||
def wait_for_up(cluster, node, wait=True):
|
||||
while True:
|
||||
host = cluster.metadata.get_host('127.0.0.%s' % node)
|
||||
time.sleep(0.1)
|
||||
if host and host.is_up:
|
||||
# BUG: shouldn't have to, but we do
|
||||
if wait:
|
||||
@@ -93,10 +96,14 @@ def wait_for_up(cluster, node, wait=True):
|
||||
|
||||
|
||||
def wait_for_down(cluster, node, wait=True):
|
||||
log.debug("Waiting for node %s to be down", node)
|
||||
while True:
|
||||
host = cluster.metadata.get_host('127.0.0.%s' % node)
|
||||
time.sleep(0.1)
|
||||
if not host or not host.is_up:
|
||||
# BUG: shouldn't have to, but we do
|
||||
if wait:
|
||||
log.debug("Sleeping 5s until host is up")
|
||||
time.sleep(5)
|
||||
log.debug("Done waiting for node %s to be down", node)
|
||||
return
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from tests.integration import PROTOCOL_VERSION
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
@@ -5,7 +7,6 @@ except ImportError:
|
||||
|
||||
import cassandra
|
||||
from cassandra.query import SimpleStatement, TraceUnavailable
|
||||
from cassandra.io.asyncorereactor import AsyncoreConnection
|
||||
from cassandra.policies import RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance
|
||||
|
||||
from cassandra.cluster import Cluster, NoHostAvailable
|
||||
@@ -18,7 +19,7 @@ class ClusterTests(unittest.TestCase):
|
||||
Test basic connection and usage
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
result = session.execute(
|
||||
"""
|
||||
@@ -54,7 +55,7 @@ class ClusterTests(unittest.TestCase):
|
||||
Ensure clusters that connect on a keyspace, do
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
result = session.execute(
|
||||
"""
|
||||
@@ -71,7 +72,7 @@ class ClusterTests(unittest.TestCase):
|
||||
self.assertEqual(result, result2)
|
||||
|
||||
def test_set_keyspace_twice(self):
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
session.execute("USE system")
|
||||
session.execute("USE system")
|
||||
@@ -86,7 +87,7 @@ class ClusterTests(unittest.TestCase):
|
||||
reconnection_policy=ExponentialReconnectionPolicy(1.0, 600.0),
|
||||
default_retry_policy=RetryPolicy(),
|
||||
conviction_policy_factory=SimpleConvictionPolicy,
|
||||
connection_class=AsyncoreConnection
|
||||
protocol_version=PROTOCOL_VERSION
|
||||
)
|
||||
|
||||
def test_double_shutdown(self):
|
||||
@@ -94,7 +95,7 @@ class ClusterTests(unittest.TestCase):
|
||||
Ensure that a cluster can be shutdown twice, without error
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
cluster.shutdown()
|
||||
|
||||
try:
|
||||
@@ -108,7 +109,7 @@ class ClusterTests(unittest.TestCase):
|
||||
Ensure you cannot connect to a cluster that's been shutdown
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
cluster.shutdown()
|
||||
self.assertRaises(Exception, cluster.connect)
|
||||
|
||||
@@ -132,7 +133,8 @@ class ClusterTests(unittest.TestCase):
|
||||
when a cluster cannot connect to given hosts
|
||||
"""
|
||||
|
||||
cluster = Cluster(['127.1.2.9', '127.1.2.10'])
|
||||
cluster = Cluster(['127.1.2.9', '127.1.2.10'],
|
||||
protocol_version=PROTOCOL_VERSION)
|
||||
self.assertRaises(NoHostAvailable, cluster.connect)
|
||||
|
||||
def test_cluster_settings(self):
|
||||
@@ -140,7 +142,7 @@ class ClusterTests(unittest.TestCase):
|
||||
Test connection setting getters and setters
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
|
||||
min_requests_per_connection = cluster.get_min_requests_per_connection(HostDistance.LOCAL)
|
||||
self.assertEqual(cassandra.cluster.DEFAULT_MIN_REQUESTS, min_requests_per_connection)
|
||||
@@ -167,11 +169,11 @@ class ClusterTests(unittest.TestCase):
|
||||
Ensure new new schema is refreshed after submit_schema_refresh()
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
cluster.connect()
|
||||
self.assertNotIn("newkeyspace", cluster.metadata.keyspaces)
|
||||
|
||||
other_cluster = Cluster()
|
||||
other_cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = other_cluster.connect()
|
||||
session.execute(
|
||||
"""
|
||||
@@ -189,15 +191,22 @@ class ClusterTests(unittest.TestCase):
|
||||
Ensure trace can be requested for async and non-async queries
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
self.assertRaises(TypeError, session.execute, "SELECT * FROM system.local", trace=True)
|
||||
|
||||
def check_trace(trace):
|
||||
self.assertIsNot(None, trace.request_type)
|
||||
self.assertIsNot(None, trace.duration)
|
||||
self.assertIsNot(None, trace.started_at)
|
||||
self.assertIsNot(None, trace.coordinator)
|
||||
self.assertIsNot(None, trace.events)
|
||||
|
||||
query = "SELECT * FROM system.local"
|
||||
statement = SimpleStatement(query)
|
||||
session.execute(statement, trace=True)
|
||||
self.assertEqual(query, statement.trace.parameters['query'])
|
||||
check_trace(statement.trace)
|
||||
|
||||
query = "SELECT * FROM system.local"
|
||||
statement = SimpleStatement(query)
|
||||
@@ -207,15 +216,20 @@ class ClusterTests(unittest.TestCase):
|
||||
statement2 = SimpleStatement(query)
|
||||
future = session.execute_async(statement2, trace=True)
|
||||
future.result()
|
||||
self.assertEqual(query, future.get_query_trace().parameters['query'])
|
||||
check_trace(future.get_query_trace())
|
||||
|
||||
statement2 = SimpleStatement(query)
|
||||
future = session.execute_async(statement2)
|
||||
future.result()
|
||||
self.assertEqual(None, future.get_query_trace())
|
||||
|
||||
prepared = session.prepare("SELECT * FROM system.local")
|
||||
future = session.execute_async(prepared, parameters=(), trace=True)
|
||||
future.result()
|
||||
check_trace(future.get_query_trace())
|
||||
|
||||
def test_trace_timeout(self):
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
query = "SELECT * FROM system.local"
|
||||
@@ -229,7 +243,7 @@ class ClusterTests(unittest.TestCase):
|
||||
Ensure str(future) returns without error
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
query = "SELECT * FROM system.local"
|
||||
|
113
tests/integration/standard/test_concurrent.py
Normal file
113
tests/integration/standard/test_concurrent.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from tests.integration import PROTOCOL_VERSION
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from itertools import cycle
|
||||
|
||||
from cassandra import InvalidRequest
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.concurrent import (execute_concurrent,
|
||||
execute_concurrent_with_args)
|
||||
from cassandra.policies import HostDistance
|
||||
from cassandra.query import tuple_factory
|
||||
|
||||
|
||||
class ClusterTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
|
||||
self.session = self.cluster.connect()
|
||||
self.session.row_factory = tuple_factory
|
||||
|
||||
def test_execute_concurrent(self):
|
||||
for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
|
||||
# write
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
parameters = [(i, i) for i in range(num_statements)]
|
||||
|
||||
results = execute_concurrent(self.session, zip(statements, parameters))
|
||||
self.assertEqual(num_statements, len(results))
|
||||
self.assertEqual([(True, None)] * num_statements, results)
|
||||
|
||||
# read
|
||||
statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s", ))
|
||||
parameters = [(i, ) for i in range(num_statements)]
|
||||
|
||||
results = execute_concurrent(self.session, zip(statements, parameters))
|
||||
self.assertEqual(num_statements, len(results))
|
||||
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
|
||||
|
||||
def test_execute_concurrent_with_args(self):
|
||||
for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201):
|
||||
statement = "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"
|
||||
parameters = [(i, i) for i in range(num_statements)]
|
||||
|
||||
results = execute_concurrent_with_args(self.session, statement, parameters)
|
||||
self.assertEqual(num_statements, len(results))
|
||||
self.assertEqual([(True, None)] * num_statements, results)
|
||||
|
||||
# read
|
||||
statement = "SELECT v FROM test3rf.test WHERE k=%s"
|
||||
parameters = [(i, ) for i in range(num_statements)]
|
||||
|
||||
results = execute_concurrent_with_args(self.session, statement, parameters)
|
||||
self.assertEqual(num_statements, len(results))
|
||||
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
|
||||
|
||||
def test_first_failure(self):
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
parameters = [(i, i) for i in range(100)]
|
||||
|
||||
# we'll get an error back from the server
|
||||
parameters[57] = ('efefef', 'awefawefawef')
|
||||
|
||||
self.assertRaises(
|
||||
InvalidRequest,
|
||||
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
|
||||
|
||||
def test_first_failure_client_side(self):
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
parameters = [(i, i) for i in range(100)]
|
||||
|
||||
# the driver will raise an error when binding the params
|
||||
parameters[57] = 1
|
||||
|
||||
self.assertRaises(
|
||||
TypeError,
|
||||
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
|
||||
|
||||
def test_no_raise_on_first_failure(self):
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
parameters = [(i, i) for i in range(100)]
|
||||
|
||||
# we'll get an error back from the server
|
||||
parameters[57] = ('efefef', 'awefawefawef')
|
||||
|
||||
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
|
||||
for i, (success, result) in enumerate(results):
|
||||
if i == 57:
|
||||
self.assertFalse(success)
|
||||
self.assertIsInstance(result, InvalidRequest)
|
||||
else:
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(None, result)
|
||||
|
||||
def test_no_raise_on_first_failure_client_side(self):
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
parameters = [(i, i) for i in range(100)]
|
||||
|
||||
# the driver will raise an error when binding the params
|
||||
parameters[57] = i
|
||||
|
||||
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
|
||||
for i, (success, result) in enumerate(results):
|
||||
if i == 57:
|
||||
self.assertFalse(success)
|
||||
self.assertIsInstance(result, TypeError)
|
||||
else:
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(None, result)
|
@@ -1,9 +1,12 @@
|
||||
from tests.integration import PROTOCOL_VERSION
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from functools import partial
|
||||
import sys
|
||||
from threading import Thread, Event
|
||||
|
||||
from cassandra import ConsistencyLevel
|
||||
@@ -24,7 +27,7 @@ class ConnectionTest(object):
|
||||
"""
|
||||
Test a single connection with sequential requests.
|
||||
"""
|
||||
conn = self.klass.factory()
|
||||
conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
|
||||
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
|
||||
event = Event()
|
||||
|
||||
@@ -47,7 +50,7 @@ class ConnectionTest(object):
|
||||
"""
|
||||
Test a single connection with pipelined requests.
|
||||
"""
|
||||
conn = self.klass.factory()
|
||||
conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
|
||||
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
|
||||
responses = [False] * 100
|
||||
event = Event()
|
||||
@@ -69,7 +72,7 @@ class ConnectionTest(object):
|
||||
"""
|
||||
Test multiple connections with pipelined requests.
|
||||
"""
|
||||
conns = [self.klass.factory() for i in range(5)]
|
||||
conns = [self.klass.factory(protocol_version=PROTOCOL_VERSION) for i in range(5)]
|
||||
events = [Event() for i in range(5)]
|
||||
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
|
||||
|
||||
@@ -100,7 +103,7 @@ class ConnectionTest(object):
|
||||
num_threads = 5
|
||||
event = Event()
|
||||
|
||||
conn = self.klass.factory()
|
||||
conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
|
||||
query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1"
|
||||
|
||||
def cb(all_responses, thread_responses, request_num, *args, **kwargs):
|
||||
@@ -157,7 +160,7 @@ class ConnectionTest(object):
|
||||
|
||||
threads = []
|
||||
for i in range(num_conns):
|
||||
conn = self.klass.factory()
|
||||
conn = self.klass.factory(protocol_version=PROTOCOL_VERSION)
|
||||
t = Thread(target=send_msgs, args=(conn, events[i]))
|
||||
threads.append(t)
|
||||
|
||||
@@ -172,12 +175,18 @@ class AsyncoreConnectionTest(ConnectionTest, unittest.TestCase):
|
||||
|
||||
klass = AsyncoreConnection
|
||||
|
||||
def setUp(self):
|
||||
if 'gevent.monkey' in sys.modules:
|
||||
raise unittest.SkipTest("Can't test libev with gevent monkey patching")
|
||||
|
||||
|
||||
class LibevConnectionTest(ConnectionTest, unittest.TestCase):
|
||||
|
||||
klass = LibevConnection
|
||||
|
||||
def setUp(self):
|
||||
if 'gevent.monkey' in sys.modules:
|
||||
raise unittest.SkipTest("Can't test libev with gevent monkey patching")
|
||||
if LibevConnection is None:
|
||||
raise unittest.SkipTest(
|
||||
'libev does not appear to be installed properly')
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from tests.integration import PROTOCOL_VERSION
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
@@ -36,7 +38,7 @@ class TestFactories(unittest.TestCase):
|
||||
'''
|
||||
|
||||
def test_tuple_factory(self):
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
session.row_factory = tuple_factory
|
||||
|
||||
@@ -58,7 +60,7 @@ class TestFactories(unittest.TestCase):
|
||||
self.assertEqual(result[1][0], 2)
|
||||
|
||||
def test_named_tuple_factoryy(self):
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
session.row_factory = named_tuple_factory
|
||||
|
||||
@@ -79,7 +81,7 @@ class TestFactories(unittest.TestCase):
|
||||
self.assertEqual(result[1].k, 2)
|
||||
|
||||
def test_dict_factory(self):
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
session.row_factory = dict_factory
|
||||
|
||||
@@ -101,7 +103,7 @@ class TestFactories(unittest.TestCase):
|
||||
self.assertEqual(result[1]['k'], 2)
|
||||
|
||||
def test_ordered_dict_factory(self):
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
session.row_factory = ordered_dict_factory
|
||||
|
||||
|
@@ -15,7 +15,7 @@ from cassandra.metadata import (Metadata, KeyspaceMetadata, TableMetadata,
|
||||
from cassandra.policies import SimpleConvictionPolicy
|
||||
from cassandra.pool import Host
|
||||
|
||||
from tests.integration import get_cluster
|
||||
from tests.integration import get_cluster, PROTOCOL_VERSION
|
||||
|
||||
|
||||
class SchemaMetadataTest(unittest.TestCase):
|
||||
@@ -28,7 +28,7 @@ class SchemaMetadataTest(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
try:
|
||||
results = session.execute("SELECT keyspace_name FROM system.schema_keyspaces")
|
||||
@@ -46,7 +46,8 @@ class SchemaMetadataTest(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
cluster = Cluster(['127.0.0.1'])
|
||||
cluster = Cluster(['127.0.0.1'],
|
||||
protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
try:
|
||||
session.execute("DROP KEYSPACE %s" % cls.ksname)
|
||||
@@ -54,7 +55,8 @@ class SchemaMetadataTest(unittest.TestCase):
|
||||
cluster.shutdown()
|
||||
|
||||
def setUp(self):
|
||||
self.cluster = Cluster(['127.0.0.1'])
|
||||
self.cluster = Cluster(['127.0.0.1'],
|
||||
protocol_version=PROTOCOL_VERSION)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
def tearDown(self):
|
||||
@@ -294,7 +296,7 @@ class TestCodeCoverage(unittest.TestCase):
|
||||
Test export schema functionality
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
cluster.connect()
|
||||
|
||||
self.assertIsInstance(cluster.metadata.export_schema_as_string(), six.string_types)
|
||||
@@ -304,7 +306,7 @@ class TestCodeCoverage(unittest.TestCase):
|
||||
Test export keyspace schema functionality
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
cluster.connect()
|
||||
|
||||
for keyspace in cluster.metadata.keyspaces:
|
||||
@@ -317,7 +319,7 @@ class TestCodeCoverage(unittest.TestCase):
|
||||
Test that names that need to be escaped in CREATE statements are
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
ksname = 'AnInterestingKeyspace'
|
||||
@@ -356,7 +358,7 @@ class TestCodeCoverage(unittest.TestCase):
|
||||
Ensure AlreadyExists exception is thrown when hit
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
ksname = 'test3rf'
|
||||
@@ -380,7 +382,7 @@ class TestCodeCoverage(unittest.TestCase):
|
||||
if murmur3 is None:
|
||||
raise unittest.SkipTest('the murmur3 extension is not available')
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
self.assertEqual(cluster.metadata.get_replicas('test3rf', 'key'), [])
|
||||
|
||||
cluster.connect('test3rf')
|
||||
@@ -395,7 +397,7 @@ class TestCodeCoverage(unittest.TestCase):
|
||||
Test token mappings
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
cluster.connect('test3rf')
|
||||
ring = cluster.metadata.token_map.ring
|
||||
owners = list(cluster.metadata.token_map.token_to_host_owner[token] for token in ring)
|
||||
@@ -418,7 +420,7 @@ class TokenMetadataTest(unittest.TestCase):
|
||||
def test_token(self):
|
||||
expected_node_count = len(get_cluster().nodes)
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
cluster.connect()
|
||||
tmap = cluster.metadata.token_map
|
||||
self.assertTrue(issubclass(tmap.token_class, Token))
|
||||
|
@@ -7,7 +7,7 @@ from cassandra.query import SimpleStatement
|
||||
from cassandra import ConsistencyLevel, WriteTimeout, Unavailable, ReadTimeout
|
||||
|
||||
from cassandra.cluster import Cluster, NoHostAvailable
|
||||
from tests.integration import get_node, get_cluster
|
||||
from tests.integration import get_node, get_cluster, PROTOCOL_VERSION
|
||||
|
||||
|
||||
class MetricsTests(unittest.TestCase):
|
||||
@@ -17,7 +17,8 @@ class MetricsTests(unittest.TestCase):
|
||||
Trigger and ensure connection_errors are counted
|
||||
"""
|
||||
|
||||
cluster = Cluster(metrics_enabled=True)
|
||||
cluster = Cluster(metrics_enabled=True,
|
||||
protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
session.execute("USE test3rf")
|
||||
|
||||
@@ -45,7 +46,8 @@ class MetricsTests(unittest.TestCase):
|
||||
Attempt a write at cl.ALL and receive a WriteTimeout.
|
||||
"""
|
||||
|
||||
cluster = Cluster(metrics_enabled=True)
|
||||
cluster = Cluster(metrics_enabled=True,
|
||||
protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
# Test write
|
||||
@@ -75,7 +77,8 @@ class MetricsTests(unittest.TestCase):
|
||||
Attempt a read at cl.ALL and receive a ReadTimeout.
|
||||
"""
|
||||
|
||||
cluster = Cluster(metrics_enabled=True)
|
||||
cluster = Cluster(metrics_enabled=True,
|
||||
protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
# Test write
|
||||
@@ -105,7 +108,8 @@ class MetricsTests(unittest.TestCase):
|
||||
Attempt an insert/read at cl.ALL and receive a Unavailable Exception.
|
||||
"""
|
||||
|
||||
cluster = Cluster(metrics_enabled=True)
|
||||
cluster = Cluster(metrics_enabled=True,
|
||||
protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
# Test write
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from tests.integration import PROTOCOL_VERSION
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
@@ -15,7 +17,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
Test basic PreparedStatement usage
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
session.execute(
|
||||
"""
|
||||
@@ -60,7 +62,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
when prepared statements are missing the primary key
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
prepared = session.prepare(
|
||||
@@ -77,7 +79,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
Ensure a ValueError is thrown when attempting to bind too many variables
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
prepared = session.prepare(
|
||||
@@ -93,7 +95,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
Ensure binding None is handled correctly
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
prepared = session.prepare(
|
||||
@@ -120,7 +122,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
Ensure None binding over async queries
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
prepared = session.prepare(
|
||||
|
@@ -10,13 +10,13 @@ from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.policies import HostDistance
|
||||
|
||||
from tests.integration import get_server_versions
|
||||
from tests.integration import get_server_versions, PROTOCOL_VERSION
|
||||
|
||||
|
||||
class QueryTest(unittest.TestCase):
|
||||
|
||||
def test_query(self):
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
prepared = session.prepare(
|
||||
@@ -44,7 +44,7 @@ class QueryTest(unittest.TestCase):
|
||||
Code coverage to ensure trace prints to string without error
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
query = "SELECT * FROM system.local"
|
||||
@@ -57,7 +57,7 @@ class QueryTest(unittest.TestCase):
|
||||
str(event)
|
||||
|
||||
def test_trace_ignores_row_factory(self):
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
session.row_factory = dict_factory
|
||||
|
||||
@@ -78,7 +78,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
Simple code coverage to ensure routing_keys can be accessed
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
prepared = session.prepare(
|
||||
@@ -96,7 +96,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
the routing key should be None
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
prepared = session.prepare(
|
||||
@@ -115,7 +115,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
overrides the current routing key
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
prepared = session.prepare(
|
||||
@@ -133,7 +133,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
Basic test that uses a fake routing_key_index
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
prepared = session.prepare(
|
||||
@@ -151,7 +151,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
Ensure that bound.keyspace works as expected
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
prepared = session.prepare(
|
||||
@@ -186,7 +186,7 @@ class PrintStatementTests(unittest.TestCase):
|
||||
Highlight the difference between Prepared and Bound statements
|
||||
"""
|
||||
|
||||
cluster = Cluster()
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)')
|
||||
@@ -202,13 +202,12 @@ class PrintStatementTests(unittest.TestCase):
|
||||
class BatchStatementTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
cass_version, _ = get_server_versions()
|
||||
if cass_version < (2, 0):
|
||||
if PROTOCOL_VERSION < 2:
|
||||
raise unittest.SkipTest(
|
||||
"Cassandra 2.0+ is required for BATCH operations, currently testing against %r"
|
||||
% (cass_version,))
|
||||
"Protocol 2.0+ is required for BATCH operations, currently testing against %r"
|
||||
% (PROTOCOL_VERSION,))
|
||||
|
||||
self.cluster = Cluster(protocol_version=2)
|
||||
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
@@ -272,13 +271,12 @@ class BatchStatementTests(unittest.TestCase):
|
||||
class SerialConsistencyTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
cass_version, _ = get_server_versions()
|
||||
if cass_version < (2, 0):
|
||||
if PROTOCOL_VERSION < 2:
|
||||
raise unittest.SkipTest(
|
||||
"Cassandra 2.0+ is required for BATCH operations, currently testing against %r"
|
||||
% (cass_version,))
|
||||
"Protocol 2.0+ is required for BATCH operations, currently testing against %r"
|
||||
% (PROTOCOL_VERSION,))
|
||||
|
||||
self.cluster = Cluster(protocol_version=2)
|
||||
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
|
112
tests/integration/standard/test_query_paging.py
Normal file
112
tests/integration/standard/test_query_paging.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from tests.integration import PROTOCOL_VERSION
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from itertools import cycle, count
|
||||
from threading import Event
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.concurrent import execute_concurrent
|
||||
from cassandra.policies import HostDistance
|
||||
from cassandra.query import SimpleStatement
|
||||
|
||||
|
||||
class QueryPagingTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if PROTOCOL_VERSION < 2:
|
||||
raise unittest.SkipTest(
|
||||
"Protocol 2.0+ is required for BATCH operations, currently testing against %r"
|
||||
% (PROTOCOL_VERSION,))
|
||||
|
||||
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
|
||||
self.session = self.cluster.connect()
|
||||
self.session.execute("TRUNCATE test3rf.test")
|
||||
|
||||
def test_paging(self):
|
||||
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
||||
[(i, ) for i in range(100)])
|
||||
execute_concurrent(self.session, statements_and_params)
|
||||
|
||||
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
||||
|
||||
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
|
||||
self.session.default_fetch_size = fetch_size
|
||||
self.assertEqual(100, len(list(self.session.execute("SELECT * FROM test3rf.test"))))
|
||||
|
||||
statement = SimpleStatement("SELECT * FROM test3rf.test")
|
||||
self.assertEqual(100, len(list(self.session.execute(statement))))
|
||||
|
||||
self.assertEqual(100, len(list(self.session.execute(prepared))))
|
||||
|
||||
def test_async_paging(self):
|
||||
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
||||
[(i, ) for i in range(100)])
|
||||
execute_concurrent(self.session, statements_and_params)
|
||||
|
||||
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
||||
|
||||
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
|
||||
self.session.default_fetch_size = fetch_size
|
||||
self.assertEqual(100, len(list(self.session.execute_async("SELECT * FROM test3rf.test").result())))
|
||||
|
||||
statement = SimpleStatement("SELECT * FROM test3rf.test")
|
||||
self.assertEqual(100, len(list(self.session.execute_async(statement).result())))
|
||||
|
||||
self.assertEqual(100, len(list(self.session.execute_async(prepared).result())))
|
||||
|
||||
def test_paging_callbacks(self):
|
||||
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
||||
[(i, ) for i in range(100)])
|
||||
execute_concurrent(self.session, statements_and_params)
|
||||
|
||||
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
||||
|
||||
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
|
||||
self.session.default_fetch_size = fetch_size
|
||||
future = self.session.execute_async("SELECT * FROM test3rf.test")
|
||||
|
||||
event = Event()
|
||||
counter = count()
|
||||
|
||||
def handle_page(rows, future, counter):
|
||||
for row in rows:
|
||||
counter.next()
|
||||
|
||||
if future.has_more_pages:
|
||||
future.start_fetching_next_page()
|
||||
else:
|
||||
event.set()
|
||||
|
||||
def handle_error(err):
|
||||
event.set()
|
||||
self.fail(err)
|
||||
|
||||
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
||||
event.wait()
|
||||
self.assertEquals(counter.next(), 100)
|
||||
|
||||
# simple statement
|
||||
future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"))
|
||||
event.clear()
|
||||
counter = count()
|
||||
|
||||
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
||||
event.wait()
|
||||
self.assertEquals(counter.next(), 100)
|
||||
|
||||
# prepared statement
|
||||
future = self.session.execute_async(prepared)
|
||||
event.clear()
|
||||
counter = count()
|
||||
|
||||
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
||||
event.wait()
|
||||
self.assertEquals(counter.next(), 100)
|
@@ -18,7 +18,7 @@ from cassandra.cqltypes import Int32Type, EMPTY
|
||||
from cassandra.query import dict_factory
|
||||
from cassandra.util import OrderedDict
|
||||
|
||||
from tests.integration import get_server_versions
|
||||
from tests.integration import get_server_versions, PROTOCOL_VERSION
|
||||
|
||||
|
||||
class TypeTests(unittest.TestCase):
|
||||
@@ -27,7 +27,7 @@ class TypeTests(unittest.TestCase):
|
||||
self._cass_version, self._cql_version = get_server_versions()
|
||||
|
||||
def test_blob_type_as_string(self):
|
||||
c = Cluster()
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
s.execute("""
|
||||
@@ -69,7 +69,7 @@ class TypeTests(unittest.TestCase):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_blob_type_as_bytearray(self):
|
||||
c = Cluster()
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
s.execute("""
|
||||
@@ -129,7 +129,7 @@ class TypeTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_basic_types(self):
|
||||
c = Cluster()
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
s.execute("""
|
||||
CREATE KEYSPACE typetests
|
||||
@@ -226,7 +226,7 @@ class TypeTests(unittest.TestCase):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_empty_strings_and_nones(self):
|
||||
c = Cluster()
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
s.execute("""
|
||||
CREATE KEYSPACE test_empty_strings_and_nones
|
||||
@@ -329,7 +329,7 @@ class TypeTests(unittest.TestCase):
|
||||
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None])
|
||||
|
||||
def test_empty_values(self):
|
||||
c = Cluster()
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
s.execute("""
|
||||
CREATE KEYSPACE test_empty_values
|
||||
@@ -356,7 +356,7 @@ class TypeTests(unittest.TestCase):
|
||||
eastern_tz = pytz.timezone('US/Eastern')
|
||||
eastern_tz.localize(dt)
|
||||
|
||||
c = Cluster()
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
s.execute("""CREATE KEYSPACE tz_aware_test
|
||||
|
@@ -136,13 +136,8 @@ class ConnectionTest(unittest.TestCase):
|
||||
"""
|
||||
Ensure the following methods throw NIE's. If not, come back and test them.
|
||||
"""
|
||||
|
||||
c = self.make_connection()
|
||||
|
||||
self.assertRaises(NotImplementedError, c.close)
|
||||
self.assertRaises(NotImplementedError, c.defunct, None)
|
||||
self.assertRaises(NotImplementedError, c.send_msg, None, None)
|
||||
self.assertRaises(NotImplementedError, c.wait_for_response, None)
|
||||
self.assertRaises(NotImplementedError, c.wait_for_responses)
|
||||
self.assertRaises(NotImplementedError, c.register_watcher, None, None)
|
||||
self.assertRaises(NotImplementedError, c.register_watchers, None)
|
||||
|
@@ -59,8 +59,8 @@ class MockCluster(object):
|
||||
self.scheduler = Mock(spec=_Scheduler)
|
||||
self.executor = Mock(spec=ThreadPoolExecutor)
|
||||
|
||||
def add_host(self, address, signal=False):
|
||||
host = Host(address, SimpleConvictionPolicy)
|
||||
def add_host(self, address, datacenter, rack, signal=False):
|
||||
host = Host(address, SimpleConvictionPolicy, datacenter, rack)
|
||||
self.added_hosts.append(host)
|
||||
return host
|
||||
|
||||
@@ -212,6 +212,7 @@ class ControlConnectionTest(unittest.TestCase):
|
||||
self.connection.peer_results[1].append(
|
||||
["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", ["3", "103", "203"]]
|
||||
)
|
||||
self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs)
|
||||
self.control_connection.refresh_node_list_and_token_map()
|
||||
self.assertEqual(1, len(self.cluster.added_hosts))
|
||||
self.assertEqual(self.cluster.added_hosts[0].address, "192.168.1.3")
|
||||
@@ -250,7 +251,7 @@ class ControlConnectionTest(unittest.TestCase):
|
||||
'address': ('1.2.3.4', 9000)
|
||||
}
|
||||
self.control_connection._handle_topology_change(event)
|
||||
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.add_host, '1.2.3.4', signal=True)
|
||||
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map)
|
||||
|
||||
event = {
|
||||
'change_type': 'REMOVED_NODE',
|
||||
@@ -272,7 +273,7 @@ class ControlConnectionTest(unittest.TestCase):
|
||||
'address': ('1.2.3.4', 9000)
|
||||
}
|
||||
self.control_connection._handle_status_change(event)
|
||||
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.add_host, '1.2.3.4', signal=True)
|
||||
self.cluster.scheduler.schedule.assert_called_with(ANY, self.control_connection.refresh_node_list_and_token_map)
|
||||
|
||||
# do the same with a known Host
|
||||
event = {
|
||||
|
@@ -35,6 +35,9 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)
|
||||
return ResponseFuture(session, message, query)
|
||||
|
||||
def make_mock_response(self, results):
|
||||
return Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=results, paging_state=None)
|
||||
|
||||
def test_result_message(self):
|
||||
session = self.make_basic_session()
|
||||
session._load_balancer.make_query_plan.return_value = ['ip1', 'ip2']
|
||||
@@ -49,9 +52,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
connection = pool.borrow_connection.return_value
|
||||
connection.send_msg.assert_called_once_with(rf.message, cb=ANY)
|
||||
|
||||
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
|
||||
rf._set_result(response)
|
||||
|
||||
rf._set_result(self.make_mock_response([{'col': 'val'}]))
|
||||
result = rf.result()
|
||||
self.assertEqual(result, [{'col': 'val'}])
|
||||
|
||||
@@ -259,8 +260,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
rf = self.make_response_future(session)
|
||||
rf.send_request()
|
||||
|
||||
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
|
||||
rf._set_result(response)
|
||||
rf._set_result(self.make_mock_response([{'col': 'val'}]))
|
||||
|
||||
result = rf.result()
|
||||
self.assertEqual(result, [{'col': 'val'}])
|
||||
@@ -280,8 +280,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
rf = self.make_response_future(session)
|
||||
rf.send_request()
|
||||
|
||||
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
|
||||
rf._set_result(response)
|
||||
rf._set_result(self.make_mock_response([{'col': 'val'}]))
|
||||
self.assertEqual(rf.result(), [{'col': 'val'}])
|
||||
|
||||
# make sure the exception is recorded correctly
|
||||
@@ -294,8 +293,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
|
||||
rf.add_callback(self.assertEqual, [{'col': 'val'}])
|
||||
|
||||
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
|
||||
rf._set_result(response)
|
||||
rf._set_result(self.make_mock_response([{'col': 'val'}]))
|
||||
|
||||
result = rf.result()
|
||||
self.assertEqual(result, [{'col': 'val'}])
|
||||
@@ -349,8 +347,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
callback=self.assertEqual, callback_args=([{'col': 'val'}],),
|
||||
errback=self.assertIsInstance, errback_args=(Exception,))
|
||||
|
||||
response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
|
||||
rf._set_result(response)
|
||||
rf._set_result(self.make_mock_response([{'col': 'val'}]))
|
||||
self.assertEqual(rf.result(), [{'col': 'val'}])
|
||||
|
||||
def test_prepared_query_not_found(self):
|
||||
|
Reference in New Issue
Block a user