Migrated casandra23 into cassandra library. Ported more secions.

This commit is contained in:
Tim Savage
2014-03-19 01:02:19 +11:00
parent 62e97f65f8
commit 714ef3234b
15 changed files with 166 additions and 552 deletions

View File

@@ -38,7 +38,7 @@ from cassandra.decoder import (QueryMessage, ResultMessage,
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
RESULT_KIND_SCHEMA_CHANGE)
from cassandra.metadata import Metadata
from cassandra.metrics import Metrics
# from cassandra.metrics import Metrics
from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy,
ExponentialReconnectionPolicy, HostDistance,
RetryPolicy)
@@ -376,6 +376,7 @@ class Cluster(object):
self._lock = RLock()
if self.metrics_enabled:
from cassandra.metrics import Metrics
self.metrics = Metrics(weakref.proxy(self))
self.control_connection = ControlConnection(

View File

@@ -6,11 +6,12 @@ from threading import Event, RLock
from six.moves.queue import Queue
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
from cassandra.marshal import int8_unpack, int32_pack
from cassandra.marshal import int8_unpack, int32_pack, header_unpack
from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage,
StartupMessage, ErrorMessage, CredentialsMessage,
QueryMessage, ResultMessage, decode_response,
InvalidRequestException, SupportedMessage)
import six
log = logging.getLogger(__name__)
@@ -102,7 +103,8 @@ def defunct_on_error(f):
return f(self, *args, **kwargs)
except Exception as exc:
self.defunct(exc)
# return f(self, *args, **kwargs)
# TODO: Clean up the above test code.
return wrapper
@@ -170,7 +172,7 @@ class Connection(object):
@defunct_on_error
def process_msg(self, msg, body_len):
version, flags, stream_id, opcode = (int8_unpack(f) for f in msg[:4])
version, flags, stream_id, opcode = header_unpack(msg[:4])
if stream_id < 0:
callback = None
else:
@@ -195,7 +197,7 @@ class Connection(object):
if body_len > 0:
body = msg[8:]
elif body_len == 0:
body = ""
body = six.binary_type()
else:
raise ProtocolError("Got negative body length: %r" % body_len)

View File

