Merge branch 'py3k' into 2.0

Conflicts:
	cassandra/cluster.py
	cassandra/encoder.py
	cassandra/marshal.py
	cassandra/pool.py
	setup.py
	tests/integration/long/test_large_data.py
	tests/integration/long/utils.py
	tests/integration/standard/test_metadata.py
	tests/integration/standard/test_prepared_statements.py
	tests/unit/io/test_asyncorereactor.py
	tests/unit/test_connection.py
	tests/unit/test_types.py
This commit is contained in:
Tyler Hobbs
2014-05-02 15:55:35 -05:00
44 changed files with 627 additions and 431 deletions

View File

@@ -4,11 +4,13 @@ env:
- TOX_ENV=py26 - TOX_ENV=py26
- TOX_ENV=py27 - TOX_ENV=py27
- TOX_ENV=pypy - TOX_ENV=pypy
- TOX_ENV=py33
before_install: before_install:
- sudo apt-get update -y - sudo apt-get update -y
- sudo apt-get install -y build-essential python-dev - sudo apt-get install -y build-essential python-dev
- sudo apt-get install -y libev4 libev-dev - sudo apt-get install -y libev4 libev-dev
install: install:
- pip install tox - pip install tox

View File

@@ -41,7 +41,7 @@ try:
from cassandra.io.libevreactor import LibevConnection from cassandra.io.libevreactor import LibevConnection
have_libev = True have_libev = True
supported_reactors.append(LibevConnection) supported_reactors.append(LibevConnection)
except ImportError, exc: except ImportError as exc:
pass pass
KEYSPACE = "testkeyspace" KEYSPACE = "testkeyspace"
@@ -104,7 +104,7 @@ def benchmark(thread_class):
""".format(table=TABLE)) """.format(table=TABLE))
values = ('key', 'a', 'b') values = ('key', 'a', 'b')
per_thread = options.num_ops / options.threads per_thread = options.num_ops // options.threads
threads = [] threads = []
log.debug("Beginning inserts...") log.debug("Beginning inserts...")

View File

@@ -18,7 +18,7 @@ from itertools import count
from threading import Event from threading import Event
from base import benchmark, BenchmarkThread from base import benchmark, BenchmarkThread
from six.moves import range
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -38,17 +38,17 @@ class Runner(BenchmarkThread):
if previous_result is not sentinel: if previous_result is not sentinel:
if isinstance(previous_result, BaseException): if isinstance(previous_result, BaseException):
log.error("Error on insert: %r", previous_result) log.error("Error on insert: %r", previous_result)
if self.num_finished.next() >= self.num_queries: if next(self.num_finished) >= self.num_queries:
self.event.set() self.event.set()
if self.num_started.next() <= self.num_queries: if next(self.num_started) <= self.num_queries:
future = self.session.execute_async(self.query, self.values) future = self.session.execute_async(self.query, self.values)
future.add_callbacks(self.insert_next, self.insert_next) future.add_callbacks(self.insert_next, self.insert_next)
def run(self): def run(self):
self.start_profile() self.start_profile()
for _ in xrange(min(120, self.num_queries)): for _ in range(min(120, self.num_queries)):
self.insert_next() self.insert_next()
self.event.wait() self.event.wait()

View File

@@ -13,16 +13,16 @@
# limitations under the License. # limitations under the License.
import logging import logging
import Queue
from base import benchmark, BenchmarkThread from base import benchmark, BenchmarkThread
from six.moves import queue
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class Runner(BenchmarkThread): class Runner(BenchmarkThread):
def run(self): def run(self):
futures = Queue.Queue(maxsize=121) futures = queue.Queue(maxsize=121)
self.start_profile() self.start_profile()
@@ -32,7 +32,7 @@ class Runner(BenchmarkThread):
while True: while True:
try: try:
futures.get_nowait().result() futures.get_nowait().result()
except Queue.Empty: except queue.Empty:
break break
future = self.session.execute_async(self.query, self.values) future = self.session.execute_async(self.query, self.values)
@@ -41,7 +41,7 @@ class Runner(BenchmarkThread):
while True: while True:
try: try:
futures.get_nowait().result() futures.get_nowait().result()
except Queue.Empty: except queue.Empty:
break break
self.finish_profile() self.finish_profile()

View File

@@ -13,16 +13,16 @@
# limitations under the License. # limitations under the License.
import logging import logging
import Queue
from base import benchmark, BenchmarkThread from base import benchmark, BenchmarkThread
from six.moves import queue
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class Runner(BenchmarkThread): class Runner(BenchmarkThread):
def run(self): def run(self):
futures = Queue.Queue(maxsize=121) futures = queue.Queue(maxsize=121)
self.start_profile() self.start_profile()
@@ -37,7 +37,7 @@ class Runner(BenchmarkThread):
while True: while True:
try: try:
futures.get_nowait().result() futures.get_nowait().result()
except Queue.Empty: except queue.Empty:
break break
self.finish_profile self.finish_profile

View File

@@ -25,7 +25,7 @@ class Runner(BenchmarkThread):
self.start_profile() self.start_profile()
for i in range(self.num_queries): for _ in range(self.num_queries):
future = self.session.execute_async(self.query, self.values) future = self.session.execute_async(self.query, self.values)
futures.append(future) futures.append(future)

View File

@@ -13,13 +13,15 @@
# limitations under the License. # limitations under the License.
from base import benchmark, BenchmarkThread from base import benchmark, BenchmarkThread
from six.moves import range
class Runner(BenchmarkThread): class Runner(BenchmarkThread):
def run(self): def run(self):
self.start_profile() self.start_profile()
for i in xrange(self.num_queries): for _ in range(self.num_queries):
self.session.execute(self.query, self.values) self.session.execute(self.query, self.values)
self.finish_profile() self.finish_profile()

View File

@@ -26,7 +26,11 @@ import socket
import sys import sys
import time import time
from threading import Lock, RLock, Thread, Event from threading import Lock, RLock, Thread, Event
import Queue
import six
from six.moves import range
from six.moves import queue as Queue
import weakref import weakref
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
try: try:
@@ -696,7 +700,7 @@ class Cluster(object):
host.set_down() host.set_down()
log.warn("Host %s has been marked down", host) log.warning("Host %s has been marked down", host)
self.load_balancing_policy.on_down(host) self.load_balancing_policy.on_down(host)
self.control_connection.on_down(host) self.control_connection.on_down(host)
@@ -742,7 +746,7 @@ class Cluster(object):
return return
if not all(futures_results): if not all(futures_results):
log.warn("Connection pool could not be created, not marking node %s up", host) log.warning("Connection pool could not be created, not marking node %s up", host)
return return
self._finalize_add(host) self._finalize_add(host)
@@ -867,7 +871,7 @@ class Cluster(object):
# prepare 10 statements at a time # prepare 10 statements at a time
ks_statements = list(ks_statements) ks_statements = list(ks_statements)
chunks = [] chunks = []
for i in xrange(0, len(ks_statements), 10): for i in range(0, len(ks_statements), 10):
chunks.append(ks_statements[i:i + 10]) chunks.append(ks_statements[i:i + 10])
for ks_chunk in chunks: for ks_chunk in chunks:
@@ -882,9 +886,9 @@ class Cluster(object):
log.debug("Done preparing all known prepared statements against host %s", host) log.debug("Done preparing all known prepared statements against host %s", host)
except OperationTimedOut as timeout: except OperationTimedOut as timeout:
log.warn("Timed out trying to prepare all statements on host %s: %s", host, timeout) log.warning("Timed out trying to prepare all statements on host %s: %s", host, timeout)
except (ConnectionException, socket.error) as exc: except (ConnectionException, socket.error) as exc:
log.warn("Error trying to prepare all statements on host %s: %r", host, exc) log.warning("Error trying to prepare all statements on host %s: %r", host, exc)
except Exception: except Exception:
log.exception("Error trying to prepare all statements on host %s", host) log.exception("Error trying to prepare all statements on host %s", host)
finally: finally:
@@ -1088,7 +1092,7 @@ class Session(object):
prepared_statement = None prepared_statement = None
if isinstance(query, basestring): if isinstance(query, six.string_types):
query = SimpleStatement(query) query = SimpleStatement(query)
elif isinstance(query, PreparedStatement): elif isinstance(query, PreparedStatement):
query = query.bind(parameters) query = query.bind(parameters)
@@ -1235,8 +1239,8 @@ class Session(object):
self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) self.cluster.signal_connection_failure(host, conn_exc, is_host_addition)
return False return False
except Exception as conn_exc: except Exception as conn_exc:
log.warn("Failed to create connection pool for new host %s: %s", log.warning("Failed to create connection pool for new host %s: %s",
host, conn_exc) host, conn_exc)
# the host itself will still be marked down, so we need to pass # the host itself will still be marked down, so we need to pass
# a special flag to make sure the reconnector is created # a special flag to make sure the reconnector is created
self.cluster.signal_connection_failure( self.cluster.signal_connection_failure(
@@ -1456,11 +1460,11 @@ class ControlConnection(object):
return self._try_connect(host) return self._try_connect(host)
except ConnectionException as exc: except ConnectionException as exc:
errors[host.address] = exc errors[host.address] = exc
log.warn("[control connection] Error connecting to %s:", host, exc_info=True) log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
self._cluster.signal_connection_failure(host, exc, is_host_addition=False) self._cluster.signal_connection_failure(host, exc, is_host_addition=False)
except Exception as exc: except Exception as exc:
errors[host.address] = exc errors[host.address] = exc
log.warn("[control connection] Error connecting to %s:", host, exc_info=True) log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
raise NoHostAvailable("Unable to connect to any servers", errors) raise NoHostAvailable("Unable to connect to any servers", errors)
@@ -1948,7 +1952,7 @@ class _Scheduler(object):
def _log_if_failed(self, future): def _log_if_failed(self, future):
exc = future.exception() exc = future.exception()
if exc: if exc:
log.warn( log.warning(
"An internally scheduled tasked failed with an unhandled exception:", "An internally scheduled tasked failed with an unhandled exception:",
exc_info=exc) exc_info=exc)
@@ -2170,8 +2174,8 @@ class ResponseFuture(object):
if self._metrics is not None: if self._metrics is not None:
self._metrics.on_other_error() self._metrics.on_other_error()
# need to retry against a different host here # need to retry against a different host here
log.warn("Host %s is overloaded, retrying against a different " log.warning("Host %s is overloaded, retrying against a different "
"host", self._current_host) "host", self._current_host)
self._retry(reuse_connection=False, consistency_level=None) self._retry(reuse_connection=False, consistency_level=None)
return return
elif isinstance(response, IsBootstrappingErrorMessage): elif isinstance(response, IsBootstrappingErrorMessage):

View File

@@ -15,6 +15,7 @@
import sys import sys
from itertools import count, cycle from itertools import count, cycle
from six.moves import xrange
from threading import Event from threading import Event
@@ -105,7 +106,7 @@ def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs
parameters = [(x,) for x in range(1000)] parameters = [(x,) for x in range(1000)]
execute_concurrent_with_args(session, statement, parameters) execute_concurrent_with_args(session, statement, parameters)
""" """
return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs) return execute_concurrent(session, list(zip(cycle((statement,)), parameters)), *args, **kwargs)
_sentinel = object() _sentinel = object()
@@ -118,12 +119,12 @@ def _handle_error(error, result_index, event, session, statements, results, num_
return return
else: else:
results[result_index] = (False, error) results[result_index] = (False, error)
if num_finished.next() >= to_execute: if next(num_finished) >= to_execute:
event.set() event.set()
return return
try: try:
(next_index, (statement, params)) = statements.next() (next_index, (statement, params)) = next(statements)
except StopIteration: except StopIteration:
return return
@@ -139,7 +140,7 @@ def _handle_error(error, result_index, event, session, statements, results, num_
return return
else: else:
results[next_index] = (False, exc) results[next_index] = (False, exc)
if num_finished.next() >= to_execute: if next(num_finished) >= to_execute:
event.set() event.set()
return return
@@ -147,13 +148,13 @@ def _handle_error(error, result_index, event, session, statements, results, num_
def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error): def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error):
if result is not _sentinel: if result is not _sentinel:
results[result_index] = (True, result) results[result_index] = (True, result)
finished = num_finished.next() finished = next(num_finished)
if finished >= to_execute: if finished >= to_execute:
event.set() event.set()
return return
try: try:
(next_index, (statement, params)) = statements.next() (next_index, (statement, params)) = next(statements)
except StopIteration: except StopIteration:
return return
@@ -169,6 +170,6 @@ def _execute_next(result, result_index, event, session, statements, results, num
return return
else: else:
results[next_index] = (False, exc) results[next_index] = (False, exc)
if num_finished.next() >= to_execute: if next(num_finished) >= to_execute:
event.set() event.set()
return return

View File

@@ -19,19 +19,21 @@ import logging
import sys import sys
from threading import Event, RLock from threading import Event, RLock
import time import time
import traceback
if 'gevent.monkey' in sys.modules: if 'gevent.monkey' in sys.modules:
from gevent.queue import Queue, Empty from gevent.queue import Queue, Empty
else: else:
from Queue import Queue, Empty # noqa from six.moves.queue import Queue, Empty # noqa
from six.moves import range
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
from cassandra.marshal import int8_unpack, int32_pack from cassandra.marshal import int32_pack, header_unpack
from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage, from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage,
StartupMessage, ErrorMessage, CredentialsMessage, StartupMessage, ErrorMessage, CredentialsMessage,
QueryMessage, ResultMessage, decode_response, QueryMessage, ResultMessage, decode_response,
InvalidRequestException, SupportedMessage) InvalidRequestException, SupportedMessage)
import six
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -123,7 +125,6 @@ def defunct_on_error(f):
return f(self, *args, **kwargs) return f(self, *args, **kwargs)
except Exception as exc: except Exception as exc:
self.defunct(exc) self.defunct(exc)
return wrapper return wrapper
@@ -181,12 +182,8 @@ class Connection(object):
return return
self.is_defunct = True self.is_defunct = True
trace = traceback.format_exc(exc) log.debug("Defuncting connection (%s) to %s:",
if trace != "None": id(self), self.host, exc_info=exc)
log.debug("Defuncting connection (%s) to %s: %s\n%s",
id(self), self.host, exc, traceback.format_exc(exc))
else:
log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc)
self.last_error = exc self.last_error = exc
self.close() self.close()
@@ -203,9 +200,9 @@ class Connection(object):
try: try:
cb(new_exc) cb(new_exc)
except Exception: except Exception:
log.warn("Ignoring unhandled exception while erroring callbacks for a " log.warning("Ignoring unhandled exception while erroring callbacks for a "
"failed connection (%s) to host %s:", "failed connection (%s) to host %s:",
id(self), self.host, exc_info=True) id(self), self.host, exc_info=True)
def handle_pushed(self, response): def handle_pushed(self, response):
log.debug("Message pushed from server: %r", response) log.debug("Message pushed from server: %r", response)
@@ -231,7 +228,7 @@ class Connection(object):
request_id = self._id_queue.get() request_id = self._id_queue.get()
self._callbacks[request_id] = cb self._callbacks[request_id] = cb
self.push(msg.to_string(request_id, self.protocol_version, compression=self.compressor)) self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
return request_id return request_id
def wait_for_response(self, msg, timeout=None): def wait_for_response(self, msg, timeout=None):
@@ -268,7 +265,7 @@ class Connection(object):
return waiter.deliver(timeout) return waiter.deliver(timeout)
except OperationTimedOut: except OperationTimedOut:
raise raise
except Exception, exc: except Exception as exc:
self.defunct(exc) self.defunct(exc)
raise raise
@@ -284,7 +281,7 @@ class Connection(object):
@defunct_on_error @defunct_on_error
def process_msg(self, msg, body_len): def process_msg(self, msg, body_len):
version, flags, stream_id, opcode = map(int8_unpack, msg[:4]) version, flags, stream_id, opcode = header_unpack(msg[:4])
if stream_id < 0: if stream_id < 0:
callback = None callback = None
else: else:
@@ -309,7 +306,7 @@ class Connection(object):
if body_len > 0: if body_len > 0:
body = msg[8:] body = msg[8:]
elif body_len == 0: elif body_len == 0:
body = "" body = six.binary_type()
else: else:
raise ProtocolError("Got negative body length: %r" % body_len) raise ProtocolError("Got negative body length: %r" % body_len)
@@ -383,7 +380,7 @@ class Connection(object):
locally_supported_compressions.keys(), locally_supported_compressions.keys(),
remote_supported_compressions) remote_supported_compressions)
else: else:
compression_type = iter(overlap).next() # choose any compression_type = next(iter(overlap)) # choose any
# set the decompressor here, but set the compressor only after # set the decompressor here, but set the compressor only after
# a successful Ready message # a successful Ready message
self._compressor, self.decompressor = \ self._compressor, self.decompressor = \

