Merge branch 'master' into 2.0

This commit is contained in:
Tyler Hobbs
2014-02-28 13:16:43 -06:00
23 changed files with 716 additions and 224 deletions

1
.gitignore vendored
View File

@@ -14,3 +14,4 @@ nosetests.xml
cover/
docs/_build/
tests/integration/ccm
setuptools*.tar.gz

View File

@@ -1,6 +1,38 @@
1.0.2
=====
In Progress
Bug Fixes
---------
* With asyncorereactor, correctly handle EAGAIN/EWOULDBLOCK when the message from
Cassandra is a multiple of the read buffer size. Previously, if no more data
became available to read on the socket, the message would never be processed,
resulting in an OperationTimedOut error.
* Double quote keyspace, table and column names that require them (those using
uppercase characters or keywords) when generating CREATE statements through
KeyspaceMetadata and TableMetadata.
* Decode TimestampType as DateType. (Cassandra replaced DateType with
TimestampType to fix sorting of pre-unix epoch dates in CASSANDRA-5723.)
* Handle latest table options when parsing the schema and generating
CREATE statements.
* Avoid 'Set changed size during iteration' during query plan generation
when hosts go up or down
Other
-----
* Remove ignored ``tracing_enabled`` parameter for ``SimpleStatement``. The
correct way to trace a query is by setting the ``trace`` argument to ``True``
in ``Session.execute()`` and ``Session.execute_async()``.
* Raise TypeError instead of cassandra.query.InvalidParameterTypeError when
a parameter for a prepared statement has the wrong type; remove
cassandra.query.InvalidParameterTypeError.
* More consistent type checking for query parameters
* Add option to a return special object for empty string values for non-string
columns
1.0.1
=====
(In Progress)
Feb 19, 2014
Bug Fixes
---------
@@ -9,12 +41,24 @@ Bug Fixes
* Always close socket when defuncting error'ed connections to avoid a potential
file descriptor leak
* Handle "custom" types (such as the replaced DateType) correctly
* With libevreactor, correctly handle EAGAIN/EWOULDBLOCK when the message from
Cassandra is a multiple of the read buffer size. Previously, if no more data
became available to read on the socket, the message would never be processed,
resulting in an OperationTimedOut error.
* Don't break tracing when a Session's row_factory is not the default
namedtuple_factory.
* Handle data that is already utf8-encoded for UTF8Type values
* Fix token-aware routing for tokens that fall before the first node token in
the ring and tokens that exactly match a node's token
* Tolerate null source_elapsed values for Trace events. These may not be
set when events complete after the main operation has already completed.
Other
-----
* Skip sending OPTIONS message on connection creation if compression is
disabled or not available and a CQL version has not been explicitly
set
* Add details about errors and the last queried host to ``OperationTimedOut``
1.0.0 Final
===========

View File

