Merge branch 'py3k' into 2.0

Conflicts:
	cassandra/cluster.py
	cassandra/encoder.py
	cassandra/marshal.py
	cassandra/pool.py
	setup.py
	tests/integration/long/test_large_data.py
	tests/integration/long/utils.py
	tests/integration/standard/test_metadata.py
	tests/integration/standard/test_prepared_statements.py
	tests/unit/io/test_asyncorereactor.py
	tests/unit/test_connection.py
	tests/unit/test_types.py
This commit is contained in:
Tyler Hobbs
2014-05-02 15:55:35 -05:00
44 changed files with 627 additions and 431 deletions

View File

@@ -4,11 +4,13 @@ env:
- TOX_ENV=py26
- TOX_ENV=py27
- TOX_ENV=pypy
- TOX_ENV=py33
before_install:
- sudo apt-get update -y
- sudo apt-get install -y build-essential python-dev
- sudo apt-get install -y libev4 libev-dev
install:
- pip install tox

View File

@@ -41,7 +41,7 @@ try:
from cassandra.io.libevreactor import LibevConnection
have_libev = True
supported_reactors.append(LibevConnection)
except ImportError, exc:
except ImportError as exc:
pass
KEYSPACE = "testkeyspace"
@@ -104,7 +104,7 @@ def benchmark(thread_class):
""".format(table=TABLE))
values = ('key', 'a', 'b')
per_thread = options.num_ops / options.threads
per_thread = options.num_ops // options.threads
threads = []
log.debug("Beginning inserts...")

View File

@@ -18,7 +18,7 @@ from itertools import count
from threading import Event
from base import benchmark, BenchmarkThread
from six.moves import range
log = logging.getLogger(__name__)
@@ -38,17 +38,17 @@ class Runner(BenchmarkThread):
if previous_result is not sentinel:
if isinstance(previous_result, BaseException):
log.error("Error on insert: %r", previous_result)
if self.num_finished.next() >= self.num_queries:
if next(self.num_finished) >= self.num_queries:
self.event.set()
if self.num_started.next() <= self.num_queries:
if next(self.num_started) <= self.num_queries:
future = self.session.execute_async(self.query, self.values)
future.add_callbacks(self.insert_next, self.insert_next)
def run(self):
self.start_profile()
for _ in xrange(min(120, self.num_queries)):
for _ in range(min(120, self.num_queries)):
self.insert_next()
self.event.wait()

View File

@@ -13,16 +13,16 @@
# limitations under the License.
import logging
import Queue
from base import benchmark, BenchmarkThread
from six.moves import queue
log = logging.getLogger(__name__)
class Runner(BenchmarkThread):
def run(self):
futures = Queue.Queue(maxsize=121)
futures = queue.Queue(maxsize=121)
self.start_profile()
@@ -32,7 +32,7 @@ class Runner(BenchmarkThread):
while True:
try:
futures.get_nowait().result()
except Queue.Empty:
except queue.Empty:
break
future = self.session.execute_async(self.query, self.values)
@@ -41,7 +41,7 @@ class Runner(BenchmarkThread):
while True:
try:
futures.get_nowait().result()
except Queue.Empty:
except queue.Empty:
break
self.finish_profile()

View File

@@ -13,16 +13,16 @@
# limitations under the License.
import logging
import Queue
from base import benchmark, BenchmarkThread
from six.moves import queue
log = logging.getLogger(__name__)
class Runner(BenchmarkThread):
def run(self):
futures = Queue.Queue(maxsize=121)
futures = queue.Queue(maxsize=121)
self.start_profile()
@@ -37,7 +37,7 @@ class Runner(BenchmarkThread):
while True:
try:
futures.get_nowait().result()
except Queue.Empty:
except queue.Empty:
break
self.finish_profile

View File

@@ -25,7 +25,7 @@ class Runner(BenchmarkThread):
self.start_profile()
for i in range(self.num_queries):
for _ in range(self.num_queries):
future = self.session.execute_async(self.query, self.values)
futures.append(future)

View File

@@ -13,13 +13,15 @@
# limitations under the License.
from base import benchmark, BenchmarkThread
from six.moves import range
class Runner(BenchmarkThread):
def run(self):
self.start_profile()
for i in xrange(self.num_queries):
for _ in range(self.num_queries):
self.session.execute(self.query, self.values)
self.finish_profile()

View File

@@ -26,7 +26,11 @@ import socket
import sys
import time
from threading import Lock, RLock, Thread, Event
import Queue
import six
from six.moves import range
from six.moves import queue as Queue
import weakref
from weakref import WeakValueDictionary
try:
@@ -696,7 +700,7 @@ class Cluster(object):
host.set_down()
log.warn("Host %s has been marked down", host)
log.warning("Host %s has been marked down", host)
self.load_balancing_policy.on_down(host)
self.control_connection.on_down(host)
@@ -742,7 +746,7 @@ class Cluster(object):
return
if not all(futures_results):
log.warn("Connection pool could not be created, not marking node %s up", host)
log.warning("Connection pool could not be created, not marking node %s up", host)
return
self._finalize_add(host)
@@ -867,7 +871,7 @@ class Cluster(object):
# prepare 10 statements at a time
ks_statements = list(ks_statements)
chunks = []
for i in xrange(0, len(ks_statements), 10):
for i in range(0, len(ks_statements), 10):
chunks.append(ks_statements[i:i + 10])
for ks_chunk in chunks:
@@ -882,9 +886,9 @@ class Cluster(object):
log.debug("Done preparing all known prepared statements against host %s", host)
except OperationTimedOut as timeout:
log.warn("Timed out trying to prepare all statements on host %s: %s", host, timeout)
log.warning("Timed out trying to prepare all statements on host %s: %s", host, timeout)
except (ConnectionException, socket.error) as exc:
log.warn("Error trying to prepare all statements on host %s: %r", host, exc)
log.warning("Error trying to prepare all statements on host %s: %r", host, exc)
except Exception:
log.exception("Error trying to prepare all statements on host %s", host)
finally:
@@ -1088,7 +1092,7 @@ class Session(object):
prepared_statement = None
if isinstance(query, basestring):
if isinstance(query, six.string_types):
query = SimpleStatement(query)
elif isinstance(query, PreparedStatement):
query = query.bind(parameters)
@@ -1235,8 +1239,8 @@ class Session(object):
self.cluster.signal_connection_failure(host, conn_exc, is_host_addition)
return False
except Exception as conn_exc:
log.warn("Failed to create connection pool for new host %s: %s",
host, conn_exc)
log.warning("Failed to create connection pool for new host %s: %s",
host, conn_exc)
# the host itself will still be marked down, so we need to pass
# a special flag to make sure the reconnector is created
self.cluster.signal_connection_failure(
@@ -1456,11 +1460,11 @@ class ControlConnection(object):
return self._try_connect(host)
except ConnectionException as exc:
errors[host.address] = exc
log.warn("[control connection] Error connecting to %s:", host, exc_info=True)
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
self._cluster.signal_connection_failure(host, exc, is_host_addition=False)
except Exception as exc:
errors[host.address] = exc
log.warn("[control connection] Error connecting to %s:", host, exc_info=True)
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
raise NoHostAvailable("Unable to connect to any servers", errors)
@@ -1948,7 +1952,7 @@ class _Scheduler(object):
def _log_if_failed(self, future):
exc = future.exception()
if exc:
log.warn(
log.warning(
"An internally scheduled tasked failed with an unhandled exception:",
exc_info=exc)
@@ -2170,8 +2174,8 @@ class ResponseFuture(object):
if self._metrics is not None:
self._metrics.on_other_error()
# need to retry against a different host here
log.warn("Host %s is overloaded, retrying against a different "
"host", self._current_host)
log.warning("Host %s is overloaded, retrying against a different "
"host", self._current_host)
self._retry(reuse_connection=False, consistency_level=None)
return
elif isinstance(response, IsBootstrappingErrorMessage):

View File

@@ -15,6 +15,7 @@
import sys
from itertools import count, cycle
from six.moves import xrange
from threading import Event
@@ -105,7 +106,7 @@ def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs
parameters = [(x,) for x in range(1000)]
execute_concurrent_with_args(session, statement, parameters)
"""
return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs)
return execute_concurrent(session, list(zip(cycle((statement,)), parameters)), *args, **kwargs)
_sentinel = object()
@@ -118,12 +119,12 @@ def _handle_error(error, result_index, event, session, statements, results, num_
return
else:
results[result_index] = (False, error)
if num_finished.next() >= to_execute:
if next(num_finished) >= to_execute:
event.set()
return
try:
(next_index, (statement, params)) = statements.next()
(next_index, (statement, params)) = next(statements)
except StopIteration:
return
@@ -139,7 +140,7 @@ def _handle_error(error, result_index, event, session, statements, results, num_
return
else:
results[next_index] = (False, exc)
if num_finished.next() >= to_execute:
if next(num_finished) >= to_execute:
event.set()
return
@@ -147,13 +148,13 @@ def _handle_error(error, result_index, event, session, statements, results, num_
def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error):
if result is not _sentinel:
results[result_index] = (True, result)
finished = num_finished.next()
finished = next(num_finished)
if finished >= to_execute:
event.set()
return
try:
(next_index, (statement, params)) = statements.next()
(next_index, (statement, params)) = next(statements)
except StopIteration:
return
@@ -169,6 +170,6 @@ def _execute_next(result, result_index, event, session, statements, results, num
return
else:
results[next_index] = (False, exc)
if num_finished.next() >= to_execute:
if next(num_finished) >= to_execute:
event.set()
return

View File

@@ -19,19 +19,21 @@ import logging
import sys
from threading import Event, RLock
import time
import traceback
if 'gevent.monkey' in sys.modules:
from gevent.queue import Queue, Empty
else:
from Queue import Queue, Empty # noqa
from six.moves.queue import Queue, Empty # noqa
from six.moves import range
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
from cassandra.marshal import int8_unpack, int32_pack
from cassandra.marshal import int32_pack, header_unpack
from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage,
StartupMessage, ErrorMessage, CredentialsMessage,
QueryMessage, ResultMessage, decode_response,
InvalidRequestException, SupportedMessage)
import six
log = logging.getLogger(__name__)
@@ -123,7 +125,6 @@ def defunct_on_error(f):
return f(self, *args, **kwargs)
except Exception as exc:
self.defunct(exc)
return wrapper
@@ -181,12 +182,8 @@ class Connection(object):
return
self.is_defunct = True
trace = traceback.format_exc(exc)
if trace != "None":
log.debug("Defuncting connection (%s) to %s: %s\n%s",
id(self), self.host, exc, traceback.format_exc(exc))
else:
log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc)
log.debug("Defuncting connection (%s) to %s:",
id(self), self.host, exc_info=exc)
self.last_error = exc
self.close()
@@ -203,9 +200,9 @@ class Connection(object):
try:
cb(new_exc)
except Exception:
log.warn("Ignoring unhandled exception while erroring callbacks for a "
"failed connection (%s) to host %s:",
id(self), self.host, exc_info=True)
log.warning("Ignoring unhandled exception while erroring callbacks for a "
"failed connection (%s) to host %s:",
id(self), self.host, exc_info=True)
def handle_pushed(self, response):
log.debug("Message pushed from server: %r", response)
@@ -231,7 +228,7 @@ class Connection(object):
request_id = self._id_queue.get()
self._callbacks[request_id] = cb
self.push(msg.to_string(request_id, self.protocol_version, compression=self.compressor))
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
return request_id
def wait_for_response(self, msg, timeout=None):
@@ -268,7 +265,7 @@ class Connection(object):
return waiter.deliver(timeout)
except OperationTimedOut:
raise
except Exception, exc:
except Exception as exc:
self.defunct(exc)
raise
@@ -284,7 +281,7 @@ class Connection(object):
@defunct_on_error
def process_msg(self, msg, body_len):
version, flags, stream_id, opcode = map(int8_unpack, msg[:4])
version, flags, stream_id, opcode = header_unpack(msg[:4])
if stream_id < 0:
callback = None
else:
@@ -309,7 +306,7 @@ class Connection(object):
if body_len > 0:
body = msg[8:]
elif body_len == 0:
body = ""
body = six.binary_type()
else:
raise ProtocolError("Got negative body length: %r" % body_len)
@@ -383,7 +380,7 @@ class Connection(object):
locally_supported_compressions.keys(),
remote_supported_compressions)
else:
compression_type = iter(overlap).next() # choose any
compression_type = next(iter(overlap)) # choose any
# set the decompressor here, but set the compressor only after
# a successful Ready message
self._compressor, self.decompressor = \