View File

@@ -36,10 +36,8 @@ from datetime import datetime
from uuid import UUID from uuid import UUID
import warnings import warnings
try: import six
from cStringIO import StringIO from six.moves import range
except ImportError:
from StringIO import StringIO # NOQA
from cassandra.marshal import (int8_pack, int8_unpack, uint16_pack, uint16_unpack, from cassandra.marshal import (int8_pack, int8_unpack, uint16_pack, uint16_unpack,
int32_pack, int32_unpack, int64_pack, int64_unpack, int32_pack, int32_unpack, int64_pack, int64_unpack,
@@ -49,7 +47,11 @@ from cassandra.util import OrderedDict
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
_number_types = frozenset((int, long, float)) if six.PY3:
_number_types = frozenset((int, float))
long = int
else:
_number_types = frozenset((int, long, float))
try: try:
from blist import sortedset from blist import sortedset
@@ -69,7 +71,7 @@ def trim_if_startswith(s, prefix):
def unix_time_from_uuid1(u): def unix_time_from_uuid1(u):
return (u.get_time() - 0x01B21DD213814000) / 10000000.0 return (u.time - 0x01B21DD213814000) / 10000000.0
_casstypes = {} _casstypes = {}
@@ -177,8 +179,8 @@ class EmptyValue(object):
EMPTY = EmptyValue() EMPTY = EmptyValue()
@six.add_metaclass(CassandraTypeType)
class _CassandraType(object): class _CassandraType(object):
__metaclass__ = CassandraTypeType
subtypes = () subtypes = ()
num_subtypes = 0 num_subtypes = 0
empty_binary_ok = False empty_binary_ok = False
@@ -199,9 +201,8 @@ class _CassandraType(object):
def __init__(self, val): def __init__(self, val):
self.val = self.validate(val) self.val = self.validate(val)
def __str__(self): def __repr__(self):
return '<%s( %r )>' % (self.cql_parameterized_type(), self.val) return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
__repr__ = __str__
@staticmethod @staticmethod
def validate(val): def validate(val):
@@ -221,7 +222,7 @@ class _CassandraType(object):
""" """
if byts is None: if byts is None:
return None return None
elif byts == '' and not cls.empty_binary_ok: elif len(byts) == 0 and not cls.empty_binary_ok:
return EMPTY if cls.support_empty_values else None return EMPTY if cls.support_empty_values else None
return cls.deserialize(byts) return cls.deserialize(byts)
@@ -232,7 +233,7 @@ class _CassandraType(object):
more information. This method differs in that if None is passed in, more information. This method differs in that if None is passed in,
the result is the empty string. the result is the empty string.
""" """
return '' if val is None else cls.serialize(val) return b'' if val is None else cls.serialize(val)
@staticmethod @staticmethod
def deserialize(byts): def deserialize(byts):
@@ -293,7 +294,8 @@ class _CassandraType(object):
if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes: if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes:
raise ValueError("%s types require %d subtypes (%d given)" raise ValueError("%s types require %d subtypes (%d given)"
% (cls.typename, cls.num_subtypes, len(subtypes))) % (cls.typename, cls.num_subtypes, len(subtypes)))
newname = cls.cass_parameterized_type_with(subtypes).encode('utf8') # newname = cls.cass_parameterized_type_with(subtypes).encode('utf8')
newname = cls.cass_parameterized_type_with(subtypes)
return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname}) return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname})
@classmethod @classmethod
@@ -324,10 +326,16 @@ class _UnrecognizedType(_CassandraType):
num_subtypes = 'UNKNOWN' num_subtypes = 'UNKNOWN'
def mkUnrecognizedType(casstypename): if six.PY3:
return CassandraTypeType(casstypename.encode('utf8'), def mkUnrecognizedType(casstypename):
(_UnrecognizedType,), return CassandraTypeType(casstypename,
{'typename': "'%s'" % casstypename}) (_UnrecognizedType,),
{'typename': "'%s'" % casstypename})
else:
def mkUnrecognizedType(casstypename): # noqa
return CassandraTypeType(casstypename.encode('utf8'),
(_UnrecognizedType,),
{'typename': "'%s'" % casstypename})
class BytesType(_CassandraType): class BytesType(_CassandraType):
@@ -336,11 +344,11 @@ class BytesType(_CassandraType):
@staticmethod @staticmethod
def validate(val): def validate(val):
return buffer(val) return bytearray(val)
@staticmethod @staticmethod
def serialize(val): def serialize(val):
return str(val) return six.binary_type(val)
class DecimalType(_CassandraType): class DecimalType(_CassandraType):
@@ -401,9 +409,25 @@ class BooleanType(_CassandraType):
return int8_pack(truth) return int8_pack(truth)
class AsciiType(_CassandraType): if six.PY2:
typename = 'ascii' class AsciiType(_CassandraType):
empty_binary_ok = True typename = 'ascii'
empty_binary_ok = True
else:
class AsciiType(_CassandraType):
typename = 'ascii'
empty_binary_ok = True
@staticmethod
def deserialize(byts):
return byts.decode('ascii')
@staticmethod
def serialize(var):
try:
return var.encode('ascii')
except UnicodeDecodeError:
return var
class FloatType(_CassandraType): class FloatType(_CassandraType):
@@ -496,7 +520,7 @@ class DateType(_CassandraType):
@classmethod @classmethod
def validate(cls, date): def validate(cls, date):
if isinstance(date, basestring): if isinstance(date, six.string_types):
date = cls.interpret_datestring(date) date = cls.interpret_datestring(date)
return date return date
@@ -628,7 +652,7 @@ class _SimpleParameterizedType(_ParameterizedType):
numelements = uint16_unpack(byts[:2]) numelements = uint16_unpack(byts[:2])
p = 2 p = 2
result = [] result = []
for n in xrange(numelements): for _ in range(numelements):
itemlen = uint16_unpack(byts[p:p + 2]) itemlen = uint16_unpack(byts[p:p + 2])
p += 2 p += 2
item = byts[p:p + itemlen] item = byts[p:p + itemlen]
@@ -638,11 +662,11 @@ class _SimpleParameterizedType(_ParameterizedType):
@classmethod @classmethod
def serialize_safe(cls, items): def serialize_safe(cls, items):
if isinstance(items, basestring): if isinstance(items, six.string_types):
raise TypeError("Received a string for a type that expects a sequence") raise TypeError("Received a string for a type that expects a sequence")
subtype, = cls.subtypes subtype, = cls.subtypes
buf = StringIO() buf = six.BytesIO()
buf.write(uint16_pack(len(items))) buf.write(uint16_pack(len(items)))
for item in items: for item in items:
itembytes = subtype.to_binary(item) itembytes = subtype.to_binary(item)
@@ -670,7 +694,7 @@ class MapType(_ParameterizedType):
@classmethod @classmethod
def validate(cls, val): def validate(cls, val):
subkeytype, subvaltype = cls.subtypes subkeytype, subvaltype = cls.subtypes
return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in val.iteritems()) return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in six.iteritems(val))
@classmethod @classmethod
def deserialize_safe(cls, byts): def deserialize_safe(cls, byts):
@@ -678,7 +702,7 @@ class MapType(_ParameterizedType):
numelements = uint16_unpack(byts[:2]) numelements = uint16_unpack(byts[:2])
p = 2 p = 2
themap = OrderedDict() themap = OrderedDict()
for n in xrange(numelements): for _ in range(numelements):
key_len = uint16_unpack(byts[p:p + 2]) key_len = uint16_unpack(byts[p:p + 2])
p += 2 p += 2
keybytes = byts[p:p + key_len] keybytes = byts[p:p + key_len]
@@ -695,10 +719,10 @@ class MapType(_ParameterizedType):
@classmethod @classmethod
def serialize_safe(cls, themap): def serialize_safe(cls, themap):
subkeytype, subvaltype = cls.subtypes subkeytype, subvaltype = cls.subtypes
buf = StringIO() buf = six.BytesIO()
buf.write(uint16_pack(len(themap))) buf.write(uint16_pack(len(themap)))
try: try:
items = themap.iteritems() items = six.iteritems(themap)
except AttributeError: except AttributeError:
raise TypeError("Got a non-map object for a map value") raise TypeError("Got a non-map object for a map value")
for key, val in items: for key, val in items:
@@ -747,7 +771,7 @@ class ReversedType(_ParameterizedType):
def is_counter_type(t): def is_counter_type(t):
if isinstance(t, basestring): if isinstance(t, six.string_types):
t = lookup_casstype(t) t = lookup_casstype(t)
return issubclass(t, CounterColumnType) return issubclass(t, CounterColumnType)

View File

