Migrated casandra23 into cassandra library. Ported more secions.

This commit is contained in:
Tim Savage
2014-03-19 01:02:19 +11:00
parent 62e97f65f8
commit 714ef3234b
15 changed files with 166 additions and 552 deletions

View File

@@ -38,7 +38,7 @@ from cassandra.decoder import (QueryMessage, ResultMessage,
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
RESULT_KIND_SCHEMA_CHANGE) RESULT_KIND_SCHEMA_CHANGE)
from cassandra.metadata import Metadata from cassandra.metadata import Metadata
from cassandra.metrics import Metrics # from cassandra.metrics import Metrics
from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy, from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy,
ExponentialReconnectionPolicy, HostDistance, ExponentialReconnectionPolicy, HostDistance,
RetryPolicy) RetryPolicy)
@@ -376,6 +376,7 @@ class Cluster(object):
self._lock = RLock() self._lock = RLock()
if self.metrics_enabled: if self.metrics_enabled:
from cassandra.metrics import Metrics
self.metrics = Metrics(weakref.proxy(self)) self.metrics = Metrics(weakref.proxy(self))
self.control_connection = ControlConnection( self.control_connection = ControlConnection(

View File

@@ -6,11 +6,12 @@ from threading import Event, RLock
from six.moves.queue import Queue from six.moves.queue import Queue
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
from cassandra.marshal import int8_unpack, int32_pack from cassandra.marshal import int8_unpack, int32_pack, header_unpack
from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage, from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage,
StartupMessage, ErrorMessage, CredentialsMessage, StartupMessage, ErrorMessage, CredentialsMessage,
QueryMessage, ResultMessage, decode_response, QueryMessage, ResultMessage, decode_response,
InvalidRequestException, SupportedMessage) InvalidRequestException, SupportedMessage)
import six
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -102,7 +103,8 @@ def defunct_on_error(f):
return f(self, *args, **kwargs) return f(self, *args, **kwargs)
except Exception as exc: except Exception as exc:
self.defunct(exc) self.defunct(exc)
# return f(self, *args, **kwargs)
# TODO: Clean up the above test code.
return wrapper return wrapper
@@ -170,7 +172,7 @@ class Connection(object):
@defunct_on_error @defunct_on_error
def process_msg(self, msg, body_len): def process_msg(self, msg, body_len):
version, flags, stream_id, opcode = (int8_unpack(f) for f in msg[:4]) version, flags, stream_id, opcode = header_unpack(msg[:4])
if stream_id < 0: if stream_id < 0:
callback = None callback = None
else: else:
@@ -195,7 +197,7 @@ class Connection(object):
if body_len > 0: if body_len > 0:
body = msg[8:] body = msg[8:]
elif body_len == 0: elif body_len == 0:
body = "" body = six.binary_type()
else: else:
raise ProtocolError("Got negative body length: %r" % body_len) raise ProtocolError("Got negative body length: %r" % body_len)

View File