View File

@@ -36,10 +36,8 @@ from datetime import datetime
from uuid import UUID
import warnings
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO # NOQA
import six
from six.moves import range
from cassandra.marshal import (int8_pack, int8_unpack, uint16_pack, uint16_unpack,
int32_pack, int32_unpack, int64_pack, int64_unpack,
@@ -49,7 +47,11 @@ from cassandra.util import OrderedDict
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
_number_types = frozenset((int, long, float))
if six.PY3:
_number_types = frozenset((int, float))
long = int
else:
_number_types = frozenset((int, long, float))
try:
from blist import sortedset
@@ -69,7 +71,7 @@ def trim_if_startswith(s, prefix):
def unix_time_from_uuid1(u):
return (u.get_time() - 0x01B21DD213814000) / 10000000.0
return (u.time - 0x01B21DD213814000) / 10000000.0
_casstypes = {}
@@ -177,8 +179,8 @@ class EmptyValue(object):
EMPTY = EmptyValue()
@six.add_metaclass(CassandraTypeType)
class _CassandraType(object):
__metaclass__ = CassandraTypeType
subtypes = ()
num_subtypes = 0
empty_binary_ok = False
@@ -199,9 +201,8 @@ class _CassandraType(object):
def __init__(self, val):
self.val = self.validate(val)
def __str__(self):
def __repr__(self):
return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
__repr__ = __str__
@staticmethod
def validate(val):
@@ -221,7 +222,7 @@ class _CassandraType(object):
"""
if byts is None:
return None
elif byts == '' and not cls.empty_binary_ok:
elif len(byts) == 0 and not cls.empty_binary_ok:
return EMPTY if cls.support_empty_values else None
return cls.deserialize(byts)
@@ -232,7 +233,7 @@ class _CassandraType(object):
more information. This method differs in that if None is passed in,
the result is the empty string.
"""
return '' if val is None else cls.serialize(val)
return b'' if val is None else cls.serialize(val)
@staticmethod
def deserialize(byts):
@@ -293,7 +294,8 @@ class _CassandraType(object):
if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes:
raise ValueError("%s types require %d subtypes (%d given)"
% (cls.typename, cls.num_subtypes, len(subtypes)))
newname = cls.cass_parameterized_type_with(subtypes).encode('utf8')
# newname = cls.cass_parameterized_type_with(subtypes).encode('utf8')
newname = cls.cass_parameterized_type_with(subtypes)
return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname})
@classmethod
@@ -324,10 +326,16 @@ class _UnrecognizedType(_CassandraType):
num_subtypes = 'UNKNOWN'
def mkUnrecognizedType(casstypename):
return CassandraTypeType(casstypename.encode('utf8'),
(_UnrecognizedType,),
{'typename': "'%s'" % casstypename})
if six.PY3:
def mkUnrecognizedType(casstypename):
return CassandraTypeType(casstypename,
(_UnrecognizedType,),
{'typename': "'%s'" % casstypename})
else:
def mkUnrecognizedType(casstypename): # noqa
return CassandraTypeType(casstypename.encode('utf8'),
(_UnrecognizedType,),
{'typename': "'%s'" % casstypename})
class BytesType(_CassandraType):
@@ -336,11 +344,11 @@ class BytesType(_CassandraType):
@staticmethod
def validate(val):
return buffer(val)
return bytearray(val)
@staticmethod
def serialize(val):
return str(val)
return six.binary_type(val)
class DecimalType(_CassandraType):
@@ -401,9 +409,25 @@ class BooleanType(_CassandraType):
return int8_pack(truth)
class AsciiType(_CassandraType):
typename = 'ascii'
empty_binary_ok = True
if six.PY2:
class AsciiType(_CassandraType):
typename = 'ascii'
empty_binary_ok = True
else:
class AsciiType(_CassandraType):
typename = 'ascii'
empty_binary_ok = True
@staticmethod
def deserialize(byts):
return byts.decode('ascii')
@staticmethod
def serialize(var):
try:
return var.encode('ascii')
except UnicodeDecodeError:
return var
class FloatType(_CassandraType):
@@ -496,7 +520,7 @@ class DateType(_CassandraType):
@classmethod
def validate(cls, date):
if isinstance(date, basestring):
if isinstance(date, six.string_types):
date = cls.interpret_datestring(date)
return date
@@ -628,7 +652,7 @@ class _SimpleParameterizedType(_ParameterizedType):
numelements = uint16_unpack(byts[:2])
p = 2
result = []
for n in xrange(numelements):
for _ in range(numelements):
itemlen = uint16_unpack(byts[p:p + 2])
p += 2
item = byts[p:p + itemlen]
@@ -638,11 +662,11 @@ class _SimpleParameterizedType(_ParameterizedType):
@classmethod
def serialize_safe(cls, items):
if isinstance(items, basestring):
if isinstance(items, six.string_types):
raise TypeError("Received a string for a type that expects a sequence")
subtype, = cls.subtypes
buf = StringIO()
buf = six.BytesIO()
buf.write(uint16_pack(len(items)))
for item in items:
itembytes = subtype.to_binary(item)
@@ -670,7 +694,7 @@ class MapType(_ParameterizedType):
@classmethod
def validate(cls, val):
subkeytype, subvaltype = cls.subtypes
return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in val.iteritems())
return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in six.iteritems(val))
@classmethod
def deserialize_safe(cls, byts):
@@ -678,7 +702,7 @@ class MapType(_ParameterizedType):
numelements = uint16_unpack(byts[:2])
p = 2
themap = OrderedDict()
for n in xrange(numelements):
for _ in range(numelements):
key_len = uint16_unpack(byts[p:p + 2])
p += 2
keybytes = byts[p:p + key_len]
@@ -695,10 +719,10 @@ class MapType(_ParameterizedType):
@classmethod
def serialize_safe(cls, themap):
subkeytype, subvaltype = cls.subtypes
buf = StringIO()
buf = six.BytesIO()
buf.write(uint16_pack(len(themap)))
try:
items = themap.iteritems()
items = six.iteritems(themap)
except AttributeError:
raise TypeError("Got a non-map object for a map value")
for key, val in items:
@@ -747,7 +771,7 @@ class ReversedType(_ParameterizedType):
def is_counter_type(t):
if isinstance(t, basestring):
if isinstance(t, six.string_types):
t = lookup_casstype(t)
return issubclass(t, CounterColumnType)

View File