@@ -16,16 +16,14 @@ import logging
import socket import socket
from uuid import UUID from uuid import UUID
try: import six
from cStringIO import StringIO from six.moves import range
except ImportError:
from StringIO import StringIO # ignore flake8 warning: # NOQA
from cassandra import (Unavailable, WriteTimeout, ReadTimeout, from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
AlreadyExists, InvalidRequest, Unauthorized, AlreadyExists, InvalidRequest, Unauthorized,
UnsupportedOperation) UnsupportedOperation)
from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
int8_pack, int8_unpack) int8_pack, int8_unpack, header_pack)
from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
CounterColumnType, DateType, DecimalType, CounterColumnType, DateType, DecimalType,
DoubleType, FloatType, Int32Type, DoubleType, FloatType, Int32Type,
@@ -48,66 +46,75 @@ HEADER_DIRECTION_FROM_CLIENT = 0x00
HEADER_DIRECTION_TO_CLIENT = 0x80 HEADER_DIRECTION_TO_CLIENT = 0x80
HEADER_DIRECTION_MASK = 0x80 HEADER_DIRECTION_MASK = 0x80
COMPRESSED_FLAG = 0x01
TRACING_FLAG = 0x02
_message_types_by_name = {} _message_types_by_name = {}
_message_types_by_opcode = {} _message_types_by_opcode = {}
class _register_msg_type(type): class _RegisterMessageType(type):
def __init__(cls, name, bases, dct): def __init__(cls, name, bases, dct):
if not name.startswith('_'): if not name.startswith('_'):
_message_types_by_name[cls.name] = cls _message_types_by_name[cls.name] = cls
_message_types_by_opcode[cls.opcode] = cls _message_types_by_opcode[cls.opcode] = cls
@six.add_metaclass(_RegisterMessageType)
class _MessageType(object): class _MessageType(object):
__metaclass__ = _register_msg_type
tracing = False tracing = False
def to_string(self, stream_id, protocol_version, compression=None): def to_binary(self, stream_id, protocol_version, compression=None):
body = StringIO() body = six.BytesIO()
self.send_body(body, protocol_version) self.send_body(body, protocol_version)
body = body.getvalue() body = body.getvalue()
version = protocol_version | HEADER_DIRECTION_FROM_CLIENT
flags = 0
if compression is not None and len(body) > 0:
body = compression(body)
flags |= 0x01
if self.tracing:
flags |= 0x02
msglen = int32_pack(len(body))
msg_parts = map(int8_pack, (version, flags, stream_id, self.opcode)) + [msglen, body]
return ''.join(msg_parts)
def __str__(self): flags = 0
paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in _get_params(self)] if compression and len(body) > 0:
return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs)) body = compression(body)
__repr__ = __str__ flags |= COMPRESSED_FLAG
if self.tracing:
flags |= TRACING_FLAG
msg = six.BytesIO()
write_header(
msg,
protocol_version | HEADER_DIRECTION_FROM_CLIENT,
flags, stream_id, self.opcode, len(body)
)
msg.write(body)
return msg.getvalue()
def __repr__(self):
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
def _get_params(message_obj): def _get_params(message_obj):
base_attrs = dir(_MessageType) base_attrs = dir(_MessageType)
return [a for a in dir(message_obj) return (
if a not in base_attrs and not a.startswith('_') and not callable(getattr(message_obj, a))] (n, a) for n, a in message_obj.__dict__.items()
if n not in base_attrs and not n.startswith('_') and not callable(a)
)
def decode_response(stream_id, flags, opcode, body, decompressor=None): def decode_response(stream_id, flags, opcode, body, decompressor=None):
if flags & 0x01: if flags & COMPRESSED_FLAG:
if decompressor is None: if decompressor is None:
raise Exception("No decompressor available for compressed frame!") raise Exception("No de-compressor available for compressed frame!")
body = decompressor(body) body = decompressor(body)
flags ^= 0x01 flags ^= COMPRESSED_FLAG
body = StringIO(body) body = six.BytesIO(body)
if flags & 0x02: if flags & TRACING_FLAG:
trace_id = UUID(bytes=body.read(16)) trace_id = UUID(bytes=body.read(16))
flags ^= 0x02 flags ^= TRACING_FLAG
else: else:
trace_id = None trace_id = None
if flags: if flags:
log.warn("Unknown protocol flags set: %02x. May cause problems.", flags) log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
msg_class = _message_types_by_opcode[opcode] msg_class = _message_types_by_opcode[opcode]
msg = msg_class.recv_body(body) msg = msg_class.recv_body(body)
@@ -156,14 +163,14 @@ class ErrorMessage(_MessageType, Exception):
return self return self
class ErrorMessageSubclass(_register_msg_type): class ErrorMessageSubclass(_RegisterMessageType):
def __init__(cls, name, bases, dct): def __init__(cls, name, bases, dct):
if cls.error_code is not None: if cls.error_code is not None: # Server has an error code of 0.
error_classes[cls.error_code] = cls error_classes[cls.error_code] = cls
@six.add_metaclass(ErrorMessageSubclass)
class ErrorMessageSub(ErrorMessage): class ErrorMessageSub(ErrorMessage):
__metaclass__ = ErrorMessageSubclass
error_code = None error_code = None
@@ -511,7 +518,7 @@ class ResultMessage(_MessageType):
def recv_results_rows(cls, f): def recv_results_rows(cls, f):
paging_state, column_metadata = cls.recv_results_metadata(f) paging_state, column_metadata = cls.recv_results_metadata(f)
rowcount = read_int(f) rowcount = read_int(f)
rows = [cls.recv_row(f, len(column_metadata)) for x in xrange(rowcount)] rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
colnames = [c[2] for c in column_metadata] colnames = [c[2] for c in column_metadata]
coltypes = [c[3] for c in column_metadata] coltypes = [c[3] for c in column_metadata]
return ( return (
@@ -538,7 +545,7 @@ class ResultMessage(_MessageType):
ksname = read_string(f) ksname = read_string(f)
cfname = read_string(f) cfname = read_string(f)
column_metadata = [] column_metadata = []
for x in xrange(colcount): for _ in range(colcount):
if glob_tblspec: if glob_tblspec:
colksname = ksname colksname = ksname
colcfname = cfname colcfname = cfname
@@ -580,7 +587,7 @@ class ResultMessage(_MessageType):
@staticmethod @staticmethod
def recv_row(f, colcount): def recv_row(f, colcount):
return [read_value(f) for x in xrange(colcount)] return [read_value(f) for _ in range(colcount)]
class PrepareMessage(_MessageType): class PrepareMessage(_MessageType):
@@ -729,6 +736,14 @@ class EventMessage(_MessageType):
return dict(change_type=change_type, keyspace=keyspace, table=table) return dict(change_type=change_type, keyspace=keyspace, table=table)
def write_header(f, version, flags, stream_id, opcode, length):
"""
Write a CQL protocol frame header.
"""
f.write(header_pack(version, flags, stream_id, opcode))
write_int(f, length)
def read_byte(f): def read_byte(f):
return int8_unpack(f.read(1)) return int8_unpack(f.read(1))
@@ -774,7 +789,7 @@ def read_binary_string(f):
def write_string(f, s): def write_string(f, s):
if isinstance(s, unicode): if isinstance(s, six.text_type):
s = s.encode('utf8') s = s.encode('utf8')
write_short(f, len(s)) write_short(f, len(s))
f.write(s) f.write(s)
@@ -791,7 +806,7 @@ def read_longstring(f):
def write_longstring(f, s): def write_longstring(f, s):
if isinstance(s, unicode): if isinstance(s, six.text_type):
s = s.encode('utf8') s = s.encode('utf8')
write_int(f, len(s)) write_int(f, len(s))
f.write(s) f.write(s)
@@ -799,7 +814,7 @@ def write_longstring(f, s):
def read_stringlist(f): def read_stringlist(f):
numstrs = read_short(f) numstrs = read_short(f)
return [read_string(f) for x in xrange(numstrs)] return [read_string(f) for _ in range(numstrs)]
def write_stringlist(f, stringlist): def write_stringlist(f, stringlist):
@@ -811,7 +826,7 @@ def write_stringlist(f, stringlist):
def read_stringmap(f): def read_stringmap(f):
numpairs = read_short(f) numpairs = read_short(f)
strmap = {} strmap = {}
for x in xrange(numpairs): for _ in range(numpairs):
k = read_string(f) k = read_string(f)
strmap[k] = read_string(f) strmap[k] = read_string(f)
return strmap return strmap
@@ -827,7 +842,7 @@ def write_stringmap(f, strmap):
def read_stringmultimap(f): def read_stringmultimap(f):
numkeys = read_short(f) numkeys = read_short(f)
strmmap = {} strmmap = {}
for x in xrange(numkeys): for _ in range(numkeys):
k = read_string(f) k = read_string(f)
strmmap[k] = read_stringlist(f) strmmap[k] = read_stringlist(f)
return strmmap return strmmap

View File

@@ -12,21 +12,34 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
log = logging.getLogger(__name__)
from binascii import hexlify from binascii import hexlify
import calendar import calendar
import datetime import datetime
import sys import sys
import types import types
from uuid import UUID from uuid import UUID
import six
from cassandra.util import OrderedDict from cassandra.util import OrderedDict
if six.PY3:
long = int
def cql_quote(term): def cql_quote(term):
if isinstance(term, unicode): # The ordering of this method is important for the result of this method to
return "'%s'" % term.encode('utf8').replace("'", "''") # be a native str type (for both Python 2 and 3)
elif isinstance(term, (str, bool)):
# Handle quoting of native str and bool types
if isinstance(term, (str, bool)):
return "'%s'" % str(term).replace("'", "''") return "'%s'" % str(term).replace("'", "''")
# This branch of the if statement will only be used by Python 2 to catch
# unicode strings, text_type is used to prevent type errors with Python 3.
elif isinstance(term, six.text_type):
return "'%s'" % term.encode('utf8').replace("'", "''")
else: else:
return str(term) return str(term)
@@ -43,13 +56,16 @@ def cql_encode_str(val):
return cql_quote(val) return cql_quote(val)
if sys.version_info >= (2, 7): if six.PY3:
def cql_encode_bytes(val): def cql_encode_bytes(val):
return '0x' + hexlify(val) return (b'0x' + hexlify(val)).decode('utf-8')
elif sys.version_info >= (2, 7):
def cql_encode_bytes(val): # noqa
return b'0x' + hexlify(val)
else: else:
# python 2.6 requires string or read-only buffer for hexlify # python 2.6 requires string or read-only buffer for hexlify
def cql_encode_bytes(val): # noqa def cql_encode_bytes(val): # noqa
return '0x' + hexlify(buffer(val)) return b'0x' + hexlify(buffer(val))
def cql_encode_object(val): def cql_encode_object(val):
@@ -71,11 +87,10 @@ def cql_encode_sequence(val):
def cql_encode_map_collection(val): def cql_encode_map_collection(val):
return '{ %s }' % ' , '.join( return '{ %s }' % ' , '.join('%s : %s' % (
'%s : %s' % ( cql_encode_all_types(k),
cql_encode_all_types(k), cql_encode_all_types(v)
cql_encode_all_types(v)) ) for k, v in six.iteritems(val))
for k, v in val.iteritems())
def cql_encode_list_collection(val): def cql_encode_list_collection(val):
@@ -92,13 +107,9 @@ def cql_encode_all_types(val):
cql_encoders = { cql_encoders = {
float: cql_encode_object, float: cql_encode_object,
buffer: cql_encode_bytes,
bytearray: cql_encode_bytes, bytearray: cql_encode_bytes,
str: cql_encode_str, str: cql_encode_str,
unicode: cql_encode_unicode,
types.NoneType: cql_encode_none,
int: cql_encode_object, int: cql_encode_object,
long: cql_encode_object,
UUID: cql_encode_object, UUID: cql_encode_object,
datetime.datetime: cql_encode_datetime, datetime.datetime: cql_encode_datetime,
datetime.date: cql_encode_date, datetime.date: cql_encode_date,
@@ -110,3 +121,17 @@ cql_encoders = {
frozenset: cql_encode_set_collection, frozenset: cql_encode_set_collection,
types.GeneratorType: cql_encode_list_collection types.GeneratorType: cql_encode_list_collection
} }
if six.PY2:
cql_encoders.update({
unicode: cql_encode_unicode,
buffer: cql_encode_bytes,
long: cql_encode_object,
types.NoneType: cql_encode_none,
})
else:
cql_encoders.update({
memoryview: cql_encode_bytes,
bytes: cql_encode_bytes,
type(None): cql_encode_none,
})

View File

@@ -20,6 +20,10 @@ import os
import socket import socket
import sys import sys
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
from six import BytesIO
from six.moves import range
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode
try: try:
from weakref import WeakSet from weakref import WeakSet
@@ -28,11 +32,6 @@ except ImportError:
import asyncore import asyncore
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO # ignore flake8 warning: # NOQA
try: try:
import ssl import ssl
except ImportError: except ImportError:
@@ -141,7 +140,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
asyncore.dispatcher.__init__(self) asyncore.dispatcher.__init__(self)
self.connected_event = Event() self.connected_event = Event()
self._iobuf = StringIO() self._iobuf = BytesIO()
self._callbacks = {} self._callbacks = {}
self.deque = deque() self.deque = deque()
@@ -286,7 +285,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
# leave leftover in current buffer # leave leftover in current buffer
leftover = self._iobuf.read() leftover = self._iobuf.read()
self._iobuf = StringIO() self._iobuf = BytesIO()
self._iobuf.write(leftover) self._iobuf.write(leftover)
self._total_reqd_bytes = 0 self._total_reqd_bytes = 0
@@ -302,7 +301,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
sabs = self.out_buffer_size sabs = self.out_buffer_size
if len(data) > sabs: if len(data) > sabs:
chunks = [] chunks = []
for i in xrange(0, len(data), sabs): for i in range(0, len(data), sabs):
chunks.append(data[i:i + sabs]) chunks.append(data[i:i + sabs])
else: else:
chunks = [data] chunks = [data]

View File

@@ -20,6 +20,8 @@ import os
import socket import socket
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
from six import BytesIO
from cassandra import OperationTimedOut from cassandra import OperationTimedOut
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING
from cassandra.decoder import RegisterMessage from cassandra.decoder import RegisterMessage
@@ -35,10 +37,6 @@ except ImportError:
"for instructions on installing build dependencies and building " "for instructions on installing build dependencies and building "
"the C extension.") "the C extension.")
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO # ignore flake8 warning: # NOQA
try: try:
import ssl import ssl
@@ -197,7 +195,7 @@ class LibevConnection(Connection):
Connection.__init__(self, *args, **kwargs) Connection.__init__(self, *args, **kwargs)
self.connected_event = Event() self.connected_event = Event()
self._iobuf = StringIO() self._iobuf = BytesIO()
self._callbacks = {} self._callbacks = {}
self.deque = deque() self.deque = deque()
@@ -323,7 +321,7 @@ class LibevConnection(Connection):
# leave leftover in current buffer # leave leftover in current buffer
leftover = self._iobuf.read() leftover = self._iobuf.read()
self._iobuf = StringIO() self._iobuf = BytesIO()
self._iobuf.write(leftover) self._iobuf.write(leftover)
self._total_reqd_bytes = 0 self._total_reqd_bytes = 0

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import six
import struct import struct
@@ -37,12 +38,24 @@ uint8_pack, uint8_unpack = _make_packer('>B')
float_pack, float_unpack = _make_packer('>f') float_pack, float_unpack = _make_packer('>f')
double_pack, double_unpack = _make_packer('>d') double_pack, double_unpack = _make_packer('>d')
# Special case for cassandra header
header_struct = struct.Struct('>BBbB')
header_pack = header_struct.pack
header_unpack = header_struct.unpack
def varint_unpack(term):
val = int(term.encode('hex'), 16) if six.PY3:
if (ord(term[0]) & 128) != 0: def varint_unpack(term):
val = val - (1 << (len(term) * 8)) val = int(''.join("%02x" % i for i in term), 16)
return val if (term[0] & 128) != 0:
val -= 1 << (len(term) * 8)
return val
else:
def varint_unpack(term): # noqa
val = int(term.encode('hex'), 16)
if (ord(term[0]) & 128) != 0:
val = val - (1 << (len(term) * 8))
return val
def bitlength(n): def bitlength(n):
@@ -56,16 +69,16 @@ def bitlength(n):
def varint_pack(big): def varint_pack(big):
pos = True pos = True
if big == 0: if big == 0:
return '\x00' return b'\x00'
if big < 0: if big < 0:
bytelength = bitlength(abs(big) - 1) / 8 + 1 bytelength = bitlength(abs(big) - 1) // 8 + 1
big = (1 << bytelength * 8) + big big = (1 << bytelength * 8) + big
pos = False pos = False
revbytes = [] revbytes = bytearray()
while big > 0: while big > 0:
revbytes.append(chr(big & 0xff)) revbytes.append(big & 0xff)
big >>= 8 big >>= 8
if pos and ord(revbytes[-1]) & 0x80: if pos and revbytes[-1] & 0x80:
revbytes.append('\x00') revbytes.append(0)
revbytes.reverse() revbytes.reverse()
return ''.join(revbytes) return six.binary_type(revbytes)

View File

@@ -21,11 +21,12 @@ import logging
import re import re
from threading import RLock from threading import RLock
import weakref import weakref
import six
murmur3 = None murmur3 = None
try: try:
from murmur3 import murmur3 from cassandra.murmur3 import murmur3
except ImportError: except ImportError as e:
pass pass
import cassandra.cqltypes as types import cassandra.cqltypes as types
@@ -330,7 +331,7 @@ class Metadata(object):
token_to_host_owner = {} token_to_host_owner = {}
ring = [] ring = []
for host, token_strings in token_map.iteritems(): for host, token_strings in six.iteritems(token_map):
for token_string in token_strings: for token_string in token_strings:
token = token_class(token_string) token = token_class(token_string)
ring.append(token) ring.append(token)
@@ -793,14 +794,18 @@ class TableMetadata(object):
return list(sorted(ret)) return list(sorted(ret))
def protect_name(name): if six.PY3:
if isinstance(name, unicode): def protect_name(name):
name = name.encode('utf8') return maybe_escape_name(name)
return maybe_escape_name(name) else:
def protect_name(name):
if isinstance(name, six.text_type):
name = name.encode('utf8')
return maybe_escape_name(name)
def protect_names(names): def protect_names(names):
return map(protect_name, names) return [protect_name(n) for n in names]
def protect_value(value): def protect_value(value):
@@ -1008,11 +1013,14 @@ class Token(object):
def __eq__(self, other): def __eq__(self, other):
return self.value == other.value return self.value == other.value
def __lt__(self, other):
return self.value < other.value
def __hash__(self): def __hash__(self):
return hash(self.value) return hash(self.value)
def __repr__(self): def __repr__(self):
return "<%s: %r>" % (self.__class__.__name__, self.value) return "<%s: %s>" % (self.__class__.__name__, self.value)
__str__ = __repr__ __str__ = __repr__
MIN_LONG = -(2 ** 63) MIN_LONG = -(2 ** 63)
@@ -1031,7 +1039,7 @@ class Murmur3Token(Token):
@classmethod @classmethod
def hash_fn(cls, key): def hash_fn(cls, key):
if murmur3 is not None: if murmur3 is not None:
h = murmur3(key) h = int(murmur3(key))
return h if h != MIN_LONG else MAX_LONG return h if h != MIN_LONG else MAX_LONG
else: else:
raise NoMurmur3() raise NoMurmur3()
@@ -1048,6 +1056,8 @@ class MD5Token(Token):
@classmethod @classmethod
def hash_fn(cls, key): def hash_fn(cls, key):
if isinstance(key, six.text_type):
key = key.encode('UTF-8')
return abs(varint_unpack(md5(key).digest())) return abs(varint_unpack(md5(key).digest()))
def __init__(self, token): def __init__(self, token):
@@ -1062,7 +1072,7 @@ class BytesToken(Token):
def __init__(self, token_string): def __init__(self, token_string):
""" `token_string` should be string representing the token. """ """ `token_string` should be string representing the token. """
if not isinstance(token_string, basestring): if not isinstance(token_string, six.string_types):
raise TypeError( raise TypeError(
"Tokens for ByteOrderedPartitioner should be strings (got %s)" "Tokens for ByteOrderedPartitioner should be strings (got %s)"
% (type(token_string),)) % (type(token_string),))

View File

@@ -169,6 +169,18 @@ uint64_t MurmurHash3_x64_128 (const void * key, const int len,
return h1; return h1;
} }
struct module_state {
PyObject *error;
};
#if PY_MAJOR_VERSION >= 3
#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m))
#else
#define GETSTATE(m) (&_state)
static struct module_state _state;
#endif
static PyObject * static PyObject *
murmur3(PyObject *self, PyObject *args) murmur3(PyObject *self, PyObject *args)
{ {
@@ -186,23 +198,63 @@ murmur3(PyObject *self, PyObject *args)
} }
static PyMethodDef murmur3_methods[] = { static PyMethodDef murmur3_methods[] = {
{"murmur3", murmur3, METH_VARARGS, {"murmur3", murmur3, METH_VARARGS, "Make an x64 murmur3 64-bit hash value"},
"Make an x64 murmur3 64-bit hash value"},
{NULL, NULL, 0, NULL} {NULL, NULL, 0, NULL}
}; };
#if PY_MAJOR_VERSION <= 2 #if PY_MAJOR_VERSION >= 3
PyMODINIT_FUNC static int murmur3_traverse(PyObject *m, visitproc visit, void *arg) {
initmurmur3(void) Py_VISIT(GETSTATE(m)->error);
{ return 0;
(void) Py_InitModule("murmur3", murmur3_methods);
} }
static int murmur3_clear(PyObject *m) {
Py_CLEAR(GETSTATE(m)->error);
return 0;
}
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"murmur3",
NULL,
sizeof(struct module_state),
murmur3_methods,
NULL,
murmur3_traverse,
murmur3_clear,
NULL
};
#define INITERROR return NULL
PyObject *
PyInit_murmur3(void)
#else #else
#define INITERROR return
/* Python 3.x */ void
// TODO initmurmur3(void)
#endif #endif
{
#if PY_MAJOR_VERSION >= 3
PyObject *module = PyModule_Create(&moduledef);
#else
PyObject *module = Py_InitModule("murmur3", murmur3_methods);
#endif
if (module == NULL)
INITERROR;
struct module_state *st = GETSTATE(module);
st->error = PyErr_NewException("murmur3.Error", NULL, NULL);
if (st->error == NULL) {
Py_DECREF(module);
INITERROR;
}
#if PY_MAJOR_VERSION >= 3
return module;
#endif
}