@@ -23,8 +23,7 @@ from uuid import UUID
import warnings
import six
from six.moves import cStringIO as StringIO
from six.moves import xrange
from six.moves import range
from cassandra.marshal import (int8_pack, int8_unpack, uint16_pack, uint16_unpack,
int32_pack, int32_unpack, int64_pack, int64_unpack,
@@ -34,12 +33,10 @@ from cassandra.util import OrderedDict
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
## Python 3 support #########
if six.PY3:
_number_types = frozenset((int, float))
else:
_number_types = frozenset((int, long, float))
#############################
try:
from blist import sortedset
@@ -78,7 +75,7 @@ class CassandraTypeType(type):
def __new__(metacls, name, bases, dct):
dct.setdefault('cassname', name)
cls = type.__new__(metacls, name, bases, dct)
if not name.startswith('_'):
if name != 'NewBase' and not name.startswith('_'):
_casstypes[name] = cls
return cls
@@ -167,8 +164,7 @@ class EmptyValue(object):
EMPTY = EmptyValue()
class _CassandraType(object):
__metaclass__ = CassandraTypeType
class _CassandraType(six.with_metaclass(CassandraTypeType, object)):
subtypes = ()
num_subtypes = 0
empty_binary_ok = False
@@ -189,9 +185,8 @@ class _CassandraType(object):
def __init__(self, val):
self.val = self.validate(val)
def __str__(self):
def __repr__(self):
return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
__repr__ = __str__
@staticmethod
def validate(val):
@@ -211,7 +206,7 @@ class _CassandraType(object):
"""
if byts is None:
return None
elif byts == '' and not cls.empty_binary_ok:
elif len(byts) == 0 and not cls.empty_binary_ok:
return EMPTY if cls.support_empty_values else None
return cls.deserialize(byts)
@@ -222,7 +217,7 @@ class _CassandraType(object):
more information. This method differs in that if None is passed in,
the result is the empty string.
"""
return '' if val is None else cls.serialize(val)
return six.binary_type() if val is None else cls.serialize(val)
@staticmethod
def deserialize(byts):
@@ -618,7 +613,7 @@ class _SimpleParameterizedType(_ParameterizedType):
numelements = uint16_unpack(byts[:2])
p = 2
result = []
for n in xrange(numelements):
for _ in range(numelements):
itemlen = uint16_unpack(byts[p:p + 2])
p += 2
item = byts[p:p + itemlen]
@@ -632,7 +627,7 @@ class _SimpleParameterizedType(_ParameterizedType):
raise TypeError("Received a string for a type that expects a sequence")
subtype, = cls.subtypes
buf = StringIO()
buf = six.BytesIO()
buf.write(uint16_pack(len(items)))
for item in items:
itembytes = subtype.to_binary(item)
@@ -668,7 +663,7 @@ class MapType(_ParameterizedType):
numelements = uint16_unpack(byts[:2])
p = 2
themap = OrderedDict()
for n in xrange(numelements):
for _ in range(numelements):
key_len = uint16_unpack(byts[p:p + 2])
p += 2
keybytes = byts[p:p + key_len]
@@ -685,7 +680,7 @@ class MapType(_ParameterizedType):
@classmethod
def serialize_safe(cls, themap):
subkeytype, subvaltype = cls.subtypes
buf = StringIO()
buf = six.BytesIO()
buf.write(uint16_pack(len(themap)))
try:
items = themap.iteritems()

View File

@@ -2,14 +2,13 @@ import logging
import socket
from uuid import UUID
# from six.moves import cStringIO as StringIO
from six.moves import xrange
from six import BytesIO
import six
from six.moves import range
from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
AlreadyExists, InvalidRequest, Unauthorized)
from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
int8_pack, int8_unpack)
int8_pack, int8_unpack, header_pack)
from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
CounterColumnType, DateType, DecimalType,
DoubleType, FloatType, Int32Type,
@@ -34,61 +33,69 @@ HEADER_DIRECTION_FROM_CLIENT = 0x00
HEADER_DIRECTION_TO_CLIENT = 0x80
HEADER_DIRECTION_MASK = 0x80
COMPRESSED_FLAG = 0x01
TRACING_FLAG = 0x02
_message_types_by_name = {}
_message_types_by_opcode = {}
class _register_msg_type(type):
class _RegisterMessageType(type):
def __init__(cls, name, bases, dct):
if not name.startswith('_'):
if name not in ('NewBase', '_MessageType'):
_message_types_by_name[cls.name] = cls
_message_types_by_opcode[cls.opcode] = cls
class _MessageType(object):
__metaclass__ = _register_msg_type
class _MessageType(six.with_metaclass(_RegisterMessageType, object)):
tracing = False
def to_binary(self, stream_id, protocol_version, compression=None):
body = BytesIO()
body = six.BytesIO()
self.send_body(body, protocol_version)
body = body.getvalue()
version = protocol_version | HEADER_DIRECTION_FROM_CLIENT
flags = 0
if compression is not None and len(body) > 0:
body = compression(body)
flags |= 0x01
if self.tracing:
flags |= 0x02
msglen = int32_pack(len(body))
msg_parts = [int8_pack(i) for i in (version, flags, stream_id, self.opcode)] + [msglen, body]
return six.binary_type().join(msg_parts)
def __str__(self):
paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in _get_params(self)]
return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs))
__repr__ = __str__
flags = 0
if compression and len(body) > 0:
body = compression(body)
flags |= COMPRESSED_FLAG
if self.tracing:
flags |= TRACING_FLAG
msg = six.BytesIO()
write_header(
msg,
protocol_version | HEADER_DIRECTION_FROM_CLIENT,
flags, stream_id, self.opcode, len(body)
)
msg.write(body)
return msg.getvalue()
def __repr__(self):
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
def _get_params(message_obj):
base_attrs = dir(_MessageType)
return [a for a in dir(message_obj)
if a not in base_attrs and not a.startswith('_') and not callable(getattr(message_obj, a))]
return (
(n, a) for n, a in message_obj.__dict__.items()
if n not in base_attrs and not n.startswith('_') and not callable(a)
)
def decode_response(stream_id, flags, opcode, body, decompressor=None):
if flags & 0x01:
if flags & COMPRESSED_FLAG:
if decompressor is None:
raise Exception("No decompressor available for compressed frame!")
raise Exception("No de-compressor available for compressed frame!")
body = decompressor(body)
flags ^= 0x01
flags ^= COMPRESSED_FLAG
body = BytesIO(body)
if flags & 0x02:
body = six.BytesIO(body)
if flags & TRACING_FLAG:
trace_id = UUID(bytes=body.read(16))
flags ^= 0x02
flags ^= TRACING_FLAG
else:
trace_id = None
@@ -142,14 +149,13 @@ class ErrorMessage(_MessageType, Exception):
return self
class ErrorMessageSubclass(_register_msg_type):
class ErrorMessageSubclass(_RegisterMessageType):
def __init__(cls, name, bases, dct):
if cls.error_code is not None:
if name not in ('NewBase', ) and cls.error_code:
error_classes[cls.error_code] = cls
class ErrorMessageSub(ErrorMessage):
__metaclass__ = ErrorMessageSubclass
class ErrorMessageSub(six.with_metaclass(ErrorMessageSubclass, ErrorMessage)):
error_code = None
@@ -450,7 +456,7 @@ class ResultMessage(_MessageType):
def recv_results_rows(cls, f):
column_metadata = cls.recv_results_metadata(f)
rowcount = read_int(f)
rows = [cls.recv_row(f, len(column_metadata)) for x in xrange(rowcount)]
rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
colnames = [c[2] for c in column_metadata]
coltypes = [c[3] for c in column_metadata]
return (colnames, [tuple(ctype.from_binary(val) for ctype, val in zip(coltypes, row))
@@ -471,7 +477,7 @@ class ResultMessage(_MessageType):
ksname = read_string(f)
cfname = read_string(f)
column_metadata = []
for x in xrange(colcount):
for _ in range(colcount):
if glob_tblspec:
colksname = ksname
colcfname = cfname
@@ -513,7 +519,7 @@ class ResultMessage(_MessageType):
@staticmethod
def recv_row(f, colcount):
return [read_value(f) for x in xrange(colcount)]
return [read_value(f) for _ in range(colcount)]
class PrepareMessage(_MessageType):
@@ -635,6 +641,14 @@ class EventMessage(_MessageType):
return dict(change_type=change_type, keyspace=keyspace, table=table)
def write_header(f, version, flags, stream_id, opcode, length):
"""
Write a CQL protocol frame header.
"""
f.write(header_pack(version, flags, stream_id, opcode))
write_int(f, length)
def read_byte(f):
return int8_unpack(f.read(1))
@@ -701,7 +715,7 @@ def write_longstring(f, s):
def read_stringlist(f):
numstrs = read_short(f)
return [read_string(f) for x in xrange(numstrs)]
return [read_string(f) for _ in range(numstrs)]
def write_stringlist(f, stringlist):
@@ -713,7 +727,7 @@ def write_stringlist(f, stringlist):
def read_stringmap(f):
numpairs = read_short(f)
strmap = {}
for x in xrange(numpairs):
for _ in range(numpairs):
k = read_string(f)
strmap[k] = read_string(f)
return strmap
@@ -729,7 +743,7 @@ def write_stringmap(f, strmap):
def read_stringmultimap(f):
numkeys = read_short(f)
strmmap = {}
for x in xrange(numkeys):
for _ in range(numkeys):
k = read_string(f)
strmmap[k] = read_stringlist(f)
return strmmap

View File

@@ -8,9 +8,9 @@ import six
from cassandra.util import OrderedDict
if six.PY3:
unicode = str
long = int
# if six.PY3:
# unicode = str
# long = int
def cql_quote(term):

View File

@@ -183,10 +183,10 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
return
self.is_defunct = True
trace = traceback.format_exc(exc)
trace = traceback.format_exc() #exc)
if trace != "None":
log.debug("Defuncting connection (%s) to %s: %s\n%s",
id(self), self.host, exc, traceback.format_exc(exc))
id(self), self.host, exc, traceback.format_exc())
else:
log.debug("Defuncting connection (%s) to %s: %s",
id(self), self.host, exc)

View File

@@ -23,6 +23,11 @@ uint8_pack, uint8_unpack = _make_packer('>B')
float_pack, float_unpack = _make_packer('>f')
double_pack, double_unpack = _make_packer('>d')
# Special case for cassandra header
header_struct = struct.Struct('>BBBB')
header_pack = header_struct.pack
header_unpack = header_struct.unpack
def varint_unpack(term):
val = int(term.encode('hex'), 16)

View File

@@ -7,6 +7,7 @@ import logging
import re
from threading import RLock
import weakref
import six
murmur3 = None
try:
@@ -779,7 +780,7 @@ class TableMetadata(object):
def protect_name(name):
if isinstance(name, unicode):
if isinstance(name, six.text_type):
name = name.encode('utf8')
return maybe_escape_name(name)
@@ -1047,7 +1048,7 @@ class BytesToken(Token):
def __init__(self, token_string):
""" `token_string` should be string representing the token. """
if not isinstance(token_string, basestring):
if not isinstance(token_string, six.string_types):
raise TypeError(
"Tokens for ByteOrderedPartitioner should be strings (got %s)"
% (type(token_string),))

View File

@@ -5,7 +5,7 @@ from threading import Lock
from cassandra import ConsistencyLevel
from six.moves import xrange
from six.moves import range
log = logging.getLogger(__name__)
@@ -512,7 +512,7 @@ class ExponentialReconnectionPolicy(ReconnectionPolicy):
self.max_delay = max_delay
def new_schedule(self):
return (min(self.base_delay * (2 ** i), self.max_delay) for i in xrange(64))
return (min(self.base_delay * (2 ** i), self.max_delay) for i in range(64))
class WriteType(object):

View File

@@ -1,409 +0,0 @@
import struct
import six
from six.moves import range
import uuid
# Low level byte pack and unpack methods.
def _make_pack_unpack_field(format):
s = struct.Struct(format)
return (
s.pack,
lambda b: s.unpack(b)[0]
)
_header_struct = struct.Struct('!BBBBL')
pack_cql_header = _header_struct.pack
unpack_cql_header = _header_struct.unpack
pack_cql_byte, unpack_cql_byte = _make_pack_unpack_field('!B')
pack_cql_int, unpack_cql_int = _make_pack_unpack_field('!i')
pack_cql_short, unpack_cql_short = _make_pack_unpack_field('!H')
# Maximum values for these data types.
MAX_INT = 0x7FFFFFFF
MAX_SHORT = 0xFFFF
def read_header(f):
"""
Read a CQL protocol frame header.
A frame header consists of 4 bytes for the fields version, flags, stream and opcode. This is followed by a 4
byte length field, reading a total of 8 bytes.
:returns: tuple consisting of the version, flags, stream, opcode and length fields.
"""
return unpack_cql_header(f.read(8))
def write_header(f, version, flags, stream_id, opcode, length):
"""
Write a CQL protocol frame header.
"""
f.write(pack_cql_header(version, flags, stream_id, opcode, length))
def read_byte(f):
return f.read()
def write_byte(f, v):
f.write(pack_cql_byte(v))
def read_int(f):
return unpack_cql_int(f.read(4))
def write_int(f, v):
f.write(pack_cql_int(v))
def read_short(f):
return unpack_cql_short(f.read(2))
def write_short(f, v):
f.write(pack_cql_short(v))
def read_string(f):
"""
:returns: Python 3 returns a str; Python 2 returns a unicode string.
"""
n = f.read_short()
return f.read(n).decode('UTF8')
def write_string(f, v):
# TODO: Should really check that a short string isn't longer than a 2^2.
if isinstance(v, six.text_type):
b = v.encode('UTF8')
write_short(f, len(b))
f.write(b)
elif isinstance(v, str):
# This assumes that str will be caught by the previous if statement with Python 3.
write_short(f, len(v))
f.write(v)
else:
write_string(f, str(v))
def read_long_string(f):
"""
:returns: Python 3 returns a str; Python 2 returns a unicode string.
"""
n = read_int(f)
return f.read(n).decode('UTF8')
def write_long_string(f, v):
# TODO: Should really check that a long string isn't longer than a 2^4 / 2.
if isinstance(v, six.text_type):
b = v.encode('UTF8')
write_int(f, len(b))
f.write(b)
elif isinstance(v, str):
# This assumes that str will be caught by the previous if statement with Python 3.
write_int(f, len(v))
f.write(v)
else:
write_long_string(f, str(v))
def read_uuid(f):
return uuid.UUID(bytes=f.read(16))
def write_uuid(f, v):
assert isinstance(v, uuid.UUID)
f.write(v.bytes)
def read_string_list(f):
n = read_short(f)
return [read_string(f) for _ in range(n)]
def write_string_list(f, v):
n = len(v)
for idx in range(n):
write_string(f, v[idx])
def read_bytes(f):
n = read_int(f)
return None if n < 0 else f.read(n)
def write_bytes(f, v):
if v is None:
write_int(f, -1)
f.write(v)
else:
write_int(f, len(v))
f.write(v)
def read_short_bytes(f):
n = read_short(f)
return f.read(n)
def write_short_bytes(f, v):
if v is None:
write_short(f, 0)
else:
n = len(v)
assert n <= MAX_SHORT
write_short(f, n)
f.write(v)
def read_inet(f):
n = f.read(1)
values = f.read(n)
raise NotImplementedError
def write_inet(f, v):
raise NotImplementedError
read_consistency = read_short
write_consistency = write_short
def read_string_map(f):
n = read_short(f)
return dict((read_string(f), read_string(f)) for _ in range(n))
def write_string_map(f, v):
write_short(f, len(v))
for key, value in six.iteritems(v):
write_string(f, key)
write_string(f, value)
def read_string_multimap(f):
n = read_short(f)
return dict((read_string(f), read_string_list(f)) for _ in range(n))
def write_string_multimap(f, v):
write_short(f, len(v))
for key, value in six.iteritems(v):
write_string(f, key)
write_string_list(f, value)
## Define messages ##############################
HEADER_DIRECTION_FROM_CLIENT = 0x00
HEADER_DIRECTION_TO_CLIENT = 0x80
HEADER_DIRECTION_MASK = 0x80
COMPRESSED_FLAG = 0x01
TRACING_FLAG = 0x02
_message_types_by_name = {}
_message_types_by_opcode = {}
_error_classes = {}
class _RegisterMessageType(type):
def __init__(cls, what, *args, **kwargs):
if what not in ('_MessageType', 'NewBase'):
_message_types_by_name[cls.name] = cls
_message_types_by_opcode[cls.opcode] = cls
def _get_params(message_obj):
base_attrs = dir(_MessageType)
return (
(n, a) for n, a in message_obj.__dict__.items()
if n not in base_attrs and not n.startswith('_') and not callable(a)
)
class _MessageType(six.with_metaclass(_RegisterMessageType, object)):
opcode = None
name = None
tracing = False
def __repr__(self):
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
def send_body(self, buf, protocol_version):
"""
Encode the body of this message for sending.
:param buf: An instance of `ByteBuffer`.
:param protocol_version: Version of the protocol currently being used.
"""
pass
def to_binary(self, stream_id, protocol_version, compression=None):
"""
Pack this message into it's binary format.
"""
body = six.BytesIO()
self.send_body(body, protocol_version)
body = body.getvalue()
flags = 0
if compression and len(body) > 0:
body = compression(body)
flags |= COMPRESSED_FLAG
if self.tracing:
flags |= TRACING_FLAG
msg = six.BytesIO()
write_header(
msg,
protocol_version | HEADER_DIRECTION_FROM_CLIENT,
flags, stream_id, self.opcode, len(body)
)
msg.write(body)
return msg.getvalue()
def decode_response(stream_id, flags, opcode, body, decompressor=None):
"""
Build msg class.
"""
if flags & COMPRESSED_FLAG:
if callable(decompressor):
body = decompressor(body)
flags ^= COMPRESSED_FLAG
else:
raise TypeError("De-compressor not available for compressed frame!")
body = six.BytesIO(body)
if flags & TRACING_FLAG:
trace_id = read_uuid(body)
flags ^= TRACING_FLAG
else:
trace_id = None
if flags:
# TODO: log.warn("Unknown protocol flags set: %02x. May cause problems.", flags)
pass
msg_class = _message_types_by_opcode[opcode]
msg = msg_class.recv_body(body)
msg.stream_id = stream_id
msg.trace_id = trace_id
return msg
class StartupMessage(_MessageType):
opcode = 0x01
name = 'STARTUP'
KNOWN_OPTION_KEYS = set(('CQL_VERSION', 'COMPRESSION',))
def __init__(self, cqlversion, options):
self.cqlversion = cqlversion
self.options = options
def send_body(self, f, protocol_version):
opt_map = self.options.copy()
opt_map['CQL_VERSION'] = self.cqlversion
write_string_map(f, opt_map)
class ReadyMessage(_MessageType):
opcode = 0x02
name = 'READY'
@classmethod
def recv_body(cls, f):
return cls()
class AuthenticateMessage(_MessageType):
opcode = 0x03
name = 'AUTHENTICATE'
def __init__(self, authenticator):
self.authenticator = authenticator
@classmethod
def recv_body(cls, f):
authenticator = read_string(f)
return cls(authenticator)
class CredentialsMessage(_MessageType):
opcode = 0x04
name = 'CREDENTIALS'
def __init__(self, credentials):
self.credentials = credentials
def send_body(self, f, protocol_version):
write_string_map(f, self.credentials)
class OptionsMessage(_MessageType):
opcode = 0x05
name = 'OPTIONS'
class SupportedMessage(_MessageType):
opcode = 0x06
name = 'SUPPORTED'
def __init__(self, cql_versions, options):
self.cql_versions = cql_versions
self.options = options
@classmethod
def recv_body(cls, f):
options = read_string_multimap(f)
cql_versions = options.pop('CQL_VERSION')
return cls(cql_versions, options)
class QueryMessage(_MessageType):
opcode = 0x07
name = 'QUERY'
def __init__(self, query, consistency_level):
self.query = query
self.consistency_level = consistency_level
def send_body(self, f, protocol_version):
write_long_string(f, self.query)
write_consistency(f, self.consistency_level)
@classmethod
def recv_body(cls, f):
query = read_long_string(f)
consistency_level = read_consistency(f)
return cls(query, consistency_level)
class ResultMessage(_MessageType):
opcode = 0x08
name = 'RESULT'
def __init__(self, kind, results):
self.kind = kind
self.results = results
class PrepareMessage(_MessageType):
opcode = 0x09
name = 'PREPARE'
def __init__(self, query):
self.query = query
def send_body(self, f, protocol_version):
write_long_string(f, self.query)

View File

@@ -1,3 +1,5 @@
import six
try:
import unittest2 as unittest
except ImportError:
@@ -6,7 +8,7 @@ except ImportError:
import errno
import os
from six.moves import StringIO
from six import BytesIO
import socket
from socket import error as socket_error
@@ -18,7 +20,7 @@ from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT,
from cassandra.decoder import (write_stringmultimap, write_int, write_string,
SupportedMessage, ReadyMessage, ServerError)
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
from cassandra.marshal import uint8_pack, uint8_unpack, uint32_pack, int32_pack
from cassandra.io.asyncorereactor import AsyncoreConnection
@@ -43,7 +45,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
return c
def make_header_prefix(self, message_class, version=2, stream_id=0):
return ''.join(map(uint8_pack, [
return six.binary_type().join(map(uint8_pack, [
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
0, # flags (compression)
stream_id,
@@ -51,7 +53,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
]))
def make_options_body(self):
options_buf = StringIO()
options_buf = BytesIO()
write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.1'],
'COMPRESSION': []
@@ -59,12 +61,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
return options_buf.getvalue()
def make_error_body(self, code, msg):
buf = StringIO()
buf = BytesIO()
write_int(buf, code)
write_string(buf, msg)
return buf.getvalue()
def make_msg(self, header, body=""):
def make_msg(self, header, body=six.binary_type()):
return header + uint32_pack(len(body)) + body
def test_successful_connection(self, *args):
@@ -93,12 +95,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
# get a connection that's already fully started
c = self.test_successful_connection()
header = '\x00\x00\x00\x00' + int32_pack(20000)
header = six.b('\x00\x00\x00\x00') + int32_pack(20000)
responses = [
header + ('a' * (4096 - len(header))),
'a' * 4096,
header + (six.b('a') * (4096 - len(header))),
six.b('a') * 4096,
socket_error(errno.EAGAIN),
'a' * 100,
six.b('a') * 100,
socket_error(errno.EAGAIN)]
def side_effect(*args):
@@ -225,14 +227,13 @@ class AsyncoreConnectionTest(unittest.TestCase):
options = self.make_options_body()
message = self.make_msg(header, options)
# read in the first byte
c.socket.recv.return_value = message[0]
c.socket.recv.return_value = message[0:1]
c.handle_read()
self.assertEquals(c._iobuf.getvalue(), message[0])
self.assertEquals(c._iobuf.getvalue(), message[0:1])
c.socket.recv.return_value = message[1:]
c.handle_read()
self.assertEquals("", c._iobuf.getvalue())
self.assertEquals(six.binary_type(), c._iobuf.getvalue())
# let it write out a StartupMessage
c.handle_write()
@@ -259,7 +260,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# ... then read in the rest
c.socket.recv.return_value = message[9:]
c.handle_read()
self.assertEquals("", c._iobuf.getvalue())
self.assertEquals(six.binary_type(), c._iobuf.getvalue())
# let it write out a StartupMessage
c.handle_write()

View File

@@ -1,9 +1,11 @@
import six
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from six.moves import StringIO
from six import BytesIO
from mock import Mock, ANY
@@ -25,7 +27,7 @@ class ConnectionTest(unittest.TestCase):
return c
def make_header_prefix(self, message_class, version=2, stream_id=0):
return ''.join(map(uint8_pack, [
return six.binary_type().join(map(uint8_pack, [
0xff & (HEADER_DIRECTION_TO_CLIENT | version),
0, # flags (compression)
stream_id,
@@ -33,7 +35,7 @@ class ConnectionTest(unittest.TestCase):
]))
def make_options_body(self):
options_buf = StringIO()
options_buf = BytesIO()
write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.1'],
'COMPRESSION': []
@@ -41,7 +43,7 @@ class ConnectionTest(unittest.TestCase):
return options_buf.getvalue()
def make_error_body(self, code, msg):
buf = StringIO()
buf = BytesIO()
write_int(buf, code)
write_string(buf, msg)
return buf.getvalue()
@@ -73,12 +75,12 @@ class ConnectionTest(unittest.TestCase):
c.defunct = Mock()
# read in a SupportedMessage response
header = ''.join(map(uint8_pack, [
header = six.binary_type().join(uint8_pack(i) for i in (
0xff & (HEADER_DIRECTION_FROM_CLIENT | self.protocol_version),
0, # flags (compression)
0,
SupportedMessage.opcode # opcode
]))
))
options = self.make_options_body()
message = self.make_msg(header, options)
c.process_msg(message, len(message) - 8)
@@ -115,7 +117,7 @@ class ConnectionTest(unittest.TestCase):
# read in a SupportedMessage response
header = self.make_header_prefix(SupportedMessage)
options_buf = StringIO()
options_buf = BytesIO()
write_stringmultimap(options_buf, {
'CQL_VERSION': ['7.8.9'],
'COMPRESSION': []

View File

@@ -1,3 +1,5 @@
import six
try:
import unittest2 as unittest
except ImportError:
@@ -20,55 +22,55 @@ marshalled_value_pairs = (
# binary form, type, python native type
('lorem ipsum dolor sit amet', 'AsciiType', 'lorem ipsum dolor sit amet'),
('', 'AsciiType', ''),
('\x01', 'BooleanType', True),
('\x00', 'BooleanType', False),
('', 'BooleanType', None),
('\xff\xfe\xfd\xfc\xfb', 'BytesType', '\xff\xfe\xfd\xfc\xfb'),
('', 'BytesType', ''),
('\x7f\xff\xff\xff\xff\xff\xff\xff', 'CounterColumnType', 9223372036854775807),
('\x80\x00\x00\x00\x00\x00\x00\x00', 'CounterColumnType', -9223372036854775808),
('', 'CounterColumnType', None),
('\x00\x00\x013\x7fb\xeey', 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)),
('', 'DateType', None),
('\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe', 'DecimalType', Decimal('1243878957943.1234124191998')),
('\x00\x00\x00\x06\xe5\xde]\x98Y', 'DecimalType', Decimal('-112233.441191')),
('\x00\x00\x00\x14\x00\xfa\xce', 'DecimalType', Decimal('0.00000000000000064206')),
('\x00\x00\x00\x14\xff\x052', 'DecimalType', Decimal('-0.00000000000000064206')),
('\xff\xff\xff\x9c\x00\xfa\xce', 'DecimalType', Decimal('64206e100')),
('', 'DecimalType', None),
('@\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', 19432.125),
('\xc0\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', -19432.125),
('\x7f\xef\x00\x00\x00\x00\x00\x00', 'DoubleType', 1.7415152243978685e+308),
('', 'DoubleType', None),
('F\x97\xd0@', 'FloatType', 19432.125),
('\xc6\x97\xd0@', 'FloatType', -19432.125),
('\xc6\x97\xd0@', 'FloatType', -19432.125),
('\x7f\x7f\x00\x00', 'FloatType', 338953138925153547590470800371487866880.0),
('', 'FloatType', None),
('\x7f\x50\x00\x00', 'Int32Type', 2135949312),
('\xff\xfd\xcb\x91', 'Int32Type', -144495),
('', 'Int32Type', None),
('f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15', 'IntegerType', 123456789123456789123456789),
('', 'IntegerType', None),
('\x7f\xff\xff\xff\xff\xff\xff\xff', 'LongType', 9223372036854775807),
('\x80\x00\x00\x00\x00\x00\x00\x00', 'LongType', -9223372036854775808),
('', 'LongType', None),
('', 'InetAddressType', None),
('A46\xa9', 'InetAddressType', '65.52.54.169'),
('*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"', 'InetAddressType', '2a00:1328:e102:ccc0::122'),
('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', u'\u307e\u3057\u3066'),
('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', u'\u307e\u3057\u3066' * 1000),
('', 'UTF8Type', u''),
('\xff' * 16, 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')),
('I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*', 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')),
('', 'UUIDType', None),
('', 'MapType(AsciiType, BooleanType)', None),
('', 'ListType(FloatType)', None),
('', 'SetType(LongType)', None),
('\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()),
('\x00\x00', 'ListType(FloatType)', ()),
('\x00\x00', 'SetType(IntegerType)', sortedset()),
('\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', (UUID(bytes='\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0'),)),
(six.b('\x01'), 'BooleanType', True),
(six.b('\x00'), 'BooleanType', False),
(six.b(''), 'BooleanType', None),
(six.b('\xff\xfe\xfd\xfc\xfb'), 'BytesType', '\xff\xfe\xfd\xfc\xfb'),
(six.b(''), 'BytesType', ''),
(six.b('\x7f\xff\xff\xff\xff\xff\xff\xff'), 'CounterColumnType', 9223372036854775807),
(six.b('\x80\x00\x00\x00\x00\x00\x00\x00'), 'CounterColumnType', -9223372036854775808),
(six.b(''), 'CounterColumnType', None),
(six.b('\x00\x00\x013\x7fb\xeey'), 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)),
(six.b(''), 'DateType', None),
(six.b('\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe'), 'DecimalType', Decimal('1243878957943.1234124191998')),
(six.b('\x00\x00\x00\x06\xe5\xde]\x98Y'), 'DecimalType', Decimal('-112233.441191')),
(six.b('\x00\x00\x00\x14\x00\xfa\xce'), 'DecimalType', Decimal('0.00000000000000064206')),
(six.b('\x00\x00\x00\x14\xff\x052'), 'DecimalType', Decimal('-0.00000000000000064206')),
(six.b('\xff\xff\xff\x9c\x00\xfa\xce'), 'DecimalType', Decimal('64206e100')),
(six.b(''), 'DecimalType', None),
(six.b('@\xd2\xfa\x08\x00\x00\x00\x00'), 'DoubleType', 19432.125),
(six.b('\xc0\xd2\xfa\x08\x00\x00\x00\x00'), 'DoubleType', -19432.125),
(six.b('\x7f\xef\x00\x00\x00\x00\x00\x00'), 'DoubleType', 1.7415152243978685e+308),
(six.b(''), 'DoubleType', None),
(six.b('F\x97\xd0@'), 'FloatType', 19432.125),
(six.b('\xc6\x97\xd0@'), 'FloatType', -19432.125),
(six.b('\xc6\x97\xd0@'), 'FloatType', -19432.125),
(six.b('\x7f\x7f\x00\x00'), 'FloatType', 338953138925153547590470800371487866880.0),
(six.b(''), 'FloatType', None),
(six.b('\x7f\x50\x00\x00'), 'Int32Type', 2135949312),
(six.b('\xff\xfd\xcb\x91'), 'Int32Type', -144495),
(six.b(''), 'Int32Type', None),
(six.b('f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15'), 'IntegerType', 123456789123456789123456789),
(six.b(''), 'IntegerType', None),
(six.b('\x7f\xff\xff\xff\xff\xff\xff\xff'), 'LongType', 9223372036854775807),
(six.b('\x80\x00\x00\x00\x00\x00\x00\x00'), 'LongType', -9223372036854775808),
(six.b(''), 'LongType', None),
(six.b(''), 'InetAddressType', None),
(six.b('A46\xa9'), 'InetAddressType', '65.52.54.169'),
(six.b('*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"'), 'InetAddressType', '2a00:1328:e102:ccc0::122'),
(six.b('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6'), 'UTF8Type', u'\u307e\u3057\u3066'),
(six.b('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000), 'UTF8Type', u'\u307e\u3057\u3066' * 1000),
(six.b(''), 'UTF8Type', u''),
(six.b('\xff' * 16), 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')),
(six.b('I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*'), 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')),
(six.b(''), 'UUIDType', None),
(six.b(''), 'MapType(AsciiType, BooleanType)', None),
(six.b(''), 'ListType(FloatType)', None),
(six.b(''), 'SetType(LongType)', None),
(six.b('\x00\x00'), 'MapType(DecimalType, BooleanType)', OrderedDict()),
(six.b('\x00\x00'), 'ListType(FloatType)', ()),
(six.b('\x00\x00'), 'SetType(IntegerType)', sortedset()),
(six.b('\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0'), 'ListType(TimeUUIDType)', (UUID(bytes=six.b('\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')),)),
)
ordered_dict_value = OrderedDict()

View File

@@ -206,8 +206,8 @@ class TestTokens(unittest.TestCase):
def test_md5_tokens(self):
md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1)
self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808L)
self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639L)
self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808)
self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639)
self.assertEqual(str(md5_token), '<MD5Token: -9223372036854775809L>')
def test_bytes_tokens(self):