Register classes for UDT deserialization

This commit is contained in:
Tyler Hobbs
2014-06-17 17:50:24 -05:00
parent dd28b585dc
commit c4ef43c26c
6 changed files with 77 additions and 53 deletions

View File

@@ -334,6 +334,11 @@ class Cluster(object):
_prepared_statements = None
_prepared_statement_lock = None
_user_types = None
"""
A map of {keyspace: {type_name: UserType}}
"""
_listeners = None
_listener_lock = None
@@ -412,6 +417,8 @@ class Cluster(object):
self._prepared_statements = WeakValueDictionary()
self._prepared_statement_lock = Lock()
self._user_types = defaultdict(dict)
self._min_requests_per_connection = {
HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
@@ -444,6 +451,9 @@ class Cluster(object):
self.control_connection = ControlConnection(
self, self.control_connection_timeout)
def register_type_class(self, keyspace, user_type, klass):
self._user_types[keyspace][user_type] = klass
def get_min_requests_per_connection(self, host_distance):
return self._min_requests_per_connection[host_distance]
@@ -517,6 +527,7 @@ class Cluster(object):
kwargs_dict['ssl_options'] = self.ssl_options
kwargs_dict['cql_version'] = self.cql_version
kwargs_dict['protocol_version'] = self.protocol_version
kwargs_dict['user_type_map'] = self._user_types
return kwargs_dict

View File

@@ -155,12 +155,14 @@ class Connection(object):
is_defunct = False
is_closed = False
lock = None
user_type_map = None
is_control_connection = False
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
ssl_options=None, sockopts=None, compression=True,
cql_version=None, protocol_version=2, is_control_connection=False):
cql_version=None, protocol_version=2, is_control_connection=False,
user_type_map=None):
self.host = host
self.port = port
self.authenticator = authenticator
@@ -170,6 +172,7 @@ class Connection(object):
self.cql_version = cql_version
self.protocol_version = protocol_version
self.is_control_connection = is_control_connection
self.user_type_map = user_type_map
self._push_watchers = defaultdict(set)
if protocol_version >= 3:
self._header_unpack = v3_header_unpack
@@ -347,7 +350,8 @@ class Connection(object):
else:
raise ProtocolError("Got negative body length: %r" % body_len)
response = decode_response(given_version, stream_id, flags, opcode, body, self.decompressor)
response = decode_response(given_version, self.user_type_map, stream_id,
flags, opcode, body, self.decompressor)
except Exception as exc:
log.exception("Error decoding response from Cassandra. "
"opcode: %04x; message contents: %r", opcode, msg)
@@ -386,10 +390,10 @@ class Connection(object):
if isinstance(options_response, ConnectionException):
raise options_response
else:
log.error("Did not get expected SupportedMessage response; " \
log.error("Did not get expected SupportedMessage response; "
"instead, got: %s", options_response)
raise ConnectionException("Did not get expected SupportedMessage " \
"response; instead, got: %s" \
raise ConnectionException("Did not get expected SupportedMessage "
"response; instead, got: %s"
% (options_response,))
log.debug("Received options response on new connection (%s) from %s",

View File

@@ -778,14 +778,21 @@ class UserDefinedType(_ParameterizedType):
FIELD_LENGTH = 4
_cache = {}
@classmethod
def apply_parameters(cls, subtypes, names):
newname = subtypes[1].cassname.decode("hex")
field_names = [encoded_name.decode("hex") for encoded_name in names[2:]]
return type(newname, (cls,), {'subtypes': subtypes[2:],
'cassname': cls.cassname,
'typename': newname,
'fieldnames': field_names})
def apply_parameters(cls, udt_name, names_and_types, mapped_class):
try:
return cls._cache[udt_name]
except KeyError:
fieldnames, types = zip(*names_and_types)
instance = type(udt_name, (cls,), {'subtypes': types,
'cassname': cls.cassname,
'typename': udt_name,
'fieldnames': fieldnames,
'mapped_class': mapped_class})
cls._cache[udt_name] = instance
return instance
@classmethod
def cql_parameterized_type(cls):
@@ -795,8 +802,7 @@ class UserDefinedType(_ParameterizedType):
def deserialize_safe(cls, byts, protocol_version):
proto_version = max(3, protocol_version)
p = 0
Result = namedtuple(cls.typename, cls.fieldnames)
result = []
values = []
for col_type in cls.subtypes:
if p == len(byts):
break
@@ -806,13 +812,17 @@ class UserDefinedType(_ParameterizedType):
p += itemlen
# collections inside UDTs are always encoded with at least the
# version 3 format
result.append(col_type.from_binary(item, proto_version))
values.append(col_type.from_binary(item, proto_version))
if len(result) < len(cls.subtypes):
nones = [None] * (len(cls.subtypes) - len(result))
result = result + nones
if len(values) < len(cls.subtypes):
nones = [None] * (len(cls.subtypes) - len(values))
values = values + nones
return Result(*result)
if cls.mapped_class:
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
else:
Result = namedtuple(cls.typename, cls.fieldnames)
return Result(*values)
@classmethod
def serialize_safe(cls, val, protocol_version):

View File

@@ -100,7 +100,8 @@ def _get_params(message_obj):
)
def decode_response(protocol_version, stream_id, flags, opcode, body, decompressor=None):
def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, body,
decompressor=None):
if flags & COMPRESSED_FLAG:
if decompressor is None:
raise Exception("No de-compressor available for compressed frame!")
@@ -118,7 +119,7 @@ def decode_response(protocol_version, stream_id, flags, opcode, body, decompress
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
msg_class = _message_types_by_opcode[opcode]
msg = msg_class.recv_body(body, protocol_version)
msg = msg_class.recv_body(body, protocol_version, user_type_map)
msg.stream_id = stream_id
msg.trace_id = trace_id
return msg
@@ -138,7 +139,7 @@ class ErrorMessage(_MessageType, Exception):
self.info = info
@classmethod
def recv_body(cls, f, protocol_version):
def recv_body(cls, f, protocol_version, user_type_map):
code = read_int(f)
msg = read_string(f)
subcls = error_classes.get(code, cls)
@@ -338,7 +339,7 @@ class ReadyMessage(_MessageType):
name = 'READY'
@classmethod
def recv_body(cls, f, protocol_version):
def recv_body(cls, f, protocol_version, user_type_map):
return cls()
@@ -350,7 +351,7 @@ class AuthenticateMessage(_MessageType):
self.authenticator = authenticator
@classmethod
def recv_body(cls, f, protocol_version):
def recv_body(cls, f, protocol_version, user_type_map):
authname = read_string(f)
return cls(authenticator=authname)
@@ -382,7 +383,7 @@ class AuthChallengeMessage(_MessageType):
self.challenge = challenge
@classmethod
def recv_body(cls, f, protocol_version):
def recv_body(cls, f, protocol_version, user_type_map):
return cls(read_longstring(f))
@@ -405,7 +406,7 @@ class AuthSuccessMessage(_MessageType):
self.token = token
@classmethod
def recv_body(cls, f, protocol_version):
def recv_body(cls, f, protocol_version, user_type_map):
return cls(read_longstring(f))
@@ -426,7 +427,7 @@ class SupportedMessage(_MessageType):
self.options = options
@classmethod
def recv_body(cls, f, protocol_version):
def recv_body(cls, f, protocol_version, user_type_map):
options = read_stringmultimap(f)
cql_versions = options.pop('CQL_VERSION')
return cls(cql_versions=cql_versions, options=options)
@@ -547,13 +548,14 @@ class ResultMessage(_MessageType):
self.paging_state = paging_state
@classmethod
def recv_body(cls, f, protocol_version):
def recv_body(cls, f, protocol_version, user_type_map):
kind = read_int(f)
paging_state = None
if kind == RESULT_KIND_VOID:
results = None
elif kind == RESULT_KIND_ROWS:
paging_state, results = cls.recv_results_rows(f, protocol_version)
paging_state, results = cls.recv_results_rows(
f, protocol_version, user_type_map)
elif kind == RESULT_KIND_SET_KEYSPACE:
ksname = read_string(f)
results = ksname
@@ -564,16 +566,17 @@ class ResultMessage(_MessageType):
return cls(kind, results, paging_state)
@classmethod
def recv_results_rows(cls, f, protocol_version):
def recv_results_rows(cls, f, protocol_version, user_type_map):
paging_state, column_metadata = cls.recv_results_metadata(f)
rowcount = read_int(f)
rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
colnames = [c[2] for c in column_metadata]
coltypes = [c[3] for c in column_metadata]
return (
paging_state,
(colnames, [tuple(ctype.from_binary(val, protocol_version) for ctype, val in zip(coltypes, row))
for row in rows]))
parsed_rows = [
tuple(ctype.from_binary(val, protocol_version)
for ctype, val in zip(coltypes, row))
for row in rows]
return (paging_state, (colnames, parsed_rows))
@classmethod
def recv_results_prepared(cls, f):
@@ -582,7 +585,7 @@ class ResultMessage(_MessageType):
return (query_id, column_metadata)
@classmethod
def recv_results_metadata(cls, f):
def recv_results_metadata(cls, f, user_type_map):
flags = read_int(f)
glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC)
colcount = read_int(f)
@@ -624,7 +627,7 @@ class ResultMessage(_MessageType):
return {'change_type': change_type, 'keyspace': keyspace, 'table': table}
@classmethod
def read_type(cls, f):
def read_type(cls, f, user_type_map):
optid = read_short(f)
try:
typeclass = cls._type_codes[optid]
@@ -638,6 +641,14 @@ class ResultMessage(_MessageType):
keysubtype = cls.read_type(f)
valsubtype = cls.read_type(f)
typeclass = typeclass.apply_parameters((keysubtype, valsubtype))
elif typeclass == UserDefinedType:
ks = cls.read_string(f)
udt_name = cls.read_string(f)
num_fields = cls.read_short(f)
names_and_types = ((cls.read_string(f), cls.read_type(f))
for _ in xrange(num_fields))
mapped_class = user_type_map.get(ks, {}).get(udt_name)
typeclass = typeclass.apply_parameters(udt_name, names_and_types, mapped_class)
elif typeclass == CUSTOM_TYPE:
classname = read_string(f)
typeclass = lookup_casstype(classname)
@@ -789,7 +800,7 @@ class EventMessage(_MessageType):
self.event_args = event_args
@classmethod
def recv_body(cls, f, protocol_version):
def recv_body(cls, f, protocol_version, user_type_map):
event_type = read_string(f).upper()
if event_type in known_event_types:
read_method = getattr(cls, 'recv_' + event_type.lower())

View File

@@ -397,27 +397,15 @@ class ControlConnectionTest(unittest.TestCase):
def test_handle_schema_change(self):
for change_type in ('CREATED', 'DROPPED'):
for change_type in ('CREATED', 'DROPPED', 'UPDATED'):
event = {
'change_type': change_type,
'keyspace': 'ks1',
'table': 'table1'
}
self.control_connection._handle_schema_change(event)
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1')
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', 'table1', None)
event['table'] = None
self.control_connection._handle_schema_change(event)
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, None)
event = {
'change_type': 'UPDATED',
'keyspace': 'ks1',
'table': 'table1'
}
self.control_connection._handle_schema_change(event)
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', 'table1')
event['table'] = None
self.control_connection._handle_schema_change(event)
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', None)
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', None, None)

View File

@@ -105,7 +105,7 @@ class ResponseFutureTests(unittest.TestCase):
kind=RESULT_KIND_SCHEMA_CHANGE,
results={'keyspace': "keyspace1", "table": "table1"})
rf._set_result(result)
session.submit.assert_called_once_with(ANY, 'keyspace1', 'table1', ANY, rf)
session.submit.assert_called_once_with(ANY, 'keyspace1', 'table1', None, ANY, rf)
def test_other_result_message_kind(self):
session = self.make_session()