View File

@@ -16,9 +16,12 @@ from itertools import islice, cycle, groupby, repeat
import logging import logging
from random import randint from random import randint
from threading import Lock from threading import Lock
import six
from cassandra import ConsistencyLevel from cassandra import ConsistencyLevel
from six.moves import range
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -263,7 +266,7 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
for host in islice(cycle(local_live), pos, pos + len(local_live)): for host in islice(cycle(local_live), pos, pos + len(local_live)):
yield host yield host
for dc, current_dc_hosts in self._dc_live_hosts.iteritems(): for dc, current_dc_hosts in six.iteritems(self._dc_live_hosts):
if dc == self.local_dc: if dc == self.local_dc:
continue continue
@@ -529,7 +532,7 @@ class ExponentialReconnectionPolicy(ReconnectionPolicy):
self.max_delay = max_delay self.max_delay = max_delay
def new_schedule(self): def new_schedule(self):
return (min(self.base_delay * (2 ** i), self.max_delay) for i in xrange(64)) return (min(self.base_delay * (2 ** i), self.max_delay) for i in range(64))
class WriteType(object): class WriteType(object):

View File

@@ -150,8 +150,14 @@ class Host(object):
def __eq__(self, other): def __eq__(self, other):
return self.address == other.address return self.address == other.address
def __hash__(self):
return hash(self.address)
def __lt__(self, other):
return self.address < other.address
def __str__(self): def __str__(self):
return self.address return str(self.address)
def __repr__(self): def __repr__(self):
dc = (" %s" % (self._datacenter,)) if self._datacenter else "" dc = (" %s" % (self._datacenter,)) if self._datacenter else ""
@@ -178,7 +184,7 @@ class _ReconnectionHandler(object):
log.debug("Reconnection handler was cancelled before starting") log.debug("Reconnection handler was cancelled before starting")
return return
first_delay = self.schedule.next() first_delay = next(self.schedule)
self.scheduler.schedule(first_delay, self.run) self.scheduler.schedule(first_delay, self.run)
def run(self): def run(self):
@@ -189,7 +195,7 @@ class _ReconnectionHandler(object):
try: try:
conn = self.try_reconnect() conn = self.try_reconnect()
except Exception as exc: except Exception as exc:
next_delay = self.schedule.next() next_delay = next(self.schedule)
if self.on_exception(exc, next_delay): if self.on_exception(exc, next_delay):
self.scheduler.schedule(next_delay, self.run) self.scheduler.schedule(next_delay, self.run)
else: else:
@@ -260,8 +266,8 @@ class _HostReconnectionHandler(_ReconnectionHandler):
if isinstance(exc, AuthenticationFailed): if isinstance(exc, AuthenticationFailed):
return False return False
else: else:
log.warn("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s", log.warning("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s",
self.host, next_delay, exc) self.host, next_delay, exc)
log.debug("Reconnection error details", exc_info=True) log.debug("Reconnection error details", exc_info=True)
return True return True
@@ -371,8 +377,8 @@ class HostConnectionPool(object):
def _create_new_connection(self): def _create_new_connection(self):
try: try:
self._add_conn_if_under_max() self._add_conn_if_under_max()
except (ConnectionException, socket.error), exc: except (ConnectionException, socket.error) as exc:
log.warn("Failed to create new connection to %s: %s", self.host, exc) log.warning("Failed to create new connection to %s: %s", self.host, exc)
except Exception: except Exception:
log.exception("Unexpectedly failed to create new connection") log.exception("Unexpectedly failed to create new connection")
finally: finally:
@@ -404,7 +410,7 @@ class HostConnectionPool(object):
self._signal_available_conn() self._signal_available_conn()
return True return True
except (ConnectionException, socket.error) as exc: except (ConnectionException, socket.error) as exc:
log.warn("Failed to add new connection to pool for host %s: %s", self.host, exc) log.warning("Failed to add new connection to pool for host %s: %s", self.host, exc)
with self._lock: with self._lock:
self.open_count -= 1 self.open_count -= 1
if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False): if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False):

View File

@@ -23,6 +23,7 @@ from datetime import datetime, timedelta
import re import re
import struct import struct
import time import time
import six
from cassandra import ConsistencyLevel, OperationTimedOut from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cqltypes import unix_time_from_uuid1 from cassandra.cqltypes import unix_time_from_uuid1
@@ -113,8 +114,8 @@ class Statement(object):
def _set_routing_key(self, key): def _set_routing_key(self, key):
if isinstance(key, (list, tuple)): if isinstance(key, (list, tuple)):
self._routing_key = "".join(struct.pack("HsB", len(component), component, 0) self._routing_key = b"".join(struct.pack("HsB", len(component), component, 0)
for component in key) for component in key)
else: else:
self._routing_key = key self._routing_key = key
@@ -408,7 +409,7 @@ class BoundStatement(Statement):
val = self.values[statement_index] val = self.values[statement_index]
components.append(struct.pack("HsB", len(val), val, 0)) components.append(struct.pack("HsB", len(val), val, 0))
self._routing_key = "".join(components) self._routing_key = b"".join(components)
return self._routing_key return self._routing_key
@@ -473,7 +474,7 @@ class BatchStatement(Statement):
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level) Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level)
def add(self, statement, parameters=None): def add(self, statement, parameters=None):
if isinstance(statement, basestring): if isinstance(statement, six.string_types):
if parameters: if parameters:
statement = bind_params(statement, parameters) statement = bind_params(statement, parameters)
self._statements_and_parameters.append((False, statement, ())) self._statements_and_parameters.append((False, statement, ()))
@@ -532,11 +533,9 @@ class ValueSequence(object):
def bind_params(query, params): def bind_params(query, params):
if isinstance(params, dict): if isinstance(params, dict):
return query % dict((k, cql_encoders.get(type(v), cql_encode_object)(v)) return query % dict((k, cql_encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
for k, v in params.iteritems())
else: else:
return query % tuple(cql_encoders.get(type(v), cql_encode_object)(v) return query % tuple(cql_encoders.get(type(v), cql_encode_object)(v) for v in params)
for v in params)
class TraceUnavailable(Exception): class TraceUnavailable(Exception):

View File

@@ -1,7 +1,5 @@
from __future__ import with_statement from __future__ import with_statement
from UserDict import DictMixin
try: try:
from collections import OrderedDict from collections import OrderedDict
except ImportError: except ImportError:
@@ -28,6 +26,7 @@ except ImportError:
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
# OTHER DEALINGS IN THE SOFTWARE. # OTHER DEALINGS IN THE SOFTWARE.
from UserDict import DictMixin
class OrderedDict(dict, DictMixin): # noqa class OrderedDict(dict, DictMixin): # noqa
""" A dictionary which maintains the insertion order of keys. """ """ A dictionary which maintains the insertion order of keys. """
@@ -80,9 +79,9 @@ except ImportError:
if not self: if not self:
raise KeyError('dictionary is empty') raise KeyError('dictionary is empty')
if last: if last:
key = reversed(self).next() key = next(reversed(self))
else: else:
key = iter(self).next() key = next(iter(self))
value = self.pop(key) value = self.pop(key)
return key, value return key, value

View File

@@ -1,3 +1,4 @@
blist blist
futures futures
scales scales
six >=1.6

View File

@@ -12,16 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function
import sys import sys
import ez_setup import ez_setup
ez_setup.use_setuptools() ez_setup.use_setuptools()
run_gevent_nosetests = False
if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests": if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests":
from gevent.monkey import patch_all from gevent.monkey import patch_all
patch_all() patch_all()
run_gevent_nosetests = True
from setuptools import setup from setuptools import setup
from distutils.command.build_ext import build_ext from distutils.command.build_ext import build_ext
@@ -48,10 +47,11 @@ with open("README.rst") as f:
long_description = f.read() long_description = f.read()
gevent_nosetests = None try:
if run_gevent_nosetests:
from nose.commands import nosetests from nose.commands import nosetests
except ImportError:
gevent_nosetests = None
else:
class gevent_nosetests(nosetests): class gevent_nosetests(nosetests):
description = "run nosetests with gevent monkey patching" description = "run nosetests with gevent monkey patching"
@@ -92,11 +92,11 @@ class DocCommand(Command):
except subprocess.CalledProcessError as exc: except subprocess.CalledProcessError as exc:
raise RuntimeError("Documentation step '%s' failed: %s: %s" % (mode, exc, exc.output)) raise RuntimeError("Documentation step '%s' failed: %s: %s" % (mode, exc, exc.output))
else: else:
print output print(output)
print "" print("")
print "Documentation step '%s' performed, results here:" % mode print("Documentation step '%s' performed, results here:" % mode)
print " %s/" % path print(" %s/" % path)
class BuildFailed(Exception): class BuildFailed(Exception):
@@ -181,7 +181,7 @@ def run_setup(extensions):
kw['cmdclass']['build_ext'] = build_extensions kw['cmdclass']['build_ext'] = build_extensions
kw['ext_modules'] = extensions kw['ext_modules'] = extensions
dependencies = ['futures', 'scales', 'blist'] dependencies = ['futures', 'scales >=1.0.5', 'blist', 'six >=1.6']
if platform.python_implementation() != "CPython": if platform.python_implementation() != "CPython":
dependencies.remove('blist') dependencies.remove('blist')
@@ -196,7 +196,7 @@ def run_setup(extensions):
packages=['cassandra', 'cassandra.io'], packages=['cassandra', 'cassandra.io'],
include_package_data=True, include_package_data=True,
install_requires=dependencies, install_requires=dependencies,
tests_require=['nose', 'mock', 'ccm', 'unittest2', 'PyYAML', 'pytz'], tests_require=['nose', 'mock', 'PyYAML', 'pytz'],
classifiers=[ classifiers=[
'Development Status :: 5 - Production/Stable', 'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers', 'Intended Audience :: Developers',
@@ -207,6 +207,10 @@ def run_setup(extensions):
'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.6', 'Programming Language :: Python :: 2.6',
'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy',
'Topic :: Software Development :: Libraries :: Python Modules' 'Topic :: Software Development :: Libraries :: Python Modules'
], ],
**kw) **kw)

View File

@@ -22,7 +22,9 @@ except ImportError:
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
import os import os
from six import print_
from threading import Event from threading import Event
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
@@ -92,7 +94,7 @@ def get_node(node_id):
def setup_package(): def setup_package():
print 'Using Cassandra version: %s' % CASSANDRA_VERSION print_('Using Cassandra version: %s' % CASSANDRA_VERSION)
try: try:
try: try:
cluster = CCMCluster.load(path, CLUSTER_NAME) cluster = CCMCluster.load(path, CLUSTER_NAME)