@@ -16,16 +16,14 @@ import logging
import socket
from uuid import UUID
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO # ignore flake8 warning: # NOQA
import six
from six.moves import range
from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
AlreadyExists, InvalidRequest, Unauthorized,
UnsupportedOperation)
from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
int8_pack, int8_unpack)
int8_pack, int8_unpack, header_pack)
from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
CounterColumnType, DateType, DecimalType,
DoubleType, FloatType, Int32Type,
@@ -48,66 +46,75 @@ HEADER_DIRECTION_FROM_CLIENT = 0x00
HEADER_DIRECTION_TO_CLIENT = 0x80
HEADER_DIRECTION_MASK = 0x80
COMPRESSED_FLAG = 0x01
TRACING_FLAG = 0x02
_message_types_by_name = {}
_message_types_by_opcode = {}
class _register_msg_type(type):
class _RegisterMessageType(type):
def __init__(cls, name, bases, dct):
if not name.startswith('_'):
_message_types_by_name[cls.name] = cls
_message_types_by_opcode[cls.opcode] = cls
@six.add_metaclass(_RegisterMessageType)
class _MessageType(object):
__metaclass__ = _register_msg_type
tracing = False
def to_string(self, stream_id, protocol_version, compression=None):
body = StringIO()
def to_binary(self, stream_id, protocol_version, compression=None):
body = six.BytesIO()
self.send_body(body, protocol_version)
body = body.getvalue()
version = protocol_version | HEADER_DIRECTION_FROM_CLIENT
flags = 0
if compression is not None and len(body) > 0:
body = compression(body)
flags |= 0x01
if self.tracing:
flags |= 0x02
msglen = int32_pack(len(body))
msg_parts = map(int8_pack, (version, flags, stream_id, self.opcode)) + [msglen, body]
return ''.join(msg_parts)
def __str__(self):
paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in _get_params(self)]
return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs))
__repr__ = __str__
flags = 0
if compression and len(body) > 0:
body = compression(body)
flags |= COMPRESSED_FLAG
if self.tracing:
flags |= TRACING_FLAG
msg = six.BytesIO()
write_header(
msg,
protocol_version | HEADER_DIRECTION_FROM_CLIENT,
flags, stream_id, self.opcode, len(body)
)
msg.write(body)
return msg.getvalue()
def __repr__(self):
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
def _get_params(message_obj):
base_attrs = dir(_MessageType)
return [a for a in dir(message_obj)
if a not in base_attrs and not a.startswith('_') and not callable(getattr(message_obj, a))]
return (
(n, a) for n, a in message_obj.__dict__.items()
if n not in base_attrs and not n.startswith('_') and not callable(a)
)
def decode_response(stream_id, flags, opcode, body, decompressor=None):
if flags & 0x01:
if flags & COMPRESSED_FLAG:
if decompressor is None:
raise Exception("No decompressor available for compressed frame!")
raise Exception("No de-compressor available for compressed frame!")
body = decompressor(body)
flags ^= 0x01
flags ^= COMPRESSED_FLAG
body = StringIO(body)
if flags & 0x02:
body = six.BytesIO(body)
if flags & TRACING_FLAG:
trace_id = UUID(bytes=body.read(16))
flags ^= 0x02
flags ^= TRACING_FLAG
else:
trace_id = None
if flags:
log.warn("Unknown protocol flags set: %02x. May cause problems.", flags)
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
msg_class = _message_types_by_opcode[opcode]
msg = msg_class.recv_body(body)
@@ -156,14 +163,14 @@ class ErrorMessage(_MessageType, Exception):
return self
class ErrorMessageSubclass(_register_msg_type):
class ErrorMessageSubclass(_RegisterMessageType):
def __init__(cls, name, bases, dct):
if cls.error_code is not None:
if cls.error_code is not None: # Server has an error code of 0.
error_classes[cls.error_code] = cls
@six.add_metaclass(ErrorMessageSubclass)
class ErrorMessageSub(ErrorMessage):
__metaclass__ = ErrorMessageSubclass
error_code = None
@@ -511,7 +518,7 @@ class ResultMessage(_MessageType):
def recv_results_rows(cls, f):
paging_state, column_metadata = cls.recv_results_metadata(f)
rowcount = read_int(f)
rows = [cls.recv_row(f, len(column_metadata)) for x in xrange(rowcount)]
rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
colnames = [c[2] for c in column_metadata]
coltypes = [c[3] for c in column_metadata]
return (
@@ -538,7 +545,7 @@ class ResultMessage(_MessageType):
ksname = read_string(f)
cfname = read_string(f)
column_metadata = []
for x in xrange(colcount):
for _ in range(colcount):
if glob_tblspec:
colksname = ksname
colcfname = cfname
@@ -580,7 +587,7 @@ class ResultMessage(_MessageType):
@staticmethod
def recv_row(f, colcount):
return [read_value(f) for x in xrange(colcount)]
return [read_value(f) for _ in range(colcount)]
class PrepareMessage(_MessageType):
@@ -729,6 +736,14 @@ class EventMessage(_MessageType):
return dict(change_type=change_type, keyspace=keyspace, table=table)
def write_header(f, version, flags, stream_id, opcode, length):
"""
Write a CQL protocol frame header.
"""
f.write(header_pack(version, flags, stream_id, opcode))
write_int(f, length)
def read_byte(f):
return int8_unpack(f.read(1))
@@ -774,7 +789,7 @@ def read_binary_string(f):
def write_string(f, s):
if isinstance(s, unicode):
if isinstance(s, six.text_type):
s = s.encode('utf8')
write_short(f, len(s))
f.write(s)
@@ -791,7 +806,7 @@ def read_longstring(f):
def write_longstring(f, s):
if isinstance(s, unicode):
if isinstance(s, six.text_type):
s = s.encode('utf8')
write_int(f, len(s))
f.write(s)
@@ -799,7 +814,7 @@ def write_longstring(f, s):
def read_stringlist(f):
numstrs = read_short(f)
return [read_string(f) for x in xrange(numstrs)]
return [read_string(f) for _ in range(numstrs)]
def write_stringlist(f, stringlist):
@@ -811,7 +826,7 @@ def write_stringlist(f, stringlist):
def read_stringmap(f):
numpairs = read_short(f)
strmap = {}
for x in xrange(numpairs):
for _ in range(numpairs):
k = read_string(f)
strmap[k] = read_string(f)
return strmap
@@ -827,7 +842,7 @@ def write_stringmap(f, strmap):
def read_stringmultimap(f):
numkeys = read_short(f)
strmmap = {}
for x in xrange(numkeys):
for _ in range(numkeys):
k = read_string(f)
strmmap[k] = read_stringlist(f)
return strmmap

View File

@@ -12,21 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
log = logging.getLogger(__name__)
from binascii import hexlify
import calendar
import datetime
import sys
import types
from uuid import UUID
import six
from cassandra.util import OrderedDict
if six.PY3:
long = int
def cql_quote(term):
if isinstance(term, unicode):
return "'%s'" % term.encode('utf8').replace("'", "''")
elif isinstance(term, (str, bool)):
# The ordering of this method is important for the result of this method to
# be a native str type (for both Python 2 and 3)
# Handle quoting of native str and bool types
if isinstance(term, (str, bool)):
return "'%s'" % str(term).replace("'", "''")
# This branch of the if statement will only be used by Python 2 to catch
# unicode strings, text_type is used to prevent type errors with Python 3.
elif isinstance(term, six.text_type):
return "'%s'" % term.encode('utf8').replace("'", "''")
else:
return str(term)
@@ -43,13 +56,16 @@ def cql_encode_str(val):
return cql_quote(val)
if sys.version_info >= (2, 7):
if six.PY3:
def cql_encode_bytes(val):
return '0x' + hexlify(val)
return (b'0x' + hexlify(val)).decode('utf-8')
elif sys.version_info >= (2, 7):
def cql_encode_bytes(val): # noqa
return b'0x' + hexlify(val)
else:
# python 2.6 requires string or read-only buffer for hexlify
def cql_encode_bytes(val): # noqa
return '0x' + hexlify(buffer(val))
return b'0x' + hexlify(buffer(val))
def cql_encode_object(val):
@@ -71,11 +87,10 @@ def cql_encode_sequence(val):
def cql_encode_map_collection(val):
return '{ %s }' % ' , '.join(
'%s : %s' % (
cql_encode_all_types(k),
cql_encode_all_types(v))
for k, v in val.iteritems())
return '{ %s }' % ' , '.join('%s : %s' % (
cql_encode_all_types(k),
cql_encode_all_types(v)
) for k, v in six.iteritems(val))
def cql_encode_list_collection(val):
@@ -92,13 +107,9 @@ def cql_encode_all_types(val):
cql_encoders = {
float: cql_encode_object,
buffer: cql_encode_bytes,
bytearray: cql_encode_bytes,
str: cql_encode_str,
unicode: cql_encode_unicode,
types.NoneType: cql_encode_none,
int: cql_encode_object,
long: cql_encode_object,
UUID: cql_encode_object,
datetime.datetime: cql_encode_datetime,
datetime.date: cql_encode_date,
@@ -110,3 +121,17 @@ cql_encoders = {
frozenset: cql_encode_set_collection,
types.GeneratorType: cql_encode_list_collection
}
if six.PY2:
cql_encoders.update({
unicode: cql_encode_unicode,
buffer: cql_encode_bytes,
long: cql_encode_object,
types.NoneType: cql_encode_none,
})
else:
cql_encoders.update({
memoryview: cql_encode_bytes,
bytes: cql_encode_bytes,
type(None): cql_encode_none,
})

View File

@@ -20,6 +20,10 @@ import os
import socket
import sys
from threading import Event, Lock, Thread
from six import BytesIO
from six.moves import range
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode
try:
from weakref import WeakSet
@@ -28,11 +32,6 @@ except ImportError:
import asyncore
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO # ignore flake8 warning: # NOQA
try:
import ssl
except ImportError:
@@ -141,7 +140,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
asyncore.dispatcher.__init__(self)
self.connected_event = Event()
self._iobuf = StringIO()
self._iobuf = BytesIO()
self._callbacks = {}
self.deque = deque()
@@ -286,7 +285,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
# leave leftover in current buffer
leftover = self._iobuf.read()
self._iobuf = StringIO()
self._iobuf = BytesIO()
self._iobuf.write(leftover)
self._total_reqd_bytes = 0
@@ -302,7 +301,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
sabs = self.out_buffer_size
if len(data) > sabs:
chunks = []
for i in xrange(0, len(data), sabs):
for i in range(0, len(data), sabs):
chunks.append(data[i:i + sabs])
else:
chunks = [data]

View File

@@ -20,6 +20,8 @@ import os
import socket
from threading import Event, Lock, Thread
from six import BytesIO
from cassandra import OperationTimedOut
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING
from cassandra.decoder import RegisterMessage
@@ -35,10 +37,6 @@ except ImportError:
"for instructions on installing build dependencies and building "
"the C extension.")
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO # ignore flake8 warning: # NOQA
try:
import ssl
@@ -197,7 +195,7 @@ class LibevConnection(Connection):
Connection.__init__(self, *args, **kwargs)
self.connected_event = Event()
self._iobuf = StringIO()
self._iobuf = BytesIO()
self._callbacks = {}
self.deque = deque()
@@ -323,7 +321,7 @@ class LibevConnection(Connection):
# leave leftover in current buffer
leftover = self._iobuf.read()
self._iobuf = StringIO()
self._iobuf = BytesIO()
self._iobuf.write(leftover)
self._total_reqd_bytes = 0

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import struct
@@ -37,12 +38,24 @@ uint8_pack, uint8_unpack = _make_packer('>B')
float_pack, float_unpack = _make_packer('>f')
double_pack, double_unpack = _make_packer('>d')
# Special case for cassandra header
header_struct = struct.Struct('>BBbB')
header_pack = header_struct.pack
header_unpack = header_struct.unpack
def varint_unpack(term):
val = int(term.encode('hex'), 16)
if (ord(term[0]) & 128) != 0:
val = val - (1 << (len(term) * 8))
return val
if six.PY3:
def varint_unpack(term):
val = int(''.join("%02x" % i for i in term), 16)
if (term[0] & 128) != 0:
val -= 1 << (len(term) * 8)
return val
else:
def varint_unpack(term): # noqa
val = int(term.encode('hex'), 16)
if (ord(term[0]) & 128) != 0:
val = val - (1 << (len(term) * 8))
return val
def bitlength(n):
@@ -56,16 +69,16 @@ def bitlength(n):
def varint_pack(big):
pos = True
if big == 0:
return '\x00'
return b'\x00'
if big < 0:
bytelength = bitlength(abs(big) - 1) / 8 + 1
bytelength = bitlength(abs(big) - 1) // 8 + 1
big = (1 << bytelength * 8) + big
pos = False
revbytes = []
revbytes = bytearray()
while big > 0:
revbytes.append(chr(big & 0xff))
revbytes.append(big & 0xff)
big >>= 8
if pos and ord(revbytes[-1]) & 0x80:
revbytes.append('\x00')
if pos and revbytes[-1] & 0x80:
revbytes.append(0)
revbytes.reverse()
return ''.join(revbytes)
return six.binary_type(revbytes)

View File

@@ -21,11 +21,12 @@ import logging
import re
from threading import RLock
import weakref
import six
murmur3 = None
try:
from murmur3 import murmur3
except ImportError:
from cassandra.murmur3 import murmur3
except ImportError as e:
pass
import cassandra.cqltypes as types
@@ -330,7 +331,7 @@ class Metadata(object):
token_to_host_owner = {}
ring = []
for host, token_strings in token_map.iteritems():
for host, token_strings in six.iteritems(token_map):
for token_string in token_strings:
token = token_class(token_string)
ring.append(token)
@@ -793,14 +794,18 @@ class TableMetadata(object):
return list(sorted(ret))
def protect_name(name):
if isinstance(name, unicode):
name = name.encode('utf8')
return maybe_escape_name(name)
if six.PY3:
def protect_name(name):
return maybe_escape_name(name)
else:
def protect_name(name):
if isinstance(name, six.text_type):
name = name.encode('utf8')
return maybe_escape_name(name)
def protect_names(names):
return map(protect_name, names)
return [protect_name(n) for n in names]
def protect_value(value):
@@ -1008,11 +1013,14 @@ class Token(object):
def __eq__(self, other):
return self.value == other.value
def __lt__(self, other):
return self.value < other.value
def __hash__(self):
return hash(self.value)
def __repr__(self):
return "<%s: %r>" % (self.__class__.__name__, self.value)
return "<%s: %s>" % (self.__class__.__name__, self.value)
__str__ = __repr__
MIN_LONG = -(2 ** 63)
@@ -1031,7 +1039,7 @@ class Murmur3Token(Token):
@classmethod
def hash_fn(cls, key):
if murmur3 is not None:
h = murmur3(key)
h = int(murmur3(key))
return h if h != MIN_LONG else MAX_LONG
else:
raise NoMurmur3()
@@ -1048,6 +1056,8 @@ class MD5Token(Token):
@classmethod
def hash_fn(cls, key):
if isinstance(key, six.text_type):
key = key.encode('UTF-8')
return abs(varint_unpack(md5(key).digest()))
def __init__(self, token):
@@ -1062,7 +1072,7 @@ class BytesToken(Token):
def __init__(self, token_string):
""" `token_string` should be string representing the token. """
if not isinstance(token_string, basestring):
if not isinstance(token_string, six.string_types):
raise TypeError(
"Tokens for ByteOrderedPartitioner should be strings (got %s)"
% (type(token_string),))

View File

@@ -169,6 +169,18 @@ uint64_t MurmurHash3_x64_128 (const void * key, const int len,
return h1;
}
struct module_state {
PyObject *error;
};
#if PY_MAJOR_VERSION >= 3
#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m))
#else
#define GETSTATE(m) (&_state)
static struct module_state _state;
#endif
static PyObject *
murmur3(PyObject *self, PyObject *args)
{
@@ -186,23 +198,63 @@ murmur3(PyObject *self, PyObject *args)
}
static PyMethodDef murmur3_methods[] = {
{"murmur3", murmur3, METH_VARARGS,
"Make an x64 murmur3 64-bit hash value"},
{"murmur3", murmur3, METH_VARARGS, "Make an x64 murmur3 64-bit hash value"},
{NULL, NULL, 0, NULL}
};
#if PY_MAJOR_VERSION <= 2
#if PY_MAJOR_VERSION >= 3
PyMODINIT_FUNC
initmurmur3(void)
{
(void) Py_InitModule("murmur3", murmur3_methods);
static int murmur3_traverse(PyObject *m, visitproc visit, void *arg) {
Py_VISIT(GETSTATE(m)->error);
return 0;
}
static int murmur3_clear(PyObject *m) {
Py_CLEAR(GETSTATE(m)->error);
return 0;
}
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"murmur3",
NULL,
sizeof(struct module_state),
murmur3_methods,
NULL,
murmur3_traverse,
murmur3_clear,
NULL
};
#define INITERROR return NULL
PyObject *
PyInit_murmur3(void)
#else
#define INITERROR return
/* Python 3.x */
// TODO
void
initmurmur3(void)
#endif
{
#if PY_MAJOR_VERSION >= 3
PyObject *module = PyModule_Create(&moduledef);
#else
PyObject *module = Py_InitModule("murmur3", murmur3_methods);
#endif
if (module == NULL)
INITERROR;
struct module_state *st = GETSTATE(module);
st->error = PyErr_NewException("murmur3.Error", NULL, NULL);
if (st->error == NULL) {
Py_DECREF(module);
INITERROR;
}
#if PY_MAJOR_VERSION >= 3
return module;
#endif
}

View File

@@ -16,9 +16,12 @@ from itertools import islice, cycle, groupby, repeat
import logging
from random import randint
from threading import Lock
import six
from cassandra import ConsistencyLevel
from six.moves import range
log = logging.getLogger(__name__)
@@ -263,7 +266,7 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
for host in islice(cycle(local_live), pos, pos + len(local_live)):
yield host
for dc, current_dc_hosts in self._dc_live_hosts.iteritems():
for dc, current_dc_hosts in six.iteritems(self._dc_live_hosts):
if dc == self.local_dc:
continue
@@ -529,7 +532,7 @@ class ExponentialReconnectionPolicy(ReconnectionPolicy):
self.max_delay = max_delay
def new_schedule(self):
return (min(self.base_delay * (2 ** i), self.max_delay) for i in xrange(64))
return (min(self.base_delay * (2 ** i), self.max_delay) for i in range(64))
class WriteType(object):

View File

@@ -150,8 +150,14 @@ class Host(object):
def __eq__(self, other):
return self.address == other.address
def __hash__(self):
return hash(self.address)
def __lt__(self, other):
return self.address < other.address
def __str__(self):
return self.address
return str(self.address)
def __repr__(self):
dc = (" %s" % (self._datacenter,)) if self._datacenter else ""
@@ -178,7 +184,7 @@ class _ReconnectionHandler(object):
log.debug("Reconnection handler was cancelled before starting")
return
first_delay = self.schedule.next()
first_delay = next(self.schedule)
self.scheduler.schedule(first_delay, self.run)
def run(self):
@@ -189,7 +195,7 @@ class _ReconnectionHandler(object):
try:
conn = self.try_reconnect()
except Exception as exc:
next_delay = self.schedule.next()
next_delay = next(self.schedule)
if self.on_exception(exc, next_delay):
self.scheduler.schedule(next_delay, self.run)
else:
@@ -260,8 +266,8 @@ class _HostReconnectionHandler(_ReconnectionHandler):
if isinstance(exc, AuthenticationFailed):
return False
else:
log.warn("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s",
self.host, next_delay, exc)
log.warning("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s",
self.host, next_delay, exc)
log.debug("Reconnection error details", exc_info=True)
return True
@@ -371,8 +377,8 @@ class HostConnectionPool(object):
def _create_new_connection(self):
try:
self._add_conn_if_under_max()
except (ConnectionException, socket.error), exc:
log.warn("Failed to create new connection to %s: %s", self.host, exc)
except (ConnectionException, socket.error) as exc:
log.warning("Failed to create new connection to %s: %s", self.host, exc)
except Exception:
log.exception("Unexpectedly failed to create new connection")
finally:
@@ -404,7 +410,7 @@ class HostConnectionPool(object):
self._signal_available_conn()
return True
except (ConnectionException, socket.error) as exc:
log.warn("Failed to add new connection to pool for host %s: %s", self.host, exc)
log.warning("Failed to add new connection to pool for host %s: %s", self.host, exc)
with self._lock:
self.open_count -= 1
if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False):