@@ -16,12 +16,6 @@ Releasing
so that it looks like ``(x, y, z, 'post')``
* Commit and push
Running the Tests
=================
In order for the extensions to be built and used in the test, run::
python setup.py nosetests
Building the Docs
=================
Sphinx is required to build the docs. You probably want to install through apt,
@@ -48,3 +42,26 @@ For example::
cp -R docs/_build/1.0.0-beta1/* ~/python-driver-docs/
cd ~/python-driver-docs
git push origin gh-pages
Running the Tests
=================
In order for the extensions to be built and used in the test, run::
python setup.py nosetests
You can run a specific test module or package like so::
python setup.py nosetests -w tests/unit/
If you want to test all of python 2.6, 2.7, and pypy, use tox (this is what
TravisCI runs)::
tox
By default, tox only runs the unit tests because I haven't put in the effort
to get the integration tests to run on TravicCI. However, the integration
tests should work locally. To run them, edit the following line in tox.ini::
commands = {envpython} setup.py build_ext --inplace nosetests --verbosity=2 tests/unit/
and change ``tests/unit/`` to ``tests/``.

View File

@@ -9,7 +9,7 @@ class NullHandler(logging.Handler):
logging.getLogger('cassandra').addHandler(NullHandler())
__version_info__ = (1, 0, 0, 'post')
__version_info__ = (1, 0, 1, 'post')
__version__ = '.'.join(map(str, __version_info__))
@@ -226,4 +226,19 @@ class OperationTimedOut(Exception):
to complete. This is not an error generated by Cassandra, only
the driver.
"""
pass
errors = None
"""
A dict of errors keyed by the :class:`~.Host` against which they occurred.
"""
last_host = None
"""
The last :class:`~.Host` this operation was attempted against.
"""
def __init__(self, errors=None, last_host=None):
self.errors = errors
self.last_host = last_host
message = "errors=%s, last_host=%s" % (self.errors, self.last_host)
Exception.__init__(self, message)

View File

@@ -858,8 +858,8 @@ class Cluster(object):
"statement on host %s: %r", host, response)
log.debug("Done preparing all known prepared statements against host %s", host)
except OperationTimedOut:
log.warn("Timed out trying to prepare all statements on host %s", host)
except OperationTimedOut as timeout:
log.warn("Timed out trying to prepare all statements on host %s: %s", host, timeout)
except (ConnectionException, socket.error) as exc:
log.warn("Error trying to prepare all statements on host %s: %r", host, exc)
except Exception:
@@ -1045,6 +1045,12 @@ class Session(object):
... log.exception("Operation failed:")
"""
future = self._create_response_future(query, parameters, trace)
future.send_request()
return future
def _create_response_future(self, query, parameters, trace):
""" Returns the ResponseFuture before calling send_request() on it """
prepared_statement = None
if isinstance(query, basestring):
query = SimpleStatement(query)
@@ -1066,11 +1072,9 @@ class Session(object):
if trace:
message.tracing = True
future = ResponseFuture(
return ResponseFuture(
self, message, query, self.default_timeout, metrics=self._metrics,
prepared_statement=prepared_statement)
future.send_request()
return future
def prepare(self, query):
"""
@@ -1650,8 +1654,9 @@ class ControlConnection(object):
timeout = min(2.0, total_timeout - elapsed)
peers_result, local_result = connection.wait_for_responses(
peers_query, local_query, timeout=timeout)
except OperationTimedOut:
log.debug("[control connection] Timed out waiting for response during schema agreement check")
except OperationTimedOut as timeout:
log.debug("[control connection] Timed out waiting for " \
"response during schema agreement check: %s", timeout)
elapsed = self._time.time() - start
continue
@@ -2195,7 +2200,7 @@ class ResponseFuture(object):
elif self._final_exception:
raise self._final_exception
else:
raise OperationTimedOut()
raise OperationTimedOut(errors=self._errors, last_host=self._current_host)
def get_query_trace(self, max_wait=None):
"""

View File

@@ -152,12 +152,35 @@ def lookup_casstype(casstype):
raise ValueError("Don't know how to parse type string %r: %s" % (casstype, e))
class EmptyValue(object):
""" See _CassandraType.support_empty_values """
def __str__(self):
return "EMPTY"
__repr__ = __str__
EMPTY = EmptyValue()
class _CassandraType(object):
__metaclass__ = CassandraTypeType
subtypes = ()
num_subtypes = 0
empty_binary_ok = False
support_empty_values = False
"""
Back in the Thrift days, empty strings were used for "null" values of
all types, including non-string types. For most users, an empty
string value in an int column is the same as being null/not present,
so the driver normally returns None in this case. (For string-like
types, it *will* return an empty string by default instead of None.)
To avoid this behavior, set this to :const:`True`. Instead of returning
None for empty string values, the EMPTY singleton (an instance
of EmptyValue) will be returned.
"""
def __init__(self, val):
self.val = self.validate(val)
@@ -181,8 +204,10 @@ class _CassandraType(object):
for more information. This method differs in that if None or the empty
string is passed in, None may be returned.
"""
if byts is None or (byts == '' and not cls.empty_binary_ok):
if byts is None:
return None
elif byts == '' and not cls.empty_binary_ok:
return EMPTY if cls.support_empty_values else None
return cls.deserialize(byts)
@classmethod
@@ -315,7 +340,10 @@ class DecimalType(_CassandraType):
@staticmethod
def serialize(dec):
sign, digits, exponent = dec.as_tuple()
try:
sign, digits, exponent = dec.as_tuple()
except AttributeError:
raise TypeError("Non-Decimal type received for Decimal value")
unscaled = int(''.join([str(digit) for digit in digits]))
if sign:
unscaled *= -1
@@ -333,7 +361,10 @@ class UUIDType(_CassandraType):
@staticmethod
def serialize(uuid):
return uuid.bytes
try:
return uuid.bytes
except AttributeError:
raise TypeError("Got a non-UUID object for a UUID value")
class BooleanType(_CassandraType):
@@ -349,7 +380,7 @@ class BooleanType(_CassandraType):
@staticmethod
def serialize(truth):
return int8_pack(bool(truth))
return int8_pack(truth)
class AsciiType(_CassandraType):
@@ -503,6 +534,10 @@ class DateType(_CassandraType):
return int64_pack(long(converted))
class TimestampType(DateType):
pass
class TimeUUIDType(DateType):
typename = 'timeuuid'
@@ -515,7 +550,10 @@ class TimeUUIDType(DateType):
@staticmethod
def serialize(timeuuid):
return timeuuid.bytes
try:
return timeuuid.bytes
except AttributeError:
raise TypeError("Got a non-UUID object for a UUID value")
class UTF8Type(_CassandraType):
@@ -528,7 +566,11 @@ class UTF8Type(_CassandraType):
@staticmethod
def serialize(ustr):
return ustr.encode('utf8')
try:
return ustr.encode('utf-8')
except UnicodeDecodeError:
# already utf-8
return ustr
class VarcharType(UTF8Type):
@@ -578,6 +620,9 @@ class _SimpleParameterizedType(_ParameterizedType):
@classmethod
def serialize_safe(cls, items):
if isinstance(items, basestring):
raise TypeError("Received a string for a type that expects a sequence")
subtype, = cls.subtypes
buf = StringIO()
buf.write(uint16_pack(len(items)))
@@ -634,7 +679,11 @@ class MapType(_ParameterizedType):
subkeytype, subvaltype = cls.subtypes
buf = StringIO()
buf.write(uint16_pack(len(themap)))
for key, val in themap.iteritems():
try:
items = themap.iteritems()
except AttributeError:
raise TypeError("Got a non-map object for a map value")
for key, val in items:
keybytes = subkeytype.to_binary(key)
valbytes = subvaltype.to_binary(val)
buf.write(uint16_pack(len(keybytes)))

View File

@@ -260,7 +260,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
except socket.error as err:
if err.args[0] not in NONBLOCKING:
self.defunct(err)
return
return
if self._iobuf.tell():
while True:

View File

@@ -313,7 +313,7 @@ class LibevConnection(Connection):
except socket.error as err:
if err.args[0] not in NONBLOCKING:
self.defunct(err)
return
return
if self._iobuf.tell():
while True:

View File

@@ -1,4 +1,4 @@
from bisect import bisect_left
from bisect import bisect_right
from collections import defaultdict
try:
from collections import OrderedDict
@@ -274,13 +274,13 @@ class Metadata(object):
column_meta = self._build_column_metadata(table_meta, col_row)
table_meta.columns[column_meta.name] = column_meta
table_meta.options = self._build_table_options(row, is_compact)
table_meta.options = self._build_table_options(row)
table_meta.is_compact_storage = is_compact
return table_meta
def _build_table_options(self, row, is_compact_storage):
def _build_table_options(self, row):
""" Setup the mostly-non-schema table options, like caching settings """
options = dict((o, row.get(o)) for o in TableMetadata.recognized_options)
options["is_compact_storage"] = is_compact_storage
options = dict((o, row.get(o)) for o in TableMetadata.recognized_options if o in row)
return options
def _build_column_metadata(self, table_metadata, row):
@@ -564,9 +564,10 @@ class KeyspaceMetadata(object):
return "\n".join([self.as_cql_query()] + [t.export_as_string() for t in self.tables.values()])
def as_cql_query(self):
ret = "CREATE KEYSPACE %s WITH REPLICATION = %s " % \
(self.name, self.replication_strategy.export_for_schema())
return ret + (' AND DURABLE_WRITES = %s;' % ("true" if self.durable_writes else "false"))
ret = "CREATE KEYSPACE %s WITH replication = %s " % (
protect_name(self.name),
self.replication_strategy.export_for_schema())
return ret + (' AND durable_writes = %s;' % ("true" if self.durable_writes else "false"))
class TableMetadata(object):
@@ -610,6 +611,8 @@ class TableMetadata(object):
A dict mapping column names to :class:`.ColumnMetadata` instances.
"""
is_compact_storage = False
options = None
"""
A dict mapping table option names to their specific settings for this
@@ -617,11 +620,28 @@ class TableMetadata(object):
"""
recognized_options = (
"comment", "read_repair_chance", # "local_read_repair_chance",
"replicate_on_write", "gc_grace_seconds", "bloom_filter_fp_chance",
"caching", "compaction_strategy_class", "compaction_strategy_options",
"min_compaction_threshold", "max_compression_threshold",
"compression_parameters")
"comment",
"read_repair_chance",
"dclocal_read_repair_chance",
"replicate_on_write",
"gc_grace_seconds",
"bloom_filter_fp_chance",
"caching",
"compaction_strategy_class",
"compaction_strategy_options",
"min_compaction_threshold",
"max_compression_threshold",
"compression_parameters",
"min_index_interval",
"max_index_interval",
"index_interval",
"speculative_retry",
"rows_per_partition_to_cache",
"memtable_flush_period_in_ms",
"populate_io_cache_on_flush",
"compaction",
"compression",
"default_time_to_live")
def __init__(self, keyspace_metadata, name, partition_key=None, clustering_key=None, columns=None, options=None):
self.keyspace = keyspace_metadata
@@ -653,7 +673,10 @@ class TableMetadata(object):
creations are not included). If `formatted` is set to :const:`True`,
extra whitespace will be added to make the query human readable.
"""
ret = "CREATE TABLE %s.%s (%s" % (self.keyspace.name, self.name, "\n" if formatted else "")
ret = "CREATE TABLE %s.%s (%s" % (
protect_name(self.keyspace.name),
protect_name(self.name),
"\n" if formatted else "")
if formatted:
column_join = ",\n"
@@ -664,7 +687,7 @@ class TableMetadata(object):
columns = []
for col in self.columns.values():
columns.append("%s %s" % (col.name, col.typestring))
columns.append("%s %s" % (protect_name(col.name), col.typestring))
if len(self.partition_key) == 1 and not self.clustering_key:
columns[0] += " PRIMARY KEY"
@@ -676,12 +699,12 @@ class TableMetadata(object):
ret += "%s%sPRIMARY KEY (" % (column_join, padding)
if len(self.partition_key) > 1:
ret += "(%s)" % ", ".join(col.name for col in self.partition_key)
ret += "(%s)" % ", ".join(protect_name(col.name) for col in self.partition_key)
else:
ret += self.partition_key[0].name
if self.clustering_key:
ret += ", %s" % ", ".join(col.name for col in self.clustering_key)
ret += ", %s" % ", ".join(protect_name(col.name) for col in self.clustering_key)
ret += ")"
@@ -689,15 +712,15 @@ class TableMetadata(object):
ret += "%s) WITH " % ("\n" if formatted else "")
option_strings = []
if self.options.get("is_compact_storage"):
if self.is_compact_storage:
option_strings.append("COMPACT STORAGE")
if self.clustering_key:
cluster_str = "CLUSTERING ORDER BY "
clustering_names = self.protect_names([c.name for c in self.clustering_key])
clustering_names = protect_names([c.name for c in self.clustering_key])
if self.options.get("is_compact_storage") and \
if self.is_compact_storage and \
not issubclass(self.comparator, types.CompositeType):
subtypes = [self.comparator]
else:
@@ -711,52 +734,61 @@ class TableMetadata(object):
cluster_str += "(%s)" % ", ".join(inner)
option_strings.append(cluster_str)
option_strings.extend(map(self._make_option_str, self.recognized_options))
option_strings = filter(lambda x: x is not None, option_strings)
option_strings.extend(self._make_option_strings())
join_str = "\n AND " if formatted else " AND "
ret += join_str.join(option_strings)
return ret
def _make_option_str(self, name):
value = self.options.get(name)
if value is not None:
if name == "comment":
value = value or ""
return "%s = %s" % (name, self.protect_value(value))
def _make_option_strings(self):
ret = []
for name, value in sorted(self.options.items()):
if value is not None:
if name == "comment":
value = value or ""
ret.append("%s = %s" % (name, protect_value(value)))
def protect_name(self, name):
if isinstance(name, unicode):
name = name.encode('utf8')
return self.maybe_escape_name(name)
return ret
def protect_names(self, names):
return map(self.protect_name, names)
def protect_value(self, value):
if value is None:
return 'NULL'
if isinstance(value, (int, float, bool)):
return str(value)
return "'%s'" % value.replace("'", "''")
def protect_name(name):
if isinstance(name, unicode):
name = name.encode('utf8')
return maybe_escape_name(name)
valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$')
def is_valid_name(self, name):
if name is None:
return False
if name.lower() in _keywords - _unreserved_keywords:
return False
return self.valid_cql3_word_re.match(name) is not None
def protect_names(names):
return map(protect_name, names)
def maybe_escape_name(self, name):
if self.is_valid_name(name):
return name
return self.escape_name(name)
def escape_name(self, name):
return '"%s"' % (name.replace('"', '""'),)
def protect_value(value):
if value is None:
return 'NULL'
if isinstance(value, (int, float, bool)):
return str(value)
return "'%s'" % value.replace("'", "''")
valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$')
def is_valid_name(name):
if name is None:
return False
if name.lower() in _keywords - _unreserved_keywords:
return False
return valid_cql3_word_re.match(name) is not None
def maybe_escape_name(name):
if is_valid_name(name):
return name
return escape_name(name)
def escape_name(name):
return '"%s"' % (name.replace('"', '""'),)
class ColumnMetadata(object):
@@ -825,7 +857,11 @@ class IndexMetadata(object):
Returns a CQL query that can be used to recreate this index.
"""
table = self.column.table
return "CREATE INDEX %s ON %s.%s (%s)" % (self.name, table.keyspace.name, table.name, self.column.name)
return "CREATE INDEX %s ON %s.%s (%s)" % (
self.name, # Cassandra doesn't like quoted index names for some reason
protect_name(table.keyspace.name),
protect_name(table.name),
protect_name(self.column.name))
class TokenMap(object):
@@ -890,10 +926,11 @@ class TokenMap(object):
if tokens_to_hosts is None:
return []
point = bisect_left(self.ring, token)
if point == 0 and token != self.ring[0]:
return tokens_to_hosts[self.ring[-1]]
elif point == len(self.ring):
# token range ownership is exclusive on the LHS (the start token), so
# we use bisect_right, which, in the case of a tie/exact match,
# picks an insertion point to the right of the existing match
point = bisect_right(self.ring, token)
if point == len(self.ring):
return tokens_to_hosts[self.ring[0]]
else:
return tokens_to_hosts[self.ring[point]]

View File

@@ -1,6 +1,7 @@
from itertools import islice, cycle, groupby, repeat
import logging
from random import randint
from threading import Lock
from cassandra import ConsistencyLevel
@@ -78,6 +79,11 @@ class LoadBalancingPolicy(HostStateListener):
custom behavior.
"""
_hosts_lock = None
def __init__(self):
self._hosts_lock = Lock()
def distance(self, host):
"""
Returns a measure of how remote a :class:`~.pool.Host` is in
@@ -130,7 +136,7 @@ class RoundRobinPolicy(LoadBalancingPolicy):
"""
def populate(self, cluster, hosts):
self._live_hosts = set(hosts)
self._live_hosts = frozenset(hosts)
if len(hosts) <= 1:
self._position = 0
else:
@@ -145,24 +151,29 @@ class RoundRobinPolicy(LoadBalancingPolicy):
pos = self._position
self._position += 1
length = len(self._live_hosts)
hosts = self._live_hosts
length = len(hosts)
if length:
pos %= length
return list(islice(cycle(self._live_hosts), pos, pos + length))
return list(islice(cycle(hosts), pos, pos + length))
else:
return []
def on_up(self, host):
self._live_hosts.add(host)
with self._hosts_lock:
self._live_hosts = self._live_hosts.union((host, ))
def on_down(self, host):
self._live_hosts.discard(host)
with self._hosts_lock:
self._live_hosts = self._live_hosts.difference((host, ))
def on_add(self, host):
self._live_hosts.add(host)
with self._hosts_lock:
self._live_hosts = self._live_hosts.union((host, ))
def on_remove(self, host):
self._live_hosts.remove(host)
with self._hosts_lock:
self._live_hosts = self._live_hosts.difference((host, ))
class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
@@ -191,13 +202,14 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
self.local_dc = local_dc
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
self._dc_live_hosts = {}
LoadBalancingPolicy.__init__(self)
def _dc(self, host):
return host.datacenter or self.local_dc
def populate(self, cluster, hosts):
for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)):
self._dc_live_hosts[dc] = set(dc_hosts)
self._dc_live_hosts[dc] = frozenset(dc_hosts)
# position is currently only used for local hosts
local_live = self._dc_live_hosts.get(self.local_dc)
@@ -244,16 +256,28 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
yield host
def on_up(self, host):
self._dc_live_hosts.setdefault(self._dc(host), set()).add(host)
dc = self._dc(host)
with self._hosts_lock:
current_hosts = self._dc_live_hosts.setdefault(dc, frozenset())
self._dc_live_hosts[dc] = current_hosts.union((host, ))
def on_down(self, host):
self._dc_live_hosts.setdefault(self._dc(host), set()).discard(host)
dc = self._dc(host)
with self._hosts_lock:
current_hosts = self._dc_live_hosts.setdefault(dc, frozenset())
self._dc_live_hosts[dc] = current_hosts.difference((host, ))
def on_add(self, host):
self._dc_live_hosts.setdefault(self._dc(host), set()).add(host)
dc = self._dc(host)
with self._hosts_lock:
current_hosts = self._dc_live_hosts.setdefault(dc, frozenset())
self._dc_live_hosts[dc] = current_hosts.union((host, ))
def on_remove(self, host):
self._dc_live_hosts.setdefault(self._dc(host), set()).discard(host)
dc = self._dc(host)
with self._hosts_lock:
current_hosts = self._dc_live_hosts.setdefault(dc, frozenset())
self._dc_live_hosts[dc] = current_hosts.difference((host, ))
class TokenAwarePolicy(LoadBalancingPolicy):
@@ -351,12 +375,10 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy):
:param hosts: List of hosts
"""
self._allowed_hosts = hosts
RoundRobinPolicy.__init__(self)
def populate(self, cluster, hosts):
self._live_hosts = set()
for host in hosts:
if host.address in self._allowed_hosts:
self._live_hosts.add(host)
self._live_hosts = frozenset(h for h in hosts if h.address in self._allowed_hosts)
if len(hosts) <= 1:
self._position = 0
@@ -371,14 +393,11 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy):
def on_up(self, host):
if host.address in self._allowed_hosts:
self._live_hosts.add(host)
RoundRobinPolicy.on_up(self, host)
def on_add(self, host):
if host.address in self._allowed_hosts:
self._live_hosts.add(host)
def on_remove(self, host):
self._live_hosts.discard(host)
RoundRobinPolicy.on_add(self, host)
class ConvictionPolicy(object):

View File

@@ -8,10 +8,10 @@ from datetime import datetime, timedelta
import struct
import time
from cassandra import ConsistencyLevel
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cqltypes import unix_time_from_uuid1
from cassandra.decoder import (cql_encoders, cql_encode_object,
cql_encode_sequence)
cql_encode_sequence, named_tuple_factory)
import logging
log = logging.getLogger(__name__)
@@ -45,10 +45,8 @@ class Statement(object):
_routing_key = None
def __init__(self, retry_policy=None, tracing_enabled=False,
consistency_level=None, routing_key=None):
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None):
self.retry_policy = retry_policy
self.tracing_enabled = tracing_enabled
if consistency_level is not None:
self.consistency_level = consistency_level
self._routing_key = routing_key
@@ -240,10 +238,9 @@ class BoundStatement(Statement):
expected_type = col_type
actual_type = type(value)
err = InvalidParameterTypeError(col_name=col_name,
expected_type=expected_type,
actual_type=actual_type)
raise err
message = ('Received an argument of invalid type for column "%s". '
'Expected: %s, Got: %s' % (col_name, expected_type, actual_type))
raise TypeError(message)
return self
@@ -321,24 +318,6 @@ class TraceUnavailable(Exception):
pass
class InvalidParameterTypeError(TypeError):
"""
Raised when a used tries to bind a prepared statement with an argument of an
invalid type.
"""
def __init__(self, col_name, expected_type, actual_type):
self.col_name = col_name
self.expected_type = expected_type
self.actual_type = actual_type
values = (self.col_name, self.expected_type, self.actual_type)
message = ('Received an argument of invalid type for column "%s". '
'Expected: %s, Got: %s' % values)
super(InvalidParameterTypeError, self).__init__(message)
class QueryTrace(object):
"""
A trace of the duration and events that occurred when executing
@@ -409,9 +388,13 @@ class QueryTrace(object):
attempt = 0
start = time.time()
while True:
if max_wait is not None and time.time() - start >= max_wait:
time_spent = time.time() - start
if max_wait is not None and time_spent >= max_wait:
raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,))
session_results = self._session.execute(self._SELECT_SESSIONS_FORMAT, (self.trace_id,))
session_results = self._execute(
self._SELECT_SESSIONS_FORMAT, (self.trace_id,), time_spent, max_wait)
if not session_results or session_results[0].duration is None:
time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt))
attempt += 1
@@ -424,11 +407,25 @@ class QueryTrace(object):
self.coordinator = session_row.coordinator
self.parameters = session_row.parameters
event_results = self._session.execute(self._SELECT_EVENTS_FORMAT, (self.trace_id,))
time_spent = time.time() - start
event_results = self._execute(
self._SELECT_EVENTS_FORMAT, (self.trace_id,), time_spent, max_wait)
self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread)
for r in event_results)
break
def _execute(self, query, parameters, time_spent, max_wait):
# in case the user switched the row factory, set it to namedtuple for this query
future = self._session._create_response_future(query, parameters, trace=False)
future.row_factory = named_tuple_factory
future.send_request()
timeout = (max_wait - time_spent) if max_wait is not None else None
try:
return future.result(timeout=timeout)
except OperationTimedOut:
raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,))
def __str__(self):
return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \
% (self.request_type, self.trace_id, self.coordinator, self.started_at,
@@ -471,7 +468,10 @@ class TraceEvent(object):
self.description = description
self.datetime = datetime.utcfromtimestamp(unix_time_from_uuid1(timeuuid))
self.source = source
self.source_elapsed = timedelta(microseconds=source_elapsed)
if source_elapsed is not None:
self.source_elapsed = timedelta(microseconds=source_elapsed)
else:
self.source_elapsed = None
self.thread_name = thread_name
def __str__(self):

View File

@@ -1,9 +1,6 @@
API Documentation
=================
Cassandra Modules
-----------------
.. toctree::
:maxdepth: 2

View File

@@ -123,7 +123,7 @@ when you execute:
It is translated to the following CQL query:
.. code-block:: SQL
.. code-block::
INSERT INTO users (name, credits, user_id)
VALUES ('John O''Reilly', 42, 2644bada-852c-11e3-89fb-e0b9a54a6d93)
@@ -297,7 +297,7 @@ There are a few important things to remember when working with callbacks:
Setting a Consistency Level
^^^^^^^^^^^^^^^^^^^^^^^^^^^
---------------------------
The consistency level used for a query determines how many of the
replicas of the data you are interacting with need to respond for
the query to be considered a success.
@@ -345,3 +345,34 @@ is different than for simple, non-prepared statements (although future versions
of the driver may use the same placeholders for both). Cassandra 2.0 added
support for named placeholders; the 1.0 version of the driver does not support
them, but the 2.0 version will.
Setting a Consistency Level with Prepared Statements
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
To specify a consistency level for prepared statements, you have two options.
The first is to set a default consistency level for every execution of the
prepared statement:
.. code-block:: python
from cassandra import ConsistencyLevel
cluster = Cluster()
session = cluster.connect("mykeyspace")
user_lookup_stmt = session.prepare("SELECT * FROM users WHERE user_id=?")
user_lookup_stmt.consistency_level = ConsistencyLevel.QUORUM
# these will both use QUORUM
user1 = session.execute(user_lookup_stmt, [user_id1])[0]
user2 = session.execute(user_lookup_stmt, [user_id2])[0]
The second option is to create a :class:`~.BoundStatement` from the
:class:`~.PreparedStatement` and binding paramaters and set a consistency
level on that:
.. code-block:: python
# override the QUORUM default
user3_lookup = user_lookup_stmt.bind([user_id3]
user3_lookup.consistency_level = ConsistencyLevel.ALL
user3 = session.execute(user3_lookup)

View File

@@ -1,8 +1,6 @@
Python Cassandra Driver
=======================
Contents:
.. toctree::
:maxdepth: 2

View File

@@ -31,6 +31,7 @@ class LargeDataTests(unittest.TestCase):
def make_session_and_keyspace(self):
cluster = Cluster()
session = cluster.connect()
session.default_timeout = 20.0 # increase the default timeout
session.row_factory = dict_factory
create_schema(session, self.keyspace)

View File

@@ -138,8 +138,8 @@ class SchemaMetadataTest(unittest.TestCase):
self.assertEqual([], tablemeta.clustering_key)
self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys()))
for option in TableMetadata.recognized_options:
self.assertTrue(option in tablemeta.options)
for option in tablemeta.options:
self.assertIn(option, TableMetadata.recognized_options)
self.check_create_statement(tablemeta, create_statement)
@@ -295,7 +295,7 @@ class TestCodeCoverage(unittest.TestCase):
cluster = Cluster()
cluster.connect()
self.assertIsInstance(cluster.metadata.export_schema_as_string(), unicode)
self.assertIsInstance(cluster.metadata.export_schema_as_string(), basestring)
def test_export_keyspace_schema(self):
"""
@@ -307,8 +307,47 @@ class TestCodeCoverage(unittest.TestCase):
for keyspace in cluster.metadata.keyspaces:
keyspace_metadata = cluster.metadata.keyspaces[keyspace]
self.assertIsInstance(keyspace_metadata.export_as_string(), unicode)
self.assertIsInstance(keyspace_metadata.as_cql_query(), unicode)
self.assertIsInstance(keyspace_metadata.export_as_string(), basestring)
self.assertIsInstance(keyspace_metadata.as_cql_query(), basestring)
def test_case_sensitivity(self):
"""
Test that names that need to be escaped in CREATE statements are
"""
cluster = Cluster()
session = cluster.connect()
ksname = 'AnInterestingKeyspace'
cfname = 'AnInterestingTable'
session.execute("""
CREATE KEYSPACE "%s"
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
""" % (ksname,))
session.execute("""
CREATE TABLE "%s"."%s" (
k int,
"A" int,
"B" int,
"MyColumn" int,
PRIMARY KEY (k, "A"))
WITH CLUSTERING ORDER BY ("A" DESC)
""" % (ksname, cfname))
session.execute("""
CREATE INDEX myindex ON "%s"."%s" ("MyColumn")
""" % (ksname, cfname))
ksmeta = cluster.metadata.keyspaces[ksname]
schema = ksmeta.export_as_string()
self.assertIn('CREATE KEYSPACE "AnInterestingKeyspace"', schema)
self.assertIn('CREATE TABLE "AnInterestingKeyspace"."AnInterestingTable"', schema)
self.assertIn('"A" int', schema)
self.assertIn('"B" int', schema)
self.assertIn('"MyColumn" int', schema)
self.assertIn('PRIMARY KEY (k, "A")', schema)
self.assertIn('WITH CLUSTERING ORDER BY ("A" DESC)', schema)
self.assertIn('CREATE INDEX myindex ON "AnInterestingKeyspace"."AnInterestingTable" ("MyColumn")', schema)
def test_already_exists_exceptions(self):
"""
@@ -365,8 +404,8 @@ class TestCodeCoverage(unittest.TestCase):
for i, token in enumerate(ring):
self.assertEqual(set(get_replicas('test3rf', token)), set(owners))
self.assertEqual(set(get_replicas('test2rf', token)), set([owners[i], owners[(i + 1) % 3]]))
self.assertEqual(set(get_replicas('test1rf', token)), set([owners[i]]))
self.assertEqual(set(get_replicas('test2rf', token)), set([owners[(i + 1) % 3], owners[(i + 2) % 3]]))
self.assertEqual(set(get_replicas('test1rf', token)), set([owners[(i + 1) % 3]]))
class TokenMetadataTest(unittest.TestCase):
@@ -393,12 +432,13 @@ class TokenMetadataTest(unittest.TestCase):
token_map = TokenMap(MD5Token, token_to_primary_replica, tokens, metadata)
# tokens match node tokens exactly
for token, expected_host in zip(tokens, hosts):
for i, token in enumerate(tokens):
expected_host = hosts[(i + 1) % len(hosts)]
replicas = token_map.get_replicas("ks", token)
self.assertEqual(set(replicas), set([expected_host]))
# shift the tokens back by one
for token, expected_host in zip(tokens[1:], hosts[1:]):
for token, expected_host in zip(tokens, hosts):
replicas = token_map.get_replicas("ks", MD5Token(str(token.value - 1)))
self.assertEqual(set(replicas), set([expected_host]))

View File

@@ -5,6 +5,7 @@ except ImportError:
from cassandra.query import PreparedStatement, BoundStatement, ValueSequence, SimpleStatement
from cassandra.cluster import Cluster
from cassandra.decoder import dict_factory
class QueryTest(unittest.TestCase):
@@ -49,6 +50,20 @@ class QueryTest(unittest.TestCase):
for event in statement.trace.events:
str(event)
def test_trace_ignores_row_factory(self):
cluster = Cluster()
session = cluster.connect()
session.row_factory = dict_factory
query = "SELECT * FROM system.local"
statement = SimpleStatement(query)
session.execute(statement, trace=True)
# Ensure this does not throw an exception
str(statement.trace)
for event in statement.trace.events:
str(event)
class PreparedStatementTests(unittest.TestCase):

View File

@@ -14,6 +14,12 @@ except ImportError:
from cassandra import InvalidRequest
from cassandra.cluster import Cluster
from cassandra.cqltypes import Int32Type, EMPTY
from cassandra.decoder import dict_factory
try:
from collections import OrderedDict
except ImportError:
from cassandra.util import OrderedDict # noqa
from tests.integration import get_server_versions
@@ -100,16 +106,8 @@ class TypeTests(unittest.TestCase):
for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual)
def test_basic_types(self):
c = Cluster()
s = c.connect()
s.execute("""
CREATE KEYSPACE typetests
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}
""")
s.set_keyspace("typetests")
s.execute("""
CREATE TABLE mytable (
create_type_table = """
CREATE TABLE mytable (
a text,
b text,
c ascii,
@@ -131,7 +129,17 @@ class TypeTests(unittest.TestCase):
t varint,
PRIMARY KEY (a, b)
)
""")
"""
def test_basic_types(self):
c = Cluster()
s = c.connect()
s.execute("""
CREATE KEYSPACE typetests
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}
""")
s.set_keyspace("typetests")
s.execute(self.create_type_table)
v1_uuid = uuid1()
v4_uuid = uuid4()
@@ -220,6 +228,126 @@ class TypeTests(unittest.TestCase):
for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual)
def test_empty_strings_and_nones(self):
c = Cluster()
s = c.connect()
s.execute("""
CREATE KEYSPACE test_empty_strings_and_nones
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}
""")
s.set_keyspace("test_empty_strings_and_nones")
s.execute(self.create_type_table)
s.execute("INSERT INTO mytable (a, b) VALUES ('a', 'b')")
s.row_factory = dict_factory
results = s.execute("""
SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable
""")
self.assertTrue(all(x is None for x in results[0].values()))
prepared = s.prepare("""
SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable
""")
results = s.execute(prepared.bind(()))
self.assertTrue(all(x is None for x in results[0].values()))
# insert empty strings for string-like fields and fetch them
s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)",
('', '', '', [''], {'': 3}))
self.assertEquals(
{'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})},
s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0])
self.assertEquals(
{'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})},
s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0])
# non-string types shouldn't accept empty strings
for col in ('d', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'q', 'r', 't'):
query = "INSERT INTO mytable (a, b, %s) VALUES ('a', 'b', %%s)" % (col, )
try:
s.execute(query, [''])
except InvalidRequest:
pass
else:
self.fail("Expected an InvalidRequest error when inserting an "
"emptry string for column %s" % (col, ))
prepared = s.prepare("INSERT INTO mytable (a, b, %s) VALUES ('a', 'b', ?)" % (col, ))
try:
s.execute(prepared, [''])
except TypeError:
pass
else:
self.fail("Expected an InvalidRequest error when inserting an "
"emptry string for column %s with a prepared statement" % (col, ))
# insert values for all columns
values = ['a', 'b', 'a', 1, True, Decimal('1.0'), 0.1, 0.1,
"1.2.3.4", 1, ['a'], set([1]), {'a': 1}, 'a',
datetime.now(), uuid4(), uuid1(), 'a', 1]
s.execute("""
INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", values)
# then insert None, which should null them out
null_values = values[:2] + ([None] * (len(values) - 2))
s.execute("""
INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", null_values)
results = s.execute("""
SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable
""")
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None])
prepared = s.prepare("""
SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable
""")
results = s.execute(prepared.bind(()))
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None])
# do the same thing again, but use a prepared statement to insert the nulls
s.execute("""
INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", values)
prepared = s.prepare("""
INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""")
s.execute(prepared, null_values)
results = s.execute("""
SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable
""")
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None])
prepared = s.prepare("""
SELECT c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t FROM mytable
""")
results = s.execute(prepared.bind(()))
self.assertEqual([], [(name, val) for (name, val) in results[0].items() if val is not None])
def test_empty_values(self):
c = Cluster()
s = c.connect()
s.execute("""
CREATE KEYSPACE test_empty_values
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}
""")
s.set_keyspace("test_empty_values")
s.execute("CREATE TABLE mytable (a text PRIMARY KEY, b int)")
s.execute("INSERT INTO mytable (a, b) VALUES ('a', blobAsInt(0x))")
try:
Int32Type.support_empty_values = True
results = s.execute("SELECT b FROM mytable WHERE a='a'")[0]
self.assertIs(EMPTY, results.b)
finally:
Int32Type.support_empty_values = False
def test_timezone_aware_datetimes(self):
""" Ensure timezone-aware datetimes are converted to timestamps correctly """
try:

View File

@@ -4,6 +4,7 @@ except ImportError:
import unittest # noqa
import errno
import os
from StringIO import StringIO
import socket
from socket import error as socket_error
@@ -15,7 +16,7 @@ from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT,
from cassandra.decoder import (write_stringmultimap, write_int, write_string,
SupportedMessage, ReadyMessage, ServerError)
from cassandra.marshal import uint8_pack, uint32_pack
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
from cassandra.io.asyncorereactor import AsyncoreConnection
@@ -84,6 +85,40 @@ class AsyncoreConnectionTest(unittest.TestCase):
c.handle_read()
self.assertTrue(c.connected_event.is_set())
return c
def test_egain_on_buffer_size(self, *args):
# get a connection that's already fully started
c = self.test_successful_connection()
header = '\x00\x00\x00\x00' + int32_pack(20000)
responses = [
header + ('a' * (4096 - len(header))),
'a' * 4096,
socket_error(errno.EAGAIN),
'a' * 100,
socket_error(errno.EAGAIN)]
def side_effect(*args):
response = responses.pop(0)
if isinstance(response, socket_error):
raise response
else:
return response
c.socket.recv.side_effect = side_effect
c.handle_read()
self.assertEquals(c._total_reqd_bytes, 20000 + len(header))
# the EAGAIN prevents it from reading the last 100 bytes
c._iobuf.seek(0, os.SEEK_END)
pos = c._iobuf.tell()
self.assertEquals(pos, 4096 + 4096)
# now tell it to read the last 100 bytes
c.handle_read()
c._iobuf.seek(0, os.SEEK_END)
pos = c._iobuf.tell()
self.assertEquals(pos, 4096 + 4096 + 100)
def test_protocol_error(self, *args):
c = self.make_connection()

View File

@@ -4,6 +4,7 @@ except ImportError:
import unittest # noqa
import errno
import os
from StringIO import StringIO
from socket import error as socket_error
@@ -14,7 +15,7 @@ from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT,
from cassandra.decoder import (write_stringmultimap, write_int, write_string,
SupportedMessage, ReadyMessage, ServerError)
from cassandra.marshal import uint8_pack, uint32_pack
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
try:
from cassandra.io.libevreactor import LibevConnection
@@ -84,6 +85,40 @@ class LibevConnectionTest(unittest.TestCase):
c.handle_read(None, 0)
self.assertTrue(c.connected_event.is_set())
return c
def test_egain_on_buffer_size(self, *args):
# get a connection that's already fully started
c = self.test_successful_connection()
header = '\x00\x00\x00\x00' + int32_pack(20000)
responses = [
header + ('a' * (4096 - len(header))),
'a' * 4096,
socket_error(errno.EAGAIN),
'a' * 100,
socket_error(errno.EAGAIN)]
def side_effect(*args):
response = responses.pop(0)
if isinstance(response, socket_error):
raise response
else:
return response
c._socket.recv.side_effect = side_effect
c.handle_read(None, 0)
self.assertEquals(c._total_reqd_bytes, 20000 + len(header))
# the EAGAIN prevents it from reading the last 100 bytes
c._iobuf.seek(0, os.SEEK_END)
pos = c._iobuf.tell()
self.assertEquals(pos, 4096 + 4096)
# now tell it to read the last 100 bytes
c.handle_read(None, 0)
c._iobuf.seek(0, os.SEEK_END)
pos = c._iobuf.tell()
self.assertEquals(pos, 4096 + 4096 + 100)
def test_protocol_error(self, *args):
c = self.make_connection()

View File

@@ -4,10 +4,11 @@ except ImportError:
import unittest # noqa
import cassandra
from cassandra.metadata import (TableMetadata, Murmur3Token, MD5Token,
from cassandra.metadata import (Murmur3Token, MD5Token,
BytesToken, ReplicationStrategy,
NetworkTopologyStrategy, SimpleStrategy,
LocalStrategy, NoMurmur3)
LocalStrategy, NoMurmur3, protect_name,
protect_names, protect_value, is_valid_name)
from cassandra.policies import SimpleConvictionPolicy
from cassandra.pool import Host
@@ -132,31 +133,25 @@ class TestStrategies(unittest.TestCase):
self.assertItemsEqual(rf3_replicas[MD5Token(200)], [host3, host1, host2])
class TestTokens(unittest.TestCase):
class TestNameEscaping(unittest.TestCase):
def test_protect_name(self):
"""
Test TableMetadata.protect_name output
Test cassandra.metadata.protect_name output
"""
table_metadata = TableMetadata('ks_name', 'table_name')
self.assertEqual(table_metadata.protect_name('tests'), 'tests')
self.assertEqual(table_metadata.protect_name('test\'s'), '"test\'s"')
self.assertEqual(table_metadata.protect_name('test\'s'), "\"test's\"")
self.assertEqual(table_metadata.protect_name('tests ?!@#$%^&*()'), '"tests ?!@#$%^&*()"')
self.assertEqual(table_metadata.protect_name('1'), '"1"')
self.assertEqual(table_metadata.protect_name('1test'), '"1test"')
self.assertEqual(protect_name('tests'), 'tests')
self.assertEqual(protect_name('test\'s'), '"test\'s"')
self.assertEqual(protect_name('test\'s'), "\"test's\"")
self.assertEqual(protect_name('tests ?!@#$%^&*()'), '"tests ?!@#$%^&*()"')
self.assertEqual(protect_name('1'), '"1"')
self.assertEqual(protect_name('1test'), '"1test"')
def test_protect_names(self):
"""
Test TableMetadata.protect_names output
Test cassandra.metadata.protect_names output
"""
table_metadata = TableMetadata('ks_name', 'table_name')
self.assertEqual(table_metadata.protect_names(['tests']), ['tests'])
self.assertEqual(table_metadata.protect_names(
self.assertEqual(protect_names(['tests']), ['tests'])
self.assertEqual(protect_names(
[
'tests',
'test\'s',
@@ -172,36 +167,33 @@ class TestTokens(unittest.TestCase):
def test_protect_value(self):
"""
Test TableMetadata.protect_value output
Test cassandra.metadata.protect_value output
"""
table_metadata = TableMetadata('ks_name', 'table_name')
self.assertEqual(table_metadata.protect_value(True), "True")
self.assertEqual(table_metadata.protect_value(False), "False")
self.assertEqual(table_metadata.protect_value(3.14), '3.14')
self.assertEqual(table_metadata.protect_value(3), '3')
self.assertEqual(table_metadata.protect_value('test'), "'test'")
self.assertEqual(table_metadata.protect_value('test\'s'), "'test''s'")
self.assertEqual(table_metadata.protect_value(None), 'NULL')
self.assertEqual(protect_value(True), "True")
self.assertEqual(protect_value(False), "False")
self.assertEqual(protect_value(3.14), '3.14')
self.assertEqual(protect_value(3), '3')
self.assertEqual(protect_value('test'), "'test'")
self.assertEqual(protect_value('test\'s'), "'test''s'")
self.assertEqual(protect_value(None), 'NULL')
def test_is_valid_name(self):
"""
Test TableMetadata.is_valid_name output
Test cassandra.metadata.is_valid_name output
"""
table_metadata = TableMetadata('ks_name', 'table_name')
self.assertEqual(table_metadata.is_valid_name(None), False)
self.assertEqual(table_metadata.is_valid_name('test'), True)
self.assertEqual(table_metadata.is_valid_name('Test'), False)
self.assertEqual(table_metadata.is_valid_name('t_____1'), True)
self.assertEqual(table_metadata.is_valid_name('test1'), True)
self.assertEqual(table_metadata.is_valid_name('1test1'), False)
self.assertEqual(is_valid_name(None), False)
self.assertEqual(is_valid_name('test'), True)
self.assertEqual(is_valid_name('Test'), False)
self.assertEqual(is_valid_name('t_____1'), True)
self.assertEqual(is_valid_name('test1'), True)
self.assertEqual(is_valid_name('1test1'), False)
non_valid_keywords = cassandra.metadata._keywords - cassandra.metadata._unreserved_keywords
for keyword in non_valid_keywords:
self.assertEqual(table_metadata.is_valid_name(keyword), False)
self.assertEqual(is_valid_name(keyword), False)
class TestTokens(unittest.TestCase):
def test_murmur3_tokens(self):
try:

View File

@@ -5,7 +5,6 @@ except ImportError:
from cassandra.query import bind_params, ValueSequence
from cassandra.query import PreparedStatement, BoundStatement
from cassandra.query import InvalidParameterTypeError
from cassandra.cqltypes import Int32Type
try:
@@ -77,21 +76,12 @@ class BoundStatementTestCase(unittest.TestCase):
values = ['nonint', 1]
try:
bound_statement.bind(values)
except InvalidParameterTypeError as e:
self.assertEqual(e.col_name, 'foo1')
self.assertEqual(e.expected_type, Int32Type)
self.assertEqual(e.actual_type, str)
else:
self.fail('Passed invalid type but exception was not thrown')
try:
bound_statement.bind(values)
except TypeError as e:
self.assertEqual(e.col_name, 'foo1')
self.assertEqual(e.expected_type, Int32Type)
self.assertEqual(e.actual_type, str)
self.assertIn('foo1', str(e))
self.assertIn('Int32Type', str(e))
self.assertIn('str', str(e))
else:
self.fail('Passed invalid type but exception was not thrown')
@@ -99,9 +89,9 @@ class BoundStatementTestCase(unittest.TestCase):
try:
bound_statement.bind(values)
except InvalidParameterTypeError as e:
self.assertEqual(e.col_name, 'foo2')
self.assertEqual(e.expected_type, Int32Type)
self.assertEqual(e.actual_type, list)
except TypeError as e:
self.assertIn('foo2', str(e))
self.assertIn('Int32Type', str(e))
self.assertIn('list', str(e))
else:
self.fail('Passed invalid type but exception was not thrown')

View File

@@ -5,6 +5,8 @@ except ImportError:
from itertools import islice, cycle
from mock import Mock
from random import randint
import sys
import struct
from threading import Thread
@@ -88,11 +90,51 @@ class TestRoundRobinPolicy(unittest.TestCase):
map(lambda t: t.start(), threads)
map(lambda t: t.join(), threads)
def test_thread_safety_during_modification(self):
hosts = range(100)
policy = RoundRobinPolicy()
policy.populate(None, hosts)
errors = []
def check_query_plan():
try:
for i in xrange(100):
list(policy.make_query_plan())
except Exception as exc:
errors.append(exc)
def host_up():
for i in xrange(1000):
policy.on_up(randint(0, 99))
def host_down():
for i in xrange(1000):
policy.on_down(randint(0, 99))
threads = []
for i in range(5):
threads.append(Thread(target=check_query_plan))
threads.append(Thread(target=host_up))
threads.append(Thread(target=host_down))
# make the GIL switch after every instruction, maximizing
# the chace of race conditions
original_interval = sys.getcheckinterval()
try:
sys.setcheckinterval(0)
map(lambda t: t.start(), threads)
map(lambda t: t.join(), threads)
finally:
sys.setcheckinterval(original_interval)
if errors:
self.fail("Saw errors: %s" % (errors,))
def test_no_live_nodes(self):
"""
Ensure query plan for a downed cluster will execute without errors
"""
hosts = [0, 1, 2, 3]
policy = RoundRobinPolicy()
policy.populate(None, hosts)
@@ -408,6 +450,7 @@ class TokenAwarePolicyTest(unittest.TestCase):
qplan = list(policy.make_query_plan())
self.assertEqual(qplan, [])
class ConvictionPolicyTest(unittest.TestCase):
def test_not_implemented(self):
"""