View File

@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import Queue try:
from Queue import Queue, Empty
except ImportError:
from queue import Queue, Empty # noqa
from struct import pack from struct import pack
import unittest import unittest
@@ -31,13 +35,12 @@ def create_column_name(i):
column_name = '' column_name = ''
while True: while True:
column_name += letters[i % 10] column_name += letters[i % 10]
i /= 10 i = i // 10
if not i: if not i:
break break
if column_name == 'if': if column_name == 'if':
column_name = 'special_case' column_name = 'special_case'
return column_name return column_name
@@ -56,15 +59,15 @@ class LargeDataTests(unittest.TestCase):
return session return session
def batch_futures(self, session, statement_generator): def batch_futures(self, session, statement_generator):
concurrency = 50 concurrency = 10
futures = Queue.Queue(maxsize=concurrency) futures = Queue(maxsize=concurrency)
for i, statement in enumerate(statement_generator): for i, statement in enumerate(statement_generator):
if i > 0 and i % (concurrency - 1) == 0: if i > 0 and i % (concurrency - 1) == 0:
# clear the existing queue # clear the existing queue
while True: while True:
try: try:
futures.get_nowait().result() futures.get_nowait().result()
except Queue.Empty: except Empty:
break break
future = session.execute_async(statement) future = session.execute_async(statement)
@@ -73,7 +76,7 @@ class LargeDataTests(unittest.TestCase):
while True: while True:
try: try:
futures.get_nowait().result() futures.get_nowait().result()
except Queue.Empty: except Empty:
break break
def test_wide_rows(self): def test_wide_rows(self):
@@ -81,15 +84,13 @@ class LargeDataTests(unittest.TestCase):
session = self.make_session_and_keyspace() session = self.make_session_and_keyspace()
session.execute('CREATE TABLE %s (k INT, i INT, PRIMARY KEY(k, i))' % table) session.execute('CREATE TABLE %s (k INT, i INT, PRIMARY KEY(k, i))' % table)
prepared = session.prepare('INSERT INTO %s (k, i) VALUES (0, ?)' % (table, ))
# Write via async futures # Write via async futures
self.batch_futures( self.batch_futures(session, (prepared.bind((i, )) for i in range(100000)))
session,
(SimpleStatement('INSERT INTO %s (k, i) VALUES (0, %s)' % (table, i),
consistency_level=ConsistencyLevel.QUORUM)
for i in range(100000)))
# Read # Read
results = session.execute('SELECT i FROM %s WHERE k=%s' % (table, 0)) results = session.execute('SELECT i FROM %s WHERE k=0' % (table, ))
# Verify # Verify
for i, row in enumerate(results): for i, row in enumerate(results):
@@ -120,18 +121,13 @@ class LargeDataTests(unittest.TestCase):
session = self.make_session_and_keyspace() session = self.make_session_and_keyspace()
session.execute('CREATE TABLE %s (k INT, i INT, v BLOB, PRIMARY KEY(k, i))' % table) session.execute('CREATE TABLE %s (k INT, i INT, v BLOB, PRIMARY KEY(k, i))' % table)
# Build small ByteBuffer sample prepared = session.prepare('INSERT INTO %s (k, i, v) VALUES (0, ?, 0xCAFE)' % (table, ))
bb = '0xCAFE'
# Write # Write
self.batch_futures( self.batch_futures(session, (prepared.bind((i, )) for i in range(100000)))
session,
(SimpleStatement('INSERT INTO %s (k, i, v) VALUES (0, %s, %s)' % (table, i, str(bb)),
consistency_level=ConsistencyLevel.QUORUM)
for i in range(100000)))
# Read # Read
results = session.execute('SELECT i, v FROM %s WHERE k=%s' % (table, 0)) results = session.execute('SELECT i, v FROM %s WHERE k=0' % (table, ))
# Verify # Verify
bb = pack('>H', 0xCAFE) bb = pack('>H', 0xCAFE)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function
import logging import logging
import time import time
@@ -125,7 +126,7 @@ def bootstrap(node, data_center=None, token=None):
def ring(node): def ring(node):
print 'From node%s:' % node print('From node%s:' % node)
get_node(node).nodetool('ring') get_node(node).nodetool('ring')
@@ -154,3 +155,5 @@ def wait_for_down(cluster, node, wait=True):
time.sleep(10) time.sleep(10)
log.debug("Done waiting for node %s to be down", node) log.debug("Done waiting for node %s to be down", node)
return return
else:
log.debug("Host is still marked up, waiting")

View File

@@ -40,7 +40,7 @@ class ClusterTests(unittest.TestCase):
CREATE KEYSPACE clustertests CREATE KEYSPACE clustertests
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
""") """)
self.assertEquals(None, result) self.assertEqual(None, result)
result = session.execute( result = session.execute(
""" """
@@ -51,16 +51,16 @@ class ClusterTests(unittest.TestCase):
PRIMARY KEY (a, b) PRIMARY KEY (a, b)
) )
""") """)
self.assertEquals(None, result) self.assertEqual(None, result)
result = session.execute( result = session.execute(
""" """
INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c') INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c')
""") """)
self.assertEquals(None, result) self.assertEqual(None, result)
result = session.execute("SELECT * FROM clustertests.cf0") result = session.execute("SELECT * FROM clustertests.cf0")
self.assertEquals([('a', 'b', 'c')], result) self.assertEqual([('a', 'b', 'c')], result)
cluster.shutdown() cluster.shutdown()
@@ -75,15 +75,15 @@ class ClusterTests(unittest.TestCase):
""" """
INSERT INTO test3rf.test (k, v) VALUES (8889, 8889) INSERT INTO test3rf.test (k, v) VALUES (8889, 8889)
""") """)
self.assertEquals(None, result) self.assertEqual(None, result)
result = session.execute("SELECT * FROM test3rf.test") result = session.execute("SELECT * FROM test3rf.test")
self.assertEquals([(8889, 8889)], result) self.assertEqual([(8889, 8889)], result)
# test_connect_on_keyspace # test_connect_on_keyspace
session2 = cluster.connect('test3rf') session2 = cluster.connect('test3rf')
result2 = session2.execute("SELECT * FROM test") result2 = session2.execute("SELECT * FROM test")
self.assertEquals(result, result2) self.assertEqual(result, result2)
def test_set_keyspace_twice(self): def test_set_keyspace_twice(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION) cluster = Cluster(protocol_version=PROTOCOL_VERSION)

View File

@@ -43,7 +43,7 @@ class ClusterTests(unittest.TestCase):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", )) statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
parameters = [(i, i) for i in range(num_statements)] parameters = [(i, i) for i in range(num_statements)]
results = execute_concurrent(self.session, zip(statements, parameters)) results = execute_concurrent(self.session, list(zip(statements, parameters)))
self.assertEqual(num_statements, len(results)) self.assertEqual(num_statements, len(results))
self.assertEqual([(True, None)] * num_statements, results) self.assertEqual([(True, None)] * num_statements, results)
@@ -51,7 +51,7 @@ class ClusterTests(unittest.TestCase):
statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s", )) statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s", ))
parameters = [(i, ) for i in range(num_statements)] parameters = [(i, ) for i in range(num_statements)]
results = execute_concurrent(self.session, zip(statements, parameters)) results = execute_concurrent(self.session, list(zip(statements, parameters)))
self.assertEqual(num_statements, len(results)) self.assertEqual(num_statements, len(results))
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results) self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
@@ -81,7 +81,7 @@ class ClusterTests(unittest.TestCase):
self.assertRaises( self.assertRaises(
InvalidRequest, InvalidRequest,
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True) execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)
def test_first_failure_client_side(self): def test_first_failure_client_side(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", )) statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
@@ -92,7 +92,7 @@ class ClusterTests(unittest.TestCase):
self.assertRaises( self.assertRaises(
TypeError, TypeError,
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True) execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)
def test_no_raise_on_first_failure(self): def test_no_raise_on_first_failure(self):
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", )) statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
@@ -101,7 +101,7 @@ class ClusterTests(unittest.TestCase):
# we'll get an error back from the server # we'll get an error back from the server
parameters[57] = ('efefef', 'awefawefawef') parameters[57] = ('efefef', 'awefawefawef')
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False) results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
for i, (success, result) in enumerate(results): for i, (success, result) in enumerate(results):
if i == 57: if i == 57:
self.assertFalse(success) self.assertFalse(success)
@@ -115,9 +115,9 @@ class ClusterTests(unittest.TestCase):
parameters = [(i, i) for i in range(100)] parameters = [(i, i) for i in range(100)]
# the driver will raise an error when binding the params # the driver will raise an error when binding the params
parameters[57] = i parameters[57] = 1
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False) results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
for i, (success, result) in enumerate(results): for i, (success, result) in enumerate(results):
if i == 57: if i == 57:
self.assertFalse(success) self.assertFalse(success)

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import six
try: try:
import unittest2 as unittest import unittest2 as unittest
@@ -115,7 +116,7 @@ class SchemaMetadataTest(unittest.TestCase):
def check_create_statement(self, tablemeta, original): def check_create_statement(self, tablemeta, original):
recreate = tablemeta.as_cql_query(formatted=False) recreate = tablemeta.as_cql_query(formatted=False)
self.assertEquals(original, recreate[:len(original)]) self.assertEqual(original, recreate[:len(original)])
self.session.execute("DROP TABLE %s.%s" % (self.ksname, self.cfname)) self.session.execute("DROP TABLE %s.%s" % (self.ksname, self.cfname))
self.session.execute(recreate) self.session.execute(recreate)
@@ -289,7 +290,7 @@ class SchemaMetadataTest(unittest.TestCase):
tablemeta = self.get_table_metadata() tablemeta = self.get_table_metadata()
statements = tablemeta.export_as_string().strip() statements = tablemeta.export_as_string().strip()
statements = [s.strip() for s in statements.split(';')] statements = [s.strip() for s in statements.split(';')]
statements = filter(bool, statements) statements = list(filter(bool, statements))
self.assertEqual(3, len(statements)) self.assertEqual(3, len(statements))
self.assertEqual(d_index, statements[1]) self.assertEqual(d_index, statements[1])
self.assertEqual(e_index, statements[2]) self.assertEqual(e_index, statements[2])
@@ -311,7 +312,7 @@ class TestCodeCoverage(unittest.TestCase):
cluster = Cluster(protocol_version=PROTOCOL_VERSION) cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cluster.connect() cluster.connect()
self.assertIsInstance(cluster.metadata.export_schema_as_string(), basestring) self.assertIsInstance(cluster.metadata.export_schema_as_string(), six.string_types)
def test_export_keyspace_schema(self): def test_export_keyspace_schema(self):
""" """
@@ -323,8 +324,8 @@ class TestCodeCoverage(unittest.TestCase):
for keyspace in cluster.metadata.keyspaces: for keyspace in cluster.metadata.keyspaces:
keyspace_metadata = cluster.metadata.keyspaces[keyspace] keyspace_metadata = cluster.metadata.keyspaces[keyspace]
self.assertIsInstance(keyspace_metadata.export_as_string(), basestring) self.assertIsInstance(keyspace_metadata.export_as_string(), six.string_types)
self.assertIsInstance(keyspace_metadata.as_cql_query(), basestring) self.assertIsInstance(keyspace_metadata.as_cql_query(), six.string_types)
def test_case_sensitivity(self): def test_case_sensitivity(self):
""" """

View File