View File

@@ -23,6 +23,7 @@ from datetime import datetime, timedelta
import re
import struct
import time
import six
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cqltypes import unix_time_from_uuid1
@@ -113,8 +114,8 @@ class Statement(object):
def _set_routing_key(self, key):
if isinstance(key, (list, tuple)):
self._routing_key = "".join(struct.pack("HsB", len(component), component, 0)
for component in key)
self._routing_key = b"".join(struct.pack("HsB", len(component), component, 0)
for component in key)
else:
self._routing_key = key
@@ -408,7 +409,7 @@ class BoundStatement(Statement):
val = self.values[statement_index]
components.append(struct.pack("HsB", len(val), val, 0))
self._routing_key = "".join(components)
self._routing_key = b"".join(components)
return self._routing_key
@@ -473,7 +474,7 @@ class BatchStatement(Statement):
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level)
def add(self, statement, parameters=None):
if isinstance(statement, basestring):
if isinstance(statement, six.string_types):
if parameters:
statement = bind_params(statement, parameters)
self._statements_and_parameters.append((False, statement, ()))
@@ -532,11 +533,9 @@ class ValueSequence(object):
def bind_params(query, params):
if isinstance(params, dict):
return query % dict((k, cql_encoders.get(type(v), cql_encode_object)(v))
for k, v in params.iteritems())
return query % dict((k, cql_encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
else:
return query % tuple(cql_encoders.get(type(v), cql_encode_object)(v)
for v in params)
return query % tuple(cql_encoders.get(type(v), cql_encode_object)(v) for v in params)
class TraceUnavailable(Exception):

View File

@@ -1,7 +1,5 @@
from __future__ import with_statement
from UserDict import DictMixin
try:
from collections import OrderedDict
except ImportError:
@@ -28,6 +26,7 @@ except ImportError:
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
# OTHER DEALINGS IN THE SOFTWARE.
from UserDict import DictMixin
class OrderedDict(dict, DictMixin): # noqa
""" A dictionary which maintains the insertion order of keys. """
@@ -80,9 +79,9 @@ except ImportError:
if not self:
raise KeyError('dictionary is empty')
if last:
key = reversed(self).next()
key = next(reversed(self))
else:
key = iter(self).next()
key = next(iter(self))
value = self.pop(key)
return key, value

View File

@@ -1,3 +1,4 @@
blist
futures
scales
six >=1.6

View File

@@ -12,16 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
import ez_setup
ez_setup.use_setuptools()
run_gevent_nosetests = False
if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests":
from gevent.monkey import patch_all
patch_all()
run_gevent_nosetests = True
from setuptools import setup
from distutils.command.build_ext import build_ext
@@ -48,10 +47,11 @@ with open("README.rst") as f:
long_description = f.read()
gevent_nosetests = None
if run_gevent_nosetests:
try:
from nose.commands import nosetests
except ImportError:
gevent_nosetests = None
else:
class gevent_nosetests(nosetests):
description = "run nosetests with gevent monkey patching"
@@ -92,11 +92,11 @@ class DocCommand(Command):
except subprocess.CalledProcessError as exc:
raise RuntimeError("Documentation step '%s' failed: %s: %s" % (mode, exc, exc.output))
else:
print output
print(output)
print ""
print "Documentation step '%s' performed, results here:" % mode
print " %s/" % path
print("")
print("Documentation step '%s' performed, results here:" % mode)
print(" %s/" % path)
class BuildFailed(Exception):
@@ -181,7 +181,7 @@ def run_setup(extensions):
kw['cmdclass']['build_ext'] = build_extensions
kw['ext_modules'] = extensions
dependencies = ['futures', 'scales', 'blist']
dependencies = ['futures', 'scales >=1.0.5', 'blist', 'six >=1.6']
if platform.python_implementation() != "CPython":
dependencies.remove('blist')
@@ -196,7 +196,7 @@ def run_setup(extensions):
packages=['cassandra', 'cassandra.io'],
include_package_data=True,
install_requires=dependencies,
tests_require=['nose', 'mock', 'ccm', 'unittest2', 'PyYAML', 'pytz'],
tests_require=['nose', 'mock', 'PyYAML', 'pytz'],
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
@@ -207,6 +207,10 @@ def run_setup(extensions):
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.6',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy',
'Topic :: Software Development :: Libraries :: Python Modules'
],
**kw)

View File

