Merge branch 'py3k' into 2.0
Conflicts: cassandra/cluster.py cassandra/encoder.py cassandra/marshal.py cassandra/pool.py setup.py tests/integration/long/test_large_data.py tests/integration/long/utils.py tests/integration/standard/test_metadata.py tests/integration/standard/test_prepared_statements.py tests/unit/io/test_asyncorereactor.py tests/unit/test_connection.py tests/unit/test_types.py
This commit is contained in:
@@ -4,11 +4,13 @@ env:
|
||||
- TOX_ENV=py26
|
||||
- TOX_ENV=py27
|
||||
- TOX_ENV=pypy
|
||||
- TOX_ENV=py33
|
||||
|
||||
before_install:
|
||||
- sudo apt-get update -y
|
||||
- sudo apt-get install -y build-essential python-dev
|
||||
- sudo apt-get install -y libev4 libev-dev
|
||||
|
||||
install:
|
||||
- pip install tox
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ try:
|
||||
from cassandra.io.libevreactor import LibevConnection
|
||||
have_libev = True
|
||||
supported_reactors.append(LibevConnection)
|
||||
except ImportError, exc:
|
||||
except ImportError as exc:
|
||||
pass
|
||||
|
||||
KEYSPACE = "testkeyspace"
|
||||
@@ -104,7 +104,7 @@ def benchmark(thread_class):
|
||||
""".format(table=TABLE))
|
||||
values = ('key', 'a', 'b')
|
||||
|
||||
per_thread = options.num_ops / options.threads
|
||||
per_thread = options.num_ops // options.threads
|
||||
threads = []
|
||||
|
||||
log.debug("Beginning inserts...")
|
||||
|
||||
@@ -18,7 +18,7 @@ from itertools import count
|
||||
from threading import Event
|
||||
|
||||
from base import benchmark, BenchmarkThread
|
||||
|
||||
from six.moves import range
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,17 +38,17 @@ class Runner(BenchmarkThread):
|
||||
if previous_result is not sentinel:
|
||||
if isinstance(previous_result, BaseException):
|
||||
log.error("Error on insert: %r", previous_result)
|
||||
if self.num_finished.next() >= self.num_queries:
|
||||
if next(self.num_finished) >= self.num_queries:
|
||||
self.event.set()
|
||||
|
||||
if self.num_started.next() <= self.num_queries:
|
||||
if next(self.num_started) <= self.num_queries:
|
||||
future = self.session.execute_async(self.query, self.values)
|
||||
future.add_callbacks(self.insert_next, self.insert_next)
|
||||
|
||||
def run(self):
|
||||
self.start_profile()
|
||||
|
||||
for _ in xrange(min(120, self.num_queries)):
|
||||
for _ in range(min(120, self.num_queries)):
|
||||
self.insert_next()
|
||||
|
||||
self.event.wait()
|
||||
|
||||
@@ -13,16 +13,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import Queue
|
||||
|
||||
from base import benchmark, BenchmarkThread
|
||||
from six.moves import queue
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Runner(BenchmarkThread):
|
||||
|
||||
def run(self):
|
||||
futures = Queue.Queue(maxsize=121)
|
||||
futures = queue.Queue(maxsize=121)
|
||||
|
||||
self.start_profile()
|
||||
|
||||
@@ -32,7 +32,7 @@ class Runner(BenchmarkThread):
|
||||
while True:
|
||||
try:
|
||||
futures.get_nowait().result()
|
||||
except Queue.Empty:
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
future = self.session.execute_async(self.query, self.values)
|
||||
@@ -41,7 +41,7 @@ class Runner(BenchmarkThread):
|
||||
while True:
|
||||
try:
|
||||
futures.get_nowait().result()
|
||||
except Queue.Empty:
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
self.finish_profile()
|
||||
|
||||
@@ -13,16 +13,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import Queue
|
||||
|
||||
from base import benchmark, BenchmarkThread
|
||||
from six.moves import queue
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Runner(BenchmarkThread):
|
||||
|
||||
def run(self):
|
||||
futures = Queue.Queue(maxsize=121)
|
||||
futures = queue.Queue(maxsize=121)
|
||||
|
||||
self.start_profile()
|
||||
|
||||
@@ -37,7 +37,7 @@ class Runner(BenchmarkThread):
|
||||
while True:
|
||||
try:
|
||||
futures.get_nowait().result()
|
||||
except Queue.Empty:
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
self.finish_profile
|
||||
|
||||
@@ -25,7 +25,7 @@ class Runner(BenchmarkThread):
|
||||
|
||||
self.start_profile()
|
||||
|
||||
for i in range(self.num_queries):
|
||||
for _ in range(self.num_queries):
|
||||
future = self.session.execute_async(self.query, self.values)
|
||||
futures.append(future)
|
||||
|
||||
|
||||
@@ -13,13 +13,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
from base import benchmark, BenchmarkThread
|
||||
from six.moves import range
|
||||
|
||||
|
||||
class Runner(BenchmarkThread):
|
||||
|
||||
def run(self):
|
||||
self.start_profile()
|
||||
|
||||
for i in xrange(self.num_queries):
|
||||
for _ in range(self.num_queries):
|
||||
self.session.execute(self.query, self.values)
|
||||
|
||||
self.finish_profile()
|
||||
|
||||
@@ -26,7 +26,11 @@ import socket
|
||||
import sys
|
||||
import time
|
||||
from threading import Lock, RLock, Thread, Event
|
||||
import Queue
|
||||
|
||||
import six
|
||||
from six.moves import range
|
||||
from six.moves import queue as Queue
|
||||
|
||||
import weakref
|
||||
from weakref import WeakValueDictionary
|
||||
try:
|
||||
@@ -696,7 +700,7 @@ class Cluster(object):
|
||||
|
||||
host.set_down()
|
||||
|
||||
log.warn("Host %s has been marked down", host)
|
||||
log.warning("Host %s has been marked down", host)
|
||||
|
||||
self.load_balancing_policy.on_down(host)
|
||||
self.control_connection.on_down(host)
|
||||
@@ -742,7 +746,7 @@ class Cluster(object):
|
||||
return
|
||||
|
||||
if not all(futures_results):
|
||||
log.warn("Connection pool could not be created, not marking node %s up", host)
|
||||
log.warning("Connection pool could not be created, not marking node %s up", host)
|
||||
return
|
||||
|
||||
self._finalize_add(host)
|
||||
@@ -867,7 +871,7 @@ class Cluster(object):
|
||||
# prepare 10 statements at a time
|
||||
ks_statements = list(ks_statements)
|
||||
chunks = []
|
||||
for i in xrange(0, len(ks_statements), 10):
|
||||
for i in range(0, len(ks_statements), 10):
|
||||
chunks.append(ks_statements[i:i + 10])
|
||||
|
||||
for ks_chunk in chunks:
|
||||
@@ -882,9 +886,9 @@ class Cluster(object):
|
||||
|
||||
log.debug("Done preparing all known prepared statements against host %s", host)
|
||||
except OperationTimedOut as timeout:
|
||||
log.warn("Timed out trying to prepare all statements on host %s: %s", host, timeout)
|
||||
log.warning("Timed out trying to prepare all statements on host %s: %s", host, timeout)
|
||||
except (ConnectionException, socket.error) as exc:
|
||||
log.warn("Error trying to prepare all statements on host %s: %r", host, exc)
|
||||
log.warning("Error trying to prepare all statements on host %s: %r", host, exc)
|
||||
except Exception:
|
||||
log.exception("Error trying to prepare all statements on host %s", host)
|
||||
finally:
|
||||
@@ -1088,7 +1092,7 @@ class Session(object):
|
||||
|
||||
prepared_statement = None
|
||||
|
||||
if isinstance(query, basestring):
|
||||
if isinstance(query, six.string_types):
|
||||
query = SimpleStatement(query)
|
||||
elif isinstance(query, PreparedStatement):
|
||||
query = query.bind(parameters)
|
||||
@@ -1235,8 +1239,8 @@ class Session(object):
|
||||
self.cluster.signal_connection_failure(host, conn_exc, is_host_addition)
|
||||
return False
|
||||
except Exception as conn_exc:
|
||||
log.warn("Failed to create connection pool for new host %s: %s",
|
||||
host, conn_exc)
|
||||
log.warning("Failed to create connection pool for new host %s: %s",
|
||||
host, conn_exc)
|
||||
# the host itself will still be marked down, so we need to pass
|
||||
# a special flag to make sure the reconnector is created
|
||||
self.cluster.signal_connection_failure(
|
||||
@@ -1456,11 +1460,11 @@ class ControlConnection(object):
|
||||
return self._try_connect(host)
|
||||
except ConnectionException as exc:
|
||||
errors[host.address] = exc
|
||||
log.warn("[control connection] Error connecting to %s:", host, exc_info=True)
|
||||
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
|
||||
self._cluster.signal_connection_failure(host, exc, is_host_addition=False)
|
||||
except Exception as exc:
|
||||
errors[host.address] = exc
|
||||
log.warn("[control connection] Error connecting to %s:", host, exc_info=True)
|
||||
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
|
||||
|
||||
raise NoHostAvailable("Unable to connect to any servers", errors)
|
||||
|
||||
@@ -1948,7 +1952,7 @@ class _Scheduler(object):
|
||||
def _log_if_failed(self, future):
|
||||
exc = future.exception()
|
||||
if exc:
|
||||
log.warn(
|
||||
log.warning(
|
||||
"An internally scheduled tasked failed with an unhandled exception:",
|
||||
exc_info=exc)
|
||||
|
||||
@@ -2170,8 +2174,8 @@ class ResponseFuture(object):
|
||||
if self._metrics is not None:
|
||||
self._metrics.on_other_error()
|
||||
# need to retry against a different host here
|
||||
log.warn("Host %s is overloaded, retrying against a different "
|
||||
"host", self._current_host)
|
||||
log.warning("Host %s is overloaded, retrying against a different "
|
||||
"host", self._current_host)
|
||||
self._retry(reuse_connection=False, consistency_level=None)
|
||||
return
|
||||
elif isinstance(response, IsBootstrappingErrorMessage):
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import sys
|
||||
|
||||
from itertools import count, cycle
|
||||
from six.moves import xrange
|
||||
from threading import Event
|
||||
|
||||
|
||||
@@ -105,7 +106,7 @@ def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs
|
||||
parameters = [(x,) for x in range(1000)]
|
||||
execute_concurrent_with_args(session, statement, parameters)
|
||||
"""
|
||||
return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs)
|
||||
return execute_concurrent(session, list(zip(cycle((statement,)), parameters)), *args, **kwargs)
|
||||
|
||||
|
||||
_sentinel = object()
|
||||
@@ -118,12 +119,12 @@ def _handle_error(error, result_index, event, session, statements, results, num_
|
||||
return
|
||||
else:
|
||||
results[result_index] = (False, error)
|
||||
if num_finished.next() >= to_execute:
|
||||
if next(num_finished) >= to_execute:
|
||||
event.set()
|
||||
return
|
||||
|
||||
try:
|
||||
(next_index, (statement, params)) = statements.next()
|
||||
(next_index, (statement, params)) = next(statements)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
@@ -139,7 +140,7 @@ def _handle_error(error, result_index, event, session, statements, results, num_
|
||||
return
|
||||
else:
|
||||
results[next_index] = (False, exc)
|
||||
if num_finished.next() >= to_execute:
|
||||
if next(num_finished) >= to_execute:
|
||||
event.set()
|
||||
return
|
||||
|
||||
@@ -147,13 +148,13 @@ def _handle_error(error, result_index, event, session, statements, results, num_
|
||||
def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error):
|
||||
if result is not _sentinel:
|
||||
results[result_index] = (True, result)
|
||||
finished = num_finished.next()
|
||||
finished = next(num_finished)
|
||||
if finished >= to_execute:
|
||||
event.set()
|
||||
return
|
||||
|
||||
try:
|
||||
(next_index, (statement, params)) = statements.next()
|
||||
(next_index, (statement, params)) = next(statements)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
@@ -169,6 +170,6 @@ def _execute_next(result, result_index, event, session, statements, results, num
|
||||
return
|
||||
else:
|
||||
results[next_index] = (False, exc)
|
||||
if num_finished.next() >= to_execute:
|
||||
if next(num_finished) >= to_execute:
|
||||
event.set()
|
||||
return
|
||||
|
||||
@@ -19,19 +19,21 @@ import logging
|
||||
import sys
|
||||
from threading import Event, RLock
|
||||
import time
|
||||
import traceback
|
||||
|
||||
if 'gevent.monkey' in sys.modules:
|
||||
from gevent.queue import Queue, Empty
|
||||
else:
|
||||
from Queue import Queue, Empty # noqa
|
||||
from six.moves.queue import Queue, Empty # noqa
|
||||
|
||||
from six.moves import range
|
||||
|
||||
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
|
||||
from cassandra.marshal import int8_unpack, int32_pack
|
||||
from cassandra.marshal import int32_pack, header_unpack
|
||||
from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage,
|
||||
StartupMessage, ErrorMessage, CredentialsMessage,
|
||||
QueryMessage, ResultMessage, decode_response,
|
||||
InvalidRequestException, SupportedMessage)
|
||||
import six
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -123,7 +125,6 @@ def defunct_on_error(f):
|
||||
return f(self, *args, **kwargs)
|
||||
except Exception as exc:
|
||||
self.defunct(exc)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -181,12 +182,8 @@ class Connection(object):
|
||||
return
|
||||
self.is_defunct = True
|
||||
|
||||
trace = traceback.format_exc(exc)
|
||||
if trace != "None":
|
||||
log.debug("Defuncting connection (%s) to %s: %s\n%s",
|
||||
id(self), self.host, exc, traceback.format_exc(exc))
|
||||
else:
|
||||
log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc)
|
||||
log.debug("Defuncting connection (%s) to %s:",
|
||||
id(self), self.host, exc_info=exc)
|
||||
|
||||
self.last_error = exc
|
||||
self.close()
|
||||
@@ -203,9 +200,9 @@ class Connection(object):
|
||||
try:
|
||||
cb(new_exc)
|
||||
except Exception:
|
||||
log.warn("Ignoring unhandled exception while erroring callbacks for a "
|
||||
"failed connection (%s) to host %s:",
|
||||
id(self), self.host, exc_info=True)
|
||||
log.warning("Ignoring unhandled exception while erroring callbacks for a "
|
||||
"failed connection (%s) to host %s:",
|
||||
id(self), self.host, exc_info=True)
|
||||
|
||||
def handle_pushed(self, response):
|
||||
log.debug("Message pushed from server: %r", response)
|
||||
@@ -231,7 +228,7 @@ class Connection(object):
|
||||
request_id = self._id_queue.get()
|
||||
|
||||
self._callbacks[request_id] = cb
|
||||
self.push(msg.to_string(request_id, self.protocol_version, compression=self.compressor))
|
||||
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
|
||||
return request_id
|
||||
|
||||
def wait_for_response(self, msg, timeout=None):
|
||||
@@ -268,7 +265,7 @@ class Connection(object):
|
||||
return waiter.deliver(timeout)
|
||||
except OperationTimedOut:
|
||||
raise
|
||||
except Exception, exc:
|
||||
except Exception as exc:
|
||||
self.defunct(exc)
|
||||
raise
|
||||
|
||||
@@ -284,7 +281,7 @@ class Connection(object):
|
||||
|
||||
@defunct_on_error
|
||||
def process_msg(self, msg, body_len):
|
||||
version, flags, stream_id, opcode = map(int8_unpack, msg[:4])
|
||||
version, flags, stream_id, opcode = header_unpack(msg[:4])
|
||||
if stream_id < 0:
|
||||
callback = None
|
||||
else:
|
||||
@@ -309,7 +306,7 @@ class Connection(object):
|
||||
if body_len > 0:
|
||||
body = msg[8:]
|
||||
elif body_len == 0:
|
||||
body = ""
|
||||
body = six.binary_type()
|
||||
else:
|
||||
raise ProtocolError("Got negative body length: %r" % body_len)
|
||||
|
||||
@@ -383,7 +380,7 @@ class Connection(object):
|
||||
locally_supported_compressions.keys(),
|
||||
remote_supported_compressions)
|
||||
else:
|
||||
compression_type = iter(overlap).next() # choose any
|
||||
compression_type = next(iter(overlap)) # choose any
|
||||
# set the decompressor here, but set the compressor only after
|
||||
# a successful Ready message
|
||||
self._compressor, self.decompressor = \
|
||||
|
||||
@@ -36,10 +36,8 @@ from datetime import datetime
|
||||
from uuid import UUID
|
||||
import warnings
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
from StringIO import StringIO # NOQA
|
||||
import six
|
||||
from six.moves import range
|
||||
|
||||
from cassandra.marshal import (int8_pack, int8_unpack, uint16_pack, uint16_unpack,
|
||||
int32_pack, int32_unpack, int64_pack, int64_unpack,
|
||||
@@ -49,7 +47,11 @@ from cassandra.util import OrderedDict
|
||||
|
||||
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
|
||||
|
||||
_number_types = frozenset((int, long, float))
|
||||
if six.PY3:
|
||||
_number_types = frozenset((int, float))
|
||||
long = int
|
||||
else:
|
||||
_number_types = frozenset((int, long, float))
|
||||
|
||||
try:
|
||||
from blist import sortedset
|
||||
@@ -69,7 +71,7 @@ def trim_if_startswith(s, prefix):
|
||||
|
||||
|
||||
def unix_time_from_uuid1(u):
|
||||
return (u.get_time() - 0x01B21DD213814000) / 10000000.0
|
||||
return (u.time - 0x01B21DD213814000) / 10000000.0
|
||||
|
||||
_casstypes = {}
|
||||
|
||||
@@ -177,8 +179,8 @@ class EmptyValue(object):
|
||||
EMPTY = EmptyValue()
|
||||
|
||||
|
||||
@six.add_metaclass(CassandraTypeType)
|
||||
class _CassandraType(object):
|
||||
__metaclass__ = CassandraTypeType
|
||||
subtypes = ()
|
||||
num_subtypes = 0
|
||||
empty_binary_ok = False
|
||||
@@ -199,9 +201,8 @@ class _CassandraType(object):
|
||||
def __init__(self, val):
|
||||
self.val = self.validate(val)
|
||||
|
||||
def __str__(self):
|
||||
def __repr__(self):
|
||||
return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
|
||||
__repr__ = __str__
|
||||
|
||||
@staticmethod
|
||||
def validate(val):
|
||||
@@ -221,7 +222,7 @@ class _CassandraType(object):
|
||||
"""
|
||||
if byts is None:
|
||||
return None
|
||||
elif byts == '' and not cls.empty_binary_ok:
|
||||
elif len(byts) == 0 and not cls.empty_binary_ok:
|
||||
return EMPTY if cls.support_empty_values else None
|
||||
return cls.deserialize(byts)
|
||||
|
||||
@@ -232,7 +233,7 @@ class _CassandraType(object):
|
||||
more information. This method differs in that if None is passed in,
|
||||
the result is the empty string.
|
||||
"""
|
||||
return '' if val is None else cls.serialize(val)
|
||||
return b'' if val is None else cls.serialize(val)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(byts):
|
||||
@@ -293,7 +294,8 @@ class _CassandraType(object):
|
||||
if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes:
|
||||
raise ValueError("%s types require %d subtypes (%d given)"
|
||||
% (cls.typename, cls.num_subtypes, len(subtypes)))
|
||||
newname = cls.cass_parameterized_type_with(subtypes).encode('utf8')
|
||||
# newname = cls.cass_parameterized_type_with(subtypes).encode('utf8')
|
||||
newname = cls.cass_parameterized_type_with(subtypes)
|
||||
return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname})
|
||||
|
||||
@classmethod
|
||||
@@ -324,10 +326,16 @@ class _UnrecognizedType(_CassandraType):
|
||||
num_subtypes = 'UNKNOWN'
|
||||
|
||||
|
||||
def mkUnrecognizedType(casstypename):
|
||||
return CassandraTypeType(casstypename.encode('utf8'),
|
||||
(_UnrecognizedType,),
|
||||
{'typename': "'%s'" % casstypename})
|
||||
if six.PY3:
|
||||
def mkUnrecognizedType(casstypename):
|
||||
return CassandraTypeType(casstypename,
|
||||
(_UnrecognizedType,),
|
||||
{'typename': "'%s'" % casstypename})
|
||||
else:
|
||||
def mkUnrecognizedType(casstypename): # noqa
|
||||
return CassandraTypeType(casstypename.encode('utf8'),
|
||||
(_UnrecognizedType,),
|
||||
{'typename': "'%s'" % casstypename})
|
||||
|
||||
|
||||
class BytesType(_CassandraType):
|
||||
@@ -336,11 +344,11 @@ class BytesType(_CassandraType):
|
||||
|
||||
@staticmethod
|
||||
def validate(val):
|
||||
return buffer(val)
|
||||
return bytearray(val)
|
||||
|
||||
@staticmethod
|
||||
def serialize(val):
|
||||
return str(val)
|
||||
return six.binary_type(val)
|
||||
|
||||
|
||||
class DecimalType(_CassandraType):
|
||||
@@ -401,9 +409,25 @@ class BooleanType(_CassandraType):
|
||||
return int8_pack(truth)
|
||||
|
||||
|
||||
class AsciiType(_CassandraType):
|
||||
typename = 'ascii'
|
||||
empty_binary_ok = True
|
||||
if six.PY2:
|
||||
class AsciiType(_CassandraType):
|
||||
typename = 'ascii'
|
||||
empty_binary_ok = True
|
||||
else:
|
||||
class AsciiType(_CassandraType):
|
||||
typename = 'ascii'
|
||||
empty_binary_ok = True
|
||||
|
||||
@staticmethod
|
||||
def deserialize(byts):
|
||||
return byts.decode('ascii')
|
||||
|
||||
@staticmethod
|
||||
def serialize(var):
|
||||
try:
|
||||
return var.encode('ascii')
|
||||
except UnicodeDecodeError:
|
||||
return var
|
||||
|
||||
|
||||
class FloatType(_CassandraType):
|
||||
@@ -496,7 +520,7 @@ class DateType(_CassandraType):
|
||||
|
||||
@classmethod
|
||||
def validate(cls, date):
|
||||
if isinstance(date, basestring):
|
||||
if isinstance(date, six.string_types):
|
||||
date = cls.interpret_datestring(date)
|
||||
return date
|
||||
|
||||
@@ -628,7 +652,7 @@ class _SimpleParameterizedType(_ParameterizedType):
|
||||
numelements = uint16_unpack(byts[:2])
|
||||
p = 2
|
||||
result = []
|
||||
for n in xrange(numelements):
|
||||
for _ in range(numelements):
|
||||
itemlen = uint16_unpack(byts[p:p + 2])
|
||||
p += 2
|
||||
item = byts[p:p + itemlen]
|
||||
@@ -638,11 +662,11 @@ class _SimpleParameterizedType(_ParameterizedType):
|
||||
|
||||
@classmethod
|
||||
def serialize_safe(cls, items):
|
||||
if isinstance(items, basestring):
|
||||
if isinstance(items, six.string_types):
|
||||
raise TypeError("Received a string for a type that expects a sequence")
|
||||
|
||||
subtype, = cls.subtypes
|
||||
buf = StringIO()
|
||||
buf = six.BytesIO()
|
||||
buf.write(uint16_pack(len(items)))
|
||||
for item in items:
|
||||
itembytes = subtype.to_binary(item)
|
||||
@@ -670,7 +694,7 @@ class MapType(_ParameterizedType):
|
||||
@classmethod
|
||||
def validate(cls, val):
|
||||
subkeytype, subvaltype = cls.subtypes
|
||||
return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in val.iteritems())
|
||||
return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in six.iteritems(val))
|
||||
|
||||
@classmethod
|
||||
def deserialize_safe(cls, byts):
|
||||
@@ -678,7 +702,7 @@ class MapType(_ParameterizedType):
|
||||
numelements = uint16_unpack(byts[:2])
|
||||
p = 2
|
||||
themap = OrderedDict()
|
||||
for n in xrange(numelements):
|
||||
for _ in range(numelements):
|
||||
key_len = uint16_unpack(byts[p:p + 2])
|
||||
p += 2
|
||||
keybytes = byts[p:p + key_len]
|
||||
@@ -695,10 +719,10 @@ class MapType(_ParameterizedType):
|
||||
@classmethod
|
||||
def serialize_safe(cls, themap):
|
||||
subkeytype, subvaltype = cls.subtypes
|
||||
buf = StringIO()
|
||||
buf = six.BytesIO()
|
||||
buf.write(uint16_pack(len(themap)))
|
||||
try:
|
||||
items = themap.iteritems()
|
||||
items = six.iteritems(themap)
|
||||
except AttributeError:
|
||||
raise TypeError("Got a non-map object for a map value")
|
||||
for key, val in items:
|
||||
@@ -747,7 +771,7 @@ class ReversedType(_ParameterizedType):
|
||||
|
||||
|
||||
def is_counter_type(t):
|
||||
if isinstance(t, basestring):
|
||||
if isinstance(t, six.string_types):
|
||||
t = lookup_casstype(t)
|
||||
return issubclass(t, CounterColumnType)
|
||||
|
||||
|
||||
@@ -16,16 +16,14 @@ import logging
|
||||
import socket
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
from StringIO import StringIO # ignore flake8 warning: # NOQA
|
||||
import six
|
||||
from six.moves import range
|
||||
|
||||
from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
|
||||
AlreadyExists, InvalidRequest, Unauthorized,
|
||||
UnsupportedOperation)
|
||||
from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
|
||||
int8_pack, int8_unpack)
|
||||
int8_pack, int8_unpack, header_pack)
|
||||
from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
|
||||
CounterColumnType, DateType, DecimalType,
|
||||
DoubleType, FloatType, Int32Type,
|
||||
@@ -48,66 +46,75 @@ HEADER_DIRECTION_FROM_CLIENT = 0x00
|
||||
HEADER_DIRECTION_TO_CLIENT = 0x80
|
||||
HEADER_DIRECTION_MASK = 0x80
|
||||
|
||||
COMPRESSED_FLAG = 0x01
|
||||
TRACING_FLAG = 0x02
|
||||
|
||||
_message_types_by_name = {}
|
||||
_message_types_by_opcode = {}
|
||||
|
||||
|
||||
class _register_msg_type(type):
|
||||
class _RegisterMessageType(type):
|
||||
def __init__(cls, name, bases, dct):
|
||||
if not name.startswith('_'):
|
||||
_message_types_by_name[cls.name] = cls
|
||||
_message_types_by_opcode[cls.opcode] = cls
|
||||
|
||||
|
||||
@six.add_metaclass(_RegisterMessageType)
|
||||
class _MessageType(object):
|
||||
__metaclass__ = _register_msg_type
|
||||
|
||||
tracing = False
|
||||
|
||||
def to_string(self, stream_id, protocol_version, compression=None):
|
||||
body = StringIO()
|
||||
def to_binary(self, stream_id, protocol_version, compression=None):
|
||||
body = six.BytesIO()
|
||||
self.send_body(body, protocol_version)
|
||||
body = body.getvalue()
|
||||
version = protocol_version | HEADER_DIRECTION_FROM_CLIENT
|
||||
flags = 0
|
||||
if compression is not None and len(body) > 0:
|
||||
body = compression(body)
|
||||
flags |= 0x01
|
||||
if self.tracing:
|
||||
flags |= 0x02
|
||||
msglen = int32_pack(len(body))
|
||||
msg_parts = map(int8_pack, (version, flags, stream_id, self.opcode)) + [msglen, body]
|
||||
return ''.join(msg_parts)
|
||||
|
||||
def __str__(self):
|
||||
paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in _get_params(self)]
|
||||
return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs))
|
||||
__repr__ = __str__
|
||||
flags = 0
|
||||
if compression and len(body) > 0:
|
||||
body = compression(body)
|
||||
flags |= COMPRESSED_FLAG
|
||||
if self.tracing:
|
||||
flags |= TRACING_FLAG
|
||||
|
||||
msg = six.BytesIO()
|
||||
write_header(
|
||||
msg,
|
||||
protocol_version | HEADER_DIRECTION_FROM_CLIENT,
|
||||
flags, stream_id, self.opcode, len(body)
|
||||
)
|
||||
msg.write(body)
|
||||
|
||||
return msg.getvalue()
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
|
||||
|
||||
|
||||
def _get_params(message_obj):
|
||||
base_attrs = dir(_MessageType)
|
||||
return [a for a in dir(message_obj)
|
||||
if a not in base_attrs and not a.startswith('_') and not callable(getattr(message_obj, a))]
|
||||
return (
|
||||
(n, a) for n, a in message_obj.__dict__.items()
|
||||
if n not in base_attrs and not n.startswith('_') and not callable(a)
|
||||
)
|
||||
|
||||
|
||||
def decode_response(stream_id, flags, opcode, body, decompressor=None):
|
||||
if flags & 0x01:
|
||||
if flags & COMPRESSED_FLAG:
|
||||
if decompressor is None:
|
||||
raise Exception("No decompressor available for compressed frame!")
|
||||
raise Exception("No de-compressor available for compressed frame!")
|
||||
body = decompressor(body)
|
||||
flags ^= 0x01
|
||||
flags ^= COMPRESSED_FLAG
|
||||
|
||||
body = StringIO(body)
|
||||
if flags & 0x02:
|
||||
body = six.BytesIO(body)
|
||||
if flags & TRACING_FLAG:
|
||||
trace_id = UUID(bytes=body.read(16))
|
||||
flags ^= 0x02
|
||||
flags ^= TRACING_FLAG
|
||||
else:
|
||||
trace_id = None
|
||||
|
||||
if flags:
|
||||
log.warn("Unknown protocol flags set: %02x. May cause problems.", flags)
|
||||
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
|
||||
|
||||
msg_class = _message_types_by_opcode[opcode]
|
||||
msg = msg_class.recv_body(body)
|
||||
@@ -156,14 +163,14 @@ class ErrorMessage(_MessageType, Exception):
|
||||
return self
|
||||
|
||||
|
||||
class ErrorMessageSubclass(_register_msg_type):
|
||||
class ErrorMessageSubclass(_RegisterMessageType):
|
||||
def __init__(cls, name, bases, dct):
|
||||
if cls.error_code is not None:
|
||||
if cls.error_code is not None: # Server has an error code of 0.
|
||||
error_classes[cls.error_code] = cls
|
||||
|
||||
|
||||
@six.add_metaclass(ErrorMessageSubclass)
|
||||
class ErrorMessageSub(ErrorMessage):
|
||||
__metaclass__ = ErrorMessageSubclass
|
||||
error_code = None
|
||||
|
||||
|
||||
@@ -511,7 +518,7 @@ class ResultMessage(_MessageType):
|
||||
def recv_results_rows(cls, f):
|
||||
paging_state, column_metadata = cls.recv_results_metadata(f)
|
||||
rowcount = read_int(f)
|
||||
rows = [cls.recv_row(f, len(column_metadata)) for x in xrange(rowcount)]
|
||||
rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
|
||||
colnames = [c[2] for c in column_metadata]
|
||||
coltypes = [c[3] for c in column_metadata]
|
||||
return (
|
||||
@@ -538,7 +545,7 @@ class ResultMessage(_MessageType):
|
||||
ksname = read_string(f)
|
||||
cfname = read_string(f)
|
||||
column_metadata = []
|
||||
for x in xrange(colcount):
|
||||
for _ in range(colcount):
|
||||
if glob_tblspec:
|
||||
colksname = ksname
|
||||
colcfname = cfname
|
||||
@@ -580,7 +587,7 @@ class ResultMessage(_MessageType):
|
||||
|
||||
@staticmethod
|
||||
def recv_row(f, colcount):
|
||||
return [read_value(f) for x in xrange(colcount)]
|
||||
return [read_value(f) for _ in range(colcount)]
|
||||
|
||||
|
||||
class PrepareMessage(_MessageType):
|
||||
@@ -729,6 +736,14 @@ class EventMessage(_MessageType):
|
||||
return dict(change_type=change_type, keyspace=keyspace, table=table)
|
||||
|
||||
|
||||
def write_header(f, version, flags, stream_id, opcode, length):
|
||||
"""
|
||||
Write a CQL protocol frame header.
|
||||
"""
|
||||
f.write(header_pack(version, flags, stream_id, opcode))
|
||||
write_int(f, length)
|
||||
|
||||
|
||||
def read_byte(f):
|
||||
return int8_unpack(f.read(1))
|
||||
|
||||
@@ -774,7 +789,7 @@ def read_binary_string(f):
|
||||
|
||||
|
||||
def write_string(f, s):
|
||||
if isinstance(s, unicode):
|
||||
if isinstance(s, six.text_type):
|
||||
s = s.encode('utf8')
|
||||
write_short(f, len(s))
|
||||
f.write(s)
|
||||
@@ -791,7 +806,7 @@ def read_longstring(f):
|
||||
|
||||
|
||||
def write_longstring(f, s):
|
||||
if isinstance(s, unicode):
|
||||
if isinstance(s, six.text_type):
|
||||
s = s.encode('utf8')
|
||||
write_int(f, len(s))
|
||||
f.write(s)
|
||||
@@ -799,7 +814,7 @@ def write_longstring(f, s):
|
||||
|
||||
def read_stringlist(f):
|
||||
numstrs = read_short(f)
|
||||
return [read_string(f) for x in xrange(numstrs)]
|
||||
return [read_string(f) for _ in range(numstrs)]
|
||||
|
||||
|
||||
def write_stringlist(f, stringlist):
|
||||
@@ -811,7 +826,7 @@ def write_stringlist(f, stringlist):
|
||||
def read_stringmap(f):
|
||||
numpairs = read_short(f)
|
||||
strmap = {}
|
||||
for x in xrange(numpairs):
|
||||
for _ in range(numpairs):
|
||||
k = read_string(f)
|
||||
strmap[k] = read_string(f)
|
||||
return strmap
|
||||
@@ -827,7 +842,7 @@ def write_stringmap(f, strmap):
|
||||
def read_stringmultimap(f):
|
||||
numkeys = read_short(f)
|
||||
strmmap = {}
|
||||
for x in xrange(numkeys):
|
||||
for _ in range(numkeys):
|
||||
k = read_string(f)
|
||||
strmmap[k] = read_stringlist(f)
|
||||
return strmmap
|
||||
|
||||
@@ -12,21 +12,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from binascii import hexlify
|
||||
import calendar
|
||||
import datetime
|
||||
import sys
|
||||
import types
|
||||
from uuid import UUID
|
||||
import six
|
||||
|
||||
from cassandra.util import OrderedDict
|
||||
|
||||
if six.PY3:
|
||||
long = int
|
||||
|
||||
|
||||
def cql_quote(term):
|
||||
if isinstance(term, unicode):
|
||||
return "'%s'" % term.encode('utf8').replace("'", "''")
|
||||
elif isinstance(term, (str, bool)):
|
||||
# The ordering of this method is important for the result of this method to
|
||||
# be a native str type (for both Python 2 and 3)
|
||||
|
||||
# Handle quoting of native str and bool types
|
||||
if isinstance(term, (str, bool)):
|
||||
return "'%s'" % str(term).replace("'", "''")
|
||||
# This branch of the if statement will only be used by Python 2 to catch
|
||||
# unicode strings, text_type is used to prevent type errors with Python 3.
|
||||
elif isinstance(term, six.text_type):
|
||||
return "'%s'" % term.encode('utf8').replace("'", "''")
|
||||
else:
|
||||
return str(term)
|
||||
|
||||
@@ -43,13 +56,16 @@ def cql_encode_str(val):
|
||||
return cql_quote(val)
|
||||
|
||||
|
||||
if sys.version_info >= (2, 7):
|
||||
if six.PY3:
|
||||
def cql_encode_bytes(val):
|
||||
return '0x' + hexlify(val)
|
||||
return (b'0x' + hexlify(val)).decode('utf-8')
|
||||
elif sys.version_info >= (2, 7):
|
||||
def cql_encode_bytes(val): # noqa
|
||||
return b'0x' + hexlify(val)
|
||||
else:
|
||||
# python 2.6 requires string or read-only buffer for hexlify
|
||||
def cql_encode_bytes(val): # noqa
|
||||
return '0x' + hexlify(buffer(val))
|
||||
return b'0x' + hexlify(buffer(val))
|
||||
|
||||
|
||||
def cql_encode_object(val):
|
||||
@@ -71,11 +87,10 @@ def cql_encode_sequence(val):
|
||||
|
||||
|
||||
def cql_encode_map_collection(val):
|
||||
return '{ %s }' % ' , '.join(
|
||||
'%s : %s' % (
|
||||
cql_encode_all_types(k),
|
||||
cql_encode_all_types(v))
|
||||
for k, v in val.iteritems())
|
||||
return '{ %s }' % ' , '.join('%s : %s' % (
|
||||
cql_encode_all_types(k),
|
||||
cql_encode_all_types(v)
|
||||
) for k, v in six.iteritems(val))
|
||||
|
||||
|
||||
def cql_encode_list_collection(val):
|
||||
@@ -92,13 +107,9 @@ def cql_encode_all_types(val):
|
||||
|
||||
cql_encoders = {
|
||||
float: cql_encode_object,
|
||||
buffer: cql_encode_bytes,
|
||||
bytearray: cql_encode_bytes,
|
||||
str: cql_encode_str,
|
||||
unicode: cql_encode_unicode,
|
||||
types.NoneType: cql_encode_none,
|
||||
int: cql_encode_object,
|
||||
long: cql_encode_object,
|
||||
UUID: cql_encode_object,
|
||||
datetime.datetime: cql_encode_datetime,
|
||||
datetime.date: cql_encode_date,
|
||||
@@ -110,3 +121,17 @@ cql_encoders = {
|
||||
frozenset: cql_encode_set_collection,
|
||||
types.GeneratorType: cql_encode_list_collection
|
||||
}
|
||||
|
||||
if six.PY2:
|
||||
cql_encoders.update({
|
||||
unicode: cql_encode_unicode,
|
||||
buffer: cql_encode_bytes,
|
||||
long: cql_encode_object,
|
||||
types.NoneType: cql_encode_none,
|
||||
})
|
||||
else:
|
||||
cql_encoders.update({
|
||||
memoryview: cql_encode_bytes,
|
||||
bytes: cql_encode_bytes,
|
||||
type(None): cql_encode_none,
|
||||
})
|
||||
|
||||
@@ -20,6 +20,10 @@ import os
|
||||
import socket
|
||||
import sys
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
from six import BytesIO
|
||||
from six.moves import range
|
||||
|
||||
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode
|
||||
try:
|
||||
from weakref import WeakSet
|
||||
@@ -28,11 +32,6 @@ except ImportError:
|
||||
|
||||
import asyncore
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
from StringIO import StringIO # ignore flake8 warning: # NOQA
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
@@ -141,7 +140,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
asyncore.dispatcher.__init__(self)
|
||||
|
||||
self.connected_event = Event()
|
||||
self._iobuf = StringIO()
|
||||
self._iobuf = BytesIO()
|
||||
|
||||
self._callbacks = {}
|
||||
self.deque = deque()
|
||||
@@ -286,7 +285,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
|
||||
# leave leftover in current buffer
|
||||
leftover = self._iobuf.read()
|
||||
self._iobuf = StringIO()
|
||||
self._iobuf = BytesIO()
|
||||
self._iobuf.write(leftover)
|
||||
|
||||
self._total_reqd_bytes = 0
|
||||
@@ -302,7 +301,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
sabs = self.out_buffer_size
|
||||
if len(data) > sabs:
|
||||
chunks = []
|
||||
for i in xrange(0, len(data), sabs):
|
||||
for i in range(0, len(data), sabs):
|
||||
chunks.append(data[i:i + sabs])
|
||||
else:
|
||||
chunks = [data]
|
||||
|
||||
@@ -20,6 +20,8 @@ import os
|
||||
import socket
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
from six import BytesIO
|
||||
|
||||
from cassandra import OperationTimedOut
|
||||
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING
|
||||
from cassandra.decoder import RegisterMessage
|
||||
@@ -35,10 +37,6 @@ except ImportError:
|
||||
"for instructions on installing build dependencies and building "
|
||||
"the C extension.")
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
from StringIO import StringIO # ignore flake8 warning: # NOQA
|
||||
|
||||
try:
|
||||
import ssl
|
||||
@@ -197,7 +195,7 @@ class LibevConnection(Connection):
|
||||
Connection.__init__(self, *args, **kwargs)
|
||||
|
||||
self.connected_event = Event()
|
||||
self._iobuf = StringIO()
|
||||
self._iobuf = BytesIO()
|
||||
|
||||
self._callbacks = {}
|
||||
self.deque = deque()
|
||||
@@ -323,7 +321,7 @@ class LibevConnection(Connection):
|
||||
|
||||
# leave leftover in current buffer
|
||||
leftover = self._iobuf.read()
|
||||
self._iobuf = StringIO()
|
||||
self._iobuf = BytesIO()
|
||||
self._iobuf.write(leftover)
|
||||
|
||||
self._total_reqd_bytes = 0
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import six
|
||||
import struct
|
||||
|
||||
|
||||
@@ -37,12 +38,24 @@ uint8_pack, uint8_unpack = _make_packer('>B')
|
||||
float_pack, float_unpack = _make_packer('>f')
|
||||
double_pack, double_unpack = _make_packer('>d')
|
||||
|
||||
# Special case for cassandra header
|
||||
header_struct = struct.Struct('>BBbB')
|
||||
header_pack = header_struct.pack
|
||||
header_unpack = header_struct.unpack
|
||||
|
||||
def varint_unpack(term):
|
||||
val = int(term.encode('hex'), 16)
|
||||
if (ord(term[0]) & 128) != 0:
|
||||
val = val - (1 << (len(term) * 8))
|
||||
return val
|
||||
|
||||
if six.PY3:
|
||||
def varint_unpack(term):
|
||||
val = int(''.join("%02x" % i for i in term), 16)
|
||||
if (term[0] & 128) != 0:
|
||||
val -= 1 << (len(term) * 8)
|
||||
return val
|
||||
else:
|
||||
def varint_unpack(term): # noqa
|
||||
val = int(term.encode('hex'), 16)
|
||||
if (ord(term[0]) & 128) != 0:
|
||||
val = val - (1 << (len(term) * 8))
|
||||
return val
|
||||
|
||||
|
||||
def bitlength(n):
|
||||
@@ -56,16 +69,16 @@ def bitlength(n):
|
||||
def varint_pack(big):
|
||||
pos = True
|
||||
if big == 0:
|
||||
return '\x00'
|
||||
return b'\x00'
|
||||
if big < 0:
|
||||
bytelength = bitlength(abs(big) - 1) / 8 + 1
|
||||
bytelength = bitlength(abs(big) - 1) // 8 + 1
|
||||
big = (1 << bytelength * 8) + big
|
||||
pos = False
|
||||
revbytes = []
|
||||
revbytes = bytearray()
|
||||
while big > 0:
|
||||
revbytes.append(chr(big & 0xff))
|
||||
revbytes.append(big & 0xff)
|
||||
big >>= 8
|
||||
if pos and ord(revbytes[-1]) & 0x80:
|
||||
revbytes.append('\x00')
|
||||
if pos and revbytes[-1] & 0x80:
|
||||
revbytes.append(0)
|
||||
revbytes.reverse()
|
||||
return ''.join(revbytes)
|
||||
return six.binary_type(revbytes)
|
||||
|
||||
@@ -21,11 +21,12 @@ import logging
|
||||
import re
|
||||
from threading import RLock
|
||||
import weakref
|
||||
import six
|
||||
|
||||
murmur3 = None
|
||||
try:
|
||||
from murmur3 import murmur3
|
||||
except ImportError:
|
||||
from cassandra.murmur3 import murmur3
|
||||
except ImportError as e:
|
||||
pass
|
||||
|
||||
import cassandra.cqltypes as types
|
||||
@@ -330,7 +331,7 @@ class Metadata(object):
|
||||
|
||||
token_to_host_owner = {}
|
||||
ring = []
|
||||
for host, token_strings in token_map.iteritems():
|
||||
for host, token_strings in six.iteritems(token_map):
|
||||
for token_string in token_strings:
|
||||
token = token_class(token_string)
|
||||
ring.append(token)
|
||||
@@ -793,14 +794,18 @@ class TableMetadata(object):
|
||||
return list(sorted(ret))
|
||||
|
||||
|
||||
def protect_name(name):
|
||||
if isinstance(name, unicode):
|
||||
name = name.encode('utf8')
|
||||
return maybe_escape_name(name)
|
||||
if six.PY3:
|
||||
def protect_name(name):
|
||||
return maybe_escape_name(name)
|
||||
else:
|
||||
def protect_name(name):
|
||||
if isinstance(name, six.text_type):
|
||||
name = name.encode('utf8')
|
||||
return maybe_escape_name(name)
|
||||
|
||||
|
||||
def protect_names(names):
|
||||
return map(protect_name, names)
|
||||
return [protect_name(n) for n in names]
|
||||
|
||||
|
||||
def protect_value(value):
|
||||
@@ -1008,11 +1013,14 @@ class Token(object):
|
||||
def __eq__(self, other):
|
||||
return self.value == other.value
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.value < other.value
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: %r>" % (self.__class__.__name__, self.value)
|
||||
return "<%s: %s>" % (self.__class__.__name__, self.value)
|
||||
__str__ = __repr__
|
||||
|
||||
MIN_LONG = -(2 ** 63)
|
||||
@@ -1031,7 +1039,7 @@ class Murmur3Token(Token):
|
||||
@classmethod
|
||||
def hash_fn(cls, key):
|
||||
if murmur3 is not None:
|
||||
h = murmur3(key)
|
||||
h = int(murmur3(key))
|
||||
return h if h != MIN_LONG else MAX_LONG
|
||||
else:
|
||||
raise NoMurmur3()
|
||||
@@ -1048,6 +1056,8 @@ class MD5Token(Token):
|
||||
|
||||
@classmethod
|
||||
def hash_fn(cls, key):
|
||||
if isinstance(key, six.text_type):
|
||||
key = key.encode('UTF-8')
|
||||
return abs(varint_unpack(md5(key).digest()))
|
||||
|
||||
def __init__(self, token):
|
||||
@@ -1062,7 +1072,7 @@ class BytesToken(Token):
|
||||
|
||||
def __init__(self, token_string):
|
||||
""" `token_string` should be string representing the token. """
|
||||
if not isinstance(token_string, basestring):
|
||||
if not isinstance(token_string, six.string_types):
|
||||
raise TypeError(
|
||||
"Tokens for ByteOrderedPartitioner should be strings (got %s)"
|
||||
% (type(token_string),))
|
||||
|
||||
@@ -169,6 +169,18 @@ uint64_t MurmurHash3_x64_128 (const void * key, const int len,
|
||||
return h1;
|
||||
}
|
||||
|
||||
|
||||
struct module_state {
|
||||
PyObject *error;
|
||||
};
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m))
|
||||
#else
|
||||
#define GETSTATE(m) (&_state)
|
||||
static struct module_state _state;
|
||||
#endif
|
||||
|
||||
static PyObject *
|
||||
murmur3(PyObject *self, PyObject *args)
|
||||
{
|
||||
@@ -186,23 +198,63 @@ murmur3(PyObject *self, PyObject *args)
|
||||
}
|
||||
|
||||
static PyMethodDef murmur3_methods[] = {
|
||||
{"murmur3", murmur3, METH_VARARGS,
|
||||
"Make an x64 murmur3 64-bit hash value"},
|
||||
|
||||
{"murmur3", murmur3, METH_VARARGS, "Make an x64 murmur3 64-bit hash value"},
|
||||
{NULL, NULL, 0, NULL}
|
||||
};
|
||||
|
||||
#if PY_MAJOR_VERSION <= 2
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
|
||||
PyMODINIT_FUNC
|
||||
initmurmur3(void)
|
||||
{
|
||||
(void) Py_InitModule("murmur3", murmur3_methods);
|
||||
static int murmur3_traverse(PyObject *m, visitproc visit, void *arg) {
|
||||
Py_VISIT(GETSTATE(m)->error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int murmur3_clear(PyObject *m) {
|
||||
Py_CLEAR(GETSTATE(m)->error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static struct PyModuleDef moduledef = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"murmur3",
|
||||
NULL,
|
||||
sizeof(struct module_state),
|
||||
murmur3_methods,
|
||||
NULL,
|
||||
murmur3_traverse,
|
||||
murmur3_clear,
|
||||
NULL
|
||||
};
|
||||
|
||||
#define INITERROR return NULL
|
||||
|
||||
PyObject *
|
||||
PyInit_murmur3(void)
|
||||
|
||||
#else
|
||||
#define INITERROR return
|
||||
|
||||
/* Python 3.x */
|
||||
// TODO
|
||||
|
||||
void
|
||||
initmurmur3(void)
|
||||
#endif
|
||||
{
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject *module = PyModule_Create(&moduledef);
|
||||
#else
|
||||
PyObject *module = Py_InitModule("murmur3", murmur3_methods);
|
||||
#endif
|
||||
|
||||
if (module == NULL)
|
||||
INITERROR;
|
||||
struct module_state *st = GETSTATE(module);
|
||||
|
||||
st->error = PyErr_NewException("murmur3.Error", NULL, NULL);
|
||||
if (st->error == NULL) {
|
||||
Py_DECREF(module);
|
||||
INITERROR;
|
||||
}
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return module;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -16,9 +16,12 @@ from itertools import islice, cycle, groupby, repeat
|
||||
import logging
|
||||
from random import randint
|
||||
from threading import Lock
|
||||
import six
|
||||
|
||||
from cassandra import ConsistencyLevel
|
||||
|
||||
from six.moves import range
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -263,7 +266,7 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
|
||||
for host in islice(cycle(local_live), pos, pos + len(local_live)):
|
||||
yield host
|
||||
|
||||
for dc, current_dc_hosts in self._dc_live_hosts.iteritems():
|
||||
for dc, current_dc_hosts in six.iteritems(self._dc_live_hosts):
|
||||
if dc == self.local_dc:
|
||||
continue
|
||||
|
||||
@@ -529,7 +532,7 @@ class ExponentialReconnectionPolicy(ReconnectionPolicy):
|
||||
self.max_delay = max_delay
|
||||
|
||||
def new_schedule(self):
|
||||
return (min(self.base_delay * (2 ** i), self.max_delay) for i in xrange(64))
|
||||
return (min(self.base_delay * (2 ** i), self.max_delay) for i in range(64))
|
||||
|
||||
|
||||
class WriteType(object):
|
||||
|
||||
@@ -150,8 +150,14 @@ class Host(object):
|
||||
def __eq__(self, other):
|
||||
return self.address == other.address
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.address)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.address < other.address
|
||||
|
||||
def __str__(self):
|
||||
return self.address
|
||||
return str(self.address)
|
||||
|
||||
def __repr__(self):
|
||||
dc = (" %s" % (self._datacenter,)) if self._datacenter else ""
|
||||
@@ -178,7 +184,7 @@ class _ReconnectionHandler(object):
|
||||
log.debug("Reconnection handler was cancelled before starting")
|
||||
return
|
||||
|
||||
first_delay = self.schedule.next()
|
||||
first_delay = next(self.schedule)
|
||||
self.scheduler.schedule(first_delay, self.run)
|
||||
|
||||
def run(self):
|
||||
@@ -189,7 +195,7 @@ class _ReconnectionHandler(object):
|
||||
try:
|
||||
conn = self.try_reconnect()
|
||||
except Exception as exc:
|
||||
next_delay = self.schedule.next()
|
||||
next_delay = next(self.schedule)
|
||||
if self.on_exception(exc, next_delay):
|
||||
self.scheduler.schedule(next_delay, self.run)
|
||||
else:
|
||||
@@ -260,8 +266,8 @@ class _HostReconnectionHandler(_ReconnectionHandler):
|
||||
if isinstance(exc, AuthenticationFailed):
|
||||
return False
|
||||
else:
|
||||
log.warn("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s",
|
||||
self.host, next_delay, exc)
|
||||
log.warning("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s",
|
||||
self.host, next_delay, exc)
|
||||
log.debug("Reconnection error details", exc_info=True)
|
||||
return True
|
||||
|
||||
@@ -371,8 +377,8 @@ class HostConnectionPool(object):
|
||||
def _create_new_connection(self):
|
||||
try:
|
||||
self._add_conn_if_under_max()
|
||||
except (ConnectionException, socket.error), exc:
|
||||
log.warn("Failed to create new connection to %s: %s", self.host, exc)
|
||||
except (ConnectionException, socket.error) as exc:
|
||||
log.warning("Failed to create new connection to %s: %s", self.host, exc)
|
||||
except Exception:
|
||||
log.exception("Unexpectedly failed to create new connection")
|
||||
finally:
|
||||
@@ -404,7 +410,7 @@ class HostConnectionPool(object):
|
||||
self._signal_available_conn()
|
||||
return True
|
||||
except (ConnectionException, socket.error) as exc:
|
||||
log.warn("Failed to add new connection to pool for host %s: %s", self.host, exc)
|
||||
log.warning("Failed to add new connection to pool for host %s: %s", self.host, exc)
|
||||
with self._lock:
|
||||
self.open_count -= 1
|
||||
if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False):
|
||||
|
||||
@@ -23,6 +23,7 @@ from datetime import datetime, timedelta
|
||||
import re
|
||||
import struct
|
||||
import time
|
||||
import six
|
||||
|
||||
from cassandra import ConsistencyLevel, OperationTimedOut
|
||||
from cassandra.cqltypes import unix_time_from_uuid1
|
||||
@@ -113,8 +114,8 @@ class Statement(object):
|
||||
|
||||
def _set_routing_key(self, key):
|
||||
if isinstance(key, (list, tuple)):
|
||||
self._routing_key = "".join(struct.pack("HsB", len(component), component, 0)
|
||||
for component in key)
|
||||
self._routing_key = b"".join(struct.pack("HsB", len(component), component, 0)
|
||||
for component in key)
|
||||
else:
|
||||
self._routing_key = key
|
||||
|
||||
@@ -408,7 +409,7 @@ class BoundStatement(Statement):
|
||||
val = self.values[statement_index]
|
||||
components.append(struct.pack("HsB", len(val), val, 0))
|
||||
|
||||
self._routing_key = "".join(components)
|
||||
self._routing_key = b"".join(components)
|
||||
|
||||
return self._routing_key
|
||||
|
||||
@@ -473,7 +474,7 @@ class BatchStatement(Statement):
|
||||
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level)
|
||||
|
||||
def add(self, statement, parameters=None):
|
||||
if isinstance(statement, basestring):
|
||||
if isinstance(statement, six.string_types):
|
||||
if parameters:
|
||||
statement = bind_params(statement, parameters)
|
||||
self._statements_and_parameters.append((False, statement, ()))
|
||||
@@ -532,11 +533,9 @@ class ValueSequence(object):
|
||||
|
||||
def bind_params(query, params):
|
||||
if isinstance(params, dict):
|
||||
return query % dict((k, cql_encoders.get(type(v), cql_encode_object)(v))
|
||||
for k, v in params.iteritems())
|
||||
return query % dict((k, cql_encoders.get(type(v), cql_encode_object)(v)) for k, v in six.iteritems(params))
|
||||
else:
|
||||
return query % tuple(cql_encoders.get(type(v), cql_encode_object)(v)
|
||||
for v in params)
|
||||
return query % tuple(cql_encoders.get(type(v), cql_encode_object)(v) for v in params)
|
||||
|
||||
|
||||
class TraceUnavailable(Exception):
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import with_statement
|
||||
|
||||
from UserDict import DictMixin
|
||||
|
||||
try:
|
||||
from collections import OrderedDict
|
||||
except ImportError:
|
||||
@@ -28,6 +26,7 @@ except ImportError:
|
||||
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
# OTHER DEALINGS IN THE SOFTWARE.
|
||||
from UserDict import DictMixin
|
||||
|
||||
class OrderedDict(dict, DictMixin): # noqa
|
||||
""" A dictionary which maintains the insertion order of keys. """
|
||||
@@ -80,9 +79,9 @@ except ImportError:
|
||||
if not self:
|
||||
raise KeyError('dictionary is empty')
|
||||
if last:
|
||||
key = reversed(self).next()
|
||||
key = next(reversed(self))
|
||||
else:
|
||||
key = iter(self).next()
|
||||
key = next(iter(self))
|
||||
value = self.pop(key)
|
||||
return key, value
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
blist
|
||||
futures
|
||||
scales
|
||||
six >=1.6
|
||||
|
||||
26
setup.py
26
setup.py
@@ -12,16 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import sys
|
||||
|
||||
import ez_setup
|
||||
ez_setup.use_setuptools()
|
||||
|
||||
run_gevent_nosetests = False
|
||||
if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests":
|
||||
from gevent.monkey import patch_all
|
||||
patch_all()
|
||||
run_gevent_nosetests = True
|
||||
|
||||
from setuptools import setup
|
||||
from distutils.command.build_ext import build_ext
|
||||
@@ -48,10 +47,11 @@ with open("README.rst") as f:
|
||||
long_description = f.read()
|
||||
|
||||
|
||||
gevent_nosetests = None
|
||||
if run_gevent_nosetests:
|
||||
try:
|
||||
from nose.commands import nosetests
|
||||
|
||||
except ImportError:
|
||||
gevent_nosetests = None
|
||||
else:
|
||||
class gevent_nosetests(nosetests):
|
||||
description = "run nosetests with gevent monkey patching"
|
||||
|
||||
@@ -92,11 +92,11 @@ class DocCommand(Command):
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise RuntimeError("Documentation step '%s' failed: %s: %s" % (mode, exc, exc.output))
|
||||
else:
|
||||
print output
|
||||
print(output)
|
||||
|
||||
print ""
|
||||
print "Documentation step '%s' performed, results here:" % mode
|
||||
print " %s/" % path
|
||||
print("")
|
||||
print("Documentation step '%s' performed, results here:" % mode)
|
||||
print(" %s/" % path)
|
||||
|
||||
|
||||
class BuildFailed(Exception):
|
||||
@@ -181,7 +181,7 @@ def run_setup(extensions):
|
||||
kw['cmdclass']['build_ext'] = build_extensions
|
||||
kw['ext_modules'] = extensions
|
||||
|
||||
dependencies = ['futures', 'scales', 'blist']
|
||||
dependencies = ['futures', 'scales >=1.0.5', 'blist', 'six >=1.6']
|
||||
if platform.python_implementation() != "CPython":
|
||||
dependencies.remove('blist')
|
||||
|
||||
@@ -196,7 +196,7 @@ def run_setup(extensions):
|
||||
packages=['cassandra', 'cassandra.io'],
|
||||
include_package_data=True,
|
||||
install_requires=dependencies,
|
||||
tests_require=['nose', 'mock', 'ccm', 'unittest2', 'PyYAML', 'pytz'],
|
||||
tests_require=['nose', 'mock', 'PyYAML', 'pytz'],
|
||||
classifiers=[
|
||||
'Development Status :: 5 - Production/Stable',
|
||||
'Intended Audience :: Developers',
|
||||
@@ -207,6 +207,10 @@ def run_setup(extensions):
|
||||
'Programming Language :: Python :: 2',
|
||||
'Programming Language :: Python :: 2.6',
|
||||
'Programming Language :: Python :: 2.7',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.3',
|
||||
'Programming Language :: Python :: Implementation :: CPython',
|
||||
'Programming Language :: Python :: Implementation :: PyPy',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules'
|
||||
],
|
||||
**kw)
|
||||
|
||||
@@ -22,7 +22,9 @@ except ImportError:
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
import os
|
||||
from six import print_
|
||||
from threading import Event
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
@@ -92,7 +94,7 @@ def get_node(node_id):
|
||||
|
||||
|
||||
def setup_package():
|
||||
print 'Using Cassandra version: %s' % CASSANDRA_VERSION
|
||||
print_('Using Cassandra version: %s' % CASSANDRA_VERSION)
|
||||
try:
|
||||
try:
|
||||
cluster = CCMCluster.load(path, CLUSTER_NAME)
|
||||
|
||||
@@ -12,7 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import Queue
|
||||
try:
|
||||
from Queue import Queue, Empty
|
||||
except ImportError:
|
||||
from queue import Queue, Empty # noqa
|
||||
|
||||
from struct import pack
|
||||
import unittest
|
||||
|
||||
@@ -31,13 +35,12 @@ def create_column_name(i):
|
||||
column_name = ''
|
||||
while True:
|
||||
column_name += letters[i % 10]
|
||||
i /= 10
|
||||
i = i // 10
|
||||
if not i:
|
||||
break
|
||||
|
||||
if column_name == 'if':
|
||||
column_name = 'special_case'
|
||||
|
||||
return column_name
|
||||
|
||||
|
||||
@@ -56,15 +59,15 @@ class LargeDataTests(unittest.TestCase):
|
||||
return session
|
||||
|
||||
def batch_futures(self, session, statement_generator):
|
||||
concurrency = 50
|
||||
futures = Queue.Queue(maxsize=concurrency)
|
||||
concurrency = 10
|
||||
futures = Queue(maxsize=concurrency)
|
||||
for i, statement in enumerate(statement_generator):
|
||||
if i > 0 and i % (concurrency - 1) == 0:
|
||||
# clear the existing queue
|
||||
while True:
|
||||
try:
|
||||
futures.get_nowait().result()
|
||||
except Queue.Empty:
|
||||
except Empty:
|
||||
break
|
||||
|
||||
future = session.execute_async(statement)
|
||||
@@ -73,7 +76,7 @@ class LargeDataTests(unittest.TestCase):
|
||||
while True:
|
||||
try:
|
||||
futures.get_nowait().result()
|
||||
except Queue.Empty:
|
||||
except Empty:
|
||||
break
|
||||
|
||||
def test_wide_rows(self):
|
||||
@@ -81,15 +84,13 @@ class LargeDataTests(unittest.TestCase):
|
||||
session = self.make_session_and_keyspace()
|
||||
session.execute('CREATE TABLE %s (k INT, i INT, PRIMARY KEY(k, i))' % table)
|
||||
|
||||
prepared = session.prepare('INSERT INTO %s (k, i) VALUES (0, ?)' % (table, ))
|
||||
|
||||
# Write via async futures
|
||||
self.batch_futures(
|
||||
session,
|
||||
(SimpleStatement('INSERT INTO %s (k, i) VALUES (0, %s)' % (table, i),
|
||||
consistency_level=ConsistencyLevel.QUORUM)
|
||||
for i in range(100000)))
|
||||
self.batch_futures(session, (prepared.bind((i, )) for i in range(100000)))
|
||||
|
||||
# Read
|
||||
results = session.execute('SELECT i FROM %s WHERE k=%s' % (table, 0))
|
||||
results = session.execute('SELECT i FROM %s WHERE k=0' % (table, ))
|
||||
|
||||
# Verify
|
||||
for i, row in enumerate(results):
|
||||
@@ -120,18 +121,13 @@ class LargeDataTests(unittest.TestCase):
|
||||
session = self.make_session_and_keyspace()
|
||||
session.execute('CREATE TABLE %s (k INT, i INT, v BLOB, PRIMARY KEY(k, i))' % table)
|
||||
|
||||
# Build small ByteBuffer sample
|
||||
bb = '0xCAFE'
|
||||
prepared = session.prepare('INSERT INTO %s (k, i, v) VALUES (0, ?, 0xCAFE)' % (table, ))
|
||||
|
||||
# Write
|
||||
self.batch_futures(
|
||||
session,
|
||||
(SimpleStatement('INSERT INTO %s (k, i, v) VALUES (0, %s, %s)' % (table, i, str(bb)),
|
||||
consistency_level=ConsistencyLevel.QUORUM)
|
||||
for i in range(100000)))
|
||||
self.batch_futures(session, (prepared.bind((i, )) for i in range(100000)))
|
||||
|
||||
# Read
|
||||
results = session.execute('SELECT i, v FROM %s WHERE k=%s' % (table, 0))
|
||||
results = session.execute('SELECT i, v FROM %s WHERE k=0' % (table, ))
|
||||
|
||||
# Verify
|
||||
bb = pack('>H', 0xCAFE)
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import logging
|
||||
import time
|
||||
|
||||
@@ -125,7 +126,7 @@ def bootstrap(node, data_center=None, token=None):
|
||||
|
||||
|
||||
def ring(node):
|
||||
print 'From node%s:' % node
|
||||
print('From node%s:' % node)
|
||||
get_node(node).nodetool('ring')
|
||||
|
||||
|
||||
@@ -154,3 +155,5 @@ def wait_for_down(cluster, node, wait=True):
|
||||
time.sleep(10)
|
||||
log.debug("Done waiting for node %s to be down", node)
|
||||
return
|
||||
else:
|
||||
log.debug("Host is still marked up, waiting")
|
||||
|
||||
@@ -40,7 +40,7 @@ class ClusterTests(unittest.TestCase):
|
||||
CREATE KEYSPACE clustertests
|
||||
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
|
||||
""")
|
||||
self.assertEquals(None, result)
|
||||
self.assertEqual(None, result)
|
||||
|
||||
result = session.execute(
|
||||
"""
|
||||
@@ -51,16 +51,16 @@ class ClusterTests(unittest.TestCase):
|
||||
PRIMARY KEY (a, b)
|
||||
)
|
||||
""")
|
||||
self.assertEquals(None, result)
|
||||
self.assertEqual(None, result)
|
||||
|
||||
result = session.execute(
|
||||
"""
|
||||
INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c')
|
||||
""")
|
||||
self.assertEquals(None, result)
|
||||
self.assertEqual(None, result)
|
||||
|
||||
result = session.execute("SELECT * FROM clustertests.cf0")
|
||||
self.assertEquals([('a', 'b', 'c')], result)
|
||||
self.assertEqual([('a', 'b', 'c')], result)
|
||||
|
||||
cluster.shutdown()
|
||||
|
||||
@@ -75,15 +75,15 @@ class ClusterTests(unittest.TestCase):
|
||||
"""
|
||||
INSERT INTO test3rf.test (k, v) VALUES (8889, 8889)
|
||||
""")
|
||||
self.assertEquals(None, result)
|
||||
self.assertEqual(None, result)
|
||||
|
||||
result = session.execute("SELECT * FROM test3rf.test")
|
||||
self.assertEquals([(8889, 8889)], result)
|
||||
self.assertEqual([(8889, 8889)], result)
|
||||
|
||||
# test_connect_on_keyspace
|
||||
session2 = cluster.connect('test3rf')
|
||||
result2 = session2.execute("SELECT * FROM test")
|
||||
self.assertEquals(result, result2)
|
||||
self.assertEqual(result, result2)
|
||||
|
||||
def test_set_keyspace_twice(self):
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
|
||||
@@ -43,7 +43,7 @@ class ClusterTests(unittest.TestCase):
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
parameters = [(i, i) for i in range(num_statements)]
|
||||
|
||||
results = execute_concurrent(self.session, zip(statements, parameters))
|
||||
results = execute_concurrent(self.session, list(zip(statements, parameters)))
|
||||
self.assertEqual(num_statements, len(results))
|
||||
self.assertEqual([(True, None)] * num_statements, results)
|
||||
|
||||
@@ -51,7 +51,7 @@ class ClusterTests(unittest.TestCase):
|
||||
statements = cycle(("SELECT v FROM test3rf.test WHERE k=%s", ))
|
||||
parameters = [(i, ) for i in range(num_statements)]
|
||||
|
||||
results = execute_concurrent(self.session, zip(statements, parameters))
|
||||
results = execute_concurrent(self.session, list(zip(statements, parameters)))
|
||||
self.assertEqual(num_statements, len(results))
|
||||
self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results)
|
||||
|
||||
@@ -81,7 +81,7 @@ class ClusterTests(unittest.TestCase):
|
||||
|
||||
self.assertRaises(
|
||||
InvalidRequest,
|
||||
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
|
||||
execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)
|
||||
|
||||
def test_first_failure_client_side(self):
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
@@ -92,7 +92,7 @@ class ClusterTests(unittest.TestCase):
|
||||
|
||||
self.assertRaises(
|
||||
TypeError,
|
||||
execute_concurrent, self.session, zip(statements, parameters), raise_on_first_error=True)
|
||||
execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True)
|
||||
|
||||
def test_no_raise_on_first_failure(self):
|
||||
statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", ))
|
||||
@@ -101,7 +101,7 @@ class ClusterTests(unittest.TestCase):
|
||||
# we'll get an error back from the server
|
||||
parameters[57] = ('efefef', 'awefawefawef')
|
||||
|
||||
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
|
||||
results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
|
||||
for i, (success, result) in enumerate(results):
|
||||
if i == 57:
|
||||
self.assertFalse(success)
|
||||
@@ -115,9 +115,9 @@ class ClusterTests(unittest.TestCase):
|
||||
parameters = [(i, i) for i in range(100)]
|
||||
|
||||
# the driver will raise an error when binding the params
|
||||
parameters[57] = i
|
||||
parameters[57] = 1
|
||||
|
||||
results = execute_concurrent(self.session, zip(statements, parameters), raise_on_first_error=False)
|
||||
results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False)
|
||||
for i, (success, result) in enumerate(results):
|
||||
if i == 57:
|
||||
self.assertFalse(success)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import six
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
@@ -115,7 +116,7 @@ class SchemaMetadataTest(unittest.TestCase):
|
||||
|
||||
def check_create_statement(self, tablemeta, original):
|
||||
recreate = tablemeta.as_cql_query(formatted=False)
|
||||
self.assertEquals(original, recreate[:len(original)])
|
||||
self.assertEqual(original, recreate[:len(original)])
|
||||
self.session.execute("DROP TABLE %s.%s" % (self.ksname, self.cfname))
|
||||
self.session.execute(recreate)
|
||||
|
||||
@@ -289,7 +290,7 @@ class SchemaMetadataTest(unittest.TestCase):
|
||||
tablemeta = self.get_table_metadata()
|
||||
statements = tablemeta.export_as_string().strip()
|
||||
statements = [s.strip() for s in statements.split(';')]
|
||||
statements = filter(bool, statements)
|
||||
statements = list(filter(bool, statements))
|
||||
self.assertEqual(3, len(statements))
|
||||
self.assertEqual(d_index, statements[1])
|
||||
self.assertEqual(e_index, statements[2])
|
||||
@@ -311,7 +312,7 @@ class TestCodeCoverage(unittest.TestCase):
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
cluster.connect()
|
||||
|
||||
self.assertIsInstance(cluster.metadata.export_schema_as_string(), basestring)
|
||||
self.assertIsInstance(cluster.metadata.export_schema_as_string(), six.string_types)
|
||||
|
||||
def test_export_keyspace_schema(self):
|
||||
"""
|
||||
@@ -323,8 +324,8 @@ class TestCodeCoverage(unittest.TestCase):
|
||||
|
||||
for keyspace in cluster.metadata.keyspaces:
|
||||
keyspace_metadata = cluster.metadata.keyspaces[keyspace]
|
||||
self.assertIsInstance(keyspace_metadata.export_as_string(), basestring)
|
||||
self.assertIsInstance(keyspace_metadata.as_cql_query(), basestring)
|
||||
self.assertIsInstance(keyspace_metadata.export_as_string(), six.string_types)
|
||||
self.assertIsInstance(keyspace_metadata.as_cql_query(), six.string_types)
|
||||
|
||||
def test_case_sensitivity(self):
|
||||
"""
|
||||
|
||||
@@ -69,7 +69,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
bound = prepared.bind(('a'))
|
||||
results = session.execute(bound)
|
||||
self.assertEquals(results, [('a', 'b', 'c')])
|
||||
self.assertEqual(results, [('a', 'b', 'c')])
|
||||
|
||||
# test with new dict binding
|
||||
prepared = session.prepare(
|
||||
@@ -95,7 +95,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
bound = prepared.bind({'a': 'x'})
|
||||
results = session.execute(bound)
|
||||
self.assertEquals(results, [('x', 'y', 'z')])
|
||||
self.assertEqual(results, [('x', 'y', 'z')])
|
||||
|
||||
def test_missing_primary_key(self):
|
||||
"""
|
||||
@@ -148,7 +148,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
""")
|
||||
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
self.assertRaises(ValueError, prepared.bind, (1,2))
|
||||
self.assertRaises(ValueError, prepared.bind, (1, 2))
|
||||
|
||||
def test_too_many_bind_values_dicts(self):
|
||||
"""
|
||||
@@ -196,7 +196,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
bound = prepared.bind((1,))
|
||||
results = session.execute(bound)
|
||||
self.assertEquals(results[0].v, None)
|
||||
self.assertEqual(results[0].v, None)
|
||||
|
||||
def test_none_values_dicts(self):
|
||||
"""
|
||||
@@ -206,7 +206,6 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
|
||||
# test with new dict binding
|
||||
prepared = session.prepare(
|
||||
"""
|
||||
@@ -225,7 +224,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
bound = prepared.bind({'k': 1})
|
||||
results = session.execute(bound)
|
||||
self.assertEquals(results[0].v, None)
|
||||
self.assertEqual(results[0].v, None)
|
||||
|
||||
def test_async_binding(self):
|
||||
"""
|
||||
@@ -252,8 +251,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
future = session.execute_async(prepared, (873,))
|
||||
results = future.result()
|
||||
self.assertEquals(results[0].v, None)
|
||||
|
||||
self.assertEqual(results[0].v, None)
|
||||
|
||||
def test_async_binding_dicts(self):
|
||||
"""
|
||||
@@ -280,4 +278,4 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
future = session.execute_async(prepared, {'k': 873})
|
||||
results = future.result()
|
||||
self.assertEquals(results[0].v, None)
|
||||
self.assertEqual(results[0].v, None)
|
||||
|
||||
@@ -24,7 +24,7 @@ from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.policies import HostDistance
|
||||
|
||||
from tests.integration import get_server_versions, PROTOCOL_VERSION
|
||||
from tests.integration import PROTOCOL_VERSION
|
||||
|
||||
|
||||
class QueryTest(unittest.TestCase):
|
||||
@@ -43,7 +43,7 @@ class QueryTest(unittest.TestCase):
|
||||
self.assertIsInstance(bound, BoundStatement)
|
||||
self.assertEqual(2, len(bound.values))
|
||||
session.execute(bound)
|
||||
self.assertEqual(bound.routing_key, '\x00\x00\x00\x01')
|
||||
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
|
||||
|
||||
def test_value_sequence(self):
|
||||
"""
|
||||
@@ -102,7 +102,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
bound = prepared.bind((1, None))
|
||||
self.assertEqual(bound.routing_key, '\x00\x00\x00\x01')
|
||||
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
|
||||
|
||||
def test_empty_routing_key_indexes(self):
|
||||
"""
|
||||
@@ -158,7 +158,7 @@ class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
bound = prepared.bind((1, 2))
|
||||
self.assertEqual(bound.routing_key, '\x04\x00\x00\x00\x04\x00\x00\x00')
|
||||
self.assertEqual(bound.routing_key, b'\x04\x00\x00\x00\x04\x00\x00\x00')
|
||||
|
||||
def test_bound_keyspace(self):
|
||||
"""
|
||||
@@ -326,14 +326,14 @@ class SerialConsistencyTests(unittest.TestCase):
|
||||
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1",
|
||||
serial_consistency_level=ConsistencyLevel.SERIAL)
|
||||
result = self.session.execute(statement)
|
||||
self.assertEquals(1, len(result))
|
||||
self.assertEqual(1, len(result))
|
||||
self.assertFalse(result[0].applied)
|
||||
|
||||
statement = SimpleStatement(
|
||||
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0",
|
||||
serial_consistency_level=ConsistencyLevel.SERIAL)
|
||||
result = self.session.execute(statement)
|
||||
self.assertEquals(1, len(result))
|
||||
self.assertEqual(1, len(result))
|
||||
self.assertTrue(result[0].applied)
|
||||
|
||||
def test_conditional_update_with_prepared_statements(self):
|
||||
@@ -343,7 +343,7 @@ class SerialConsistencyTests(unittest.TestCase):
|
||||
|
||||
statement.serial_consistency_level = ConsistencyLevel.SERIAL
|
||||
result = self.session.execute(statement)
|
||||
self.assertEquals(1, len(result))
|
||||
self.assertEqual(1, len(result))
|
||||
self.assertFalse(result[0].applied)
|
||||
|
||||
statement = self.session.prepare(
|
||||
@@ -351,7 +351,7 @@ class SerialConsistencyTests(unittest.TestCase):
|
||||
bound = statement.bind(())
|
||||
bound.serial_consistency_level = ConsistencyLevel.SERIAL
|
||||
result = self.session.execute(statement)
|
||||
self.assertEquals(1, len(result))
|
||||
self.assertEqual(1, len(result))
|
||||
self.assertTrue(result[0].applied)
|
||||
|
||||
def test_bad_consistency_level(self):
|
||||
|
||||
@@ -23,6 +23,7 @@ except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from itertools import cycle, count
|
||||
from six.moves import range
|
||||
from threading import Event
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
@@ -47,7 +48,7 @@ class QueryPagingTests(unittest.TestCase):
|
||||
def test_paging(self):
|
||||
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
||||
[(i, ) for i in range(100)])
|
||||
execute_concurrent(self.session, statements_and_params)
|
||||
execute_concurrent(self.session, list(statements_and_params))
|
||||
|
||||
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
||||
|
||||
@@ -153,7 +154,7 @@ class QueryPagingTests(unittest.TestCase):
|
||||
def test_async_paging(self):
|
||||
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
||||
[(i, ) for i in range(100)])
|
||||
execute_concurrent(self.session, statements_and_params)
|
||||
execute_concurrent(self.session, list(statements_and_params))
|
||||
|
||||
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
||||
|
||||
@@ -219,7 +220,7 @@ class QueryPagingTests(unittest.TestCase):
|
||||
def test_paging_callbacks(self):
|
||||
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
||||
[(i, ) for i in range(100)])
|
||||
execute_concurrent(self.session, statements_and_params)
|
||||
execute_concurrent(self.session, list(statements_and_params))
|
||||
|
||||
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
||||
|
||||
@@ -232,7 +233,7 @@ class QueryPagingTests(unittest.TestCase):
|
||||
|
||||
def handle_page(rows, future, counter):
|
||||
for row in rows:
|
||||
counter.next()
|
||||
next(counter)
|
||||
|
||||
if future.has_more_pages:
|
||||
future.start_fetching_next_page()
|
||||
@@ -245,7 +246,7 @@ class QueryPagingTests(unittest.TestCase):
|
||||
|
||||
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
||||
event.wait()
|
||||
self.assertEquals(counter.next(), 100)
|
||||
self.assertEquals(next(counter), 100)
|
||||
|
||||
# simple statement
|
||||
future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"))
|
||||
@@ -254,7 +255,7 @@ class QueryPagingTests(unittest.TestCase):
|
||||
|
||||
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
||||
event.wait()
|
||||
self.assertEquals(counter.next(), 100)
|
||||
self.assertEquals(next(counter), 100)
|
||||
|
||||
# prepared statement
|
||||
future = self.session.execute_async(prepared)
|
||||
@@ -263,4 +264,4 @@ class QueryPagingTests(unittest.TestCase):
|
||||
|
||||
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
||||
event.wait()
|
||||
self.assertEquals(counter.next(), 100)
|
||||
self.assertEquals(next(counter), 100)
|
||||
|
||||
@@ -17,8 +17,12 @@ try:
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
import six
|
||||
from uuid import uuid1, uuid4
|
||||
|
||||
try:
|
||||
@@ -59,12 +63,16 @@ class TypeTests(unittest.TestCase):
|
||||
|
||||
params = [
|
||||
'key1',
|
||||
'blobyblob'.encode('hex')
|
||||
b'blobyblob'
|
||||
]
|
||||
|
||||
query = 'INSERT INTO mytable (a, b) VALUES (%s, %s)'
|
||||
|
||||
if self._cql_version >= (3, 1, 0):
|
||||
# In python 3, the 'bytes' type is treated as a blob, so we can
|
||||
# correctly encode it with hex notation.
|
||||
# In python2, we don't treat the 'str' type as a blob, so we'll encode it
|
||||
# as a string literal and have the following failure.
|
||||
if six.PY2 and self._cql_version >= (3, 1, 0):
|
||||
# Blob values can't be specified using string notation in CQL 3.1.0 and
|
||||
# above which is used by default in Cassandra 2.0.
|
||||
msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*'
|
||||
@@ -74,13 +82,13 @@ class TypeTests(unittest.TestCase):
|
||||
s.execute(query, params)
|
||||
expected_vals = [
|
||||
'key1',
|
||||
'blobyblob'
|
||||
bytearray(b'blobyblob')
|
||||
]
|
||||
|
||||
results = s.execute("SELECT * FROM mytable")
|
||||
|
||||
for expected, actual in zip(expected_vals, results[0]):
|
||||
self.assertEquals(expected, actual)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_blob_type_as_bytearray(self):
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
@@ -101,7 +109,7 @@ class TypeTests(unittest.TestCase):
|
||||
|
||||
params = [
|
||||
'key1',
|
||||
bytearray('blob1', 'hex')
|
||||
bytearray(b'blob1')
|
||||
]
|
||||
|
||||
query = 'INSERT INTO mytable (a, b) VALUES (%s, %s);'
|
||||
@@ -109,13 +117,13 @@ class TypeTests(unittest.TestCase):
|
||||
|
||||
expected_vals = [
|
||||
'key1',
|
||||
bytearray('blob1', 'hex')
|
||||
bytearray(b'blob1')
|
||||
]
|
||||
|
||||
results = s.execute("SELECT * FROM mytable")
|
||||
|
||||
for expected, actual in zip(expected_vals, results[0]):
|
||||
self.assertEquals(expected, actual)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
create_type_table = """
|
||||
CREATE TABLE mytable (
|
||||
@@ -208,7 +216,7 @@ class TypeTests(unittest.TestCase):
|
||||
results = s.execute("SELECT * FROM mytable")
|
||||
|
||||
for expected, actual in zip(expected_vals, results[0]):
|
||||
self.assertEquals(expected, actual)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
# try the same thing with a prepared statement
|
||||
prepared = s.prepare("""
|
||||
@@ -221,7 +229,7 @@ class TypeTests(unittest.TestCase):
|
||||
results = s.execute("SELECT * FROM mytable")
|
||||
|
||||
for expected, actual in zip(expected_vals, results[0]):
|
||||
self.assertEquals(expected, actual)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
# query with prepared statement
|
||||
prepared = s.prepare("""
|
||||
@@ -230,14 +238,14 @@ class TypeTests(unittest.TestCase):
|
||||
results = s.execute(prepared.bind(()))
|
||||
|
||||
for expected, actual in zip(expected_vals, results[0]):
|
||||
self.assertEquals(expected, actual)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
# query with prepared statement, no explicit columns
|
||||
prepared = s.prepare("""SELECT * FROM mytable""")
|
||||
results = s.execute(prepared.bind(()))
|
||||
|
||||
for expected, actual in zip(expected_vals, results[0]):
|
||||
self.assertEquals(expected, actual)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_empty_strings_and_nones(self):
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
@@ -265,11 +273,11 @@ class TypeTests(unittest.TestCase):
|
||||
# insert empty strings for string-like fields and fetch them
|
||||
s.execute("INSERT INTO mytable (a, b, c, o, s, l, n) VALUES ('a', 'b', %s, %s, %s, %s, %s)",
|
||||
('', '', '', [''], {'': 3}))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
{'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})},
|
||||
s.execute("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'")[0])
|
||||
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
{'c': '', 'o': '', 's': '', 'l': ('', ), 'n': OrderedDict({'': 3})},
|
||||
s.execute(s.prepare("SELECT c, o, s, l, n FROM mytable WHERE a='a' AND b='b'"), [])[0])
|
||||
|
||||
@@ -363,7 +371,7 @@ class TypeTests(unittest.TestCase):
|
||||
""" Ensure timezone-aware datetimes are converted to timestamps correctly """
|
||||
try:
|
||||
import pytz
|
||||
except ImportError, exc:
|
||||
except ImportError as exc:
|
||||
raise unittest.SkipTest('pytz is not available: %r' % (exc,))
|
||||
|
||||
dt = datetime(1997, 8, 29, 11, 14)
|
||||
@@ -381,10 +389,10 @@ class TypeTests(unittest.TestCase):
|
||||
# test non-prepared statement
|
||||
s.execute("INSERT INTO mytable (a, b) VALUES ('key1', %s)", parameters=(dt,))
|
||||
result = s.execute("SELECT b FROM mytable WHERE a='key1'")[0].b
|
||||
self.assertEquals(dt.utctimetuple(), result.utctimetuple())
|
||||
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
|
||||
|
||||
# test prepared statement
|
||||
prepared = s.prepare("INSERT INTO mytable (a, b) VALUES ('key2', ?)")
|
||||
s.execute(prepared, parameters=(dt,))
|
||||
result = s.execute("SELECT b FROM mytable WHERE a='key2'")[0].b
|
||||
self.assertEquals(dt.utctimetuple(), result.utctimetuple())
|
||||
self.assertEqual(dt.utctimetuple(), result.utctimetuple())
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import six
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
@@ -19,7 +20,9 @@ except ImportError:
|
||||
|
||||
import errno
|
||||
import os
|
||||
from StringIO import StringIO
|
||||
|
||||
from six import BytesIO
|
||||
|
||||
import socket
|
||||
from socket import error as socket_error
|
||||
|
||||
@@ -55,7 +58,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
return c
|
||||
|
||||
def make_header_prefix(self, message_class, version=2, stream_id=0):
|
||||
return ''.join(map(uint8_pack, [
|
||||
return six.binary_type().join(map(uint8_pack, [
|
||||
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
|
||||
0, # flags (compression)
|
||||
stream_id,
|
||||
@@ -63,7 +66,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
]))
|
||||
|
||||
def make_options_body(self):
|
||||
options_buf = StringIO()
|
||||
options_buf = BytesIO()
|
||||
write_stringmultimap(options_buf, {
|
||||
'CQL_VERSION': ['3.0.1'],
|
||||
'COMPRESSION': []
|
||||
@@ -71,12 +74,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
return options_buf.getvalue()
|
||||
|
||||
def make_error_body(self, code, msg):
|
||||
buf = StringIO()
|
||||
buf = BytesIO()
|
||||
write_int(buf, code)
|
||||
write_string(buf, msg)
|
||||
return buf.getvalue()
|
||||
|
||||
def make_msg(self, header, body=""):
|
||||
def make_msg(self, header, body=six.binary_type()):
|
||||
return header + uint32_pack(len(body)) + body
|
||||
|
||||
def test_successful_connection(self, *args):
|
||||
@@ -105,12 +108,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
# get a connection that's already fully started
|
||||
c = self.test_successful_connection()
|
||||
|
||||
header = '\x00\x00\x00\x00' + int32_pack(20000)
|
||||
header = six.b('\x00\x00\x00\x00') + int32_pack(20000)
|
||||
responses = [
|
||||
header + ('a' * (4096 - len(header))),
|
||||
'a' * 4096,
|
||||
header + (six.b('a') * (4096 - len(header))),
|
||||
six.b('a') * 4096,
|
||||
socket_error(errno.EAGAIN),
|
||||
'a' * 100,
|
||||
six.b('a') * 100,
|
||||
socket_error(errno.EAGAIN)]
|
||||
|
||||
def side_effect(*args):
|
||||
@@ -122,17 +125,17 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
|
||||
c.socket.recv.side_effect = side_effect
|
||||
c.handle_read()
|
||||
self.assertEquals(c._total_reqd_bytes, 20000 + len(header))
|
||||
self.assertEqual(c._total_reqd_bytes, 20000 + len(header))
|
||||
# the EAGAIN prevents it from reading the last 100 bytes
|
||||
c._iobuf.seek(0, os.SEEK_END)
|
||||
pos = c._iobuf.tell()
|
||||
self.assertEquals(pos, 4096 + 4096)
|
||||
self.assertEqual(pos, 4096 + 4096)
|
||||
|
||||
# now tell it to read the last 100 bytes
|
||||
c.handle_read()
|
||||
c._iobuf.seek(0, os.SEEK_END)
|
||||
pos = c._iobuf.tell()
|
||||
self.assertEquals(pos, 4096 + 4096 + 100)
|
||||
self.assertEqual(pos, 4096 + 4096 + 100)
|
||||
|
||||
def test_protocol_error(self, *args):
|
||||
c = self.make_connection()
|
||||
@@ -237,14 +240,13 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
options = self.make_options_body()
|
||||
message = self.make_msg(header, options)
|
||||
|
||||
# read in the first byte
|
||||
c.socket.recv.return_value = message[0]
|
||||
c.socket.recv.return_value = message[0:1]
|
||||
c.handle_read()
|
||||
self.assertEquals(c._iobuf.getvalue(), message[0])
|
||||
self.assertEqual(c._iobuf.getvalue(), message[0:1])
|
||||
|
||||
c.socket.recv.return_value = message[1:]
|
||||
c.handle_read()
|
||||
self.assertEquals("", c._iobuf.getvalue())
|
||||
self.assertEqual(six.binary_type(), c._iobuf.getvalue())
|
||||
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write()
|
||||
@@ -266,12 +268,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
# read in the first nine bytes
|
||||
c.socket.recv.return_value = message[:9]
|
||||
c.handle_read()
|
||||
self.assertEquals(c._iobuf.getvalue(), message[:9])
|
||||
self.assertEqual(c._iobuf.getvalue(), message[:9])
|
||||
|
||||
# ... then read in the rest
|
||||
c.socket.recv.return_value = message[9:]
|
||||
c.handle_read()
|
||||
self.assertEquals("", c._iobuf.getvalue())
|
||||
self.assertEqual(six.binary_type(), c._iobuf.getvalue())
|
||||
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write()
|
||||
|
||||
@@ -19,7 +19,9 @@ except ImportError:
|
||||
|
||||
import errno
|
||||
import os
|
||||
from StringIO import StringIO
|
||||
|
||||
from six.moves import StringIO
|
||||
|
||||
from socket import error as socket_error
|
||||
|
||||
from mock import patch, Mock
|
||||
@@ -122,17 +124,17 @@ class LibevConnectionTest(unittest.TestCase):
|
||||
|
||||
c._socket.recv.side_effect = side_effect
|
||||
c.handle_read(None, 0)
|
||||
self.assertEquals(c._total_reqd_bytes, 20000 + len(header))
|
||||
self.assertEqual(c._total_reqd_bytes, 20000 + len(header))
|
||||
# the EAGAIN prevents it from reading the last 100 bytes
|
||||
c._iobuf.seek(0, os.SEEK_END)
|
||||
pos = c._iobuf.tell()
|
||||
self.assertEquals(pos, 4096 + 4096)
|
||||
self.assertEqual(pos, 4096 + 4096)
|
||||
|
||||
# now tell it to read the last 100 bytes
|
||||
c.handle_read(None, 0)
|
||||
c._iobuf.seek(0, os.SEEK_END)
|
||||
pos = c._iobuf.tell()
|
||||
self.assertEquals(pos, 4096 + 4096 + 100)
|
||||
self.assertEqual(pos, 4096 + 4096 + 100)
|
||||
|
||||
def test_protocol_error(self, *args):
|
||||
c = self.make_connection()
|
||||
@@ -240,11 +242,11 @@ class LibevConnectionTest(unittest.TestCase):
|
||||
# read in the first byte
|
||||
c._socket.recv.return_value = message[0]
|
||||
c.handle_read(None, 0)
|
||||
self.assertEquals(c._iobuf.getvalue(), message[0])
|
||||
self.assertEqual(c._iobuf.getvalue(), message[0])
|
||||
|
||||
c._socket.recv.return_value = message[1:]
|
||||
c.handle_read(None, 0)
|
||||
self.assertEquals("", c._iobuf.getvalue())
|
||||
self.assertEqual("", c._iobuf.getvalue())
|
||||
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write(None, 0)
|
||||
@@ -266,12 +268,12 @@ class LibevConnectionTest(unittest.TestCase):
|
||||
# read in the first nine bytes
|
||||
c._socket.recv.return_value = message[:9]
|
||||
c.handle_read(None, 0)
|
||||
self.assertEquals(c._iobuf.getvalue(), message[:9])
|
||||
self.assertEqual(c._iobuf.getvalue(), message[:9])
|
||||
|
||||
# ... then read in the rest
|
||||
c._socket.recv.return_value = message[9:]
|
||||
c.handle_read(None, 0)
|
||||
self.assertEquals("", c._iobuf.getvalue())
|
||||
self.assertEqual("", c._iobuf.getvalue())
|
||||
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write(None, 0)
|
||||
|
||||
@@ -11,17 +11,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from cassandra.cluster import Cluster
|
||||
import six
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from StringIO import StringIO
|
||||
from six import BytesIO
|
||||
|
||||
from mock import Mock, ANY
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT,
|
||||
HEADER_DIRECTION_FROM_CLIENT, ProtocolError)
|
||||
from cassandra.decoder import (write_stringmultimap, write_int, write_string,
|
||||
@@ -40,7 +41,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
return c
|
||||
|
||||
def make_header_prefix(self, message_class, version=2, stream_id=0):
|
||||
return ''.join(map(uint8_pack, [
|
||||
return six.binary_type().join(map(uint8_pack, [
|
||||
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
|
||||
0, # flags (compression)
|
||||
stream_id,
|
||||
@@ -48,7 +49,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
]))
|
||||
|
||||
def make_options_body(self):
|
||||
options_buf = StringIO()
|
||||
options_buf = BytesIO()
|
||||
write_stringmultimap(options_buf, {
|
||||
'CQL_VERSION': ['3.0.1'],
|
||||
'COMPRESSION': []
|
||||
@@ -56,7 +57,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
return options_buf.getvalue()
|
||||
|
||||
def make_error_body(self, code, msg):
|
||||
buf = StringIO()
|
||||
buf = BytesIO()
|
||||
write_int(buf, code)
|
||||
write_string(buf, msg)
|
||||
return buf.getvalue()
|
||||
@@ -88,12 +89,12 @@ class ConnectionTest(unittest.TestCase):
|
||||
c.defunct = Mock()
|
||||
|
||||
# read in a SupportedMessage response
|
||||
header = ''.join(map(uint8_pack, [
|
||||
header = six.binary_type().join(uint8_pack(i) for i in (
|
||||
0xff & (HEADER_DIRECTION_FROM_CLIENT | self.protocol_version),
|
||||
0, # flags (compression)
|
||||
0,
|
||||
SupportedMessage.opcode # opcode
|
||||
]))
|
||||
))
|
||||
options = self.make_options_body()
|
||||
message = self.make_msg(header, options)
|
||||
c.process_msg(message, len(message) - 8)
|
||||
@@ -130,7 +131,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
# read in a SupportedMessage response
|
||||
header = self.make_header_prefix(SupportedMessage)
|
||||
|
||||
options_buf = StringIO()
|
||||
options_buf = BytesIO()
|
||||
write_stringmultimap(options_buf, {
|
||||
'CQL_VERSION': ['7.8.9'],
|
||||
'COMPRESSION': []
|
||||
|
||||
@@ -16,7 +16,7 @@ from cassandra.marshal import bitlength
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
import unittest # noqa
|
||||
|
||||
import platform
|
||||
from datetime import datetime
|
||||
@@ -33,57 +33,57 @@ from cassandra.util import OrderedDict
|
||||
|
||||
marshalled_value_pairs = (
|
||||
# binary form, type, python native type
|
||||
('lorem ipsum dolor sit amet', 'AsciiType', 'lorem ipsum dolor sit amet'),
|
||||
('', 'AsciiType', ''),
|
||||
('\x01', 'BooleanType', True),
|
||||
('\x00', 'BooleanType', False),
|
||||
('', 'BooleanType', None),
|
||||
('\xff\xfe\xfd\xfc\xfb', 'BytesType', '\xff\xfe\xfd\xfc\xfb'),
|
||||
('', 'BytesType', ''),
|
||||
('\x7f\xff\xff\xff\xff\xff\xff\xff', 'CounterColumnType', 9223372036854775807),
|
||||
('\x80\x00\x00\x00\x00\x00\x00\x00', 'CounterColumnType', -9223372036854775808),
|
||||
('', 'CounterColumnType', None),
|
||||
('\x00\x00\x013\x7fb\xeey', 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)),
|
||||
('', 'DateType', None),
|
||||
('\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe', 'DecimalType', Decimal('1243878957943.1234124191998')),
|
||||
('\x00\x00\x00\x06\xe5\xde]\x98Y', 'DecimalType', Decimal('-112233.441191')),
|
||||
('\x00\x00\x00\x14\x00\xfa\xce', 'DecimalType', Decimal('0.00000000000000064206')),
|
||||
('\x00\x00\x00\x14\xff\x052', 'DecimalType', Decimal('-0.00000000000000064206')),
|
||||
('\xff\xff\xff\x9c\x00\xfa\xce', 'DecimalType', Decimal('64206e100')),
|
||||
('', 'DecimalType', None),
|
||||
('@\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', 19432.125),
|
||||
('\xc0\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', -19432.125),
|
||||
('\x7f\xef\x00\x00\x00\x00\x00\x00', 'DoubleType', 1.7415152243978685e+308),
|
||||
('', 'DoubleType', None),
|
||||
('F\x97\xd0@', 'FloatType', 19432.125),
|
||||
('\xc6\x97\xd0@', 'FloatType', -19432.125),
|
||||
('\xc6\x97\xd0@', 'FloatType', -19432.125),
|
||||
('\x7f\x7f\x00\x00', 'FloatType', 338953138925153547590470800371487866880.0),
|
||||
('', 'FloatType', None),
|
||||
('\x7f\x50\x00\x00', 'Int32Type', 2135949312),
|
||||
('\xff\xfd\xcb\x91', 'Int32Type', -144495),
|
||||
('', 'Int32Type', None),
|
||||
('f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15', 'IntegerType', 123456789123456789123456789),
|
||||
('', 'IntegerType', None),
|
||||
('\x7f\xff\xff\xff\xff\xff\xff\xff', 'LongType', 9223372036854775807),
|
||||
('\x80\x00\x00\x00\x00\x00\x00\x00', 'LongType', -9223372036854775808),
|
||||
('', 'LongType', None),
|
||||
('', 'InetAddressType', None),
|
||||
('A46\xa9', 'InetAddressType', '65.52.54.169'),
|
||||
('*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"', 'InetAddressType', '2a00:1328:e102:ccc0::122'),
|
||||
('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', u'\u307e\u3057\u3066'),
|
||||
('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', u'\u307e\u3057\u3066' * 1000),
|
||||
('', 'UTF8Type', u''),
|
||||
('\xff' * 16, 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')),
|
||||
('I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*', 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')),
|
||||
('', 'UUIDType', None),
|
||||
('', 'MapType(AsciiType, BooleanType)', None),
|
||||
('', 'ListType(FloatType)', None),
|
||||
('', 'SetType(LongType)', None),
|
||||
('\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()),
|
||||
('\x00\x00', 'ListType(FloatType)', ()),
|
||||
('\x00\x00', 'SetType(IntegerType)', sortedset()),
|
||||
('\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', (UUID(bytes='\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0'),)),
|
||||
(b'lorem ipsum dolor sit amet', 'AsciiType', 'lorem ipsum dolor sit amet'),
|
||||
(b'', 'AsciiType', ''),
|
||||
(b'\x01', 'BooleanType', True),
|
||||
(b'\x00', 'BooleanType', False),
|
||||
(b'', 'BooleanType', None),
|
||||
(b'\xff\xfe\xfd\xfc\xfb', 'BytesType', b'\xff\xfe\xfd\xfc\xfb'),
|
||||
(b'', 'BytesType', b''),
|
||||
(b'\x7f\xff\xff\xff\xff\xff\xff\xff', 'CounterColumnType', 9223372036854775807),
|
||||
(b'\x80\x00\x00\x00\x00\x00\x00\x00', 'CounterColumnType', -9223372036854775808),
|
||||
(b'', 'CounterColumnType', None),
|
||||
(b'\x00\x00\x013\x7fb\xeey', 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)),
|
||||
(b'', 'DateType', None),
|
||||
(b'\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe', 'DecimalType', Decimal('1243878957943.1234124191998')),
|
||||
(b'\x00\x00\x00\x06\xe5\xde]\x98Y', 'DecimalType', Decimal('-112233.441191')),
|
||||
(b'\x00\x00\x00\x14\x00\xfa\xce', 'DecimalType', Decimal('0.00000000000000064206')),
|
||||
(b'\x00\x00\x00\x14\xff\x052', 'DecimalType', Decimal('-0.00000000000000064206')),
|
||||
(b'\xff\xff\xff\x9c\x00\xfa\xce', 'DecimalType', Decimal('64206e100')),
|
||||
(b'', 'DecimalType', None),
|
||||
(b'@\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', 19432.125),
|
||||
(b'\xc0\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', -19432.125),
|
||||
(b'\x7f\xef\x00\x00\x00\x00\x00\x00', 'DoubleType', 1.7415152243978685e+308),
|
||||
(b'', 'DoubleType', None),
|
||||
(b'F\x97\xd0@', 'FloatType', 19432.125),
|
||||
(b'\xc6\x97\xd0@', 'FloatType', -19432.125),
|
||||
(b'\xc6\x97\xd0@', 'FloatType', -19432.125),
|
||||
(b'\x7f\x7f\x00\x00', 'FloatType', 338953138925153547590470800371487866880.0),
|
||||
(b'', 'FloatType', None),
|
||||
(b'\x7f\x50\x00\x00', 'Int32Type', 2135949312),
|
||||
(b'\xff\xfd\xcb\x91', 'Int32Type', -144495),
|
||||
(b'', 'Int32Type', None),
|
||||
(b'f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15', 'IntegerType', 123456789123456789123456789),
|
||||
(b'', 'IntegerType', None),
|
||||
(b'\x7f\xff\xff\xff\xff\xff\xff\xff', 'LongType', 9223372036854775807),
|
||||
(b'\x80\x00\x00\x00\x00\x00\x00\x00', 'LongType', -9223372036854775808),
|
||||
(b'', 'LongType', None),
|
||||
(b'', 'InetAddressType', None),
|
||||
(b'A46\xa9', 'InetAddressType', '65.52.54.169'),
|
||||
(b'*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"', 'InetAddressType', '2a00:1328:e102:ccc0::122'),
|
||||
(b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', u'\u307e\u3057\u3066'),
|
||||
(b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', u'\u307e\u3057\u3066' * 1000),
|
||||
(b'', 'UTF8Type', u''),
|
||||
(b'\xff' * 16, 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')),
|
||||
(b'I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*', 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')),
|
||||
(b'', 'UUIDType', None),
|
||||
(b'', 'MapType(AsciiType, BooleanType)', None),
|
||||
(b'', 'ListType(FloatType)', None),
|
||||
(b'', 'SetType(LongType)', None),
|
||||
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()),
|
||||
(b'\x00\x00', 'ListType(FloatType)', ()),
|
||||
(b'\x00\x00', 'SetType(IntegerType)', sortedset()),
|
||||
(b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', (UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0'),)),
|
||||
)
|
||||
|
||||
ordered_dict_value = OrderedDict()
|
||||
@@ -94,9 +94,9 @@ ordered_dict_value[u'\\'] = 0
|
||||
# these following entries work for me right now, but they're dependent on
|
||||
# vagaries of internal python ordering for unordered types
|
||||
marshalled_value_pairs_unsafe = (
|
||||
('\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value),
|
||||
('\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])),
|
||||
('\x00', 'IntegerType', 0),
|
||||
(b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value),
|
||||
(b'\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])),
|
||||
(b'\x00', 'IntegerType', 0),
|
||||
)
|
||||
|
||||
if platform.python_implementation() == 'CPython':
|
||||
|
||||
@@ -29,6 +29,12 @@ from cassandra.pool import Host
|
||||
|
||||
class TestStrategies(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"Hook method for setting up class fixture before running tests in the class."
|
||||
if not hasattr(cls, 'assertItemsEqual'):
|
||||
cls.assertItemsEqual = cls.assertCountEqual
|
||||
|
||||
def test_replication_strategy(self):
|
||||
"""
|
||||
Basic code coverage testing that ensures different ReplicationStrategies
|
||||
@@ -217,22 +223,22 @@ class TestTokens(unittest.TestCase):
|
||||
murmur3_token = Murmur3Token(cassandra.metadata.MIN_LONG - 1)
|
||||
self.assertEqual(murmur3_token.hash_fn('123'), -7468325962851647638)
|
||||
self.assertEqual(murmur3_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 7162290910810015547)
|
||||
self.assertEqual(str(murmur3_token), '<Murmur3Token: -9223372036854775809L>')
|
||||
self.assertEqual(str(murmur3_token), '<Murmur3Token: -9223372036854775809>')
|
||||
except NoMurmur3:
|
||||
raise unittest.SkipTest('The murmur3 extension is not available')
|
||||
|
||||
def test_md5_tokens(self):
|
||||
md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1)
|
||||
self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808L)
|
||||
self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639L)
|
||||
self.assertEqual(str(md5_token), '<MD5Token: -9223372036854775809L>')
|
||||
self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808)
|
||||
self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639)
|
||||
self.assertEqual(str(md5_token), '<MD5Token: %s>' % -9223372036854775809)
|
||||
|
||||
def test_bytes_tokens(self):
|
||||
bytes_token = BytesToken(str(cassandra.metadata.MIN_LONG - 1))
|
||||
self.assertEqual(bytes_token.hash_fn('123'), '123')
|
||||
self.assertEqual(bytes_token.hash_fn(123), 123)
|
||||
self.assertEqual(bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)), str(cassandra.metadata.MAX_LONG))
|
||||
self.assertEqual(str(bytes_token), "<BytesToken: '-9223372036854775809'>")
|
||||
self.assertEqual(str(bytes_token), "<BytesToken: -9223372036854775809>")
|
||||
|
||||
try:
|
||||
bytes_token = BytesToken(cassandra.metadata.MIN_LONG - 1)
|
||||
|
||||
@@ -22,32 +22,34 @@ from cassandra.query import PreparedStatement, BoundStatement
|
||||
from cassandra.cqltypes import Int32Type
|
||||
from cassandra.util import OrderedDict
|
||||
|
||||
from six.moves import xrange
|
||||
|
||||
|
||||
class ParamBindingTest(unittest.TestCase):
|
||||
|
||||
def test_bind_sequence(self):
|
||||
result = bind_params("%s %s %s", (1, "a", 2.0))
|
||||
self.assertEquals(result, "1 'a' 2.0")
|
||||
self.assertEqual(result, "1 'a' 2.0")
|
||||
|
||||
def test_bind_map(self):
|
||||
result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0))
|
||||
self.assertEquals(result, "1 'a' 2.0")
|
||||
self.assertEqual(result, "1 'a' 2.0")
|
||||
|
||||
def test_sequence_param(self):
|
||||
result = bind_params("%s", (ValueSequence((1, "a", 2.0)),))
|
||||
self.assertEquals(result, "( 1 , 'a' , 2.0 )")
|
||||
self.assertEqual(result, "( 1 , 'a' , 2.0 )")
|
||||
|
||||
def test_generator_param(self):
|
||||
result = bind_params("%s", ((i for i in xrange(3)),))
|
||||
self.assertEquals(result, "[ 0 , 1 , 2 ]")
|
||||
self.assertEqual(result, "[ 0 , 1 , 2 ]")
|
||||
|
||||
def test_none_param(self):
|
||||
result = bind_params("%s", (None,))
|
||||
self.assertEquals(result, "NULL")
|
||||
self.assertEqual(result, "NULL")
|
||||
|
||||
def test_list_collection(self):
|
||||
result = bind_params("%s", (['a', 'b', 'c'],))
|
||||
self.assertEquals(result, "[ 'a' , 'b' , 'c' ]")
|
||||
self.assertEqual(result, "[ 'a' , 'b' , 'c' ]")
|
||||
|
||||
def test_set_collection(self):
|
||||
result = bind_params("%s", (set(['a', 'b']),))
|
||||
@@ -59,11 +61,11 @@ class ParamBindingTest(unittest.TestCase):
|
||||
vals['b'] = 'b'
|
||||
vals['c'] = 'c'
|
||||
result = bind_params("%s", (vals,))
|
||||
self.assertEquals(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
|
||||
self.assertEqual(result, "{ 'a' : 'a' , 'b' : 'b' , 'c' : 'c' }")
|
||||
|
||||
def test_quote_escaping(self):
|
||||
result = bind_params("%s", ("""'ef''ef"ef""ef'""",))
|
||||
self.assertEquals(result, """'''ef''''ef"ef""ef'''""")
|
||||
self.assertEqual(result, """'''ef''''ef"ef""ef'''""")
|
||||
|
||||
|
||||
class BoundStatementTestCase(unittest.TestCase):
|
||||
|
||||
@@ -20,6 +20,7 @@ except ImportError:
|
||||
from itertools import islice, cycle
|
||||
from mock import Mock
|
||||
from random import randint
|
||||
import six
|
||||
import sys
|
||||
import struct
|
||||
from threading import Thread
|
||||
@@ -36,6 +37,8 @@ from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
|
||||
from cassandra.pool import Host
|
||||
from cassandra.query import Statement
|
||||
|
||||
from six.moves import xrange
|
||||
|
||||
|
||||
class TestLoadBalancingPolicy(unittest.TestCase):
|
||||
def test_non_implemented(self):
|
||||
@@ -137,13 +140,23 @@ class TestRoundRobinPolicy(unittest.TestCase):
|
||||
|
||||
# make the GIL switch after every instruction, maximizing
|
||||
# the chace of race conditions
|
||||
original_interval = sys.getcheckinterval()
|
||||
if six.PY2:
|
||||
original_interval = sys.getcheckinterval()
|
||||
else:
|
||||
original_interval = sys.getswitchinterval()
|
||||
|
||||
try:
|
||||
sys.setcheckinterval(0)
|
||||
if six.PY2:
|
||||
sys.setcheckinterval(0)
|
||||
else:
|
||||
sys.setswitchinterval(0.0001)
|
||||
map(lambda t: t.start(), threads)
|
||||
map(lambda t: t.join(), threads)
|
||||
finally:
|
||||
sys.setcheckinterval(original_interval)
|
||||
if six.PY2:
|
||||
sys.setcheckinterval(original_interval)
|
||||
else:
|
||||
sys.setswitchinterval(original_interval)
|
||||
|
||||
if errors:
|
||||
self.fail("Saw errors: %s" % (errors,))
|
||||
@@ -334,14 +347,14 @@ class TokenAwarePolicyTest(unittest.TestCase):
|
||||
|
||||
replicas = get_replicas(None, struct.pack('>i', i))
|
||||
other = set(h for h in hosts if h not in replicas)
|
||||
self.assertEquals(replicas, qplan[:2])
|
||||
self.assertEquals(other, set(qplan[2:]))
|
||||
self.assertEqual(replicas, qplan[:2])
|
||||
self.assertEqual(other, set(qplan[2:]))
|
||||
|
||||
# Should use the secondary policy
|
||||
for i in range(4):
|
||||
qplan = list(policy.make_query_plan())
|
||||
|
||||
self.assertEquals(set(qplan), set(hosts))
|
||||
self.assertEqual(set(qplan), set(hosts))
|
||||
|
||||
def test_wrap_dc_aware(self):
|
||||
cluster = Mock(spec=Cluster)
|
||||
@@ -374,16 +387,16 @@ class TokenAwarePolicyTest(unittest.TestCase):
|
||||
|
||||
# first should be the only local replica
|
||||
self.assertIn(qplan[0], replicas)
|
||||
self.assertEquals(qplan[0].datacenter, "dc1")
|
||||
self.assertEqual(qplan[0].datacenter, "dc1")
|
||||
|
||||
# then the local non-replica
|
||||
self.assertNotIn(qplan[1], replicas)
|
||||
self.assertEquals(qplan[1].datacenter, "dc1")
|
||||
self.assertEqual(qplan[1].datacenter, "dc1")
|
||||
|
||||
# then one of the remotes (used_hosts_per_remote_dc is 1, so we
|
||||
# shouldn't see two remotes)
|
||||
self.assertEquals(qplan[2].datacenter, "dc2")
|
||||
self.assertEquals(3, len(qplan))
|
||||
self.assertEqual(qplan[2].datacenter, "dc2")
|
||||
self.assertEqual(3, len(qplan))
|
||||
|
||||
class FakeCluster:
|
||||
def __init__(self):
|
||||
|
||||
@@ -346,7 +346,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
rf.send_request()
|
||||
|
||||
rf.add_callbacks(
|
||||
callback=self.assertEquals, callback_args=([{'col': 'val'}],),
|
||||
callback=self.assertEqual, callback_args=([{'col': 'val'}],),
|
||||
errback=self.assertIsInstance, errback_args=(Exception,))
|
||||
|
||||
result = Mock(spec=UnavailableErrorMessage, info={})
|
||||
@@ -358,7 +358,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
rf.send_request()
|
||||
|
||||
rf.add_callbacks(
|
||||
callback=self.assertEquals, callback_args=([{'col': 'val'}],),
|
||||
callback=self.assertEqual, callback_args=([{'col': 'val'}],),
|
||||
errback=self.assertIsInstance, errback_args=(Exception,))
|
||||
|
||||
rf._set_result(self.make_mock_response([{'col': 'val'}]))
|
||||
@@ -380,9 +380,9 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
|
||||
session.submit.assert_called_once()
|
||||
args, kwargs = session.submit.call_args
|
||||
self.assertEquals(rf._reprepare, args[-2])
|
||||
self.assertEqual(rf._reprepare, args[-2])
|
||||
self.assertIsInstance(args[-1], PrepareMessage)
|
||||
self.assertEquals(args[-1].query, "SELECT * FROM foobar")
|
||||
self.assertEqual(args[-1].query, "SELECT * FROM foobar")
|
||||
|
||||
def test_prepared_query_not_found_bad_keyspace(self):
|
||||
session = self.make_session()
|
||||
|
||||
@@ -163,18 +163,17 @@ class TypeTests(unittest.TestCase):
|
||||
'7a6970:org.apache.cassandra.db.marshal.UTF8Type',
|
||||
')')))
|
||||
|
||||
self.assertEquals(FooType, ctype.__class__)
|
||||
self.assertEqual(FooType, ctype.__class__)
|
||||
|
||||
self.assertEquals(UTF8Type, ctype.subtypes[0])
|
||||
self.assertEqual(UTF8Type, ctype.subtypes[0])
|
||||
|
||||
# middle subtype should be a BarType instance with its own subtypes and names
|
||||
self.assertIsInstance(ctype.subtypes[1], BarType)
|
||||
self.assertEquals([UTF8Type], ctype.subtypes[1].subtypes)
|
||||
self.assertEquals(["address"], ctype.subtypes[1].names)
|
||||
self.assertEqual([UTF8Type], ctype.subtypes[1].subtypes)
|
||||
self.assertEqual([b"address"], ctype.subtypes[1].names)
|
||||
|
||||
self.assertEquals(UTF8Type, ctype.subtypes[2])
|
||||
|
||||
self.assertEquals(['city', None, 'zip'], ctype.names)
|
||||
self.assertEqual(UTF8Type, ctype.subtypes[2])
|
||||
self.assertEqual([b'city', None, b'zip'], ctype.names)
|
||||
|
||||
def test_empty_value(self):
|
||||
self.assertEqual(str(EmptyValue()), 'EMPTY')
|
||||
|
||||
10
tox.ini
10
tox.ini
@@ -1,5 +1,5 @@
|
||||
[tox]
|
||||
envlist = py26,py27,pypy
|
||||
envlist = py26,py27,pypy,py33
|
||||
|
||||
[testenv]
|
||||
deps = nose
|
||||
@@ -8,5 +8,13 @@ deps = nose
|
||||
unittest2
|
||||
pip
|
||||
PyYAML
|
||||
six
|
||||
commands = {envpython} setup.py build_ext --inplace
|
||||
nosetests --verbosity=2 tests/unit/
|
||||
|
||||
[testenv:py33]
|
||||
deps = nose
|
||||
mock
|
||||
pip
|
||||
PyYAML
|
||||
six
|
||||
|
||||
Reference in New Issue
Block a user