@@ -69,7 +69,7 @@ class PreparedStatementTests(unittest.TestCase):
bound = prepared.bind(('a')) bound = prepared.bind(('a'))
results = session.execute(bound) results = session.execute(bound)
self.assertEquals(results, [('a', 'b', 'c')]) self.assertEqual(results, [('a', 'b', 'c')])
# test with new dict binding # test with new dict binding
prepared = session.prepare( prepared = session.prepare(
@@ -95,7 +95,7 @@ class PreparedStatementTests(unittest.TestCase):
bound = prepared.bind({'a': 'x'}) bound = prepared.bind({'a': 'x'})
results = session.execute(bound) results = session.execute(bound)
self.assertEquals(results, [('x', 'y', 'z')]) self.assertEqual(results, [('x', 'y', 'z')])
def test_missing_primary_key(self): def test_missing_primary_key(self):
""" """
@@ -148,7 +148,7 @@ class PreparedStatementTests(unittest.TestCase):
""") """)
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
self.assertRaises(ValueError, prepared.bind, (1,2)) self.assertRaises(ValueError, prepared.bind, (1, 2))
def test_too_many_bind_values_dicts(self): def test_too_many_bind_values_dicts(self):
""" """
@@ -196,7 +196,7 @@ class PreparedStatementTests(unittest.TestCase):
bound = prepared.bind((1,)) bound = prepared.bind((1,))
results = session.execute(bound) results = session.execute(bound)
self.assertEquals(results[0].v, None) self.assertEqual(results[0].v, None)
def test_none_values_dicts(self): def test_none_values_dicts(self):
""" """
@@ -206,7 +206,6 @@ class PreparedStatementTests(unittest.TestCase):
cluster = Cluster(protocol_version=PROTOCOL_VERSION) cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
# test with new dict binding # test with new dict binding
prepared = session.prepare( prepared = session.prepare(
""" """
@@ -225,7 +224,7 @@ class PreparedStatementTests(unittest.TestCase):
bound = prepared.bind({'k': 1}) bound = prepared.bind({'k': 1})
results = session.execute(bound) results = session.execute(bound)
self.assertEquals(results[0].v, None) self.assertEqual(results[0].v, None)
def test_async_binding(self): def test_async_binding(self):
""" """
@@ -252,8 +251,7 @@ class PreparedStatementTests(unittest.TestCase):
future = session.execute_async(prepared, (873,)) future = session.execute_async(prepared, (873,))
results = future.result() results = future.result()
self.assertEquals(results[0].v, None) self.assertEqual(results[0].v, None)
def test_async_binding_dicts(self): def test_async_binding_dicts(self):
""" """
@@ -280,4 +278,4 @@ class PreparedStatementTests(unittest.TestCase):
future = session.execute_async(prepared, {'k': 873}) future = session.execute_async(prepared, {'k': 873})
results = future.result() results = future.result()
self.assertEquals(results[0].v, None) self.assertEqual(results[0].v, None)

View File

@@ -24,7 +24,7 @@ from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.policies import HostDistance from cassandra.policies import HostDistance
from tests.integration import get_server_versions, PROTOCOL_VERSION from tests.integration import PROTOCOL_VERSION
class QueryTest(unittest.TestCase): class QueryTest(unittest.TestCase):
@@ -43,7 +43,7 @@ class QueryTest(unittest.TestCase):
self.assertIsInstance(bound, BoundStatement) self.assertIsInstance(bound, BoundStatement)
self.assertEqual(2, len(bound.values)) self.assertEqual(2, len(bound.values))
session.execute(bound) session.execute(bound)
self.assertEqual(bound.routing_key, '\x00\x00\x00\x01') self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
def test_value_sequence(self): def test_value_sequence(self):
""" """
@@ -102,7 +102,7 @@ class PreparedStatementTests(unittest.TestCase):
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1, None)) bound = prepared.bind((1, None))
self.assertEqual(bound.routing_key, '\x00\x00\x00\x01') self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
def test_empty_routing_key_indexes(self): def test_empty_routing_key_indexes(self):
""" """
@@ -158,7 +158,7 @@ class PreparedStatementTests(unittest.TestCase):
self.assertIsInstance(prepared, PreparedStatement) self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1, 2)) bound = prepared.bind((1, 2))
self.assertEqual(bound.routing_key, '\x04\x00\x00\x00\x04\x00\x00\x00') self.assertEqual(bound.routing_key, b'\x04\x00\x00\x00\x04\x00\x00\x00')
def test_bound_keyspace(self): def test_bound_keyspace(self):
""" """
@@ -326,14 +326,14 @@ class SerialConsistencyTests(unittest.TestCase):
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1", "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1",
serial_consistency_level=ConsistencyLevel.SERIAL) serial_consistency_level=ConsistencyLevel.SERIAL)
result = self.session.execute(statement) result = self.session.execute(statement)
self.assertEquals(1, len(result)) self.assertEqual(1, len(result))
self.assertFalse(result[0].applied) self.assertFalse(result[0].applied)
statement = SimpleStatement( statement = SimpleStatement(
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0", "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0",
serial_consistency_level=ConsistencyLevel.SERIAL) serial_consistency_level=ConsistencyLevel.SERIAL)
result = self.session.execute(statement) result = self.session.execute(statement)
self.assertEquals(1, len(result)) self.assertEqual(1, len(result))
self.assertTrue(result[0].applied) self.assertTrue(result[0].applied)
def test_conditional_update_with_prepared_statements(self): def test_conditional_update_with_prepared_statements(self):
@@ -343,7 +343,7 @@ class SerialConsistencyTests(unittest.TestCase):
statement.serial_consistency_level = ConsistencyLevel.SERIAL statement.serial_consistency_level = ConsistencyLevel.SERIAL
result = self.session.execute(statement) result = self.session.execute(statement)
self.assertEquals(1, len(result)) self.assertEqual(1, len(result))
self.assertFalse(result[0].applied) self.assertFalse(result[0].applied)
statement = self.session.prepare( statement = self.session.prepare(
@@ -351,7 +351,7 @@ class SerialConsistencyTests(unittest.TestCase):
bound = statement.bind(()) bound = statement.bind(())
bound.serial_consistency_level = ConsistencyLevel.SERIAL bound.serial_consistency_level = ConsistencyLevel.SERIAL
result = self.session.execute(statement) result = self.session.execute(statement)
self.assertEquals(1, len(result)) self.assertEqual(1, len(result))
self.assertTrue(result[0].applied) self.assertTrue(result[0].applied)
def test_bad_consistency_level(self): def test_bad_consistency_level(self):

View File

@@ -23,6 +23,7 @@ except ImportError:
import unittest # noqa import unittest # noqa
from itertools import cycle, count from itertools import cycle, count
from six.moves import range
from threading import Event from threading import Event
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
@@ -47,7 +48,7 @@ class QueryPagingTests(unittest.TestCase):
def test_paging(self): def test_paging(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)]) [(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params) execute_concurrent(self.session, list(statements_and_params))
prepared = self.session.prepare("SELECT * FROM test3rf.test") prepared = self.session.prepare("SELECT * FROM test3rf.test")
@@ -153,7 +154,7 @@ class QueryPagingTests(unittest.TestCase):
def test_async_paging(self): def test_async_paging(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)]) [(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params) execute_concurrent(self.session, list(statements_and_params))
prepared = self.session.prepare("SELECT * FROM test3rf.test") prepared = self.session.prepare("SELECT * FROM test3rf.test")
@@ -219,7 +220,7 @@ class QueryPagingTests(unittest.TestCase):
def test_paging_callbacks(self): def test_paging_callbacks(self):
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)]) [(i, ) for i in range(100)])
execute_concurrent(self.session, statements_and_params) execute_concurrent(self.session, list(statements_and_params))
prepared = self.session.prepare("SELECT * FROM test3rf.test") prepared = self.session.prepare("SELECT * FROM test3rf.test")
@@ -232,7 +233,7 @@ class QueryPagingTests(unittest.TestCase):
def handle_page(rows, future, counter): def handle_page(rows, future, counter):
for row in rows: for row in rows:
counter.next() next(counter)
if future.has_more_pages: if future.has_more_pages:
future.start_fetching_next_page() future.start_fetching_next_page()
@@ -245,7 +246,7 @@ class QueryPagingTests(unittest.TestCase):
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error) future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait() event.wait()
self.assertEquals(counter.next(), 100) self.assertEquals(next(counter), 100)
# simple statement # simple statement
future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test")) future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"))
@@ -254,7 +255,7 @@ class QueryPagingTests(unittest.TestCase):
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error) future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait() event.wait()
self.assertEquals(counter.next(), 100) self.assertEquals(next(counter), 100)
# prepared statement # prepared statement
future = self.session.execute_async(prepared) future = self.session.execute_async(prepared)
@@ -263,4 +264,4 @@ class QueryPagingTests(unittest.TestCase):
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error) future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
event.wait() event.wait()
self.assertEquals(counter.next(), 100) self.assertEquals(next(counter), 100)

View File