@@ -22,7 +22,9 @@ except ImportError:
import logging
log = logging.getLogger(__name__)
import os
from six import print_
from threading import Event
from cassandra.cluster import Cluster
@@ -92,7 +94,7 @@ def get_node(node_id):
def setup_package():
print 'Using Cassandra version: %s' % CASSANDRA_VERSION
print_('Using Cassandra version: %s' % CASSANDRA_VERSION)
try:
try:
cluster = CCMCluster.load(path, CLUSTER_NAME)

View File

@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import Queue
try:
from Queue import Queue, Empty
except ImportError:
from queue import Queue, Empty # noqa
from struct import pack
import unittest
@@ -31,13 +35,12 @@ def create_column_name(i):
column_name = ''
while True:
column_name += letters[i % 10]
i /= 10
i = i // 10
if not i:
break
if column_name == 'if':
column_name = 'special_case'
return column_name
@@ -56,15 +59,15 @@ class LargeDataTests(unittest.TestCase):
return session
def batch_futures(self, session, statement_generator):
concurrency = 50
futures = Queue.Queue(maxsize=concurrency)
concurrency = 10
futures = Queue(maxsize=concurrency)
for i, statement in enumerate(statement_generator):
if i > 0 and i % (concurrency - 1) == 0:
# clear the existing queue
while True:
try:
futures.get_nowait().result()
except Queue.Empty:
except Empty:
break
future = session.execute_async(statement)
@@ -73,7 +76,7 @@ class LargeDataTests(unittest.TestCase):
while True:
try:
futures.get_nowait().result()
except Queue.Empty:
except Empty:
break
def test_wide_rows(self):
@@ -81,15 +84,13 @@ class LargeDataTests(unittest.TestCase):
session = self.make_session_and_keyspace()
session.execute('CREATE TABLE %s (k INT, i INT, PRIMARY KEY(k, i))' % table)
prepared = session.prepare('INSERT INTO %s (k, i) VALUES (0, ?)' % (table, ))
# Write via async futures
self.batch_futures(
session,
(SimpleStatement('INSERT INTO %s (k, i) VALUES (0, %s)' % (table, i),
consistency_level=ConsistencyLevel.QUORUM)
for i in range(100000)))
self.batch_futures(session, (prepared.bind((i, )) for i in range(100000)))
# Read
results = session.execute('SELECT i FROM %s WHERE k=%s' % (table, 0))
results = session.execute('SELECT i FROM %s WHERE k=0' % (table, ))
# Verify
for i, row in enumerate(results):
@@ -120,18 +121,13 @@ class LargeDataTests(unittest.TestCase):
session = self.make_session_and_keyspace()
session.execute('CREATE TABLE %s (k INT, i INT, v BLOB, PRIMARY KEY(k, i))' % table)
# Build small ByteBuffer sample
bb = '0xCAFE'
prepared = session.prepare('INSERT INTO %s (k, i, v) VALUES (0, ?, 0xCAFE)' % (table, ))
# Write
self.batch_futures(
session,
(SimpleStatement('INSERT INTO %s (k, i, v) VALUES (0, %s, %s)' % (table, i, str(bb)),
consistency_level=ConsistencyLevel.QUORUM)
for i in range(100000)))
self.batch_futures(session, (prepared.bind((i, )) for i in range(100000)))
# Read
results = session.execute('SELECT i, v FROM %s WHERE k=%s' % (table, 0))
results = session.execute('SELECT i, v FROM %s WHERE k=0' % (table, ))
# Verify
bb = pack('>H', 0xCAFE)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import logging
import time
@@ -125,7 +126,7 @@ def bootstrap(node, data_center=None, token=None):
def ring(node):
print 'From node%s:' % node
print('From node%s:' % node)
get_node(node).nodetool('ring')
@@ -154,3 +155,5 @@ def wait_for_down(cluster, node, wait=True):
time.sleep(10)
log.debug("Done waiting for node %s to be down", node)
return
else:
log.debug("Host is still marked up, waiting")

View File

@@ -40,7 +40,7 @@ class ClusterTests(unittest.TestCase):
CREATE KEYSPACE clustertests
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
""")
self.assertEquals(None, result)
self.assertEqual(None, result)
result = session.execute(
"""
@@ -51,16 +51,16 @@ class ClusterTests(unittest.TestCase):
PRIMARY KEY (a, b)
)
""")
self.assertEquals(None, result)
self.assertEqual(None, result)
result = session.execute(
"""
INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c')
""")
self.assertEquals(None, result)
self.assertEqual(None, result)
result = session.execute("SELECT * FROM clustertests.cf0")
self.assertEquals([('a', 'b', 'c')], result)
self.assertEqual([('a', 'b', 'c')], result)
cluster.shutdown()
@@ -75,15 +75,15 @@ class ClusterTests(unittest.TestCase):
"""
INSERT INTO test3rf.test (k, v) VALUES (8889, 8889)
""")
self.assertEquals(None, result)
self.assertEqual(None, result)
result = session.execute("SELECT * FROM test3rf.test")
self.assertEquals([(8889, 8889)], result)
self.assertEqual([(8889, 8889)], result)
# test_connect_on_keyspace
session2 = cluster.connect('test3rf')
result2 = session2.execute("SELECT * FROM test")
self.assertEquals(result, result2)
self.assertEqual(result, result2)
def test_set_keyspace_twice(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)

View File

@@ -43,7 +43,7 @@ class ClusterTests(unittest.TestCase):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(num_statements)]
results = execute_concurrent(self.session, zip(statements, parameters))
results = execute_concurrent(self.session, list(zip(statements, parameters)))
self.assertEqual(num_statements, len(results))
self.assertEqual([(True, None)] * num_statements, results)
@@ -51,7 +51,7 @@ class ClusterTests(unittest.TestCase):
statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s", ))
parameters = [(i, ) for i in range(num_statements)]
results = execute_concurrent(self.session, zip(statements, parameters))
results = execute_concurrent(self.session, list(zip(statements, parameters)))
self.assertEqual(num_statements, len(results))
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
@@ -81,7 +81,7 @@ class ClusterTests(unittest.TestCase):
self.assertRaises(
InvalidRequest,
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)
def test_first_failure_client_side(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
@@ -92,7 +92,7 @@ class ClusterTests(unittest.TestCase):
self.assertRaises(
TypeError,
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)
def test_no_raise_on_first_failure(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
@@ -101,7 +101,7 @@ class ClusterTests(unittest.TestCase):
# we'll get an error back from the server
parameters[57] = ('efefef', 'awefawefawef')
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
for i, (success, result) in enumerate(results):
if i == 57:
self.assertFalse(success)
@@ -115,9 +115,9 @@ class ClusterTests(unittest.TestCase):
parameters = [(i, i) for i in range(100)]
# the driver will raise an error when binding the params
parameters[57] = i
parameters[57] = 1
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
for i, (success, result) in enumerate(results):
if i == 57:
self.assertFalse(success)

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
try:
import unittest2 as unittest
@@ -115,7 +116,7 @@ class SchemaMetadataTest(unittest.TestCase):
def check_create_statement(self, tablemeta, original):
recreate = tablemeta.as_cql_query(formatted=False)
self.assertEquals(original, recreate[:len(original)])
self.assertEqual(original, recreate[:len(original)])
self.session.execute("DROP TABLE %s.%s" % (self.ksname, self.cfname))
self.session.execute(recreate)
@@ -289,7 +290,7 @@ class SchemaMetadataTest(unittest.TestCase):
tablemeta = self.get_table_metadata()
statements = tablemeta.export_as_string().strip()
statements = [s.strip() for s in statements.split(';')]
statements = filter(bool, statements)
statements = list(filter(bool, statements))
self.assertEqual(3, len(statements))
self.assertEqual(d_index, statements[1])
self.assertEqual(e_index, statements[2])
@@ -311,7 +312,7 @@ class TestCodeCoverage(unittest.TestCase):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.connect()
self.assertIsInstance(cluster.metadata.export_schema_as_string(), basestring)
self.assertIsInstance(cluster.metadata.export_schema_as_string(), six.string_types)
def test_export_keyspace_schema(self):
"""
@@ -323,8 +324,8 @@ class TestCodeCoverage(unittest.TestCase):
for keyspace in cluster.metadata.keyspaces:
keyspace_metadata = cluster.metadata.keyspaces[keyspace]
self.assertIsInstance(keyspace_metadata.export_as_string(), basestring)
self.assertIsInstance(keyspace_metadata.as_cql_query(), basestring)
self.assertIsInstance(keyspace_metadata.export_as_string(), six.string_types)
self.assertIsInstance(keyspace_metadata.as_cql_query(), six.string_types)
def test_case_sensitivity(self):
"""

View File

@@ -69,7 +69,7 @@ class PreparedStatementTests(unittest.TestCase):
bound = prepared.bind(('a'))
results = session.execute(bound)
self.assertEquals(results, [('a', 'b', 'c')])
self.assertEqual(results, [('a', 'b', 'c')])
# test with new dict binding
prepared = session.prepare(
@@ -95,7 +95,7 @@ class PreparedStatementTests(unittest.TestCase):
bound = prepared.bind({'a': 'x'})
results = session.execute(bound)
self.assertEquals(results, [('x', 'y', 'z')])
self.assertEqual(results, [('x', 'y', 'z')])
def test_missing_primary_key(self):
"""
@@ -148,7 +148,7 @@ class PreparedStatementTests(unittest.TestCase):
""")
self.assertIsInstance(prepared, PreparedStatement)
self.assertRaises(ValueError, prepared.bind, (1,2))
self.assertRaises(ValueError, prepared.bind, (1, 2))
def test_too_many_bind_values_dicts(self):
"""
@@ -196,7 +196,7 @@ class PreparedStatementTests(unittest.TestCase):
bound = prepared.bind((1,))
results = session.execute(bound)
self.assertEquals(results[0].v, None)
self.assertEqual(results[0].v, None)
def test_none_values_dicts(self):
"""
@@ -206,7 +206,6 @@ class PreparedStatementTests(unittest.TestCase):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
# test with new dict binding
prepared = session.prepare(
"""
@@ -225,7 +224,7 @@ class PreparedStatementTests(unittest.TestCase):
bound = prepared.bind({'k': 1})
results = session.execute(bound)
self.assertEquals(results[0].v, None)
self.assertEqual(results[0].v, None)
def test_async_binding(self):
"""
@@ -252,8 +251,7 @@ class PreparedStatementTests(unittest.TestCase):
future = session.execute_async(prepared, (873,))
results = future.result()
self.assertEquals(results[0].v, None)
self.assertEqual(results[0].v, None)
def test_async_binding_dicts(self):
"""
@@ -280,4 +278,4 @@ class PreparedStatementTests(unittest.TestCase):
future = session.execute_async(prepared, {'k': 873})
results = future.result()
self.assertEquals(results[0].v, None)
self.assertEqual(results[0].v, None)

View File