@@ -23,8 +23,7 @@ from uuid import UUID
import warnings import warnings
import six import six
from six.moves import cStringIO as StringIO from six.moves import range
from six.moves import xrange
from cassandra.marshal import (int8_pack, int8_unpack, uint16_pack, uint16_unpack, from cassandra.marshal import (int8_pack, int8_unpack, uint16_pack, uint16_unpack,
int32_pack, int32_unpack, int64_pack, int64_unpack, int32_pack, int32_unpack, int64_pack, int64_unpack,
@@ -34,12 +33,10 @@ from cassandra.util import OrderedDict
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
## Python 3 support #########
if six.PY3: if six.PY3:
_number_types = frozenset((int, float)) _number_types = frozenset((int, float))
else: else:
_number_types = frozenset((int, long, float)) _number_types = frozenset((int, long, float))
#############################
try: try:
from blist import sortedset from blist import sortedset
@@ -78,7 +75,7 @@ class CassandraTypeType(type):
def __new__(metacls, name, bases, dct): def __new__(metacls, name, bases, dct):
dct.setdefault('cassname', name) dct.setdefault('cassname', name)
cls = type.__new__(metacls, name, bases, dct) cls = type.__new__(metacls, name, bases, dct)
if not name.startswith('_'): if name != 'NewBase' and not name.startswith('_'):
_casstypes[name] = cls _casstypes[name] = cls
return cls return cls
@@ -167,8 +164,7 @@ class EmptyValue(object):
EMPTY = EmptyValue() EMPTY = EmptyValue()
class _CassandraType(object): class _CassandraType(six.with_metaclass(CassandraTypeType, object)):
__metaclass__ = CassandraTypeType
subtypes = () subtypes = ()
num_subtypes = 0 num_subtypes = 0
empty_binary_ok = False empty_binary_ok = False
@@ -189,9 +185,8 @@ class _CassandraType(object):
def __init__(self, val): def __init__(self, val):
self.val = self.validate(val) self.val = self.validate(val)
def __str__(self): def __repr__(self):
return '<%s( %r )>' % (self.cql_parameterized_type(), self.val) return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
__repr__ = __str__
@staticmethod @staticmethod
def validate(val): def validate(val):
@@ -211,7 +206,7 @@ class _CassandraType(object):
""" """
if byts is None: if byts is None:
return None return None
elif byts == '' and not cls.empty_binary_ok: elif len(byts) == 0 and not cls.empty_binary_ok:
return EMPTY if cls.support_empty_values else None return EMPTY if cls.support_empty_values else None
return cls.deserialize(byts) return cls.deserialize(byts)
@@ -222,7 +217,7 @@ class _CassandraType(object):
more information. This method differs in that if None is passed in, more information. This method differs in that if None is passed in,
the result is the empty string. the result is the empty string.
""" """
return '' if val is None else cls.serialize(val) return six.binary_type() if val is None else cls.serialize(val)
@staticmethod @staticmethod
def deserialize(byts): def deserialize(byts):
@@ -618,7 +613,7 @@ class _SimpleParameterizedType(_ParameterizedType):
numelements = uint16_unpack(byts[:2]) numelements = uint16_unpack(byts[:2])
p = 2 p = 2
result = [] result = []
for n in xrange(numelements): for _ in range(numelements):
itemlen = uint16_unpack(byts[p:p + 2]) itemlen = uint16_unpack(byts[p:p + 2])
p += 2 p += 2
item = byts[p:p + itemlen] item = byts[p:p + itemlen]
@@ -632,7 +627,7 @@ class _SimpleParameterizedType(_ParameterizedType):
raise TypeError("Received a string for a type that expects a sequence") raise TypeError("Received a string for a type that expects a sequence")
subtype, = cls.subtypes subtype, = cls.subtypes
buf = StringIO() buf = six.BytesIO()
buf.write(uint16_pack(len(items))) buf.write(uint16_pack(len(items)))
for item in items: for item in items:
itembytes = subtype.to_binary(item) itembytes = subtype.to_binary(item)
@@ -668,7 +663,7 @@ class MapType(_ParameterizedType):
numelements = uint16_unpack(byts[:2]) numelements = uint16_unpack(byts[:2])
p = 2 p = 2
themap = OrderedDict() themap = OrderedDict()
for n in xrange(numelements): for _ in range(numelements):
key_len = uint16_unpack(byts[p:p + 2]) key_len = uint16_unpack(byts[p:p + 2])
p += 2 p += 2
keybytes = byts[p:p + key_len] keybytes = byts[p:p + key_len]
@@ -685,7 +680,7 @@ class MapType(_ParameterizedType):
@classmethod @classmethod
def serialize_safe(cls, themap): def serialize_safe(cls, themap):
subkeytype, subvaltype = cls.subtypes subkeytype, subvaltype = cls.subtypes
buf = StringIO() buf = six.BytesIO()
buf.write(uint16_pack(len(themap))) buf.write(uint16_pack(len(themap)))
try: try:
items = themap.iteritems() items = themap.iteritems()

View File

@@ -2,14 +2,13 @@ import logging
import socket import socket
from uuid import UUID from uuid import UUID
# from six.moves import cStringIO as StringIO import six
from six.moves import xrange from six.moves import range
from six import BytesIO
from cassandra import (Unavailable, WriteTimeout, ReadTimeout, from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
AlreadyExists, InvalidRequest, Unauthorized) AlreadyExists, InvalidRequest, Unauthorized)
from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
int8_pack, int8_unpack) int8_pack, int8_unpack, header_pack)
from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
CounterColumnType, DateType, DecimalType, CounterColumnType, DateType, DecimalType,
DoubleType, FloatType, Int32Type, DoubleType, FloatType, Int32Type,
@@ -34,61 +33,69 @@ HEADER_DIRECTION_FROM_CLIENT = 0x00
HEADER_DIRECTION_TO_CLIENT = 0x80 HEADER_DIRECTION_TO_CLIENT = 0x80
HEADER_DIRECTION_MASK = 0x80 HEADER_DIRECTION_MASK = 0x80
COMPRESSED_FLAG = 0x01
TRACING_FLAG = 0x02
_message_types_by_name = {} _message_types_by_name = {}
_message_types_by_opcode = {} _message_types_by_opcode = {}
class _register_msg_type(type): class _RegisterMessageType(type):
def __init__(cls, name, bases, dct): def __init__(cls, name, bases, dct):
if not name.startswith('_'): if name not in ('NewBase', '_MessageType'):
_message_types_by_name[cls.name] = cls _message_types_by_name[cls.name] = cls
_message_types_by_opcode[cls.opcode] = cls _message_types_by_opcode[cls.opcode] = cls
class _MessageType(object): class _MessageType(six.with_metaclass(_RegisterMessageType, object)):
__metaclass__ = _register_msg_type
tracing = False tracing = False
def to_binary(self, stream_id, protocol_version, compression=None): def to_binary(self, stream_id, protocol_version, compression=None):
body = BytesIO() body = six.BytesIO()
self.send_body(body, protocol_version) self.send_body(body, protocol_version)
body = body.getvalue() body = body.getvalue()
version = protocol_version | HEADER_DIRECTION_FROM_CLIENT
flags = 0
if compression is not None and len(body) > 0:
body = compression(body)
flags |= 0x01
if self.tracing:
flags |= 0x02
msglen = int32_pack(len(body))
msg_parts = [int8_pack(i) for i in (version, flags, stream_id, self.opcode)] + [msglen, body]
return six.binary_type().join(msg_parts)
def __str__(self): flags = 0
paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in _get_params(self)] if compression and len(body) > 0:
return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs)) body = compression(body)
__repr__ = __str__ flags |= COMPRESSED_FLAG
if self.tracing:
flags |= TRACING_FLAG
msg = six.BytesIO()
write_header(
msg,
protocol_version | HEADER_DIRECTION_FROM_CLIENT,
flags, stream_id, self.opcode, len(body)
)
msg.write(body)
return msg.getvalue()
def __repr__(self):
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
def _get_params(message_obj): def _get_params(message_obj):
base_attrs = dir(_MessageType) base_attrs = dir(_MessageType)
return [a for a in dir(message_obj) return (
if a not in base_attrs and not a.startswith('_') and not callable(getattr(message_obj, a))] (n, a) for n, a in message_obj.__dict__.items()
if n not in base_attrs and not n.startswith('_') and not callable(a)
)
def decode_response(stream_id, flags, opcode, body, decompressor=None): def decode_response(stream_id, flags, opcode, body, decompressor=None):
if flags & 0x01: if flags & COMPRESSED_FLAG:
if decompressor is None: if decompressor is None:
raise Exception("No decompressor available for compressed frame!") raise Exception("No de-compressor available for compressed frame!")
body = decompressor(body) body = decompressor(body)
flags ^= 0x01 flags ^= COMPRESSED_FLAG
body = BytesIO(body) body = six.BytesIO(body)
if flags & 0x02: if flags & TRACING_FLAG:
trace_id = UUID(bytes=body.read(16)) trace_id = UUID(bytes=body.read(16))
flags ^= 0x02 flags ^= TRACING_FLAG
else: else:
trace_id = None trace_id = None
@@ -142,14 +149,13 @@ class ErrorMessage(_MessageType, Exception):
return self return self
class ErrorMessageSubclass(_register_msg_type): class ErrorMessageSubclass(_RegisterMessageType):
def __init__(cls, name, bases, dct): def __init__(cls, name, bases, dct):
if cls.error_code is not None: if name not in ('NewBase', ) and cls.error_code:
error_classes[cls.error_code] = cls error_classes[cls.error_code] = cls
class ErrorMessageSub(ErrorMessage): class ErrorMessageSub(six.with_metaclass(ErrorMessageSubclass, ErrorMessage)):
__metaclass__ = ErrorMessageSubclass
error_code = None error_code = None
@@ -450,7 +456,7 @@ class ResultMessage(_MessageType):
def recv_results_rows(cls, f): def recv_results_rows(cls, f):
column_metadata = cls.recv_results_metadata(f) column_metadata = cls.recv_results_metadata(f)
rowcount = read_int(f) rowcount = read_int(f)
rows = [cls.recv_row(f, len(column_metadata)) for x in xrange(rowcount)] rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
colnames = [c[2] for c in column_metadata] colnames = [c[2] for c in column_metadata]
coltypes = [c[3] for c in column_metadata] coltypes = [c[3] for c in column_metadata]
return (colnames, [tuple(ctype.from_binary(val) for ctype, val in zip(coltypes, row)) return (colnames, [tuple(ctype.from_binary(val) for ctype, val in zip(coltypes, row))
@@ -471,7 +477,7 @@ class ResultMessage(_MessageType):
ksname = read_string(f) ksname = read_string(f)
cfname = read_string(f) cfname = read_string(f)
column_metadata = [] column_metadata = []
for x in xrange(colcount): for _ in range(colcount):
if glob_tblspec: if glob_tblspec:
colksname = ksname colksname = ksname
colcfname = cfname colcfname = cfname
@@ -513,7 +519,7 @@ class ResultMessage(_MessageType):
@staticmethod @staticmethod
def recv_row(f, colcount): def recv_row(f, colcount):
return [read_value(f) for x in xrange(colcount)] return [read_value(f) for _ in range(colcount)]
class PrepareMessage(_MessageType): class PrepareMessage(_MessageType):
@@ -635,6 +641,14 @@ class EventMessage(_MessageType):
return dict(change_type=change_type, keyspace=keyspace, table=table) return dict(change_type=change_type, keyspace=keyspace, table=table)
def write_header(f, version, flags, stream_id, opcode, length):
"""
Write a CQL protocol frame header.
"""
f.write(header_pack(version, flags, stream_id, opcode))
write_int(f, length)
def read_byte(f): def read_byte(f):
return int8_unpack(f.read(1)) return int8_unpack(f.read(1))
@@ -701,7 +715,7 @@ def write_longstring(f, s):
def read_stringlist(f): def read_stringlist(f):
numstrs = read_short(f) numstrs = read_short(f)
return [read_string(f) for x in xrange(numstrs)] return [read_string(f) for _ in range(numstrs)]
def write_stringlist(f, stringlist): def write_stringlist(f, stringlist):
@@ -713,7 +727,7 @@ def write_stringlist(f, stringlist):
def read_stringmap(f): def read_stringmap(f):
numpairs = read_short(f) numpairs = read_short(f)
strmap = {} strmap = {}
for x in xrange(numpairs): for _ in range(numpairs):
k = read_string(f) k = read_string(f)
strmap[k] = read_string(f) strmap[k] = read_string(f)
return strmap return strmap
@@ -729,7 +743,7 @@ def write_stringmap(f, strmap):
def read_stringmultimap(f): def read_stringmultimap(f):
numkeys = read_short(f) numkeys = read_short(f)
strmmap = {} strmmap = {}
for x in xrange(numkeys): for _ in range(numkeys):
k = read_string(f) k = read_string(f)
strmmap[k] = read_stringlist(f) strmmap[k] = read_stringlist(f)
return strmmap return strmmap

View File

@@ -8,9 +8,9 @@ import six
from cassandra.util import OrderedDict from cassandra.util import OrderedDict
if six.PY3: # if six.PY3:
unicode = str # unicode = str
long = int # long = int
def cql_quote(term): def cql_quote(term):

View File

@@ -183,10 +183,10 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
return return
self.is_defunct = True self.is_defunct = True
trace = traceback.format_exc(exc) trace = traceback.format_exc() #exc)
if trace != "None": if trace != "None":
log.debug("Defuncting connection (%s) to %s: %s\n%s", log.debug("Defuncting connection (%s) to %s: %s\n%s",
id(self), self.host, exc, traceback.format_exc(exc)) id(self), self.host, exc, traceback.format_exc())
else: else:
log.debug("Defuncting connection (%s) to %s: %s", log.debug("Defuncting connection (%s) to %s: %s",
id(self), self.host, exc) id(self), self.host, exc)

View File

@@ -23,6 +23,11 @@ uint8_pack, uint8_unpack = _make_packer('>B')
float_pack, float_unpack = _make_packer('>f') float_pack, float_unpack = _make_packer('>f')
double_pack, double_unpack = _make_packer('>d') double_pack, double_unpack = _make_packer('>d')
# Special case for cassandra header
header_struct = struct.Struct('>BBBB')
header_pack = header_struct.pack
header_unpack = header_struct.unpack
def varint_unpack(term): def varint_unpack(term):
val = int(term.encode('hex'), 16) val = int(term.encode('hex'), 16)

View File

@@ -7,6 +7,7 @@ import logging
import re import re
from threading import RLock from threading import RLock
import weakref import weakref
import six
murmur3 = None murmur3 = None
try: try:
@@ -779,7 +780,7 @@ class TableMetadata(object):
def protect_name(name): def protect_name(name):
if isinstance(name, unicode): if isinstance(name, six.text_type):
name = name.encode('utf8') name = name.encode('utf8')
return maybe_escape_name(name) return maybe_escape_name(name)
@@ -1047,7 +1048,7 @@ class BytesToken(Token):
def __init__(self, token_string): def __init__(self, token_string):
""" `token_string` should be string representing the token. """ """ `token_string` should be string representing the token. """
if not isinstance(token_string, basestring): if not isinstance(token_string, six.string_types):
raise TypeError( raise TypeError(
"Tokens for ByteOrderedPartitioner should be strings (got %s)" "Tokens for ByteOrderedPartitioner should be strings (got %s)"
% (type(token_string),)) % (type(token_string),))

View File

@@ -5,7 +5,7 @@ from threading import Lock
from cassandra import ConsistencyLevel from cassandra import ConsistencyLevel
from six.moves import xrange from six.moves import range
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -512,7 +512,7 @@ class ExponentialReconnectionPolicy(ReconnectionPolicy):
self.max_delay = max_delay self.max_delay = max_delay
def new_schedule(self): def new_schedule(self):
return (min(self.base_delay * (2 ** i), self.max_delay) for i in xrange(64)) return (min(self.base_delay * (2 ** i), self.max_delay) for i in range(64))
class WriteType(object): class WriteType(object):

View File

@@ -1,409 +0,0 @@
import struct
import six
from six.moves import range
import uuid
# Low level byte pack and unpack methods.
def _make_pack_unpack_field(format):
s = struct.Struct(format)
return (
s.pack,
lambda b: s.unpack(b)[0]
)
_header_struct = struct.Struct('!BBBBL')
pack_cql_header = _header_struct.pack
unpack_cql_header = _header_struct.unpack
pack_cql_byte, unpack_cql_byte = _make_pack_unpack_field('!B')
pack_cql_int, unpack_cql_int = _make_pack_unpack_field('!i')
pack_cql_short, unpack_cql_short = _make_pack_unpack_field('!H')
# Maximum values for these data types.
MAX_INT = 0x7FFFFFFF
MAX_SHORT = 0xFFFF
def read_header(f):
"""
Read a CQL protocol frame header.
A frame header consists of 4 bytes for the fields version, flags, stream and opcode. This is followed by a 4
byte length field, reading a total of 8 bytes.
:returns: tuple consisting of the version, flags, stream, opcode and length fields.
"""
return unpack_cql_header(f.read(8))
def write_header(f, version, flags, stream_id, opcode, length):
"""
Write a CQL protocol frame header.
"""
f.write(pack_cql_header(version, flags, stream_id, opcode, length))
def read_byte(f):
return f.read()
def write_byte(f, v):
f.write(pack_cql_byte(v))
def read_int(f):
return unpack_cql_int(f.read(4))
def write_int(f, v):
f.write(pack_cql_int(v))
def read_short(f):
return unpack_cql_short(f.read(2))
def write_short(f, v):
f.write(pack_cql_short(v))
def read_string(f):
"""
:returns: Python 3 returns a str; Python 2 returns a unicode string.
"""
n = f.read_short()
return f.read(n).decode('UTF8')
def write_string(f, v):
# TODO: Should really check that a short string isn't longer than a 2^2.
if isinstance(v, six.text_type):
b = v.encode('UTF8')
write_short(f, len(b))
f.write(b)
elif isinstance(v, str):
# This assumes that str will be caught by the previous if statement with Python 3.
write_short(f, len(v))
f.write(v)
else:
write_string(f, str(v))
def read_long_string(f):
"""
:returns: Python 3 returns a str; Python 2 returns a unicode string.
"""
n = read_int(f)
return f.read(n).decode('UTF8')
def write_long_string(f, v):
# TODO: Should really check that a long string isn't longer than a 2^4 / 2.
if isinstance(v, six.text_type):
b = v.encode('UTF8')
write_int(f, len(b))
f.write(b)
elif isinstance(v, str):
# This assumes that str will be caught by the previous if statement with Python 3.
write_int(f, len(v))
f.write(v)
else:
write_long_string(f, str(v))
def read_uuid(f):
return uuid.UUID(bytes=f.read(16))
def write_uuid(f, v):
assert isinstance(v, uuid.UUID)
f.write(v.bytes)
def read_string_list(f):
n = read_short(f)
return [read_string(f) for _ in range(n)]
def write_string_list(f, v):
n = len(v)
for idx in range(n):
write_string(f, v[idx])
def read_bytes(f):
n = read_int(f)
return None if n < 0 else f.read(n)
def write_bytes(f, v):
if v is None:
write_int(f, -1)
f.write(v)
else:
write_int(f, len(v))
f.write(v)
def read_short_bytes(f):
n = read_short(f)
return f.read(n)
def write_short_bytes(f, v):
if v is None:
write_short(f, 0)
else:
n = len(v)
assert n <= MAX_SHORT
write_short(f, n)
f.write(v)
def read_inet(f):
n = f.read(1)
values = f.read(n)
raise NotImplementedError
def write_inet(f, v):
raise NotImplementedError
read_consistency = read_short
write_consistency = write_short
def read_string_map(f):
n = read_short(f)
return dict((read_string(f), read_string(f)) for _ in range(n))
def write_string_map(f, v):
write_short(f, len(v))
for key, value in six.iteritems(v):
write_string(f, key)
write_string(f, value)
def read_string_multimap(f):
n = read_short(f)
return dict((read_string(f), read_string_list(f)) for _ in range(n))
def write_string_multimap(f, v):
write_short(f, len(v))
for key, value in six.iteritems(v):
write_string(f, key)
write_string_list(f, value)
## Define messages ##############################
HEADER_DIRECTION_FROM_CLIENT = 0x00
HEADER_DIRECTION_TO_CLIENT = 0x80
HEADER_DIRECTION_MASK = 0x80
COMPRESSED_FLAG = 0x01
TRACING_FLAG = 0x02
_message_types_by_name = {}
_message_types_by_opcode = {}
_error_classes = {}
class _RegisterMessageType(type):
def __init__(cls, what, *args, **kwargs):
if what not in ('_MessageType', 'NewBase'):
_message_types_by_name[cls.name] = cls
_message_types_by_opcode[cls.opcode] = cls
def _get_params(message_obj):
base_attrs = dir(_MessageType)
return (
(n, a) for n, a in message_obj.__dict__.items()
if n not in base_attrs and not n.startswith('_') and not callable(a)
)
class _MessageType(six.with_metaclass(_RegisterMessageType, object)):
opcode = None
name = None
tracing = False
def __repr__(self):
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
def send_body(self, buf, protocol_version):
"""
Encode the body of this message for sending.
:param buf: An instance of `ByteBuffer`.
:param protocol_version: Version of the protocol currently being used.
"""
pass
def to_binary(self, stream_id, protocol_version, compression=None):
"""
Pack this message into it's binary format.
"""
body = six.BytesIO()
self.send_body(body, protocol_version)
body = body.getvalue()
flags = 0
if compression and len(body) > 0:
body = compression(body)
flags |= COMPRESSED_FLAG
if self.tracing:
flags |= TRACING_FLAG
msg = six.BytesIO()
write_header(
msg,
protocol_version | HEADER_DIRECTION_FROM_CLIENT,
flags, stream_id, self.opcode, len(body)
)
msg.write(body)
return msg.getvalue()
def decode_response(stream_id, flags, opcode, body, decompressor=None):
"""
Build msg class.
"""
if flags & COMPRESSED_FLAG:
if callable(decompressor):
body = decompressor(body)
flags ^= COMPRESSED_FLAG
else:
raise TypeError("De-compressor not available for compressed frame!")
body = six.BytesIO(body)
if flags & TRACING_FLAG:
trace_id = read_uuid(body)
flags ^= TRACING_FLAG
else:
trace_id = None
if flags:
# TODO: log.warn("Unknown protocol flags set: %02x. May cause problems.", flags)
pass
msg_class = _message_types_by_opcode[opcode]
msg = msg_class.recv_body(body)
msg.stream_id = stream_id
msg.trace_id = trace_id
return msg
class StartupMessage(_MessageType):
opcode = 0x01
name = 'STARTUP'
KNOWN_OPTION_KEYS = set(('CQL_VERSION', 'COMPRESSION',))
def __init__(self, cqlversion, options):
self.cqlversion = cqlversion
self.options = options
def send_body(self, f, protocol_version):
opt_map = self.options.copy()
opt_map['CQL_VERSION'] = self.cqlversion
write_string_map(f, opt_map)
class ReadyMessage(_MessageType):
opcode = 0x02
name = 'READY'
@classmethod
def recv_body(cls, f):
return cls()
class AuthenticateMessage(_MessageType):
opcode = 0x03
name = 'AUTHENTICATE'
def __init__(self, authenticator):
self.authenticator = authenticator
@classmethod
def recv_body(cls, f):
authenticator = read_string(f)
return cls(authenticator)
class CredentialsMessage(_MessageType):
opcode = 0x04
name = 'CREDENTIALS'
def __init__(self, credentials):
self.credentials = credentials
def send_body(self, f, protocol_version):
write_string_map(f, self.credentials)
class OptionsMessage(_MessageType):
opcode = 0x05
name = 'OPTIONS'
class SupportedMessage(_MessageType):
opcode = 0x06
name = 'SUPPORTED'
def __init__(self, cql_versions, options):
self.cql_versions = cql_versions
self.options = options
@classmethod
def recv_body(cls, f):
options = read_string_multimap(f)
cql_versions = options.pop('CQL_VERSION')
return cls(cql_versions, options)
class QueryMessage(_MessageType):
opcode = 0x07
name = 'QUERY'
def __init__(self, query, consistency_level):
self.query = query
self.consistency_level = consistency_level
def send_body(self, f, protocol_version):
write_long_string(f, self.query)
write_consistency(f, self.consistency_level)
@classmethod
def recv_body(cls, f):
query = read_long_string(f)
consistency_level = read_consistency(f)
return cls(query, consistency_level)
class ResultMessage(_MessageType):
opcode = 0x08
name = 'RESULT'
def __init__(self, kind, results):
self.kind = kind
self.results = results
class PrepareMessage(_MessageType):
opcode = 0x09
name = 'PREPARE'
def __init__(self, query):
self.query = query
def send_body(self, f, protocol_version):
write_long_string(f, self.query)

View File

@@ -1,3 +1,5 @@
import six
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
@@ -6,7 +8,7 @@ except ImportError:
import errno import errno
import os import os
from six.moves import StringIO from six import BytesIO
import socket import socket
from socket import error as socket_error from socket import error as socket_error
@@ -18,7 +20,7 @@ from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT,
from cassandra.decoder import (write_stringmultimap, write_int, write_string, from cassandra.decoder import (write_stringmultimap, write_int, write_string,
SupportedMessage, ReadyMessage, ServerError) SupportedMessage, ReadyMessage, ServerError)
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.marshal import uint8_pack, uint8_unpack, uint32_pack, int32_pack
from cassandra.io.asyncorereactor import AsyncoreConnection from cassandra.io.asyncorereactor import AsyncoreConnection
@@ -43,7 +45,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
return c return c
def make_header_prefix(self, message_class, version=2, stream_id=0): def make_header_prefix(self, message_class, version=2, stream_id=0):
return ''.join(map(uint8_pack, [ return six.binary_type().join(map(uint8_pack, [
0xff & (HEADER_DIRECTION_TO_CLIENT | version), 0xff & (HEADER_DIRECTION_TO_CLIENT | version),
0, # flags (compression) 0, # flags (compression)
stream_id, stream_id,
@@ -51,7 +53,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
])) ]))
def make_options_body(self): def make_options_body(self):
options_buf = StringIO() options_buf = BytesIO()
write_stringmultimap(options_buf, { write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.1'], 'CQL_VERSION': ['3.0.1'],
'COMPRESSION': [] 'COMPRESSION': []
@@ -59,12 +61,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
return options_buf.getvalue() return options_buf.getvalue()
def make_error_body(self, code, msg): def make_error_body(self, code, msg):
buf = StringIO() buf = BytesIO()
write_int(buf, code) write_int(buf, code)
write_string(buf, msg) write_string(buf, msg)
return buf.getvalue() return buf.getvalue()
def make_msg(self, header, body=""): def make_msg(self, header, body=six.binary_type()):
return header + uint32_pack(len(body)) + body return header + uint32_pack(len(body)) + body
def test_successful_connection(self, *args): def test_successful_connection(self, *args):
@@ -93,12 +95,12 @@ class AsyncoreConnectionTest(unittest.TestCase):
# get a connection that's already fully started # get a connection that's already fully started
c = self.test_successful_connection() c = self.test_successful_connection()
header = '\x00\x00\x00\x00' + int32_pack(20000) header = six.b('\x00\x00\x00\x00') + int32_pack(20000)
responses = [ responses = [
header + ('a' * (4096 - len(header))), header + (six.b('a') * (4096 - len(header))),
'a' * 4096, six.b('a') * 4096,
socket_error(errno.EAGAIN), socket_error(errno.EAGAIN),
'a' * 100, six.b('a') * 100,
socket_error(errno.EAGAIN)] socket_error(errno.EAGAIN)]
def side_effect(*args): def side_effect(*args):
@@ -225,14 +227,13 @@ class AsyncoreConnectionTest(unittest.TestCase):
options = self.make_options_body() options = self.make_options_body()
message = self.make_msg(header, options) message = self.make_msg(header, options)
# read in the first byte c.socket.recv.return_value = message[0:1]
c.socket.recv.return_value = message[0]
c.handle_read() c.handle_read()
self.assertEquals(c._iobuf.getvalue(), message[0]) self.assertEquals(c._iobuf.getvalue(), message[0:1])
c.socket.recv.return_value = message[1:] c.socket.recv.return_value = message[1:]
c.handle_read() c.handle_read()
self.assertEquals("", c._iobuf.getvalue()) self.assertEquals(six.binary_type(), c._iobuf.getvalue())
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write() c.handle_write()
@@ -259,7 +260,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# ... then read in the rest # ... then read in the rest
c.socket.recv.return_value = message[9:] c.socket.recv.return_value = message[9:]
c.handle_read() c.handle_read()
self.assertEquals("", c._iobuf.getvalue()) self.assertEquals(six.binary_type(), c._iobuf.getvalue())
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write() c.handle_write()

View File

@@ -1,9 +1,11 @@
import six
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
from six.moves import StringIO from six import BytesIO
from mock import Mock, ANY from mock import Mock, ANY
@@ -25,7 +27,7 @@ class ConnectionTest(unittest.TestCase):
return c return c
def make_header_prefix(self, message_class, version=2, stream_id=0): def make_header_prefix(self, message_class, version=2, stream_id=0):
return ''.join(map(uint8_pack, [ return six.binary_type().join(map(uint8_pack, [
0xff & (HEADER_DIRECTION_TO_CLIENT | version), 0xff & (HEADER_DIRECTION_TO_CLIENT | version),
0, # flags (compression) 0, # flags (compression)
stream_id, stream_id,
@@ -33,7 +35,7 @@ class ConnectionTest(unittest.TestCase):
])) ]))
def make_options_body(self): def make_options_body(self):
options_buf = StringIO() options_buf = BytesIO()
write_stringmultimap(options_buf, { write_stringmultimap(options_buf, {
'CQL_VERSION': ['3.0.1'], 'CQL_VERSION': ['3.0.1'],
'COMPRESSION': [] 'COMPRESSION': []
@@ -41,7 +43,7 @@ class ConnectionTest(unittest.TestCase):
return options_buf.getvalue() return options_buf.getvalue()
def make_error_body(self, code, msg): def make_error_body(self, code, msg):
buf = StringIO() buf = BytesIO()
write_int(buf, code) write_int(buf, code)
write_string(buf, msg) write_string(buf, msg)
return buf.getvalue() return buf.getvalue()
@@ -73,12 +75,12 @@ class ConnectionTest(unittest.TestCase):
c.defunct = Mock() c.defunct = Mock()
# read in a SupportedMessage response # read in a SupportedMessage response
header = ''.join(map(uint8_pack, [ header = six.binary_type().join(uint8_pack(i) for i in (
0xff & (HEADER_DIRECTION_FROM_CLIENT | self.protocol_version), 0xff & (HEADER_DIRECTION_FROM_CLIENT | self.protocol_version),
0, # flags (compression) 0, # flags (compression)
0, 0,
SupportedMessage.opcode # opcode SupportedMessage.opcode # opcode
])) ))
options = self.make_options_body() options = self.make_options_body()
message = self.make_msg(header, options) message = self.make_msg(header, options)
c.process_msg(message, len(message) - 8) c.process_msg(message, len(message) - 8)
@@ -115,7 +117,7 @@ class ConnectionTest(unittest.TestCase):
# read in a SupportedMessage response # read in a SupportedMessage response
header = self.make_header_prefix(SupportedMessage) header = self.make_header_prefix(SupportedMessage)
options_buf = StringIO() options_buf = BytesIO()
write_stringmultimap(options_buf, { write_stringmultimap(options_buf, {
'CQL_VERSION': ['7.8.9'], 'CQL_VERSION': ['7.8.9'],
'COMPRESSION': [] 'COMPRESSION': []

View File

@@ -1,3 +1,5 @@
import six
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
@@ -20,55 +22,55 @@ marshalled_value_pairs = (
# binary form, type, python native type # binary form, type, python native type
('lorem ipsum dolor sit amet', 'AsciiType', 'lorem ipsum dolor sit amet'), ('lorem ipsum dolor sit amet', 'AsciiType', 'lorem ipsum dolor sit amet'),
('', 'AsciiType', ''), ('', 'AsciiType', ''),
('\x01', 'BooleanType', True), (six.b('\x01'), 'BooleanType', True),
('\x00', 'BooleanType', False), (six.b('\x00'), 'BooleanType', False),
('', 'BooleanType', None), (six.b(''), 'BooleanType', None),
('\xff\xfe\xfd\xfc\xfb', 'BytesType', '\xff\xfe\xfd\xfc\xfb'), (six.b('\xff\xfe\xfd\xfc\xfb'), 'BytesType', '\xff\xfe\xfd\xfc\xfb'),
('', 'BytesType', ''), (six.b(''), 'BytesType', ''),
('\x7f\xff\xff\xff\xff\xff\xff\xff', 'CounterColumnType', 9223372036854775807), (six.b('\x7f\xff\xff\xff\xff\xff\xff\xff'), 'CounterColumnType', 9223372036854775807),
('\x80\x00\x00\x00\x00\x00\x00\x00', 'CounterColumnType', -9223372036854775808), (six.b('\x80\x00\x00\x00\x00\x00\x00\x00'), 'CounterColumnType', -9223372036854775808),
('', 'CounterColumnType', None), (six.b(''), 'CounterColumnType', None),
('\x00\x00\x013\x7fb\xeey', 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)), (six.b('\x00\x00\x013\x7fb\xeey'), 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)),
('', 'DateType', None), (six.b(''), 'DateType', None),
('\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe', 'DecimalType', Decimal('1243878957943.1234124191998')), (six.b('\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe'), 'DecimalType', Decimal('1243878957943.1234124191998')),
('\x00\x00\x00\x06\xe5\xde]\x98Y', 'DecimalType', Decimal('-112233.441191')), (six.b('\x00\x00\x00\x06\xe5\xde]\x98Y'), 'DecimalType', Decimal('-112233.441191')),
('\x00\x00\x00\x14\x00\xfa\xce', 'DecimalType', Decimal('0.00000000000000064206')), (six.b('\x00\x00\x00\x14\x00\xfa\xce'), 'DecimalType', Decimal('0.00000000000000064206')),
('\x00\x00\x00\x14\xff\x052', 'DecimalType', Decimal('-0.00000000000000064206')), (six.b('\x00\x00\x00\x14\xff\x052'), 'DecimalType', Decimal('-0.00000000000000064206')),
('\xff\xff\xff\x9c\x00\xfa\xce', 'DecimalType', Decimal('64206e100')), (six.b('\xff\xff\xff\x9c\x00\xfa\xce'), 'DecimalType', Decimal('64206e100')),
('', 'DecimalType', None), (six.b(''), 'DecimalType', None),
('@\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', 19432.125), (six.b('@\xd2\xfa\x08\x00\x00\x00\x00'), 'DoubleType', 19432.125),
('\xc0\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', -19432.125), (six.b('\xc0\xd2\xfa\x08\x00\x00\x00\x00'), 'DoubleType', -19432.125),
('\x7f\xef\x00\x00\x00\x00\x00\x00', 'DoubleType', 1.7415152243978685e+308), (six.b('\x7f\xef\x00\x00\x00\x00\x00\x00'), 'DoubleType', 1.7415152243978685e+308),
('', 'DoubleType', None), (six.b(''), 'DoubleType', None),
('F\x97\xd0@', 'FloatType', 19432.125), (six.b('F\x97\xd0@'), 'FloatType', 19432.125),
('\xc6\x97\xd0@', 'FloatType', -19432.125), (six.b('\xc6\x97\xd0@'), 'FloatType', -19432.125),
('\xc6\x97\xd0@', 'FloatType', -19432.125), (six.b('\xc6\x97\xd0@'), 'FloatType', -19432.125),
('\x7f\x7f\x00\x00', 'FloatType', 338953138925153547590470800371487866880.0), (six.b('\x7f\x7f\x00\x00'), 'FloatType', 338953138925153547590470800371487866880.0),
('', 'FloatType', None), (six.b(''), 'FloatType', None),
('\x7f\x50\x00\x00', 'Int32Type', 2135949312), (six.b('\x7f\x50\x00\x00'), 'Int32Type', 2135949312),
('\xff\xfd\xcb\x91', 'Int32Type', -144495), (six.b('\xff\xfd\xcb\x91'), 'Int32Type', -144495),
('', 'Int32Type', None), (six.b(''), 'Int32Type', None),
('f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15', 'IntegerType', 123456789123456789123456789), (six.b('f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15'), 'IntegerType', 123456789123456789123456789),
('', 'IntegerType', None), (six.b(''), 'IntegerType', None),
('\x7f\xff\xff\xff\xff\xff\xff\xff', 'LongType', 9223372036854775807), (six.b('\x7f\xff\xff\xff\xff\xff\xff\xff'), 'LongType', 9223372036854775807),
('\x80\x00\x00\x00\x00\x00\x00\x00', 'LongType', -9223372036854775808), (six.b('\x80\x00\x00\x00\x00\x00\x00\x00'), 'LongType', -9223372036854775808),
('', 'LongType', None), (six.b(''), 'LongType', None),
('', 'InetAddressType', None), (six.b(''), 'InetAddressType', None),
('A46\xa9', 'InetAddressType', '65.52.54.169'), (six.b('A46\xa9'), 'InetAddressType', '65.52.54.169'),
('*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"', 'InetAddressType', '2a00:1328:e102:ccc0::122'), (six.b('*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"'), 'InetAddressType', '2a00:1328:e102:ccc0::122'),
('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', u'\u307e\u3057\u3066'), (six.b('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6'), 'UTF8Type', u'\u307e\u3057\u3066'),
('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', u'\u307e\u3057\u3066' * 1000), (six.b('\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000), 'UTF8Type', u'\u307e\u3057\u3066' * 1000),
('', 'UTF8Type', u''), (six.b(''), 'UTF8Type', u''),
('\xff' * 16, 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), (six.b('\xff' * 16), 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')),
('I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*', 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')), (six.b('I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*'), 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')),
('', 'UUIDType', None), (six.b(''), 'UUIDType', None),
('', 'MapType(AsciiType, BooleanType)', None), (six.b(''), 'MapType(AsciiType, BooleanType)', None),
('', 'ListType(FloatType)', None), (six.b(''), 'ListType(FloatType)', None),
('', 'SetType(LongType)', None), (six.b(''), 'SetType(LongType)', None),
('\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()), (six.b('\x00\x00'), 'MapType(DecimalType, BooleanType)', OrderedDict()),
('\x00\x00', 'ListType(FloatType)', ()), (six.b('\x00\x00'), 'ListType(FloatType)', ()),
('\x00\x00', 'SetType(IntegerType)', sortedset()), (six.b('\x00\x00'), 'SetType(IntegerType)', sortedset()),
('\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', (UUID(bytes='\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0'),)), (six.b('\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0'), 'ListType(TimeUUIDType)', (UUID(bytes=six.b('\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')),)),
) )
ordered_dict_value = OrderedDict() ordered_dict_value = OrderedDict()

View File

@@ -206,8 +206,8 @@ class TestTokens(unittest.TestCase):
def test_md5_tokens(self): def test_md5_tokens(self):
md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1) md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1)
self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808L) self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808)
self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639L) self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639)
self.assertEqual(str(md5_token), '<MD5Token: -9223372036854775809L>') self.assertEqual(str(md5_token), '<MD5Token: -9223372036854775809L>')
def test_bytes_tokens(self): def test_bytes_tokens(self):