@@ -17,8 +17,12 @@ try:
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
import logging
log = logging.getLogger(__name__)
from decimal import Decimal from decimal import Decimal
from datetime import datetime from datetime import datetime
import six
from uuid import uuid1, uuid4 from uuid import uuid1, uuid4
try: try:
@@ -59,12 +63,16 @@ class TypeTests(unittest.TestCase):
params = [ params = [
'key1', 'key1',
'blobyblob'.encode('hex') b'blobyblob'
] ]
query = 'INSERT INTO mytable (a, b) VALUES (%s, %s)' query = 'INSERT INTO mytable (a, b) VALUES (%s, %s)'
if self._cql_version >= (3, 1, 0): # In python 3, the 'bytes' type is treated as a blob, so we can
# correctly encode it with hex notation.
# In python2, we don't treat the 'str' type as a blob, so we'll encode it
# as a string literal and have the following failure.
if six.PY2 and self._cql_version >= (3, 1, 0):
# Blob values can't be specified using string notation in CQL 3.1.0 and # Blob values can't be specified using string notation in CQL 3.1.0 and
# above which is used by default in Cassandra 2.0. # above which is used by default in Cassandra 2.0.
msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*' msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*'
@@ -74,13 +82,13 @@ class TypeTests(unittest.TestCase):
s.execute(query, params) s.execute(query, params)
expected_vals = [ expected_vals = [
'key1', 'key1',
'blobyblob' bytearray(b'blobyblob')
] ]
results = s.execute("SELECT * FROM mytable") results = s.execute("SELECT * FROM mytable")
for expected, actual in zip(expected_vals, results[0]): for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual) self.assertEqual(expected, actual)
def test_blob_type_as_bytearray(self): def test_blob_type_as_bytearray(self):
c = Cluster(protocol_version=PROTOCOL_VERSION) c = Cluster(protocol_version=PROTOCOL_VERSION)
@@ -101,7 +109,7 @@ class TypeTests(unittest.TestCase):
params = [ params = [
'key1', 'key1',
bytearray('blob1', 'hex') bytearray(b'blob1')
] ]
query = 'INSERT INTO mytable (a, b) VALUES (%s, %s);' query = 'INSERT INTO mytable (a, b) VALUES (%s, %s);'
@@ -109,13 +117,13 @@ class TypeTests(unittest.TestCase):
expected_vals = [ expected_vals = [
'key1', 'key1',
bytearray('blob1', 'hex') bytearray(b'blob1')
] ]
results = s.execute("SELECT * FROM mytable") results = s.execute("SELECT * FROM mytable")
for expected, actual in zip(expected_vals, results[0]): for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual) self.assertEqual(expected, actual)
create_type_table = """ create_type_table = """
CREATE TABLE mytable ( CREATE TABLE mytable (
@@ -208,7 +216,7 @@ class TypeTests(unittest.TestCase):
results = s.execute("SELECT * FROM mytable") results = s.execute("SELECT * FROM mytable")
for expected, actual in zip(expected_vals, results[0]): for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual) self.assertEqual(expected, actual)
# try the same thing with a prepared statement # try the same thing with a prepared statement
prepared = s.prepare(""" prepared = s.prepare("""
@@ -221,7 +229,7 @@ class TypeTests(unittest.TestCase):
results = s.execute("SELECT * FROM mytable") results = s.execute("SELECT * FROM mytable")
for expected, actual in zip(expected_vals, results[0]): for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual) self.assertEqual(expected, actual)
# query with prepared statement # query with prepared statement
prepared = s.prepare(""" prepared = s.prepare("""
@@ -230,14 +238,14 @@ class TypeTests(unittest.TestCase):
results = s.execute(prepared.bind(())) results = s.execute(prepared.bind(()))
for expected, actual in zip(expected_vals, results[0]): for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual) self.assertEqual(expected, actual)
# query with prepared statement, no explicit columns # query with prepared statement, no explicit columns
prepared = s.prepare("""SELECT * FROM mytable""") prepared = s.prepare("""SELECT * FROM mytable""")
results = s.execute(prepared.bind(())) results = s.execute(prepared.bind(()))
for expected, actual in zip(expected_vals, results[0]): for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual) self.assertEqual(expected, actual)
def test_empty_strings_and_nones(self): def test_empty_strings_and_nones(self):
c = Cluster(protocol_version=PROTOCOL_VERSION) c = Cluster(protocol_version=PROTOCOL_VERSION)
@@ -265,11 +273,11 @@ class TypeTests(unittest.TestCase):
# insert empty strings for string-like fields and fetch them # insert empty strings for string-like fields and fetch them
s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)", s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)",
('', '', '', [''], {'': 3})) ('', '', '', [''], {'': 3}))
self.assertEquals( self.assertEqual(
{'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})}, {'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})},
s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0]) s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0])
self.assertEquals( self.assertEqual(
{'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})}, {'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})},
s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0]) s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0])
@@ -363,7 +371,7 @@ class TypeTests(unittest.TestCase):
""" Ensure timezone-aware datetimes are converted to timestamps correctly """ """ Ensure timezone-aware datetimes are converted to timestamps correctly """
try: try:
import pytz import pytz
except ImportError, exc: except ImportError as exc:
raise unittest.SkipTest('pytz is not available: %r' % (exc,)) raise unittest.SkipTest('pytz is not available: %r' % (exc,))
dt = datetime(1997, 8, 29, 11, 14) dt = datetime(1997, 8, 29, 11, 14)
@@ -381,10 +389,10 @@ class TypeTests(unittest.TestCase):
# test non-prepared statement # test non-prepared statement
s.execute("INSERT INTO mytable (a, b) VALUES ('key1', %s)", parameters=(dt,)) s.execute("INSERT INTO mytable (a, b) VALUES ('key1', %s)", parameters=(dt,))
result = s.execute("SELECT b FROM mytable WHERE a='key1'")[0].b result = s.execute("SELECT b FROM mytable WHERE a='key1'")[0].b
self.assertEquals(dt.utctimetuple(), result.utctimetuple()) self.assertEqual(dt.utctimetuple(), result.utctimetuple())
# test prepared statement # test prepared statement
prepared = s.prepare("INSERT INTO mytable (a, b) VALUES ('key2', ?)") prepared = s.prepare("INSERT INTO mytable (a, b) VALUES ('key2', ?)")
s.execute(prepared, parameters=(dt,)) s.execute(prepared, parameters=(dt,))
result = s.execute("SELECT b FROM mytable WHERE a='key2'")[0].b result = s.execute("SELECT b FROM mytable WHERE a='key2'")[0].b
self.assertEquals(dt.utctimetuple(), result.utctimetuple()) self.assertEqual(dt.utctimetuple(), result.utctimetuple())

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import six
try: try:
import unittest2 as unittest import unittest2 as unittest
@@ -19,7 +20,9 @@ except ImportError:
import errno import errno
import os import os
from StringIO import StringIO
from six import BytesIO
import socket import socket
from socket import error as socket_error from socket import error as socket_error
@@ -55,7 +58,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
return c return c
def make_header_prefix(self, message_class, version=2, stream_id=0): def make_header_prefix(self, message_class, version=2, stream_id=0):
return ''.join(map(uint8_pack, [ return six.binary_type().join(map(uint8_pack, [
0xff & (HEADER_DIRECTION_TO_CLIENT | version), 0xff & (HEADER_DIRECTION_TO_CLIENT | version),
0, # flags (compression) 0, # flags (compression)
stream_id, stream_id,
@@ -63,7 +66,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
])) ]))
def make_options_body(self): def make_options_body(self):
options_buf = StringIO() options_buf = BytesIO()
write_stringmultimap(options_buf, { write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.1'], 'CQL_VERSION': ['3.0.1'],
'COMPRESSION': [] 'COMPRESSION': []
@@ -71,12 +74,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
return options_buf.getvalue() return options_buf.getvalue()
def make_error_body(self, code, msg): def make_error_body(self, code, msg):
buf = StringIO() buf = BytesIO()
write_int(buf, code) write_int(buf, code)
write_string(buf, msg) write_string(buf, msg)
return buf.getvalue() return buf.getvalue()
def make_msg(self, header, body=""): def make_msg(self, header, body=six.binary_type()):
return header + uint32_pack(len(body)) + body return header + uint32_pack(len(body)) + body
def test_successful_connection(self, *args): def test_successful_connection(self, *args):
@@ -105,12 +108,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
# get a connection that's already fully started # get a connection that's already fully started
c = self.test_successful_connection() c = self.test_successful_connection()
header = '\x00\x00\x00\x00' + int32_pack(20000) header = six.b('\x00\x00\x00\x00') + int32_pack(20000)
responses = [ responses = [
header + ('a' * (4096 - len(header))), header + (six.b('a') * (4096 - len(header))),
'a' * 4096, six.b('a') * 4096,
socket_error(errno.EAGAIN), socket_error(errno.EAGAIN),
'a' * 100, six.b('a') * 100,
socket_error(errno.EAGAIN)] socket_error(errno.EAGAIN)]
def side_effect(*args): def side_effect(*args):
@@ -122,17 +125,17 @@ class AsyncoreConnectionTest(unittest.TestCase):
c.socket.recv.side_effect = side_effect c.socket.recv.side_effect = side_effect
c.handle_read() c.handle_read()
self.assertEquals(c._total_reqd_bytes, 20000 + len(header)) self.assertEqual(c._total_reqd_bytes, 20000 + len(header))
# the EAGAIN prevents it from reading the last 100 bytes # the EAGAIN prevents it from reading the last 100 bytes
c._iobuf.seek(0, os.SEEK_END) c._iobuf.seek(0, os.SEEK_END)
pos = c._iobuf.tell() pos = c._iobuf.tell()
self.assertEquals(pos, 4096 + 4096) self.assertEqual(pos, 4096 + 4096)
# now tell it to read the last 100 bytes # now tell it to read the last 100 bytes
c.handle_read() c.handle_read()
c._iobuf.seek(0, os.SEEK_END) c._iobuf.seek(0, os.SEEK_END)
pos = c._iobuf.tell() pos = c._iobuf.tell()
self.assertEquals(pos, 4096 + 4096 + 100) self.assertEqual(pos, 4096 + 4096 + 100)
def test_protocol_error(self, *args): def test_protocol_error(self, *args):
c = self.make_connection() c = self.make_connection()
@@ -237,14 +240,13 @@ class AsyncoreConnectionTest(unittest.TestCase):
options = self.make_options_body() options = self.make_options_body()
message = self.make_msg(header, options) message = self.make_msg(header, options)
# read in the first byte c.socket.recv.return_value = message[0:1]
c.socket.recv.return_value = message[0]
c.handle_read() c.handle_read()
self.assertEquals(c._iobuf.getvalue(), message[0]) self.assertEqual(c._iobuf.getvalue(), message[0:1])
c.socket.recv.return_value = message[1:] c.socket.recv.return_value = message[1:]
c.handle_read() c.handle_read()
self.assertEquals("", c._iobuf.getvalue()) self.assertEqual(six.binary_type(), c._iobuf.getvalue())
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write() c.handle_write()
@@ -266,12 +268,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
# read in the first nine bytes # read in the first nine bytes
c.socket.recv.return_value = message[:9] c.socket.recv.return_value = message[:9]
c.handle_read() c.handle_read()
self.assertEquals(c._iobuf.getvalue(), message[:9]) self.assertEqual(c._iobuf.getvalue(), message[:9])
# ... then read in the rest # ... then read in the rest
c.socket.recv.return_value = message[9:] c.socket.recv.return_value = message[9:]
c.handle_read() c.handle_read()
self.assertEquals("", c._iobuf.getvalue()) self.assertEqual(six.binary_type(), c._iobuf.getvalue())
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write() c.handle_write()

View File

@@ -19,7 +19,9 @@ except ImportError:
import errno import errno
import os import os
from StringIO import StringIO
from six.moves import StringIO
from socket import error as socket_error from socket import error as socket_error
from mock import patch, Mock from mock import patch, Mock
@@ -122,17 +124,17 @@ class LibevConnectionTest(unittest.TestCase):
c._socket.recv.side_effect = side_effect c._socket.recv.side_effect = side_effect
c.handle_read(None, 0) c.handle_read(None, 0)
self.assertEquals(c._total_reqd_bytes, 20000 + len(header)) self.assertEqual(c._total_reqd_bytes, 20000 + len(header))
# the EAGAIN prevents it from reading the last 100 bytes # the EAGAIN prevents it from reading the last 100 bytes
c._iobuf.seek(0, os.SEEK_END) c._iobuf.seek(0, os.SEEK_END)
pos = c._iobuf.tell() pos = c._iobuf.tell()
self.assertEquals(pos, 4096 + 4096) self.assertEqual(pos, 4096 + 4096)
# now tell it to read the last 100 bytes # now tell it to read the last 100 bytes
c.handle_read(None, 0) c.handle_read(None, 0)
c._iobuf.seek(0, os.SEEK_END) c._iobuf.seek(0, os.SEEK_END)
pos = c._iobuf.tell() pos = c._iobuf.tell()
self.assertEquals(pos, 4096 + 4096 + 100) self.assertEqual(pos, 4096 + 4096 + 100)
def test_protocol_error(self, *args): def test_protocol_error(self, *args):
c = self.make_connection() c = self.make_connection()
@@ -240,11 +242,11 @@ class LibevConnectionTest(unittest.TestCase):
# read in the first byte # read in the first byte
c._socket.recv.return_value = message[0] c._socket.recv.return_value = message[0]
c.handle_read(None, 0) c.handle_read(None, 0)
self.assertEquals(c._iobuf.getvalue(), message[0]) self.assertEqual(c._iobuf.getvalue(), message[0])
c._socket.recv.return_value = message[1:] c._socket.recv.return_value = message[1:]
c.handle_read(None, 0) c.handle_read(None, 0)
self.assertEquals("", c._iobuf.getvalue()) self.assertEqual("", c._iobuf.getvalue())
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write(None, 0) c.handle_write(None, 0)
@@ -266,12 +268,12 @@ class LibevConnectionTest(unittest.TestCase):
# read in the first nine bytes # read in the first nine bytes
c._socket.recv.return_value = message[:9] c._socket.recv.return_value = message[:9]
c.handle_read(None, 0) c.handle_read(None, 0)
self.assertEquals(c._iobuf.getvalue(), message[:9]) self.assertEqual(c._iobuf.getvalue(), message[:9])
# ... then read in the rest # ... then read in the rest
c._socket.recv.return_value = message[9:] c._socket.recv.return_value = message[9:]
c.handle_read(None, 0) c.handle_read(None, 0)
self.assertEquals("", c._iobuf.getvalue()) self.assertEqual("", c._iobuf.getvalue())
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write(None, 0) c.handle_write(None, 0)

View File

@@ -11,17 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from cassandra.cluster import Cluster import six
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
from StringIO import StringIO from six import BytesIO
from mock import Mock, ANY from mock import Mock, ANY
from cassandra.cluster import Cluster
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT,
HEADER_DIRECTION_FROM_CLIENT, ProtocolError) HEADER_DIRECTION_FROM_CLIENT, ProtocolError)
from cassandra.decoder import (write_stringmultimap, write_int, write_string, from cassandra.decoder import (write_stringmultimap, write_int, write_string,
@@ -40,7 +41,7 @@ class ConnectionTest(unittest.TestCase):
return c return c
def make_header_prefix(self, message_class, version=2, stream_id=0): def make_header_prefix(self, message_class, version=2, stream_id=0):
return ''.join(map(uint8_pack, [ return six.binary_type().join(map(uint8_pack, [
0xff & (HEADER_DIRECTION_TO_CLIENT | version), 0xff & (HEADER_DIRECTION_TO_CLIENT | version),
0, # flags (compression) 0, # flags (compression)
stream_id, stream_id,
@@ -48,7 +49,7 @@ class ConnectionTest(unittest.TestCase):
])) ]))
def make_options_body(self): def make_options_body(self):
options_buf = StringIO() options_buf = BytesIO()
write_stringmultimap(options_buf, { write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.1'], 'CQL_VERSION': ['3.0.1'],
'COMPRESSION': [] 'COMPRESSION': []
@@ -56,7 +57,7 @@ class ConnectionTest(unittest.TestCase):
return options_buf.getvalue() return options_buf.getvalue()
def make_error_body(self, code, msg): def make_error_body(self, code, msg):
buf = StringIO() buf = BytesIO()
write_int(buf, code) write_int(buf, code)
write_string(buf, msg) write_string(buf, msg)
return buf.getvalue() return buf.getvalue()
@@ -88,12 +89,12 @@ class ConnectionTest(unittest.TestCase):
c.defunct = Mock() c.defunct = Mock()
# read in a SupportedMessage response # read in a SupportedMessage response
header = ''.join(map(uint8_pack, [ header = six.binary_type().join(uint8_pack(i) for i in (
0xff & (HEADER_DIRECTION_FROM_CLIENT | self.protocol_version), 0xff & (HEADER_DIRECTION_FROM_CLIENT | self.protocol_version),
0, # flags (compression) 0, # flags (compression)
0, 0,
SupportedMessage.opcode # opcode SupportedMessage.opcode # opcode
])) ))
options = self.make_options_body() options = self.make_options_body()
message = self.make_msg(header, options) message = self.make_msg(header, options)
c.process_msg(message, len(message) - 8) c.process_msg(message, len(message) - 8)
@@ -130,7 +131,7 @@ class ConnectionTest(unittest.TestCase):
# read in a SupportedMessage response # read in a SupportedMessage response
header = self.make_header_prefix(SupportedMessage) header = self.make_header_prefix(SupportedMessage)
options_buf = StringIO() options_buf = BytesIO()
write_stringmultimap(options_buf, { write_stringmultimap(options_buf, {
'CQL_VERSION': ['7.8.9'], 'CQL_VERSION': ['7.8.9'],
'COMPRESSION': [] 'COMPRESSION': []

View File

@@ -16,7 +16,7 @@ from cassandra.marshal import bitlength
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
import platform import platform
from datetime import datetime from datetime import datetime
@@ -33,57 +33,57 @@ from cassandra.util import OrderedDict
marshalled_value_pairs = ( marshalled_value_pairs = (
# binary form, type, python native type # binary form, type, python native type
('lorem ipsum dolor sit amet', 'AsciiType', 'lorem ipsum dolor sit amet'), (b'lorem ipsum dolor sit amet', 'AsciiType', 'lorem ipsum dolor sit amet'),
('', 'AsciiType', ''), (b'', 'AsciiType', ''),
('\x01', 'BooleanType', True), (b'\x01', 'BooleanType', True),
('\x00', 'BooleanType', False), (b'\x00', 'BooleanType', False),
('', 'BooleanType', None), (b'', 'BooleanType', None),
('\xff\xfe\xfd\xfc\xfb', 'BytesType', '\xff\xfe\xfd\xfc\xfb'), (b'\xff\xfe\xfd\xfc\xfb', 'BytesType', b'\xff\xfe\xfd\xfc\xfb'),
('', 'BytesType', ''), (b'', 'BytesType', b''),
('\x7f\xff\xff\xff\xff\xff\xff\xff', 'CounterColumnType', 9223372036854775807), (b'\x7f\xff\xff\xff\xff\xff\xff\xff', 'CounterColumnType', 9223372036854775807),
('\x80\x00\x00\x00\x00\x00\x00\x00', 'CounterColumnType', -9223372036854775808), (b'\x80\x00\x00\x00\x00\x00\x00\x00', 'CounterColumnType', -9223372036854775808),
('', 'CounterColumnType', None), (b'', 'CounterColumnType', None),
('\x00\x00\x013\x7fb\xeey', 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)), (b'\x00\x00\x013\x7fb\xeey', 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)),
('', 'DateType', None), (b'', 'DateType', None),
('\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe', 'DecimalType', Decimal('1243878957943.1234124191998')), (b'\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe', 'DecimalType', Decimal('1243878957943.1234124191998')),
('\x00\x00\x00\x06\xe5\xde]\x98Y', 'DecimalType', Decimal('-112233.441191')), (b'\x00\x00\x00\x06\xe5\xde]\x98Y', 'DecimalType', Decimal('-112233.441191')),
('\x00\x00\x00\x14\x00\xfa\xce', 'DecimalType', Decimal('0.00000000000000064206')), (b'\x00\x00\x00\x14\x00\xfa\xce', 'DecimalType', Decimal('0.00000000000000064206')),
('\x00\x00\x00\x14\xff\x052', 'DecimalType', Decimal('-0.00000000000000064206')), (b'\x00\x00\x00\x14\xff\x052', 'DecimalType', Decimal('-0.00000000000000064206')),
('\xff\xff\xff\x9c\x00\xfa\xce', 'DecimalType', Decimal('64206e100')), (b'\xff\xff\xff\x9c\x00\xfa\xce', 'DecimalType', Decimal('64206e100')),
('', 'DecimalType', None), (b'', 'DecimalType', None),
('@\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', 19432.125), (b'@\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', 19432.125),
('\xc0\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', -19432.125), (b'\xc0\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', -19432.125),
('\x7f\xef\x00\x00\x00\x00\x00\x00', 'DoubleType', 1.7415152243978685e+308), (b'\x7f\xef\x00\x00\x00\x00\x00\x00', 'DoubleType', 1.7415152243978685e+308),
('', 'DoubleType', None), (b'', 'DoubleType', None),
('F\x97\xd0@', 'FloatType', 19432.125), (b'F\x97\xd0@', 'FloatType', 19432.125),
('\xc6\x97\xd0@', 'FloatType', -19432.125), (b'\xc6\x97\xd0@', 'FloatType', -19432.125),
('\xc6\x97\xd0@', 'FloatType', -19432.125), (b'\xc6\x97\xd0@', 'FloatType', -19432.125),
('\x7f\x7f\x00\x00', 'FloatType', 338953138925153547590470800371487866880.0), (b'\x7f\x7f\x00\x00', 'FloatType', 338953138925153547590470800371487866880.0),
('', 'FloatType', None), (b'', 'FloatType', None),
('\x7f\x50\x00\x00', 'Int32Type', 2135949312), (b'\x7f\x50\x00\x00', 'Int32Type', 2135949312),
('\xff\xfd\xcb\x91', 'Int32Type', -144495), (b'\xff\xfd\xcb\x91', 'Int32Type', -144495),
('', 'Int32Type', None), (b'', 'Int32Type', None),
('f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15', 'IntegerType', 123456789123456789123456789), (b'f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15', 'IntegerType', 123456789123456789123456789),
('', 'IntegerType', None), (b'', 'IntegerType', None),
('\x7f\xff\xff\xff\xff\xff\xff\xff', 'LongType', 9223372036854775807), (b'\x7f\xff\xff\xff\xff\xff\xff\xff', 'LongType', 9223372036854775807),
('\x80\x00\x00\x00\x00\x00\x00\x00', 'LongType', -9223372036854775808), (b'\x80\x00\x00\x00\x00\x00\x00\x00', 'LongType', -9223372036854775808),
('', 'LongType', None), (b'', 'LongType', None),
('', 'InetAddressType', None), (b'', 'InetAddressType', None),
('A46\xa9', 'InetAddressType', '65.52.54.169'), (b'A46\xa9', 'InetAddressType', '65.52.54.169'),
('*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"', 'InetAddressType', '2a00:1328:e102:ccc0::122'), (b'*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"', 'InetAddressType', '2a00:1328:e102:ccc0::122'),
('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', u'\u307e\u3057\u3066'), (b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', u'\u307e\u3057\u3066'),
('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', u'\u307e\u3057\u3066' * 1000), (b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', u'\u307e\u3057\u3066' * 1000),
('', 'UTF8Type', u''), (b'', 'UTF8Type', u''),
('\xff' * 16, 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), (b'\xff' * 16, 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')),
('I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*', 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')), (b'I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*', 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')),
('', 'UUIDType', None), (b'', 'UUIDType', None),
('', 'MapType(AsciiType, BooleanType)', None), (b'', 'MapType(AsciiType, BooleanType)', None),
('', 'ListType(FloatType)', None), (b'', 'ListType(FloatType)', None),
('', 'SetType(LongType)', None), (b'', 'SetType(LongType)', None),
('\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()), (b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()),
('\x00\x00', 'ListType(FloatType)', ()), (b'\x00\x00', 'ListType(FloatType)', ()),
('\x00\x00', 'SetType(IntegerType)', sortedset()), (b'\x00\x00', 'SetType(IntegerType)', sortedset()),
('\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', (UUID(bytes='\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0'),)), (b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', (UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0'),)),
) )
ordered_dict_value = OrderedDict() ordered_dict_value = OrderedDict()
@@ -94,9 +94,9 @@ ordered_dict_value[u'\\'] = 0
# these following entries work for me right now, but they're dependent on # these following entries work for me right now, but they're dependent on
# vagaries of internal python ordering for unordered types # vagaries of internal python ordering for unordered types
marshalled_value_pairs_unsafe = ( marshalled_value_pairs_unsafe = (
('\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value), (b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value),
('\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])), (b'\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])),
('\x00', 'IntegerType', 0), (b'\x00', 'IntegerType', 0),
) )
if platform.python_implementation() == 'CPython': if platform.python_implementation() == 'CPython':

View File

@@ -29,6 +29,12 @@ from cassandra.pool import Host
class TestStrategies(unittest.TestCase): class TestStrategies(unittest.TestCase):
@classmethod
def setUpClass(cls):
"Hook method for setting up class fixture before running tests in the class."
if not hasattr(cls, 'assertItemsEqual'):
cls.assertItemsEqual = cls.assertCountEqual
def test_replication_strategy(self): def test_replication_strategy(self):
""" """
Basic code coverage testing that ensures different ReplicationStrategies Basic code coverage testing that ensures different ReplicationStrategies
@@ -217,22 +223,22 @@ class TestTokens(unittest.TestCase):
murmur3_token = Murmur3Token(cassandra.metadata.MIN_LONG - 1) murmur3_token = Murmur3Token(cassandra.metadata.MIN_LONG - 1)
self.assertEqual(murmur3_token.hash_fn('123'), -7468325962851647638) self.assertEqual(murmur3_token.hash_fn('123'), -7468325962851647638)
self.assertEqual(murmur3_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 7162290910810015547) self.assertEqual(murmur3_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 7162290910810015547)
self.assertEqual(str(murmur3_token), '<Murmur3Token: -9223372036854775809L>') self.assertEqual(str(murmur3_token), '<Murmur3Token: -9223372036854775809>')
except NoMurmur3: except NoMurmur3:
raise unittest.SkipTest('The murmur3 extension is not available') raise unittest.SkipTest('The murmur3 extension is not available')
def test_md5_tokens(self): def test_md5_tokens(self):
md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1) md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1)
self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808L) self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808)
self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639L) self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639)
self.assertEqual(str(md5_token), '<MD5Token: -9223372036854775809L>') self.assertEqual(str(md5_token), '<MD5Token: %s>' % -9223372036854775809)
def test_bytes_tokens(self): def test_bytes_tokens(self):
bytes_token = BytesToken(str(cassandra.metadata.MIN_LONG - 1)) bytes_token = BytesToken(str(cassandra.metadata.MIN_LONG - 1))
self.assertEqual(bytes_token.hash_fn('123'), '123') self.assertEqual(bytes_token.hash_fn('123'), '123')
self.assertEqual(bytes_token.hash_fn(123), 123) self.assertEqual(bytes_token.hash_fn(123), 123)
self.assertEqual(bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)), str(cassandra.metadata.MAX_LONG)) self.assertEqual(bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)), str(cassandra.metadata.MAX_LONG))
self.assertEqual(str(bytes_token), "<BytesToken: '-9223372036854775809'>") self.assertEqual(str(bytes_token), "<BytesToken: -9223372036854775809>")
try: try:
bytes_token = BytesToken(cassandra.metadata.MIN_LONG - 1) bytes_token = BytesToken(cassandra.metadata.MIN_LONG - 1)