@@ -24,7 +24,7 @@ from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
from cassandra.cluster import Cluster
from cassandra.policies import HostDistance
from tests.integration import get_server_versions, PROTOCOL_VERSION
from tests.integration import PROTOCOL_VERSION
class QueryTest(unittest.TestCase):
@@ -43,7 +43,7 @@ class QueryTest(unittest.TestCase):
self.assertIsInstance(bound, BoundStatement)
self.assertEqual(2, len(bound.values))
session.execute(bound)
self.assertEqual(bound.routing_key, '\x00\x00\x00\x01')
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
def test_value_sequence(self):
"""
@@ -102,7 +102,7 @@ class PreparedStatementTests(unittest.TestCase):
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1, None))
self.assertEqual(bound.routing_key, '\x00\x00\x00\x01')
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
def test_empty_routing_key_indexes(self):
"""
@@ -158,7 +158,7 @@ class PreparedStatementTests(unittest.TestCase):
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1, 2))
self.assertEqual(bound.routing_key, '\x04\x00\x00\x00\x04\x00\x00\x00')
self.assertEqual(bound.routing_key, b'\x04\x00\x00\x00\x04\x00\x00\x00')
def test_bound_keyspace(self):
"""
@@ -326,14 +326,14 @@ class SerialConsistencyTests(unittest.TestCase):
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1",
serial_consistency_level=ConsistencyLevel.SERIAL)
result = self.session.execute(statement)
self.assertEquals(1, len(result))
self.assertEqual(1, len(result))
self.assertFalse(result[0].applied)
statement = SimpleStatement(
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0",
serial_consistency_level=ConsistencyLevel.SERIAL)
result = self.session.execute(statement)
self.assertEquals(1, len(result))
self.assertEqual(1, len(result))
self.assertTrue(result[0].applied)
def test_conditional_update_with_prepared_statements(self):
@@ -343,7 +343,7 @@ class SerialConsistencyTests(unittest.TestCase):
statement.serial_consistency_level = ConsistencyLevel.SERIAL
result = self.session.execute(statement)
self.assertEquals(1, len(result))
self.assertEqual(1, len(result))
self.assertFalse(result[0].applied)
statement = self.session.prepare(
@@ -351,7 +351,7 @@ class SerialConsistencyTests(unittest.TestCase):
bound = statement.bind(())
bound.serial_consistency_level = ConsistencyLevel.SERIAL
result = self.session.execute(statement)
self.assertEquals(1, len(result))
self.assertEqual(1, len(result))
self.assertTrue(result[0].applied)
def test_bad_consistency_level(self):

View File

@@ -23,6 +23,7 @@ except ImportError:
import unittest # noqa
from itertools import cycle, count
from six.moves import range
from threading import Event
from cassandra.cluster import Cluster
@@ -47,7 +48,7 @@ class QueryPagingTests(unittest.TestCase):
def test_paging(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params)
execute_concurrent(self.session, list(statements_and_params))
prepared = self.session.prepare("SELECT * FROM test3rf.test")
@@ -153,7 +154,7 @@ class QueryPagingTests(unittest.TestCase):
def test_async_paging(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params)
execute_concurrent(self.session, list(statements_and_params))
prepared = self.session.prepare("SELECT * FROM test3rf.test")
@@ -219,7 +220,7 @@ class QueryPagingTests(unittest.TestCase):
def test_paging_callbacks(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params)
execute_concurrent(self.session, list(statements_and_params))
prepared = self.session.prepare("SELECT * FROM test3rf.test")
@@ -232,7 +233,7 @@ class QueryPagingTests(unittest.TestCase):
def handle_page(rows, future, counter):
for row in rows:
counter.next()
next(counter)
if future.has_more_pages:
future.start_fetching_next_page()
@@ -245,7 +246,7 @@ class QueryPagingTests(unittest.TestCase):
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait()
self.assertEquals(counter.next(), 100)
self.assertEquals(next(counter), 100)
# simple statement
future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"))
@@ -254,7 +255,7 @@ class QueryPagingTests(unittest.TestCase):
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait()
self.assertEquals(counter.next(), 100)
self.assertEquals(next(counter), 100)
# prepared statement
future = self.session.execute_async(prepared)
@@ -263,4 +264,4 @@ class QueryPagingTests(unittest.TestCase):
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait()
self.assertEquals(counter.next(), 100)
self.assertEquals(next(counter), 100)

View File

@@ -17,8 +17,12 @@ try:
except ImportError:
import unittest # noqa
import logging
log = logging.getLogger(__name__)
from decimal import Decimal
from datetime import datetime
import six
from uuid import uuid1, uuid4
try:
@@ -59,12 +63,16 @@ class TypeTests(unittest.TestCase):
params = [
'key1',
'blobyblob'.encode('hex')
b'blobyblob'
]
query = 'INSERT INTO mytable (a, b) VALUES (%s, %s)'
if self._cql_version >= (3, 1, 0):
# In python 3, the 'bytes' type is treated as a blob, so we can
# correctly encode it with hex notation.
# In python2, we don't treat the 'str' type as a blob, so we'll encode it
# as a string literal and have the following failure.
if six.PY2 and self._cql_version >= (3, 1, 0):
# Blob values can't be specified using string notation in CQL 3.1.0 and
# above which is used by default in Cassandra 2.0.
msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*'
@@ -74,13 +82,13 @@ class TypeTests(unittest.TestCase):
s.execute(query, params)
expected_vals = [
'key1',
'blobyblob'
bytearray(b'blobyblob')
]
results = s.execute("SELECT * FROM mytable")
for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual)
self.assertEqual(expected, actual)
def test_blob_type_as_bytearray(self):
c = Cluster(protocol_version=PROTOCOL_VERSION)
@@ -101,7 +109,7 @@ class TypeTests(unittest.TestCase):
params = [
'key1',
bytearray('blob1', 'hex')
bytearray(b'blob1')
]
query = 'INSERT INTO mytable (a, b) VALUES (%s, %s);'
@@ -109,13 +117,13 @@ class TypeTests(unittest.TestCase):
expected_vals = [
'key1',
bytearray('blob1', 'hex')
bytearray(b'blob1')
]
results = s.execute("SELECT * FROM mytable")
for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual)
self.assertEqual(expected, actual)
create_type_table = """
CREATE TABLE mytable (
@@ -208,7 +216,7 @@ class TypeTests(unittest.TestCase):
results = s.execute("SELECT * FROM mytable")
for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual)
self.assertEqual(expected, actual)
# try the same thing with a prepared statement
prepared = s.prepare("""
@@ -221,7 +229,7 @@ class TypeTests(unittest.TestCase):
results = s.execute("SELECT * FROM mytable")
for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual)
self.assertEqual(expected, actual)
# query with prepared statement
prepared = s.prepare("""
@@ -230,14 +238,14 @@ class TypeTests(unittest.TestCase):
results = s.execute(prepared.bind(()))
for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual)
self.assertEqual(expected, actual)
# query with prepared statement, no explicit columns
prepared = s.prepare("""SELECT * FROM mytable""")
results = s.execute(prepared.bind(()))
for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual)
self.assertEqual(expected, actual)
def test_empty_strings_and_nones(self):
c = Cluster(protocol_version=PROTOCOL_VERSION)
@@ -265,11 +273,11 @@ class TypeTests(unittest.TestCase):
# insert empty strings for string-like fields and fetch them
s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)",
('', '', '', [''], {'': 3}))
self.assertEquals(
self.assertEqual(
{'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})},
s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0])
self.assertEquals(
self.assertEqual(
{'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})},
s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0])
@@ -363,7 +371,7 @@ class TypeTests(unittest.TestCase):
""" Ensure timezone-aware datetimes are converted to timestamps correctly """
try:
import pytz
except ImportError, exc:
except ImportError as exc:
raise unittest.SkipTest('pytz is not available: %r' % (exc,))
dt = datetime(1997, 8, 29, 11, 14)
@@ -381,10 +389,10 @@ class TypeTests(unittest.TestCase):
# test non-prepared statement
s.execute("INSERT INTO mytable (a, b) VALUES ('key1', %s)", parameters=(dt,))
result = s.execute("SELECT b FROM mytable WHERE a='key1'")[0].b
self.assertEquals(dt.utctimetuple(), result.utctimetuple())
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
# test prepared statement
prepared = s.prepare("INSERT INTO mytable (a, b) VALUES ('key2', ?)")
s.execute(prepared, parameters=(dt,))
result = s.execute("SELECT b FROM mytable WHERE a='key2'")[0].b
self.assertEquals(dt.utctimetuple(), result.utctimetuple())
self.assertEqual(dt.utctimetuple(), result.utctimetuple())

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
try:
import unittest2 as unittest
@@ -19,7 +20,9 @@ except ImportError:
import errno
import os
from StringIO import StringIO
from six import BytesIO
import socket
from socket import error as socket_error
@@ -55,7 +58,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
return c
def make_header_prefix(self, message_class, version=2, stream_id=0):
return ''.join(map(uint8_pack, [
return six.binary_type().join(map(uint8_pack, [
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
0, # flags (compression)
stream_id,
@@ -63,7 +66,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
]))
def make_options_body(self):
options_buf = StringIO()
options_buf = BytesIO()
write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.1'],
'COMPRESSION': []
@@ -71,12 +74,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
return options_buf.getvalue()
def make_error_body(self, code, msg):
buf = StringIO()
buf = BytesIO()
write_int(buf, code)
write_string(buf, msg)
return buf.getvalue()
def make_msg(self, header, body=""):
def make_msg(self, header, body=six.binary_type()):
return header + uint32_pack(len(body)) + body
def test_successful_connection(self, *args):
@@ -105,12 +108,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
# get a connection that's already fully started
c = self.test_successful_connection()
header = '\x00\x00\x00\x00' + int32_pack(20000)
header = six.b('\x00\x00\x00\x00') + int32_pack(20000)
responses = [
header + ('a' * (4096 - len(header))),
'a' * 4096,
header + (six.b('a') * (4096 - len(header))),
six.b('a') * 4096,
socket_error(errno.EAGAIN),
'a' * 100,
six.b('a') * 100,
socket_error(errno.EAGAIN)]
def side_effect(*args):
@@ -122,17 +125,17 @@ class AsyncoreConnectionTest(unittest.TestCase):
c.socket.recv.side_effect = side_effect
c.handle_read()
self.assertEquals(c._total_reqd_bytes, 20000 + len(header))
self.assertEqual(c._total_reqd_bytes, 20000 + len(header))
# the EAGAIN prevents it from reading the last 100 bytes
c._iobuf.seek(0, os.SEEK_END)
pos = c._iobuf.tell()
self.assertEquals(pos, 4096 + 4096)
self.assertEqual(pos, 4096 + 4096)
# now tell it to read the last 100 bytes
c.handle_read()
c._iobuf.seek(0, os.SEEK_END)
pos = c._iobuf.tell()
self.assertEquals(pos, 4096 + 4096 + 100)
self.assertEqual(pos, 4096 + 4096 + 100)
def test_protocol_error(self, *args):
c = self.make_connection()
@@ -237,14 +240,13 @@ class AsyncoreConnectionTest(unittest.TestCase):
options = self.make_options_body()
message = self.make_msg(header, options)
# read in the first byte
c.socket.recv.return_value = message[0]
c.socket.recv.return_value = message[0:1]
c.handle_read()
self.assertEquals(c._iobuf.getvalue(), message[0])
self.assertEqual(c._iobuf.getvalue(), message[0:1])
c.socket.recv.return_value = message[1:]
c.handle_read()
self.assertEquals("", c._iobuf.getvalue())
self.assertEqual(six.binary_type(), c._iobuf.getvalue())
# let it write out a StartupMessage
c.handle_write()
@@ -266,12 +268,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
# read in the first nine bytes
c.socket.recv.return_value = message[:9]
c.handle_read()
self.assertEquals(c._iobuf.getvalue(), message[:9])
self.assertEqual(c._iobuf.getvalue(), message[:9])
# ... then read in the rest
c.socket.recv.return_value = message[9:]
c.handle_read()
self.assertEquals("", c._iobuf.getvalue())
self.assertEqual(six.binary_type(), c._iobuf.getvalue())
# let it write out a StartupMessage
c.handle_write()

View File

@@ -19,7 +19,9 @@ except ImportError:
import errno
import os
from StringIO import StringIO
from six.moves import StringIO
from socket import error as socket_error
from mock import patch, Mock
@@ -122,17 +124,17 @@ class LibevConnectionTest(unittest.TestCase):
c._socket.recv.side_effect = side_effect
c.handle_read(None, 0)
self.assertEquals(c._total_reqd_bytes, 20000 + len(header))
self.assertEqual(c._total_reqd_bytes, 20000 + len(header))
# the EAGAIN prevents it from reading the last 100 bytes
c._iobuf.seek(0, os.SEEK_END)
pos = c._iobuf.tell()
self.assertEquals(pos, 4096 + 4096)
self.assertEqual(pos, 4096 + 4096)
# now tell it to read the last 100 bytes
c.handle_read(None, 0)
c._iobuf.seek(0, os.SEEK_END)
pos = c._iobuf.tell()
self.assertEquals(pos, 4096 + 4096 + 100)
self.assertEqual(pos, 4096 + 4096 + 100)
def test_protocol_error(self, *args):
c = self.make_connection()
@@ -240,11 +242,11 @@ class LibevConnectionTest(unittest.TestCase):
# read in the first byte
c._socket.recv.return_value = message[0]
c.handle_read(None, 0)
self.assertEquals(c._iobuf.getvalue(), message[0])
self.assertEqual(c._iobuf.getvalue(), message[0])
c._socket.recv.return_value = message[1:]
c.handle_read(None, 0)
self.assertEquals("", c._iobuf.getvalue())
self.assertEqual("", c._iobuf.getvalue())
# let it write out a StartupMessage
c.handle_write(None, 0)
@@ -266,12 +268,12 @@ class LibevConnectionTest(unittest.TestCase):
# read in the first nine bytes
c._socket.recv.return_value = message[:9]
c.handle_read(None, 0)
self.assertEquals(c._iobuf.getvalue(), message[:9])
self.assertEqual(c._iobuf.getvalue(), message[:9])
# ... then read in the rest
c._socket.recv.return_value = message[9:]
c.handle_read(None, 0)
self.assertEquals("", c._iobuf.getvalue())
self.assertEqual("", c._iobuf.getvalue())
# let it write out a StartupMessage
c.handle_write(None, 0)

View File

@@ -11,17 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cassandra.cluster import Cluster
import six
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from StringIO import StringIO
from six import BytesIO
from mock import Mock, ANY
from cassandra.cluster import Cluster
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT,
HEADER_DIRECTION_FROM_CLIENT, ProtocolError)
from cassandra.decoder import (write_stringmultimap, write_int, write_string,
@@ -40,7 +41,7 @@ class ConnectionTest(unittest.TestCase):
return c
def make_header_prefix(self, message_class, version=2, stream_id=0):
return ''.join(map(uint8_pack, [
return six.binary_type().join(map(uint8_pack, [
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
0, # flags (compression)
stream_id,
@@ -48,7 +49,7 @@ class ConnectionTest(unittest.TestCase):
]))
def make_options_body(self):
options_buf = StringIO()
options_buf = BytesIO()
write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.1'],
'COMPRESSION': []
@@ -56,7 +57,7 @@ class ConnectionTest(unittest.TestCase):
return options_buf.getvalue()
def make_error_body(self, code, msg):
buf = StringIO()
buf = BytesIO()
write_int(buf, code)
write_string(buf, msg)
return buf.getvalue()
@@ -88,12 +89,12 @@ class ConnectionTest(unittest.TestCase):
c.defunct = Mock()
# read in a SupportedMessage response
header = ''.join(map(uint8_pack, [
header = six.binary_type().join(uint8_pack(i) for i in (
0xff & (HEADER_DIRECTION_FROM_CLIENT | self.protocol_version),
0, # flags (compression)
0,
SupportedMessage.opcode # opcode
]))
))
options = self.make_options_body()
message = self.make_msg(header, options)
c.process_msg(message, len(message) - 8)
@@ -130,7 +131,7 @@ class ConnectionTest(unittest.TestCase):
# read in a SupportedMessage response
header = self.make_header_prefix(SupportedMessage)
options_buf = StringIO()
options_buf = BytesIO()
write_stringmultimap(options_buf, {
'CQL_VERSION': ['7.8.9'],
'COMPRESSION': []

View File

@@ -16,7 +16,7 @@ from cassandra.marshal import bitlength
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
import unittest # noqa
import platform
from datetime import datetime
@@ -33,57 +33,57 @@ from cassandra.util import OrderedDict
marshalled_value_pairs = (
# binary form, type, python native type
('lorem ipsum dolor sit amet', 'AsciiType', 'lorem ipsum dolor sit amet'),
('', 'AsciiType', ''),
('\x01', 'BooleanType', True),
('\x00', 'BooleanType', False),
('', 'BooleanType', None),
('\xff\xfe\xfd\xfc\xfb', 'BytesType', '\xff\xfe\xfd\xfc\xfb'),
('', 'BytesType', ''),
('\x7f\xff\xff\xff\xff\xff\xff\xff', 'CounterColumnType', 9223372036854775807),
('\x80\x00\x00\x00\x00\x00\x00\x00', 'CounterColumnType', -9223372036854775808),
('', 'CounterColumnType', None),
('\x00\x00\x013\x7fb\xeey', 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)),
('', 'DateType', None),
('\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe', 'DecimalType', Decimal('1243878957943.1234124191998')),
('\x00\x00\x00\x06\xe5\xde]\x98Y', 'DecimalType', Decimal('-112233.441191')),
('\x00\x00\x00\x14\x00\xfa\xce', 'DecimalType', Decimal('0.00000000000000064206')),
('\x00\x00\x00\x14\xff\x052', 'DecimalType', Decimal('-0.00000000000000064206')),
('\xff\xff\xff\x9c\x00\xfa\xce', 'DecimalType', Decimal('64206e100')),
('', 'DecimalType', None),
('@\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', 19432.125),
('\xc0\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', -19432.125),
('\x7f\xef\x00\x00\x00\x00\x00\x00', 'DoubleType', 1.7415152243978685e+308),
('', 'DoubleType', None),
('F\x97\xd0@', 'FloatType', 19432.125),
('\xc6\x97\xd0@', 'FloatType', -19432.125),
('\xc6\x97\xd0@', 'FloatType', -19432.125),
('\x7f\x7f\x00\x00', 'FloatType', 338953138925153547590470800371487866880.0),
('', 'FloatType', None),
('\x7f\x50\x00\x00', 'Int32Type', 2135949312),
('\xff\xfd\xcb\x91', 'Int32Type', -144495),
('', 'Int32Type', None),
('f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15', 'IntegerType', 123456789123456789123456789),
('', 'IntegerType', None),
('\x7f\xff\xff\xff\xff\xff\xff\xff', 'LongType', 9223372036854775807),
('\x80\x00\x00\x00\x00\x00\x00\x00', 'LongType', -9223372036854775808),
('', 'LongType', None),
('', 'InetAddressType', None),
('A46\xa9', 'InetAddressType', '65.52.54.169'),
('*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"', 'InetAddressType', '2a00:1328:e102:ccc0::122'),
('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', u'\u307e\u3057\u3066'),
('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', u'\u307e\u3057\u3066' * 1000),
('', 'UTF8Type', u''),
('\xff' * 16, 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')),
('I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*', 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')),
('', 'UUIDType', None),
('', 'MapType(AsciiType, BooleanType)', None),
('', 'ListType(FloatType)', None),
('', 'SetType(LongType)', None),
('\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()),
('\x00\x00', 'ListType(FloatType)', ()),
('\x00\x00', 'SetType(IntegerType)', sortedset()),
('\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', (UUID(bytes='\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0'),)),
(b'lorem ipsum dolor sit amet', 'AsciiType', 'lorem ipsum dolor sit amet'),
(b'', 'AsciiType', ''),
(b'\x01', 'BooleanType', True),
(b'\x00', 'BooleanType', False),
(b'', 'BooleanType', None),
(b'\xff\xfe\xfd\xfc\xfb', 'BytesType', b'\xff\xfe\xfd\xfc\xfb'),
(b'', 'BytesType', b''),
(b'\x7f\xff\xff\xff\xff\xff\xff\xff', 'CounterColumnType', 9223372036854775807),
(b'\x80\x00\x00\x00\x00\x00\x00\x00', 'CounterColumnType', -9223372036854775808),
(b'', 'CounterColumnType', None),
(b'\x00\x00\x013\x7fb\xeey', 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)),
(b'', 'DateType', None),
(b'\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe', 'DecimalType', Decimal('1243878957943.1234124191998')),
(b'\x00\x00\x00\x06\xe5\xde]\x98Y', 'DecimalType', Decimal('-112233.441191')),
(b'\x00\x00\x00\x14\x00\xfa\xce', 'DecimalType', Decimal('0.00000000000000064206')),
(b'\x00\x00\x00\x14\xff\x052', 'DecimalType', Decimal('-0.00000000000000064206')),
(b'\xff\xff\xff\x9c\x00\xfa\xce', 'DecimalType', Decimal('64206e100')),
(b'', 'DecimalType', None),
(b'@\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', 19432.125),
(b'\xc0\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', -19432.125),
(b'\x7f\xef\x00\x00\x00\x00\x00\x00', 'DoubleType', 1.7415152243978685e+308),
(b'', 'DoubleType', None),
(b'F\x97\xd0@', 'FloatType', 19432.125),
(b'\xc6\x97\xd0@', 'FloatType', -19432.125),
(b'\xc6\x97\xd0@', 'FloatType', -19432.125),
(b'\x7f\x7f\x00\x00', 'FloatType', 338953138925153547590470800371487866880.0),
(b'', 'FloatType', None),
(b'\x7f\x50\x00\x00', 'Int32Type', 2135949312),
(b'\xff\xfd\xcb\x91', 'Int32Type', -144495),
(b'', 'Int32Type', None),
(b'f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15', 'IntegerType', 123456789123456789123456789),
(b'', 'IntegerType', None),
(b'\x7f\xff\xff\xff\xff\xff\xff\xff', 'LongType', 9223372036854775807),
(b'\x80\x00\x00\x00\x00\x00\x00\x00', 'LongType', -9223372036854775808),
(b'', 'LongType', None),
(b'', 'InetAddressType', None),
(b'A46\xa9', 'InetAddressType', '65.52.54.169'),
(b'*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"', 'InetAddressType', '2a00:1328:e102:ccc0::122'),
(b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', u'\u307e\u3057\u3066'),
(b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', u'\u307e\u3057\u3066' * 1000),
(b'', 'UTF8Type', u''),
(b'\xff' * 16, 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')),
(b'I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*', 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')),
(b'', 'UUIDType', None),
(b'', 'MapType(AsciiType, BooleanType)', None),
(b'', 'ListType(FloatType)', None),
(b'', 'SetType(LongType)', None),
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()),
(b'\x00\x00', 'ListType(FloatType)', ()),
(b'\x00\x00', 'SetType(IntegerType)', sortedset()),
(b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', (UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0'),)),
)
ordered_dict_value = OrderedDict()
@@ -94,9 +94,9 @@ ordered_dict_value[u'\\'] = 0
# these following entries work for me right now, but they're dependent on
# vagaries of internal python ordering for unordered types
marshalled_value_pairs_unsafe = (
('\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value),
('\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])),
('\x00', 'IntegerType', 0),
(b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value),
(b'\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])),
(b'\x00', 'IntegerType', 0),
)
if platform.python_implementation() == 'CPython':

View File

@@ -29,6 +29,12 @@ from cassandra.pool import Host
class TestStrategies(unittest.TestCase):
@classmethod
def setUpClass(cls):
"Hook method for setting up class fixture before running tests in the class."
if not hasattr(cls, 'assertItemsEqual'):
cls.assertItemsEqual = cls.assertCountEqual
def test_replication_strategy(self):
"""
Basic code coverage testing that ensures different ReplicationStrategies
@@ -217,22 +223,22 @@ class TestTokens(unittest.TestCase):
murmur3_token = Murmur3Token(cassandra.metadata.MIN_LONG - 1)
self.assertEqual(murmur3_token.hash_fn('123'), -7468325962851647638)
self.assertEqual(murmur3_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 7162290910810015547)
self.assertEqual(str(murmur3_token), '<Murmur3Token: -9223372036854775809L>')
self.assertEqual(str(murmur3_token), '<Murmur3Token: -9223372036854775809>')
except NoMurmur3:
raise unittest.SkipTest('The murmur3 extension is not available')
def test_md5_tokens(self):
md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1)
self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808L)
self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639L)
self.assertEqual(str(md5_token), '<MD5Token: -9223372036854775809L>')
self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808)
self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639)
self.assertEqual(str(md5_token), '<MD5Token: %s>' % -9223372036854775809)
def test_bytes_tokens(self):
bytes_token = BytesToken(str(cassandra.metadata.MIN_LONG - 1))
self.assertEqual(bytes_token.hash_fn('123'), '123')
self.assertEqual(bytes_token.hash_fn(123), 123)
self.assertEqual(bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)), str(cassandra.metadata.MAX_LONG))
self.assertEqual(str(bytes_token), "<BytesToken: '-9223372036854775809'>")
self.assertEqual(str(bytes_token), "<BytesToken: -9223372036854775809>")
try:
bytes_token = BytesToken(cassandra.metadata.MIN_LONG - 1)

View File

@@ -22,32 +22,34 @@ from cassandra.query import PreparedStatement, BoundStatement
from cassandra.cqltypes import Int32Type
from cassandra.util import OrderedDict
from six.moves import xrange
class ParamBindingTest(unittest.TestCase):
def test_bind_sequence(self):
result = bind_params("%s %s %s", (1, "a", 2.0))
self.assertEquals(result, "1 'a' 2.0")
self.assertEqual(result, "1 'a' 2.0")
def test_bind_map(self):
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0))
self.assertEquals(result, "1 'a' 2.0")
self.assertEqual(result, "1 'a' 2.0")
def test_sequence_param(self):
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),))
self.assertEquals(result, "( 1 , 'a' , 2.0 )")
self.assertEqual(result, "( 1 , 'a' , 2.0 )")
def test_generator_param(self):
result = bind_params("%s", ((i for i in xrange(3)),))
self.assertEquals(result, "[ 0 , 1 , 2 ]")
self.assertEqual(result, "[ 0 , 1 , 2 ]")
def test_none_param(self):
result = bind_params("%s", (None,))
self.assertEquals(result, "NULL")
self.assertEqual(result, "NULL")
def test_list_collection(self):
result = bind_params("%s", (['a', 'b', 'c'],))
self.assertEquals(result, "[ 'a' , 'b' , 'c' ]")
self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
def test_set_collection(self):
result = bind_params("%s", (set(['a', 'b']),))
@@ -59,11 +61,11 @@ class ParamBindingTest(unittest.TestCase):
vals['b'] = 'b'
vals['c'] = 'c'
result = bind_params("%s", (vals,))
self.assertEquals(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
def test_quote_escaping(self):
result = bind_params("%s", ("""'ef''ef"ef""ef'""",))
self.assertEquals(result, """'''ef''''ef"ef""ef'''""")
self.assertEqual(result, """'''ef''''ef"ef""ef'''""")
class BoundStatementTestCase(unittest.TestCase):

View File

@@ -20,6 +20,7 @@ except ImportError:
from itertools import islice, cycle
from mock import Mock
from random import randint
import six
import sys
import struct
from threading import Thread
@@ -36,6 +37,8 @@ from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
from cassandra.pool import Host
from cassandra.query import Statement
from six.moves import xrange
class TestLoadBalancingPolicy(unittest.TestCase):
def test_non_implemented(self):
@@ -137,13 +140,23 @@ class TestRoundRobinPolicy(unittest.TestCase):
# make the GIL switch after every instruction, maximizing
# the chace of race conditions
original_interval = sys.getcheckinterval()
if six.PY2:
original_interval = sys.getcheckinterval()
else:
original_interval = sys.getswitchinterval()
try:
sys.setcheckinterval(0)
if six.PY2:
sys.setcheckinterval(0)
else:
sys.setswitchinterval(0.0001)
map(lambda t: t.start(), threads)
map(lambda t: t.join(), threads)
finally:
sys.setcheckinterval(original_interval)
if six.PY2:
sys.setcheckinterval(original_interval)
else:
sys.setswitchinterval(original_interval)
if errors:
self.fail("Saw errors: %s" % (errors,))
@@ -334,14 +347,14 @@ class TokenAwarePolicyTest(unittest.TestCase):
replicas = get_replicas(None, struct.pack('>i', i))
other = set(h for h in hosts if h not in replicas)
self.assertEquals(replicas, qplan[:2])
self.assertEquals(other, set(qplan[2:]))
self.assertEqual(replicas, qplan[:2])
self.assertEqual(other, set(qplan[2:]))
# Should use the secondary policy
for i in range(4):
qplan = list(policy.make_query_plan())
self.assertEquals(set(qplan), set(hosts))
self.assertEqual(set(qplan), set(hosts))
def test_wrap_dc_aware(self):
cluster = Mock(spec=Cluster)
@@ -374,16 +387,16 @@ class TokenAwarePolicyTest(unittest.TestCase):
# first should be the only local replica
self.assertIn(qplan[0], replicas)
self.assertEquals(qplan[0].datacenter, "dc1")
self.assertEqual(qplan[0].datacenter, "dc1")
# then the local non-replica
self.assertNotIn(qplan[1], replicas)
self.assertEquals(qplan[1].datacenter, "dc1")
self.assertEqual(qplan[1].datacenter, "dc1")
# then one of the remotes (used_hosts_per_remote_dc is 1, so we
# shouldn't see two remotes)
self.assertEquals(qplan[2].datacenter, "dc2")
self.assertEquals(3, len(qplan))
self.assertEqual(qplan[2].datacenter, "dc2")
self.assertEqual(3, len(qplan))
class FakeCluster:
def __init__(self):

View File

@@ -346,7 +346,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.send_request()
rf.add_callbacks(
callback=self.assertEquals, callback_args=([{'col': 'val'}],),
callback=self.assertEqual, callback_args=([{'col': 'val'}],),
errback=self.assertIsInstance, errback_args=(Exception,))
result = Mock(spec=UnavailableErrorMessage, info={})
@@ -358,7 +358,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.send_request()
rf.add_callbacks(
callback=self.assertEquals, callback_args=([{'col': 'val'}],),
callback=self.assertEqual, callback_args=([{'col': 'val'}],),
errback=self.assertIsInstance, errback_args=(Exception,))
rf._set_result(self.make_mock_response([{'col': 'val'}]))
@@ -380,9 +380,9 @@ class ResponseFutureTests(unittest.TestCase):
session.submit.assert_called_once()
args, kwargs = session.submit.call_args
self.assertEquals(rf._reprepare, args[-2])
self.assertEqual(rf._reprepare, args[-2])
self.assertIsInstance(args[-1], PrepareMessage)
self.assertEquals(args[-1].query, "SELECT * FROM foobar")
self.assertEqual(args[-1].query, "SELECT * FROM foobar")
def test_prepared_query_not_found_bad_keyspace(self):
session = self.make_session()

View File

@@ -163,18 +163,17 @@ class TypeTests(unittest.TestCase):
'7a6970:org.apache.cassandra.db.marshal.UTF8Type',
')')))
self.assertEquals(FooType, ctype.__class__)
self.assertEqual(FooType, ctype.__class__)
self.assertEquals(UTF8Type, ctype.subtypes[0])
self.assertEqual(UTF8Type, ctype.subtypes[0])
# middle subtype should be a BarType instance with its own subtypes and names
self.assertIsInstance(ctype.subtypes[1], BarType)
self.assertEquals([UTF8Type], ctype.subtypes[1].subtypes)
self.assertEquals(["address"], ctype.subtypes[1].names)
self.assertEqual([UTF8Type], ctype.subtypes[1].subtypes)
self.assertEqual([b"address"], ctype.subtypes[1].names)
self.assertEquals(UTF8Type, ctype.subtypes[2])
self.assertEquals(['city', None, 'zip'], ctype.names)
self.assertEqual(UTF8Type, ctype.subtypes[2])
self.assertEqual([b'city', None, b'zip'], ctype.names)
def test_empty_value(self):
self.assertEqual(str(EmptyValue()), 'EMPTY')

10
tox.ini
View File

@@ -1,5 +1,5 @@
[tox]
envlist = py26,py27,pypy
envlist = py26,py27,pypy,py33
[testenv]
deps = nose
@@ -8,5 +8,13 @@ deps = nose
unittest2
pip
PyYAML
six
commands = {envpython} setup.py build_ext --inplace
nosetests --verbosity=2 tests/unit/
[testenv:py33]
deps = nose
mock
pip
PyYAML
six