Merge branch 'master' into 2.0

Conflicts:
	cassandra/cluster.py
	cassandra/concurrent.py
	cassandra/decoder.py
	tests/integration/__init__.py
	tests/integration/long/test_loadbalancingpolicies.py
	tests/integration/standard/test_cluster.py
	tests/integration/standard/test_concurrent.py
	tests/integration/standard/test_connection.py
	tests/integration/standard/test_factories.py
	tests/integration/standard/test_prepared_statements.py
	tests/unit/test_control_connection.py
This commit is contained in:
Tyler Hobbs
2014-04-16 15:13:19 -05:00
25 changed files with 533 additions and 60 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
*.swp
*.swo
*.so
*.egg
*.egg-info
.tox
.idea/

View File

@@ -8,6 +8,9 @@ Features
github issue #46)
* Support static columns in schemas, which are available starting in
Cassandra 2.1. (github issue #91)
* Add debian packaging (github issue #101)
* Add utility methods for easy concurrent execution of statements. See
the new cassandra.concurrent module. (github issue #7)
Bug Fixes
---------
@@ -27,6 +30,8 @@ Bug Fixes
and rack information has been set, if possible.
* Avoid KeyError when updating metadata after droping a table (github issues
#97, #98)
* Use tuples instead of sets for DCAwareLoadBalancingPolicy to ensure equal
distribution of requests
Other
-----
@@ -34,6 +39,10 @@ Other
user-defined type support. (github issue #90)
* Better error message when libevwrapper is not found
* Only try to import scales when metrics are enabled (github issue #92)
* Cut down on the number of queries executing when a new Cluster
connects and when the control connection has to reconnect (github issue #104,
PYTHON-59)
* Issue warning log when schema versions do not match
1.0.2
=====

View File

@@ -18,6 +18,7 @@ This module houses the main classes you will interact with,
"""
from __future__ import absolute_import
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import logging
import socket
@@ -54,7 +55,7 @@ from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy,
ExponentialReconnectionPolicy, HostDistance,
RetryPolicy)
from cassandra.pool import (_ReconnectionHandler, _HostReconnectionHandler,
HostConnectionPool)
HostConnectionPool, NoConnectionsAvailable)
from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
BatchStatement, bind_params, QueryTrace, Statement,
named_tuple_factory, dict_factory)
@@ -1380,8 +1381,8 @@ class ControlConnection(object):
_SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies"
_SELECT_COLUMNS = "SELECT * FROM system.schema_columns"
_SELECT_PEERS = "SELECT peer, data_center, rack, tokens, rpc_address FROM system.peers"
_SELECT_LOCAL = "SELECT cluster_name, data_center, rack, tokens, partitioner FROM system.local WHERE key='local'"
_SELECT_PEERS = "SELECT peer, data_center, rack, tokens, rpc_address, schema_version FROM system.peers"
_SELECT_LOCAL = "SELECT cluster_name, data_center, rack, tokens, partitioner, schema_version FROM system.local WHERE key='local'"
_SELECT_SCHEMA_PEERS = "SELECT rpc_address, schema_version FROM system.peers"
_SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'"
@@ -1464,8 +1465,13 @@ class ControlConnection(object):
"SCHEMA_CHANGE": self._handle_schema_change
})
self._refresh_node_list_and_token_map(connection)
self._refresh_schema(connection)
peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=ConsistencyLevel.ONE)
local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=ConsistencyLevel.ONE)
shared_results = connection.wait_for_responses(
peers_query, local_query, timeout=self._timeout)
self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results)
self._refresh_schema(connection, preloaded_results=shared_results)
except Exception:
connection.close()
raise
@@ -1547,11 +1553,11 @@ class ControlConnection(object):
log.debug("[control connection] Error refreshing schema", exc_info=True)
self._signal_error()
def _refresh_schema(self, connection, keyspace=None, table=None):
def _refresh_schema(self, connection, keyspace=None, table=None, preloaded_results=None):
if self._cluster.is_shutdown:
return
self.wait_for_schema_agreement(connection)
self.wait_for_schema_agreement(connection, preloaded_results=preloaded_results)
where_clause = ""
if keyspace:
@@ -1598,13 +1604,19 @@ class ControlConnection(object):
log.debug("[control connection] Error refreshing node list and token map", exc_info=True)
self._signal_error()
def _refresh_node_list_and_token_map(self, connection):
log.debug("[control connection] Refreshing node list and token map")
cl = ConsistencyLevel.ONE
peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=cl)
local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=cl)
peers_result, local_result = connection.wait_for_responses(
peers_query, local_query, timeout=self._timeout)
def _refresh_node_list_and_token_map(self, connection, preloaded_results=None):
if preloaded_results:
log.debug("[control connection] Refreshing node list and token map using preloaded results")
peers_result = preloaded_results[0]
local_result = preloaded_results[1]
else:
log.debug("[control connection] Refreshing node list and token map")
cl = ConsistencyLevel.ONE
peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=cl)
local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=cl)
peers_result, local_result = connection.wait_for_responses(
peers_query, local_query, timeout=self._timeout)
peers_result = dict_factory(*peers_result.results)
partitioner = None
@@ -1712,7 +1724,7 @@ class ControlConnection(object):
elif event['change_type'] == "UPDATED":
self._submit(self.refresh_schema, keyspace, table)
def wait_for_schema_agreement(self, connection=None):
def wait_for_schema_agreement(self, connection=None, preloaded_results=None):
# Each schema change typically generates two schema refreshes, one
# from the response type and one from the pushed notification. Holding
# a lock is just a simple way to cut down on the number of schema queries
@@ -1721,14 +1733,24 @@ class ControlConnection(object):
if self._is_shutdown:
return
log.debug("[control connection] Waiting for schema agreement")
if not connection:
connection = self._connection
if preloaded_results:
log.debug("[control connection] Attempting to use preloaded results for schema agreement")
peers_result = preloaded_results[0]
local_result = preloaded_results[1]
schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.host)
if schema_mismatches is None:
return True
log.debug("[control connection] Waiting for schema agreement")
start = self._time.time()
elapsed = 0
cl = ConsistencyLevel.ONE
total_timeout = self._cluster.max_schema_agreement_wait
schema_mismatches = None
while elapsed < total_timeout:
peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl)
local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl)
@@ -1742,36 +1764,45 @@ class ControlConnection(object):
elapsed = self._time.time() - start
continue
peers_result = dict_factory(*peers_result.results)
versions = set()
if local_result.results:
local_row = dict_factory(*local_result.results)[0]
if local_row.get("schema_version"):
versions.add(local_row.get("schema_version"))
for row in peers_result:
if not row.get("rpc_address") or not row.get("schema_version"):
continue
rpc = row.get("rpc_address")
if rpc == "0.0.0.0": # TODO ipv6 check
rpc = row.get("peer")
peer = self._cluster.metadata.get_host(rpc)
if peer and peer.is_up:
versions.add(row.get("schema_version"))
if len(versions) == 1:
log.debug("[control connection] Schemas match")
schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.host)
if schema_mismatches is None:
return True
log.debug("[control connection] Schemas mismatched, trying again")
self._time.sleep(0.2)
elapsed = self._time.time() - start
log.warn("Node %s is reporting a schema disagreement: %s",
connection.host, schema_mismatches)
return False
def _get_schema_mismatches(self, peers_result, local_result, local_address):
peers_result = dict_factory(*peers_result.results)
versions = defaultdict(set)
if local_result.results:
local_row = dict_factory(*local_result.results)[0]
if local_row.get("schema_version"):
versions[local_row.get("schema_version")].add(local_address)
for row in peers_result:
if not row.get("rpc_address") or not row.get("schema_version"):
continue
rpc = row.get("rpc_address")
if rpc == "0.0.0.0": # TODO ipv6 check
rpc = row.get("peer")
peer = self._cluster.metadata.get_host(rpc)
if peer and peer.is_up:
versions[row.get("schema_version")].add(rpc)
if len(versions) == 1:
log.debug("[control connection] Schemas match")
return None
return dict((version, list(nodes)) for version, nodes in versions.iteritems())
def _signal_error(self):
# try just signaling the cluster, as this will trigger a reconnect
# as part of marking the host down
@@ -1976,6 +2007,10 @@ class ResponseFuture(object):
# TODO get connectTimeout from cluster settings
connection = pool.borrow_connection(timeout=2.0)
request_id = connection.send_msg(message, cb=cb)
except NoConnectionsAvailable as exc:
log.debug("All connections for host %s are at capacity, moving to the next host", host)
self._errors[host] = exc
return None
except Exception as exc:
log.debug("Error querying host %s", host, exc_info=True)
self._errors[host] = exc

View File

@@ -65,9 +65,16 @@ def execute_concurrent(session, statements_and_parameters, concurrency=100, rais
if not statements_and_parameters:
return []
# TODO handle iterators and generators naturally without converting the
# whole thing to a list. This would requires not building a result
# list of Nones up front (we don't know how many results there will be),
# so a dict keyed by index should be used instead. The tricky part is
# knowing when you're the final statement to finish.
statements_and_parameters = list(statements_and_parameters)
event = Event()
first_error = [] if raise_on_first_error else None
to_execute = len(statements_and_parameters) # TODO handle iterators/generators
to_execute = len(statements_and_parameters)
results = [None] * to_execute
num_finished = count(start=1)
statements = enumerate(iter(statements_and_parameters))
@@ -76,7 +83,12 @@ def execute_concurrent(session, statements_and_parameters, concurrency=100, rais
event.wait()
if first_error:
raise first_error[0]
exc = first_error[0]
if isinstance(exc, tuple):
(exc_type, value, traceback) = exc
raise exc_type, value, traceback
else:
raise exc
else:
return results

View File

@@ -224,7 +224,7 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
def populate(self, cluster, hosts):
for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)):
self._dc_live_hosts[dc] = frozenset(dc_hosts)
self._dc_live_hosts[dc] = tuple(set(dc_hosts))
# position is currently only used for local hosts
local_live = self._dc_live_hosts.get(self.local_dc)
@@ -258,7 +258,7 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
pos = self._position
self._position += 1
local_live = list(self._dc_live_hosts.get(self.local_dc, ()))
local_live = self._dc_live_hosts.get(self.local_dc, ())
pos = (pos % len(local_live)) if local_live else 0
for host in islice(cycle(local_live), pos, pos + len(local_live)):
yield host
@@ -267,32 +267,36 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
if dc == self.local_dc:
continue
for host in list(current_dc_hosts)[:self.used_hosts_per_remote_dc]:
for host in current_dc_hosts[:self.used_hosts_per_remote_dc]:
yield host
def on_up(self, host):
dc = self._dc(host)
with self._hosts_lock:
current_hosts = self._dc_live_hosts.setdefault(dc, frozenset())
self._dc_live_hosts[dc] = current_hosts.union((host, ))
current_hosts = self._dc_live_hosts.setdefault(dc, ())
if host not in current_hosts:
self._dc_live_hosts[dc] = current_hosts + (host, )
def on_down(self, host):
dc = self._dc(host)
with self._hosts_lock:
current_hosts = self._dc_live_hosts.setdefault(dc, frozenset())
self._dc_live_hosts[dc] = current_hosts.difference((host, ))
current_hosts = self._dc_live_hosts.setdefault(dc, ())
if host in current_hosts:
self._dc_live_hosts[dc] = tuple(h for h in current_hosts if h != host)
def on_add(self, host):
dc = self._dc(host)
with self._hosts_lock:
current_hosts = self._dc_live_hosts.setdefault(dc, frozenset())
self._dc_live_hosts[dc] = current_hosts.union((host, ))
current_hosts = self._dc_live_hosts.setdefault(dc, ())
if host not in current_hosts:
self._dc_live_hosts[dc] = current_hosts + (host, )
def on_remove(self, host):
dc = self._dc(host)
with self._hosts_lock:
current_hosts = self._dc_live_hosts.setdefault(dc, frozenset())
self._dc_live_hosts[dc] = current_hosts.difference((host, ))
current_hosts = self._dc_live_hosts.setdefault(dc, ())
if host in current_hosts:
self._dc_live_hosts[dc] = tuple(h for h in current_hosts if h != host)
class TokenAwarePolicy(LoadBalancingPolicy):

View File

@@ -264,7 +264,7 @@ class _HostReconnectionHandler(_ReconnectionHandler):
if isinstance(exc, AuthenticationFailed):
return False
else:
log.warn("Error attempting to reconnect to %s, scheduling retry in %f seconds: %s",
log.warn("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s",
self.host, next_delay, exc)
log.debug("Reconnection error details", exc_info=True)
return True

View File

@@ -322,12 +322,49 @@ class BoundStatement(Statement):
def bind(self, values):
"""
Binds a sequence of values for the prepared statement parameters
and returns this instance. Note that `values` *must* be a
sequence, even if you are only binding one value.
and returns this instance. Note that `values` *must* be:
* a sequence, even if you are only binding one value, or
* a dict that relates 1-to-1 between dict keys and columns
"""
if values is None:
values = ()
col_meta = self.prepared_statement.column_metadata
# special case for binding dicts
if isinstance(values, dict):
dict_values = values
values = []
# sort values accordingly
for col in col_meta:
try:
values.append(dict_values[col[2]])
except KeyError:
raise KeyError(
'Column name `%s` not found in bound dict.' %
(col[2]))
# ensure a 1-to-1 dict keys to columns relationship
if len(dict_values) != len(col_meta):
# find expected columns
columns = set()
for col in col_meta:
columns.add(col[2])
# generate error message
if len(dict_values) > len(col_meta):
difference = set(dict_values.keys()).difference(columns)
msg = "Too many arguments provided to bind() (got %d, expected %d). " + \
"Unexpected keys %s."
else:
difference = set(columns).difference(dict_values.keys())
msg = "Too few arguments provided to bind() (got %d, expected %d). " + \
"Expected keys %s."
# exit with error message
msg = msg % (len(values), len(col_meta), difference)
raise ValueError(msg)
if len(values) > len(col_meta):
raise ValueError(
"Too many arguments provided to bind() (got %d, expected %d)" %
@@ -581,7 +618,8 @@ class QueryTrace(object):
while True:
time_spent = time.time() - start
if max_wait is not None and time_spent >= max_wait:
raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,))
raise TraceUnavailable(
"Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,))
log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id)
session_results = self._execute(

5
debian/changelog vendored Normal file
View File

@@ -0,0 +1,5 @@
python-cassandra-driver (1.1.0~prerelease-1) unstable; urgency=low
* Initial packaging
-- paul cannon <pik@debian.org> Thu, 03 Apr 2014 10:30:11 -0600

1
debian/compat vendored Normal file
View File

@@ -0,0 +1 @@
9

46
debian/control vendored Normal file
View File

@@ -0,0 +1,46 @@
Source: python-cassandra-driver
Maintainer: paul cannon <pik@debian.org>
Section: python
Priority: optional
Build-Depends: python-all-dev (>= 2.6.6-3), python-all-dbg, debhelper (>= 9),
python-sphinx (>= 1.0.7+dfsg) | python3-sphinx, libev-dev,
python-concurrent.futures | python-futures, python-setuptools,
python-nose, python-mock, python-yaml, python-gevent,
python-blist, python-tz
X-Python-Version: >= 2.7
Standards-Version: 3.9.4
Package: python-cassandra-driver
Architecture: any
Depends: ${misc:Depends}, ${python:Depends}, ${shlibs:Depends}, python-blist,
python-concurrent.futures | python-futures
Provides: ${python:Provides}
Recommends: python-scales
Suggests: python-cassandra-driver-doc
Description: Python driver for Apache Cassandra
This driver works exclusively with the Cassandra Query Language v3 (CQL3)
and Cassandra's native protocol. As such, only Cassandra 1.2+ is supported.
Package: python-cassandra-driver-dbg
Architecture: any
Depends: ${misc:Depends}, ${python:Depends}, ${shlibs:Depends},
python-cassandra-driver (= ${binary:Version})
Provides: ${python:Provides}
Section: debug
Priority: extra
Description: Python driver for Apache Cassandra (debug build and symbols)
This driver works exclusively with the Cassandra Query Language v3 (CQL3)
and Cassandra's native protocol. As such, only Cassandra 1.2+ is supported.
.
This package contains debug builds of the extensions and debug symbols for
the extensions in the main package.
Package: python-cassandra-driver-doc
Architecture: all
Section: doc
Priority: extra
Depends: ${misc:Depends}, ${sphinxdoc:Depends}
Suggests: python-cassandra-driver
Description: Python driver for Apache Cassandra (documentation)
This contains HTML documentation for the use of the Python Cassandra
driver.

28
debian/copyright vendored Normal file
View File

@@ -0,0 +1,28 @@
Format: http://www.debian.org/doc/packaging-manuals/copyright-format/1.0/
Upstream-Name: python-driver
Upstream-Contact: Tyler Hobbs <tyler@datastax.com>
Source: https://github.com/datastax/python-driver
Files: *
Copyright: Copyright 2013, DataStax
License: Apache-2.0
Files: debian/*
Copyright: Copyright (c) 2014 by Space Monkey, Inc.
License: Apache-2.0
License: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
.
http://www.apache.org/licenses/LICENSE-2.0
.
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
.
On Debian systems, the full text of the Apache License version 2.0
can be found in the file `/usr/share/common-licenses/Apache-2.0'.

View File

@@ -0,0 +1,40 @@
From: paul cannon <paul@spacemonkey.com>
Date: Thu, 3 Apr 2014 11:27:09 -0600
Subject: don't use ez_setup
Debian packages aren't supposed to download stuff while building, and
since the version of setuptools in stable is less than the one ez_setup
wants, and since some system python packages don't ship their .egg-info
directories, it might try.
It's ok though, we can rely on the Depends and Build-Depends for making
sure python-setuptools and the various other deps are around at the right
times.
---
setup.py | 7 ++-----
1 file changed, 2 insertions(+), 5 deletions(-)
diff --git a/setup.py b/setup.py
index 0c28d3d..c0fd6c1 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,5 @@
import sys
-import ez_setup
-ez_setup.use_setuptools()
-
if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests":
from gevent.monkey import patch_all
patch_all()
@@ -174,8 +171,8 @@ def run_setup(extensions):
author_email='tyler@datastax.com',
packages=['cassandra', 'cassandra.io'],
include_package_data=True,
- install_requires=dependencies,
- tests_require=['nose', 'mock', 'ccm', 'unittest2', 'PyYAML', 'pytz'],
+ install_requires=(),
+ tests_require=(),
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',

1
debian/patches/series vendored Normal file
View File

@@ -0,0 +1 @@
0001-don-t-use-ez_setup.patch

View File

@@ -0,0 +1,2 @@
usr/lib/python2*/*-packages/cassandra/*_d.so
usr/lib/python2*/*-packages/cassandra/io/*_d.so

View File

@@ -0,0 +1 @@
docs/_build/*/*

View File

@@ -0,0 +1,4 @@
usr/lib/python2*/*-packages/cassandra/*[!_][!_].so
usr/lib/python2*/*-packages/cassandra/*.py
usr/lib/python2*/*-packages/cassandra/io/*[!_][!_].so
usr/lib/python2*/*-packages/cassandra/io/*.py

16
debian/rules vendored Executable file
View File

@@ -0,0 +1,16 @@
#!/usr/bin/make -f
%:
dh $@ --with python2,sphinxdoc
override_dh_auto_build:
dh_auto_build
python setup.py doc
ifeq (,$(filter nocheck,$(DEB_BUILD_OPTIONS)))
override_dh_auto_test:
python setup.py gevent_nosetests
endif
override_dh_strip:
dh_strip --dbg-package=python-cassandra-driver-dbg

1
debian/source/format vendored Normal file
View File

@@ -0,0 +1 @@
3.0 (quilt)

View File

@@ -0,0 +1,8 @@
``cassandra.concurrent`` - Utilities for Concurrent Statement Execution
=======================================================================
.. module:: cassandra.concurrent
.. autofunction:: execute_concurrent
.. autofunction:: execute_concurrent_with_args

View File

@@ -10,6 +10,7 @@ API Documentation
cassandra/metadata
cassandra/query
cassandra/pool
cassandra/concurrent
cassandra/connection
cassandra/io/asyncorereactor
cassandra/io/libevreactor

16
example.py Executable file → Normal file
View File

@@ -1,3 +1,17 @@
# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python
import logging
@@ -56,7 +70,7 @@ def main():
for i in range(10):
log.info("inserting row %d" % i)
session.execute(query, dict(key="key%d" % i, a='a', b='b'))
session.execute(prepared.bind(("key%d" % i, 'b', 'b')))
session.execute(prepared, ("key%d" % i, 'b', 'b'))
future = session.execute_async("SELECT * FROM mytable")
log.info("key\tcol1\tcol2")

View File

@@ -19,6 +19,7 @@ from cassandra.concurrent import execute_concurrent_with_args
from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
TokenAwarePolicy, WhiteListRoundRobinPolicy)
from cassandra.query import SimpleStatement
from tests.integration import use_multidc, use_singledc, PROTOCOL_VERSION
from tests.integration.long.utils import (wait_for_up, create_schema,
CoordinatorStats, force_stop,

View File

@@ -416,7 +416,7 @@ class TestCodeCoverage(unittest.TestCase):
get_replicas = cluster.metadata.token_map.get_replicas
for ksname in ('test1rf', 'test2rf', 'test3rf'):
self.assertNotEqual(list(get_replicas('test3rf', ring[0])), [])
self.assertNotEqual(list(get_replicas(ksname, ring[0])), [])
for i, token in enumerate(ring):
self.assertEqual(set(get_replicas('test3rf', token)), set(owners))

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.integration import PROTOCOL_VERSION
try:
@@ -23,6 +24,7 @@ from cassandra import InvalidRequest
from cassandra.cluster import Cluster
from cassandra.query import PreparedStatement
class PreparedStatementTests(unittest.TestCase):
def test_basic(self):
@@ -69,6 +71,32 @@ class PreparedStatementTests(unittest.TestCase):
results = session.execute(bound)
self.assertEquals(results, [('a', 'b', 'c')])
# test with new dict binding
prepared = session.prepare(
"""
INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?)
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind({
'a': 'x',
'b': 'y',
'c': 'z'
})
session.execute(bound)
prepared = session.prepare(
"""
SELECT * FROM cf0 WHERE a=?
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind({'a': 'x'})
results = session.execute(bound)
self.assertEquals(results, [('x', 'y', 'z')])
def test_missing_primary_key(self):
"""
Ensure an InvalidRequest is thrown
@@ -87,6 +115,25 @@ class PreparedStatementTests(unittest.TestCase):
bound = prepared.bind((1,))
self.assertRaises(InvalidRequest, session.execute, bound)
def test_missing_primary_key_dicts(self):
"""
Ensure an InvalidRequest is thrown
when prepared statements are missing the primary key
with dict bindings
"""
cluster = Cluster()
session = cluster.connect()
prepared = session.prepare(
"""
INSERT INTO test3rf.test (v) VALUES (?)
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind({'v': 1})
self.assertRaises(InvalidRequest, session.execute, bound)
def test_too_many_bind_values(self):
"""
Ensure a ValueError is thrown when attempting to bind too many variables
@@ -103,6 +150,27 @@ class PreparedStatementTests(unittest.TestCase):
self.assertIsInstance(prepared, PreparedStatement)
self.assertRaises(ValueError, prepared.bind, (1,2))
def test_too_many_bind_values_dicts(self):
"""
Ensure a ValueError is thrown when attempting to bind too many variables
with dict bindings
"""
cluster = Cluster()
session = cluster.connect()
prepared = session.prepare(
"""
INSERT INTO test3rf.test (v) VALUES (?)
""")
self.assertIsInstance(prepared, PreparedStatement)
self.assertRaises(ValueError, prepared.bind, {'k': 1, 'v': 2})
# also catch too few variables with dicts
self.assertIsInstance(prepared, PreparedStatement)
self.assertRaises(KeyError, prepared.bind, {})
def test_none_values(self):
"""
Ensure binding None is handled correctly
@@ -130,6 +198,35 @@ class PreparedStatementTests(unittest.TestCase):
results = session.execute(bound)
self.assertEquals(results[0].v, None)
def test_none_values_dicts(self):
"""
Ensure binding None is handled correctly with dict bindings
"""
cluster = Cluster()
session = cluster.connect()
# test with new dict binding
prepared = session.prepare(
"""
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind({'k': 1, 'v': None})
session.execute(bound)
prepared = session.prepare(
"""
SELECT * FROM test3rf.test WHERE k=?
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind({'k': 1})
results = session.execute(bound)
self.assertEquals(results[0].v, None)
def test_async_binding(self):
"""
Ensure None binding over async queries
@@ -156,3 +253,31 @@ class PreparedStatementTests(unittest.TestCase):
future = session.execute_async(prepared, (873,))
results = future.result()
self.assertEquals(results[0].v, None)
def test_async_binding_dicts(self):
"""
Ensure None binding over async queries with dict bindings
"""
cluster = Cluster()
session = cluster.connect()
prepared = session.prepare(
"""
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
""")
self.assertIsInstance(prepared, PreparedStatement)
future = session.execute_async(prepared, {'k': 873, 'v': None})
future.result()
prepared = session.prepare(
"""
SELECT * FROM test3rf.test WHERE k=?
""")
self.assertIsInstance(prepared, PreparedStatement)
future = session.execute_async(prepared, {'k': 873})
results = future.result()
self.assertEquals(results[0].v, None)

View File

@@ -104,13 +104,12 @@ class MockConnection(object):
[["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"]],
["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"]]]
]
def wait_for_responses(self, peer_query, local_query, timeout=None):
local_response = ResultMessage(
kind=RESULT_KIND_ROWS, results=self.local_results)
peer_response = ResultMessage(
kind=RESULT_KIND_ROWS, results=self.peer_results)
return (peer_response, local_response)
self.wait_for_responses = Mock(return_value=(peer_response, local_response))
class FakeTime(object):
@@ -136,6 +135,38 @@ class ControlConnectionTest(unittest.TestCase):
self.control_connection._connection = self.connection
self.control_connection._time = self.time
def _get_matching_schema_preloaded_results(self):
local_results = [
["schema_version", "cluster_name", "data_center", "rack", "partitioner", "tokens"],
[["a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", ["0", "100", "200"]]]
]
local_response = ResultMessage(kind=RESULT_KIND_ROWS, results=local_results)
peer_results = [
["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens"],
[["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"]],
["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"]]]
]
peer_response = ResultMessage(kind=RESULT_KIND_ROWS, results=peer_results)
return (peer_response, local_response)
def _get_nonmatching_schema_preloaded_results(self):
local_results = [
["schema_version", "cluster_name", "data_center", "rack", "partitioner", "tokens"],
[["a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", ["0", "100", "200"]]]
]
local_response = ResultMessage(kind=RESULT_KIND_ROWS, results=local_results)
peer_results = [
["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens"],
[["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"]],
["192.168.1.2", "10.0.0.2", "b", "dc1", "rack1", ["2", "102", "202"]]]
]
peer_response = ResultMessage(kind=RESULT_KIND_ROWS, results=peer_results)
return (peer_response, local_response)
def test_wait_for_schema_agreement(self):
"""
Basic test with all schema versions agreeing
@@ -144,6 +175,29 @@ class ControlConnectionTest(unittest.TestCase):
# the control connection should not have slept at all
self.assertEqual(self.time.clock, 0)
def test_wait_for_schema_agreement_uses_preloaded_results_if_given(self):
"""
wait_for_schema_agreement uses preloaded results if given for shared table queries
"""
preloaded_results = self._get_matching_schema_preloaded_results()
self.assertTrue(self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results))
# the control connection should not have slept at all
self.assertEqual(self.time.clock, 0)
# the connection should not have made any queries if given preloaded results
self.assertEqual(self.connection.wait_for_responses.call_count, 0)
def test_wait_for_schema_agreement_falls_back_to_querying_if_schemas_dont_match_preloaded_result(self):
"""
wait_for_schema_agreement requery if schema does not match using preloaded results
"""
preloaded_results = self._get_nonmatching_schema_preloaded_results()
self.assertTrue(self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results))
# the control connection should not have slept at all
self.assertEqual(self.time.clock, 0)
self.assertEqual(self.connection.wait_for_responses.call_count, 1)
def test_wait_for_schema_agreement_fails(self):
"""
Make sure the control connection sleeps and retries
@@ -211,6 +265,32 @@ class ControlConnectionTest(unittest.TestCase):
self.assertEqual(host.datacenter, "dc1")
self.assertEqual(host.rack, "rack1")
self.assertEqual(self.connection.wait_for_responses.call_count, 1)
def test_refresh_nodes_and_tokens_uses_preloaded_results_if_given(self):
"""
refresh_nodes_and_tokens uses preloaded results if given for shared table queries
"""
preloaded_results = self._get_matching_schema_preloaded_results()
self.control_connection._refresh_node_list_and_token_map(self.connection, preloaded_results=preloaded_results)
meta = self.cluster.metadata
self.assertEqual(meta.partitioner, 'Murmur3Partitioner')
self.assertEqual(meta.cluster_name, 'foocluster')
# check token map
self.assertEqual(sorted(meta.all_hosts()), sorted(meta.token_map.keys()))
for token_list in meta.token_map.values():
self.assertEqual(3, len(token_list))
# check datacenter/rack
for host in meta.all_hosts():
self.assertEqual(host.datacenter, "dc1")
self.assertEqual(host.rack, "rack1")
# the connection should not have made any queries if given preloaded results
self.assertEqual(self.connection.wait_for_responses.call_count, 0)
def test_refresh_nodes_and_tokens_no_partitioner(self):
"""
Test handling of an unknown partitioner.