View File

@@ -22,32 +22,34 @@ from cassandra.query import PreparedStatement, BoundStatement
from cassandra.cqltypes import Int32Type from cassandra.cqltypes import Int32Type
from cassandra.util import OrderedDict from cassandra.util import OrderedDict
from six.moves import xrange
class ParamBindingTest(unittest.TestCase): class ParamBindingTest(unittest.TestCase):
def test_bind_sequence(self): def test_bind_sequence(self):
result = bind_params("%s %s %s", (1, "a", 2.0)) result = bind_params("%s %s %s", (1, "a", 2.0))
self.assertEquals(result, "1 'a' 2.0") self.assertEqual(result, "1 'a' 2.0")
def test_bind_map(self): def test_bind_map(self):
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0)) result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0))
self.assertEquals(result, "1 'a' 2.0") self.assertEqual(result, "1 'a' 2.0")
def test_sequence_param(self): def test_sequence_param(self):
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),)) result = bind_params("%s", (ValueSequence((1, "a", 2.0)),))
self.assertEquals(result, "( 1 , 'a' , 2.0 )") self.assertEqual(result, "( 1 , 'a' , 2.0 )")
def test_generator_param(self): def test_generator_param(self):
result = bind_params("%s", ((i for i in xrange(3)),)) result = bind_params("%s", ((i for i in xrange(3)),))
self.assertEquals(result, "[ 0 , 1 , 2 ]") self.assertEqual(result, "[ 0 , 1 , 2 ]")
def test_none_param(self): def test_none_param(self):
result = bind_params("%s", (None,)) result = bind_params("%s", (None,))
self.assertEquals(result, "NULL") self.assertEqual(result, "NULL")
def test_list_collection(self): def test_list_collection(self):
result = bind_params("%s", (['a', 'b', 'c'],)) result = bind_params("%s", (['a', 'b', 'c'],))
self.assertEquals(result, "[ 'a' , 'b' , 'c' ]") self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
def test_set_collection(self): def test_set_collection(self):
result = bind_params("%s", (set(['a', 'b']),)) result = bind_params("%s", (set(['a', 'b']),))
@@ -59,11 +61,11 @@ class ParamBindingTest(unittest.TestCase):
vals['b'] = 'b' vals['b'] = 'b'
vals['c'] = 'c' vals['c'] = 'c'
result = bind_params("%s", (vals,)) result = bind_params("%s", (vals,))
self.assertEquals(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }") self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
def test_quote_escaping(self): def test_quote_escaping(self):
result = bind_params("%s", ("""'ef''ef"ef""ef'""",)) result = bind_params("%s", ("""'ef''ef"ef""ef'""",))
self.assertEquals(result, """'''ef''''ef"ef""ef'''""") self.assertEqual(result, """'''ef''''ef"ef""ef'''""")
class BoundStatementTestCase(unittest.TestCase): class BoundStatementTestCase(unittest.TestCase):

View File

@@ -20,6 +20,7 @@ except ImportError:
from itertools import islice, cycle from itertools import islice, cycle
from mock import Mock from mock import Mock
from random import randint from random import randint
import six
import sys import sys
import struct import struct
from threading import Thread from threading import Thread
@@ -36,6 +37,8 @@ from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
from cassandra.pool import Host from cassandra.pool import Host
from cassandra.query import Statement from cassandra.query import Statement
from six.moves import xrange
class TestLoadBalancingPolicy(unittest.TestCase): class TestLoadBalancingPolicy(unittest.TestCase):
def test_non_implemented(self): def test_non_implemented(self):
@@ -137,13 +140,23 @@ class TestRoundRobinPolicy(unittest.TestCase):
# make the GIL switch after every instruction, maximizing # make the GIL switch after every instruction, maximizing
# the chace of race conditions # the chace of race conditions
original_interval = sys.getcheckinterval() if six.PY2:
original_interval = sys.getcheckinterval()
else:
original_interval = sys.getswitchinterval()
try: try:
sys.setcheckinterval(0) if six.PY2:
sys.setcheckinterval(0)
else:
sys.setswitchinterval(0.0001)
map(lambda t: t.start(), threads) map(lambda t: t.start(), threads)
map(lambda t: t.join(), threads) map(lambda t: t.join(), threads)
finally: finally:
sys.setcheckinterval(original_interval) if six.PY2:
sys.setcheckinterval(original_interval)
else:
sys.setswitchinterval(original_interval)
if errors: if errors:
self.fail("Saw errors: %s" % (errors,)) self.fail("Saw errors: %s" % (errors,))
@@ -334,14 +347,14 @@ class TokenAwarePolicyTest(unittest.TestCase):
replicas = get_replicas(None, struct.pack('>i', i)) replicas = get_replicas(None, struct.pack('>i', i))
other = set(h for h in hosts if h not in replicas) other = set(h for h in hosts if h not in replicas)
self.assertEquals(replicas, qplan[:2]) self.assertEqual(replicas, qplan[:2])
self.assertEquals(other, set(qplan[2:])) self.assertEqual(other, set(qplan[2:]))
# Should use the secondary policy # Should use the secondary policy
for i in range(4): for i in range(4):
qplan = list(policy.make_query_plan()) qplan = list(policy.make_query_plan())
self.assertEquals(set(qplan), set(hosts)) self.assertEqual(set(qplan), set(hosts))
def test_wrap_dc_aware(self): def test_wrap_dc_aware(self):
cluster = Mock(spec=Cluster) cluster = Mock(spec=Cluster)
@@ -374,16 +387,16 @@ class TokenAwarePolicyTest(unittest.TestCase):
# first should be the only local replica # first should be the only local replica
self.assertIn(qplan[0], replicas) self.assertIn(qplan[0], replicas)
self.assertEquals(qplan[0].datacenter, "dc1") self.assertEqual(qplan[0].datacenter, "dc1")
# then the local non-replica # then the local non-replica
self.assertNotIn(qplan[1], replicas) self.assertNotIn(qplan[1], replicas)
self.assertEquals(qplan[1].datacenter, "dc1") self.assertEqual(qplan[1].datacenter, "dc1")
# then one of the remotes (used_hosts_per_remote_dc is 1, so we # then one of the remotes (used_hosts_per_remote_dc is 1, so we
# shouldn't see two remotes) # shouldn't see two remotes)
self.assertEquals(qplan[2].datacenter, "dc2") self.assertEqual(qplan[2].datacenter, "dc2")
self.assertEquals(3, len(qplan)) self.assertEqual(3, len(qplan))
class FakeCluster: class FakeCluster:
def __init__(self): def __init__(self):

View File

@@ -346,7 +346,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.send_request() rf.send_request()
rf.add_callbacks( rf.add_callbacks(
callback=self.assertEquals, callback_args=([{'col': 'val'}],), callback=self.assertEqual, callback_args=([{'col': 'val'}],),
errback=self.assertIsInstance, errback_args=(Exception,)) errback=self.assertIsInstance, errback_args=(Exception,))
result = Mock(spec=UnavailableErrorMessage, info={}) result = Mock(spec=UnavailableErrorMessage, info={})
@@ -358,7 +358,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.send_request() rf.send_request()
rf.add_callbacks( rf.add_callbacks(
callback=self.assertEquals, callback_args=([{'col': 'val'}],), callback=self.assertEqual, callback_args=([{'col': 'val'}],),
errback=self.assertIsInstance, errback_args=(Exception,)) errback=self.assertIsInstance, errback_args=(Exception,))
rf._set_result(self.make_mock_response([{'col': 'val'}])) rf._set_result(self.make_mock_response([{'col': 'val'}]))
@@ -380,9 +380,9 @@ class ResponseFutureTests(unittest.TestCase):
session.submit.assert_called_once() session.submit.assert_called_once()
args, kwargs = session.submit.call_args args, kwargs = session.submit.call_args
self.assertEquals(rf._reprepare, args[-2]) self.assertEqual(rf._reprepare, args[-2])
self.assertIsInstance(args[-1], PrepareMessage) self.assertIsInstance(args[-1], PrepareMessage)
self.assertEquals(args[-1].query, "SELECT * FROM foobar") self.assertEqual(args[-1].query, "SELECT * FROM foobar")
def test_prepared_query_not_found_bad_keyspace(self): def test_prepared_query_not_found_bad_keyspace(self):
session = self.make_session() session = self.make_session()

View File

@@ -163,18 +163,17 @@ class TypeTests(unittest.TestCase):
'7a6970:org.apache.cassandra.db.marshal.UTF8Type', '7a6970:org.apache.cassandra.db.marshal.UTF8Type',
')'))) ')')))
self.assertEquals(FooType, ctype.__class__) self.assertEqual(FooType, ctype.__class__)
self.assertEquals(UTF8Type, ctype.subtypes[0]) self.assertEqual(UTF8Type, ctype.subtypes[0])
# middle subtype should be a BarType instance with its own subtypes and names # middle subtype should be a BarType instance with its own subtypes and names
self.assertIsInstance(ctype.subtypes[1], BarType) self.assertIsInstance(ctype.subtypes[1], BarType)
self.assertEquals([UTF8Type], ctype.subtypes[1].subtypes) self.assertEqual([UTF8Type], ctype.subtypes[1].subtypes)
self.assertEquals(["address"], ctype.subtypes[1].names) self.assertEqual([b"address"], ctype.subtypes[1].names)
self.assertEquals(UTF8Type, ctype.subtypes[2]) self.assertEqual(UTF8Type, ctype.subtypes[2])
self.assertEqual([b'city', None, b'zip'], ctype.names)
self.assertEquals(['city', None, 'zip'], ctype.names)
def test_empty_value(self): def test_empty_value(self):
self.assertEqual(str(EmptyValue()), 'EMPTY') self.assertEqual(str(EmptyValue()), 'EMPTY')

10
tox.ini
View File

@@ -1,5 +1,5 @@
[tox] [tox]
envlist = py26,py27,pypy envlist = py26,py27,pypy,py33
[testenv] [testenv]
deps = nose deps = nose
@@ -8,5 +8,13 @@ deps = nose
unittest2 unittest2
pip pip
PyYAML PyYAML
six
commands = {envpython} setup.py build_ext --inplace commands = {envpython} setup.py build_ext --inplace
nosetests --verbosity=2 tests/unit/ nosetests --verbosity=2 tests/unit/
[testenv:py33]
deps = nose
mock
pip
PyYAML
six